static int vti_net_id __read_mostly;
 static int vti_tunnel_init(struct net_device *dev);
 
-/* We dont digest the packet therefore let the packet pass */
-static int vti_rcv(struct sk_buff *skb)
+static int vti_input(struct sk_buff *skb, int nexthdr, __be32 spi,
+                    int encap_type)
 {
        struct ip_tunnel *tunnel;
        const struct iphdr *iph = ip_hdr(skb);
        tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY,
                                  iph->saddr, iph->daddr, 0);
        if (tunnel != NULL) {
-               struct pcpu_sw_netstats *tstats;
-               u32 oldmark = skb->mark;
-               int ret;
-
-
-               /* temporarily mark the skb with the tunnel o_key, to
-                * only match policies with this mark.
-                */
-               skb->mark = be32_to_cpu(tunnel->parms.o_key);
-               ret = xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb);
-               skb->mark = oldmark;
-               if (!ret)
-                       return -1;
-
-               tstats = this_cpu_ptr(tunnel->dev->tstats);
-               u64_stats_update_begin(&tstats->syncp);
-               tstats->rx_packets++;
-               tstats->rx_bytes += skb->len;
-               u64_stats_update_end(&tstats->syncp);
-
-               secpath_reset(skb);
-               skb->dev = tunnel->dev;
+               if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
+                       goto drop;
+
+               XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = tunnel;
+               skb->mark = be32_to_cpu(tunnel->parms.i_key);
+
+               return xfrm_input(skb, nexthdr, spi, encap_type);
+       }
+
+       return -EINVAL;
+drop:
+       kfree_skb(skb);
+       return 0;
+}
+
+static int vti_rcv(struct sk_buff *skb)
+{
+       XFRM_SPI_SKB_CB(skb)->family = AF_INET;
+       XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr);
+
+       return vti_input(skb, ip_hdr(skb)->protocol, 0, 0);
+}
+
+static int vti_rcv_cb(struct sk_buff *skb, int err)
+{
+       unsigned short family;
+       struct net_device *dev;
+       struct pcpu_sw_netstats *tstats;
+       struct xfrm_state *x;
+       struct ip_tunnel *tunnel = XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4;
+
+       if (!tunnel)
                return 1;
+
+       dev = tunnel->dev;
+
+       if (err) {
+               dev->stats.rx_errors++;
+               dev->stats.rx_dropped++;
+
+               return 0;
        }
 
-       return -1;
+       x = xfrm_input_state(skb);
+       family = x->inner_mode->afinfo->family;
+
+       if (!xfrm_policy_check(NULL, XFRM_POLICY_IN, skb, family))
+               return -EPERM;
+
+       skb_scrub_packet(skb, !net_eq(tunnel->net, dev_net(skb->dev)));
+       skb->dev = dev;
+
+       tstats = this_cpu_ptr(dev->tstats);
+
+       u64_stats_update_begin(&tstats->syncp);
+       tstats->rx_packets++;
+       tstats->rx_bytes += skb->len;
+       u64_stats_update_end(&tstats->syncp);
+
+       return 0;
 }
 
 /* This function assumes it is being called from dev_queue_xmit()
  * and that skb is filled properly by that function.
  */
-
 static netdev_tx_t vti_tunnel_xmit(struct sk_buff *skb, struct net_device *dev)
 {
        struct ip_tunnel *tunnel = netdev_priv(dev);
-       struct iphdr  *tiph = &tunnel->parms.iph;
-       u8     tos;
        struct rtable *rt;              /* Route to the other host */
        struct net_device *tdev;        /* Device to other host */
-       struct iphdr  *old_iph = ip_hdr(skb);
-       __be32 dst = tiph->daddr;
-       struct flowi4 fl4;
+       struct flowi fl;
        int err;
 
        if (skb->protocol != htons(ETH_P_IP))
                goto tx_error;
 
-       tos = old_iph->tos;
+       memset(&fl, 0, sizeof(fl));
+       skb->mark = be32_to_cpu(tunnel->parms.o_key);
+       xfrm_decode_session(skb, &fl, AF_INET);
+
+       if (!skb_dst(skb)) {
+               dev->stats.tx_carrier_errors++;
+               goto tx_error_icmp;
+       }
 
-       memset(&fl4, 0, sizeof(fl4));
-       flowi4_init_output(&fl4, tunnel->parms.link,
-                          be32_to_cpu(tunnel->parms.o_key), RT_TOS(tos),
-                          RT_SCOPE_UNIVERSE,
-                          IPPROTO_IPIP, 0,
-                          dst, tiph->saddr, 0, 0);
-       rt = ip_route_output_key(dev_net(dev), &fl4);
+       dst_hold(skb_dst(skb));
+       rt = (struct rtable *)xfrm_lookup(tunnel->net, skb_dst(skb), &fl, NULL, 0);
        if (IS_ERR(rt)) {
                dev->stats.tx_carrier_errors++;
                goto tx_error_icmp;
        }
+
        /* if there is no transform then this tunnel is not functional.
         * Or if the xfrm is not mode tunnel.
         */
        }
 
        memset(IPCB(skb), 0, sizeof(*IPCB(skb)));
-       skb_dst_drop(skb);
+       skb_scrub_packet(skb, !net_eq(tunnel->net, dev_net(dev)));
        skb_dst_set(skb, &rt->dst);
-       nf_reset(skb);
        skb->dev = skb_dst(skb)->dev;
 
        err = dst_output(skb);
        return NETDEV_TX_OK;
 }
 
+static int vti4_err(struct sk_buff *skb, u32 info)
+{
+       __be32 spi;
+       struct xfrm_state *x;
+       struct ip_tunnel *tunnel;
+       struct ip_esp_hdr *esph;
+       struct ip_auth_hdr *ah ;
+       struct ip_comp_hdr *ipch;
+       struct net *net = dev_net(skb->dev);
+       const struct iphdr *iph = (const struct iphdr *)skb->data;
+       int protocol = iph->protocol;
+       struct ip_tunnel_net *itn = net_generic(net, vti_net_id);
+
+       tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY,
+                                 iph->daddr, iph->saddr, 0);
+       if (!tunnel)
+               return -1;
+
+       switch (protocol) {
+       case IPPROTO_ESP:
+               esph = (struct ip_esp_hdr *)(skb->data+(iph->ihl<<2));
+               spi = esph->spi;
+               break;
+       case IPPROTO_AH:
+               ah = (struct ip_auth_hdr *)(skb->data+(iph->ihl<<2));
+               spi = ah->spi;
+               break;
+       case IPPROTO_COMP:
+               ipch = (struct ip_comp_hdr *)(skb->data+(iph->ihl<<2));
+               spi = htonl(ntohs(ipch->cpi));
+               break;
+       default:
+               return 0;
+       }
+
+       switch (icmp_hdr(skb)->type) {
+       case ICMP_DEST_UNREACH:
+               if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED)
+                       return 0;
+       case ICMP_REDIRECT:
+               break;
+       default:
+               return 0;
+       }
+
+       x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr,
+                             spi, protocol, AF_INET);
+       if (!x)
+               return 0;
+
+       if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH)
+               ipv4_update_pmtu(skb, net, info, 0, 0, protocol, 0);
+       else
+               ipv4_redirect(skb, net, 0, 0, protocol, 0);
+       xfrm_state_put(x);
+
+       return 0;
+}
+
 static int
 vti_tunnel_ioctl(struct net_device *dev, struct ifreq *ifr, int cmd)
 {
                        return -EINVAL;
        }
 
+       p.i_flags |= VTI_ISVTI;
        err = ip_tunnel_ioctl(dev, &p, cmd);
        if (err)
                return err;
 
        if (cmd != SIOCDELTUNNEL) {
-               p.i_flags |= GRE_KEY | VTI_ISVTI;
+               p.i_flags |= GRE_KEY;
                p.o_flags |= GRE_KEY;
        }
 
        iph->ihl                = 5;
 }
 
-static struct xfrm_tunnel_notifier vti_handler __read_mostly = {
+static struct xfrm4_protocol vti_esp4_protocol __read_mostly = {
        .handler        =       vti_rcv,
-       .priority       =       1,
+       .input_handler  =       vti_input,
+       .cb_handler     =       vti_rcv_cb,
+       .err_handler    =       vti4_err,
+       .priority       =       100,
+};
+
+static struct xfrm4_protocol vti_ah4_protocol __read_mostly = {
+       .handler        =       vti_rcv,
+       .input_handler  =       vti_input,
+       .cb_handler     =       vti_rcv_cb,
+       .err_handler    =       vti4_err,
+       .priority       =       100,
+};
+
+static struct xfrm4_protocol vti_ipcomp4_protocol __read_mostly = {
+       .handler        =       vti_rcv,
+       .input_handler  =       vti_input,
+       .cb_handler     =       vti_rcv_cb,
+       .err_handler    =       vti4_err,
+       .priority       =       100,
 };
 
 static int __net_init vti_init_net(struct net *net)
        if (!data)
                return;
 
+       parms->i_flags = VTI_ISVTI;
+
        if (data[IFLA_VTI_LINK])
                parms->link = nla_get_u32(data[IFLA_VTI_LINK]);
 
        err = register_pernet_device(&vti_net_ops);
        if (err < 0)
                return err;
-       err = xfrm4_mode_tunnel_input_register(&vti_handler);
+       err = xfrm4_protocol_register(&vti_esp4_protocol, IPPROTO_ESP);
+       if (err < 0) {
+               unregister_pernet_device(&vti_net_ops);
+               pr_info("vti init: can't register tunnel\n");
+
+               return err;
+       }
+
+       err = xfrm4_protocol_register(&vti_ah4_protocol, IPPROTO_AH);
+       if (err < 0) {
+               xfrm4_protocol_deregister(&vti_esp4_protocol, IPPROTO_ESP);
+               unregister_pernet_device(&vti_net_ops);
+               pr_info("vti init: can't register tunnel\n");
+
+               return err;
+       }
+
+       err = xfrm4_protocol_register(&vti_ipcomp4_protocol, IPPROTO_COMP);
        if (err < 0) {
+               xfrm4_protocol_deregister(&vti_ah4_protocol, IPPROTO_AH);
+               xfrm4_protocol_deregister(&vti_esp4_protocol, IPPROTO_ESP);
                unregister_pernet_device(&vti_net_ops);
                pr_info("vti init: can't register tunnel\n");
+
+               return err;
        }
 
        err = rtnl_link_register(&vti_link_ops);
        return err;
 
 rtnl_link_failed:
-       xfrm4_mode_tunnel_input_deregister(&vti_handler);
+       xfrm4_protocol_deregister(&vti_ipcomp4_protocol, IPPROTO_COMP);
+       xfrm4_protocol_deregister(&vti_ah4_protocol, IPPROTO_AH);
+       xfrm4_protocol_deregister(&vti_esp4_protocol, IPPROTO_ESP);
        unregister_pernet_device(&vti_net_ops);
        return err;
 }
 static void __exit vti_fini(void)
 {
        rtnl_link_unregister(&vti_link_ops);
-       if (xfrm4_mode_tunnel_input_deregister(&vti_handler))
+       if (xfrm4_protocol_deregister(&vti_ipcomp4_protocol, IPPROTO_COMP))
+               pr_info("vti close: can't deregister tunnel\n");
+       if (xfrm4_protocol_deregister(&vti_ah4_protocol, IPPROTO_AH))
                pr_info("vti close: can't deregister tunnel\n");
+       if (xfrm4_protocol_deregister(&vti_esp4_protocol, IPPROTO_ESP))
+               pr_info("vti close: can't deregister tunnel\n");
+
 
        unregister_pernet_device(&vti_net_ops);
 }