/*
  * Allocate a thread state structure.
  */
-static struct gru_thread_state *gru_alloc_gts(struct vm_area_struct *vma,
-                                             struct gru_vma_data *vdata,
-                                             int tsid)
+struct gru_thread_state *gru_alloc_gts(struct vm_area_struct *vma,
+               int cbr_au_count, int dsr_au_count, int options, int tsid)
 {
        struct gru_thread_state *gts;
        int bytes;
 
-       bytes = DSR_BYTES(vdata->vd_dsr_au_count) +
-                               CBR_BYTES(vdata->vd_cbr_au_count);
+       bytes = DSR_BYTES(dsr_au_count) + CBR_BYTES(cbr_au_count);
        bytes += sizeof(struct gru_thread_state);
        gts = kzalloc(bytes, GFP_KERNEL);
        if (!gts)
        STAT(gts_alloc);
        atomic_set(>s->ts_refcnt, 1);
        mutex_init(>s->ts_ctxlock);
-       gts->ts_cbr_au_count = vdata->vd_cbr_au_count;
-       gts->ts_dsr_au_count = vdata->vd_dsr_au_count;
-       gts->ts_user_options = vdata->vd_user_options;
+       gts->ts_cbr_au_count = cbr_au_count;
+       gts->ts_dsr_au_count = dsr_au_count;
+       gts->ts_user_options = options;
        gts->ts_tsid = tsid;
-       gts->ts_user_options = vdata->vd_user_options;
        gts->ts_ctxnum = NULLCTX;
-       gts->ts_mm = current->mm;
-       gts->ts_vma = vma;
        gts->ts_tlb_int_select = -1;
-       gts->ts_gms = gru_register_mmu_notifier();
        gts->ts_sizeavail = GRU_SIZEAVAIL(PAGE_SHIFT);
-       if (!gts->ts_gms)
-               goto err;
+       if (vma) {
+               gts->ts_mm = current->mm;
+               gts->ts_vma = vma;
+               gts->ts_gms = gru_register_mmu_notifier();
+               if (!gts->ts_gms)
+                       goto err;
+       }
 
-       gru_dbg(grudev, "alloc vdata %p, new gts %p\n", vdata, gts);
+       gru_dbg(grudev, "alloc gts %p\n", gts);
        return gts;
 
 err:
        struct gru_vma_data *vdata = vma->vm_private_data;
        struct gru_thread_state *gts, *ngts;
 
-       gts = gru_alloc_gts(vma, vdata, tsid);
+       gts = gru_alloc_gts(vma, vdata->vd_cbr_au_count, vdata->vd_dsr_au_count,
+                           vdata->vd_user_options, tsid);
        if (!gts)
                return NULL;
 
 #define next_gru(b, g) (((g) < &(b)->bs_grus[GRU_CHIPLETS_PER_BLADE - 1]) ?  \
                                 ((g)+1) : &(b)->bs_grus[0])
 
-static void gru_steal_context(struct gru_thread_state *gts)
+static void gru_steal_context(struct gru_thread_state *gts, int blade_id)
 {
        struct gru_blade_state *blade;
        struct gru_state *gru, *gru0;
        cbr = gts->ts_cbr_au_count;
        dsr = gts->ts_dsr_au_count;
 
-       preempt_disable();
-       blade = gru_base[uv_numa_blade_id()];
+       blade = gru_base[blade_id];
        spin_lock(&blade->bs_lock);
 
        ctxnum = next_ctxnum(blade->bs_lru_ctxnum);
        blade->bs_lru_gru = gru;
        blade->bs_lru_ctxnum = ctxnum;
        spin_unlock(&blade->bs_lock);
-       preempt_enable();
 
        if (ngts) {
                STAT(steal_context);
 /*
  * Scan the GRUs on the local blade & assign a GRU context.
  */
-static struct gru_state *gru_assign_gru_context(struct gru_thread_state *gts)
+static struct gru_state *gru_assign_gru_context(struct gru_thread_state *gts,
+                                               int blade)
 {
        struct gru_state *gru, *grux;
        int i, max_active_contexts;
 
-       preempt_disable();
 
 again:
        gru = NULL;
        max_active_contexts = GRU_NUM_CCH;
-       for_each_gru_on_blade(grux, uv_numa_blade_id(), i) {
+       for_each_gru_on_blade(grux, blade, i) {
                if (check_gru_resources(grux, gts->ts_cbr_au_count,
                                        gts->ts_dsr_au_count,
                                        max_active_contexts)) {
                STAT(assign_context_failed);
        }
 
-       preempt_enable();
        return gru;
 }
 
 {
        struct gru_thread_state *gts;
        unsigned long paddr, vaddr;
+       int blade_id;
 
        vaddr = (unsigned long)vmf->virtual_address;
        gru_dbg(grudev, "vma %p, vaddr 0x%lx (0x%lx)\n",
 again:
        mutex_lock(>s->ts_ctxlock);
        preempt_disable();
+       blade_id = uv_numa_blade_id();
+
        if (gts->ts_gru) {
-               if (gts->ts_gru->gs_blade_id != uv_numa_blade_id()) {
+               if (gts->ts_gru->gs_blade_id != blade_id) {
                        STAT(migrated_nopfn_unload);
                        gru_unload_context(gts, 1);
                } else {
        }
 
        if (!gts->ts_gru) {
-               if (!gru_assign_gru_context(gts)) {
-                       mutex_unlock(>s->ts_ctxlock);
+               if (!gru_assign_gru_context(gts, blade_id)) {
                        preempt_enable();
+                       mutex_unlock(>s->ts_ctxlock);
+                       set_current_state(TASK_INTERRUPTIBLE);
                        schedule_timeout(GRU_ASSIGN_DELAY);  /* true hack ZZZ */
+                       blade_id = uv_numa_blade_id();
                        if (gts->ts_steal_jiffies + GRU_STEAL_DELAY < jiffies)
-                               gru_steal_context(gts);
+                               gru_steal_context(gts, blade_id);
                        goto again;
                }
                gru_load_context(gts);
                                vma->vm_page_prot);
        }
 
-       mutex_unlock(>s->ts_ctxlock);
        preempt_enable();
+       mutex_unlock(>s->ts_ctxlock);
 
        return VM_FAULT_NOPAGE;
 }