static void __percpu_ref_switch_to_atomic(struct percpu_ref *ref,
                                          percpu_ref_func_t *confirm_switch)
 {
-       /*
-        * If the previous ATOMIC switching hasn't finished yet, wait for
-        * its completion.  If the caller ensures that ATOMIC switching
-        * isn't in progress, this function can be called from any context.
-        * Do an extra confirm_switch test to circumvent the unconditional
-        * might_sleep() in wait_event().
-        */
-       if (ref->confirm_switch)
-               wait_event(percpu_ref_switch_waitq, !ref->confirm_switch);
-
        if (ref->percpu_count_ptr & __PERCPU_REF_ATOMIC) {
                if (confirm_switch)
                        confirm_switch(ref);
        unsigned long __percpu *percpu_count = percpu_count_ptr(ref);
        int cpu;
 
-       /*
-        * If the previous ATOMIC switching hasn't finished yet, wait for
-        * its completion.  If the caller ensures that ATOMIC switching
-        * isn't in progress, this function can be called from any context.
-        * Do an extra confirm_switch test to circumvent the unconditional
-        * might_sleep() in wait_event().
-        */
-       if (ref->confirm_switch)
-               wait_event(percpu_ref_switch_waitq, !ref->confirm_switch);
-
        BUG_ON(!percpu_count);
 
        if (!(ref->percpu_count_ptr & __PERCPU_REF_ATOMIC))
                          ref->percpu_count_ptr & ~__PERCPU_REF_ATOMIC);
 }
 
+static void __percpu_ref_switch_mode(struct percpu_ref *ref,
+                                    percpu_ref_func_t *confirm_switch)
+{
+       /*
+        * If the previous ATOMIC switching hasn't finished yet, wait for
+        * its completion.  If the caller ensures that ATOMIC switching
+        * isn't in progress, this function can be called from any context.
+        * Do an extra confirm_switch test to circumvent the unconditional
+        * might_sleep() in wait_event().
+        */
+       if (ref->confirm_switch)
+               wait_event(percpu_ref_switch_waitq, !ref->confirm_switch);
+
+       if (ref->force_atomic || (ref->percpu_count_ptr & __PERCPU_REF_DEAD))
+               __percpu_ref_switch_to_atomic(ref, confirm_switch);
+       else
+               __percpu_ref_switch_to_percpu(ref);
+}
+
 /**
  * percpu_ref_switch_to_atomic - switch a percpu_ref to atomic mode
  * @ref: percpu_ref to switch to atomic mode
  * operations.  Note that @ref will stay in atomic mode across kill/reinit
  * cycles until percpu_ref_switch_to_percpu() is called.
  *
- * This function normally doesn't block and can be called from any context
- * but it may block if @confirm_kill is specified and @ref is already in
- * the process of switching to atomic mode.  In such cases, @confirm_switch
- * will be invoked after the switching is complete.
+ * This function may block if @ref is in the process of switching to atomic
+ * mode.  If the caller ensures that @ref is not in the process of
+ * switching to atomic mode, this function can be called from any context.
  */
 void percpu_ref_switch_to_atomic(struct percpu_ref *ref,
                                 percpu_ref_func_t *confirm_switch)
 {
        ref->force_atomic = true;
-       __percpu_ref_switch_to_atomic(ref, confirm_switch);
+       __percpu_ref_switch_mode(ref, confirm_switch);
 }
 
 /**
  * dying or dead, the actual switching takes place on the following
  * percpu_ref_reinit().
  *
- * This function normally doesn't block and can be called from any context
- * but it may block if @ref is in the process of switching to atomic mode
- * by percpu_ref_switch_atomic().
+ * This function may block if @ref is in the process of switching to atomic
+ * mode.  If the caller ensures that @ref is not in the process of
+ * switching to atomic mode, this function can be called from any context.
  */
 void percpu_ref_switch_to_percpu(struct percpu_ref *ref)
 {
        ref->force_atomic = false;
-
-       /* a dying or dead ref can't be switched to percpu mode w/o reinit */
-       if (!(ref->percpu_count_ptr & __PERCPU_REF_DEAD))
-               __percpu_ref_switch_to_percpu(ref);
+       __percpu_ref_switch_mode(ref, NULL);
 }
 
 /**
                  "%s called more than once on %pf!", __func__, ref->release);
 
        ref->percpu_count_ptr |= __PERCPU_REF_DEAD;
-       __percpu_ref_switch_to_atomic(ref, confirm_kill);
+       __percpu_ref_switch_mode(ref, confirm_kill);
        percpu_ref_put(ref);
 }
 EXPORT_SYMBOL_GPL(percpu_ref_kill_and_confirm);
 
        ref->percpu_count_ptr &= ~__PERCPU_REF_DEAD;
        percpu_ref_get(ref);
-       if (!ref->force_atomic)
-               __percpu_ref_switch_to_percpu(ref);
+       __percpu_ref_switch_mode(ref, NULL);
 }
 EXPORT_SYMBOL_GPL(percpu_ref_reinit);