#include <net/pkt_sched.h>
 #include <net/pkt_cls.h>
 #include <net/sch_generic.h>
+#include <net/sock.h>
 
 static LIST_HEAD(taprio_list);
 static DEFINE_SPINLOCK(taprio_list_lock);
 
 #define TAPRIO_ALL_GATES_OPEN -1
 
+#define FLAGS_VALID(flags) (!((flags) & ~TCA_TAPRIO_ATTR_FLAG_TXTIME_ASSIST))
+#define TXTIME_ASSIST_IS_ENABLED(flags) ((flags) & TCA_TAPRIO_ATTR_FLAG_TXTIME_ASSIST)
+
 struct sched_entry {
        struct list_head list;
 
         * packet leaves after this time.
         */
        ktime_t close_time;
+       ktime_t next_txtime;
        atomic_t budget;
        int index;
        u32 gate_mask;
 struct taprio_sched {
        struct Qdisc **qdiscs;
        struct Qdisc *root;
+       u32 flags;
        int clockid;
        atomic64_t picos_per_byte; /* Using picoseconds because for 10Gbps+
                                    * speeds it's sub-nanoseconds per byte
        ktime_t (*get_time)(void);
        struct hrtimer advance_timer;
        struct list_head taprio_list;
+       int txtime_delay;
 };
 
 static ktime_t sched_base_time(const struct sched_gate_list *sched)
        *admin = NULL;
 }
 
+/* Get how much time has been already elapsed in the current cycle. */
+static s32 get_cycle_time_elapsed(struct sched_gate_list *sched, ktime_t time)
+{
+       ktime_t time_since_sched_start;
+       s32 time_elapsed;
+
+       time_since_sched_start = ktime_sub(time, sched->base_time);
+       div_s64_rem(time_since_sched_start, sched->cycle_time, &time_elapsed);
+
+       return time_elapsed;
+}
+
+static ktime_t get_interval_end_time(struct sched_gate_list *sched,
+                                    struct sched_gate_list *admin,
+                                    struct sched_entry *entry,
+                                    ktime_t intv_start)
+{
+       s32 cycle_elapsed = get_cycle_time_elapsed(sched, intv_start);
+       ktime_t intv_end, cycle_ext_end, cycle_end;
+
+       cycle_end = ktime_add_ns(intv_start, sched->cycle_time - cycle_elapsed);
+       intv_end = ktime_add_ns(intv_start, entry->interval);
+       cycle_ext_end = ktime_add(cycle_end, sched->cycle_time_extension);
+
+       if (ktime_before(intv_end, cycle_end))
+               return intv_end;
+       else if (admin && admin != sched &&
+                ktime_after(admin->base_time, cycle_end) &&
+                ktime_before(admin->base_time, cycle_ext_end))
+               return admin->base_time;
+       else
+               return cycle_end;
+}
+
+static int length_to_duration(struct taprio_sched *q, int len)
+{
+       return div_u64(len * atomic64_read(&q->picos_per_byte), 1000);
+}
+
+/* Returns the entry corresponding to next available interval. If
+ * validate_interval is set, it only validates whether the timestamp occurs
+ * when the gate corresponding to the skb's traffic class is open.
+ */
+static struct sched_entry *find_entry_to_transmit(struct sk_buff *skb,
+                                                 struct Qdisc *sch,
+                                                 struct sched_gate_list *sched,
+                                                 struct sched_gate_list *admin,
+                                                 ktime_t time,
+                                                 ktime_t *interval_start,
+                                                 ktime_t *interval_end,
+                                                 bool validate_interval)
+{
+       ktime_t curr_intv_start, curr_intv_end, cycle_end, packet_transmit_time;
+       ktime_t earliest_txtime = KTIME_MAX, txtime, cycle, transmit_end_time;
+       struct sched_entry *entry = NULL, *entry_found = NULL;
+       struct taprio_sched *q = qdisc_priv(sch);
+       struct net_device *dev = qdisc_dev(sch);
+       bool entry_available = false;
+       s32 cycle_elapsed;
+       int tc, n;
+
+       tc = netdev_get_prio_tc_map(dev, skb->priority);
+       packet_transmit_time = length_to_duration(q, qdisc_pkt_len(skb));
+
+       *interval_start = 0;
+       *interval_end = 0;
+
+       if (!sched)
+               return NULL;
+
+       cycle = sched->cycle_time;
+       cycle_elapsed = get_cycle_time_elapsed(sched, time);
+       curr_intv_end = ktime_sub_ns(time, cycle_elapsed);
+       cycle_end = ktime_add_ns(curr_intv_end, cycle);
+
+       list_for_each_entry(entry, &sched->entries, list) {
+               curr_intv_start = curr_intv_end;
+               curr_intv_end = get_interval_end_time(sched, admin, entry,
+                                                     curr_intv_start);
+
+               if (ktime_after(curr_intv_start, cycle_end))
+                       break;
+
+               if (!(entry->gate_mask & BIT(tc)) ||
+                   packet_transmit_time > entry->interval)
+                       continue;
+
+               txtime = entry->next_txtime;
+
+               if (ktime_before(txtime, time) || validate_interval) {
+                       transmit_end_time = ktime_add_ns(time, packet_transmit_time);
+                       if ((ktime_before(curr_intv_start, time) &&
+                            ktime_before(transmit_end_time, curr_intv_end)) ||
+                           (ktime_after(curr_intv_start, time) && !validate_interval)) {
+                               entry_found = entry;
+                               *interval_start = curr_intv_start;
+                               *interval_end = curr_intv_end;
+                               break;
+                       } else if (!entry_available && !validate_interval) {
+                               /* Here, we are just trying to find out the
+                                * first available interval in the next cycle.
+                                */
+                               entry_available = 1;
+                               entry_found = entry;
+                               *interval_start = ktime_add_ns(curr_intv_start, cycle);
+                               *interval_end = ktime_add_ns(curr_intv_end, cycle);
+                       }
+               } else if (ktime_before(txtime, earliest_txtime) &&
+                          !entry_available) {
+                       earliest_txtime = txtime;
+                       entry_found = entry;
+                       n = div_s64(ktime_sub(txtime, curr_intv_start), cycle);
+                       *interval_start = ktime_add(curr_intv_start, n * cycle);
+                       *interval_end = ktime_add(curr_intv_end, n * cycle);
+               }
+       }
+
+       return entry_found;
+}
+
+static bool is_valid_interval(struct sk_buff *skb, struct Qdisc *sch)
+{
+       struct taprio_sched *q = qdisc_priv(sch);
+       struct sched_gate_list *sched, *admin;
+       ktime_t interval_start, interval_end;
+       struct sched_entry *entry;
+
+       rcu_read_lock();
+       sched = rcu_dereference(q->oper_sched);
+       admin = rcu_dereference(q->admin_sched);
+
+       entry = find_entry_to_transmit(skb, sch, sched, admin, skb->tstamp,
+                                      &interval_start, &interval_end, true);
+       rcu_read_unlock();
+
+       return entry;
+}
+
+/* There are a few scenarios where we will have to modify the txtime from
+ * what is read from next_txtime in sched_entry. They are:
+ * 1. If txtime is in the past,
+ *    a. The gate for the traffic class is currently open and packet can be
+ *       transmitted before it closes, schedule the packet right away.
+ *    b. If the gate corresponding to the traffic class is going to open later
+ *       in the cycle, set the txtime of packet to the interval start.
+ * 2. If txtime is in the future, there are packets corresponding to the
+ *    current traffic class waiting to be transmitted. So, the following
+ *    possibilities exist:
+ *    a. We can transmit the packet before the window containing the txtime
+ *       closes.
+ *    b. The window might close before the transmission can be completed
+ *       successfully. So, schedule the packet in the next open window.
+ */
+static long get_packet_txtime(struct sk_buff *skb, struct Qdisc *sch)
+{
+       ktime_t transmit_end_time, interval_end, interval_start;
+       struct taprio_sched *q = qdisc_priv(sch);
+       struct sched_gate_list *sched, *admin;
+       ktime_t minimum_time, now, txtime;
+       int len, packet_transmit_time;
+       struct sched_entry *entry;
+       bool sched_changed;
+
+       now = q->get_time();
+       minimum_time = ktime_add_ns(now, q->txtime_delay);
+
+       rcu_read_lock();
+       admin = rcu_dereference(q->admin_sched);
+       sched = rcu_dereference(q->oper_sched);
+       if (admin && ktime_after(minimum_time, admin->base_time))
+               switch_schedules(q, &admin, &sched);
+
+       /* Until the schedule starts, all the queues are open */
+       if (!sched || ktime_before(minimum_time, sched->base_time)) {
+               txtime = minimum_time;
+               goto done;
+       }
+
+       len = qdisc_pkt_len(skb);
+       packet_transmit_time = length_to_duration(q, len);
+
+       do {
+               sched_changed = 0;
+
+               entry = find_entry_to_transmit(skb, sch, sched, admin,
+                                              minimum_time,
+                                              &interval_start, &interval_end,
+                                              false);
+               if (!entry) {
+                       txtime = 0;
+                       goto done;
+               }
+
+               txtime = entry->next_txtime;
+               txtime = max_t(ktime_t, txtime, minimum_time);
+               txtime = max_t(ktime_t, txtime, interval_start);
+
+               if (admin && admin != sched &&
+                   ktime_after(txtime, admin->base_time)) {
+                       sched = admin;
+                       sched_changed = 1;
+                       continue;
+               }
+
+               transmit_end_time = ktime_add(txtime, packet_transmit_time);
+               minimum_time = transmit_end_time;
+
+               /* Update the txtime of current entry to the next time it's
+                * interval starts.
+                */
+               if (ktime_after(transmit_end_time, interval_end))
+                       entry->next_txtime = ktime_add(interval_start, sched->cycle_time);
+       } while (sched_changed || ktime_after(transmit_end_time, interval_end));
+
+       entry->next_txtime = transmit_end_time;
+
+done:
+       rcu_read_unlock();
+       return txtime;
+}
+
 static int taprio_enqueue(struct sk_buff *skb, struct Qdisc *sch,
                          struct sk_buff **to_free)
 {
        if (unlikely(!child))
                return qdisc_drop(skb, sch, to_free);
 
+       if (skb->sk && sock_flag(skb->sk, SOCK_TXTIME)) {
+               if (!is_valid_interval(skb, sch))
+                       return qdisc_drop(skb, sch, to_free);
+       } else if (TXTIME_ASSIST_IS_ENABLED(q->flags)) {
+               skb->tstamp = get_packet_txtime(skb, sch);
+               if (!skb->tstamp)
+                       return qdisc_drop(skb, sch, to_free);
+       }
+
        qdisc_qstats_backlog_inc(sch, skb);
        sch->q.qlen++;
 
                if (!skb)
                        continue;
 
+               if (TXTIME_ASSIST_IS_ENABLED(q->flags))
+                       return skb;
+
                prio = skb->priority;
                tc = netdev_get_prio_tc_map(dev, prio);
 
        return NULL;
 }
 
-static int length_to_duration(struct taprio_sched *q, int len)
-{
-       return div_u64(len * atomic64_read(&q->picos_per_byte), 1000);
-}
-
 static void taprio_set_budget(struct taprio_sched *q, struct sched_entry *entry)
 {
        atomic_set(&entry->budget,
                if (unlikely(!child))
                        continue;
 
+               if (TXTIME_ASSIST_IS_ENABLED(q->flags)) {
+                       skb = child->ops->dequeue(child);
+                       if (!skb)
+                               continue;
+                       goto skb_found;
+               }
+
                skb = child->ops->peek(child);
                if (!skb)
                        continue;
                if (unlikely(!skb))
                        goto done;
 
+skb_found:
                qdisc_bstats_update(sch, skb);
                qdisc_qstats_backlog_dec(sch, skb);
                sch->q.qlen--;
 
 static int taprio_parse_mqprio_opt(struct net_device *dev,
                                   struct tc_mqprio_qopt *qopt,
-                                  struct netlink_ext_ack *extack)
+                                  struct netlink_ext_ack *extack,
+                                  u32 taprio_flags)
 {
        int i, j;
 
                        return -EINVAL;
                }
 
+               if (TXTIME_ASSIST_IS_ENABLED(taprio_flags))
+                       continue;
+
                /* Verify that the offset and counts do not overlap */
                for (j = i + 1; j < qopt->num_tc; j++) {
                        if (last > qopt->offset[j]) {
        return NOTIFY_DONE;
 }
 
+static void setup_txtime(struct taprio_sched *q,
+                        struct sched_gate_list *sched, ktime_t base)
+{
+       struct sched_entry *entry;
+       u32 interval = 0;
+
+       list_for_each_entry(entry, &sched->entries, list) {
+               entry->next_txtime = ktime_add_ns(base, interval);
+               interval += entry->interval;
+       }
+}
+
 static int taprio_change(struct Qdisc *sch, struct nlattr *opt,
                         struct netlink_ext_ack *extack)
 {
        struct taprio_sched *q = qdisc_priv(sch);
        struct net_device *dev = qdisc_dev(sch);
        struct tc_mqprio_qopt *mqprio = NULL;
+       u32 taprio_flags = 0;
        int i, err, clockid;
        unsigned long flags;
        ktime_t start;
        if (tb[TCA_TAPRIO_ATTR_PRIOMAP])
                mqprio = nla_data(tb[TCA_TAPRIO_ATTR_PRIOMAP]);
 
-       err = taprio_parse_mqprio_opt(dev, mqprio, extack);
+       if (tb[TCA_TAPRIO_ATTR_FLAGS]) {
+               taprio_flags = nla_get_u32(tb[TCA_TAPRIO_ATTR_FLAGS]);
+
+               if (q->flags != 0 && q->flags != taprio_flags) {
+                       NL_SET_ERR_MSG_MOD(extack, "Changing 'flags' of a running schedule is not supported");
+                       return -EOPNOTSUPP;
+               } else if (!FLAGS_VALID(taprio_flags)) {
+                       NL_SET_ERR_MSG_MOD(extack, "Specified 'flags' are not valid");
+                       return -EINVAL;
+               }
+
+               q->flags = taprio_flags;
+       }
+
+       err = taprio_parse_mqprio_opt(dev, mqprio, extack, taprio_flags);
        if (err < 0)
                return err;
 
        /* Protects against enqueue()/dequeue() */
        spin_lock_bh(qdisc_lock(sch));
 
-       if (!hrtimer_active(&q->advance_timer)) {
+       if (tb[TCA_TAPRIO_ATTR_TXTIME_DELAY]) {
+               if (!TXTIME_ASSIST_IS_ENABLED(q->flags)) {
+                       NL_SET_ERR_MSG_MOD(extack, "txtime-delay can only be set when txtime-assist mode is enabled");
+                       err = -EINVAL;
+                       goto unlock;
+               }
+
+               q->txtime_delay = nla_get_s32(tb[TCA_TAPRIO_ATTR_TXTIME_DELAY]);
+       }
+
+       if (!TXTIME_ASSIST_IS_ENABLED(taprio_flags) &&
+           !hrtimer_active(&q->advance_timer)) {
                hrtimer_init(&q->advance_timer, q->clockid, HRTIMER_MODE_ABS);
                q->advance_timer.function = advance_sched;
        }
                goto unlock;
        }
 
-       setup_first_close_time(q, new_admin, start);
+       if (TXTIME_ASSIST_IS_ENABLED(taprio_flags)) {
+               setup_txtime(q, new_admin, start);
 
-       /* Protects against advance_sched() */
-       spin_lock_irqsave(&q->current_entry_lock, flags);
+               if (!oper) {
+                       rcu_assign_pointer(q->oper_sched, new_admin);
+                       err = 0;
+                       new_admin = NULL;
+                       goto unlock;
+               }
 
-       taprio_start_sched(sch, start, new_admin);
+               rcu_assign_pointer(q->admin_sched, new_admin);
+               if (admin)
+                       call_rcu(&admin->rcu, taprio_free_sched_cb);
+       } else {
+               setup_first_close_time(q, new_admin, start);
 
-       rcu_assign_pointer(q->admin_sched, new_admin);
-       if (admin)
-               call_rcu(&admin->rcu, taprio_free_sched_cb);
-       new_admin = NULL;
+               /* Protects against advance_sched() */
+               spin_lock_irqsave(&q->current_entry_lock, flags);
+
+               taprio_start_sched(sch, start, new_admin);
 
-       spin_unlock_irqrestore(&q->current_entry_lock, flags);
+               rcu_assign_pointer(q->admin_sched, new_admin);
+               if (admin)
+                       call_rcu(&admin->rcu, taprio_free_sched_cb);
 
+               spin_unlock_irqrestore(&q->current_entry_lock, flags);
+       }
+
+       new_admin = NULL;
        err = 0;
 
 unlock:
        if (nla_put_s32(skb, TCA_TAPRIO_ATTR_SCHED_CLOCKID, q->clockid))
                goto options_error;
 
+       if (q->flags && nla_put_u32(skb, TCA_TAPRIO_ATTR_FLAGS, q->flags))
+               goto options_error;
+
+       if (q->txtime_delay &&
+           nla_put_s32(skb, TCA_TAPRIO_ATTR_TXTIME_DELAY, q->txtime_delay))
+               goto options_error;
+
        if (oper && dump_schedule(skb, oper))
                goto options_error;