]> www.infradead.org Git - users/hch/misc.git/commitdiff
psp: provide decapsulation and receive helper for drivers
authorRaed Salem <raeds@nvidia.com>
Wed, 17 Sep 2025 00:09:44 +0000 (17:09 -0700)
committerPaolo Abeni <pabeni@redhat.com>
Thu, 18 Sep 2025 10:32:07 +0000 (12:32 +0200)
Create psp_dev_rcv(), which drivers can call to psp decapsulate and attach
a psp_skb_ext to an skb.

psp_dev_rcv() only supports what the PSP architecture specification
refers to as "transport mode" packets, where the L3 header is either
IPv6 or IPv4.

Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Raed Salem <raeds@nvidia.com>
Signed-off-by: Rahul Rameshbabu <rrameshbabu@nvidia.com>
Signed-off-by: Cosmin Ratiu <cratiu@nvidia.com>
Co-developed-by: Daniel Zahka <daniel.zahka@gmail.com>
Signed-off-by: Daniel Zahka <daniel.zahka@gmail.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Link: https://patch.msgid.link/20250917000954.859376-18-daniel.zahka@gmail.com
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
include/net/psp/functions.h
net/psp/psp_main.c

index 0a539e1b39f4beda06b524e3c303c6b604c2629d..91ba067333219e37ee744b98f3c301a7d0ebd9c9 100644 (file)
@@ -19,6 +19,7 @@ psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
 void psp_dev_unregister(struct psp_dev *psd);
 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
                         u8 ver, __be16 sport);
+int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
 
 /* Kernel-facing API */
 void psp_assoc_put(struct psp_assoc *pas);
index e026880fa1a281f930f519ffa2bafbd72657ac35..b4b756f87382c46815f0d74e8b765c962e74eba5 100644 (file)
@@ -223,6 +223,94 @@ bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
 }
 EXPORT_SYMBOL(psp_dev_encapsulate);
 
+/* Receive handler for PSP packets.
+ *
+ * Presently it accepts only already-authenticated packets and does not
+ * support optional fields, such as virtualization cookies. The caller should
+ * ensure that skb->data is pointing to the mac header, and that skb->mac_len
+ * is set.
+ */
+int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
+{
+       int l2_hlen = 0, l3_hlen, encap;
+       struct psp_skb_ext *pse;
+       struct psphdr *psph;
+       struct ethhdr *eth;
+       struct udphdr *uh;
+       __be16 proto;
+       bool is_udp;
+
+       eth = (struct ethhdr *)skb->data;
+       proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
+       if (proto == htons(ETH_P_IP))
+               l3_hlen = sizeof(struct iphdr);
+       else if (proto == htons(ETH_P_IPV6))
+               l3_hlen = sizeof(struct ipv6hdr);
+       else
+               return -EINVAL;
+
+       if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
+               return -EINVAL;
+
+       if (proto == htons(ETH_P_IP)) {
+               struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
+
+               is_udp = iph->protocol == IPPROTO_UDP;
+               l3_hlen = iph->ihl * 4;
+               if (l3_hlen != sizeof(struct iphdr) &&
+                   !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
+                       return -EINVAL;
+       } else {
+               struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
+
+               is_udp = ipv6h->nexthdr == IPPROTO_UDP;
+       }
+
+       if (unlikely(!is_udp))
+               return -EINVAL;
+
+       uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
+       if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
+               return -EINVAL;
+
+       pse = skb_ext_add(skb, SKB_EXT_PSP);
+       if (!pse)
+               return -EINVAL;
+
+       psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
+                                sizeof(struct udphdr));
+       pse->spi = psph->spi;
+       pse->dev_id = dev_id;
+       pse->generation = generation;
+       pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
+
+       encap = PSP_ENCAP_HLEN;
+       encap += strip_icv ? PSP_TRL_SIZE : 0;
+
+       if (proto == htons(ETH_P_IP)) {
+               struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
+
+               iph->protocol = psph->nexthdr;
+               iph->tot_len = htons(ntohs(iph->tot_len) - encap);
+               iph->check = 0;
+               iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
+       } else {
+               struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
+
+               ipv6h->nexthdr = psph->nexthdr;
+               ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
+       }
+
+       memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
+       skb_pull(skb, PSP_ENCAP_HLEN);
+
+       if (strip_icv)
+               pskb_trim(skb, skb->len - PSP_TRL_SIZE);
+
+       return 0;
+}
+EXPORT_SYMBOL(psp_dev_rcv);
+
 static int __init psp_init(void)
 {
        mutex_init(&psp_devs_lock);