#include <linux/init.h>
 #include <linux/io.h>
 #include <linux/jiffies.h>
+#include <linux/notifier.h>
 #include <linux/of_address.h>
 #include <linux/of_irq.h>
 #include <linux/of_platform.h>
 #include <linux/smp.h>
 #include <linux/mfd/syscon.h>
 
-#include <asm/system_misc.h>
-
 #define RESET_SOURCE_ENABLE_REG 1
 #define SW_MASTER_RESET_REG 2
 
 static u32 rst_src_en;
 static u32 sw_mstr_rst;
 
-static void brcmstb_reboot(enum reboot_mode mode, const char *cmd)
+static int brcmstb_restart_handler(struct notifier_block *this,
+                                  unsigned long mode, void *cmd)
 {
        int rc;
        u32 tmp;
        rc = regmap_write(regmap, rst_src_en, 1);
        if (rc) {
                pr_err("failed to write rst_src_en (%d)\n", rc);
-               return;
+               return NOTIFY_DONE;
        }
 
        rc = regmap_read(regmap, rst_src_en, &tmp);
        if (rc) {
                pr_err("failed to read rst_src_en (%d)\n", rc);
-               return;
+               return NOTIFY_DONE;
        }
 
        rc = regmap_write(regmap, sw_mstr_rst, 1);
        if (rc) {
                pr_err("failed to write sw_mstr_rst (%d)\n", rc);
-               return;
+               return NOTIFY_DONE;
        }
 
        rc = regmap_read(regmap, sw_mstr_rst, &tmp);
        if (rc) {
                pr_err("failed to read sw_mstr_rst (%d)\n", rc);
-               return;
+               return NOTIFY_DONE;
        }
 
        while (1)
                ;
+
+       return NOTIFY_DONE;
 }
 
+static struct notifier_block brcmstb_restart_nb = {
+       .notifier_call = brcmstb_restart_handler,
+       .priority = 128,
+};
+
 static int brcmstb_reboot_probe(struct platform_device *pdev)
 {
        int rc;
                return -EINVAL;
        }
 
-       arm_pm_restart = brcmstb_reboot;
+       rc = register_restart_handler(&brcmstb_restart_nb);
+       if (rc)
+               dev_err(&pdev->dev,
+                       "cannot register restart handler (err=%d)\n", rc);
 
-       return 0;
+       return rc;
 }
 
 static const struct of_device_id of_match[] = {