#include <linux/fs.h>
 #include <linux/fs_struct.h>
 #include <linux/psp.h>
+#include <linux/amd-iommu.h>
 
 #include <asm/smp.h>
 #include <asm/cacheflush.h>
                return ret;
        }
 
+       /*
+        * SNP_SHUTDOWN_EX with IOMMU_SNP_SHUTDOWN set to 1 disables SNP
+        * enforcement by the IOMMU and also transitions all pages
+        * associated with the IOMMU to the Reclaim state.
+        * Firmware was transitioning the IOMMU pages to Hypervisor state
+        * before version 1.53. But, accounting for the number of assigned
+        * 4kB pages in a 2M page was done incorrectly by not transitioning
+        * to the Reclaim state. This resulted in RMP #PF when later accessing
+        * the 2M page containing those pages during kexec boot. Hence, the
+        * firmware now transitions these pages to Reclaim state and hypervisor
+        * needs to transition these pages to shared state. SNP Firmware
+        * version 1.53 and above are needed for kexec boot.
+        */
+       ret = amd_iommu_snp_disable();
+       if (ret) {
+               dev_err(sev->dev, "SNP IOMMU shutdown failed\n");
+               return ret;
+       }
+
        sev->snp_initialized = false;
        dev_dbg(sev->dev, "SEV-SNP firmware shutdown\n");
 
 
 #include <asm/io_apic.h>
 #include <asm/irq_remapping.h>
 #include <asm/set_memory.h>
+#include <asm/sev.h>
 
 #include <linux/crash_dump.h>
 
 
        return iommu_pc_get_set_reg(iommu, bank, cntr, fxn, value, true);
 }
+
+#ifdef CONFIG_KVM_AMD_SEV
+static int iommu_page_make_shared(void *page)
+{
+       unsigned long paddr, pfn;
+
+       paddr = iommu_virt_to_phys(page);
+       /* Cbit maybe set in the paddr */
+       pfn = __sme_clr(paddr) >> PAGE_SHIFT;
+
+       if (!(pfn % PTRS_PER_PMD)) {
+               int ret, level;
+               bool assigned;
+
+               ret = snp_lookup_rmpentry(pfn, &assigned, &level);
+               if (ret)
+                       pr_warn("IOMMU PFN %lx RMP lookup failed, ret %d\n",
+                               pfn, ret);
+
+               if (!assigned)
+                       pr_warn("IOMMU PFN %lx not assigned in RMP table\n",
+                               pfn);
+
+               if (level > PG_LEVEL_4K) {
+                       ret = psmash(pfn);
+                       if (ret) {
+                               pr_warn("IOMMU PFN %lx had a huge RMP entry, but attempted psmash failed, ret: %d, level: %d\n",
+                                       pfn, ret, level);
+                       }
+               }
+       }
+
+       return rmp_make_shared(pfn, PG_LEVEL_4K);
+}
+
+static int iommu_make_shared(void *va, size_t size)
+{
+       void *page;
+       int ret;
+
+       if (!va)
+               return 0;
+
+       for (page = va; page < (va + size); page += PAGE_SIZE) {
+               ret = iommu_page_make_shared(page);
+               if (ret)
+                       return ret;
+       }
+
+       return 0;
+}
+
+int amd_iommu_snp_disable(void)
+{
+       struct amd_iommu *iommu;
+       int ret;
+
+       if (!amd_iommu_snp_en)
+               return 0;
+
+       for_each_iommu(iommu) {
+               ret = iommu_make_shared(iommu->evt_buf, EVT_BUFFER_SIZE);
+               if (ret)
+                       return ret;
+
+               ret = iommu_make_shared(iommu->ppr_log, PPR_LOG_SIZE);
+               if (ret)
+                       return ret;
+
+               ret = iommu_make_shared((void *)iommu->cmd_sem, PAGE_SIZE);
+               if (ret)
+                       return ret;
+       }
+
+       return 0;
+}
+EXPORT_SYMBOL_GPL(amd_iommu_snp_disable);
+#endif