int amdgpu_xcp_switch_partition_mode(struct amdgpu_xcp_mgr *xcp_mgr, int mode)
 {
-       int ret, num_xcps = 0;
+       int ret, curr_mode, num_xcps = 0;
 
        if (!xcp_mgr || mode == AMDGPU_XCP_MODE_NONE)
                return -EINVAL;
 
        mutex_lock(&xcp_mgr->xcp_lock);
 
+       curr_mode = xcp_mgr->mode;
+       /* State set to transient mode */
+       xcp_mgr->mode = AMDGPU_XCP_MODE_TRANS;
+
        ret = xcp_mgr->funcs->switch_partition_mode(xcp_mgr, mode, &num_xcps);
 
-       if (ret)
+       if (ret) {
+               /* Failed, get whatever mode it's at now */
+               if (xcp_mgr->funcs->query_partition_mode)
+                       xcp_mgr->mode = amdgpu_xcp_query_partition_mode(
+                               xcp_mgr, AMDGPU_XCP_FL_LOCKED);
+               else
+                       xcp_mgr->mode = curr_mode;
+
                goto out;
+       }
 
        if (!num_xcps || num_xcps > MAX_XCP) {
                ret = -EINVAL;
        if (!(flags & AMDGPU_XCP_FL_LOCKED))
                mutex_lock(&xcp_mgr->xcp_lock);
        mode = xcp_mgr->funcs->query_partition_mode(xcp_mgr);
-       if (mode != xcp_mgr->mode)
+       if (xcp_mgr->mode != AMDGPU_XCP_MODE_TRANS && mode != xcp_mgr->mode)
                dev_WARN(
                        xcp_mgr->adev->dev,
                        "Cached partition mode %d not matching with device mode %d",