extern const ulong vmx_return;
 
 static DEFINE_STATIC_KEY_FALSE(vmx_l1d_should_flush);
+static DEFINE_STATIC_KEY_FALSE(vmx_l1d_flush_always);
 
 /* Storage for pre module init parameter parsing */
 static enum vmx_l1d_flush_state __read_mostly vmentry_l1d_flush_param = VMENTER_L1D_FLUSH_AUTO;
 
        l1tf_vmx_mitigation = l1tf;
 
-       if (l1tf != VMENTER_L1D_FLUSH_NEVER)
-               static_branch_enable(&vmx_l1d_should_flush);
+       if (l1tf == VMENTER_L1D_FLUSH_NEVER)
+               return 0;
+
+       static_branch_enable(&vmx_l1d_should_flush);
+       if (l1tf == VMENTER_L1D_FLUSH_ALWAYS)
+               static_branch_enable(&vmx_l1d_flush_always);
        return 0;
 }
 
 static void vmx_l1d_flush(struct kvm_vcpu *vcpu)
 {
        int size = PAGE_SIZE << L1D_CACHE_ORDER;
-       bool always;
 
        /*
         * This code is only executed when the the flush mode is 'cond' or
         * it. The flush bit gets set again either from vcpu_run() or from
         * one of the unsafe VMEXIT handlers.
         */
-       always = l1tf_vmx_mitigation == VMENTER_L1D_FLUSH_ALWAYS;
-       vcpu->arch.l1tf_flush_l1d = always;
+       if (static_branch_unlikely(&vmx_l1d_flush_always))
+               vcpu->arch.l1tf_flush_l1d = true;
+       else
+               vcpu->arch.l1tf_flush_l1d = false;
 
        vcpu->stat.l1d_flush++;