static void *stage2_memcache_zalloc_page(void *arg)
 {
        struct kvm_mmu_memory_cache *mc = arg;
+       void *virt;
 
        /* Allocated with __GFP_ZERO, so no need to zero */
-       return kvm_mmu_memory_cache_alloc(mc);
+       virt = kvm_mmu_memory_cache_alloc(mc);
+       if (virt)
+               kvm_account_pgtable_pages(virt, 1);
+       return virt;
 }
 
 static void *kvm_host_zalloc_pages_exact(size_t size)
        return alloc_pages_exact(size, GFP_KERNEL_ACCOUNT | __GFP_ZERO);
 }
 
+static void *kvm_s2_zalloc_pages_exact(size_t size)
+{
+       void *virt = kvm_host_zalloc_pages_exact(size);
+
+       if (virt)
+               kvm_account_pgtable_pages(virt, (size >> PAGE_SHIFT));
+       return virt;
+}
+
+static void kvm_s2_free_pages_exact(void *virt, size_t size)
+{
+       kvm_account_pgtable_pages(virt, -(size >> PAGE_SHIFT));
+       free_pages_exact(virt, size);
+}
+
 static void kvm_host_get_page(void *addr)
 {
        get_page(virt_to_page(addr));
        put_page(virt_to_page(addr));
 }
 
+static void kvm_s2_put_page(void *addr)
+{
+       struct page *p = virt_to_page(addr);
+       /* Dropping last refcount, the page will be freed */
+       if (page_count(p) == 1)
+               kvm_account_pgtable_pages(addr, -1);
+       put_page(p);
+}
+
 static int kvm_host_page_count(void *addr)
 {
        return page_count(virt_to_page(addr));
 
 static struct kvm_pgtable_mm_ops kvm_s2_mm_ops = {
        .zalloc_page            = stage2_memcache_zalloc_page,
-       .zalloc_pages_exact     = kvm_host_zalloc_pages_exact,
-       .free_pages_exact       = free_pages_exact,
+       .zalloc_pages_exact     = kvm_s2_zalloc_pages_exact,
+       .free_pages_exact       = kvm_s2_free_pages_exact,
        .get_page               = kvm_host_get_page,
-       .put_page               = kvm_host_put_page,
+       .put_page               = kvm_s2_put_page,
        .page_count             = kvm_host_page_count,
        .phys_to_virt           = kvm_host_va,
        .virt_to_phys           = kvm_host_pa,