#define REG_MASK (KVM_REG_ARCH_MASK | KVM_REG_SIZE_MASK)
 
+static bool isa_ext_cant_disable[KVM_RISCV_ISA_EXT_MAX];
+
 bool filter_reg(__u64 reg)
 {
        switch (reg & ~REG_MASK) {
        case KVM_REG_RISCV_ISA_EXT | KVM_RISCV_ISA_EXT_ZIFENCEI:
        case KVM_REG_RISCV_ISA_EXT | KVM_RISCV_ISA_EXT_ZIHPM:
                return true;
+       /* AIA registers are always available when Ssaia can't be disabled */
+       case KVM_REG_RISCV_CSR | KVM_REG_RISCV_CSR_AIA | KVM_REG_RISCV_CSR_AIA_REG(siselect):
+       case KVM_REG_RISCV_CSR | KVM_REG_RISCV_CSR_AIA | KVM_REG_RISCV_CSR_AIA_REG(iprio1):
+       case KVM_REG_RISCV_CSR | KVM_REG_RISCV_CSR_AIA | KVM_REG_RISCV_CSR_AIA_REG(iprio2):
+       case KVM_REG_RISCV_CSR | KVM_REG_RISCV_CSR_AIA | KVM_REG_RISCV_CSR_AIA_REG(sieh):
+       case KVM_REG_RISCV_CSR | KVM_REG_RISCV_CSR_AIA | KVM_REG_RISCV_CSR_AIA_REG(siph):
+       case KVM_REG_RISCV_CSR | KVM_REG_RISCV_CSR_AIA | KVM_REG_RISCV_CSR_AIA_REG(iprio1h):
+       case KVM_REG_RISCV_CSR | KVM_REG_RISCV_CSR_AIA | KVM_REG_RISCV_CSR_AIA_REG(iprio2h):
+               return isa_ext_cant_disable[KVM_RISCV_ISA_EXT_SSAIA];
        default:
                break;
        }
 
 void finalize_vcpu(struct kvm_vcpu *vcpu, struct vcpu_reg_list *c)
 {
+       unsigned long isa_ext_state[KVM_RISCV_ISA_EXT_MAX] = { 0 };
        struct vcpu_reg_sublist *s;
+       int rc;
+
+       for (int i = 0; i < KVM_RISCV_ISA_EXT_MAX; i++)
+               __vcpu_get_reg(vcpu, RISCV_ISA_EXT_REG(i), &isa_ext_state[i]);
 
        /*
         * Disable all extensions which were enabled by default
         * if they were available in the risc-v host.
         */
-       for (int i = 0; i < KVM_RISCV_ISA_EXT_MAX; i++)
-               __vcpu_set_reg(vcpu, RISCV_ISA_EXT_REG(i), 0);
+       for (int i = 0; i < KVM_RISCV_ISA_EXT_MAX; i++) {
+               rc = __vcpu_set_reg(vcpu, RISCV_ISA_EXT_REG(i), 0);
+               if (rc && isa_ext_state[i])
+                       isa_ext_cant_disable[i] = true;
+       }
 
        for_each_sublist(c, s) {
                if (!s->feature)