return !!(*irq_state);
 }
 
+#define KVM_MMU_ROOT_CURRENT   BIT(0)
+#define KVM_MMU_ROOT_PREVIOUS  BIT(1)
+#define KVM_MMU_ROOTS_ALL      (~0UL)
+
 int kvm_pic_set_irq(struct kvm_pic *pic, int irq, int irq_source_id, int level);
 void kvm_pic_clear_all(struct kvm_pic *pic, int irq_source_id);
 
 int kvm_mmu_load(struct kvm_vcpu *vcpu);
 void kvm_mmu_unload(struct kvm_vcpu *vcpu);
 void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu);
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, bool free_prev_root);
+void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, ulong roots_to_free);
 gpa_t translate_nested_gpa(struct kvm_vcpu *vcpu, gpa_t gpa, u32 access,
                           struct x86_exception *exception);
 gpa_t kvm_mmu_gva_to_gpa_read(struct kvm_vcpu *vcpu, gva_t gva,
 
        *root_hpa = INVALID_PAGE;
 }
 
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, bool free_prev_root)
+/* roots_to_free must be some combination of the KVM_MMU_ROOT_* flags */
+void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, ulong roots_to_free)
 {
        int i;
        LIST_HEAD(invalid_list);
        struct kvm_mmu *mmu = &vcpu->arch.mmu;
+       bool free_active_root = roots_to_free & KVM_MMU_ROOT_CURRENT;
+       bool free_prev_root = roots_to_free & KVM_MMU_ROOT_PREVIOUS;
 
-       if (!VALID_PAGE(mmu->root_hpa) &&
-           (!VALID_PAGE(mmu->prev_root.hpa) || !free_prev_root))
+       /* Before acquiring the MMU lock, see if we need to do any real work. */
+       if (!(free_active_root && VALID_PAGE(mmu->root_hpa)) &&
+           !(free_prev_root && VALID_PAGE(mmu->prev_root.hpa)))
                return;
 
        spin_lock(&vcpu->kvm->mmu_lock);
                mmu_free_root_page(vcpu->kvm, &mmu->prev_root.hpa,
                                   &invalid_list);
 
-       if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
-           (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
-               mmu_free_root_page(vcpu->kvm, &mmu->root_hpa, &invalid_list);
-       } else {
-               for (i = 0; i < 4; ++i)
-                       if (mmu->pae_root[i] != 0)
-                               mmu_free_root_page(vcpu->kvm, &mmu->pae_root[i],
-                                                  &invalid_list);
-               mmu->root_hpa = INVALID_PAGE;
+       if (free_active_root) {
+               if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
+                   (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
+                       mmu_free_root_page(vcpu->kvm, &mmu->root_hpa,
+                                          &invalid_list);
+               } else {
+                       for (i = 0; i < 4; ++i)
+                               if (mmu->pae_root[i] != 0)
+                                       mmu_free_root_page(vcpu->kvm,
+                                                          &mmu->pae_root[i],
+                                                          &invalid_list);
+                       mmu->root_hpa = INVALID_PAGE;
+               }
        }
 
        kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
                              bool skip_tlb_flush)
 {
        if (!fast_cr3_switch(vcpu, new_cr3, new_role, skip_tlb_flush))
-               kvm_mmu_free_roots(vcpu, false);
+               kvm_mmu_free_roots(vcpu, KVM_MMU_ROOT_CURRENT);
 }
 
 void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3, bool skip_tlb_flush)
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
-       kvm_mmu_free_roots(vcpu, true);
+       kvm_mmu_free_roots(vcpu, KVM_MMU_ROOTS_ALL);
        WARN_ON(VALID_PAGE(vcpu->arch.mmu.root_hpa));
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unload);