hl_debugfs_remove_cb(cb);
 
+       hl_ctx_put(cb->ctx);
+
        cb_do_release(hdev, cb);
 }
 
 }
 
 int hl_cb_create(struct hl_device *hdev, struct hl_cb_mgr *mgr,
-                       u32 cb_size, u64 *handle, int ctx_id, bool internal_cb)
+                       struct hl_ctx *ctx, u32 cb_size, bool internal_cb,
+                       u64 *handle)
 {
        struct hl_cb *cb;
        bool alloc_new_cb = true;
-       int rc;
+       int rc, ctx_id = ctx->asid;
 
        /*
         * Can't use generic function to check this because of special case
        }
 
        cb->hdev = hdev;
-       cb->ctx_id = ctx_id;
+       cb->ctx = ctx;
+       hl_ctx_get(hdev, cb->ctx);
 
        spin_lock(&mgr->cb_lock);
        rc = idr_alloc(&mgr->cb_handles, cb, 1, 0, GFP_ATOMIC);
        return 0;
 
 release_cb:
+       hl_ctx_put(cb->ctx);
        cb_do_release(hdev, cb);
 out_err:
        *handle = 0;
                                args->in.cb_size, HL_MAX_CB_SIZE);
                        rc = -EINVAL;
                } else {
-                       rc = hl_cb_create(hdev, &hpriv->cb_mgr,
-                                       args->in.cb_size, &handle,
-                                       hpriv->ctx->asid, false);
+                       rc = hl_cb_create(hdev, &hpriv->cb_mgr, hpriv->ctx,
+                                       args->in.cb_size, false, &handle);
                }
 
                memset(args, 0, sizeof(*args));
                if (kref_put(&cb->refcount, cb_release) != 1)
                        dev_err(hdev->dev,
                                "CB %d for CTX ID %d is still alive\n",
-                               id, cb->ctx_id);
+                               id, cb->ctx->asid);
        }
 
        idr_destroy(&mgr->cb_handles);
        struct hl_cb *cb;
        int rc;
 
-       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, cb_size, &cb_handle,
-                       HL_KERNEL_ASID_ID, internal_cb);
+       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, hdev->kernel_ctx, cb_size,
+                               internal_cb, &cb_handle);
        if (rc) {
                dev_err(hdev->dev,
                        "Failed to allocate CB for the kernel driver %d\n", rc);
 
                }
                seq_printf(s,
                        "   %03llu        %d    0x%08x      %d          %d          %d\n",
-                       cb->id, cb->ctx_id, cb->size,
+                       cb->id, cb->ctx->asid, cb->size,
                        kref_read(&cb->refcount),
                        cb->mmap, cb->cs_cnt);
        }
 
  * struct hl_cb - describes a Command Buffer.
  * @refcount: reference counter for usage of the CB.
  * @hdev: pointer to device this CB belongs to.
+ * @ctx: pointer to the CB owner's context.
  * @lock: spinlock to protect mmap/cs flows.
  * @debugfs_list: node in debugfs list of command buffers.
  * @pool_list: node in pool list of command buffers.
  * @mmap_size: Holds the CB's size that was mmaped.
  * @size: holds the CB's size.
  * @cs_cnt: holds number of CS that this CB participates in.
- * @ctx_id: holds the ID of the owner's context.
  * @mmap: true if the CB is currently mmaped to user.
  * @is_pool: true if CB was acquired from the pool, false otherwise.
  * @is_internal: internaly allocated
 struct hl_cb {
        struct kref             refcount;
        struct hl_device        *hdev;
+       struct hl_ctx           *ctx;
        spinlock_t              lock;
        struct list_head        debugfs_list;
        struct list_head        pool_list;
        u32                     mmap_size;
        u32                     size;
        u32                     cs_cnt;
-       u32                     ctx_id;
        u8                      mmap;
        u8                      is_pool;
        u8                      is_internal;
 int hl_hwmon_init(struct hl_device *hdev);
 void hl_hwmon_fini(struct hl_device *hdev);
 
-int hl_cb_create(struct hl_device *hdev, struct hl_cb_mgr *mgr, u32 cb_size,
-               u64 *handle, int ctx_id, bool internal_cb);
+int hl_cb_create(struct hl_device *hdev, struct hl_cb_mgr *mgr,
+                       struct hl_ctx *ctx, u32 cb_size, bool internal_cb,
+                       u64 *handle);
 int hl_cb_destroy(struct hl_device *hdev, struct hl_cb_mgr *mgr, u64 cb_handle);
 int hl_cb_mmap(struct hl_fpriv *hpriv, struct vm_area_struct *vma);
 struct hl_cb *hl_cb_get(struct hl_device *hdev,        struct hl_cb_mgr *mgr,
 
        parser->patched_cb_size = parser->user_cb_size +
                        sizeof(struct packet_msg_prot) * 2;
 
-       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, parser->patched_cb_size,
-                       &patched_cb_handle, HL_KERNEL_ASID_ID, false);
+       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, hdev->kernel_ctx,
+                               parser->patched_cb_size, false,
+                               &patched_cb_handle);
 
        if (rc) {
                dev_err(hdev->dev,
        if (rc)
                goto free_userptr;
 
-       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, parser->patched_cb_size,
-                       &patched_cb_handle, HL_KERNEL_ASID_ID, false);
+       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, hdev->kernel_ctx,
+                               parser->patched_cb_size, false,
+                               &patched_cb_handle);
        if (rc) {
                dev_err(hdev->dev,
                        "Failed to allocate patched CB for DMA CS %d\n", rc);
 
        parser->patched_cb_size = parser->user_cb_size +
                        sizeof(struct packet_msg_prot) * 2;
 
-       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, parser->patched_cb_size,
-                       &patched_cb_handle, HL_KERNEL_ASID_ID, false);
+       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, hdev->kernel_ctx,
+                               parser->patched_cb_size, false,
+                               &patched_cb_handle);
 
        if (rc) {
                dev_err(hdev->dev,
        if (rc)
                goto free_userptr;
 
-       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, parser->patched_cb_size,
-                       &patched_cb_handle, HL_KERNEL_ASID_ID, false);
+       rc = hl_cb_create(hdev, &hdev->kernel_cb_mgr, hdev->kernel_ctx,
+                               parser->patched_cb_size, false,
+                               &patched_cb_handle);
        if (rc) {
                dev_err(hdev->dev,
                        "Failed to allocate patched CB for DMA CS %d\n", rc);