uint64_t aligned_size;
        u64 alloc_flags;
        int ret;
+       int mem_id = 0; /* Fixme : to be changed when mem_id support patch lands, until then NPS1, SPX only */
 
        /*
         * Check on which domain to allocate BO
            ((*mem)->alloc_flags & KFD_IOC_ALLOC_MEM_FLAGS_VRAM)) {
                bo->allowed_domains = AMDGPU_GEM_DOMAIN_GTT;
                bo->preferred_domains = AMDGPU_GEM_DOMAIN_GTT;
+               ret = amdgpu_ttm_tt_set_mem_pool(&bo->tbo, mem_id);
+               if (ret) {
+                       pr_debug("failed to set ttm mem pool %d\n", ret);
+                       goto err_set_mem_partition;
+               }
        }
 
        add_kgd_mem_to_kfd_bo_list(*mem, avm->process_info, user_addr);
 allocate_init_user_pages_failed:
 err_pin_bo:
        remove_kgd_mem_from_kfd_bo_list(*mem, avm->process_info);
+err_set_mem_partition:
        drm_vma_node_revoke(&gobj->vma_node, drm_priv);
 err_node_allow:
        /* Don't unreserve system mem limit twice */
 
        return ttm_pool_free(pool, ttm);
 }
 
+/**
+ * amdgpu_ttm_tt_set_mem_pool - Set the TTM memory pool for the TTM BO
+ * @tbo: The ttm_buffer_object that backs the VRAM bo
+ * @mem_id: to select the initialized ttm pool corresponding to the memory partition
+ */
+int amdgpu_ttm_tt_set_mem_pool(struct ttm_buffer_object *tbo, int mem_id)
+{
+       struct ttm_tt *ttm = tbo->ttm;
+       struct amdgpu_ttm_tt *gtt;
+
+       if (!ttm && !ttm_tt_is_populated(ttm))
+               return -EINVAL;
+
+       gtt = ttm_to_amdgpu_ttm_tt(ttm);
+       gtt->pool_id = mem_id;
+       return 0;
+}
+
 /**
  * amdgpu_ttm_tt_get_userptr - Return the userptr GTT ttm_tt for the current
  * task
 
 struct mm_struct *amdgpu_ttm_tt_get_usermm(struct ttm_tt *ttm);
 bool amdgpu_ttm_tt_affect_userptr(struct ttm_tt *ttm, unsigned long start,
                                  unsigned long end, unsigned long *userptr);
+int amdgpu_ttm_tt_set_mem_pool(struct ttm_buffer_object *tbo, int mem_id);
 bool amdgpu_ttm_tt_userptr_invalidated(struct ttm_tt *ttm,
                                       int *last_invalidated);
 bool amdgpu_ttm_tt_is_userptr(struct ttm_tt *ttm);