mctp_test_route_input_multiple_nets_key_fini(test, &t2);
 }
 
+#if IS_ENABLED(CONFIG_MCTP_FLOWS)
+
+static void mctp_test_flow_init(struct kunit *test,
+                               struct mctp_test_dev **devp,
+                               struct mctp_test_route **rtp,
+                               struct socket **sock,
+                               struct sk_buff **skbp,
+                               unsigned int len)
+{
+       struct mctp_test_route *rt;
+       struct mctp_test_dev *dev;
+       struct sk_buff *skb;
+
+       /* we have a slightly odd routing setup here; the test route
+        * is for EID 8, which is our local EID. We don't do a routing
+        * lookup, so that's fine - all we require is a path through
+        * mctp_local_output, which will call rt->output on whatever
+        * route we provide
+        */
+       __mctp_route_test_init(test, &dev, &rt, sock, MCTP_NET_ANY);
+
+       /* Assign a single EID. ->addrs is freed on mctp netdev release */
+       dev->mdev->addrs = kmalloc(sizeof(u8), GFP_KERNEL);
+       dev->mdev->num_addrs = 1;
+       dev->mdev->addrs[0] = 8;
+
+       skb = alloc_skb(len + sizeof(struct mctp_hdr) + 1, GFP_KERNEL);
+       KUNIT_ASSERT_TRUE(test, skb);
+       __mctp_cb(skb);
+       skb_reserve(skb, sizeof(struct mctp_hdr) + 1);
+       memset(skb_put(skb, len), 0, len);
+
+       /* take a ref for the route, we'll decrement in local output */
+       refcount_inc(&rt->rt.refs);
+
+       *devp = dev;
+       *rtp = rt;
+       *skbp = skb;
+}
+
+static void mctp_test_flow_fini(struct kunit *test,
+                               struct mctp_test_dev *dev,
+                               struct mctp_test_route *rt,
+                               struct socket *sock)
+{
+       __mctp_route_test_fini(test, dev, rt, sock);
+}
+
+/* test that an outgoing skb has the correct MCTP extension data set */
+static void mctp_test_packet_flow(struct kunit *test)
+{
+       struct sk_buff *skb, *skb2;
+       struct mctp_test_route *rt;
+       struct mctp_test_dev *dev;
+       struct mctp_flow *flow;
+       struct socket *sock;
+       u8 dst = 8;
+       int n, rc;
+
+       mctp_test_flow_init(test, &dev, &rt, &sock, &skb, 30);
+
+       rc = mctp_local_output(sock->sk, &rt->rt, skb, dst, MCTP_TAG_OWNER);
+       KUNIT_ASSERT_EQ(test, rc, 0);
+
+       n = rt->pkts.qlen;
+       KUNIT_ASSERT_EQ(test, n, 1);
+
+       skb2 = skb_dequeue(&rt->pkts);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2);
+
+       flow = skb_ext_find(skb2, SKB_EXT_MCTP);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow->key);
+       KUNIT_ASSERT_PTR_EQ(test, flow->key->sk, sock->sk);
+
+       kfree_skb(skb2);
+       mctp_test_flow_fini(test, dev, rt, sock);
+}
+
+/* test that outgoing skbs, after fragmentation, all have the correct MCTP
+ * extension data set.
+ */
+static void mctp_test_fragment_flow(struct kunit *test)
+{
+       struct mctp_flow *flows[2];
+       struct sk_buff *tx_skbs[2];
+       struct mctp_test_route *rt;
+       struct mctp_test_dev *dev;
+       struct sk_buff *skb;
+       struct socket *sock;
+       u8 dst = 8;
+       int n, rc;
+
+       mctp_test_flow_init(test, &dev, &rt, &sock, &skb, 100);
+
+       rc = mctp_local_output(sock->sk, &rt->rt, skb, dst, MCTP_TAG_OWNER);
+       KUNIT_ASSERT_EQ(test, rc, 0);
+
+       n = rt->pkts.qlen;
+       KUNIT_ASSERT_EQ(test, n, 2);
+
+       /* both resulting packets should have the same flow data */
+       tx_skbs[0] = skb_dequeue(&rt->pkts);
+       tx_skbs[1] = skb_dequeue(&rt->pkts);
+
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[0]);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[1]);
+
+       flows[0] = skb_ext_find(tx_skbs[0], SKB_EXT_MCTP);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]->key);
+       KUNIT_ASSERT_PTR_EQ(test, flows[0]->key->sk, sock->sk);
+
+       flows[1] = skb_ext_find(tx_skbs[1], SKB_EXT_MCTP);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[1]);
+       KUNIT_ASSERT_PTR_EQ(test, flows[1]->key, flows[0]->key);
+
+       kfree_skb(tx_skbs[0]);
+       kfree_skb(tx_skbs[1]);
+       mctp_test_flow_fini(test, dev, rt, sock);
+}
+
+#else
+static void mctp_test_packet_flow(struct kunit *test)
+{
+       kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y");
+}
+
+static void mctp_test_fragment_flow(struct kunit *test)
+{
+       kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y");
+}
+#endif
+
 static struct kunit_case mctp_test_cases[] = {
        KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
        KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
                         mctp_route_input_sk_keys_gen_params),
        KUNIT_CASE(mctp_test_route_input_multiple_nets_bind),
        KUNIT_CASE(mctp_test_route_input_multiple_nets_key),
+       KUNIT_CASE(mctp_test_packet_flow),
+       KUNIT_CASE(mctp_test_fragment_flow),
        {}
 };