* Copyright (C) 2002 Andi Kleen
  *
  * This handles calls from both 32bit and 64bit mode.
+ *
+ * Lock order:
+ *     contex.ldt_usr_sem
+ *       mmap_sem
+ *         context.lock
  */
 
 #include <linux/errno.h>
 #endif
 }
 
-/* context.lock is held for us, so we don't need any locking. */
+/* context.lock is held by the task which issued the smp function call */
 static void flush_ldt(void *__mm)
 {
        struct mm_struct *mm = __mm;
        paravirt_alloc_ldt(ldt->entries, ldt->nr_entries);
 }
 
-/* context.lock is held */
-static void install_ldt(struct mm_struct *current_mm,
-                       struct ldt_struct *ldt)
+static void install_ldt(struct mm_struct *mm, struct ldt_struct *ldt)
 {
+       mutex_lock(&mm->context.lock);
+
        /* Synchronizes with READ_ONCE in load_mm_ldt. */
-       smp_store_release(¤t_mm->context.ldt, ldt);
+       smp_store_release(&mm->context.ldt, ldt);
 
-       /* Activate the LDT for all CPUs using current_mm. */
-       on_each_cpu_mask(mm_cpumask(current_mm), flush_ldt, current_mm, true);
+       /* Activate the LDT for all CPUs using currents mm. */
+       on_each_cpu_mask(mm_cpumask(mm), flush_ldt, mm, true);
+
+       mutex_unlock(&mm->context.lock);
 }
 
 static void free_ldt_struct(struct ldt_struct *ldt)
        struct mm_struct *old_mm;
        int retval = 0;
 
-       mutex_init(&mm->context.lock);
+       init_rwsem(&mm->context.ldt_usr_sem);
+
        old_mm = current->mm;
        if (!old_mm) {
                mm->context.ldt = NULL;
        unsigned long entries_size;
        int retval;
 
-       mutex_lock(&mm->context.lock);
+       down_read(&mm->context.ldt_usr_sem);
 
        if (!mm->context.ldt) {
                retval = 0;
        retval = bytecount;
 
 out_unlock:
-       mutex_unlock(&mm->context.lock);
+       up_read(&mm->context.ldt_usr_sem);
        return retval;
 }
 
                        ldt.avl = 0;
        }
 
-       mutex_lock(&mm->context.lock);
+       if (down_write_killable(&mm->context.ldt_usr_sem))
+               return -EINTR;
 
        old_ldt       = mm->context.ldt;
        old_nr_entries = old_ldt ? old_ldt->nr_entries : 0;
        error = 0;
 
 out_unlock:
-       mutex_unlock(&mm->context.lock);
+       up_write(&mm->context.ldt_usr_sem);
 out:
        return error;
 }