page_counter_read(&memcg->memory);
 }
 
+void mem_cgroup_commit_charge(struct folio *folio, struct mem_cgroup *memcg);
+
 int __mem_cgroup_charge(struct folio *folio, struct mm_struct *mm, gfp_t gfp);
 
 /**
        __mem_cgroup_uncharge_list(page_list);
 }
 
+void mem_cgroup_cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages);
+
 void mem_cgroup_migrate(struct folio *old, struct folio *new);
 
 /**
 
 struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm);
 
+struct mem_cgroup *get_mem_cgroup_from_current(void);
+
 struct lruvec *folio_lruvec_lock(struct folio *folio);
 struct lruvec *folio_lruvec_lock_irq(struct folio *folio);
 struct lruvec *folio_lruvec_lock_irqsave(struct folio *folio,
        return false;
 }
 
+static inline void mem_cgroup_commit_charge(struct folio *folio,
+               struct mem_cgroup *memcg)
+{
+}
+
 static inline int mem_cgroup_charge(struct folio *folio,
                struct mm_struct *mm, gfp_t gfp)
 {
 {
 }
 
+static inline void mem_cgroup_cancel_charge(struct mem_cgroup *memcg,
+               unsigned int nr_pages)
+{
+}
+
 static inline void mem_cgroup_migrate(struct folio *old, struct folio *new)
 {
 }
        return NULL;
 }
 
+static inline struct mem_cgroup *get_mem_cgroup_from_current(void)
+{
+       return NULL;
+}
+
 static inline
 struct mem_cgroup *mem_cgroup_from_css(struct cgroup_subsys_state *css)
 {
 
        return false;
 }
 
+/**
+ * get_mem_cgroup_from_current - Obtain a reference on current task's memcg.
+ */
+struct mem_cgroup *get_mem_cgroup_from_current(void)
+{
+       struct mem_cgroup *memcg;
+
+       if (mem_cgroup_disabled())
+               return NULL;
+
+again:
+       rcu_read_lock();
+       memcg = mem_cgroup_from_task(current);
+       if (!css_tryget(&memcg->css)) {
+               rcu_read_unlock();
+               goto again;
+       }
+       rcu_read_unlock();
+       return memcg;
+}
+
 /**
  * mem_cgroup_iter - iterate over memory cgroup hierarchy
  * @root: hierarchy root
        return try_charge_memcg(memcg, gfp_mask, nr_pages);
 }
 
-static inline void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
+/**
+ * mem_cgroup_cancel_charge() - cancel an uncommitted try_charge() call.
+ * @memcg: memcg previously charged.
+ * @nr_pages: number of pages previously charged.
+ */
+void mem_cgroup_cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
        if (mem_cgroup_is_root(memcg))
                return;
        folio->memcg_data = (unsigned long)memcg;
 }
 
+/**
+ * mem_cgroup_commit_charge - commit a previously successful try_charge().
+ * @folio: folio to commit the charge to.
+ * @memcg: memcg previously charged.
+ */
+void mem_cgroup_commit_charge(struct folio *folio, struct mem_cgroup *memcg)
+{
+       css_get(&memcg->css);
+       commit_charge(folio, memcg);
+
+       local_irq_disable();
+       mem_cgroup_charge_statistics(memcg, folio_nr_pages(folio));
+       memcg_check_events(memcg, folio_nid(folio));
+       local_irq_enable();
+}
+
 #ifdef CONFIG_MEMCG_KMEM
 /*
  * The allocated objcg pointers array is not accounted directly.
 
        /* we must uncharge all the leftover precharges from mc.to */
        if (mc.precharge) {
-               cancel_charge(mc.to, mc.precharge);
+               mem_cgroup_cancel_charge(mc.to, mc.precharge);
                mc.precharge = 0;
        }
        /*
         * we must uncharge here.
         */
        if (mc.moved_charge) {
-               cancel_charge(mc.from, mc.moved_charge);
+               mem_cgroup_cancel_charge(mc.from, mc.moved_charge);
                mc.moved_charge = 0;
        }
        /* we must fixup refcnts and charges */
 static int charge_memcg(struct folio *folio, struct mem_cgroup *memcg,
                        gfp_t gfp)
 {
-       long nr_pages = folio_nr_pages(folio);
        int ret;
 
-       ret = try_charge(memcg, gfp, nr_pages);
+       ret = try_charge(memcg, gfp, folio_nr_pages(folio));
        if (ret)
                goto out;
 
-       css_get(&memcg->css);
-       commit_charge(folio, memcg);
-
-       local_irq_disable();
-       mem_cgroup_charge_statistics(memcg, nr_pages);
-       memcg_check_events(memcg, folio_nid(folio));
-       local_irq_enable();
+       mem_cgroup_commit_charge(folio, memcg);
 out:
        return ret;
 }