#define ALL_AUTO_MODE_MASKS (RDMA_COUNTER_MASK_QP_TYPE | RDMA_COUNTER_MASK_PID)
 
-static int __counter_set_mode(struct rdma_counter_mode *curr,
+static int __counter_set_mode(struct rdma_port_counter *port_counter,
                              enum rdma_nl_counter_mode new_mode,
                              enum rdma_nl_counter_mask new_mask)
 {
-       if ((new_mode == RDMA_COUNTER_MODE_AUTO) &&
-           ((new_mask & (~ALL_AUTO_MODE_MASKS)) ||
-            (curr->mode != RDMA_COUNTER_MODE_NONE)))
-               return -EINVAL;
+       if (new_mode == RDMA_COUNTER_MODE_AUTO && port_counter->num_counters)
+               if (new_mask & ~ALL_AUTO_MODE_MASKS ||
+                   port_counter->mode.mode != RDMA_COUNTER_MODE_NONE)
+                       return -EINVAL;
 
-       curr->mode = new_mode;
-       curr->mask = new_mask;
+       port_counter->mode.mode = new_mode;
+       port_counter->mode.mask = new_mask;
        return 0;
 }
 
 /**
  * rdma_counter_set_auto_mode() - Turn on/off per-port auto mode
  *
- * When @on is true, the @mask must be set; When @on is false, it goes
- * into manual mode if there's any counter, so that the user is able to
- * manually access them.
+ * @dev: Device to operate
+ * @port: Port to use
+ * @mask: Mask to configure
+ * @extack: Message to the user
+ *
+ * Return 0 on success.
  */
 int rdma_counter_set_auto_mode(struct ib_device *dev, u8 port,
-                              bool on, enum rdma_nl_counter_mask mask)
+                              enum rdma_nl_counter_mask mask,
+                              struct netlink_ext_ack *extack)
 {
+       enum rdma_nl_counter_mode mode = RDMA_COUNTER_MODE_AUTO;
        struct rdma_port_counter *port_counter;
        int ret;
 
                return -EOPNOTSUPP;
 
        mutex_lock(&port_counter->lock);
-       if (on) {
-               ret = __counter_set_mode(&port_counter->mode,
-                                        RDMA_COUNTER_MODE_AUTO, mask);
-       } else {
-               if (port_counter->mode.mode != RDMA_COUNTER_MODE_AUTO) {
-                       ret = -EINVAL;
-                       goto out;
-               }
+       if (mask) {
+               ret = __counter_set_mode(port_counter, mode, mask);
+               if (ret)
+                       NL_SET_ERR_MSG(
+                               extack,
+                               "Turning on auto mode is not allowed when there is bound QP");
+               goto out;
+       }
 
-               if (port_counter->num_counters)
-                       ret = __counter_set_mode(&port_counter->mode,
-                                                RDMA_COUNTER_MODE_MANUAL, 0);
-               else
-                       ret = __counter_set_mode(&port_counter->mode,
-                                                RDMA_COUNTER_MODE_NONE, 0);
+       if (port_counter->mode.mode != RDMA_COUNTER_MODE_AUTO) {
+               ret = -EINVAL;
+               goto out;
        }
 
+       mode = (port_counter->num_counters) ? RDMA_COUNTER_MODE_MANUAL :
+                                                   RDMA_COUNTER_MODE_NONE;
+       ret = __counter_set_mode(port_counter, mode, 0);
 out:
        mutex_unlock(&port_counter->lock);
        return ret;
        mutex_lock(&port_counter->lock);
        switch (mode) {
        case RDMA_COUNTER_MODE_MANUAL:
-               ret = __counter_set_mode(&port_counter->mode,
-                                        RDMA_COUNTER_MODE_MANUAL, 0);
+               ret = __counter_set_mode(port_counter, RDMA_COUNTER_MODE_MANUAL,
+                                        0);
                if (ret) {
                        mutex_unlock(&port_counter->lock);
                        goto err_mode;
        port_counter->num_counters--;
        if (!port_counter->num_counters &&
            (port_counter->mode.mode == RDMA_COUNTER_MODE_MANUAL))
-               __counter_set_mode(&port_counter->mode, RDMA_COUNTER_MODE_NONE,
-                                  0);
+               __counter_set_mode(port_counter, RDMA_COUNTER_MODE_NONE, 0);
 
        mutex_unlock(&port_counter->lock);