extern const struct amd_ip_funcs gmc_v9_0_ip_funcs;
 extern const struct amdgpu_ip_block_version gmc_v9_0_ip_block;
 
+/* amdgpu_amdkfd*.c */
+void gfxhub_v1_0_setup_vm_pt_regs(struct amdgpu_device *adev, uint32_t vmid,
+                               uint64_t value);
+void mmhub_v1_0_setup_vm_pt_regs(struct amdgpu_device *adev, uint32_t vmid,
+                               uint64_t value);
+void mmhub_v9_4_setup_vm_pt_regs(struct amdgpu_device *adev, int hubid,
+                               uint32_t vmid, uint64_t value);
 #endif
 
        return base;
 }
 
-static void mmhub_v9_4_init_gart_pt_regs(struct amdgpu_device *adev, int hubid)
+void mmhub_v9_4_setup_vm_pt_regs(struct amdgpu_device *adev, int hubid,
+                               uint32_t vmid, uint64_t value)
 {
-       uint64_t value = amdgpu_gmc_pd_addr(adev->gart.bo);
+       /* two registers distance between mmVML2VC0_VM_CONTEXT0_* to
+        * mmVML2VC0_VM_CONTEXT1_*
+        */
+       int dist = mmVML2VC0_VM_CONTEXT1_PAGE_TABLE_BASE_ADDR_LO32
+                       - mmVML2VC0_VM_CONTEXT0_PAGE_TABLE_BASE_ADDR_LO32;
 
        WREG32_SOC15_OFFSET(MMHUB, 0,
                            mmVML2VC0_VM_CONTEXT0_PAGE_TABLE_BASE_ADDR_LO32,
-                           hubid * MMHUB_INSTANCE_REGISTER_OFFSET,
+                           dist * vmid + hubid * MMHUB_INSTANCE_REGISTER_OFFSET,
                            lower_32_bits(value));
 
        WREG32_SOC15_OFFSET(MMHUB, 0,
                            mmVML2VC0_VM_CONTEXT0_PAGE_TABLE_BASE_ADDR_HI32,
-                           hubid * MMHUB_INSTANCE_REGISTER_OFFSET,
+                           dist * vmid + hubid * MMHUB_INSTANCE_REGISTER_OFFSET,
                            upper_32_bits(value));
 
 }
 static void mmhub_v9_4_init_gart_aperture_regs(struct amdgpu_device *adev,
                                               int hubid)
 {
-       mmhub_v9_4_init_gart_pt_regs(adev, hubid);
+       uint64_t pt_base = amdgpu_gmc_pd_addr(adev->gart.bo);
+
+       mmhub_v9_4_setup_vm_pt_regs(adev, hubid, 0, pt_base);
 
        WREG32_SOC15_OFFSET(MMHUB, 0,
                            mmVML2VC0_VM_CONTEXT0_PAGE_TABLE_START_ADDR_LO32,