struct fl_flow_mask {
        struct fl_flow_key key;
        struct fl_flow_mask_range range;
-       struct rcu_head rcu;
+       struct rhash_head ht_node;
+       struct rhashtable ht;
+       struct rhashtable_params filter_ht_params;
+       struct flow_dissector dissector;
+       struct list_head filters;
+       struct rcu_head rcu;
+       struct list_head list;
 };
 
 struct cls_fl_head {
        struct rhashtable ht;
-       struct fl_flow_mask mask;
-       struct flow_dissector dissector;
-       bool mask_assigned;
-       struct list_head filters;
-       struct rhashtable_params ht_params;
+       struct list_head masks;
        union {
                struct work_struct work;
                struct rcu_head rcu;
 };
 
 struct cls_fl_filter {
+       struct fl_flow_mask *mask;
        struct rhash_head ht_node;
        struct fl_flow_key mkey;
        struct tcf_exts exts;
        struct net_device *hw_dev;
 };
 
+static const struct rhashtable_params mask_ht_params = {
+       .key_offset = offsetof(struct fl_flow_mask, key),
+       .key_len = sizeof(struct fl_flow_key),
+       .head_offset = offsetof(struct fl_flow_mask, ht_node),
+       .automatic_shrinking = true,
+};
+
 static unsigned short int fl_mask_range(const struct fl_flow_mask *mask)
 {
        return mask->range.end - mask->range.start;
 {
        const u8 *bytes = (const u8 *) &mask->key;
        size_t size = sizeof(mask->key);
-       size_t i, first = 0, last = size - 1;
+       size_t i, first = 0, last;
 
-       for (i = 0; i < sizeof(mask->key); i++) {
+       for (i = 0; i < size; i++) {
+               if (bytes[i]) {
+                       first = i;
+                       break;
+               }
+       }
+       last = first;
+       for (i = size - 1; i != first; i--) {
                if (bytes[i]) {
-                       if (!first && i)
-                               first = i;
                        last = i;
+                       break;
                }
        }
        mask->range.start = rounddown(first, sizeof(long));
        memset(fl_key_get_start(key, mask), 0, fl_mask_range(mask));
 }
 
-static struct cls_fl_filter *fl_lookup(struct cls_fl_head *head,
+static struct cls_fl_filter *fl_lookup(struct fl_flow_mask *mask,
                                       struct fl_flow_key *mkey)
 {
-       return rhashtable_lookup_fast(&head->ht,
-                                     fl_key_get_start(mkey, &head->mask),
-                                     head->ht_params);
+       return rhashtable_lookup_fast(&mask->ht, fl_key_get_start(mkey, mask),
+                                     mask->filter_ht_params);
 }
 
 static int fl_classify(struct sk_buff *skb, const struct tcf_proto *tp,
 {
        struct cls_fl_head *head = rcu_dereference_bh(tp->root);
        struct cls_fl_filter *f;
+       struct fl_flow_mask *mask;
        struct fl_flow_key skb_key;
        struct fl_flow_key skb_mkey;
 
-       if (!atomic_read(&head->ht.nelems))
-               return -1;
-
-       fl_clear_masked_range(&skb_key, &head->mask);
+       list_for_each_entry_rcu(mask, &head->masks, list) {
+               fl_clear_masked_range(&skb_key, mask);
 
-       skb_key.indev_ifindex = skb->skb_iif;
-       /* skb_flow_dissect() does not set n_proto in case an unknown protocol,
-        * so do it rather here.
-        */
-       skb_key.basic.n_proto = skb->protocol;
-       skb_flow_dissect_tunnel_info(skb, &head->dissector, &skb_key);
-       skb_flow_dissect(skb, &head->dissector, &skb_key, 0);
+               skb_key.indev_ifindex = skb->skb_iif;
+               /* skb_flow_dissect() does not set n_proto in case an unknown
+                * protocol, so do it rather here.
+                */
+               skb_key.basic.n_proto = skb->protocol;
+               skb_flow_dissect_tunnel_info(skb, &mask->dissector, &skb_key);
+               skb_flow_dissect(skb, &mask->dissector, &skb_key, 0);
 
-       fl_set_masked_key(&skb_mkey, &skb_key, &head->mask);
+               fl_set_masked_key(&skb_mkey, &skb_key, mask);
 
-       f = fl_lookup(head, &skb_mkey);
-       if (f && !tc_skip_sw(f->flags)) {
-               *res = f->res;
-               return tcf_exts_exec(skb, &f->exts, res);
+               f = fl_lookup(mask, &skb_mkey);
+               if (f && !tc_skip_sw(f->flags)) {
+                       *res = f->res;
+                       return tcf_exts_exec(skb, &f->exts, res);
+               }
        }
        return -1;
 }
        if (!head)
                return -ENOBUFS;
 
-       INIT_LIST_HEAD_RCU(&head->filters);
+       INIT_LIST_HEAD_RCU(&head->masks);
        rcu_assign_pointer(tp->root, head);
        idr_init(&head->handle_idr);
 
-       return 0;
+       return rhashtable_init(&head->ht, &mask_ht_params);
+}
+
+static bool fl_mask_put(struct cls_fl_head *head, struct fl_flow_mask *mask,
+                       bool async)
+{
+       if (!list_empty(&mask->filters))
+               return false;
+
+       rhashtable_remove_fast(&head->ht, &mask->ht_node, mask_ht_params);
+       rhashtable_destroy(&mask->ht);
+       list_del_rcu(&mask->list);
+       if (async)
+               kfree_rcu(mask, rcu);
+       else
+               kfree(mask);
+
+       return true;
 }
 
 static void __fl_destroy_filter(struct cls_fl_filter *f)
 }
 
 static int fl_hw_replace_filter(struct tcf_proto *tp,
-                               struct flow_dissector *dissector,
-                               struct fl_flow_key *mask,
                                struct cls_fl_filter *f,
                                struct netlink_ext_ack *extack)
 {
        tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, extack);
        cls_flower.command = TC_CLSFLOWER_REPLACE;
        cls_flower.cookie = (unsigned long) f;
-       cls_flower.dissector = dissector;
-       cls_flower.mask = mask;
+       cls_flower.dissector = &f->mask->dissector;
+       cls_flower.mask = &f->mask->key;
        cls_flower.key = &f->mkey;
        cls_flower.exts = &f->exts;
        cls_flower.classid = f->res.classid;
                         &cls_flower, false);
 }
 
-static void __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f,
+static bool __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f,
                        struct netlink_ext_ack *extack)
 {
        struct cls_fl_head *head = rtnl_dereference(tp->root);
+       bool async = tcf_exts_get_net(&f->exts);
+       bool last;
 
        idr_remove(&head->handle_idr, f->handle);
        list_del_rcu(&f->list);
+       last = fl_mask_put(head, f->mask, async);
        if (!tc_skip_hw(f->flags))
                fl_hw_destroy_filter(tp, f, extack);
        tcf_unbind_filter(tp, &f->res);
-       if (tcf_exts_get_net(&f->exts))
+       if (async)
                call_rcu(&f->rcu, fl_destroy_filter);
        else
                __fl_destroy_filter(f);
+
+       return last;
 }
 
 static void fl_destroy_sleepable(struct work_struct *work)
 {
        struct cls_fl_head *head = container_of(work, struct cls_fl_head,
                                                work);
-       if (head->mask_assigned)
-               rhashtable_destroy(&head->ht);
        kfree(head);
        module_put(THIS_MODULE);
 }
 static void fl_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
        struct cls_fl_head *head = rtnl_dereference(tp->root);
+       struct fl_flow_mask *mask, *next_mask;
        struct cls_fl_filter *f, *next;
 
-       list_for_each_entry_safe(f, next, &head->filters, list)
-               __fl_delete(tp, f, extack);
+       list_for_each_entry_safe(mask, next_mask, &head->masks, list) {
+               list_for_each_entry_safe(f, next, &mask->filters, list) {
+                       if (__fl_delete(tp, f, extack))
+                               break;
+               }
+       }
        idr_destroy(&head->handle_idr);
 
        __module_get(THIS_MODULE);
        return ret;
 }
 
-static bool fl_mask_eq(struct fl_flow_mask *mask1,
-                      struct fl_flow_mask *mask2)
+static void fl_mask_copy(struct fl_flow_mask *dst,
+                        struct fl_flow_mask *src)
 {
-       const long *lmask1 = fl_key_get_start(&mask1->key, mask1);
-       const long *lmask2 = fl_key_get_start(&mask2->key, mask2);
+       const void *psrc = fl_key_get_start(&src->key, src);
+       void *pdst = fl_key_get_start(&dst->key, src);
 
-       return !memcmp(&mask1->range, &mask2->range, sizeof(mask1->range)) &&
-              !memcmp(lmask1, lmask2, fl_mask_range(mask1));
+       memcpy(pdst, psrc, fl_mask_range(src));
+       dst->range = src->range;
 }
 
 static const struct rhashtable_params fl_ht_params = {
        .automatic_shrinking = true,
 };
 
-static int fl_init_hashtable(struct cls_fl_head *head,
-                            struct fl_flow_mask *mask)
+static int fl_init_mask_hashtable(struct fl_flow_mask *mask)
 {
-       head->ht_params = fl_ht_params;
-       head->ht_params.key_len = fl_mask_range(mask);
-       head->ht_params.key_offset += mask->range.start;
+       mask->filter_ht_params = fl_ht_params;
+       mask->filter_ht_params.key_len = fl_mask_range(mask);
+       mask->filter_ht_params.key_offset += mask->range.start;
 
-       return rhashtable_init(&head->ht, &head->ht_params);
+       return rhashtable_init(&mask->ht, &mask->filter_ht_params);
 }
 
 #define FL_KEY_MEMBER_OFFSET(member) offsetof(struct fl_flow_key, member)
                        FL_KEY_SET(keys, cnt, id, member);                      \
        } while(0);
 
-static void fl_init_dissector(struct cls_fl_head *head,
-                             struct fl_flow_mask *mask)
+static void fl_init_dissector(struct fl_flow_mask *mask)
 {
        struct flow_dissector_key keys[FLOW_DISSECTOR_KEY_MAX];
        size_t cnt = 0;
        FL_KEY_SET_IF_MASKED(&mask->key, keys, cnt,
                             FLOW_DISSECTOR_KEY_ENC_PORTS, enc_tp);
 
-       skb_flow_dissector_init(&head->dissector, keys, cnt);
+       skb_flow_dissector_init(&mask->dissector, keys, cnt);
+}
+
+static struct fl_flow_mask *fl_create_new_mask(struct cls_fl_head *head,
+                                              struct fl_flow_mask *mask)
+{
+       struct fl_flow_mask *newmask;
+       int err;
+
+       newmask = kzalloc(sizeof(*newmask), GFP_KERNEL);
+       if (!newmask)
+               return ERR_PTR(-ENOMEM);
+
+       fl_mask_copy(newmask, mask);
+
+       err = fl_init_mask_hashtable(newmask);
+       if (err)
+               goto errout_free;
+
+       fl_init_dissector(newmask);
+
+       INIT_LIST_HEAD_RCU(&newmask->filters);
+
+       err = rhashtable_insert_fast(&head->ht, &newmask->ht_node,
+                                    mask_ht_params);
+       if (err)
+               goto errout_destroy;
+
+       list_add_tail_rcu(&newmask->list, &head->masks);
+
+       return newmask;
+
+errout_destroy:
+       rhashtable_destroy(&newmask->ht);
+errout_free:
+       kfree(newmask);
+
+       return ERR_PTR(err);
 }
 
 static int fl_check_assign_mask(struct cls_fl_head *head,
+                               struct cls_fl_filter *fnew,
+                               struct cls_fl_filter *fold,
                                struct fl_flow_mask *mask)
 {
-       int err;
+       struct fl_flow_mask *newmask;
 
-       if (head->mask_assigned) {
-               if (!fl_mask_eq(&head->mask, mask))
+       fnew->mask = rhashtable_lookup_fast(&head->ht, mask, mask_ht_params);
+       if (!fnew->mask) {
+               if (fold)
                        return -EINVAL;
-               else
-                       return 0;
-       }
 
-       /* Mask is not assigned yet. So assign it and init hashtable
-        * according to that.
-        */
-       err = fl_init_hashtable(head, mask);
-       if (err)
-               return err;
-       memcpy(&head->mask, mask, sizeof(head->mask));
-       head->mask_assigned = true;
+               newmask = fl_create_new_mask(head, mask);
+               if (IS_ERR(newmask))
+                       return PTR_ERR(newmask);
 
-       fl_init_dissector(head, mask);
+               fnew->mask = newmask;
+       } else if (fold && fold->mask == fnew->mask) {
+               return -EINVAL;
+       }
 
        return 0;
 }
        if (err)
                goto errout_idr;
 
-       err = fl_check_assign_mask(head, &mask);
+       err = fl_check_assign_mask(head, fnew, fold, &mask);
        if (err)
                goto errout_idr;
 
        if (!tc_skip_sw(fnew->flags)) {
-               if (!fold && fl_lookup(head, &fnew->mkey)) {
+               if (!fold && fl_lookup(fnew->mask, &fnew->mkey)) {
                        err = -EEXIST;
-                       goto errout_idr;
+                       goto errout_mask;
                }
 
-               err = rhashtable_insert_fast(&head->ht, &fnew->ht_node,
-                                            head->ht_params);
+               err = rhashtable_insert_fast(&fnew->mask->ht, &fnew->ht_node,
+                                            fnew->mask->filter_ht_params);
                if (err)
-                       goto errout_idr;
+                       goto errout_mask;
        }
 
        if (!tc_skip_hw(fnew->flags)) {
-               err = fl_hw_replace_filter(tp,
-                                          &head->dissector,
-                                          &mask.key,
-                                          fnew,
-                                          extack);
+               err = fl_hw_replace_filter(tp, fnew, extack);
                if (err)
-                       goto errout_idr;
+                       goto errout_mask;
        }
 
        if (!tc_in_hw(fnew->flags))
 
        if (fold) {
                if (!tc_skip_sw(fold->flags))
-                       rhashtable_remove_fast(&head->ht, &fold->ht_node,
-                                              head->ht_params);
+                       rhashtable_remove_fast(&fold->mask->ht,
+                                              &fold->ht_node,
+                                              fold->mask->filter_ht_params);
                if (!tc_skip_hw(fold->flags))
                        fl_hw_destroy_filter(tp, fold, NULL);
        }
                tcf_exts_get_net(&fold->exts);
                call_rcu(&fold->rcu, fl_destroy_filter);
        } else {
-               list_add_tail_rcu(&fnew->list, &head->filters);
+               list_add_tail_rcu(&fnew->list, &fnew->mask->filters);
        }
 
        kfree(tb);
        return 0;
 
+errout_mask:
+       fl_mask_put(head, fnew->mask, false);
+
 errout_idr:
        if (fnew->handle)
                idr_remove(&head->handle_idr, fnew->handle);
        struct cls_fl_filter *f = arg;
 
        if (!tc_skip_sw(f->flags))
-               rhashtable_remove_fast(&head->ht, &f->ht_node,
-                                      head->ht_params);
+               rhashtable_remove_fast(&f->mask->ht, &f->ht_node,
+                                      f->mask->filter_ht_params);
        __fl_delete(tp, f, extack);
-       *last = list_empty(&head->filters);
+       *last = list_empty(&head->masks);
        return 0;
 }
 
 {
        struct cls_fl_head *head = rtnl_dereference(tp->root);
        struct cls_fl_filter *f;
-
-       list_for_each_entry_rcu(f, &head->filters, list) {
-               if (arg->count < arg->skip)
-                       goto skip;
-               if (arg->fn(tp, f, arg) < 0) {
-                       arg->stop = 1;
-                       break;
-               }
+       struct fl_flow_mask *mask;
+
+       list_for_each_entry_rcu(mask, &head->masks, list) {
+               list_for_each_entry_rcu(f, &mask->filters, list) {
+                       if (arg->count < arg->skip)
+                               goto skip;
+                       if (arg->fn(tp, f, arg) < 0) {
+                               arg->stop = 1;
+                               break;
+                       }
 skip:
-               arg->count++;
+                       arg->count++;
+               }
        }
 }
 
 static int fl_dump(struct net *net, struct tcf_proto *tp, void *fh,
                   struct sk_buff *skb, struct tcmsg *t)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
        struct cls_fl_filter *f = fh;
        struct nlattr *nest;
        struct fl_flow_key *key, *mask;
                goto nla_put_failure;
 
        key = &f->key;
-       mask = &head->mask.key;
+       mask = &f->mask->key;
 
        if (mask->indev_ifindex) {
                struct net_device *dev;