static int attempt_writeback(const char *cgroup, void *arg)
 {
        long pagesize = sysconf(_SC_PAGESIZE);
-       char *test_group = arg;
        size_t memsize = MB(4);
        char buf[pagesize];
        long zswap_usage;
-       bool wb_enabled;
+       bool wb_enabled = *(bool *) arg;
        int ret = -1;
        char *mem;
 
-       wb_enabled = cg_read_long(test_group, "memory.zswap.writeback");
        mem = (char *)malloc(memsize);
        if (!mem)
                return ret;
                memcpy(&mem[i], buf, pagesize);
 
        /* Try and reclaim allocated memory */
-       if (cg_write_numeric(test_group, "memory.reclaim", memsize)) {
+       if (cg_write_numeric(cgroup, "memory.reclaim", memsize)) {
                ksft_print_msg("Failed to reclaim all of the requested memory\n");
                goto out;
        }
 
-       zswap_usage = cg_read_long(test_group, "memory.zswap.current");
+       zswap_usage = cg_read_long(cgroup, "memory.zswap.current");
 
        /* zswpin */
        for (int i = 0; i < memsize; i += pagesize) {
                }
        }
 
-       if (cg_write_numeric(test_group, "memory.zswap.max", zswap_usage/2))
+       if (cg_write_numeric(cgroup, "memory.zswap.max", zswap_usage/2))
                goto out;
 
        /*
         * If writeback is disabled, memory reclaim will fail as zswap is limited and
         * it can't writeback to swap.
         */
-       ret = cg_write_numeric(test_group, "memory.reclaim", memsize);
+       ret = cg_write_numeric(cgroup, "memory.reclaim", memsize);
        if (!wb_enabled)
                ret = (ret == -EAGAIN) ? 0 : -1;
 
        return ret;
 }
 
+static int test_zswap_writeback_one(const char *cgroup, bool wb)
+{
+       long zswpwb_before, zswpwb_after;
+
+       zswpwb_before = get_cg_wb_count(cgroup);
+       if (zswpwb_before != 0) {
+               ksft_print_msg("zswpwb_before = %ld instead of 0\n", zswpwb_before);
+               return -1;
+       }
+
+       if (cg_run(cgroup, attempt_writeback, (void *) &wb))
+               return -1;
+
+       /* Verify that zswap writeback occurred only if writeback was enabled */
+       zswpwb_after = get_cg_wb_count(cgroup);
+       if (zswpwb_after < 0)
+               return -1;
+
+       if (wb != !!zswpwb_after) {
+               ksft_print_msg("zswpwb_after is %ld while wb is %s",
+                               zswpwb_after, wb ? "enabled" : "disabled");
+               return -1;
+       }
+
+       return 0;
+}
+
 /* Test to verify the zswap writeback path */
 static int test_zswap_writeback(const char *root, bool wb)
 {
-       long zswpwb_before, zswpwb_after;
        int ret = KSFT_FAIL;
-       char *test_group;
+       char *test_group, *test_group_child = NULL;
+
+       if (cg_read_strcmp(root, "memory.zswap.writeback", "1"))
+               return KSFT_SKIP;
 
        test_group = cg_name(root, "zswap_writeback_test");
        if (!test_group)
        if (cg_write(test_group, "memory.zswap.writeback", wb ? "1" : "0"))
                goto out;
 
-       zswpwb_before = get_cg_wb_count(test_group);
-       if (zswpwb_before != 0) {
-               ksft_print_msg("zswpwb_before = %ld instead of 0\n", zswpwb_before);
+       if (test_zswap_writeback_one(test_group, wb))
                goto out;
-       }
 
-       if (cg_run(test_group, attempt_writeback, (void *) test_group))
+       /* Reset memory.zswap.max to max (modified by attempt_writeback), and
+        * set up child cgroup, whose memory.zswap.writeback is hardcoded to 1.
+        * Thus, the parent's setting shall be what's in effect. */
+       if (cg_write(test_group, "memory.zswap.max", "max"))
+               goto out;
+       if (cg_write(test_group, "cgroup.subtree_control", "+memory"))
                goto out;
 
-       /* Verify that zswap writeback occurred only if writeback was enabled */
-       zswpwb_after = get_cg_wb_count(test_group);
-       if (zswpwb_after < 0)
+       test_group_child = cg_name(test_group, "zswap_writeback_test_child");
+       if (!test_group_child)
+               goto out;
+       if (cg_create(test_group_child))
+               goto out;
+       if (cg_write(test_group_child, "memory.zswap.writeback", "1"))
                goto out;
 
-       if (wb != !!zswpwb_after) {
-               ksft_print_msg("zswpwb_after is %ld while wb is %s",
-                               zswpwb_after, wb ? "enabled" : "disabled");
+       if (test_zswap_writeback_one(test_group_child, wb))
                goto out;
-       }
 
        ret = KSFT_PASS;
 
 out:
+       if (test_group_child) {
+               cg_destroy(test_group_child);
+               free(test_group_child);
+       }
        cg_destroy(test_group);
        free(test_group);
        return ret;