#include <linux/etherdevice.h>
 #include <linux/interrupt.h>
 #include <linux/kernel.h>
+#include <linux/types.h>
 #include <net/addrconf.h>
 #include <rdma/ib_umem.h>
 
        dseg->len  = cpu_to_le32(sg->length);
 }
 
+static void set_extend_sge(struct hns_roce_qp *qp, struct ib_send_wr *wr,
+                          unsigned int *sge_ind)
+{
+       struct hns_roce_v2_wqe_data_seg *dseg;
+       struct ib_sge *sg;
+       int num_in_wqe = 0;
+       int extend_sge_num;
+       int fi_sge_num;
+       int se_sge_num;
+       int shift;
+       int i;
+
+       if (qp->ibqp.qp_type == IB_QPT_RC || qp->ibqp.qp_type == IB_QPT_UC)
+               num_in_wqe = HNS_ROCE_V2_UC_RC_SGE_NUM_IN_WQE;
+       extend_sge_num = wr->num_sge - num_in_wqe;
+       sg = wr->sg_list + num_in_wqe;
+       shift = qp->hr_buf.page_shift;
+
+       /*
+        * Check whether wr->num_sge sges are in the same page. If not, we
+        * should calculate how many sges in the first page and the second
+        * page.
+        */
+       dseg = get_send_extend_sge(qp, (*sge_ind) & (qp->sge.sge_cnt - 1));
+       fi_sge_num = (round_up((uintptr_t)dseg, 1 << shift) -
+                     (uintptr_t)dseg) /
+                     sizeof(struct hns_roce_v2_wqe_data_seg);
+       if (extend_sge_num > fi_sge_num) {
+               se_sge_num = extend_sge_num - fi_sge_num;
+               for (i = 0; i < fi_sge_num; i++) {
+                       set_data_seg_v2(dseg++, sg + i);
+                       (*sge_ind)++;
+               }
+               dseg = get_send_extend_sge(qp,
+                                          (*sge_ind) & (qp->sge.sge_cnt - 1));
+               for (i = 0; i < se_sge_num; i++) {
+                       set_data_seg_v2(dseg++, sg + fi_sge_num + i);
+                       (*sge_ind)++;
+               }
+       } else {
+               for (i = 0; i < extend_sge_num; i++) {
+                       set_data_seg_v2(dseg++, sg + i);
+                       (*sge_ind)++;
+               }
+       }
+}
+
 static int set_rwqe_data_seg(struct ib_qp *ibqp, struct ib_send_wr *wr,
                             struct hns_roce_v2_rc_send_wqe *rc_sq_wqe,
                             void *wqe, unsigned int *sge_ind,
                roce_set_bit(rc_sq_wqe->byte_4, V2_RC_SEND_WQE_BYTE_4_INLINE_S,
                             1);
        } else {
-               if (wr->num_sge <= 2) {
+               if (wr->num_sge <= HNS_ROCE_V2_UC_RC_SGE_NUM_IN_WQE) {
                        for (i = 0; i < wr->num_sge; i++) {
                                if (likely(wr->sg_list[i].length)) {
                                        set_data_seg_v2(dseg, wr->sg_list + i);
                                     V2_RC_SEND_WQE_BYTE_20_MSG_START_SGE_IDX_S,
                                     (*sge_ind) & (qp->sge.sge_cnt - 1));
 
-                       for (i = 0; i < 2; i++) {
+                       for (i = 0; i < HNS_ROCE_V2_UC_RC_SGE_NUM_IN_WQE; i++) {
                                if (likely(wr->sg_list[i].length)) {
                                        set_data_seg_v2(dseg, wr->sg_list + i);
                                        dseg++;
                                }
                        }
 
-                       dseg = get_send_extend_sge(qp,
-                                           (*sge_ind) & (qp->sge.sge_cnt - 1));
-
-                       for (i = 0; i < wr->num_sge - 2; i++) {
-                               if (likely(wr->sg_list[i + 2].length)) {
-                                       set_data_seg_v2(dseg,
-                                                       wr->sg_list + 2 + i);
-                                       dseg++;
-                                       (*sge_ind)++;
-                               }
-                       }
+                       set_extend_sge(qp, wr, sge_ind);
                }
 
                roce_set_field(rc_sq_wqe->byte_16,
                        memcpy(&ud_sq_wqe->dgid[0], &ah->av.dgid[0],
                               GID_LEN_V2);
 
-                       dseg = get_send_extend_sge(qp,
-                                           sge_ind & (qp->sge.sge_cnt - 1));
-                       for (i = 0; i < wr->num_sge; i++) {
-                               set_data_seg_v2(dseg + i, wr->sg_list + i);
-                               sge_ind++;
-                       }
-
+                       set_extend_sge(qp, wr, &sge_ind);
                        ind++;
                } else if (ibqp->qp_type == IB_QPT_RC) {
                        rc_sq_wqe = wqe;