#include <crypto/sha.h>
 #include <linux/fsverity.h>
+#include <linux/mempool.h>
 
 struct ahash_request;
 
        const char *name;         /* crypto API name, e.g. sha256 */
        unsigned int digest_size; /* digest size in bytes, e.g. 32 for SHA-256 */
        unsigned int block_size;  /* block size in bytes, e.g. 64 for SHA-256 */
+       mempool_t req_pool;       /* mempool with a preallocated hash request */
 };
 
 /* Merkle tree parameters: hash algorithm, initial hash state, and topology */
 struct merkle_tree_params {
-       const struct fsverity_hash_alg *hash_alg; /* the hash algorithm */
+       struct fsverity_hash_alg *hash_alg; /* the hash algorithm */
        const u8 *hashstate;            /* initial hash state or NULL */
        unsigned int digest_size;       /* same as hash_alg->digest_size */
        unsigned int block_size;        /* size of data and tree blocks */
 
 extern struct fsverity_hash_alg fsverity_hash_algs[];
 
-const struct fsverity_hash_alg *fsverity_get_hash_alg(const struct inode *inode,
-                                                     unsigned int num);
-const u8 *fsverity_prepare_hash_state(const struct fsverity_hash_alg *alg,
+struct fsverity_hash_alg *fsverity_get_hash_alg(const struct inode *inode,
+                                               unsigned int num);
+struct ahash_request *fsverity_alloc_hash_request(struct fsverity_hash_alg *alg,
+                                                 gfp_t gfp_flags);
+void fsverity_free_hash_request(struct fsverity_hash_alg *alg,
+                               struct ahash_request *req);
+const u8 *fsverity_prepare_hash_state(struct fsverity_hash_alg *alg,
                                      const u8 *salt, size_t salt_size);
 int fsverity_hash_page(const struct merkle_tree_params *params,
                       const struct inode *inode,
                       struct ahash_request *req, struct page *page, u8 *out);
-int fsverity_hash_buffer(const struct fsverity_hash_alg *alg,
+int fsverity_hash_buffer(struct fsverity_hash_alg *alg,
                         const void *data, size_t size, u8 *out);
 void __init fsverity_check_hash_algs(void);
 
 
        },
 };
 
+static DEFINE_MUTEX(fsverity_hash_alg_init_mutex);
+
 /**
  * fsverity_get_hash_alg() - validate and prepare a hash algorithm
  * @inode: optional inode for logging purposes
  *
  * Return: pointer to the hash alg on success, else an ERR_PTR()
  */
-const struct fsverity_hash_alg *fsverity_get_hash_alg(const struct inode *inode,
-                                                     unsigned int num)
+struct fsverity_hash_alg *fsverity_get_hash_alg(const struct inode *inode,
+                                               unsigned int num)
 {
        struct fsverity_hash_alg *alg;
        struct crypto_ahash *tfm;
        }
        alg = &fsverity_hash_algs[num];
 
-       /* pairs with cmpxchg() below */
-       tfm = READ_ONCE(alg->tfm);
-       if (likely(tfm != NULL))
+       /* pairs with smp_store_release() below */
+       if (likely(smp_load_acquire(&alg->tfm) != NULL))
                return alg;
+
+       mutex_lock(&fsverity_hash_alg_init_mutex);
+
+       if (alg->tfm != NULL)
+               goto out_unlock;
+
        /*
         * Using the shash API would make things a bit simpler, but the ahash
         * API is preferable as it allows the use of crypto accelerators.
                        fsverity_warn(inode,
                                      "Missing crypto API support for hash algorithm \"%s\"",
                                      alg->name);
-                       return ERR_PTR(-ENOPKG);
+                       alg = ERR_PTR(-ENOPKG);
+                       goto out_unlock;
                }
                fsverity_err(inode,
                             "Error allocating hash algorithm \"%s\": %ld",
                             alg->name, PTR_ERR(tfm));
-               return ERR_CAST(tfm);
+               alg = ERR_CAST(tfm);
+               goto out_unlock;
        }
 
        err = -EINVAL;
        if (WARN_ON(alg->block_size != crypto_ahash_blocksize(tfm)))
                goto err_free_tfm;
 
+       err = mempool_init_kmalloc_pool(&alg->req_pool, 1,
+                                       sizeof(struct ahash_request) +
+                                       crypto_ahash_reqsize(tfm));
+       if (err)
+               goto err_free_tfm;
+
        pr_info("%s using implementation \"%s\"\n",
                alg->name, crypto_ahash_driver_name(tfm));
 
-       /* pairs with READ_ONCE() above */
-       if (cmpxchg(&alg->tfm, NULL, tfm) != NULL)
-               crypto_free_ahash(tfm);
-
-       return alg;
+       /* pairs with smp_load_acquire() above */
+       smp_store_release(&alg->tfm, tfm);
+       goto out_unlock;
 
 err_free_tfm:
        crypto_free_ahash(tfm);
-       return ERR_PTR(err);
+       alg = ERR_PTR(err);
+out_unlock:
+       mutex_unlock(&fsverity_hash_alg_init_mutex);
+       return alg;
+}
+
+/**
+ * fsverity_alloc_hash_request() - allocate a hash request object
+ * @alg: the hash algorithm for which to allocate the request
+ * @gfp_flags: memory allocation flags
+ *
+ * This is mempool-backed, so this never fails if __GFP_DIRECT_RECLAIM is set in
+ * @gfp_flags.  However, in that case this might need to wait for all
+ * previously-allocated requests to be freed.  So to avoid deadlocks, callers
+ * must never need multiple requests at a time to make forward progress.
+ *
+ * Return: the request object on success; NULL on failure (but see above)
+ */
+struct ahash_request *fsverity_alloc_hash_request(struct fsverity_hash_alg *alg,
+                                                 gfp_t gfp_flags)
+{
+       struct ahash_request *req = mempool_alloc(&alg->req_pool, gfp_flags);
+
+       if (req)
+               ahash_request_set_tfm(req, alg->tfm);
+       return req;
+}
+
+/**
+ * fsverity_free_hash_request() - free a hash request object
+ * @alg: the hash algorithm
+ * @req: the hash request object to free
+ */
+void fsverity_free_hash_request(struct fsverity_hash_alg *alg,
+                               struct ahash_request *req)
+{
+       if (req) {
+               ahash_request_zero(req);
+               mempool_free(req, &alg->req_pool);
+       }
 }
 
 /**
  * Return: NULL if the salt is empty, otherwise the kmalloc()'ed precomputed
  *        initial hash state on success or an ERR_PTR() on failure.
  */
-const u8 *fsverity_prepare_hash_state(const struct fsverity_hash_alg *alg,
+const u8 *fsverity_prepare_hash_state(struct fsverity_hash_alg *alg,
                                      const u8 *salt, size_t salt_size)
 {
        u8 *hashstate = NULL;
        if (!hashstate)
                return ERR_PTR(-ENOMEM);
 
-       req = ahash_request_alloc(alg->tfm, GFP_KERNEL);
-       if (!req) {
-               err = -ENOMEM;
-               goto err_free;
-       }
+       /* This allocation never fails, since it's mempool-backed. */
+       req = fsverity_alloc_hash_request(alg, GFP_KERNEL);
 
        /*
         * Zero-pad the salt to the next multiple of the input size of the hash
        if (err)
                goto err_free;
 out:
-       ahash_request_free(req);
+       fsverity_free_hash_request(alg, req);
        kfree(padded_salt);
        return hashstate;
 
  *
  * Return: 0 on success, -errno on failure
  */
-int fsverity_hash_buffer(const struct fsverity_hash_alg *alg,
+int fsverity_hash_buffer(struct fsverity_hash_alg *alg,
                         const void *data, size_t size, u8 *out)
 {
        struct ahash_request *req;
        DECLARE_CRYPTO_WAIT(wait);
        int err;
 
-       req = ahash_request_alloc(alg->tfm, GFP_KERNEL);
-       if (!req)
-               return -ENOMEM;
+       /* This allocation never fails, since it's mempool-backed. */
+       req = fsverity_alloc_hash_request(alg, GFP_KERNEL);
 
        sg_init_one(&sg, data, size);
        ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP |
 
        err = crypto_wait_req(crypto_ahash_digest(req), &wait);
 
-       ahash_request_free(req);
+       fsverity_free_hash_request(alg, req);
        return err;
 }
 
 
        struct ahash_request *req;
        bool valid;
 
-       req = ahash_request_alloc(vi->tree_params.hash_alg->tfm, GFP_NOFS);
-       if (unlikely(!req))
-               return false;
+       /* This allocation never fails, since it's mempool-backed. */
+       req = fsverity_alloc_hash_request(vi->tree_params.hash_alg, GFP_NOFS);
 
        valid = verify_page(inode, vi, req, page, 0);
 
-       ahash_request_free(req);
+       fsverity_free_hash_request(vi->tree_params.hash_alg, req);
 
        return valid;
 }
        struct bvec_iter_all iter_all;
        unsigned long max_ra_pages = 0;
 
-       req = ahash_request_alloc(params->hash_alg->tfm, GFP_NOFS);
-       if (unlikely(!req)) {
-               bio_for_each_segment_all(bv, bio, iter_all)
-                       SetPageError(bv->bv_page);
-               return;
-       }
+       /* This allocation never fails, since it's mempool-backed. */
+       req = fsverity_alloc_hash_request(params->hash_alg, GFP_NOFS);
 
        if (bio->bi_opf & REQ_RAHEAD) {
                /*
                        SetPageError(page);
        }
 
-       ahash_request_free(req);
+       fsverity_free_hash_request(params->hash_alg, req);
 }
 EXPORT_SYMBOL_GPL(fsverity_verify_bio);
 #endif /* CONFIG_BLOCK */