if (WARN_ON_ONCE(!data != !buf_len))
                return -EINVAL;
 
-       if (data && WARN_ON_ONCE(!virt_addr_valid(data)))
-               return -EINVAL;
+       /*
+        * Copy the incoming data to driver's scratch buffer as __pa() will not
+        * work for some memory, e.g. vmalloc'd addresses, and @data may not be
+        * physically contiguous.
+        */
+       if (data)
+               memcpy(sev->cmd_buf, data, buf_len);
 
        /* Get the physical address of the command buffer */
-       phys_lsb = data ? lower_32_bits(__psp_pa(data)) : 0;
-       phys_msb = data ? upper_32_bits(__psp_pa(data)) : 0;
+       phys_lsb = data ? lower_32_bits(__psp_pa(sev->cmd_buf)) : 0;
+       phys_msb = data ? upper_32_bits(__psp_pa(sev->cmd_buf)) : 0;
 
        dev_dbg(sev->dev, "sev command id %#x buffer 0x%08x%08x timeout %us\n",
                cmd, phys_msb, phys_lsb, psp_timeout);
        print_hex_dump_debug("(out): ", DUMP_PREFIX_OFFSET, 16, 2, data,
                             buf_len, false);
 
+       /*
+        * Copy potential output from the PSP back to data.  Do this even on
+        * failure in case the caller wants to glean something from the error.
+        */
+       if (data)
+               memcpy(data, sev->cmd_buf, buf_len);
+
        return ret;
 }
 
        if (!sev)
                goto e_err;
 
+       sev->cmd_buf = (void *)devm_get_free_pages(dev, GFP_KERNEL, 0);
+       if (!sev->cmd_buf)
+               goto e_sev;
+
        psp->sev_data = sev;
 
        sev->dev = dev;
        if (!sev->vdata) {
                ret = -ENODEV;
                dev_err(dev, "sev: missing driver data\n");
-               goto e_sev;
+               goto e_buf;
        }
 
        psp_set_sev_irq_handler(psp, sev_irq_handler, sev);
 
 e_irq:
        psp_clear_sev_irq_handler(psp);
+e_buf:
+       devm_free_pages(dev, (unsigned long)sev->cmd_buf);
 e_sev:
        devm_kfree(dev, sev);
 e_err: