#include <linux/delay.h>
 #include <linux/sizes.h>
 #include <linux/ntb.h>
+#include <linux/mutex.h>
 
 #define DRIVER_NAME            "ntb_perf"
 #define DRIVER_DESCRIPTION     "PCIe NTB Performance Measurement Tool"
        int                     dma_prep_err;
        int                     src_idx;
        void                    *srcs[MAX_SRCS];
+       wait_queue_head_t       *wq;
 };
 
 struct perf_ctx {
        struct dentry           *debugfs_run;
        struct dentry           *debugfs_threads;
        u8                      perf_threads;
-       bool                    run;
+       /* mutex ensures only one set of threads run at once */
+       struct mutex            run_mutex;
        struct pthr_ctx         pthr_ctx[MAX_THREADS];
        atomic_t                tsync;
+       atomic_t                tdone;
 };
 
 enum {
                        set_current_state(TASK_INTERRUPTIBLE);
                        schedule_timeout(1);
                }
+
+               if (unlikely(kthread_should_stop()))
+                       break;
        }
 
        if (use_dma) {
                pr_info("%s: All DMA descriptors submitted\n", current->comm);
-               while (atomic_read(&pctx->dma_sync) != 0)
+               while (atomic_read(&pctx->dma_sync) != 0) {
+                       if (kthread_should_stop())
+                               break;
                        msleep(20);
+               }
        }
 
        kstop = ktime_get();
                pctx->srcs[i] = NULL;
        }
 
-       return 0;
+       atomic_inc(&perf->tdone);
+       wake_up(pctx->wq);
+       rc = 0;
+       goto done;
 
 err:
        for (i = 0; i < MAX_SRCS; i++) {
                pctx->dma_chan = NULL;
        }
 
+done:
+       /* Wait until we are told to stop */
+       for (;;) {
+               set_current_state(TASK_INTERRUPTIBLE);
+               if (kthread_should_stop())
+                       break;
+               schedule();
+       }
+       __set_current_state(TASK_RUNNING);
+
        return rc;
 }
 
        struct perf_ctx *perf = filp->private_data;
        char *buf;
        ssize_t ret, out_offset;
+       int running;
 
        if (!perf)
                return 0;
        buf = kmalloc(64, GFP_KERNEL);
        if (!buf)
                return -ENOMEM;
-       out_offset = snprintf(buf, 64, "%d\n", perf->run);
+
+       running = mutex_is_locked(&perf->run_mutex);
+       out_offset = snprintf(buf, 64, "%d\n", running);
        ret = simple_read_from_buffer(ubuf, count, offp, buf, out_offset);
        kfree(buf);
 
        struct pthr_ctx *pctx;
        int i;
 
-       perf->run = false;
        for (i = 0; i < MAX_THREADS; i++) {
                pctx = &perf->pthr_ctx[i];
                if (pctx->thread) {
 {
        struct perf_ctx *perf = filp->private_data;
        int node, i;
+       DECLARE_WAIT_QUEUE_HEAD(wq);
 
        if (!perf->link_is_up)
-               return 0;
+               return -ENOLINK;
 
        if (perf->perf_threads == 0)
-               return 0;
+               return -EINVAL;
 
-       if (atomic_read(&perf->tsync) == 0)
-               perf->run = false;
+       if (!mutex_trylock(&perf->run_mutex))
+               return -EBUSY;
 
-       if (perf->run)
-               threads_cleanup(perf);
-       else {
-               perf->run = true;
+       if (perf->perf_threads > MAX_THREADS) {
+               perf->perf_threads = MAX_THREADS;
+               pr_info("Reset total threads to: %u\n", MAX_THREADS);
+       }
 
-               if (perf->perf_threads > MAX_THREADS) {
-                       perf->perf_threads = MAX_THREADS;
-                       pr_info("Reset total threads to: %u\n", MAX_THREADS);
-               }
+       /* no greater than 1M */
+       if (seg_order > MAX_SEG_ORDER) {
+               seg_order = MAX_SEG_ORDER;
+               pr_info("Fix seg_order to %u\n", seg_order);
+       }
 
-               /* no greater than 1M */
-               if (seg_order > MAX_SEG_ORDER) {
-                       seg_order = MAX_SEG_ORDER;
-                       pr_info("Fix seg_order to %u\n", seg_order);
-               }
+       if (run_order < seg_order) {
+               run_order = seg_order;
+               pr_info("Fix run_order to %u\n", run_order);
+       }
 
-               if (run_order < seg_order) {
-                       run_order = seg_order;
-                       pr_info("Fix run_order to %u\n", run_order);
-               }
+       node = dev_to_node(&perf->ntb->pdev->dev);
+       atomic_set(&perf->tdone, 0);
 
-               node = dev_to_node(&perf->ntb->pdev->dev);
-               /* launch kernel thread */
-               for (i = 0; i < perf->perf_threads; i++) {
-                       struct pthr_ctx *pctx;
-
-                       pctx = &perf->pthr_ctx[i];
-                       atomic_set(&pctx->dma_sync, 0);
-                       pctx->perf = perf;
-                       pctx->thread =
-                               kthread_create_on_node(ntb_perf_thread,
-                                                      (void *)pctx,
-                                                      node, "ntb_perf %d", i);
-                       if (IS_ERR(pctx->thread)) {
-                               pctx->thread = NULL;
-                               goto err;
-                       } else
-                               wake_up_process(pctx->thread);
-
-                       if (perf->run == false)
-                               return -ENXIO;
-               }
+       /* launch kernel thread */
+       for (i = 0; i < perf->perf_threads; i++) {
+               struct pthr_ctx *pctx;
 
+               pctx = &perf->pthr_ctx[i];
+               atomic_set(&pctx->dma_sync, 0);
+               pctx->perf = perf;
+               pctx->wq = &wq;
+               pctx->thread =
+                       kthread_create_on_node(ntb_perf_thread,
+                                              (void *)pctx,
+                                              node, "ntb_perf %d", i);
+               if (IS_ERR(pctx->thread)) {
+                       pctx->thread = NULL;
+                       goto err;
+               } else {
+                       wake_up_process(pctx->thread);
+               }
        }
 
+       wait_event_interruptible(wq,
+               atomic_read(&perf->tdone) == perf->perf_threads);
+
+       threads_cleanup(perf);
+       mutex_unlock(&perf->run_mutex);
        return count;
 
 err:
        threads_cleanup(perf);
+       mutex_unlock(&perf->run_mutex);
        return -ENXIO;
 }
 
        perf->ntb = ntb;
        perf->perf_threads = 1;
        atomic_set(&perf->tsync, 0);
-       perf->run = false;
+       mutex_init(&perf->run_mutex);
        spin_lock_init(&perf->db_lock);
        perf_setup_mw(ntb, perf);
        INIT_DELAYED_WORK(&perf->link_work, perf_link_work);
 
        dev_dbg(&perf->ntb->dev, "%s called\n", __func__);
 
+       mutex_lock(&perf->run_mutex);
+
        cancel_delayed_work_sync(&perf->link_work);
        cancel_work_sync(&perf->link_cleanup);