#define SEEN_FUNC      0x20000000 /* might call external helpers */
 #define SEEN_TAILCALL  0x40000000 /* uses tail calls */
 
-#define SEEN_VREG_MASK 0x1ff80000 /* Volatile registers r3-r12 */
-#define SEEN_NVREG_MASK        0x0003ffff /* Non volatile registers r14-r31 */
-
 #ifdef CONFIG_PPC64
 extern const int b2p[MAX_BPF_JIT_REG + 2];
 #else
 
        return BPF_PPC_STACKFRAME(ctx) - 4;
 }
 
+#define SEEN_VREG_MASK         0x1ff80000 /* Volatile registers r3-r12 */
+#define SEEN_NVREG_FULL_MASK   0x0003ffff /* Non volatile registers r14-r31 */
+#define SEEN_NVREG_TEMP_MASK   0x00001e01 /* BPF_REG_5, BPF_REG_AX, TMP_REG */
+
 void bpf_jit_realloc_regs(struct codegen_context *ctx)
 {
+       unsigned int nvreg_mask;
+
        if (ctx->seen & SEEN_FUNC)
-               return;
+               nvreg_mask = SEEN_NVREG_TEMP_MASK;
+       else
+               nvreg_mask = SEEN_NVREG_FULL_MASK;
 
-       while (ctx->seen & SEEN_NVREG_MASK &&
+       while (ctx->seen & nvreg_mask &&
              (ctx->seen & SEEN_VREG_MASK) != SEEN_VREG_MASK) {
-               int old = 32 - fls(ctx->seen & (SEEN_NVREG_MASK & 0xaaaaaaab));
+               int old = 32 - fls(ctx->seen & (nvreg_mask & 0xaaaaaaab));
                int new = 32 - fls(~ctx->seen & (SEEN_VREG_MASK & 0xaaaaaaaa));
                int i;