* pass on to access_ok(), for instance.
  */
 #define untagged_addr(addr)    \
-       ((__typeof__(addr))sign_extend64((u64)(addr), 55))
+       ((__typeof__(addr))sign_extend64((__force u64)(addr), 55))
 
 #ifdef CONFIG_KASAN_SW_TAGS
 #define __tag_shifted(tag)     ((u64)(tag) << 56)
 
 {
        unsigned long ret, limit = current_thread_info()->addr_limit;
 
+       addr = untagged_addr(addr);
+
        __chk_user_ptr(addr);
        asm volatile(
        // A + B <= C + 1 for all A,B,C, in four easy steps:
 
 /*
  * Sanitise a uaccess pointer such that it becomes NULL if above the
- * current addr_limit.
+ * current addr_limit. In case the pointer is tagged (has the top byte set),
+ * untag the pointer before checking.
  */
 #define uaccess_mask_ptr(ptr) (__typeof__(ptr))__uaccess_mask_ptr(ptr)
 static inline void __user *__uaccess_mask_ptr(const void __user *ptr)
        void __user *safe_ptr;
 
        asm volatile(
-       "       bics    xzr, %1, %2\n"
+       "       bics    xzr, %3, %2\n"
        "       csel    %0, %1, xzr, eq\n"
        : "=&r" (safe_ptr)
-       : "r" (ptr), "r" (current_thread_info()->addr_limit)
+       : "r" (ptr), "r" (current_thread_info()->addr_limit),
+         "r" (untagged_addr(ptr))
        : "cc");
 
        csdb();