struct ib_sa_sm_ah *new_ah;
        struct ib_port_attr port_attr;
        struct rdma_ah_attr   ah_attr;
+       bool grh_required;
 
        if (ib_query_port(port->agent->device, port->port_num, &port_attr)) {
                pr_warn("Couldn't query port\n");
        rdma_ah_set_sl(&ah_attr, port_attr.sm_sl);
        rdma_ah_set_port_num(&ah_attr, port->port_num);
 
+       grh_required = rdma_is_grh_required(port->agent->device,
+                                           port->port_num);
+
        /*
         * The OPA sm_lid of 0xFFFF needs special handling so that it can be
         * differentiated from a permissive LID of 0xFFFF.  We set the
         * address handle appropriately
         */
        if (ah_attr.type == RDMA_AH_ATTR_TYPE_OPA &&
-           (port_attr.grh_required ||
+           (grh_required ||
             port_attr.sm_lid == be16_to_cpu(IB_LID_PERMISSIVE)))
                rdma_ah_set_make_grd(&ah_attr, true);
 
-       if (ah_attr.type == RDMA_AH_ATTR_TYPE_IB && port_attr.grh_required) {
+       if (ah_attr.type == RDMA_AH_ATTR_TYPE_IB && grh_required) {
                rdma_ah_set_ah_flags(&ah_attr, IB_AH_GRH);
                rdma_ah_set_subnet_prefix(&ah_attr,
                                          cpu_to_be64(port_attr.subnet_prefix));
 
        if (!rdma_is_port_valid(device, ah_attr->port_num))
                return -EINVAL;
 
-       if (ah_attr->type == RDMA_AH_ATTR_TYPE_ROCE &&
+       if ((rdma_is_grh_required(device, ah_attr->port_num) ||
+            ah_attr->type == RDMA_AH_ATTR_TYPE_ROCE) &&
            !(ah_attr->ah_flags & IB_AH_GRH))
                return -EINVAL;
 
 
        props->qkey_viol_cntr   = rep->qkey_violation_counter;
        props->subnet_timeout   = rep->subnet_timeout;
        props->init_type_reply  = rep->init_type_reply;
-       props->grh_required     = rep->grh_required;
 
        err = mlx5_query_port_link_width_oper(mdev, &ib_link_width_oper, port);
        if (err)
                cancel_work_sync(&devr->ports[port].pkey_change_work);
 }
 
-static u32 get_core_cap_flags(struct ib_device *ibdev)
+static u32 get_core_cap_flags(struct ib_device *ibdev,
+                             struct mlx5_hca_vport_context *rep)
 {
        struct mlx5_ib_dev *dev = to_mdev(ibdev);
        enum rdma_link_layer ll = mlx5_ib_port_link_layer(ibdev, 1);
        bool raw_support = !mlx5_core_mp_enabled(dev->mdev);
        u32 ret = 0;
 
+       if (rep->grh_required)
+               ret |= RDMA_CORE_CAP_IB_GRH_REQUIRED;
+
        if (ll == IB_LINK_LAYER_INFINIBAND)
-               return RDMA_CORE_PORT_IBA_IB;
+               return ret | RDMA_CORE_PORT_IBA_IB;
 
        if (raw_support)
-               ret = RDMA_CORE_PORT_RAW_PACKET;
+               ret |= RDMA_CORE_PORT_RAW_PACKET;
 
        if (!(l3_type_cap & MLX5_ROCE_L3_TYPE_IPV4_CAP))
                return ret;
        struct ib_port_attr attr;
        struct mlx5_ib_dev *dev = to_mdev(ibdev);
        enum rdma_link_layer ll = mlx5_ib_port_link_layer(ibdev, port_num);
+       struct mlx5_hca_vport_context rep = {0};
        int err;
 
-       immutable->core_cap_flags = get_core_cap_flags(ibdev);
-
        err = ib_query_port(ibdev, port_num, &attr);
        if (err)
                return err;
 
+       if (ll == IB_LINK_LAYER_INFINIBAND) {
+               err = mlx5_query_hca_vport_context(dev->mdev, 0, port_num, 0,
+                                                  &rep);
+               if (err)
+                       return err;
+       }
+
        immutable->pkey_tbl_len = attr.pkey_tbl_len;
        immutable->gid_tbl_len = attr.gid_tbl_len;
-       immutable->core_cap_flags = get_core_cap_flags(ibdev);
+       immutable->core_cap_flags = get_core_cap_flags(ibdev, &rep);
        if ((ll == IB_LINK_LAYER_INFINIBAND) || MLX5_CAP_GEN(dev->mdev, roce))
                immutable->max_mad_size = IB_MGMT_MAD_SIZE;
 
 
 #define RDMA_CORE_CAP_AF_IB             0x00001000
 #define RDMA_CORE_CAP_ETH_AH            0x00002000
 #define RDMA_CORE_CAP_OPA_AH            0x00004000
+#define RDMA_CORE_CAP_IB_GRH_REQUIRED   0x00008000
 
 /* Protocol                             0xFFF00000 */
 #define RDMA_CORE_CAP_PROT_IB           0x00100000
 #define RDMA_CORE_CAP_PROT_RAW_PACKET   0x01000000
 #define RDMA_CORE_CAP_PROT_USNIC        0x02000000
 
+#define RDMA_CORE_PORT_IB_GRH_REQUIRED (RDMA_CORE_CAP_IB_GRH_REQUIRED \
+                                       | RDMA_CORE_CAP_PROT_ROCE     \
+                                       | RDMA_CORE_CAP_PROT_ROCE_UDP_ENCAP)
+
 #define RDMA_CORE_PORT_IBA_IB          (RDMA_CORE_CAP_PROT_IB  \
                                        | RDMA_CORE_CAP_IB_MAD \
                                        | RDMA_CORE_CAP_IB_SMI \
        enum ib_mtu             max_mtu;
        enum ib_mtu             active_mtu;
        int                     gid_tbl_len;
-       unsigned int            grh_required:1;
        unsigned int            ip_gids:1;
        /* This is the value from PortInfo CapabilityMask, defined by IBA */
        u32                     port_cap_flags;
                port <= rdma_end_port(device));
 }
 
+static inline bool rdma_is_grh_required(const struct ib_device *device,
+                                       u8 port_num)
+{
+       return device->port_immutable[port_num].core_cap_flags &
+               RDMA_CORE_PORT_IB_GRH_REQUIRED;
+}
+
 static inline bool rdma_protocol_ib(const struct ib_device *device, u8 port_num)
 {
        return device->port_immutable[port_num].core_cap_flags & RDMA_CORE_CAP_PROT_IB;