return NULL;
 }
 
-void iommu_dma_map_msi_msg(int irq, struct msi_msg *msg)
+int iommu_dma_prepare_msi(struct msi_desc *desc, phys_addr_t msi_addr)
 {
-       struct device *dev = msi_desc_to_dev(irq_get_msi_desc(irq));
+       struct device *dev = msi_desc_to_dev(desc);
        struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
        struct iommu_dma_cookie *cookie;
        struct iommu_dma_msi_page *msi_page;
-       phys_addr_t msi_addr = (u64)msg->address_hi << 32 | msg->address_lo;
        unsigned long flags;
 
-       if (!domain || !domain->iova_cookie)
-               return;
+       if (!domain || !domain->iova_cookie) {
+               desc->iommu_cookie = NULL;
+               return 0;
+       }
 
        cookie = domain->iova_cookie;
 
        msi_page = iommu_dma_get_msi_page(dev, msi_addr, domain);
        spin_unlock_irqrestore(&cookie->msi_lock, flags);
 
-       if (WARN_ON(!msi_page)) {
+       msi_desc_set_iommu_cookie(desc, msi_page);
+
+       if (!msi_page)
+               return -ENOMEM;
+       return 0;
+}
+
+void iommu_dma_compose_msi_msg(struct msi_desc *desc,
+                              struct msi_msg *msg)
+{
+       struct device *dev = msi_desc_to_dev(desc);
+       const struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
+       const struct iommu_dma_msi_page *msi_page;
+
+       msi_page = msi_desc_get_iommu_cookie(desc);
+
+       if (!domain || !domain->iova_cookie || WARN_ON(!msi_page))
+               return;
+
+       msg->address_hi = upper_32_bits(msi_page->iova);
+       msg->address_lo &= cookie_msi_granule(domain->iova_cookie) - 1;
+       msg->address_lo += lower_32_bits(msi_page->iova);
+}
+
+void iommu_dma_map_msi_msg(int irq, struct msi_msg *msg)
+{
+       struct msi_desc *desc = irq_get_msi_desc(irq);
+       phys_addr_t msi_addr = (u64)msg->address_hi << 32 | msg->address_lo;
+
+       if (WARN_ON(iommu_dma_prepare_msi(desc, msi_addr))) {
                /*
                 * We're called from a void callback, so the best we can do is
                 * 'fail' by filling the message with obviously bogus values.
                msg->address_lo = ~0U;
                msg->data = ~0U;
        } else {
-               msg->address_hi = upper_32_bits(msi_page->iova);
-               msg->address_lo &= cookie_msi_granule(cookie) - 1;
-               msg->address_lo += lower_32_bits(msi_page->iova);
+               iommu_dma_compose_msi_msg(desc, msg);
        }
 }
 
                size_t size, enum dma_data_direction dir, unsigned long attrs);
 
 /* The DMA API isn't _quite_ the whole story, though... */
+/*
+ * iommu_dma_prepare_msi() - Map the MSI page in the IOMMU device
+ *
+ * The MSI page will be stored in @desc.
+ *
+ * Return: 0 on success otherwise an error describing the failure.
+ */
+int iommu_dma_prepare_msi(struct msi_desc *desc, phys_addr_t msi_addr);
+
+/* Update the MSI message if required. */
+void iommu_dma_compose_msi_msg(struct msi_desc *desc,
+                              struct msi_msg *msg);
+
 void iommu_dma_map_msi_msg(int irq, struct msi_msg *msg);
 void iommu_dma_get_resv_regions(struct device *dev, struct list_head *list);
 
 #else
 
 struct iommu_domain;
+struct msi_desc;
 struct msi_msg;
 struct device;
 
 {
 }
 
+static inline int iommu_dma_prepare_msi(struct msi_desc *desc,
+                                       phys_addr_t msi_addr)
+{
+       return 0;
+}
+
+static inline void iommu_dma_compose_msi_msg(struct msi_desc *desc,
+                                            struct msi_msg *msg)
+{
+}
+
 static inline void iommu_dma_map_msi_msg(int irq, struct msi_msg *msg)
 {
 }