#include <asm/kvm_page_track.h>
 #include "trace.h"
 
+extern bool itlb_multihit_kvm_mitigation;
+
+static int __read_mostly nx_huge_pages = -1;
+
+static int set_nx_huge_pages(const char *val, const struct kernel_param *kp);
+
+static struct kernel_param_ops nx_huge_pages_ops = {
+       .set = set_nx_huge_pages,
+       .get = param_get_bool,
+};
+
+module_param_cb(nx_huge_pages, &nx_huge_pages_ops, &nx_huge_pages, 0644);
+__MODULE_PARM_TYPE(nx_huge_pages, "bool");
+
 /*
  * When setting this variable to true it enables Two-Dimensional-Paging
  * where the hardware walks 2 page tables:
        return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_ENABLED_MASK;
 }
 
+static bool is_nx_huge_page_enabled(void)
+{
+       return READ_ONCE(nx_huge_pages);
+}
+
 static inline u64 spte_shadow_accessed_mask(u64 spte)
 {
        MMU_WARN_ON(is_mmio_spte(spte));
        kvm_mmu_gfn_disallow_lpage(slot, gfn);
 }
 
+static void account_huge_nx_page(struct kvm *kvm, struct kvm_mmu_page *sp)
+{
+       if (sp->lpage_disallowed)
+               return;
+
+       ++kvm->stat.nx_lpage_splits;
+       sp->lpage_disallowed = true;
+}
+
 static void unaccount_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        struct kvm_memslots *slots;
        kvm_mmu_gfn_allow_lpage(slot, gfn);
 }
 
+static void unaccount_huge_nx_page(struct kvm *kvm, struct kvm_mmu_page *sp)
+{
+       --kvm->stat.nx_lpage_splits;
+       sp->lpage_disallowed = false;
+}
+
 static bool __mmu_gfn_lpage_is_disallowed(gfn_t gfn, int level,
                                          struct kvm_memory_slot *slot)
 {
                        kvm_reload_remote_mmus(kvm);
        }
 
+       if (sp->lpage_disallowed)
+               unaccount_huge_nx_page(kvm, sp);
+
        sp->role.invalid = 1;
        return list_unstable;
 }
        if (!speculative)
                spte |= spte_shadow_accessed_mask(spte);
 
+       if (level > PT_PAGE_TABLE_LEVEL && (pte_access & ACC_EXEC_MASK) &&
+           is_nx_huge_page_enabled()) {
+               pte_access &= ~ACC_EXEC_MASK;
+       }
+
        if (pte_access & ACC_EXEC_MASK)
                spte |= shadow_x_mask;
        else
        __direct_pte_prefetch(vcpu, sp, sptep);
 }
 
+static void disallowed_hugepage_adjust(struct kvm_shadow_walk_iterator it,
+                                      gfn_t gfn, kvm_pfn_t *pfnp, int *levelp)
+{
+       int level = *levelp;
+       u64 spte = *it.sptep;
+
+       if (it.level == level && level > PT_PAGE_TABLE_LEVEL &&
+           is_nx_huge_page_enabled() &&
+           is_shadow_present_pte(spte) &&
+           !is_large_pte(spte)) {
+               /*
+                * A small SPTE exists for this pfn, but FNAME(fetch)
+                * and __direct_map would like to create a large PTE
+                * instead: just force them to go down another level,
+                * patching back for them into pfn the next 9 bits of
+                * the address.
+                */
+               u64 page_mask = KVM_PAGES_PER_HPAGE(level) - KVM_PAGES_PER_HPAGE(level - 1);
+               *pfnp |= gfn & page_mask;
+               (*levelp)--;
+       }
+}
+
 static int __direct_map(struct kvm_vcpu *vcpu, gpa_t gpa, int write,
                        int map_writable, int level, kvm_pfn_t pfn,
-                       bool prefault)
+                       bool prefault, bool lpage_disallowed)
 {
        struct kvm_shadow_walk_iterator it;
        struct kvm_mmu_page *sp;
 
        trace_kvm_mmu_spte_requested(gpa, level, pfn);
        for_each_shadow_entry(vcpu, gpa, it) {
+               /*
+                * We cannot overwrite existing page tables with an NX
+                * large page, as the leaf could be executable.
+                */
+               disallowed_hugepage_adjust(it, gfn, &pfn, &level);
+
                base_gfn = gfn & ~(KVM_PAGES_PER_HPAGE(it.level) - 1);
                if (it.level == level)
                        break;
                                              it.level - 1, true, ACC_ALL);
 
                        link_shadow_page(vcpu, it.sptep, sp);
+                       if (lpage_disallowed)
+                               account_huge_nx_page(vcpu->kvm, sp);
                }
        }
 
 {
        int r;
        int level;
-       bool force_pt_level = false;
+       bool force_pt_level;
        kvm_pfn_t pfn;
        unsigned long mmu_seq;
        bool map_writable, write = error_code & PFERR_WRITE_MASK;
+       bool lpage_disallowed = (error_code & PFERR_FETCH_MASK) &&
+                               is_nx_huge_page_enabled();
 
+       force_pt_level = lpage_disallowed;
        level = mapping_level(vcpu, gfn, &force_pt_level);
        if (likely(!force_pt_level)) {
                /*
                goto out_unlock;
        if (likely(!force_pt_level))
                transparent_hugepage_adjust(vcpu, gfn, &pfn, &level);
-       r = __direct_map(vcpu, v, write, map_writable, level, pfn, prefault);
+       r = __direct_map(vcpu, v, write, map_writable, level, pfn,
+                        prefault, false);
 out_unlock:
        spin_unlock(&vcpu->kvm->mmu_lock);
        kvm_release_pfn_clean(pfn);
        unsigned long mmu_seq;
        int write = error_code & PFERR_WRITE_MASK;
        bool map_writable;
+       bool lpage_disallowed = (error_code & PFERR_FETCH_MASK) &&
+                               is_nx_huge_page_enabled();
 
        MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu->root_hpa));
 
        if (r)
                return r;
 
-       force_pt_level = !check_hugepage_cache_consistency(vcpu, gfn,
-                                                          PT_DIRECTORY_LEVEL);
+       force_pt_level =
+               lpage_disallowed ||
+               !check_hugepage_cache_consistency(vcpu, gfn, PT_DIRECTORY_LEVEL);
        level = mapping_level(vcpu, gfn, &force_pt_level);
        if (likely(!force_pt_level)) {
                if (level > PT_DIRECTORY_LEVEL &&
                goto out_unlock;
        if (likely(!force_pt_level))
                transparent_hugepage_adjust(vcpu, gfn, &pfn, &level);
-       r = __direct_map(vcpu, gpa, write, map_writable, level, pfn, prefault);
+       r = __direct_map(vcpu, gpa, write, map_writable, level, pfn,
+                        prefault, lpage_disallowed);
 out_unlock:
        spin_unlock(&vcpu->kvm->mmu_lock);
        kvm_release_pfn_clean(pfn);
        kvm_mmu_set_mmio_spte_mask(mask, mask, ACC_WRITE_MASK | ACC_USER_MASK);
 }
 
+static bool get_nx_auto_mode(void)
+{
+       /* Return true when CPU has the bug, and mitigations are ON */
+       return boot_cpu_has_bug(X86_BUG_ITLB_MULTIHIT) && !cpu_mitigations_off();
+}
+
+static void __set_nx_huge_pages(bool val)
+{
+       nx_huge_pages = itlb_multihit_kvm_mitigation = val;
+}
+
+static int set_nx_huge_pages(const char *val, const struct kernel_param *kp)
+{
+       bool old_val = nx_huge_pages;
+       bool new_val;
+
+       /* In "auto" mode deploy workaround only if CPU has the bug. */
+       if (sysfs_streq(val, "off"))
+               new_val = 0;
+       else if (sysfs_streq(val, "force"))
+               new_val = 1;
+       else if (sysfs_streq(val, "auto"))
+               new_val = get_nx_auto_mode();
+       else if (strtobool(val, &new_val) < 0)
+               return -EINVAL;
+
+       __set_nx_huge_pages(new_val);
+
+       if (new_val != old_val) {
+               struct kvm *kvm;
+               int idx;
+
+               mutex_lock(&kvm_lock);
+
+               list_for_each_entry(kvm, &vm_list, vm_list) {
+                       idx = srcu_read_lock(&kvm->srcu);
+                       kvm_mmu_zap_all_fast(kvm);
+                       srcu_read_unlock(&kvm->srcu, idx);
+               }
+               mutex_unlock(&kvm_lock);
+       }
+
+       return 0;
+}
+
 int kvm_mmu_module_init(void)
 {
        int ret = -ENOMEM;
 
+       if (nx_huge_pages == -1)
+               __set_nx_huge_pages(get_nx_auto_mode());
+
        /*
         * MMU roles use union aliasing which is, generally speaking, an
         * undefined behavior. However, we supposedly know how compilers behave
 
 static int FNAME(fetch)(struct kvm_vcpu *vcpu, gva_t addr,
                         struct guest_walker *gw,
                         int write_fault, int hlevel,
-                        kvm_pfn_t pfn, bool map_writable, bool prefault)
+                        kvm_pfn_t pfn, bool map_writable, bool prefault,
+                        bool lpage_disallowed)
 {
        struct kvm_mmu_page *sp = NULL;
        struct kvm_shadow_walk_iterator it;
        unsigned direct_access, access = gw->pt_access;
        int top_level, ret;
-       gfn_t base_gfn;
+       gfn_t gfn, base_gfn;
 
        direct_access = gw->pte_access;
 
                        link_shadow_page(vcpu, it.sptep, sp);
        }
 
-       base_gfn = gw->gfn;
+       /*
+        * FNAME(page_fault) might have clobbered the bottom bits of
+        * gw->gfn, restore them from the virtual address.
+        */
+       gfn = gw->gfn | ((addr & PT_LVL_OFFSET_MASK(gw->level)) >> PAGE_SHIFT);
+       base_gfn = gfn;
 
        trace_kvm_mmu_spte_requested(addr, gw->level, pfn);
 
        for (; shadow_walk_okay(&it); shadow_walk_next(&it)) {
                clear_sp_write_flooding_count(it.sptep);
-               base_gfn = gw->gfn & ~(KVM_PAGES_PER_HPAGE(it.level) - 1);
+
+               /*
+                * We cannot overwrite existing page tables with an NX
+                * large page, as the leaf could be executable.
+                */
+               disallowed_hugepage_adjust(it, gfn, &pfn, &hlevel);
+
+               base_gfn = gfn & ~(KVM_PAGES_PER_HPAGE(it.level) - 1);
                if (it.level == hlevel)
                        break;
 
                        sp = kvm_mmu_get_page(vcpu, base_gfn, addr,
                                              it.level - 1, true, direct_access);
                        link_shadow_page(vcpu, it.sptep, sp);
+                       if (lpage_disallowed)
+                               account_huge_nx_page(vcpu->kvm, sp);
                }
        }
 
        int r;
        kvm_pfn_t pfn;
        int level = PT_PAGE_TABLE_LEVEL;
-       bool force_pt_level = false;
        unsigned long mmu_seq;
        bool map_writable, is_self_change_mapping;
+       bool lpage_disallowed = (error_code & PFERR_FETCH_MASK) &&
+                               is_nx_huge_page_enabled();
+       bool force_pt_level = lpage_disallowed;
 
        pgprintk("%s: addr %lx err %x\n", __func__, addr, error_code);
 
        if (!force_pt_level)
                transparent_hugepage_adjust(vcpu, walker.gfn, &pfn, &level);
        r = FNAME(fetch)(vcpu, addr, &walker, write_fault,
-                        level, pfn, map_writable, prefault);
+                        level, pfn, map_writable, prefault, lpage_disallowed);
        kvm_mmu_audit(vcpu, AUDIT_POST_PAGE_FAULT);
 
 out_unlock: