* IP Payload Compression Protocol (IPComp) - RFC3173.
  *
  * Copyright (c) 2003 James Morris <jmorris@intercode.com.au>
- * Copyright (c) 2003-2008 Herbert Xu <herbert@gondor.apana.org.au>
+ * Copyright (c) 2003-2025 Herbert Xu <herbert@gondor.apana.org.au>
  *
  * Todo:
  *   - Tunable compression parameters.
  *   - Adaptive compression.
  */
 
-#include <linux/crypto.h>
+#include <crypto/acompress.h>
 #include <linux/err.h>
-#include <linux/list.h>
 #include <linux/module.h>
-#include <linux/mutex.h>
-#include <linux/percpu.h>
+#include <linux/skbuff_ref.h>
 #include <linux/slab.h>
-#include <linux/smp.h>
-#include <linux/vmalloc.h>
-#include <net/ip.h>
 #include <net/ipcomp.h>
 #include <net/xfrm.h>
 
-struct ipcomp_tfms {
-       struct list_head list;
-       struct crypto_comp * __percpu *tfms;
-       int users;
+#define IPCOMP_SCRATCH_SIZE 65400
+
+struct ipcomp_skb_cb {
+       struct xfrm_skb_cb xfrm;
+       struct acomp_req *req;
 };
 
-static DEFINE_MUTEX(ipcomp_resource_mutex);
-static void * __percpu *ipcomp_scratches;
-static int ipcomp_scratch_users;
-static LIST_HEAD(ipcomp_tfms_list);
+struct ipcomp_data {
+       u16 threshold;
+       struct crypto_acomp *tfm;
+};
 
-static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb)
+struct ipcomp_req_extra {
+       struct xfrm_state *x;
+       struct scatterlist sg[];
+};
+
+static inline struct ipcomp_skb_cb *ipcomp_cb(struct sk_buff *skb)
 {
-       struct ipcomp_data *ipcd = x->data;
-       const int plen = skb->len;
-       int dlen = IPCOMP_SCRATCH_SIZE;
-       const u8 *start = skb->data;
-       u8 *scratch = *this_cpu_ptr(ipcomp_scratches);
-       struct crypto_comp *tfm = *this_cpu_ptr(ipcd->tfms);
-       int err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen);
-       int len;
+       struct ipcomp_skb_cb *cb = (void *)skb->cb;
 
-       if (err)
-               return err;
+       BUILD_BUG_ON(sizeof(*cb) > sizeof(skb->cb));
+       return cb;
+}
 
-       if (dlen < (plen + sizeof(struct ip_comp_hdr)))
-               return -EINVAL;
+static int ipcomp_post_acomp(struct sk_buff *skb, int err, int hlen)
+{
+       struct acomp_req *req = ipcomp_cb(skb)->req;
+       struct ipcomp_req_extra *extra;
+       const int plen = skb->data_len;
+       struct scatterlist *dsg;
+       int len, dlen;
+
+       if (unlikely(err))
+               goto out_free_req;
 
-       len = dlen - plen;
-       if (len > skb_tailroom(skb))
-               len = skb_tailroom(skb);
+       extra = acomp_request_extra(req);
+       dsg = extra->sg;
+       dlen = req->dlen;
 
-       __skb_put(skb, len);
+       pskb_trim_unique(skb, 0);
+       __skb_put(skb, hlen);
 
-       len += plen;
-       skb_copy_to_linear_data(skb, scratch, len);
+       /* Only update truesize on input. */
+       if (!hlen)
+               skb->truesize += dlen - plen;
+       skb->data_len = dlen;
+       skb->len += dlen;
 
-       while ((scratch += len, dlen -= len) > 0) {
+       do {
                skb_frag_t *frag;
                struct page *page;
 
-               if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS))
-                       return -EMSGSIZE;
-
                frag = skb_shinfo(skb)->frags + skb_shinfo(skb)->nr_frags;
-               page = alloc_page(GFP_ATOMIC);
-
-               if (!page)
-                       return -ENOMEM;
+               page = sg_page(dsg);
+               dsg = sg_next(dsg);
 
                len = PAGE_SIZE;
                if (dlen < len)
                        len = dlen;
 
                skb_frag_fill_page_desc(frag, page, 0, len);
-               memcpy(skb_frag_address(frag), scratch, len);
-
-               skb->truesize += len;
-               skb->data_len += len;
-               skb->len += len;
 
                skb_shinfo(skb)->nr_frags++;
-       }
+       } while ((dlen -= len));
 
-       return 0;
+       for (; dsg; dsg = sg_next(dsg))
+               __free_page(sg_page(dsg));
+
+out_free_req:
+       acomp_request_free(req);
+       return err;
 }
 
-int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb)
+static int ipcomp_input_done2(struct sk_buff *skb, int err)
 {
-       int nexthdr;
-       int err = -ENOMEM;
-       struct ip_comp_hdr *ipch;
-
-       if (!pskb_may_pull(skb, sizeof(*ipch)))
-               return -EINVAL;
-
-       if (skb_linearize_cow(skb))
-               goto out;
-
-       skb->ip_summed = CHECKSUM_NONE;
+       struct ip_comp_hdr *ipch = ip_comp_hdr(skb);
+       const int plen = skb->len;
 
-       /* Remove ipcomp header and decompress original payload */
-       ipch = (void *)skb->data;
-       nexthdr = ipch->nexthdr;
+       skb_reset_transport_header(skb);
 
-       skb->transport_header = skb->network_header + sizeof(*ipch);
-       __skb_pull(skb, sizeof(*ipch));
-       err = ipcomp_decompress(x, skb);
-       if (err)
-               goto out;
+       return ipcomp_post_acomp(skb, err, 0) ?:
+              skb->len < (plen + sizeof(ip_comp_hdr)) ? -EINVAL :
+              ipch->nexthdr;
+}
 
-       err = nexthdr;
+static void ipcomp_input_done(void *data, int err)
+{
+       struct sk_buff *skb = data;
 
-out:
-       return err;
+       xfrm_input_resume(skb, ipcomp_input_done2(skb, err));
 }
-EXPORT_SYMBOL_GPL(ipcomp_input);
 
-static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb)
+static struct acomp_req *ipcomp_setup_req(struct xfrm_state *x,
+                                         struct sk_buff *skb, int minhead,
+                                         int dlen)
 {
+       const int dnfrags = min(MAX_SKB_FRAGS, 16);
        struct ipcomp_data *ipcd = x->data;
+       struct ipcomp_req_extra *extra;
+       struct scatterlist *sg, *dsg;
        const int plen = skb->len;
-       int dlen = IPCOMP_SCRATCH_SIZE;
-       u8 *start = skb->data;
-       struct crypto_comp *tfm;
-       u8 *scratch;
+       struct crypto_acomp *tfm;
+       struct acomp_req *req;
+       int nfrags;
+       int total;
        int err;
+       int i;
 
-       local_bh_disable();
-       scratch = *this_cpu_ptr(ipcomp_scratches);
-       tfm = *this_cpu_ptr(ipcd->tfms);
-       err = crypto_comp_compress(tfm, start, plen, scratch, &dlen);
-       if (err)
-               goto out;
-
-       if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) {
-               err = -EMSGSIZE;
-               goto out;
-       }
+       ipcomp_cb(skb)->req = NULL;
 
-       memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen);
-       local_bh_enable();
+       do {
+               struct sk_buff *trailer;
 
-       pskb_trim(skb, dlen + sizeof(struct ip_comp_hdr));
-       return 0;
+               if (skb->len > PAGE_SIZE) {
+                       if (skb_linearize_cow(skb))
+                               return ERR_PTR(-ENOMEM);
+                       nfrags = 1;
+                       break;
+               }
 
-out:
-       local_bh_enable();
-       return err;
-}
+               if (!skb_cloned(skb) && skb_headlen(skb) >= minhead) {
+                       if (!skb_is_nonlinear(skb)) {
+                               nfrags = 1;
+                               break;
+                       } else if (!skb_has_frag_list(skb)) {
+                               nfrags = skb_shinfo(skb)->nr_frags;
+                               nfrags++;
+                               break;
+                       }
+               }
 
-int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb)
-{
-       int err;
-       struct ip_comp_hdr *ipch;
-       struct ipcomp_data *ipcd = x->data;
+               nfrags = skb_cow_data(skb, skb_headlen(skb) < minhead ?
+                                          minhead - skb_headlen(skb) : 0,
+                                     &trailer);
+               if (nfrags < 0)
+                       return ERR_PTR(nfrags);
+       } while (0);
+
+       tfm = ipcd->tfm;
+       req = acomp_request_alloc_extra(
+               tfm, sizeof(*extra) + sizeof(*sg) * (nfrags + dnfrags),
+               GFP_ATOMIC);
+       ipcomp_cb(skb)->req = req;
+       if (!req)
+               return ERR_PTR(-ENOMEM);
+
+       extra = acomp_request_extra(req);
+       extra->x = x;
+
+       dsg = extra->sg;
+       sg = dsg + dnfrags;
+       sg_init_table(sg, nfrags);
+       err = skb_to_sgvec(skb, sg, 0, plen);
+       if (unlikely(err < 0))
+               return ERR_PTR(err);
+
+       sg_init_table(dsg, dnfrags);
+       total = 0;
+       for (i = 0; i < dnfrags && total < dlen; i++) {
+               struct page *page;
 
-       if (skb->len < ipcd->threshold) {
-               /* Don't bother compressing */
-               goto out_ok;
+               page = alloc_page(GFP_ATOMIC);
+               if (!page)
+                       break;
+               sg_set_page(dsg + i, page, PAGE_SIZE, 0);
+               total += PAGE_SIZE;
        }
+       if (!i)
+               return ERR_PTR(-ENOMEM);
+       sg_mark_end(dsg + i - 1);
+       dlen = min(dlen, total);
 
-       if (skb_linearize_cow(skb))
-               goto out_ok;
-
-       err = ipcomp_compress(x, skb);
+       acomp_request_set_params(req, sg, dsg, plen, dlen);
 
-       if (err) {
-               goto out_ok;
-       }
-
-       /* Install ipcomp header, convert into ipcomp datagram. */
-       ipch = ip_comp_hdr(skb);
-       ipch->nexthdr = *skb_mac_header(skb);
-       ipch->flags = 0;
-       ipch->cpi = htons((u16 )ntohl(x->id.spi));
-       *skb_mac_header(skb) = IPPROTO_COMP;
-out_ok:
-       skb_push(skb, -skb_network_offset(skb));
-       return 0;
+       return req;
 }
-EXPORT_SYMBOL_GPL(ipcomp_output);
 
-static void ipcomp_free_scratches(void)
+static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb)
 {
-       int i;
-       void * __percpu *scratches;
-
-       if (--ipcomp_scratch_users)
-               return;
+       struct acomp_req *req;
+       int err;
 
-       scratches = ipcomp_scratches;
-       if (!scratches)
-               return;
+       req = ipcomp_setup_req(x, skb, 0, IPCOMP_SCRATCH_SIZE);
+       err = PTR_ERR(req);
+       if (IS_ERR(req))
+               goto out;
 
-       for_each_possible_cpu(i)
-               vfree(*per_cpu_ptr(scratches, i));
+       acomp_request_set_callback(req, 0, ipcomp_input_done, skb);
+       err = crypto_acomp_decompress(req);
+       if (err == -EINPROGRESS)
+               return err;
 
-       free_percpu(scratches);
-       ipcomp_scratches = NULL;
+out:
+       return ipcomp_input_done2(skb, err);
 }
 
-static void * __percpu *ipcomp_alloc_scratches(void)
+int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb)
 {
-       void * __percpu *scratches;
-       int i;
-
-       if (ipcomp_scratch_users++)
-               return ipcomp_scratches;
+       struct ip_comp_hdr *ipch __maybe_unused;
 
-       scratches = alloc_percpu(void *);
-       if (!scratches)
-               return NULL;
-
-       ipcomp_scratches = scratches;
+       if (!pskb_may_pull(skb, sizeof(*ipch)))
+               return -EINVAL;
 
-       for_each_possible_cpu(i) {
-               void *scratch;
+       skb->ip_summed = CHECKSUM_NONE;
 
-               scratch = vmalloc_node(IPCOMP_SCRATCH_SIZE, cpu_to_node(i));
-               if (!scratch)
-                       return NULL;
-               *per_cpu_ptr(scratches, i) = scratch;
-       }
+       /* Remove ipcomp header and decompress original payload */
+       __skb_pull(skb, sizeof(*ipch));
 
-       return scratches;
+       return ipcomp_decompress(x, skb);
 }
+EXPORT_SYMBOL_GPL(ipcomp_input);
 
-static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms)
+static int ipcomp_output_push(struct sk_buff *skb)
 {
-       struct ipcomp_tfms *pos;
-       int cpu;
-
-       list_for_each_entry(pos, &ipcomp_tfms_list, list) {
-               if (pos->tfms == tfms)
-                       break;
-       }
-
-       WARN_ON(list_entry_is_head(pos, &ipcomp_tfms_list, list));
-
-       if (--pos->users)
-               return;
+       skb_push(skb, -skb_network_offset(skb));
+       return 0;
+}
 
-       list_del(&pos->list);
-       kfree(pos);
+static int ipcomp_output_done2(struct xfrm_state *x, struct sk_buff *skb,
+                              int err)
+{
+       struct ip_comp_hdr *ipch;
 
-       if (!tfms)
-               return;
+       err = ipcomp_post_acomp(skb, err, sizeof(*ipch));
+       if (err)
+               goto out_ok;
 
-       for_each_possible_cpu(cpu) {
-               struct crypto_comp *tfm = *per_cpu_ptr(tfms, cpu);
-               crypto_free_comp(tfm);
-       }
-       free_percpu(tfms);
+       /* Install ipcomp header, convert into ipcomp datagram. */
+       ipch = ip_comp_hdr(skb);
+       ipch->nexthdr = *skb_mac_header(skb);
+       ipch->flags = 0;
+       ipch->cpi = htons((u16 )ntohl(x->id.spi));
+       *skb_mac_header(skb) = IPPROTO_COMP;
+out_ok:
+       return ipcomp_output_push(skb);
 }
 
-static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name)
+static void ipcomp_output_done(void *data, int err)
 {
-       struct ipcomp_tfms *pos;
-       struct crypto_comp * __percpu *tfms;
-       int cpu;
+       struct ipcomp_req_extra *extra;
+       struct sk_buff *skb = data;
+       struct acomp_req *req;
 
+       req = ipcomp_cb(skb)->req;
+       extra = acomp_request_extra(req);
 
-       list_for_each_entry(pos, &ipcomp_tfms_list, list) {
-               struct crypto_comp *tfm;
+       xfrm_output_resume(skb_to_full_sk(skb), skb,
+                          ipcomp_output_done2(extra->x, skb, err));
+}
 
-               /* This can be any valid CPU ID so we don't need locking. */
-               tfm = this_cpu_read(*pos->tfms);
+static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb)
+{
+       struct ip_comp_hdr *ipch __maybe_unused;
+       struct acomp_req *req;
+       int err;
 
-               if (!strcmp(crypto_comp_name(tfm), alg_name)) {
-                       pos->users++;
-                       return pos->tfms;
-               }
-       }
+       req = ipcomp_setup_req(x, skb, sizeof(*ipch),
+                              skb->len - sizeof(*ipch));
+       err = PTR_ERR(req);
+       if (IS_ERR(req))
+               goto out;
 
-       pos = kmalloc(sizeof(*pos), GFP_KERNEL);
-       if (!pos)
-               return NULL;
+       acomp_request_set_callback(req, 0, ipcomp_output_done, skb);
+       err = crypto_acomp_compress(req);
+       if (err == -EINPROGRESS)
+               return err;
 
-       pos->users = 1;
-       INIT_LIST_HEAD(&pos->list);
-       list_add(&pos->list, &ipcomp_tfms_list);
+out:
+       return ipcomp_output_done2(x, skb, err);
+}
 
-       pos->tfms = tfms = alloc_percpu(struct crypto_comp *);
-       if (!tfms)
-               goto error;
+int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb)
+{
+       struct ipcomp_data *ipcd = x->data;
 
-       for_each_possible_cpu(cpu) {
-               struct crypto_comp *tfm = crypto_alloc_comp(alg_name, 0,
-                                                           CRYPTO_ALG_ASYNC);
-               if (IS_ERR(tfm))
-                       goto error;
-               *per_cpu_ptr(tfms, cpu) = tfm;
+       if (skb->len < ipcd->threshold) {
+               /* Don't bother compressing */
+               return ipcomp_output_push(skb);
        }
 
-       return tfms;
-
-error:
-       ipcomp_free_tfms(tfms);
-       return NULL;
+       return ipcomp_compress(x, skb);
 }
+EXPORT_SYMBOL_GPL(ipcomp_output);
 
 static void ipcomp_free_data(struct ipcomp_data *ipcd)
 {
-       if (ipcd->tfms)
-               ipcomp_free_tfms(ipcd->tfms);
-       ipcomp_free_scratches();
+       crypto_free_acomp(ipcd->tfm);
 }
 
 void ipcomp_destroy(struct xfrm_state *x)
        if (!ipcd)
                return;
        xfrm_state_delete_tunnel(x);
-       mutex_lock(&ipcomp_resource_mutex);
        ipcomp_free_data(ipcd);
-       mutex_unlock(&ipcomp_resource_mutex);
        kfree(ipcd);
 }
 EXPORT_SYMBOL_GPL(ipcomp_destroy);
        if (!ipcd)
                goto out;
 
-       mutex_lock(&ipcomp_resource_mutex);
-       if (!ipcomp_alloc_scratches())
-               goto error;
-
-       ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name);
-       if (!ipcd->tfms)
+       ipcd->tfm = crypto_alloc_acomp(x->calg->alg_name, 0, 0);
+       if (IS_ERR(ipcd->tfm))
                goto error;
-       mutex_unlock(&ipcomp_resource_mutex);
 
        calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0);
        BUG_ON(!calg_desc);
 
 error:
        ipcomp_free_data(ipcd);
-       mutex_unlock(&ipcomp_resource_mutex);
        kfree(ipcd);
        goto out;
 }