diff options
Diffstat (limited to 'arch/x86/net/bpf_jit_comp.c')
-rw-r--r-- | arch/x86/net/bpf_jit_comp.c | 161 |
1 files changed, 131 insertions, 30 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index d25d81c8ecc0..06b080b61aa5 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -64,6 +64,56 @@ static bool is_imm8(int value) return value <= 127 && value >= -128; } +/* + * Let us limit the positive offset to be <= 123. + * This is to ensure eventual jit convergence For the following patterns: + * ... + * pass4, final_proglen=4391: + * ... + * 20e: 48 85 ff test rdi,rdi + * 211: 74 7d je 0x290 + * 213: 48 8b 77 00 mov rsi,QWORD PTR [rdi+0x0] + * ... + * 289: 48 85 ff test rdi,rdi + * 28c: 74 17 je 0x2a5 + * 28e: e9 7f ff ff ff jmp 0x212 + * 293: bf 03 00 00 00 mov edi,0x3 + * Note that insn at 0x211 is 2-byte cond jump insn for offset 0x7d (-125) + * and insn at 0x28e is 5-byte jmp insn with offset -129. + * + * pass5, final_proglen=4392: + * ... + * 20e: 48 85 ff test rdi,rdi + * 211: 0f 84 80 00 00 00 je 0x297 + * 217: 48 8b 77 00 mov rsi,QWORD PTR [rdi+0x0] + * ... + * 28d: 48 85 ff test rdi,rdi + * 290: 74 1a je 0x2ac + * 292: eb 84 jmp 0x218 + * 294: bf 03 00 00 00 mov edi,0x3 + * Note that insn at 0x211 is 6-byte cond jump insn now since its offset + * becomes 0x80 based on previous round (0x293 - 0x213 = 0x80). + * At the same time, insn at 0x292 is a 2-byte insn since its offset is + * -124. + * + * pass6 will repeat the same code as in pass4 and this will prevent + * eventual convergence. + * + * To fix this issue, we need to break je (2->6 bytes) <-> jmp (5->2 bytes) + * cycle in the above. In the above example je offset <= 0x7c should work. + * + * For other cases, je <-> je needs offset <= 0x7b to avoid no convergence + * issue. For jmp <-> je and jmp <-> jmp cases, jmp offset <= 0x7c should + * avoid no convergence issue. + * + * Overall, let us limit the positive offset for 8bit cond/uncond jmp insn + * to maximum 123 (0x7b). This way, the jit pass can eventually converge. + */ +static bool is_imm8_jmp_offset(int value) +{ + return value <= 123 && value >= -128; +} + static bool is_simm32(s64 value) { return value == (s64)(s32)value; @@ -273,7 +323,7 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 /* Number of bytes that will be skipped on tailcall */ -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) +#define X86_TAIL_CALL_OFFSET (12 + ENDBR_INSN_SIZE) static void push_r12(u8 **pprog) { @@ -403,6 +453,37 @@ static void emit_cfi(u8 **pprog, u32 hash) *pprog = prog; } +static void emit_prologue_tail_call(u8 **pprog, bool is_subprog) +{ + u8 *prog = *pprog; + + if (!is_subprog) { + /* cmp rax, MAX_TAIL_CALL_CNT */ + EMIT4(0x48, 0x83, 0xF8, MAX_TAIL_CALL_CNT); + EMIT2(X86_JA, 6); /* ja 6 */ + /* rax is tail_call_cnt if <= MAX_TAIL_CALL_CNT. + * case1: entry of main prog. + * case2: tail callee of main prog. + */ + EMIT1(0x50); /* push rax */ + /* Make rax as tail_call_cnt_ptr. */ + EMIT3(0x48, 0x89, 0xE0); /* mov rax, rsp */ + EMIT2(0xEB, 1); /* jmp 1 */ + /* rax is tail_call_cnt_ptr if > MAX_TAIL_CALL_CNT. + * case: tail callee of subprog. + */ + EMIT1(0x50); /* push rax */ + /* push tail_call_cnt_ptr */ + EMIT1(0x50); /* push rax */ + } else { /* is_subprog */ + /* rax is tail_call_cnt_ptr. */ + EMIT1(0x50); /* push rax */ + EMIT1(0x50); /* push rax */ + } + + *pprog = prog; +} + /* * Emit x86-64 prologue code for BPF program. * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes @@ -424,10 +505,10 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, /* When it's the entry of the whole tailcall context, * zeroing rax means initialising tail_call_cnt. */ - EMIT2(0x31, 0xC0); /* xor eax, eax */ + EMIT3(0x48, 0x31, 0xC0); /* xor rax, rax */ else /* Keep the same instruction layout. */ - EMIT2(0x66, 0x90); /* nop2 */ + emit_nops(&prog, 3); /* nop3 */ } /* Exception callback receives FP as third parameter */ if (is_exception_cb) { @@ -453,7 +534,7 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, if (stack_depth) EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); if (tail_call_reachable) - EMIT1(0x50); /* push rax */ + emit_prologue_tail_call(&prog, is_subprog); *pprog = prog; } @@ -589,13 +670,15 @@ static void emit_return(u8 **pprog, u8 *ip) *pprog = prog; } +#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (-16 - round_up(stack, 8)) + /* * Generate the following code: * * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... * if (index >= array->map.max_entries) * goto out; - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) * goto out; * prog = array->ptrs[index]; * if (prog == NULL) @@ -608,7 +691,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, u32 stack_depth, u8 *ip, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); + int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack_depth); u8 *prog = *pprog, *start = *pprog; int offset; @@ -630,16 +713,14 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, EMIT2(X86_JBE, offset); /* jbe out */ /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ + EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */ + EMIT4(0x48, 0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp qword ptr [rax], MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_indirect_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ /* prog = array->ptrs[index]; */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ @@ -654,6 +735,9 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, offset = ctx->tail_call_indirect_label - (prog + 2 - start); EMIT2(X86_JE, offset); /* je out */ + /* Inc tail_call_cnt if the slot is populated. */ + EMIT4(0x48, 0x83, 0x00, 0x01); /* add qword ptr [rax], 1 */ + if (bpf_prog->aux->exception_boundary) { pop_callee_regs(&prog, all_callee_regs_used); pop_r12(&prog); @@ -663,6 +747,11 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, pop_r12(&prog); } + /* Pop tail_call_cnt_ptr. */ + EMIT1(0x58); /* pop rax */ + /* Pop tail_call_cnt, if it's main prog. + * Pop tail_call_cnt_ptr, if it's subprog. + */ EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ @@ -691,21 +780,19 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, bool *callee_regs_used, u32 stack_depth, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); + int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack_depth); u8 *prog = *pprog, *start = *pprog; int offset; /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ + EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */ + EMIT4(0x48, 0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp qword ptr [rax], MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_direct_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ poke->tailcall_bypass = ip + (prog - start); poke->adj_off = X86_TAIL_CALL_OFFSET; @@ -715,6 +802,9 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE, poke->tailcall_bypass); + /* Inc tail_call_cnt if the slot is populated. */ + EMIT4(0x48, 0x83, 0x00, 0x01); /* add qword ptr [rax], 1 */ + if (bpf_prog->aux->exception_boundary) { pop_callee_regs(&prog, all_callee_regs_used); pop_r12(&prog); @@ -724,6 +814,11 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, pop_r12(&prog); } + /* Pop tail_call_cnt_ptr. */ + EMIT1(0x58); /* pop rax */ + /* Pop tail_call_cnt, if it's main prog. + * Pop tail_call_cnt_ptr, if it's subprog. + */ EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); @@ -1311,9 +1406,11 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op) #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) -/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ -#define RESTORE_TAIL_CALL_CNT(stack) \ - EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) +#define __LOAD_TCC_PTR(off) \ + EMIT3_off32(0x48, 0x8B, 0x85, off) +/* mov rax, qword ptr [rbp - rounded_stack_depth - 16] */ +#define LOAD_TAIL_CALL_CNT_PTR(stack) \ + __LOAD_TCC_PTR(BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack)) static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, int oldproglen, struct jit_context *ctx, bool jmp_padding) @@ -2031,7 +2128,7 @@ populate_extable: func = (u8 *) __bpf_call_base + imm32; if (tail_call_reachable) { - RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); + LOAD_TAIL_CALL_CNT_PTR(bpf_prog->aux->stack_depth); ip += 7; } if (!imm32) @@ -2184,7 +2281,7 @@ emit_cond_jmp: /* Convert BPF opcode to x86 */ return -EFAULT; } jmp_offset = addrs[i + insn->off] - addrs[i]; - if (is_imm8(jmp_offset)) { + if (is_imm8_jmp_offset(jmp_offset)) { if (jmp_padding) { /* To keep the jmp_offset valid, the extra bytes are * padded before the jump insn, so we subtract the @@ -2266,7 +2363,7 @@ emit_cond_jmp: /* Convert BPF opcode to x86 */ break; } emit_jmp: - if (is_imm8(jmp_offset)) { + if (is_imm8_jmp_offset(jmp_offset)) { if (jmp_padding) { /* To avoid breaking jmp_offset, the extra bytes * are padded before the actual jmp insn, so @@ -2706,6 +2803,10 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog, return 0; } +/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ +#define LOAD_TRAMP_TAIL_CALL_CNT_PTR(stack) \ + __LOAD_TCC_PTR(-round_up(stack, 8) - 8) + /* Example: * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev); * its 'struct btf_func_model' will be nr_args=2 @@ -2826,7 +2927,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im * [ ... ] * [ stack_arg2 ] * RBP - arg_stack_off [ stack_arg1 ] - * RSP [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX + * RSP [ tail_call_cnt_ptr ] BPF_TRAMP_F_TAIL_CALL_CTX */ /* room for return value of orig_call or fentry prog */ @@ -2955,10 +3056,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im save_args(m, &prog, arg_stack_off, true); if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before calling the original function, restore the - * tail_call_cnt from stack to rax. + /* Before calling the original function, load the + * tail_call_cnt_ptr from stack to rax. */ - RESTORE_TAIL_CALL_CNT(stack_size); + LOAD_TRAMP_TAIL_CALL_CNT_PTR(stack_size); } if (flags & BPF_TRAMP_F_ORIG_STACK) { @@ -3017,10 +3118,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im goto cleanup; } } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before running the original function, restore the - * tail_call_cnt from stack to rax. + /* Before running the original function, load the + * tail_call_cnt_ptr from stack to rax. */ - RESTORE_TAIL_CALL_CNT(stack_size); + LOAD_TRAMP_TAIL_CALL_CNT_PTR(stack_size); } /* restore return value of orig_call or fentry prog back into RAX */ |