#include "mvm.h"
 #include "iwl-debug.h"
 #include <linux/timekeeping.h>
+#include <linux/math64.h>
 
 #define IWL_PTP_GP2_WRAP       0x100000000ULL
 #define IWL_PTP_WRAP_TIME      (3600 * HZ)
 
+/* The scaled_ppm parameter is ppm (parts per million) with a 16-bit fractional
+ * part, which means that a value of 1 in one of those fields actually means
+ * 2^-16 ppm, and 2^16=65536 is 1 ppm.
+ */
+#define SCALE_FACTOR   65536000000ULL
+#define IWL_PTP_WRAP_THRESHOLD_USEC    (5000)
+
 static void iwl_mvm_ptp_update_new_read(struct iwl_mvm *mvm, u32 gp2)
 {
+       /* If the difference is above the threshold, assume it's a wraparound.
+        * Otherwise assume it's an old read and ignore it.
+        */
+       if (gp2 < mvm->ptp_data.last_gp2 &&
+           mvm->ptp_data.last_gp2 - gp2 < IWL_PTP_WRAP_THRESHOLD_USEC) {
+               IWL_DEBUG_INFO(mvm,
+                              "PTP: ignore old read (gp2=%u, last_gp2=%u)\n",
+                              gp2, mvm->ptp_data.last_gp2);
+               return;
+       }
+
        if (gp2 < mvm->ptp_data.last_gp2) {
                mvm->ptp_data.wrap_counter++;
                IWL_DEBUG_INFO(mvm,
        schedule_delayed_work(&mvm->ptp_data.dwork, IWL_PTP_WRAP_TIME);
 }
 
+u64 iwl_mvm_ptp_get_adj_time(struct iwl_mvm *mvm, u64 base_time_ns)
+{
+       struct ptp_data *data = &mvm->ptp_data;
+       u64 last_gp2_ns = mvm->ptp_data.scale_update_gp2 * NSEC_PER_USEC;
+       u64 res;
+       u64 diff;
+
+       iwl_mvm_ptp_update_new_read(mvm,
+                                   div64_u64(base_time_ns, NSEC_PER_USEC));
+
+       IWL_DEBUG_INFO(mvm, "base_time_ns=%llu, wrap_counter=%u\n",
+                      (unsigned long long)base_time_ns, data->wrap_counter);
+
+       base_time_ns = base_time_ns +
+               (data->wrap_counter * IWL_PTP_GP2_WRAP * NSEC_PER_USEC);
+
+       /* It is possible that a GP2 timestamp was received from fw before the
+        * last scale update. Since we don't know how to scale - ignore it.
+        */
+       if (base_time_ns < last_gp2_ns) {
+               IWL_DEBUG_INFO(mvm, "Time before scale update - ignore\n");
+               return 0;
+       }
+
+       diff = base_time_ns - last_gp2_ns;
+       IWL_DEBUG_INFO(mvm, "diff ns=%llu\n", (unsigned long long)diff);
+
+       diff = mul_u64_u64_div_u64(diff, data->scaled_freq,
+                                  SCALE_FACTOR);
+       IWL_DEBUG_INFO(mvm, "scaled diff ns=%llu\n", (unsigned long long)diff);
+
+       res = data->scale_update_adj_time_ns + data->delta + diff;
+
+       IWL_DEBUG_INFO(mvm, "base=%llu delta=%lld adj=%llu\n",
+                      (unsigned long long)base_time_ns, (long long)data->delta,
+                      (unsigned long long)res);
+       return res;
+}
+
 static int
 iwl_mvm_get_crosstimestamp_fw(struct iwl_mvm *mvm, u32 *gp2, u64 *sys_time)
 {
                                      &sys_time);
        }
 
-       iwl_mvm_ptp_update_new_read(mvm, gp2);
-
-       gp2_ns = (gp2 + (mvm->ptp_data.wrap_counter * IWL_PTP_GP2_WRAP)) *
-               NSEC_PER_USEC;
+       gp2_ns = iwl_mvm_ptp_get_adj_time(mvm, (u64)gp2 * NSEC_PER_USEC);
 
        IWL_INFO(mvm, "Got Sync Time: GP2:%u, last_GP2: %u, GP2_ns: %lld, sys_time: %lld\n",
                 gp2, mvm->ptp_data.last_gp2, gp2_ns, (s64)sys_time);
        mutex_unlock(&mvm->mutex);
 }
 
+static int iwl_mvm_ptp_gettime(struct ptp_clock_info *ptp,
+                              struct timespec64 *ts)
+{
+       struct iwl_mvm *mvm = container_of(ptp, struct iwl_mvm,
+                                          ptp_data.ptp_clock_info);
+       u64 gp2;
+       u64 ns;
+
+       mutex_lock(&mvm->mutex);
+       gp2 = iwl_mvm_get_systime(mvm);
+       ns = iwl_mvm_ptp_get_adj_time(mvm, gp2 * NSEC_PER_USEC);
+       mutex_unlock(&mvm->mutex);
+
+       *ts = ns_to_timespec64(ns);
+       return 0;
+}
+
+static int iwl_mvm_ptp_adjtime(struct ptp_clock_info *ptp, s64 delta)
+{
+       struct iwl_mvm *mvm = container_of(ptp, struct iwl_mvm,
+                                          ptp_data.ptp_clock_info);
+       struct ptp_data *data = container_of(ptp, struct ptp_data,
+                                            ptp_clock_info);
+
+       mutex_lock(&mvm->mutex);
+       data->delta += delta;
+       IWL_DEBUG_INFO(mvm, "delta=%lld, new delta=%lld\n", (long long)delta,
+                      (long long)data->delta);
+       mutex_unlock(&mvm->mutex);
+       return 0;
+}
+
+static int iwl_mvm_ptp_adjfine(struct ptp_clock_info *ptp, long scaled_ppm)
+{
+       struct iwl_mvm *mvm = container_of(ptp, struct iwl_mvm,
+                                          ptp_data.ptp_clock_info);
+       struct ptp_data *data = &mvm->ptp_data;
+       u32 gp2;
+
+       mutex_lock(&mvm->mutex);
+
+       /* Must call _iwl_mvm_ptp_get_adj_time() before updating
+        * data->scale_update_gp2 or data->scaled_freq since
+        * scale_update_adj_time_ns should reflect the previous scaled_freq.
+        */
+       gp2 = iwl_mvm_get_systime(mvm);
+       data->scale_update_adj_time_ns =
+               iwl_mvm_ptp_get_adj_time(mvm, gp2 * NSEC_PER_USEC);
+       data->scale_update_gp2 = gp2;
+       data->wrap_counter = 0;
+       data->delta = 0;
+
+       data->scaled_freq = SCALE_FACTOR + scaled_ppm;
+       IWL_DEBUG_INFO(mvm, "adjfine: scaled_ppm=%ld new=%llu\n",
+                      scaled_ppm, (unsigned long long)data->scaled_freq);
+
+       mutex_unlock(&mvm->mutex);
+       return 0;
+}
+
 /* iwl_mvm_ptp_init - initialize PTP for devices which support it.
  * @mvm: internal mvm structure, see &struct iwl_mvm.
  *
        mvm->ptp_data.ptp_clock_info.max_adj = 0x7fffffff;
        mvm->ptp_data.ptp_clock_info.getcrosststamp =
                                        iwl_mvm_phc_get_crosstimestamp;
+       mvm->ptp_data.ptp_clock_info.adjfine = iwl_mvm_ptp_adjfine;
+       mvm->ptp_data.ptp_clock_info.adjtime = iwl_mvm_ptp_adjtime;
+       mvm->ptp_data.ptp_clock_info.gettime64 = iwl_mvm_ptp_gettime;
+       mvm->ptp_data.scaled_freq = SCALE_FACTOR;
 
        /* Give a short 'friendly name' to identify the PHC clock */
        snprintf(mvm->ptp_data.ptp_clock_info.name,
 
        struct sk_buff *skb =
                iwl_mvm_time_sync_find_skb(mvm, notif->peer_addr,
                                           le32_to_cpu(notif->dialog_token));
+       u64 adj_time;
 
        if (!skb) {
                IWL_DEBUG_INFO(mvm, "Time sync event but no pending skb\n");
                return;
        }
 
-       ts_10ns = iwl_mvm_get_64_bit(notif->t3_hi, notif->t3_lo);
-       rx_status = IEEE80211_SKB_RXCB(skb);
-       rx_status->ack_tx_hwtstamp = ktime_set(0, ts_10ns * 10);
-
        ts_10ns = iwl_mvm_get_64_bit(notif->t2_hi, notif->t2_lo);
+       adj_time = iwl_mvm_ptp_get_adj_time(mvm, ts_10ns * 10);
        shwt = skb_hwtstamps(skb);
-       shwt->hwtstamp = ktime_set(0, ts_10ns * 10);
+       shwt->hwtstamp = ktime_set(0, adj_time);
+
+       ts_10ns = iwl_mvm_get_64_bit(notif->t3_hi, notif->t3_lo);
+       adj_time = iwl_mvm_ptp_get_adj_time(mvm, ts_10ns * 10);
+       rx_status = IEEE80211_SKB_RXCB(skb);
+       rx_status->ack_tx_hwtstamp = ktime_set(0, adj_time);
 
        IWL_DEBUG_INFO(mvm,
                       "Time sync: RX event - report frame t2=%llu t3=%llu\n",
        struct iwl_time_msmt_cfm_notify *notif = (void *)pkt->data;
        struct ieee80211_tx_status status = {};
        struct skb_shared_hwtstamps *shwt;
-       u64 ts_10ns;
+       u64 ts_10ns, adj_time;
 
        status.skb =
                iwl_mvm_time_sync_find_skb(mvm, notif->peer_addr,
                return;
        }
 
-       status.info = IEEE80211_SKB_CB(status.skb);
-
-       ts_10ns = iwl_mvm_get_64_bit(notif->t4_hi, notif->t4_lo);
-       status.ack_hwtstamp = ktime_set(0, ts_10ns * 10);
-
        ts_10ns = iwl_mvm_get_64_bit(notif->t1_hi, notif->t1_lo);
+       adj_time = iwl_mvm_ptp_get_adj_time(mvm, ts_10ns * 10);
        shwt = skb_hwtstamps(status.skb);
-       shwt->hwtstamp = ktime_set(0, ts_10ns * 10);
+       shwt->hwtstamp = ktime_set(0, adj_time);
+
+       ts_10ns = iwl_mvm_get_64_bit(notif->t4_hi, notif->t4_lo);
+       adj_time = iwl_mvm_ptp_get_adj_time(mvm, ts_10ns * 10);
+       status.info = IEEE80211_SKB_CB(status.skb);
+       status.ack_hwtstamp = ktime_set(0, adj_time);
 
        IWL_DEBUG_INFO(mvm,
                       "Time sync: TX event - report frame t1=%llu t4=%llu\n",