#include <linux/clk.h>
 #include <linux/delay.h>
 #include <linux/io.h>
+#include <linux/iopoll.h>
 #include <linux/kernel.h>
 #include <linux/module.h>
 #include <linux/of.h>
        struct watchdog_device wdev;
        unsigned long clk_rate;
        u8 cks;
+       struct clk *clk;
 };
 
 static void rwdt_write(struct rwdt_priv *priv, u32 val, unsigned int reg)
        return DIV_BY_CLKS_PER_SEC(priv, 65536 - val);
 }
 
+/* needs to be atomic - no RPM, no usleep_range, no scheduling! */
 static int rwdt_restart(struct watchdog_device *wdev, unsigned long action,
                        void *data)
 {
        struct rwdt_priv *priv = watchdog_get_drvdata(wdev);
+       u8 val;
+
+       clk_prepare_enable(priv->clk);
+
+       /* Stop the timer before we modify any register */
+       val = readb_relaxed(priv->base + RWTCSRA) & ~RWTCSRA_TME;
+       rwdt_write(priv, val, RWTCSRA);
+       /* Delay 2 cycles before setting watchdog counter */
+       udelay(DIV_ROUND_UP(2 * 1000000, priv->clk_rate));
 
-       rwdt_start(wdev);
        rwdt_write(priv, 0xffff, RWTCNT);
+       /* smallest divider to reboot soon */
+       rwdt_write(priv, 0, RWTCSRA);
+
+       readb_poll_timeout_atomic(priv->base + RWTCSRA, val,
+                                 !(val & RWTCSRA_WRFLG), 1, 100);
+
+       rwdt_write(priv, RWTCSRA_TME, RWTCSRA);
+
        return 0;
 }
 
 {
        struct device *dev = &pdev->dev;
        struct rwdt_priv *priv;
-       struct clk *clk;
        unsigned long clks_per_sec;
        int ret, i;
        u8 csra;
        if (IS_ERR(priv->base))
                return PTR_ERR(priv->base);
 
-       clk = devm_clk_get(dev, NULL);
-       if (IS_ERR(clk))
-               return PTR_ERR(clk);
+       priv->clk = devm_clk_get(dev, NULL);
+       if (IS_ERR(priv->clk))
+               return PTR_ERR(priv->clk);
 
        pm_runtime_enable(dev);
        pm_runtime_get_sync(dev);
-       priv->clk_rate = clk_get_rate(clk);
+       priv->clk_rate = clk_get_rate(priv->clk);
        csra = readb_relaxed(priv->base + RWTCSRA);
        priv->wdev.bootstatus = csra & RWTCSRA_WOVF ? WDIOF_CARDRESET : 0;
        pm_runtime_put(dev);