#include <linux/eventfd.h>
 
 #include <asm/apicdef.h>
+#include <asm/mshyperv.h>
 #include <trace/events/kvm.h>
 
 #include "trace.h"
                                  entries, consumed_xmm_halves, offset);
 }
 
-static void hv_tlb_flush_enqueue(struct kvm_vcpu *vcpu, u64 *entries, int count)
+static void hv_tlb_flush_enqueue(struct kvm_vcpu *vcpu,
+                                struct kvm_vcpu_hv_tlb_flush_fifo *tlb_flush_fifo,
+                                u64 *entries, int count)
 {
-       struct kvm_vcpu_hv_tlb_flush_fifo *tlb_flush_fifo;
        struct kvm_vcpu_hv *hv_vcpu = to_hv_vcpu(vcpu);
        u64 flush_all_entry = KVM_HV_TLB_FLUSHALL_ENTRY;
 
        if (!hv_vcpu)
                return;
 
-       /* kvm_hv_flush_tlb() is not ready to handle requests for L2s yet */
-       tlb_flush_fifo = &hv_vcpu->tlb_flush_fifo[HV_L1_TLB_FLUSH_FIFO];
-
        spin_lock(&tlb_flush_fifo->write_lock);
 
        /*
        struct hv_tlb_flush_ex flush_ex;
        struct hv_tlb_flush flush;
        DECLARE_BITMAP(vcpu_mask, KVM_MAX_VCPUS);
+       struct kvm_vcpu_hv_tlb_flush_fifo *tlb_flush_fifo;
        /*
         * Normally, there can be no more than 'KVM_HV_TLB_FLUSH_FIFO_SIZE'
         * entries on the TLB flush fifo. The last entry, however, needs to be
                }
 
                trace_kvm_hv_flush_tlb(flush.processor_mask,
-                                      flush.address_space, flush.flags);
+                                      flush.address_space, flush.flags,
+                                      is_guest_mode(vcpu));
 
                valid_bank_mask = BIT_ULL(0);
                sparse_banks[0] = flush.processor_mask;
                trace_kvm_hv_flush_tlb_ex(flush_ex.hv_vp_set.valid_bank_mask,
                                          flush_ex.hv_vp_set.format,
                                          flush_ex.address_space,
-                                         flush_ex.flags);
+                                         flush_ex.flags, is_guest_mode(vcpu));
 
                valid_bank_mask = flush_ex.hv_vp_set.valid_bank_mask;
                all_cpus = flush_ex.hv_vp_set.format !=
         * vcpu->arch.cr3 may not be up-to-date for running vCPUs so we can't
         * analyze it here, flush TLB regardless of the specified address space.
         */
-       if (all_cpus) {
-               kvm_for_each_vcpu(i, v, kvm)
-                       hv_tlb_flush_enqueue(v, tlb_flush_entries, hc->rep_cnt);
+       if (all_cpus && !is_guest_mode(vcpu)) {
+               kvm_for_each_vcpu(i, v, kvm) {
+                       tlb_flush_fifo = kvm_hv_get_tlb_flush_fifo(v, false);
+                       hv_tlb_flush_enqueue(v, tlb_flush_fifo,
+                                            tlb_flush_entries, hc->rep_cnt);
+               }
 
                kvm_make_all_cpus_request(kvm, KVM_REQ_HV_TLB_FLUSH);
-       } else {
+       } else if (!is_guest_mode(vcpu)) {
                sparse_set_to_vcpu_mask(kvm, sparse_banks, valid_bank_mask, vcpu_mask);
 
                for_each_set_bit(i, vcpu_mask, KVM_MAX_VCPUS) {
                        v = kvm_get_vcpu(kvm, i);
                        if (!v)
                                continue;
-                       hv_tlb_flush_enqueue(v, tlb_flush_entries, hc->rep_cnt);
+                       tlb_flush_fifo = kvm_hv_get_tlb_flush_fifo(v, false);
+                       hv_tlb_flush_enqueue(v, tlb_flush_fifo,
+                                            tlb_flush_entries, hc->rep_cnt);
+               }
+
+               kvm_make_vcpus_request_mask(kvm, KVM_REQ_HV_TLB_FLUSH, vcpu_mask);
+       } else {
+               struct kvm_vcpu_hv *hv_v;
+
+               bitmap_zero(vcpu_mask, KVM_MAX_VCPUS);
+
+               kvm_for_each_vcpu(i, v, kvm) {
+                       hv_v = to_hv_vcpu(v);
+
+                       /*
+                        * The following check races with nested vCPUs entering/exiting
+                        * and/or migrating between L1's vCPUs, however the only case when
+                        * KVM *must* flush the TLB is when the target L2 vCPU keeps
+                        * running on the same L1 vCPU from the moment of the request until
+                        * kvm_hv_flush_tlb() returns. TLB is fully flushed in all other
+                        * cases, e.g. when the target L2 vCPU migrates to a different L1
+                        * vCPU or when the corresponding L1 vCPU temporary switches to a
+                        * different L2 vCPU while the request is being processed.
+                        */
+                       if (!hv_v || hv_v->nested.vm_id != hv_vcpu->nested.vm_id)
+                               continue;
+
+                       if (!all_cpus &&
+                           !hv_is_vp_in_sparse_set(hv_v->nested.vp_id, valid_bank_mask,
+                                                   sparse_banks))
+                               continue;
+
+                       __set_bit(i, vcpu_mask);
+                       tlb_flush_fifo = kvm_hv_get_tlb_flush_fifo(v, true);
+                       hv_tlb_flush_enqueue(v, tlb_flush_fifo,
+                                            tlb_flush_entries, hc->rep_cnt);
                }
 
                kvm_make_vcpus_request_mask(kvm, KVM_REQ_HV_TLB_FLUSH, vcpu_mask);
 
 static int kvm_hv_hypercall_complete(struct kvm_vcpu *vcpu, u64 result)
 {
+       u32 tlb_lock_count = 0;
+       int ret;
+
+       if (hv_result_success(result) && is_guest_mode(vcpu) &&
+           kvm_hv_is_tlb_flush_hcall(vcpu) &&
+           kvm_read_guest(vcpu->kvm, to_hv_vcpu(vcpu)->nested.pa_page_gpa,
+                          &tlb_lock_count, sizeof(tlb_lock_count)))
+               result = HV_STATUS_INVALID_HYPERCALL_INPUT;
+
        trace_kvm_hv_hypercall_done(result);
        kvm_hv_hypercall_set_result(vcpu, result);
        ++vcpu->stat.hypercalls;
-       return kvm_skip_emulated_instruction(vcpu);
+
+       ret = kvm_skip_emulated_instruction(vcpu);
+
+       if (tlb_lock_count)
+               kvm_x86_ops.nested_ops->hv_inject_synthetic_vmexit_post_tlb_flush(vcpu);
+
+       return ret;
 }
 
 static int kvm_hv_hypercall_complete_userspace(struct kvm_vcpu *vcpu)
 
  * Tracepoint for kvm_hv_flush_tlb.
  */
 TRACE_EVENT(kvm_hv_flush_tlb,
-       TP_PROTO(u64 processor_mask, u64 address_space, u64 flags),
-       TP_ARGS(processor_mask, address_space, flags),
+       TP_PROTO(u64 processor_mask, u64 address_space, u64 flags, bool guest_mode),
+       TP_ARGS(processor_mask, address_space, flags, guest_mode),
 
        TP_STRUCT__entry(
                __field(u64, processor_mask)
                __field(u64, address_space)
                __field(u64, flags)
+               __field(bool, guest_mode)
        ),
 
        TP_fast_assign(
                __entry->processor_mask = processor_mask;
                __entry->address_space = address_space;
                __entry->flags = flags;
+               __entry->guest_mode = guest_mode;
        ),
 
-       TP_printk("processor_mask 0x%llx address_space 0x%llx flags 0x%llx",
+       TP_printk("processor_mask 0x%llx address_space 0x%llx flags 0x%llx %s",
                  __entry->processor_mask, __entry->address_space,
-                 __entry->flags)
+                 __entry->flags, __entry->guest_mode ? "(L2)" : "")
 );
 
 /*
  * Tracepoint for kvm_hv_flush_tlb_ex.
  */
 TRACE_EVENT(kvm_hv_flush_tlb_ex,
-       TP_PROTO(u64 valid_bank_mask, u64 format, u64 address_space, u64 flags),
-       TP_ARGS(valid_bank_mask, format, address_space, flags),
+       TP_PROTO(u64 valid_bank_mask, u64 format, u64 address_space, u64 flags, bool guest_mode),
+       TP_ARGS(valid_bank_mask, format, address_space, flags, guest_mode),
 
        TP_STRUCT__entry(
                __field(u64, valid_bank_mask)
                __field(u64, format)
                __field(u64, address_space)
                __field(u64, flags)
+               __field(bool, guest_mode)
        ),
 
        TP_fast_assign(
                __entry->format = format;
                __entry->address_space = address_space;
                __entry->flags = flags;
+               __entry->guest_mode = guest_mode;
        ),
 
        TP_printk("valid_bank_mask 0x%llx format 0x%llx "
-                 "address_space 0x%llx flags 0x%llx",
+                 "address_space 0x%llx flags 0x%llx %s",
                  __entry->valid_bank_mask, __entry->format,
-                 __entry->address_space, __entry->flags)
+                 __entry->address_space, __entry->flags,
+                 __entry->guest_mode ? "(L2)" : "")
 );
 
 /*