#include <asm/unwind_hints.h>
 #include <asm/cpufeatures.h>
 #include <asm/page_types.h>
+#include <asm/percpu.h>
+#include <asm/asm-offsets.h>
+#include <asm/processor-flags.h>
 
 /*
 
 
 #ifdef CONFIG_PAGE_TABLE_ISOLATION
 
-/* PAGE_TABLE_ISOLATION PGDs are 8k.  Flip bit 12 to switch between the two halves: */
-#define PTI_SWITCH_MASK (1<<PAGE_SHIFT)
+/*
+ * PAGE_TABLE_ISOLATION PGDs are 8k.  Flip bit 12 to switch between the two
+ * halves:
+ */
+#define PTI_SWITCH_PGTABLES_MASK       (1<<PAGE_SHIFT)
+#define PTI_SWITCH_MASK                (PTI_SWITCH_PGTABLES_MASK|(1<<X86_CR3_PTI_SWITCH_BIT))
 
-.macro ADJUST_KERNEL_CR3 reg:req
-       /* Clear "PAGE_TABLE_ISOLATION bit", point CR3 at kernel pagetables: */
-       andq    $(~PTI_SWITCH_MASK), \reg
+.macro SET_NOFLUSH_BIT reg:req
+       bts     $X86_CR3_PCID_NOFLUSH_BIT, \reg
 .endm
 
-.macro ADJUST_USER_CR3 reg:req
-       /* Move CR3 up a page to the user page tables: */
-       orq     $(PTI_SWITCH_MASK), \reg
+.macro ADJUST_KERNEL_CR3 reg:req
+       ALTERNATIVE "", "SET_NOFLUSH_BIT \reg", X86_FEATURE_PCID
+       /* Clear PCID and "PAGE_TABLE_ISOLATION bit", point CR3 at kernel pagetables: */
+       andq    $(~PTI_SWITCH_MASK), \reg
 .endm
 
 .macro SWITCH_TO_KERNEL_CR3 scratch_reg:req
 .Lend_\@:
 .endm
 
-.macro SWITCH_TO_USER_CR3 scratch_reg:req
+#define THIS_CPU_user_pcid_flush_mask   \
+       PER_CPU_VAR(cpu_tlbstate) + TLB_STATE_user_pcid_flush_mask
+
+.macro SWITCH_TO_USER_CR3_NOSTACK scratch_reg:req scratch_reg2:req
        ALTERNATIVE "jmp .Lend_\@", "", X86_FEATURE_PTI
        mov     %cr3, \scratch_reg
-       ADJUST_USER_CR3 \scratch_reg
+
+       ALTERNATIVE "jmp .Lwrcr3_\@", "", X86_FEATURE_PCID
+
+       /*
+        * Test if the ASID needs a flush.
+        */
+       movq    \scratch_reg, \scratch_reg2
+       andq    $(0x7FF), \scratch_reg          /* mask ASID */
+       bt      \scratch_reg, THIS_CPU_user_pcid_flush_mask
+       jnc     .Lnoflush_\@
+
+       /* Flush needed, clear the bit */
+       btr     \scratch_reg, THIS_CPU_user_pcid_flush_mask
+       movq    \scratch_reg2, \scratch_reg
+       jmp     .Lwrcr3_\@
+
+.Lnoflush_\@:
+       movq    \scratch_reg2, \scratch_reg
+       SET_NOFLUSH_BIT \scratch_reg
+
+.Lwrcr3_\@:
+       /* Flip the PGD and ASID to the user version */
+       orq     $(PTI_SWITCH_MASK), \scratch_reg
        mov     \scratch_reg, %cr3
 .Lend_\@:
 .endm
 
+.macro SWITCH_TO_USER_CR3_STACK        scratch_reg:req
+       pushq   %rax
+       SWITCH_TO_USER_CR3_NOSTACK scratch_reg=\scratch_reg scratch_reg2=%rax
+       popq    %rax
+.endm
+
 .macro SAVE_AND_SWITCH_TO_KERNEL_CR3 scratch_reg:req save_reg:req
        ALTERNATIVE "jmp .Ldone_\@", "", X86_FEATURE_PTI
        movq    %cr3, \scratch_reg
        movq    \scratch_reg, \save_reg
        /*
-        * Is the switch bit zero?  This means the address is
-        * up in real PAGE_TABLE_ISOLATION patches in a moment.
+        * Is the "switch mask" all zero?  That means that both of
+        * these are zero:
+        *
+        *      1. The user/kernel PCID bit, and
+        *      2. The user/kernel "bit" that points CR3 to the
+        *         bottom half of the 8k PGD
+        *
+        * That indicates a kernel CR3 value, not a user CR3.
         */
        testq   $(PTI_SWITCH_MASK), \scratch_reg
        jz      .Ldone_\@
 
 .macro SWITCH_TO_KERNEL_CR3 scratch_reg:req
 .endm
-.macro SWITCH_TO_USER_CR3 scratch_reg:req
+.macro SWITCH_TO_USER_CR3_NOSTACK scratch_reg:req scratch_reg2:req
+.endm
+.macro SWITCH_TO_USER_CR3_STACK scratch_reg:req
 .endm
 .macro SAVE_AND_SWITCH_TO_KERNEL_CR3 scratch_reg:req save_reg:req
 .endm
 
 #include <asm/segment.h>
 #include <asm/cache.h>
 #include <asm/errno.h>
-#include "calling.h"
 #include <asm/asm-offsets.h>
 #include <asm/msr.h>
 #include <asm/unistd.h>
 #include <asm/frame.h>
 #include <linux/err.h>
 
+#include "calling.h"
+
 .code64
 .section .entry.text, "ax"
 
         * We are on the trampoline stack.  All regs except RDI are live.
         * We can do future final exit work right here.
         */
-       SWITCH_TO_USER_CR3 scratch_reg=%rdi
+       SWITCH_TO_USER_CR3_STACK scratch_reg=%rdi
 
        popq    %rdi
        popq    %rsp
         * We can do future final exit work right here.
         */
 
-       SWITCH_TO_USER_CR3 scratch_reg=%rdi
+       SWITCH_TO_USER_CR3_STACK scratch_reg=%rdi
 
        /* Restore RDI. */
        popq    %rdi
         */
        orq     PER_CPU_VAR(espfix_stack), %rax
 
-       SWITCH_TO_USER_CR3 scratch_reg=%rdi     /* to user CR3 */
+       SWITCH_TO_USER_CR3_STACK scratch_reg=%rdi
        SWAPGS                                  /* to user GS */
        popq    %rdi                            /* Restore user RDI */
 
 
         * switch until after after the last reference to the process
         * stack.
         *
-        * %r8 is zeroed before the sysret, thus safe to clobber.
+        * %r8/%r9 are zeroed before the sysret, thus safe to clobber.
         */
-       SWITCH_TO_USER_CR3 scratch_reg=%r8
+       SWITCH_TO_USER_CR3_NOSTACK scratch_reg=%r8 scratch_reg2=%r9
 
        xorq    %r8, %r8
        xorq    %r9, %r9
 
 #define CR3_ADDR_MASK  __sme_clr(0x7FFFFFFFFFFFF000ull)
 #define CR3_PCID_MASK  0xFFFull
 #define CR3_NOFLUSH    BIT_ULL(63)
+
+#ifdef CONFIG_PAGE_TABLE_ISOLATION
+# define X86_CR3_PTI_SWITCH_BIT        11
+#endif
+
 #else
 /*
  * CR3_ADDR_MASK needs at least bits 31:5 set on PAE systems, and we save
 
 #include <asm/special_insns.h>
 #include <asm/smp.h>
 #include <asm/invpcid.h>
+#include <asm/pti.h>
+#include <asm/processor-flags.h>
 
 static inline u64 inc_mm_tlb_gen(struct mm_struct *mm)
 {
 
 /* There are 12 bits of space for ASIDS in CR3 */
 #define CR3_HW_ASID_BITS               12
+
 /*
  * When enabled, PAGE_TABLE_ISOLATION consumes a single bit for
  * user/kernel switches
  */
-#define PTI_CONSUMED_ASID_BITS         0
+#ifdef CONFIG_PAGE_TABLE_ISOLATION
+# define PTI_CONSUMED_PCID_BITS        1
+#else
+# define PTI_CONSUMED_PCID_BITS        0
+#endif
+
+#define CR3_AVAIL_PCID_BITS (X86_CR3_PCID_BITS - PTI_CONSUMED_PCID_BITS)
 
-#define CR3_AVAIL_ASID_BITS (CR3_HW_ASID_BITS - PTI_CONSUMED_ASID_BITS)
 /*
  * ASIDs are zero-based: 0->MAX_AVAIL_ASID are valid.  -1 below to account
  * for them being zero-based.  Another -1 is because ASID 0 is reserved for
  * use by non-PCID-aware users.
  */
-#define MAX_ASID_AVAILABLE ((1 << CR3_AVAIL_ASID_BITS) - 2)
+#define MAX_ASID_AVAILABLE ((1 << CR3_AVAIL_PCID_BITS) - 2)
+
+/*
+ * 6 because 6 should be plenty and struct tlb_state will fit in two cache
+ * lines.
+ */
+#define TLB_NR_DYN_ASIDS       6
 
 static inline u16 kern_pcid(u16 asid)
 {
        VM_WARN_ON_ONCE(asid > MAX_ASID_AVAILABLE);
+
+#ifdef CONFIG_PAGE_TABLE_ISOLATION
+       /*
+        * Make sure that the dynamic ASID space does not confict with the
+        * bit we are using to switch between user and kernel ASIDs.
+        */
+       BUILD_BUG_ON(TLB_NR_DYN_ASIDS >= (1 << X86_CR3_PTI_SWITCH_BIT));
+
        /*
+        * The ASID being passed in here should have respected the
+        * MAX_ASID_AVAILABLE and thus never have the switch bit set.
+        */
+       VM_WARN_ON_ONCE(asid & (1 << X86_CR3_PTI_SWITCH_BIT));
+#endif
+       /*
+        * The dynamically-assigned ASIDs that get passed in are small
+        * (<TLB_NR_DYN_ASIDS).  They never have the high switch bit set,
+        * so do not bother to clear it.
+        *
         * If PCID is on, ASID-aware code paths put the ASID+1 into the
         * PCID bits.  This serves two purposes.  It prevents a nasty
         * situation in which PCID-unaware code saves CR3, loads some other
        return !static_cpu_has(X86_FEATURE_PCID);
 }
 
-/*
- * 6 because 6 should be plenty and struct tlb_state will fit in
- * two cache lines.
- */
-#define TLB_NR_DYN_ASIDS 6
-
 struct tlb_context {
        u64 ctx_id;
        u64 tlb_gen;
         */
        bool invalidate_other;
 
+       /*
+        * Mask that contains TLB_NR_DYN_ASIDS+1 bits to indicate
+        * the corresponding user PCID needs a flush next time we
+        * switch to it; see SWITCH_TO_USER_CR3.
+        */
+       unsigned short user_pcid_flush_mask;
+
        /*
         * Access to this CR4 shadow and to H/W CR4 is protected by
         * disabling interrupts when modifying either one.
 
 extern void initialize_tlbstate_and_flush(void);
 
+/*
+ * Given an ASID, flush the corresponding user ASID.  We can delay this
+ * until the next time we switch to it.
+ *
+ * See SWITCH_TO_USER_CR3.
+ */
+static inline void invalidate_user_asid(u16 asid)
+{
+       /* There is no user ASID if address space separation is off */
+       if (!IS_ENABLED(CONFIG_PAGE_TABLE_ISOLATION))
+               return;
+
+       /*
+        * We only have a single ASID if PCID is off and the CR3
+        * write will have flushed it.
+        */
+       if (!cpu_feature_enabled(X86_FEATURE_PCID))
+               return;
+
+       if (!static_cpu_has(X86_FEATURE_PTI))
+               return;
+
+       __set_bit(kern_pcid(asid),
+                 (unsigned long *)this_cpu_ptr(&cpu_tlbstate.user_pcid_flush_mask));
+}
+
 /*
  * flush the entire current user mapping
  */
 static inline void __native_flush_tlb(void)
 {
+       invalidate_user_asid(this_cpu_read(cpu_tlbstate.loaded_mm_asid));
        /*
-        * If current->mm == NULL then we borrow a mm which may change during a
-        * task switch and therefore we must not be preempted while we write CR3
-        * back:
+        * If current->mm == NULL then we borrow a mm which may change
+        * during a task switch and therefore we must not be preempted
+        * while we write CR3 back:
         */
        preempt_disable();
        native_write_cr3(__native_read_cr3());
  */
 static inline void __native_flush_tlb_single(unsigned long addr)
 {
+       u32 loaded_mm_asid = this_cpu_read(cpu_tlbstate.loaded_mm_asid);
+
        asm volatile("invlpg (%0)" ::"r" (addr) : "memory");
+
+       if (!static_cpu_has(X86_FEATURE_PTI))
+               return;
+
+       invalidate_user_asid(loaded_mm_asid);
 }
 
 /*
 
 #define X86_CR3_PWT            _BITUL(X86_CR3_PWT_BIT)
 #define X86_CR3_PCD_BIT                4 /* Page Cache Disable */
 #define X86_CR3_PCD            _BITUL(X86_CR3_PCD_BIT)
-#define X86_CR3_PCID_MASK      _AC(0x00000fff,UL) /* PCID Mask */
+
+#define X86_CR3_PCID_BITS      12
+#define X86_CR3_PCID_MASK      (_AC((1UL << X86_CR3_PCID_BITS) - 1, UL))
+
+#define X86_CR3_PCID_NOFLUSH_BIT 63 /* Preserve old PCID */
+#define X86_CR3_PCID_NOFLUSH    _BITULL(X86_CR3_PCID_NOFLUSH_BIT)
 
 /*
  * Intel CPU features in CR4
 
 #include <asm/sigframe.h>
 #include <asm/bootparam.h>
 #include <asm/suspend.h>
+#include <asm/tlbflush.h>
 
 #ifdef CONFIG_XEN
 #include <xen/interface/xen.h>
        BLANK();
        DEFINE(PTREGS_SIZE, sizeof(struct pt_regs));
 
+       /* TLB state for the entry code */
+       OFFSET(TLB_STATE_user_pcid_flush_mask, tlb_state, user_pcid_flush_mask);
+
        /* Layout info for cpu_entry_area */
        OFFSET(CPU_ENTRY_AREA_tss, cpu_entry_area, tss);
        OFFSET(CPU_ENTRY_AREA_entry_trampoline, cpu_entry_area, entry_trampoline);
 
        free_area_init_nodes(max_zone_pfns);
 }
 
-DEFINE_PER_CPU_SHARED_ALIGNED(struct tlb_state, cpu_tlbstate) = {
+__visible DEFINE_PER_CPU_SHARED_ALIGNED(struct tlb_state, cpu_tlbstate) = {
        .loaded_mm = &init_mm,
        .next_asid = 1,
        .cr4 = ~0UL,    /* fail hard if we screw up cr4 shadow initialization */
 
        unsigned long new_mm_cr3;
 
        if (need_flush) {
+               invalidate_user_asid(new_asid);
                new_mm_cr3 = build_cr3(pgdir, new_asid);
        } else {
                new_mm_cr3 = build_cr3_noflush(pgdir, new_asid);