#include <net/nexthop.h>
 #include "internal.h"
 
-#define MAX_NEW_LABELS 2
+/* put a reasonable limit on the number of labels
+ * we will accept from userspace
+ */
+#define MAX_NEW_LABELS 30
 
 /* max memory we will use for mpls_route */
 #define MAX_MPLS_ROUTE_MEM     4096
                return -ENOMEM;
 
        err = -EINVAL;
-       /* Ensure only a supported number of labels are present */
-       if (cfg->rc_output_labels > MAX_NEW_LABELS)
-               goto errout;
 
        nh->nh_labels = cfg->rc_output_labels;
        for (i = 0; i < nh->nh_labels; i++)
 
 static int mpls_nh_build(struct net *net, struct mpls_route *rt,
                         struct mpls_nh *nh, int oif, struct nlattr *via,
-                        struct nlattr *newdst)
+                        struct nlattr *newdst, u8 max_labels)
 {
        int err = -ENOMEM;
 
                goto errout;
 
        if (newdst) {
-               err = nla_get_labels(newdst, MAX_NEW_LABELS,
+               err = nla_get_labels(newdst, max_labels,
                                     &nh->nh_labels, nh->nh_label);
                if (err)
                        goto errout;
 }
 
 static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len,
-                             u8 cfg_via_alen, u8 *max_via_alen)
+                             u8 cfg_via_alen, u8 *max_via_alen,
+                             u8 *max_labels)
 {
        int remaining = len;
        u8 nhs = 0;
 
-       if (!rtnh) {
-               *max_via_alen = cfg_via_alen;
-               return 1;
-       }
-
        *max_via_alen = 0;
+       *max_labels = 0;
 
        while (rtnh_ok(rtnh, remaining)) {
                struct nlattr *nla, *attrs = rtnh_attrs(rtnh);
                int attrlen;
+               u8 n_labels = 0;
 
                attrlen = rtnh_attrlen(rtnh);
                nla = nla_find(attrs, attrlen, RTA_VIA);
                                                      via_alen);
                }
 
+               nla = nla_find(attrs, attrlen, RTA_NEWDST);
+               if (nla &&
+                   nla_get_labels(nla, MAX_NEW_LABELS, &n_labels, NULL) != 0)
+                       return 0;
+
+               *max_labels = max_t(u8, *max_labels, n_labels);
+
                /* number of nexthops is tracked by a u8.
                 * Check for overflow.
                 */
 }
 
 static int mpls_nh_build_multi(struct mpls_route_config *cfg,
-                              struct mpls_route *rt)
+                              struct mpls_route *rt, u8 max_labels)
 {
        struct rtnexthop *rtnh = cfg->rc_mp;
        struct nlattr *nla_via, *nla_newdst;
                }
 
                err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh,
-                                   rtnh->rtnh_ifindex, nla_via, nla_newdst);
+                                   rtnh->rtnh_ifindex, nla_via, nla_newdst,
+                                   max_labels);
                if (err)
                        goto errout;
 
        int err = -EINVAL;
        u8 max_via_alen;
        unsigned index;
+       u8 max_labels;
        u8 nhs;
 
        index = cfg->rc_label;
                goto errout;
 
        err = -EINVAL;
-       nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
-                                 cfg->rc_via_alen, &max_via_alen);
+       if (cfg->rc_mp) {
+               nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
+                                         cfg->rc_via_alen, &max_via_alen,
+                                         &max_labels);
+       } else {
+               max_via_alen = cfg->rc_via_alen;
+               max_labels = cfg->rc_output_labels;
+               nhs = 1;
+       }
+
        if (nhs == 0)
                goto errout;
 
        err = -ENOMEM;
-       rt = mpls_rt_alloc(nhs, max_via_alen, MAX_NEW_LABELS);
+       rt = mpls_rt_alloc(nhs, max_via_alen, max_labels);
        if (IS_ERR(rt)) {
                err = PTR_ERR(rt);
                goto errout;
        rt->rt_ttl_propagate = cfg->rc_ttl_propagate;
 
        if (cfg->rc_mp)
-               err = mpls_nh_build_multi(cfg, rt);
+               err = mpls_nh_build_multi(cfg, rt, max_labels);
        else
                err = mpls_nh_build_from_cfg(cfg, rt);
        if (err)
 EXPORT_SYMBOL_GPL(nla_put_labels);
 
 int nla_get_labels(const struct nlattr *nla,
-                  u32 max_labels, u8 *labels, u32 label[])
+                  u8 max_labels, u8 *labels, u32 label[])
 {
        unsigned len = nla_len(nla);
-       unsigned nla_labels;
        struct mpls_shim_hdr *nla_label;
+       u8 nla_labels;
        bool bos;
        int i;
 
-       /* len needs to be an even multiple of 4 (the label size) */
-       if (len & 3)
+       /* len needs to be an even multiple of 4 (the label size). Number
+        * of labels is a u8 so check for overflow.
+        */
+       if (len & 3 || len / 4 > 255)
                return -EINVAL;
 
        /* Limit the number of new labels allowed */
        if (nla_labels > max_labels)
                return -EINVAL;
 
+       /* when label == NULL, caller wants number of labels */
+       if (!label)
+               goto out;
+
        nla_label = nla_data(nla);
        bos = true;
        for (i = nla_labels - 1; i >= 0; i--, bos = false) {
 
                label[i] = dec.label;
        }
+out:
        *labels = nla_labels;
        return 0;
 }
 
        err = -EINVAL;
        rtm = nlmsg_data(nlh);
-       memset(cfg, 0, sizeof(*cfg));
 
        if (rtm->rtm_family != AF_MPLS)
                goto errout;
 
 static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh)
 {
-       struct mpls_route_config cfg;
+       struct mpls_route_config *cfg;
        int err;
 
-       err = rtm_to_route_config(skb, nlh, &cfg);
+       cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
+       if (!cfg)
+               return -ENOMEM;
+
+       err = rtm_to_route_config(skb, nlh, cfg);
        if (err < 0)
-               return err;
+               goto out;
 
-       return mpls_route_del(&cfg);
+       err = mpls_route_del(cfg);
+out:
+       kfree(cfg);
+
+       return err;
 }
 
 
 static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
 {
-       struct mpls_route_config cfg;
+       struct mpls_route_config *cfg;
        int err;
 
-       err = rtm_to_route_config(skb, nlh, &cfg);
+       cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
+       if (!cfg)
+               return -ENOMEM;
+
+       err = rtm_to_route_config(skb, nlh, cfg);
        if (err < 0)
-               return err;
+               goto out;
 
-       return mpls_route_add(&cfg);
+       err = mpls_route_add(cfg);
+out:
+       kfree(cfg);
+
+       return err;
 }
 
 static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
        /* In case the predefined labels need to be populated */
        if (limit > MPLS_LABEL_IPV4NULL) {
                struct net_device *lo = net->loopback_dev;
-               rt0 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS);
+               rt0 = mpls_rt_alloc(1, lo->addr_len, 0);
                if (IS_ERR(rt0))
                        goto nort0;
                RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo);
        }
        if (limit > MPLS_LABEL_IPV6NULL) {
                struct net_device *lo = net->loopback_dev;
-               rt2 = mpls_rt_alloc(1, lo->addr_len, MAX_NEW_LABELS);
+               rt2 = mpls_rt_alloc(1, lo->addr_len, 0);
                if (IS_ERR(rt2))
                        goto nort2;
                RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo);