EXPORT_SYMBOL_GPL(_copy_mc_to_iter);
 #endif /* CONFIG_ARCH_HAS_COPY_MC */
 
-static size_t memcpy_from_iter_mc(void *iter_from, size_t progress,
-                                 size_t len, void *to, void *priv2)
+static __always_inline
+size_t memcpy_from_iter_mc(void *iter_from, size_t progress,
+                          size_t len, void *to, void *priv2)
+{
+       return copy_mc_to_kernel(to + progress, iter_from, len);
+}
+
+static size_t __copy_from_iter_mc(void *addr, size_t bytes, struct iov_iter *i)
 {
-       struct iov_iter *iter = priv2;
+       if (unlikely(i->count < bytes))
+               bytes = i->count;
+       if (unlikely(!bytes))
+               return 0;
+       return iterate_bvec(i, bytes, addr, NULL, memcpy_from_iter_mc);
+}
 
-       if (iov_iter_is_copy_mc(iter))
-               return copy_mc_to_kernel(to + progress, iter_from, len);
-       return memcpy_from_iter(iter_from, progress, len, to, priv2);
+static __always_inline
+size_t __copy_from_iter(void *addr, size_t bytes, struct iov_iter *i)
+{
+       if (unlikely(iov_iter_is_copy_mc(i)))
+               return __copy_from_iter_mc(addr, bytes, i);
+       return iterate_and_advance(i, bytes, addr,
+                                  copy_from_user_iter, memcpy_from_iter);
 }
 
 size_t _copy_from_iter(void *addr, size_t bytes, struct iov_iter *i)
 
        if (user_backed_iter(i))
                might_fault();
-       return iterate_and_advance2(i, bytes, addr, i,
-                                   copy_from_user_iter,
-                                   memcpy_from_iter_mc);
+       return __copy_from_iter(addr, bytes, i);
 }
 EXPORT_SYMBOL(_copy_from_iter);
 
                }
 
                p = kmap_atomic(page) + offset;
-               n = iterate_and_advance2(i, n, p, i,
-                                        copy_from_user_iter,
-                                        memcpy_from_iter_mc);
+               n = __copy_from_iter(p, n, i);
                kunmap_atomic(p);
                copied += n;
                offset += n;