#include <linux/kernel.h>
 #include <linux/irqreturn.h>
+#include <linux/mutex.h>
+#include <linux/bitfield.h>
+#include <linux/delay.h>
 
 #include "sp-dev.h"
 #include "psp-dev.h"
 
 struct psp_device *psp_master;
 
+#define PSP_C2PMSG_17_CMDRESP_CMD      GENMASK(19, 16)
+
+static int psp_mailbox_poll(const void __iomem *cmdresp_reg, unsigned int *cmdresp,
+                           unsigned int timeout_msecs)
+{
+       while (true) {
+               *cmdresp = ioread32(cmdresp_reg);
+               if (FIELD_GET(PSP_CMDRESP_RESP, *cmdresp))
+                       return 0;
+
+               if (!timeout_msecs--)
+                       break;
+
+               usleep_range(1000, 1100);
+       }
+
+       return -ETIMEDOUT;
+}
+
+int psp_mailbox_command(struct psp_device *psp, enum psp_cmd cmd, void *cmdbuff,
+                       unsigned int timeout_msecs, unsigned int *cmdresp)
+{
+       void __iomem *cmdresp_reg, *cmdbuff_lo_reg, *cmdbuff_hi_reg;
+       int ret;
+
+       if (!psp || !psp->vdata || !psp->vdata->cmdresp_reg ||
+           !psp->vdata->cmdbuff_addr_lo_reg || !psp->vdata->cmdbuff_addr_hi_reg)
+               return -ENODEV;
+
+       cmdresp_reg    = psp->io_regs + psp->vdata->cmdresp_reg;
+       cmdbuff_lo_reg = psp->io_regs + psp->vdata->cmdbuff_addr_lo_reg;
+       cmdbuff_hi_reg = psp->io_regs + psp->vdata->cmdbuff_addr_hi_reg;
+
+       mutex_lock(&psp->mailbox_mutex);
+
+       /* Ensure mailbox is ready for a command */
+       ret = -EBUSY;
+       if (psp_mailbox_poll(cmdresp_reg, cmdresp, 0))
+               goto unlock;
+
+       if (cmdbuff) {
+               iowrite32(lower_32_bits(__psp_pa(cmdbuff)), cmdbuff_lo_reg);
+               iowrite32(upper_32_bits(__psp_pa(cmdbuff)), cmdbuff_hi_reg);
+       }
+
+       *cmdresp = FIELD_PREP(PSP_C2PMSG_17_CMDRESP_CMD, cmd);
+       iowrite32(*cmdresp, cmdresp_reg);
+
+       ret = psp_mailbox_poll(cmdresp_reg, cmdresp, timeout_msecs);
+
+unlock:
+       mutex_unlock(&psp->mailbox_mutex);
+
+       return ret;
+}
+
 static struct psp_device *psp_alloc_struct(struct sp_device *sp)
 {
        struct device *dev = sp->dev;
        }
 
        psp->io_regs = sp->io_map;
+       mutex_init(&psp->mailbox_mutex);
 
        ret = psp_get_capability(psp);
        if (ret)
 
 #include <linux/list.h>
 #include <linux/bits.h>
 #include <linux/interrupt.h>
+#include <linux/mutex.h>
+#include <linux/psp.h>
 
 #include "sp-dev.h"
 
        struct sp_device *sp;
 
        void __iomem *io_regs;
+       struct mutex mailbox_mutex;
 
        psp_irq_handler_t sev_irq_handler;
        void *sev_irq_data;
 #define PSP_SECURITY_HSP_TPM_AVAILABLE         BIT(10)
 #define PSP_SECURITY_ROM_ARMOR_ENFORCED                BIT(11)
 
+/**
+ * enum psp_cmd - PSP mailbox commands
+ * @PSP_CMD_TEE_RING_INIT:     Initialize TEE ring buffer
+ * @PSP_CMD_TEE_RING_DESTROY:  Destroy TEE ring buffer
+ * @PSP_CMD_MAX:               Maximum command id
+ */
+enum psp_cmd {
+       PSP_CMD_TEE_RING_INIT           = 1,
+       PSP_CMD_TEE_RING_DESTROY        = 2,
+       PSP_CMD_MAX                     = 15,
+};
+
+int psp_mailbox_command(struct psp_device *psp, enum psp_cmd cmd, void *cmdbuff,
+                       unsigned int timeout_msecs, unsigned int *cmdresp);
+
 #endif /* __PSP_DEV_H */
 
        const struct sev_vdata *sev;
        const struct tee_vdata *tee;
        const struct platform_access_vdata *platform_access;
+       const unsigned int cmdresp_reg;
+       const unsigned int cmdbuff_addr_lo_reg;
+       const unsigned int cmdbuff_addr_hi_reg;
        const unsigned int feature_reg;
        const unsigned int inten_reg;
        const unsigned int intsts_reg;
 
 };
 
 static const struct tee_vdata teev1 = {
-       .cmdresp_reg            = 0x10544,      /* C2PMSG_17 */
-       .cmdbuff_addr_lo_reg    = 0x10548,      /* C2PMSG_18 */
-       .cmdbuff_addr_hi_reg    = 0x1054c,      /* C2PMSG_19 */
        .ring_wptr_reg          = 0x10550,      /* C2PMSG_20 */
        .ring_rptr_reg          = 0x10554,      /* C2PMSG_21 */
        .info_reg               = 0x109e8,      /* C2PMSG_58 */
 };
 
 static const struct tee_vdata teev2 = {
-       .cmdresp_reg            = 0x10944,      /* C2PMSG_17 */
-       .cmdbuff_addr_lo_reg    = 0x10948,      /* C2PMSG_18 */
-       .cmdbuff_addr_hi_reg    = 0x1094c,      /* C2PMSG_19 */
        .ring_wptr_reg          = 0x10950,      /* C2PMSG_20 */
        .ring_rptr_reg          = 0x10954,      /* C2PMSG_21 */
 };
 static const struct psp_vdata pspv3 = {
        .tee                    = &teev1,
        .platform_access        = &pa_v1,
+       .cmdresp_reg            = 0x10544,      /* C2PMSG_17 */
+       .cmdbuff_addr_lo_reg    = 0x10548,      /* C2PMSG_18 */
+       .cmdbuff_addr_hi_reg    = 0x1054c,      /* C2PMSG_19 */
        .bootloader_info_reg    = 0x109ec,      /* C2PMSG_59 */
        .feature_reg            = 0x109fc,      /* C2PMSG_63 */
        .inten_reg              = 0x10690,      /* P2CMSG_INTEN */
 static const struct psp_vdata pspv4 = {
        .sev                    = &sevv2,
        .tee                    = &teev1,
+       .cmdresp_reg            = 0x10544,      /* C2PMSG_17 */
+       .cmdbuff_addr_lo_reg    = 0x10548,      /* C2PMSG_18 */
+       .cmdbuff_addr_hi_reg    = 0x1054c,      /* C2PMSG_19 */
        .bootloader_info_reg    = 0x109ec,      /* C2PMSG_59 */
        .feature_reg            = 0x109fc,      /* C2PMSG_63 */
        .inten_reg              = 0x10690,      /* P2CMSG_INTEN */
 static const struct psp_vdata pspv5 = {
        .tee                    = &teev2,
        .platform_access        = &pa_v2,
+       .cmdresp_reg            = 0x10944,      /* C2PMSG_17 */
+       .cmdbuff_addr_lo_reg    = 0x10948,      /* C2PMSG_18 */
+       .cmdbuff_addr_hi_reg    = 0x1094c,      /* C2PMSG_19 */
        .feature_reg            = 0x109fc,      /* C2PMSG_63 */
        .inten_reg              = 0x10510,      /* P2CMSG_INTEN */
        .intsts_reg             = 0x10514,      /* P2CMSG_INTSTS */
 static const struct psp_vdata pspv6 = {
        .sev                    = &sevv2,
        .tee                    = &teev2,
+       .cmdresp_reg            = 0x10944,      /* C2PMSG_17 */
+       .cmdbuff_addr_lo_reg    = 0x10948,      /* C2PMSG_18 */
+       .cmdbuff_addr_hi_reg    = 0x1094c,      /* C2PMSG_19 */
        .feature_reg            = 0x109fc,      /* C2PMSG_63 */
        .inten_reg              = 0x10510,      /* P2CMSG_INTEN */
        .intsts_reg             = 0x10514,      /* P2CMSG_INTSTS */
 
        mutex_destroy(&rb_mgr->mutex);
 }
 
-static int tee_wait_cmd_poll(struct psp_tee_device *tee, unsigned int timeout,
-                            unsigned int *reg)
-{
-       /* ~10ms sleep per loop => nloop = timeout * 100 */
-       int nloop = timeout * 100;
-
-       while (--nloop) {
-               *reg = ioread32(tee->io_regs + tee->vdata->cmdresp_reg);
-               if (FIELD_GET(PSP_CMDRESP_RESP, *reg))
-                       return 0;
-
-               usleep_range(10000, 10100);
-       }
-
-       dev_err(tee->dev, "tee: command timed out, disabling PSP\n");
-       psp_dead = true;
-
-       return -ETIMEDOUT;
-}
-
 static
 struct tee_init_ring_cmd *tee_alloc_cmd_buffer(struct psp_tee_device *tee)
 {
 {
        int ring_size = MAX_RING_BUFFER_ENTRIES * sizeof(struct tee_ring_cmd);
        struct tee_init_ring_cmd *cmd;
-       phys_addr_t cmd_buffer;
        unsigned int reg;
        int ret;
 
                return -ENOMEM;
        }
 
-       cmd_buffer = __psp_pa((void *)cmd);
-
        /* Send command buffer details to Trusted OS by writing to
         * CPU-PSP message registers
         */
-
-       iowrite32(lower_32_bits(cmd_buffer),
-                 tee->io_regs + tee->vdata->cmdbuff_addr_lo_reg);
-       iowrite32(upper_32_bits(cmd_buffer),
-                 tee->io_regs + tee->vdata->cmdbuff_addr_hi_reg);
-       iowrite32(TEE_RING_INIT_CMD,
-                 tee->io_regs + tee->vdata->cmdresp_reg);
-
-       ret = tee_wait_cmd_poll(tee, TEE_DEFAULT_TIMEOUT, ®);
+       ret = psp_mailbox_command(tee->psp, PSP_CMD_TEE_RING_INIT, cmd,
+                                 TEE_DEFAULT_CMD_TIMEOUT, ®);
        if (ret) {
-               dev_err(tee->dev, "tee: ring init command timed out\n");
+               dev_err(tee->dev, "tee: ring init command timed out, disabling TEE support\n");
                tee_free_ring(tee);
+               psp_dead = true;
                goto free_buf;
        }
 
        if (psp_dead)
                goto free_ring;
 
-       iowrite32(TEE_RING_DESTROY_CMD,
-                 tee->io_regs + tee->vdata->cmdresp_reg);
-
-       ret = tee_wait_cmd_poll(tee, TEE_DEFAULT_TIMEOUT, ®);
+       ret = psp_mailbox_command(tee->psp, PSP_CMD_TEE_RING_DESTROY, NULL,
+                                 TEE_DEFAULT_CMD_TIMEOUT, ®);
        if (ret) {
-               dev_err(tee->dev, "tee: ring destroy command timed out\n");
+               dev_err(tee->dev, "tee: ring destroy command timed out, disabling TEE support\n");
+               psp_dead = true;
        } else if (FIELD_GET(PSP_CMDRESP_STS, reg)) {
                dev_err(tee->dev, "tee: ring destroy command failed (%#010lx)\n",
                        FIELD_GET(PSP_CMDRESP_STS, reg));
        if (ret)
                return ret;
 
-       ret = tee_wait_cmd_completion(tee, resp, TEE_DEFAULT_TIMEOUT);
+       ret = tee_wait_cmd_completion(tee, resp, TEE_DEFAULT_RING_TIMEOUT);
        if (ret) {
                resp->flag = CMD_RESPONSE_TIMEDOUT;
                return ret;
 
 #include <linux/device.h>
 #include <linux/mutex.h>
 
-#define TEE_DEFAULT_TIMEOUT            10
+#define TEE_DEFAULT_CMD_TIMEOUT                (10 * MSEC_PER_SEC)
+#define TEE_DEFAULT_RING_TIMEOUT       10
 #define MAX_BUFFER_SIZE                        988
 
-/**
- * enum tee_ring_cmd_id - TEE interface commands for ring buffer configuration
- * @TEE_RING_INIT_CMD:         Initialize ring buffer
- * @TEE_RING_DESTROY_CMD:      Destroy ring buffer
- * @TEE_RING_MAX_CMD:          Maximum command id
- */
-enum tee_ring_cmd_id {
-       TEE_RING_INIT_CMD               = 0x00010000,
-       TEE_RING_DESTROY_CMD            = 0x00020000,
-       TEE_RING_MAX_CMD                = 0x000F0000,
-};
-
 /**
  * struct tee_init_ring_cmd - Command to init TEE ring buffer
  * @low_addr:  bits [31:0] of the physical address of ring buffer