#include <linux/clk.h>
 #include <linux/component.h>
 #include <linux/device.h>
+#include <linux/dma-direct.h>
 #include <linux/dma-iommu.h>
 #include <linux/err.h>
 #include <linux/interrupt.h>
        return IRQ_HANDLED;
 }
 
+static int mtk_iommu_get_domain_id(struct device *dev,
+                                  const struct mtk_iommu_plat_data *plat_data)
+{
+       const struct mtk_iommu_iova_region *rgn = plat_data->iova_region;
+       const struct bus_dma_region *dma_rgn = dev->dma_range_map;
+       int i, candidate = -1;
+       dma_addr_t dma_end;
+
+       if (!dma_rgn || plat_data->iova_region_nr == 1)
+               return 0;
+
+       dma_end = dma_rgn->dma_start + dma_rgn->size - 1;
+       for (i = 0; i < plat_data->iova_region_nr; i++, rgn++) {
+               /* Best fit. */
+               if (dma_rgn->dma_start == rgn->iova_base &&
+                   dma_end == rgn->iova_base + rgn->size - 1)
+                       return i;
+               /* ok if it is inside this region. */
+               if (dma_rgn->dma_start >= rgn->iova_base &&
+                   dma_end < rgn->iova_base + rgn->size)
+                       candidate = i;
+       }
+
+       if (candidate >= 0)
+               return candidate;
+       dev_err(dev, "Can NOT find the iommu domain id(%pad 0x%llx).\n",
+               &dma_rgn->dma_start, dma_rgn->size);
+       return -EINVAL;
+}
+
 static void mtk_iommu_config(struct mtk_iommu_data *data,
                             struct device *dev, bool enable)
 {
        struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
        struct mtk_iommu_domain *dom = to_mtk_domain(domain);
        struct device *m4udev = data->dev;
-       int ret;
+       int ret, domid;
 
        if (!data)
                return -ENODEV;
 
+       domid = mtk_iommu_get_domain_id(dev, data->plat_data);
+       if (domid < 0)
+               return domid;
+
        if (!dom->data) {
                if (mtk_iommu_domain_finalise(dom, data))
                        return -ENODEV;
 static struct iommu_group *mtk_iommu_device_group(struct device *dev)
 {
        struct mtk_iommu_data *data = mtk_iommu_get_m4u_data();
+       int domid;
 
        if (!data)
                return ERR_PTR(-ENODEV);
 
+       domid = mtk_iommu_get_domain_id(dev, data->plat_data);
+       if (domid < 0)
+               return ERR_PTR(domid);
+
        /* All the client devices are in the same m4u iommu-group */
        if (!data->m4u_group) {
                data->m4u_group = iommu_group_alloc();