offset += va_block->size;
        }
 
-       hdev->asic_funcs->mmu_invalidate_cache(hdev, false,
-               MMU_OP_USERPTR | MMU_OP_SKIP_LOW_CACHE_INV);
+       rc = hl_mmu_invalidate_cache(hdev, false, MMU_OP_USERPTR | MMU_OP_SKIP_LOW_CACHE_INV);
 
        mutex_unlock(&ctx->mmu_lock);
 
        cb->is_mmu_mapped = true;
 
-       return 0;
+       return rc;
 
 err_va_umap:
        list_for_each_entry(va_block, &cb->va_block_list, node) {
                offset -= va_block->size;
        }
 
-       hdev->asic_funcs->mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
+       rc = hl_mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
 
        mutex_unlock(&ctx->mmu_lock);
 
                                        "Failed to unmap CB's va 0x%llx\n",
                                        va_block->start);
 
-       hdev->asic_funcs->mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
+       hl_mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
 
        mutex_unlock(&ctx->mmu_lock);
 
 
 int hl_mmu_map_contiguous(struct hl_ctx *ctx, u64 virt_addr,
                                        u64 phys_addr, u32 size);
 int hl_mmu_unmap_contiguous(struct hl_ctx *ctx, u64 virt_addr, u32 size);
+int hl_mmu_invalidate_cache(struct hl_device *hdev, bool is_hard, u32 flags);
+int hl_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_hard,
+                                       u32 flags, u32 asid, u64 va, u64 size);
 void hl_mmu_swap_out(struct hl_ctx *ctx);
 void hl_mmu_swap_in(struct hl_ctx *ctx);
 int hl_mmu_if_set_funcs(struct hl_device *hdev);
 
                goto map_err;
        }
 
-       rc = hdev->asic_funcs->mmu_invalidate_cache_range(hdev, false,
-               *vm_type | MMU_OP_SKIP_LOW_CACHE_INV,
-               ctx->asid, ret_vaddr, phys_pg_pack->total_size);
+       rc = hl_mmu_invalidate_cache_range(hdev, false, *vm_type | MMU_OP_SKIP_LOW_CACHE_INV,
+                               ctx->asid, ret_vaddr, phys_pg_pack->total_size);
 
        mutex_unlock(&ctx->mmu_lock);
 
-       if (rc) {
-               dev_err(hdev->dev,
-                       "mapping handle %u failed due to MMU cache invalidation\n",
-                       handle);
+       if (rc)
                goto map_err;
-       }
 
        ret_vaddr += phys_pg_pack->offset;
 
         * at the loop end rather than for each iteration
         */
        if (!ctx_free)
-               rc = hdev->asic_funcs->mmu_invalidate_cache_range(hdev, true,
-                               *vm_type, ctx->asid, vaddr,
-                               phys_pg_pack->total_size);
+               rc = hl_mmu_invalidate_cache_range(hdev, true, *vm_type, ctx->asid, vaddr,
+                                                       phys_pg_pack->total_size);
 
        mutex_unlock(&ctx->mmu_lock);
 
        if (!ctx_free) {
                int tmp_rc;
 
-               if (rc)
-                       dev_err(hdev->dev,
-                               "unmapping vaddr 0x%llx failed due to MMU cache invalidation\n",
-                               vaddr);
-
                tmp_rc = add_va_block(hdev, va_range, vaddr,
                                        vaddr + phys_pg_pack->total_size - 1);
                if (tmp_rc) {
        mutex_lock(&ctx->mmu_lock);
 
        /* invalidate the cache once after the unmapping loop */
-       hdev->asic_funcs->mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
-       hdev->asic_funcs->mmu_invalidate_cache(hdev, true, MMU_OP_PHYS_PACK);
+       hl_mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
+       hl_mmu_invalidate_cache(hdev, true, MMU_OP_PHYS_PACK);
 
        mutex_unlock(&ctx->mmu_lock);
 
 
 {
        return addr;
 }
+
+int hl_mmu_invalidate_cache(struct hl_device *hdev, bool is_hard, u32 flags)
+{
+       int rc;
+
+       rc = hdev->asic_funcs->mmu_invalidate_cache(hdev, is_hard, flags);
+       if (rc)
+               dev_err_ratelimited(hdev->dev, "MMU cache invalidation failed\n");
+
+       return rc;
+}
+
+int hl_mmu_invalidate_cache_range(struct hl_device *hdev, bool is_hard,
+                                       u32 flags, u32 asid, u64 va, u64 size)
+{
+       int rc;
+
+       rc = hdev->asic_funcs->mmu_invalidate_cache_range(hdev, is_hard, flags,
+                                                               asid, va, size);
+       if (rc)
+               dev_err_ratelimited(hdev->dev, "MMU cache range invalidation failed\n");
+
+       return rc;
+}
+
 
 
        WREG32(mmSTLB_INV_SET, 0);
 
-       if (rc) {
-               dev_err_ratelimited(hdev->dev,
-                                       "MMU cache invalidation timeout\n");
-               hl_device_reset(hdev, HL_DRV_RESET_HARD);
-       }
-
        return rc;
 }
 
 
                1000,
                timeout_usec);
 
-       if (rc) {
-               dev_err_ratelimited(hdev->dev,
-                                       "MMU cache invalidation timeout\n");
-               hl_device_reset(hdev, HL_DRV_RESET_HARD);
-       }
-
        return rc;
 }