prdt_length = le16_to_cpu(
                        lrbp->utr_descriptor_ptr->prd_table_length);
                if (hba->quirks & UFSHCD_QUIRK_PRDT_BYTE_GRAN)
-                       prdt_length /= sizeof(struct ufshcd_sg_entry);
+                       prdt_length /= ufshcd_sg_entry_size(hba);
 
                dev_err(hba->dev,
                        "UPIU[%d] - PRDT - %d entries  phys@0x%llx\n",
 
                if (pr_prdt)
                        ufshcd_hex_dump("UPIU PRDT: ", lrbp->ucd_prdt_ptr,
-                               sizeof(struct ufshcd_sg_entry) * prdt_length);
+                               ufshcd_sg_entry_size(hba) * prdt_length);
        }
 }
 
  */
 static int ufshcd_map_sg(struct ufs_hba *hba, struct ufshcd_lrb *lrbp)
 {
-       struct ufshcd_sg_entry *prd_table;
+       struct ufshcd_sg_entry *prd;
        struct scatterlist *sg;
        struct scsi_cmnd *cmd;
        int sg_segments;
 
                if (hba->quirks & UFSHCD_QUIRK_PRDT_BYTE_GRAN)
                        lrbp->utr_descriptor_ptr->prd_table_length =
-                               cpu_to_le16((sg_segments *
-                                       sizeof(struct ufshcd_sg_entry)));
+                               cpu_to_le16(sg_segments * ufshcd_sg_entry_size(hba));
                else
                        lrbp->utr_descriptor_ptr->prd_table_length =
                                cpu_to_le16(sg_segments);
 
-               prd_table = lrbp->ucd_prdt_ptr;
+               prd = lrbp->ucd_prdt_ptr;
 
                scsi_for_each_sg(cmd, sg, sg_segments, i) {
                        const unsigned int len = sg_dma_len(sg);
                         * indicates 4 bytes, '7' indicates 8 bytes, etc."
                         */
                        WARN_ONCE(len > 256 * 1024, "len = %#x\n", len);
-                       prd_table[i].size = cpu_to_le32(len - 1);
-                       prd_table[i].addr = cpu_to_le64(sg->dma_address);
-                       prd_table[i].reserved = 0;
+                       prd->size = cpu_to_le32(len - 1);
+                       prd->addr = cpu_to_le64(sg->dma_address);
+                       prd->reserved = 0;
+                       prd = (void *)prd + ufshcd_sg_entry_size(hba);
                }
        } else {
                lrbp->utr_descriptor_ptr->prd_table_length = 0;
 
 static void ufshcd_init_lrb(struct ufs_hba *hba, struct ufshcd_lrb *lrb, int i)
 {
-       struct utp_transfer_cmd_desc *cmd_descp = hba->ucdl_base_addr;
+       struct utp_transfer_cmd_desc *cmd_descp = (void *)hba->ucdl_base_addr +
+               i * sizeof_utp_transfer_cmd_desc(hba);
        struct utp_transfer_req_desc *utrdlp = hba->utrdl_base_addr;
        dma_addr_t cmd_desc_element_addr = hba->ucdl_dma_addr +
-               i * sizeof(struct utp_transfer_cmd_desc);
+               i * sizeof_utp_transfer_cmd_desc(hba);
        u16 response_offset = offsetof(struct utp_transfer_cmd_desc,
                                       response_upiu);
        u16 prdt_offset = offsetof(struct utp_transfer_cmd_desc, prd_table);
        lrb->utr_descriptor_ptr = utrdlp + i;
        lrb->utrd_dma_addr = hba->utrdl_dma_addr +
                i * sizeof(struct utp_transfer_req_desc);
-       lrb->ucd_req_ptr = (struct utp_upiu_req *)(cmd_descp + i);
+       lrb->ucd_req_ptr = (struct utp_upiu_req *)cmd_descp->command_upiu;
        lrb->ucd_req_dma_addr = cmd_desc_element_addr;
-       lrb->ucd_rsp_ptr = (struct utp_upiu_rsp *)cmd_descp[i].response_upiu;
+       lrb->ucd_rsp_ptr = (struct utp_upiu_rsp *)cmd_descp->response_upiu;
        lrb->ucd_rsp_dma_addr = cmd_desc_element_addr + response_offset;
-       lrb->ucd_prdt_ptr = cmd_descp[i].prd_table;
+       lrb->ucd_prdt_ptr = (struct ufshcd_sg_entry *)cmd_descp->prd_table;
        lrb->ucd_prdt_dma_addr = cmd_desc_element_addr + prdt_offset;
 }
 
        size_t utmrdl_size, utrdl_size, ucdl_size;
 
        /* Allocate memory for UTP command descriptors */
-       ucdl_size = (sizeof(struct utp_transfer_cmd_desc) * hba->nutrs);
+       ucdl_size = sizeof_utp_transfer_cmd_desc(hba) * hba->nutrs;
        hba->ucdl_base_addr = dmam_alloc_coherent(hba->dev,
                                                  ucdl_size,
                                                  &hba->ucdl_dma_addr,
        prdt_offset =
                offsetof(struct utp_transfer_cmd_desc, prd_table);
 
-       cmd_desc_size = sizeof(struct utp_transfer_cmd_desc);
+       cmd_desc_size = sizeof_utp_transfer_cmd_desc(hba);
        cmd_desc_dma_addr = hba->ucdl_dma_addr;
 
        for (i = 0; i < hba->nutrs; i++) {
        hba->dev = dev;
        hba->dev_ref_clk_freq = REF_CLK_FREQ_INVAL;
        hba->nop_out_timeout = NOP_OUT_TIMEOUT;
+       ufshcd_set_sg_entry_size(hba, sizeof(struct ufshcd_sg_entry));
        INIT_LIST_HEAD(&hba->clk_list_head);
        spin_lock_init(&hba->outstanding_lock);
 
 {
        int ret;
 
-       /* Verify that there are no gaps in struct utp_transfer_cmd_desc. */
-       static_assert(sizeof(struct utp_transfer_cmd_desc) ==
-                     2 * ALIGNED_UPIU_SIZE +
-                             SG_ALL * sizeof(struct ufshcd_sg_entry));
-
        ufs_debugfs_init();
 
        ret = scsi_register_driver(&ufs_dev_wlun_template.gendrv);
 
  * @vops: pointer to variant specific operations
  * @vps: pointer to variant specific parameters
  * @priv: pointer to variant specific private data
+ * @sg_entry_size: size of struct ufshcd_sg_entry (may include variant fields)
  * @irq: Irq number of the controller
  * @is_irq_enabled: whether or not the UFS controller interrupt is enabled.
  * @dev_ref_clk_freq: reference clock frequency
        const struct ufs_hba_variant_ops *vops;
        struct ufs_hba_variant_params *vps;
        void *priv;
+#ifdef CONFIG_SCSI_UFS_VARIABLE_SG_ENTRY_SIZE
+       size_t sg_entry_size;
+#endif
        unsigned int irq;
        bool is_irq_enabled;
        enum ufs_ref_clk_freq dev_ref_clk_freq;
        bool complete_put;
 };
 
+#ifdef CONFIG_SCSI_UFS_VARIABLE_SG_ENTRY_SIZE
+static inline size_t ufshcd_sg_entry_size(const struct ufs_hba *hba)
+{
+       return hba->sg_entry_size;
+}
+
+static inline void ufshcd_set_sg_entry_size(struct ufs_hba *hba, size_t sg_entry_size)
+{
+       WARN_ON_ONCE(sg_entry_size < sizeof(struct ufshcd_sg_entry));
+       hba->sg_entry_size = sg_entry_size;
+}
+#else
+static inline size_t ufshcd_sg_entry_size(const struct ufs_hba *hba)
+{
+       return sizeof(struct ufshcd_sg_entry);
+}
+
+#define ufshcd_set_sg_entry_size(hba, sg_entry_size)                   \
+       ({ (void)(hba); BUILD_BUG_ON(sg_entry_size != sizeof(struct ufshcd_sg_entry)); })
+#endif
+
+static inline size_t sizeof_utp_transfer_cmd_desc(const struct ufs_hba *hba)
+{
+       return sizeof(struct utp_transfer_cmd_desc) + SG_ALL * ufshcd_sg_entry_size(hba);
+}
+
 /* Returns true if clocks can be gated. Otherwise false */
 static inline bool ufshcd_is_clkgating_allowed(struct ufs_hba *hba)
 {