{
        struct kvm_host_map map;
        struct kvm_steal_time *st;
+       int idx;
 
        if (!(vcpu->arch.st.msr_val & KVM_MSR_ENABLED))
                return;
        if (vcpu->arch.st.preempted)
                return;
 
+       /*
+        * Take the srcu lock as memslots will be accessed to check the gfn
+        * cache generation against the memslots generation.
+        */
+       idx = srcu_read_lock(&vcpu->kvm->srcu);
+
        if (kvm_map_gfn(vcpu, vcpu->arch.st.msr_val >> PAGE_SHIFT, &map,
                        &vcpu->arch.st.cache, true))
-               return;
+               goto out;
 
        st = map.hva +
                offset_in_page(vcpu->arch.st.msr_val & KVM_STEAL_VALID_BITS);
        st->preempted = vcpu->arch.st.preempted = KVM_VCPU_PREEMPTED;
 
        kvm_unmap_gfn(vcpu, &map, &vcpu->arch.st.cache, true, true);
+
+out:
+       srcu_read_unlock(&vcpu->kvm->srcu, idx);
 }
 
 void kvm_arch_vcpu_put(struct kvm_vcpu *vcpu)
 {
-       int idx;
-
        if (vcpu->preempted && !vcpu->arch.guest_state_protected)
                vcpu->arch.preempted_in_kernel = !kvm_x86_ops.get_cpl(vcpu);
 
-       /*
-        * kvm_memslots() will be called by
-        * kvm_write_guest_offset_cached() so take the srcu lock.
-        */
-       idx = srcu_read_lock(&vcpu->kvm->srcu);
        kvm_steal_time_set_preempted(vcpu);
-       srcu_read_unlock(&vcpu->kvm->srcu, idx);
        kvm_x86_ops.vcpu_put(vcpu);
        vcpu->arch.last_host_tsc = rdtsc();
        /*