}
 
 #ifdef CONFIG_NET_CLS_ACT
-DECLARE_STATIC_KEY_FALSE(tcf_bypass_check_needed_key);
+DECLARE_STATIC_KEY_FALSE(tcf_sw_enabled_key);
 
 static inline bool tcf_block_bypass_sw(struct tcf_block *block)
 {
-       return block && block->bypass_wanted;
+       return block && !atomic_read(&block->useswcnt);
 }
 #endif
 
                cls_common->extack = extack;
 }
 
+static inline void tcf_proto_update_usesw(struct tcf_proto *tp, u32 flags)
+{
+       if (tp->usesw)
+               return;
+       if (tc_skip_sw(flags) && tc_in_hw(flags))
+               return;
+       tp->usesw = true;
+}
+
 #if IS_ENABLED(CONFIG_NET_TC_SKB_EXT)
 static inline struct tc_skb_ext *tc_skb_ext_alloc(struct sk_buff *skb)
 {
 
 #endif
 
 #ifdef CONFIG_NET_CLS_ACT
-DEFINE_STATIC_KEY_FALSE(tcf_bypass_check_needed_key);
-EXPORT_SYMBOL(tcf_bypass_check_needed_key);
+DEFINE_STATIC_KEY_FALSE(tcf_sw_enabled_key);
+EXPORT_SYMBOL(tcf_sw_enabled_key);
 #endif
 
 DEFINE_STATIC_KEY_FALSE(netstamp_needed_key);
        if (!miniq)
                return ret;
 
-       if (static_branch_unlikely(&tcf_bypass_check_needed_key)) {
-               if (tcf_block_bypass_sw(miniq->block))
-                       return ret;
-       }
+       /* Global bypass */
+       if (!static_branch_likely(&tcf_sw_enabled_key))
+               return ret;
+
+       /* Block-wise bypass */
+       if (tcf_block_bypass_sw(miniq->block))
+               return ret;
 
        tc_skb_cb(skb)->mru = 0;
        tc_skb_cb(skb)->post_ct = false;
 
        tp->protocol = protocol;
        tp->prio = prio;
        tp->chain = chain;
+       tp->usesw = !tp->ops->reoffload;
        spin_lock_init(&tp->lock);
        refcount_set(&tp->refcnt, 1);
 
        refcount_inc(&tp->refcnt);
 }
 
-static void tcf_maintain_bypass(struct tcf_block *block)
+static void tcf_proto_count_usesw(struct tcf_proto *tp, bool add)
 {
-       int filtercnt = atomic_read(&block->filtercnt);
-       int skipswcnt = atomic_read(&block->skipswcnt);
-       bool bypass_wanted = filtercnt > 0 && filtercnt == skipswcnt;
-
-       if (bypass_wanted != block->bypass_wanted) {
 #ifdef CONFIG_NET_CLS_ACT
-               if (bypass_wanted)
-                       static_branch_inc(&tcf_bypass_check_needed_key);
-               else
-                       static_branch_dec(&tcf_bypass_check_needed_key);
-#endif
-               block->bypass_wanted = bypass_wanted;
+       struct tcf_block *block = tp->chain->block;
+       bool counted = false;
+
+       if (!add) {
+               if (tp->usesw && tp->counted) {
+                       if (!atomic_dec_return(&block->useswcnt))
+                               static_branch_dec(&tcf_sw_enabled_key);
+                       tp->counted = false;
+               }
+               return;
        }
-}
-
-static void tcf_block_filter_cnt_update(struct tcf_block *block, bool *counted, bool add)
-{
-       lockdep_assert_not_held(&block->cb_lock);
 
-       down_write(&block->cb_lock);
-       if (*counted != add) {
-               if (add) {
-                       atomic_inc(&block->filtercnt);
-                       *counted = true;
-               } else {
-                       atomic_dec(&block->filtercnt);
-                       *counted = false;
-               }
+       spin_lock(&tp->lock);
+       if (tp->usesw && !tp->counted) {
+               counted = true;
+               tp->counted = true;
        }
-       tcf_maintain_bypass(block);
-       up_write(&block->cb_lock);
+       spin_unlock(&tp->lock);
+
+       if (counted && atomic_inc_return(&block->useswcnt) == 1)
+               static_branch_inc(&tcf_sw_enabled_key);
+#endif
 }
 
 static void tcf_chain_put(struct tcf_chain *chain);
                              bool sig_destroy, struct netlink_ext_ack *extack)
 {
        tp->ops->destroy(tp, rtnl_held, extack);
-       tcf_block_filter_cnt_update(tp->chain->block, &tp->counted, false);
+       tcf_proto_count_usesw(tp, false);
        if (sig_destroy)
                tcf_proto_signal_destroyed(tp->chain, tp);
        tcf_chain_put(tp->chain);
                tfilter_notify(net, skb, n, tp, block, q, parent, fh,
                               RTM_NEWTFILTER, false, rtnl_held, extack);
                tfilter_put(tp, fh);
-               tcf_block_filter_cnt_update(block, &tp->counted, true);
+               tcf_proto_count_usesw(tp, true);
                /* q pointer is NULL for shared blocks */
                if (q)
                        q->flags &= ~TCQ_F_CAN_BYPASS;
        if (*flags & TCA_CLS_FLAGS_IN_HW)
                return;
        *flags |= TCA_CLS_FLAGS_IN_HW;
-       if (tc_skip_sw(*flags))
-               atomic_inc(&block->skipswcnt);
        atomic_inc(&block->offloadcnt);
 }
 
        if (!(*flags & TCA_CLS_FLAGS_IN_HW))
                return;
        *flags &= ~TCA_CLS_FLAGS_IN_HW;
-       if (tc_skip_sw(*flags))
-               atomic_dec(&block->skipswcnt);
        atomic_dec(&block->offloadcnt);
 }
 
 
                if (!tc_in_hw(new->flags))
                        new->flags |= TCA_CLS_FLAGS_NOT_IN_HW;
 
+               tcf_proto_update_usesw(tp, new->flags);
+
                u32_replace_knode(tp, tp_c, new);
                tcf_unbind_filter(tp, &n->res);
                tcf_exts_get_net(&n->exts);
                if (!tc_in_hw(n->flags))
                        n->flags |= TCA_CLS_FLAGS_NOT_IN_HW;
 
+               tcf_proto_update_usesw(tp, n->flags);
+
                ins = &ht->ht[TC_U32_HASH(handle)];
                for (pins = rtnl_dereference(*ins); pins;
                     ins = &pins->next, pins = rtnl_dereference(*ins))