#include <linux/mm.h>
 #include <linux/xarray.h>
 #include <linux/cdx/cdx_bus.h>
+#include <linux/iommu.h>
+#include <linux/dma-map-ops.h>
 #include "cdx.h"
 
 /* Default DMA mask for devices on a CDX bus */
 
 static int cdx_dma_configure(struct device *dev)
 {
+       struct cdx_driver *cdx_drv = to_cdx_driver(dev->driver);
        struct cdx_device *cdx_dev = to_cdx_device(dev);
        u32 input_id = cdx_dev->req_id;
        int ret;
                return ret;
        }
 
+       if (!ret && !cdx_drv->driver_managed_dma) {
+               ret = iommu_device_use_default_domain(dev);
+               if (ret)
+                       arch_teardown_dma_ops(dev);
+       }
+
        return 0;
 }
 
+static void cdx_dma_cleanup(struct device *dev)
+{
+       struct cdx_driver *cdx_drv = to_cdx_driver(dev->driver);
+
+       if (!cdx_drv->driver_managed_dma)
+               iommu_device_unuse_default_domain(dev);
+}
+
 /* show configuration fields */
 #define cdx_config_attr(field, format_string)  \
 static ssize_t \
        .remove         = cdx_remove,
        .shutdown       = cdx_shutdown,
        .dma_configure  = cdx_dma_configure,
+       .dma_cleanup    = cdx_dma_cleanup,
        .bus_groups     = cdx_bus_groups,
        .dev_groups     = cdx_dev_groups,
 };