#include <linux/mmu_context.h>
 #include <linux/bsearch.h>
 #include <linux/sync_core.h>
+#include <linux/execmem.h>
 #include <asm/text-patching.h>
 #include <asm/alternative.h>
 #include <asm/sections.h>
 #include <asm/asm-prototypes.h>
 #include <asm/cfi.h>
 #include <asm/ibt.h>
+#include <asm/set_memory.h>
 
 int __read_mostly alternatives_patched;
 
 #endif
 };
 
+#ifdef CONFIG_MITIGATION_ITS
+
+static struct module *its_mod;
+static void *its_page;
+static unsigned int its_offset;
+
+/* Initialize a thunk with the "jmp *reg; int3" instructions. */
+static void *its_init_thunk(void *thunk, int reg)
+{
+       u8 *bytes = thunk;
+       int i = 0;
+
+       if (reg >= 8) {
+               bytes[i++] = 0x41; /* REX.B prefix */
+               reg -= 8;
+       }
+       bytes[i++] = 0xff;
+       bytes[i++] = 0xe0 + reg; /* jmp *reg */
+       bytes[i++] = 0xcc;
+
+       return thunk;
+}
+
+void its_init_mod(struct module *mod)
+{
+       if (!cpu_feature_enabled(X86_FEATURE_INDIRECT_THUNK_ITS))
+               return;
+
+       mutex_lock(&text_mutex);
+       its_mod = mod;
+       its_page = NULL;
+}
+
+void its_fini_mod(struct module *mod)
+{
+       if (!cpu_feature_enabled(X86_FEATURE_INDIRECT_THUNK_ITS))
+               return;
+
+       WARN_ON_ONCE(its_mod != mod);
+
+       its_mod = NULL;
+       its_page = NULL;
+       mutex_unlock(&text_mutex);
+
+       for (int i = 0; i < mod->its_num_pages; i++) {
+               void *page = mod->its_page_array[i];
+               execmem_restore_rox(page, PAGE_SIZE);
+       }
+}
+
+void its_free_mod(struct module *mod)
+{
+       if (!cpu_feature_enabled(X86_FEATURE_INDIRECT_THUNK_ITS))
+               return;
+
+       for (int i = 0; i < mod->its_num_pages; i++) {
+               void *page = mod->its_page_array[i];
+               execmem_free(page);
+       }
+       kfree(mod->its_page_array);
+}
+
+static void *its_alloc(void)
+{
+       void *page __free(execmem) = execmem_alloc(EXECMEM_MODULE_TEXT, PAGE_SIZE);
+
+       if (!page)
+               return NULL;
+
+       if (its_mod) {
+               void *tmp = krealloc(its_mod->its_page_array,
+                                    (its_mod->its_num_pages+1) * sizeof(void *),
+                                    GFP_KERNEL);
+               if (!tmp)
+                       return NULL;
+
+               its_mod->its_page_array = tmp;
+               its_mod->its_page_array[its_mod->its_num_pages++] = page;
+
+               execmem_make_temp_rw(page, PAGE_SIZE);
+       }
+
+       return no_free_ptr(page);
+}
+
+static void *its_allocate_thunk(int reg)
+{
+       int size = 3 + (reg / 8);
+       void *thunk;
+
+       if (!its_page || (its_offset + size - 1) >= PAGE_SIZE) {
+               its_page = its_alloc();
+               if (!its_page) {
+                       pr_err("ITS page allocation failed\n");
+                       return NULL;
+               }
+               memset(its_page, INT3_INSN_OPCODE, PAGE_SIZE);
+               its_offset = 32;
+       }
+
+       /*
+        * If the indirect branch instruction will be in the lower half
+        * of a cacheline, then update the offset to reach the upper half.
+        */
+       if ((its_offset + size - 1) % 64 < 32)
+               its_offset = ((its_offset - 1) | 0x3F) + 33;
+
+       thunk = its_page + its_offset;
+       its_offset += size;
+
+       return its_init_thunk(thunk, reg);
+}
+
+#endif
+
 /*
  * Nomenclature for variable names to simplify and clarify this code and ease
  * any potential staring at it:
 #ifdef CONFIG_MITIGATION_ITS
 static int emit_its_trampoline(void *addr, struct insn *insn, int reg, u8 *bytes)
 {
-       return __emit_trampoline(addr, insn, bytes,
-                                __x86_indirect_its_thunk_array[reg],
-                                __x86_indirect_its_thunk_array[reg]);
+       u8 *thunk = __x86_indirect_its_thunk_array[reg];
+       u8 *tmp = its_allocate_thunk(reg);
+
+       if (tmp)
+               thunk = tmp;
+
+       return __emit_trampoline(addr, insn, bytes, thunk, thunk);
 }
 
 /* Check if an indirect branch is at ITS-unsafe address */