}
 }
 
-/**
- * of_msi_map_rid - Map a MSI requester ID for a device.
- * @dev: device for which the mapping is to be done.
- * @msi_np: device node of the expected msi controller.
- * @rid_in: unmapped MSI requester ID for the device.
- *
- * Walk up the device hierarchy looking for devices with a "msi-map"
- * property.  If found, apply the mapping to @rid_in.
- *
- * Returns the mapped MSI requester ID.
- */
-u32 of_msi_map_rid(struct device *dev, struct device_node *msi_np, u32 rid_in)
+static u32 __of_msi_map_rid(struct device *dev, struct device_node **np,
+                           u32 rid_in)
 {
        struct device *parent_dev;
        struct device_node *msi_controller_node;
+       struct device_node *msi_np = *np;
        u32 map_mask, masked_rid, rid_base, msi_base, rid_len, phandle;
        int msi_map_len;
        bool matched;
 
                msi_controller_node = of_find_node_by_phandle(phandle);
 
-               matched = masked_rid >= rid_base &&
-                       masked_rid < rid_base + rid_len &&
-                       msi_np == msi_controller_node;
+               matched = (masked_rid >= rid_base &&
+                          masked_rid < rid_base + rid_len);
+               if (msi_np)
+                       matched &= msi_np == msi_controller_node;
+
+               if (matched && !msi_np) {
+                       *np = msi_np = msi_controller_node;
+                       break;
+               }
 
                of_node_put(msi_controller_node);
                msi_map_len -= 4 * sizeof(__be32);
        return rid_out;
 }
 
+/**
+ * of_msi_map_rid - Map a MSI requester ID for a device.
+ * @dev: device for which the mapping is to be done.
+ * @msi_np: device node of the expected msi controller.
+ * @rid_in: unmapped MSI requester ID for the device.
+ *
+ * Walk up the device hierarchy looking for devices with a "msi-map"
+ * property.  If found, apply the mapping to @rid_in.
+ *
+ * Returns the mapped MSI requester ID.
+ */
+u32 of_msi_map_rid(struct device *dev, struct device_node *msi_np, u32 rid_in)
+{
+       return __of_msi_map_rid(dev, &msi_np, rid_in);
+}
+
 static struct irq_domain *__of_get_msi_domain(struct device_node *np,
                                              enum irq_domain_bus_token token)
 {