{
        struct mm_struct *mm = current->mm;
        struct vm_area_struct *vma;
-       struct gru_thread_state *gts = NULL;
+       struct gru_thread_state *gts = ERR_PTR(-EINVAL);
 
        down_write(&mm->mmap_sem);
        vma = gru_find_vma(vaddr);
-       if (vma)
-               gts = gru_alloc_thread_state(vma, TSID(vaddr, vma));
-       if (!IS_ERR(gts)) {
-               mutex_lock(>s->ts_ctxlock);
-               downgrade_write(&mm->mmap_sem);
-       } else {
-               up_write(&mm->mmap_sem);
-       }
+       if (!vma)
+               goto err;
+
+       gts = gru_alloc_thread_state(vma, TSID(vaddr, vma));
+       if (IS_ERR(gts))
+               goto err;
+       mutex_lock(>s->ts_ctxlock);
+       downgrade_write(&mm->mmap_sem);
+       return gts;
 
+err:
+       up_write(&mm->mmap_sem);
        return gts;
 }