#include <linux/bug.h>
 #include <asm/asm.h>
 
-void __xchg_called_with_bad_pointer(void);
-
-static __always_inline unsigned long
-__arch_xchg(unsigned long x, unsigned long address, int size)
-{
-       unsigned long old;
-       int shift;
-
-       switch (size) {
-       case 1:
-               shift = (3 ^ (address & 3)) << 3;
-               address ^= address & 3;
-               asm volatile(
-                       "       l       %0,%1\n"
-                       "0:     lr      0,%0\n"
-                       "       nr      0,%3\n"
-                       "       or      0,%2\n"
-                       "       cs      %0,0,%1\n"
-                       "       jl      0b\n"
-                       : "=&d" (old), "+Q" (*(int *) address)
-                       : "d" ((x & 0xff) << shift), "d" (~(0xff << shift))
-                       : "memory", "cc", "0");
-               return old >> shift;
-       case 2:
-               shift = (2 ^ (address & 2)) << 3;
-               address ^= address & 2;
-               asm volatile(
-                       "       l       %0,%1\n"
-                       "0:     lr      0,%0\n"
-                       "       nr      0,%3\n"
-                       "       or      0,%2\n"
-                       "       cs      %0,0,%1\n"
-                       "       jl      0b\n"
-                       : "=&d" (old), "+Q" (*(int *) address)
-                       : "d" ((x & 0xffff) << shift), "d" (~(0xffff << shift))
-                       : "memory", "cc", "0");
-               return old >> shift;
-       case 4:
-               asm volatile(
-                       "       l       %0,%1\n"
-                       "0:     cs      %0,%2,%1\n"
-                       "       jl      0b\n"
-                       : "=&d" (old), "+Q" (*(int *) address)
-                       : "d" (x)
-                       : "memory", "cc");
-               return old;
-       case 8:
-               asm volatile(
-                       "       lg      %0,%1\n"
-                       "0:     csg     %0,%2,%1\n"
-                       "       jl      0b\n"
-                       : "=&d" (old), "+QS" (*(long *) address)
-                       : "d" (x)
-                       : "memory", "cc");
-               return old;
-       }
-       __xchg_called_with_bad_pointer();
-       return x;
-}
-
-#define arch_xchg(ptr, x)                                              \
-({                                                                     \
-       __typeof__(*(ptr)) __ret;                                       \
-                                                                       \
-       __ret = (__typeof__(*(ptr)))                                    \
-               __arch_xchg((unsigned long)(x), (unsigned long)(ptr),   \
-                           sizeof(*(ptr)));                            \
-       __ret;                                                          \
-})
-
 void __cmpxchg_called_with_bad_pointer(void);
 
 static __always_inline u32 __cs_asm(u64 ptr, u32 old, u32 new)
        likely(__cc == 0);                                              \
 })
 
+#else /* __HAVE_ASM_FLAG_OUTPUTS__ */
+
+#define arch_try_cmpxchg(ptr, oldp, new)                               \
+({                                                                     \
+       __typeof__((ptr)) __oldp = (__typeof__(ptr))(oldp);             \
+       __typeof__(*(ptr)) __old = *__oldp;                             \
+       __typeof__(*(ptr)) __new = (new);                               \
+       __typeof__(*(ptr)) __prev;                                      \
+                                                                       \
+       __prev = arch_cmpxchg((ptr), (__old), (__new));                 \
+       if (unlikely(__prev != __old))                                  \
+               *__oldp = __prev;                                       \
+       likely(__prev == __old);                                        \
+})
+
+#endif /* __HAVE_ASM_FLAG_OUTPUTS__ */
+
 #define arch_try_cmpxchg64             arch_try_cmpxchg
 #define arch_try_cmpxchg_local         arch_try_cmpxchg
 #define arch_try_cmpxchg64_local       arch_try_cmpxchg
 
-#endif /* __HAVE_ASM_FLAG_OUTPUTS__ */
+void __xchg_called_with_bad_pointer(void);
+
+static inline u8 __arch_xchg1(u64 ptr, u8 x)
+{
+       int shift = (3 ^ (ptr & 3)) << 3;
+       u32 mask, old, new;
+
+       ptr &= ~0x3;
+       mask = ~(0xff << shift);
+       old = READ_ONCE(*(u32 *)ptr);
+       do {
+               new = old & mask;
+               new |= x << shift;
+       } while (!arch_try_cmpxchg((u32 *)ptr, &old, new));
+       return old >> shift;
+}
+
+static inline u16 __arch_xchg2(u64 ptr, u16 x)
+{
+       int shift = (2 ^ (ptr & 2)) << 3;
+       u32 mask, old, new;
+
+       ptr &= ~0x3;
+       mask = ~(0xffff << shift);
+       old = READ_ONCE(*(u32 *)ptr);
+       do {
+               new = old & mask;
+               new |= x << shift;
+       } while (!arch_try_cmpxchg((u32 *)ptr, &old, new));
+       return old >> shift;
+}
+
+static __always_inline u64 __arch_xchg(u64 ptr, u64 x, int size)
+{
+       switch (size) {
+       case 1:
+               return __arch_xchg1(ptr, x & 0xff);
+       case 2:
+               return __arch_xchg2(ptr, x & 0xffff);
+       case 4: {
+               u32 old = READ_ONCE(*(u32 *)ptr);
+
+               do {
+               } while (!arch_try_cmpxchg((u32 *)ptr, &old, x & 0xffffffff));
+               return old;
+       }
+       case 8: {
+               u64 old = READ_ONCE(*(u64 *)ptr);
+
+               do {
+               } while (!arch_try_cmpxchg((u64 *)ptr, &old, x));
+               return old;
+       }
+       }
+       __xchg_called_with_bad_pointer();
+       return x;
+}
+
+#define arch_xchg(ptr, x)                                              \
+({                                                                     \
+       (__typeof__(*(ptr)))__arch_xchg((unsigned long)(ptr),           \
+                                       (unsigned long)(x),             \
+                                       sizeof(*(ptr)));                \
+})
 
 #define system_has_cmpxchg128()                1