static void dump_command(unsigned long phys_addr)
 {
-       struct iommu_cmd *cmd = phys_to_virt(phys_addr);
+       struct iommu_cmd *cmd = iommu_phys_to_virt(phys_addr);
        int i;
 
        for (i = 0; i < 4; ++i)
 
 static void build_completion_wait(struct iommu_cmd *cmd, u64 address)
 {
+       u64 paddr = iommu_virt_to_phys((void *)address);
+
        WARN_ON(address & 0x7ULL);
 
        memset(cmd, 0, sizeof(*cmd));
-       cmd->data[0] = lower_32_bits(__pa(address)) | CMD_COMPL_WAIT_STORE_MASK;
-       cmd->data[1] = upper_32_bits(__pa(address));
+       cmd->data[0] = lower_32_bits(paddr) | CMD_COMPL_WAIT_STORE_MASK;
+       cmd->data[1] = upper_32_bits(paddr);
        cmd->data[2] = 1;
        CMD_SET_TYPE(cmd, CMD_COMPL_WAIT);
 }
                return false;
 
        *pte             = PM_LEVEL_PDE(domain->mode,
-                                       virt_to_phys(domain->pt_root));
+                                       iommu_virt_to_phys(domain->pt_root));
        domain->pt_root  = pte;
        domain->mode    += 1;
        domain->updated  = true;
                        if (!page)
                                return NULL;
 
-                       __npte = PM_LEVEL_PDE(level, virt_to_phys(page));
+                       __npte = PM_LEVEL_PDE(level, iommu_virt_to_phys(page));
 
                        /* pte could have been changed somewhere. */
                        if (cmpxchg64(pte, __pte, __npte) != __pte) {
                        return -EBUSY;
 
        if (count > 1) {
-               __pte = PAGE_SIZE_PTE(phys_addr, page_size);
+               __pte = PAGE_SIZE_PTE(__sme_set(phys_addr), page_size);
                __pte |= PM_LEVEL_ENC(7) | IOMMU_PTE_P | IOMMU_PTE_FC;
        } else
-               __pte = phys_addr | IOMMU_PTE_P | IOMMU_PTE_FC;
+               __pte = __sme_set(phys_addr) | IOMMU_PTE_P | IOMMU_PTE_FC;
 
        if (prot & IOMMU_PROT_IR)
                __pte |= IOMMU_PTE_IR;
                if (!(tbl[i] & GCR3_VALID))
                        continue;
 
-               ptr = __va(tbl[i] & PAGE_MASK);
+               ptr = iommu_phys_to_virt(tbl[i] & PAGE_MASK);
 
                free_page((unsigned long)ptr);
        }
                if (!(tbl[i] & GCR3_VALID))
                        continue;
 
-               ptr = __va(tbl[i] & PAGE_MASK);
+               ptr = iommu_phys_to_virt(tbl[i] & PAGE_MASK);
 
                free_gcr3_tbl_level1(ptr);
        }
        u64 flags = 0;
 
        if (domain->mode != PAGE_MODE_NONE)
-               pte_root = virt_to_phys(domain->pt_root);
+               pte_root = iommu_virt_to_phys(domain->pt_root);
 
        pte_root |= (domain->mode & DEV_ENTRY_MODE_MASK)
                    << DEV_ENTRY_MODE_SHIFT;
                flags |= DTE_FLAG_IOTLB;
 
        if (domain->flags & PD_IOMMUV2_MASK) {
-               u64 gcr3 = __pa(domain->gcr3_tbl);
+               u64 gcr3 = iommu_virt_to_phys(domain->gcr3_tbl);
                u64 glx  = domain->glx;
                u64 tmp;
 
                        if (root == NULL)
                                return NULL;
 
-                       *pte = __pa(root) | GCR3_VALID;
+                       *pte = iommu_virt_to_phys(root) | GCR3_VALID;
                }
 
-               root = __va(*pte & PAGE_MASK);
+               root = iommu_phys_to_virt(*pte & PAGE_MASK);
 
                level -= 1;
        }
 
        dte     = amd_iommu_dev_table[devid].data[2];
        dte     &= ~DTE_IRQ_PHYS_ADDR_MASK;
-       dte     |= virt_to_phys(table->table);
+       dte     |= iommu_virt_to_phys(table->table);
        dte     |= DTE_IRQ_REMAP_INTCTL;
        dte     |= DTE_IRQ_TABLE_LEN;
        dte     |= DTE_IRQ_REMAP_ENABLE;
 
 #include <linux/iommu.h>
 #include <linux/kmemleak.h>
 #include <linux/crash_dump.h>
+#include <linux/mem_encrypt.h>
 #include <asm/pci-direct.h>
 #include <asm/iommu.h>
 #include <asm/gart.h>
 
        BUG_ON(iommu->mmio_base == NULL);
 
-       entry = virt_to_phys(amd_iommu_dev_table);
+       entry = iommu_virt_to_phys(amd_iommu_dev_table);
        entry |= (dev_table_size >> 12) - 1;
        memcpy_toio(iommu->mmio_base + MMIO_DEV_TABLE_OFFSET,
                        &entry, sizeof(entry));
 
        BUG_ON(iommu->cmd_buf == NULL);
 
-       entry = (u64)virt_to_phys(iommu->cmd_buf);
+       entry = iommu_virt_to_phys(iommu->cmd_buf);
        entry |= MMIO_CMD_SIZE_512;
 
        memcpy_toio(iommu->mmio_base + MMIO_CMD_BUF_OFFSET,
 
        BUG_ON(iommu->evt_buf == NULL);
 
-       entry = (u64)virt_to_phys(iommu->evt_buf) | EVT_LEN_MASK;
+       entry = iommu_virt_to_phys(iommu->evt_buf) | EVT_LEN_MASK;
 
        memcpy_toio(iommu->mmio_base + MMIO_EVT_BUF_OFFSET,
                    &entry, sizeof(entry));
        if (iommu->ppr_log == NULL)
                return;
 
-       entry = (u64)virt_to_phys(iommu->ppr_log) | PPR_LOG_SIZE_512;
+       entry = iommu_virt_to_phys(iommu->ppr_log) | PPR_LOG_SIZE_512;
 
        memcpy_toio(iommu->mmio_base + MMIO_PPR_LOG_OFFSET,
                    &entry, sizeof(entry));
        if (!iommu->ga_log_tail)
                goto err_out;
 
-       entry = (u64)virt_to_phys(iommu->ga_log) | GA_LOG_SIZE_512;
+       entry = iommu_virt_to_phys(iommu->ga_log) | GA_LOG_SIZE_512;
        memcpy_toio(iommu->mmio_base + MMIO_GA_LOG_BASE_OFFSET,
                    &entry, sizeof(entry));
-       entry = ((u64)virt_to_phys(iommu->ga_log) & 0xFFFFFFFFFFFFFULL) & ~7ULL;
+       entry = (iommu_virt_to_phys(iommu->ga_log) & 0xFFFFFFFFFFFFFULL) & ~7ULL;
        memcpy_toio(iommu->mmio_base + MMIO_GA_LOG_TAIL_OFFSET,
                    &entry, sizeof(entry));
        writel(0x00, iommu->mmio_base + MMIO_GA_HEAD_OFFSET);
        return ret;
 }
 
+static bool amd_iommu_sme_check(void)
+{
+       if (!sme_active() || (boot_cpu_data.x86 != 0x17))
+               return true;
+
+       /* For Fam17h, a specific level of support is required */
+       if (boot_cpu_data.microcode >= 0x08001205)
+               return true;
+
+       if ((boot_cpu_data.microcode >= 0x08001126) &&
+           (boot_cpu_data.microcode <= 0x080011ff))
+               return true;
+
+       pr_notice("AMD-Vi: IOMMU not currently supported when SME is active\n");
+
+       return false;
+}
+
 /****************************************************************************
  *
  * Early detect code. This code runs at IOMMU detection time in the DMA
        if (no_iommu || (iommu_detected && !gart_iommu_aperture))
                return -ENODEV;
 
+       if (!amd_iommu_sme_check())
+               return -ENODEV;
+
        ret = iommu_go_to_state(IOMMU_IVRS_DETECTED);
        if (ret)
                return ret;