}
 
 static void
-nvkm_gsp_mem_dtor(struct nvkm_gsp *gsp, struct nvkm_gsp_mem *mem)
+nvkm_gsp_mem_dtor(struct nvkm_gsp_mem *mem)
 {
        if (mem->data) {
                /*
                 */
                memset(mem->data, 0xFF, mem->size);
 
-               dma_free_coherent(gsp->subdev.device->dev, mem->size, mem->data, mem->addr);
+               dma_free_coherent(mem->dev, mem->size, mem->data, mem->addr);
+               put_device(mem->dev);
+
                memset(mem, 0, sizeof(*mem));
        }
 }
 
+/**
+ * nvkm_gsp_mem_ctor - constructor for nvkm_gsp_mem objects
+ * @gsp: gsp pointer
+ * @size: number of bytes to allocate
+ * @mem: nvkm_gsp_mem object to initialize
+ *
+ * Allocates a block of memory for use with GSP.
+ *
+ * This memory block can potentially out-live the driver's remove() callback,
+ * so we take a device reference to ensure its lifetime. The reference is
+ * dropped in the destructor.
+ */
 static int
 nvkm_gsp_mem_ctor(struct nvkm_gsp *gsp, size_t size, struct nvkm_gsp_mem *mem)
 {
-       mem->size = size;
        mem->data = dma_alloc_coherent(gsp->subdev.device->dev, size, &mem->addr, GFP_KERNEL);
        if (WARN_ON(!mem->data))
                return -ENOMEM;
 
+       mem->size = size;
+       mem->dev = get_device(gsp->subdev.device->dev);
+
        return 0;
 }
 
        nvkm_wr32(device, 0x110004, 0x00000040);
 
        /* Release the DMA buffers that were needed only for boot and init */
-       nvkm_gsp_mem_dtor(gsp, &gsp->boot.fw);
-       nvkm_gsp_mem_dtor(gsp, &gsp->libos);
+       nvkm_gsp_mem_dtor(&gsp->boot.fw);
+       nvkm_gsp_mem_dtor(&gsp->libos);
 
        return ret;
 }
 nvkm_gsp_radix3_dtor(struct nvkm_gsp *gsp, struct nvkm_gsp_radix3 *rx3)
 {
        nvkm_gsp_sg_free(gsp->subdev.device, &rx3->lvl2);
-       nvkm_gsp_mem_dtor(gsp, &rx3->lvl1);
-       nvkm_gsp_mem_dtor(gsp, &rx3->lvl0);
+       nvkm_gsp_mem_dtor(&rx3->lvl1);
+       nvkm_gsp_mem_dtor(&rx3->lvl0);
 }
 
 /**
 
        if (ret) {
 lvl2_fail:
-               nvkm_gsp_mem_dtor(gsp, &rx3->lvl1);
+               nvkm_gsp_mem_dtor(&rx3->lvl1);
 lvl1_fail:
-               nvkm_gsp_mem_dtor(gsp, &rx3->lvl0);
+               nvkm_gsp_mem_dtor(&rx3->lvl0);
        }
 
        return ret;
 
 done:
        if (gsp->sr.meta.data) {
-               nvkm_gsp_mem_dtor(gsp, &gsp->sr.meta);
+               nvkm_gsp_mem_dtor(&gsp->sr.meta);
                nvkm_gsp_radix3_dtor(gsp, &gsp->sr.radix3);
                nvkm_gsp_sg_free(gsp->subdev.device, &gsp->sr.sgt);
                return ret;
        mutex_destroy(&gsp->client_id.mutex);
 
        nvkm_gsp_radix3_dtor(gsp, &gsp->radix3);
-       nvkm_gsp_mem_dtor(gsp, &gsp->sig);
+       nvkm_gsp_mem_dtor(&gsp->sig);
        nvkm_firmware_dtor(&gsp->fw);
 
        nvkm_falcon_fw_dtor(&gsp->booter.unload);
 
        r535_gsp_dtor_fws(gsp);
 
-       nvkm_gsp_mem_dtor(gsp, &gsp->rmargs);
-       nvkm_gsp_mem_dtor(gsp, &gsp->wpr_meta);
-       nvkm_gsp_mem_dtor(gsp, &gsp->shm.mem);
-       nvkm_gsp_mem_dtor(gsp, &gsp->loginit);
-       nvkm_gsp_mem_dtor(gsp, &gsp->logintr);
-       nvkm_gsp_mem_dtor(gsp, &gsp->logrm);
+       nvkm_gsp_mem_dtor(&gsp->rmargs);
+       nvkm_gsp_mem_dtor(&gsp->wpr_meta);
+       nvkm_gsp_mem_dtor(&gsp->shm.mem);
+       nvkm_gsp_mem_dtor(&gsp->loginit);
+       nvkm_gsp_mem_dtor(&gsp->logintr);
+       nvkm_gsp_mem_dtor(&gsp->logrm);
 }
 
 int