return false;
 }
 
-static inline
-int drm_dp_mst_atomic_check_bw_limit(struct drm_dp_mst_branch *branch,
-                                    struct drm_dp_mst_topology_state *mst_state)
+static int
+drm_dp_mst_atomic_check_port_bw_limit(struct drm_dp_mst_port *port,
+                                     struct drm_dp_mst_topology_state *state);
+
+static int
+drm_dp_mst_atomic_check_mstb_bw_limit(struct drm_dp_mst_branch *mstb,
+                                     struct drm_dp_mst_topology_state *state)
 {
-       struct drm_dp_mst_port *port;
        struct drm_dp_vcpi_allocation *vcpi;
-       int pbn_limit = 0, pbn_used = 0;
+       struct drm_dp_mst_port *port;
+       int pbn_used = 0, ret;
+       bool found = false;
 
-       list_for_each_entry(port, &branch->ports, next) {
-               if (port->mstb)
-                       if (drm_dp_mst_atomic_check_bw_limit(port->mstb, mst_state))
-                               return -ENOSPC;
+       /* Check that we have at least one port in our state that's downstream
+        * of this branch, otherwise we can skip this branch
+        */
+       list_for_each_entry(vcpi, &state->vcpis, next) {
+               if (!vcpi->pbn ||
+                   !drm_dp_mst_port_downstream_of_branch(vcpi->port, mstb))
+                       continue;
 
-               if (port->full_pbn > 0)
-                       pbn_limit = port->full_pbn;
+               found = true;
+               break;
        }
-       DRM_DEBUG_ATOMIC("[MST BRANCH:%p] branch has %d PBN available\n",
-                        branch, pbn_limit);
+       if (!found)
+               return 0;
 
-       list_for_each_entry(vcpi, &mst_state->vcpis, next) {
-               if (!vcpi->pbn)
-                       continue;
+       if (mstb->port_parent)
+               DRM_DEBUG_ATOMIC("[MSTB:%p] [MST PORT:%p] Checking bandwidth limits on [MSTB:%p]\n",
+                                mstb->port_parent->parent, mstb->port_parent,
+                                mstb);
+       else
+               DRM_DEBUG_ATOMIC("[MSTB:%p] Checking bandwidth limits\n",
+                                mstb);
+
+       list_for_each_entry(port, &mstb->ports, next) {
+               ret = drm_dp_mst_atomic_check_port_bw_limit(port, state);
+               if (ret < 0)
+                       return ret;
 
-               if (drm_dp_mst_port_downstream_of_branch(vcpi->port, branch))
-                       pbn_used += vcpi->pbn;
+               pbn_used += ret;
        }
-       DRM_DEBUG_ATOMIC("[MST BRANCH:%p] branch used %d PBN\n",
-                        branch, pbn_used);
 
-       if (pbn_used > pbn_limit) {
-               DRM_DEBUG_ATOMIC("[MST BRANCH:%p] No available bandwidth\n",
-                                branch);
+       return pbn_used;
+}
+
+static int
+drm_dp_mst_atomic_check_port_bw_limit(struct drm_dp_mst_port *port,
+                                     struct drm_dp_mst_topology_state *state)
+{
+       struct drm_dp_vcpi_allocation *vcpi;
+       int pbn_used = 0;
+
+       if (port->pdt == DP_PEER_DEVICE_NONE)
+               return 0;
+
+       if (drm_dp_mst_is_end_device(port->pdt, port->mcs)) {
+               bool found = false;
+
+               list_for_each_entry(vcpi, &state->vcpis, next) {
+                       if (vcpi->port != port)
+                               continue;
+                       if (!vcpi->pbn)
+                               return 0;
+
+                       found = true;
+                       break;
+               }
+               if (!found)
+                       return 0;
+
+               /* This should never happen, as it means we tried to
+                * set a mode before querying the full_pbn
+                */
+               if (WARN_ON(!port->full_pbn))
+                       return -EINVAL;
+
+               pbn_used = vcpi->pbn;
+       } else {
+               pbn_used = drm_dp_mst_atomic_check_mstb_bw_limit(port->mstb,
+                                                                state);
+               if (pbn_used <= 0)
+                       return pbn_used;
+       }
+
+       if (pbn_used > port->full_pbn) {
+               DRM_DEBUG_ATOMIC("[MSTB:%p] [MST PORT:%p] required PBN of %d exceeds port limit of %d\n",
+                                port->parent, port, pbn_used,
+                                port->full_pbn);
                return -ENOSPC;
        }
-       return 0;
+
+       DRM_DEBUG_ATOMIC("[MSTB:%p] [MST PORT:%p] uses %d out of %d PBN\n",
+                        port->parent, port, pbn_used, port->full_pbn);
+
+       return pbn_used;
 }
 
 static inline int
                ret = drm_dp_mst_atomic_check_vcpi_alloc_limit(mgr, mst_state);
                if (ret)
                        break;
-               ret = drm_dp_mst_atomic_check_bw_limit(mgr->mst_primary, mst_state);
-               if (ret)
+
+               mutex_lock(&mgr->lock);
+               ret = drm_dp_mst_atomic_check_mstb_bw_limit(mgr->mst_primary,
+                                                           mst_state);
+               mutex_unlock(&mgr->lock);
+               if (ret < 0)
                        break;
+               else
+                       ret = 0;
        }
 
        return ret;