/* length is in bytes */
 static struct nvme_prps *nvme_setup_prps(struct nvme_dev *dev,
                                        struct nvme_common_command *cmd,
-                                       struct scatterlist *sg, int length)
+                                       struct scatterlist *sg, int *len,
+                                       gfp_t gfp)
 {
        struct dma_pool *pool;
+       int length = *len;
        int dma_len = sg_dma_len(sg);
        u64 dma_addr = sg_dma_address(sg);
        int offset = offset_in_page(dma_addr);
 
        nprps = DIV_ROUND_UP(length, PAGE_SIZE);
        npages = DIV_ROUND_UP(8 * nprps, PAGE_SIZE);
-       prps = kmalloc(sizeof(*prps) + sizeof(__le64 *) * npages, GFP_ATOMIC);
+       prps = kmalloc(sizeof(*prps) + sizeof(__le64 *) * npages, gfp);
+       if (!prps) {
+               cmd->prp2 = cpu_to_le64(dma_addr);
+               *len = (*len - length) + PAGE_SIZE;
+               return prps;
+       }
        prp_page = 0;
        if (nprps <= (256 / 8)) {
                pool = dev->prp_small_pool;
                prps->npages = npages;
        }
 
-       prp_list = dma_pool_alloc(pool, GFP_ATOMIC, &prp_dma);
+       prp_list = dma_pool_alloc(pool, gfp, &prp_dma);
+       if (!prp_list) {
+               cmd->prp2 = cpu_to_le64(dma_addr);
+               *len = (*len - length) + PAGE_SIZE;
+               kfree(prps);
+               return NULL;
+       }
        prps->list[prp_page++] = prp_list;
        prps->first_dma = prp_dma;
        cmd->prp2 = cpu_to_le64(prp_dma);
        for (;;) {
                if (i == PAGE_SIZE / 8) {
                        __le64 *old_prp_list = prp_list;
-                       prp_list = dma_pool_alloc(pool, GFP_ATOMIC, &prp_dma);
+                       prp_list = dma_pool_alloc(pool, gfp, &prp_dma);
+                       if (!prp_list) {
+                               *len = (*len - length);
+                               return prps;
+                       }
                        prps->list[prp_page++] = prp_list;
                        prp_list[0] = old_prp_list[i - 1];
                        old_prp_list[i - 1] = cpu_to_le64(prp_dma);
        cmnd->rw.command_id = cmdid;
        cmnd->rw.nsid = cpu_to_le32(ns->ns_id);
        nbio->prps = nvme_setup_prps(nvmeq->dev, &cmnd->common, nbio->sg,
-                                                               length);
+                                                       &length, GFP_ATOMIC);
        cmnd->rw.slba = cpu_to_le64(bio->bi_sector >> (ns->lba_shift - 9));
        cmnd->rw.length = cpu_to_le16((length >> ns->lba_shift) - 1);
        cmnd->rw.control = cpu_to_le16(control);
                                        unsigned long addr, unsigned length,
                                        struct nvme_command *cmd)
 {
-       int err, nents;
+       int err, nents, tmplen = length;
        struct scatterlist *sg;
        struct nvme_prps *prps;
 
        nents = nvme_map_user_pages(dev, 0, addr, length, &sg);
        if (nents < 0)
                return nents;
-       prps = nvme_setup_prps(dev, &cmd->common, sg, length);
-       err = nvme_submit_admin_cmd(dev, cmd, NULL);
+       prps = nvme_setup_prps(dev, &cmd->common, sg, &tmplen, GFP_KERNEL);
+       if (tmplen != length)
+               err = -ENOMEM;
+       else
+               err = nvme_submit_admin_cmd(dev, cmd, NULL);
        nvme_unmap_user_pages(dev, 0, addr, length, sg, nents);
        nvme_free_prps(dev, prps);
        return err ? -EIO : 0;
        c.rw.apptag = io.apptag;
        c.rw.appmask = io.appmask;
        /* XXX: metadata */
-       prps = nvme_setup_prps(dev, &c.common, sg, length);
+       prps = nvme_setup_prps(dev, &c.common, sg, &length, GFP_KERNEL);
 
        nvmeq = get_nvmeq(ns);
        /*
         * additional races since q_lock already protects against other CPUs.
         */
        put_nvmeq(nvmeq);
-       status = nvme_submit_sync_cmd(nvmeq, &c, NULL, IO_TIMEOUT);
+       if (length != (io.nblocks + 1) << ns->lba_shift)
+               status = -ENOMEM;
+       else
+               status = nvme_submit_sync_cmd(nvmeq, &c, NULL, IO_TIMEOUT);
 
        nvme_unmap_user_pages(dev, io.opcode & 1, io.addr, length, sg, nents);
        nvme_free_prps(dev, prps);
        struct nvme_dev *dev = ns->dev;
        struct nvme_dlfw dlfw;
        struct nvme_command c;
-       int nents, status;
+       int nents, status, length;
        struct scatterlist *sg;
        struct nvme_prps *prps;
 
                return -EFAULT;
        if (dlfw.length >= (1 << 30))
                return -EINVAL;
+       length = dlfw.length * 4;
 
-       nents = nvme_map_user_pages(dev, 1, dlfw.addr, dlfw.length * 4, &sg);
+       nents = nvme_map_user_pages(dev, 1, dlfw.addr, length, &sg);
        if (nents < 0)
                return nents;
 
        c.dlfw.opcode = nvme_admin_download_fw;
        c.dlfw.numd = cpu_to_le32(dlfw.length);
        c.dlfw.offset = cpu_to_le32(dlfw.offset);
-       prps = nvme_setup_prps(dev, &c.common, sg, dlfw.length * 4);
-
-       status = nvme_submit_admin_cmd(dev, &c, NULL);
+       prps = nvme_setup_prps(dev, &c.common, sg, &length, GFP_KERNEL);
+       if (length != dlfw.length * 4)
+               status = -ENOMEM;
+       else
+               status = nvme_submit_admin_cmd(dev, &c, NULL);
        nvme_unmap_user_pages(dev, 0, dlfw.addr, dlfw.length * 4, sg, nents);
        nvme_free_prps(dev, prps);
        return status;