*/
 void ieee80211_report_low_ack(struct ieee80211_sta *sta, u32 num_packets);
 
+#define IEEE80211_MAX_CSA_COUNTERS_NUM 2
+
 /**
  * struct ieee80211_mutable_offsets - mutable beacon offsets
  * @tim_offset: position of TIM element
  * @tim_length: size of TIM element
+ * @csa_offs: array of IEEE80211_MAX_CSA_COUNTERS_NUM offsets to CSA counters.
+ *     This array can contain zero values which should be ignored.
  */
 struct ieee80211_mutable_offsets {
        u16 tim_offset;
        u16 tim_length;
+
+       u16 csa_counter_offs[IEEE80211_MAX_CSA_COUNTERS_NUM];
 };
 
 /**
  *
  * This function should be used if the beacon frames are generated by the
  * device, and then the driver must use the returned beacon as the template
- * The driver is responsible to update the DTIM count.
+ * The driver or the device are responsible to update the DTIM and, when
+ * applicable, the CSA count.
  *
  * The driver is responsible for freeing the returned skb.
  *
        return ieee80211_beacon_get_tim(hw, vif, NULL, NULL);
 }
 
+/**
+ * ieee80211_csa_update_counter - request mac80211 to decrement the csa counter
+ * @vif: &struct ieee80211_vif pointer from the add_interface callback.
+ *
+ * The csa counter should be updated after each beacon transmission.
+ * This function is called implicitly when
+ * ieee80211_beacon_get/ieee80211_beacon_get_tim are called, however if the
+ * beacon frames are generated by the device, the driver should call this
+ * function after each beacon transmission to sync mac80211's csa counters.
+ *
+ * Return: new csa counter value
+ */
+u8 ieee80211_csa_update_counter(struct ieee80211_vif *vif);
+
 /**
  * ieee80211_csa_finish - notify mac80211 about channel switch
  * @vif: &struct ieee80211_vif pointer from the add_interface callback.
 
        return 0;
 }
 
-static void ieee80211_update_csa(struct ieee80211_sub_if_data *sdata,
-                                struct beacon_data *beacon)
+static void ieee80211_set_csa(struct ieee80211_sub_if_data *sdata,
+                             struct beacon_data *beacon)
 {
        struct probe_resp *resp;
        u8 *beacon_data;
        size_t beacon_data_len;
        int i;
+       u8 count = sdata->csa_current_counter;
 
        switch (sdata->vif.type) {
        case NL80211_IFTYPE_AP:
                        if (WARN_ON(counter_offset_beacon >= beacon_data_len))
                                return;
 
-                       /* Warn if the driver did not check for/react to csa
-                        * completeness.  A beacon with CSA counter set to 0
-                        * should never occur, because a counter of 1 means
-                        * switch just before the next beacon.
-                        */
-                       if (WARN_ON(beacon_data[counter_offset_beacon] == 1))
-                               return;
-
-                       beacon_data[counter_offset_beacon] =
-                               sdata->csa_current_counter - 1;
+                       beacon_data[counter_offset_beacon] = count;
                }
 
                if (sdata->vif.type == NL80211_IFTYPE_AP &&
                                rcu_read_unlock();
                                return;
                        }
-                       resp->data[counter_offset_presp] =
-                               sdata->csa_current_counter - 1;
+                       resp->data[counter_offset_presp] = count;
                        rcu_read_unlock();
                }
        }
+}
+
+u8 ieee80211_csa_update_counter(struct ieee80211_vif *vif)
+{
+       struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif);
 
        sdata->csa_current_counter--;
+
+       /* the counter should never reach 0 */
+       WARN_ON(!sdata->csa_current_counter);
+
+       return sdata->csa_current_counter;
 }
+EXPORT_SYMBOL(ieee80211_csa_update_counter);
 
 bool ieee80211_csa_is_complete(struct ieee80211_vif *vif)
 {
        enum ieee80211_band band;
        struct ieee80211_tx_rate_control txrc;
        struct ieee80211_chanctx_conf *chanctx_conf;
+       int csa_off_base = 0;
 
        rcu_read_lock();
 
                struct beacon_data *beacon = rcu_dereference(ap->beacon);
 
                if (beacon) {
-                       if (sdata->vif.csa_active)
-                               ieee80211_update_csa(sdata, beacon);
+                       if (sdata->vif.csa_active) {
+                               if (!is_template)
+                                       ieee80211_csa_update_counter(vif);
+
+                               ieee80211_set_csa(sdata, beacon);
+                       }
 
                        /*
                         * headroom, head length,
                        if (offs) {
                                offs->tim_offset = beacon->head_len;
                                offs->tim_length = skb->len - beacon->head_len;
+
+                               /* for AP the csa offsets are from tail */
+                               csa_off_base = skb->len;
                        }
 
                        if (beacon->tail)
                if (!presp)
                        goto out;
 
-               if (sdata->vif.csa_active)
-                       ieee80211_update_csa(sdata, presp);
+               if (sdata->vif.csa_active) {
+                       if (!is_template)
+                               ieee80211_csa_update_counter(vif);
 
+                       ieee80211_set_csa(sdata, presp);
+               }
 
                skb = dev_alloc_skb(local->tx_headroom + presp->head_len +
                                    local->hw.extra_beacon_tailroom);
                if (!bcn)
                        goto out;
 
-               if (sdata->vif.csa_active)
-                       ieee80211_update_csa(sdata, bcn);
+               if (sdata->vif.csa_active) {
+                       if (!is_template)
+                               /* TODO: For mesh csa_counter is in TU, so
+                                * decrementing it by one isn't correct, but
+                                * for now we leave it consistent with overall
+                                * mac80211's behavior.
+                                */
+                               ieee80211_csa_update_counter(vif);
+
+                       ieee80211_set_csa(sdata, bcn);
+               }
 
                if (ifmsh->sync_ops)
                        ifmsh->sync_ops->adjust_tbtt(sdata, bcn);
                goto out;
        }
 
+       /* CSA offsets */
+       if (offs) {
+               int i;
+
+               for (i = 0; i < IEEE80211_MAX_CSA_COUNTERS_NUM; i++) {
+                       u16 csa_off = sdata->csa_counter_offset_beacon[i];
+
+                       if (!csa_off)
+                               continue;
+
+                       offs->csa_counter_offs[i] = csa_off_base + csa_off;
+               }
+       }
+
        band = chanctx_conf->def.chan->band;
 
        info = IEEE80211_SKB_CB(skb);