return get_unaligned_be16(&buffer[2]) + 4;
 }
 
+static int scsi_get_vpd_size(struct scsi_device *sdev, u8 page)
+{
+       unsigned char vpd_header[SCSI_VPD_HEADER_SIZE] __aligned(4);
+       int result;
+
+       /*
+        * Fetch the VPD page header to find out how big the page
+        * is. This is done to prevent problems on legacy devices
+        * which can not handle allocation lengths as large as
+        * potentially requested by the caller.
+        */
+       result = scsi_vpd_inquiry(sdev, vpd_header, page, sizeof(vpd_header));
+       if (result < 0)
+               return 0;
+
+       if (result < SCSI_VPD_HEADER_SIZE) {
+               dev_warn_once(&sdev->sdev_gendev,
+                             "%s: short VPD page 0x%02x length: %d bytes\n",
+                             __func__, page, result);
+               return 0;
+       }
+
+       return result;
+}
+
 /**
  * scsi_get_vpd_page - Get Vital Product Data from a SCSI device
  * @sdev: The device to ask
  *
  * SCSI devices may optionally supply Vital Product Data.  Each 'page'
  * of VPD is defined in the appropriate SCSI document (eg SPC, SBC).
- * If the device supports this VPD page, this routine returns a pointer
- * to a buffer containing the data from that page.  The caller is
- * responsible for calling kfree() on this pointer when it is no longer
- * needed.  If we cannot retrieve the VPD page this routine returns %NULL.
+ * If the device supports this VPD page, this routine fills @buf
+ * with the data from that page and return 0. If the VPD page is not
+ * supported or its content cannot be retrieved, -EINVAL is returned.
  */
 int scsi_get_vpd_page(struct scsi_device *sdev, u8 page, unsigned char *buf,
                      int buf_len)
 {
-       int i, result;
-
-       if (sdev->skip_vpd_pages)
-               goto fail;
+       int result, vpd_len;
 
-       /* Ask for all the pages supported by this device */
-       result = scsi_vpd_inquiry(sdev, buf, 0, buf_len);
-       if (result < 4)
-               goto fail;
-
-       /* If the user actually wanted this page, we can skip the rest */
-       if (page == 0)
-               return 0;
+       if (!scsi_device_supports_vpd(sdev))
+               return -EINVAL;
 
-       for (i = 4; i < min(result, buf_len); i++)
-               if (buf[i] == page)
-                       goto found;
+       vpd_len = scsi_get_vpd_size(sdev, page);
+       if (vpd_len <= 0)
+               return -EINVAL;
 
-       if (i < result && i >= buf_len)
-               /* ran off the end of the buffer, give us benefit of doubt */
-               goto found;
-       /* The device claims it doesn't support the requested page */
-       goto fail;
+       vpd_len = min(vpd_len, buf_len);
 
- found:
-       result = scsi_vpd_inquiry(sdev, buf, page, buf_len);
+       /*
+        * Fetch the actual page. Since the appropriate size was reported
+        * by the device it is now safe to ask for something bigger.
+        */
+       memset(buf, 0, buf_len);
+       result = scsi_vpd_inquiry(sdev, buf, page, vpd_len);
        if (result < 0)
-               goto fail;
+               return -EINVAL;
+       else if (result > vpd_len)
+               dev_warn_once(&sdev->sdev_gendev,
+                             "%s: VPD page 0x%02x result %d > %d bytes\n",
+                             __func__, page, result, vpd_len);
 
        return 0;
-
- fail:
-       return -EINVAL;
 }
 EXPORT_SYMBOL_GPL(scsi_get_vpd_page);
 
 static struct scsi_vpd *scsi_get_vpd_buf(struct scsi_device *sdev, u8 page)
 {
        struct scsi_vpd *vpd_buf;
-       int vpd_len = SCSI_VPD_PG_LEN, result;
+       int vpd_len, result;
+
+       vpd_len = scsi_get_vpd_size(sdev, page);
+       if (vpd_len <= 0)
+               return NULL;
 
 retry_pg:
+       /*
+        * Fetch the actual page. Since the appropriate size was reported
+        * by the device it is now safe to ask for something bigger.
+        */
        vpd_buf = kmalloc(sizeof(*vpd_buf) + vpd_len, GFP_KERNEL);
        if (!vpd_buf)
                return NULL;
                return NULL;
        }
        if (result > vpd_len) {
+               dev_warn_once(&sdev->sdev_gendev,
+                             "%s: VPD page 0x%02x result %d > %d bytes\n",
+                             __func__, page, result, vpd_len);
                vpd_len = result;
                kfree(vpd_buf);
                goto retry_pg;