#include "class.h"
 #include "mux.h"
 
+#define TYPEC_MUX_MAX_DEVS     3
+
 struct typec_switch {
-       struct typec_switch_dev *sw_dev;
+       struct typec_switch_dev *sw_devs[TYPEC_MUX_MAX_DEVS];
+       unsigned int num_sw_devs;
 };
 
 static int switch_fwnode_match(struct device *dev, const void *fwnode)
  */
 struct typec_switch *fwnode_typec_switch_get(struct fwnode_handle *fwnode)
 {
-       struct typec_switch_dev *sw_dev;
+       struct typec_switch_dev *sw_devs[TYPEC_MUX_MAX_DEVS];
        struct typec_switch *sw;
+       int count;
+       int err;
+       int i;
 
        sw = kzalloc(sizeof(*sw), GFP_KERNEL);
        if (!sw)
                return ERR_PTR(-ENOMEM);
 
-       sw_dev = fwnode_connection_find_match(fwnode, "orientation-switch", NULL,
-                                             typec_switch_match);
-       if (IS_ERR_OR_NULL(sw_dev)) {
+       count = fwnode_connection_find_matches(fwnode, "orientation-switch", NULL,
+                                              typec_switch_match,
+                                              (void **)sw_devs,
+                                              ARRAY_SIZE(sw_devs));
+       if (count <= 0) {
                kfree(sw);
-               return ERR_CAST(sw_dev);
+               return NULL;
        }
 
-       WARN_ON(!try_module_get(sw_dev->dev.parent->driver->owner));
+       for (i = 0; i < count; i++) {
+               if (IS_ERR(sw_devs[i])) {
+                       err = PTR_ERR(sw_devs[i]);
+                       goto put_sw_devs;
+               }
+       }
 
-       sw->sw_dev = sw_dev;
+       for (i = 0; i < count; i++) {
+               WARN_ON(!try_module_get(sw_devs[i]->dev.parent->driver->owner));
+               sw->sw_devs[i] = sw_devs[i];
+       }
+
+       sw->num_sw_devs = count;
 
        return sw;
+
+put_sw_devs:
+       for (i = 0; i < count; i++) {
+               if (!IS_ERR(sw_devs[i]))
+                       put_device(&sw_devs[i]->dev);
+       }
+
+       kfree(sw);
+
+       return ERR_PTR(err);
 }
 EXPORT_SYMBOL_GPL(fwnode_typec_switch_get);
 
 void typec_switch_put(struct typec_switch *sw)
 {
        struct typec_switch_dev *sw_dev;
+       unsigned int i;
 
        if (IS_ERR_OR_NULL(sw))
                return;
 
-       sw_dev = sw->sw_dev;
+       for (i = 0; i < sw->num_sw_devs; i++) {
+               sw_dev = sw->sw_devs[i];
 
-       module_put(sw_dev->dev.parent->driver->owner);
-       put_device(&sw_dev->dev);
+               module_put(sw_dev->dev.parent->driver->owner);
+               put_device(&sw_dev->dev);
+       }
        kfree(sw);
 }
 EXPORT_SYMBOL_GPL(typec_switch_put);
                     enum typec_orientation orientation)
 {
        struct typec_switch_dev *sw_dev;
+       unsigned int i;
+       int ret;
 
        if (IS_ERR_OR_NULL(sw))
                return 0;
 
-       sw_dev = sw->sw_dev;
+       for (i = 0; i < sw->num_sw_devs; i++) {
+               sw_dev = sw->sw_devs[i];
+
+               ret = sw_dev->set(sw_dev, orientation);
+               if (ret)
+                       return ret;
+       }
 
-       return sw_dev->set(sw_dev, orientation);
+       return 0;
 }
 EXPORT_SYMBOL_GPL(typec_switch_set);
 
 /* ------------------------------------------------------------------------- */
 
 struct typec_mux {
-       struct typec_mux_dev *mux_dev;
+       struct typec_mux_dev *mux_devs[TYPEC_MUX_MAX_DEVS];
+       unsigned int num_mux_devs;
 };
 
 static int mux_fwnode_match(struct device *dev, const void *fwnode)
 struct typec_mux *fwnode_typec_mux_get(struct fwnode_handle *fwnode,
                                       const struct typec_altmode_desc *desc)
 {
-       struct typec_mux_dev *mux_dev;
+       struct typec_mux_dev *mux_devs[TYPEC_MUX_MAX_DEVS];
        struct typec_mux *mux;
+       int count;
+       int err;
+       int i;
 
        mux = kzalloc(sizeof(*mux), GFP_KERNEL);
        if (!mux)
                return ERR_PTR(-ENOMEM);
 
-       mux_dev = fwnode_connection_find_match(fwnode, "mode-switch", (void *)desc,
-                                              typec_mux_match);
-       if (IS_ERR_OR_NULL(mux_dev)) {
+       count = fwnode_connection_find_matches(fwnode, "mode-switch",
+                                              (void *)desc, typec_mux_match,
+                                              (void **)mux_devs,
+                                              ARRAY_SIZE(mux_devs));
+       if (count <= 0) {
                kfree(mux);
-               return ERR_CAST(mux_dev);
+               return NULL;
        }
 
-       WARN_ON(!try_module_get(mux_dev->dev.parent->driver->owner));
+       for (i = 0; i < count; i++) {
+               if (IS_ERR(mux_devs[i])) {
+                       err = PTR_ERR(mux_devs[i]);
+                       goto put_mux_devs;
+               }
+       }
+
+       for (i = 0; i < count; i++) {
+               WARN_ON(!try_module_get(mux_devs[i]->dev.parent->driver->owner));
+               mux->mux_devs[i] = mux_devs[i];
+       }
 
-       mux->mux_dev = mux_dev;
+       mux->num_mux_devs = count;
 
        return mux;
+
+put_mux_devs:
+       for (i = 0; i < count; i++) {
+               if (!IS_ERR(mux_devs[i]))
+                       put_device(&mux_devs[i]->dev);
+       }
+
+       kfree(mux);
+
+       return ERR_PTR(err);
 }
 EXPORT_SYMBOL_GPL(fwnode_typec_mux_get);
 
 void typec_mux_put(struct typec_mux *mux)
 {
        struct typec_mux_dev *mux_dev;
+       unsigned int i;
 
        if (IS_ERR_OR_NULL(mux))
                return;
 
-       mux_dev = mux->mux_dev;
-       module_put(mux_dev->dev.parent->driver->owner);
-       put_device(&mux_dev->dev);
+       for (i = 0; i < mux->num_mux_devs; i++) {
+               mux_dev = mux->mux_devs[i];
+               module_put(mux_dev->dev.parent->driver->owner);
+               put_device(&mux_dev->dev);
+       }
        kfree(mux);
 }
 EXPORT_SYMBOL_GPL(typec_mux_put);
 int typec_mux_set(struct typec_mux *mux, struct typec_mux_state *state)
 {
        struct typec_mux_dev *mux_dev;
+       unsigned int i;
+       int ret;
 
        if (IS_ERR_OR_NULL(mux))
                return 0;
 
-       mux_dev = mux->mux_dev;
+       for (i = 0; i < mux->num_mux_devs; i++) {
+               mux_dev = mux->mux_devs[i];
+
+               ret = mux_dev->set(mux_dev, state);
+               if (ret)
+                       return ret;
+       }
 
-       return mux_dev->set(mux_dev, state);
+       return 0;
 }
 EXPORT_SYMBOL_GPL(typec_mux_set);