if (!sext)
                return -ENOENT;
 
-       scontext->ext_status[sext->ext_idx] = reg_val ?
-               KVM_RISCV_SBI_EXT_AVAILABLE : KVM_RISCV_SBI_EXT_UNAVAILABLE;
+       /*
+        * We can't set the extension status to available here, since it may
+        * have a probe() function which needs to confirm availability first,
+        * but it may be too early to call that here. We can set the status to
+        * unavailable, though.
+        */
+       if (!reg_val)
+               scontext->ext_status[sext->ext_idx] =
+                       KVM_RISCV_SBI_EXT_UNAVAILABLE;
 
        return 0;
 }
        if (!sext)
                return -ENOENT;
 
-       *reg_val = scontext->ext_status[sext->ext_idx] ==
-                               KVM_RISCV_SBI_EXT_AVAILABLE;
+       /*
+        * If the extension status is still uninitialized, then we should probe
+        * to determine if it's available, but it may be too early to do that
+        * here. The best we can do is report that the extension has not been
+        * disabled, i.e. we return 1 when the extension is available and also
+        * when it only may be available.
+        */
+       *reg_val = scontext->ext_status[sext->ext_idx] !=
+                               KVM_RISCV_SBI_EXT_UNAVAILABLE;
 
        return 0;
 }
 const struct kvm_vcpu_sbi_extension *kvm_vcpu_sbi_find_ext(
                                struct kvm_vcpu *vcpu, unsigned long extid)
 {
-       int i;
-       const struct kvm_riscv_sbi_extension_entry *sext;
        struct kvm_vcpu_sbi_context *scontext = &vcpu->arch.sbi_context;
+       const struct kvm_riscv_sbi_extension_entry *entry;
+       const struct kvm_vcpu_sbi_extension *ext;
+       int i;
 
        for (i = 0; i < ARRAY_SIZE(sbi_ext); i++) {
-               sext = &sbi_ext[i];
-               if (sext->ext_ptr->extid_start <= extid &&
-                   sext->ext_ptr->extid_end >= extid) {
-                       if (sext->ext_idx < KVM_RISCV_SBI_EXT_MAX &&
-                           scontext->ext_status[sext->ext_idx] ==
+               entry = &sbi_ext[i];
+               ext = entry->ext_ptr;
+
+               if (ext->extid_start <= extid && ext->extid_end >= extid) {
+                       if (entry->ext_idx >= KVM_RISCV_SBI_EXT_MAX ||
+                           scontext->ext_status[entry->ext_idx] ==
+                                               KVM_RISCV_SBI_EXT_AVAILABLE)
+                               return ext;
+                       if (scontext->ext_status[entry->ext_idx] ==
                                                KVM_RISCV_SBI_EXT_UNAVAILABLE)
                                return NULL;
-                       return sbi_ext[i].ext_ptr;
+                       if (ext->probe && !ext->probe(vcpu)) {
+                               scontext->ext_status[entry->ext_idx] =
+                                       KVM_RISCV_SBI_EXT_UNAVAILABLE;
+                               return NULL;
+                       }
+
+                       scontext->ext_status[entry->ext_idx] =
+                               KVM_RISCV_SBI_EXT_AVAILABLE;
+                       return ext;
                }
        }