return kvm->arch.nested_guests[lpid];
 }
 
+static pte_t *find_kvm_nested_guest_pte(struct kvm *kvm, unsigned long lpid,
+                                       unsigned long ea, unsigned *hshift)
+{
+       struct kvm_nested_guest *gp;
+       pte_t *pte;
+
+       gp = kvmhv_find_nested(kvm, lpid);
+       if (!gp)
+               return NULL;
+
+       VM_WARN(!spin_is_locked(&kvm->mmu_lock),
+               "%s called with kvm mmu_lock not held \n", __func__);
+       pte = __find_linux_pte(gp->shadow_pgtable, ea, NULL, hshift);
+
+       return pte;
+}
+
+
 static inline bool kvmhv_n_rmap_is_equal(u64 rmap_1, u64 rmap_2)
 {
        return !((rmap_1 ^ rmap_2) & (RMAP_NESTED_LPID_MASK |
                                      unsigned long clr, unsigned long set,
                                      unsigned long hpa, unsigned long mask)
 {
-       struct kvm_nested_guest *gp;
        unsigned long gpa;
        unsigned int shift, lpid;
        pte_t *ptep;
 
        gpa = n_rmap & RMAP_NESTED_GPA_MASK;
        lpid = (n_rmap & RMAP_NESTED_LPID_MASK) >> RMAP_NESTED_LPID_SHIFT;
-       gp = kvmhv_find_nested(kvm, lpid);
-       if (!gp)
-               return;
 
        /* Find the pte */
-       ptep = __find_linux_pte(gp->shadow_pgtable, gpa, NULL, &shift);
+       ptep = find_kvm_nested_guest_pte(kvm, lpid, gpa, &shift);
        /*
         * If the pte is present and the pfn is still the same, update the pte.
         * If the pfn has changed then this is a stale rmap entry, the nested
                return;
 
        /* Find and invalidate the pte */
-       ptep = __find_linux_pte(gp->shadow_pgtable, gpa, NULL, &shift);
+       ptep = find_kvm_nested_guest_pte(kvm, lpid, gpa, &shift);
        /* Don't spuriously invalidate ptes if the pfn has changed */
        if (ptep && pte_present(*ptep) && ((pte_val(*ptep) & mask) == hpa))
                kvmppc_unmap_pte(kvm, ptep, gpa, shift, NULL, gp->shadow_lpid);
        int shift;
 
        spin_lock(&kvm->mmu_lock);
-       ptep = __find_linux_pte(gp->shadow_pgtable, gpa, NULL, &shift);
+       ptep = find_kvm_nested_guest_pte(kvm, gp->l1_lpid, gpa, &shift);
        if (!shift)
                shift = PAGE_SHIFT;
        if (ptep && pte_present(*ptep)) {