gsi_evt_ring_id_free(gsi, evt_ring_id);
 }
 
-static bool gsi_channel_data_valid(struct gsi *gsi,
+static bool gsi_channel_data_valid(struct gsi *gsi, bool command,
                                   const struct ipa_gsi_endpoint_data *data)
 {
+       const struct gsi_channel_data *channel_data;
        u32 channel_id = data->channel_id;
        struct device *dev = gsi->dev;
 
                return false;
        }
 
-       if (!data->channel.tlv_count ||
-           data->channel.tlv_count > GSI_TLV_MAX) {
+       if (command && !data->toward_ipa) {
+               dev_err(dev, "command channel %u is not TX\n", channel_id);
+               return false;
+       }
+
+       channel_data = &data->channel;
+
+       if (!channel_data->tlv_count ||
+           channel_data->tlv_count > GSI_TLV_MAX) {
                dev_err(dev, "channel %u bad tlv_count %u; must be 1..%u\n",
-                       channel_id, data->channel.tlv_count, GSI_TLV_MAX);
+                       channel_id, channel_data->tlv_count, GSI_TLV_MAX);
+               return false;
+       }
+
+       if (command && IPA_COMMAND_TRANS_TRE_MAX > channel_data->tlv_count) {
+               dev_err(dev, "command TRE max too big for channel %u (%u > %u)\n",
+                       channel_id, IPA_COMMAND_TRANS_TRE_MAX,
+                       channel_data->tlv_count);
                return false;
        }
 
         * gsi_channel_tre_max() is computed, tre_count has to be almost
         * twice the TLV FIFO size to satisfy this requirement.
         */
-       if (data->channel.tre_count < 2 * data->channel.tlv_count - 1) {
+       if (channel_data->tre_count < 2 * channel_data->tlv_count - 1) {
                dev_err(dev, "channel %u TLV count %u exceeds TRE count %u\n",
-                       channel_id, data->channel.tlv_count,
-                       data->channel.tre_count);
+                       channel_id, channel_data->tlv_count,
+                       channel_data->tre_count);
                return false;
        }
 
-       if (!is_power_of_2(data->channel.tre_count)) {
+       if (!is_power_of_2(channel_data->tre_count)) {
                dev_err(dev, "channel %u bad tre_count %u; not power of 2\n",
-                       channel_id, data->channel.tre_count);
+                       channel_id, channel_data->tre_count);
                return false;
        }
 
-       if (!is_power_of_2(data->channel.event_count)) {
+       if (!is_power_of_2(channel_data->event_count)) {
                dev_err(dev, "channel %u bad event_count %u; not power of 2\n",
-                       channel_id, data->channel.event_count);
+                       channel_id, channel_data->event_count);
                return false;
        }
 
        u32 tre_count;
        int ret;
 
-       if (!gsi_channel_data_valid(gsi, data))
+       if (!gsi_channel_data_valid(gsi, command, data))
                return -EINVAL;
 
        /* Worst case we need an event for every outstanding TRE */