#include <linux/list.h>
 #include <linux/cpu.h>
 #include <linux/smp.h>
+#include <linux/bug.h>
 
 #include <linux/hw_breakpoint.h>
 /*
        mutex_unlock(&nr_bp_mutex);
 }
 
+static int __modify_bp_slot(struct perf_event *bp, u64 old_type)
+{
+       int err;
+
+       __release_bp_slot(bp, old_type);
+
+       err = __reserve_bp_slot(bp, bp->attr.bp_type);
+       if (err) {
+               /*
+                * Reserve the old_type slot back in case
+                * there's no space for the new type.
+                *
+                * This must succeed, because we just released
+                * the old_type slot in the __release_bp_slot
+                * call above. If not, something is broken.
+                */
+               WARN_ON(__reserve_bp_slot(bp, old_type));
+       }
+
+       return err;
+}
+
+static int modify_bp_slot(struct perf_event *bp, u64 old_type)
+{
+       int ret;
+
+       mutex_lock(&nr_bp_mutex);
+       ret = __modify_bp_slot(bp, old_type);
+       mutex_unlock(&nr_bp_mutex);
+       return ret;
+}
+
 /*
  * Allow the kernel debugger to reserve breakpoint slots without
  * taking a lock using the dbg_* variant of for the reserve and
        u64 old_addr = bp->attr.bp_addr;
        u64 old_len = bp->attr.bp_len;
        int old_type = bp->attr.bp_type;
+       bool modify = attr->bp_type != old_type;
        int err = 0;
 
        /*
        bp->attr.bp_type = attr->bp_type;
        bp->attr.bp_len = attr->bp_len;
 
-       if (attr->disabled)
-               goto end;
-
        err = validate_hw_breakpoint(bp);
-       if (!err)
-               perf_event_enable(bp);
+       if (!err && modify)
+               err = modify_bp_slot(bp, old_type);
 
        if (err) {
                bp->attr.bp_addr = old_addr;
                return err;
        }
 
-end:
-       bp->attr.disabled = attr->disabled;
+       if (!attr->disabled)
+               perf_event_enable(bp);
 
+       bp->attr.disabled = attr->disabled;
        return 0;
 }
 EXPORT_SYMBOL_GPL(modify_user_hw_breakpoint);