.ad_disabled = 1,
 };
 
-#define for_each_shadow_entry(_vcpu, _addr, _walker)    \
+#define for_each_shadow_entry_using_root(_vcpu, _root, _addr, _walker)     \
+       for (shadow_walk_init_using_root(&(_walker), (_vcpu),              \
+                                        (_root), (_addr));                \
+            shadow_walk_okay(&(_walker));                                 \
+            shadow_walk_next(&(_walker)))
+
+#define for_each_shadow_entry(_vcpu, _addr, _walker)            \
        for (shadow_walk_init(&(_walker), _vcpu, _addr);        \
             shadow_walk_okay(&(_walker));                      \
             shadow_walk_next(&(_walker)))
        return 0;
 }
 
-static void nonpaging_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
+static void nonpaging_invlpg(struct kvm_vcpu *vcpu, gva_t gva, hpa_t root)
 {
 }
 
        return sp;
 }
 
-static void shadow_walk_init(struct kvm_shadow_walk_iterator *iterator,
-                            struct kvm_vcpu *vcpu, u64 addr)
+static void shadow_walk_init_using_root(struct kvm_shadow_walk_iterator *iterator,
+                                       struct kvm_vcpu *vcpu, hpa_t root,
+                                       u64 addr)
 {
        iterator->addr = addr;
-       iterator->shadow_addr = vcpu->arch.mmu.root_hpa;
+       iterator->shadow_addr = root;
        iterator->level = vcpu->arch.mmu.shadow_root_level;
 
        if (iterator->level == PT64_ROOT_4LEVEL &&
                --iterator->level;
 
        if (iterator->level == PT32E_ROOT_LEVEL) {
+               /*
+                * prev_root is currently only used for 64-bit hosts. So only
+                * the active root_hpa is valid here.
+                */
+               BUG_ON(root != vcpu->arch.mmu.root_hpa);
+
                iterator->shadow_addr
                        = vcpu->arch.mmu.pae_root[(addr >> 30) & 3];
                iterator->shadow_addr &= PT64_BASE_ADDR_MASK;
        }
 }
 
+static void shadow_walk_init(struct kvm_shadow_walk_iterator *iterator,
+                            struct kvm_vcpu *vcpu, u64 addr)
+{
+       shadow_walk_init_using_root(iterator, vcpu, vcpu->arch.mmu.root_hpa,
+                                   addr);
+}
+
 static bool shadow_walk_okay(struct kvm_shadow_walk_iterator *iterator)
 {
        if (iterator->level < PT_PAGE_TABLE_LEVEL)
 
 void kvm_mmu_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
 {
-       vcpu->arch.mmu.invlpg(vcpu, gva);
+       struct kvm_mmu *mmu = &vcpu->arch.mmu;
+
+       mmu->invlpg(vcpu, gva, mmu->root_hpa);
        kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
        ++vcpu->stat.invlpg;
 }
        struct kvm_mmu *mmu = &vcpu->arch.mmu;
 
        if (pcid == kvm_get_active_pcid(vcpu)) {
-               mmu->invlpg(vcpu, gva);
+               mmu->invlpg(vcpu, gva, mmu->root_hpa);
                kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
        }
 
 
        return gfn_to_gpa(sp->gfn) + offset * sizeof(pt_element_t);
 }
 
-static void FNAME(invlpg)(struct kvm_vcpu *vcpu, gva_t gva)
+static void FNAME(invlpg)(struct kvm_vcpu *vcpu, gva_t gva, hpa_t root_hpa)
 {
        struct kvm_shadow_walk_iterator iterator;
        struct kvm_mmu_page *sp;
         */
        mmu_topup_memory_caches(vcpu);
 
-       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa)) {
+       if (!VALID_PAGE(root_hpa)) {
                WARN_ON(1);
                return;
        }
 
        spin_lock(&vcpu->kvm->mmu_lock);
-       for_each_shadow_entry(vcpu, gva, iterator) {
+       for_each_shadow_entry_using_root(vcpu, root_hpa, gva, iterator) {
                level = iterator.level;
                sptep = iterator.sptep;