int kvm_cpu_dirty_log_size(void);
 
+int alloc_all_memslots_rmaps(struct kvm *kvm);
+
 #endif /* _ASM_X86_KVM_HOST_H */
 
 
 static inline bool kvm_memslots_have_rmaps(struct kvm *kvm)
 {
-       return kvm->arch.memslots_have_rmaps;
+       /*
+        * Read memslot_have_rmaps before rmap pointers.  Hence, threads reading
+        * memslots_have_rmaps in any lock context are guaranteed to see the
+        * pointers.  Pairs with smp_store_release in alloc_all_memslots_rmaps.
+        */
+       return smp_load_acquire(&kvm->arch.memslots_have_rmaps);
 }
 
 #endif
 
                }
        }
 
+       r = alloc_all_memslots_rmaps(vcpu->kvm);
+       if (r)
+               return r;
+
        write_lock(&vcpu->kvm->mmu_lock);
        r = make_mmu_pages_available(vcpu);
        if (r < 0)
 {
        struct kvm_page_track_notifier_node *node = &kvm->arch.mmu_sp_tracker;
 
-       kvm_mmu_init_tdp_mmu(kvm);
-
-       kvm->arch.memslots_have_rmaps = true;
+       if (!kvm_mmu_init_tdp_mmu(kvm))
+               /*
+                * No smp_load/store wrappers needed here as we are in
+                * VM init and there cannot be any memslots / other threads
+                * accessing this struct kvm yet.
+                */
+               kvm->arch.memslots_have_rmaps = true;
 
        node->track_write = kvm_mmu_pte_write;
        node->track_flush_slot = kvm_mmu_invalidate_zap_pages_in_memslot;
 
 module_param_named(tdp_mmu, tdp_mmu_enabled, bool, 0644);
 
 /* Initializes the TDP MMU for the VM, if enabled. */
-void kvm_mmu_init_tdp_mmu(struct kvm *kvm)
+bool kvm_mmu_init_tdp_mmu(struct kvm *kvm)
 {
        if (!tdp_enabled || !READ_ONCE(tdp_mmu_enabled))
-               return;
+               return false;
 
        /* This should not be changed for the lifetime of the VM. */
        kvm->arch.tdp_mmu_enabled = true;
        INIT_LIST_HEAD(&kvm->arch.tdp_mmu_roots);
        spin_lock_init(&kvm->arch.tdp_mmu_pages_lock);
        INIT_LIST_HEAD(&kvm->arch.tdp_mmu_pages);
+
+       return true;
 }
 
 static __always_inline void kvm_lockdep_assert_mmu_lock_held(struct kvm *kvm,
 
                         int *root_level);
 
 #ifdef CONFIG_X86_64
-void kvm_mmu_init_tdp_mmu(struct kvm *kvm);
+bool kvm_mmu_init_tdp_mmu(struct kvm *kvm);
 void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm);
 static inline bool is_tdp_mmu_enabled(struct kvm *kvm) { return kvm->arch.tdp_mmu_enabled; }
 static inline bool is_tdp_mmu_page(struct kvm_mmu_page *sp) { return sp->tdp_mmu_page; }
 #else
-static inline void kvm_mmu_init_tdp_mmu(struct kvm *kvm) {}
+static inline bool kvm_mmu_init_tdp_mmu(struct kvm *kvm) { return false; }
 static inline void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm) {}
 static inline bool is_tdp_mmu_enabled(struct kvm *kvm) { return false; }
 static inline bool is_tdp_mmu_page(struct kvm_mmu_page *sp) { return false; }
 
                int lpages = gfn_to_index(slot->base_gfn + npages - 1,
                                          slot->base_gfn, level) + 1;
 
+               WARN_ON(slot->arch.rmap[i]);
+
                slot->arch.rmap[i] = kvcalloc(lpages, sz, GFP_KERNEL_ACCOUNT);
                if (!slot->arch.rmap[i]) {
                        memslot_rmap_free(slot);
        return 0;
 }
 
+int alloc_all_memslots_rmaps(struct kvm *kvm)
+{
+       struct kvm_memslots *slots;
+       struct kvm_memory_slot *slot;
+       int r, i;
+
+       /*
+        * Check if memslots alreday have rmaps early before acquiring
+        * the slots_arch_lock below.
+        */
+       if (kvm_memslots_have_rmaps(kvm))
+               return 0;
+
+       mutex_lock(&kvm->slots_arch_lock);
+
+       /*
+        * Read memslots_have_rmaps again, under the slots arch lock,
+        * before allocating the rmaps
+        */
+       if (kvm_memslots_have_rmaps(kvm)) {
+               mutex_unlock(&kvm->slots_arch_lock);
+               return 0;
+       }
+
+       for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
+               slots = __kvm_memslots(kvm, i);
+               kvm_for_each_memslot(slot, slots) {
+                       r = memslot_rmap_alloc(slot, slot->npages);
+                       if (r) {
+                               mutex_unlock(&kvm->slots_arch_lock);
+                               return r;
+                       }
+               }
+       }
+
+       /*
+        * Ensure that memslots_have_rmaps becomes true strictly after
+        * all the rmap pointers are set.
+        */
+       smp_store_release(&kvm->arch.memslots_have_rmaps, true);
+       mutex_unlock(&kvm->slots_arch_lock);
+       return 0;
+}
+
 static int kvm_alloc_memslot_metadata(struct kvm *kvm,
                                      struct kvm_memory_slot *slot,
                                      unsigned long npages)