/* Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
  * Copyright (c) 2016 Facebook
+ * Copyright (c) 2018 Covalent IO, Inc. http://covalent.io
  *
  * This program is free software; you can redistribute it and/or
  * modify it under the terms of version 2 of the GNU General Public
  *
  * After the call R0 is set to return type of the function and registers R1-R5
  * are set to NOT_INIT to indicate that they are no longer readable.
+ *
+ * The following reference types represent a potential reference to a kernel
+ * resource which, after first being allocated, must be checked and freed by
+ * the BPF program:
+ * - PTR_TO_SOCKET_OR_NULL, PTR_TO_SOCKET
+ *
+ * When the verifier sees a helper call return a reference type, it allocates a
+ * pointer id for the reference and stores it in the current function state.
+ * Similar to the way that PTR_TO_MAP_VALUE_OR_NULL is converted into
+ * PTR_TO_MAP_VALUE, PTR_TO_SOCKET_OR_NULL becomes PTR_TO_SOCKET when the type
+ * passes through a NULL-check conditional. For the branch wherein the state is
+ * changed to CONST_IMM, the verifier releases the reference.
  */
 
 /* verifier_state + insn_idx are pushed to stack when branch is encountered */
        int access_size;
        s64 msize_smax_value;
        u64 msize_umax_value;
+       int ptr_id;
 };
 
 static DEFINE_MUTEX(bpf_verifier_lock);
 
 static bool reg_type_may_be_null(enum bpf_reg_type type)
 {
-       return type == PTR_TO_MAP_VALUE_OR_NULL;
+       return type == PTR_TO_MAP_VALUE_OR_NULL ||
+              type == PTR_TO_SOCKET_OR_NULL;
+}
+
+static bool type_is_refcounted(enum bpf_reg_type type)
+{
+       return type == PTR_TO_SOCKET;
+}
+
+static bool type_is_refcounted_or_null(enum bpf_reg_type type)
+{
+       return type == PTR_TO_SOCKET || type == PTR_TO_SOCKET_OR_NULL;
+}
+
+static bool reg_is_refcounted(const struct bpf_reg_state *reg)
+{
+       return type_is_refcounted(reg->type);
+}
+
+static bool reg_is_refcounted_or_null(const struct bpf_reg_state *reg)
+{
+       return type_is_refcounted_or_null(reg->type);
+}
+
+static bool arg_type_is_refcounted(enum bpf_arg_type type)
+{
+       return type == ARG_PTR_TO_SOCKET;
+}
+
+/* Determine whether the function releases some resources allocated by another
+ * function call. The first reference type argument will be assumed to be
+ * released by release_reference().
+ */
+static bool is_release_function(enum bpf_func_id func_id)
+{
+       return false;
 }
 
 /* string representation of 'enum bpf_reg_type' */
                else
                        verbose(env, "=%s", types_buf);
        }
+       if (state->acquired_refs && state->refs[0].id) {
+               verbose(env, " refs=%d", state->refs[0].id);
+               for (i = 1; i < state->acquired_refs; i++)
+                       if (state->refs[i].id)
+                               verbose(env, ",%d", state->refs[i].id);
+       }
        verbose(env, "\n");
 }
 
               sizeof(*src->FIELD) * (src->COUNT / SIZE));              \
        return 0;                                                       \
 }
+/* copy_reference_state() */
+COPY_STATE_FN(reference, acquired_refs, refs, 1)
 /* copy_stack_state() */
 COPY_STATE_FN(stack, allocated_stack, stack, BPF_REG_SIZE)
 #undef COPY_STATE_FN
        state->FIELD = new_##FIELD;                                     \
        return 0;                                                       \
 }
+/* realloc_reference_state() */
+REALLOC_STATE_FN(reference, acquired_refs, refs, 1)
 /* realloc_stack_state() */
 REALLOC_STATE_FN(stack, allocated_stack, stack, BPF_REG_SIZE)
 #undef REALLOC_STATE_FN
  * which realloc_stack_state() copies over. It points to previous
  * bpf_verifier_state which is never reallocated.
  */
-static int realloc_func_state(struct bpf_func_state *state, int size,
-                             bool copy_old)
+static int realloc_func_state(struct bpf_func_state *state, int stack_size,
+                             int refs_size, bool copy_old)
 {
-       return realloc_stack_state(state, size, copy_old);
+       int err = realloc_reference_state(state, refs_size, copy_old);
+       if (err)
+               return err;
+       return realloc_stack_state(state, stack_size, copy_old);
+}
+
+/* Acquire a pointer id from the env and update the state->refs to include
+ * this new pointer reference.
+ * On success, returns a valid pointer id to associate with the register
+ * On failure, returns a negative errno.
+ */
+static int acquire_reference_state(struct bpf_verifier_env *env, int insn_idx)
+{
+       struct bpf_func_state *state = cur_func(env);
+       int new_ofs = state->acquired_refs;
+       int id, err;
+
+       err = realloc_reference_state(state, state->acquired_refs + 1, true);
+       if (err)
+               return err;
+       id = ++env->id_gen;
+       state->refs[new_ofs].id = id;
+       state->refs[new_ofs].insn_idx = insn_idx;
+
+       return id;
+}
+
+/* release function corresponding to acquire_reference_state(). Idempotent. */
+static int __release_reference_state(struct bpf_func_state *state, int ptr_id)
+{
+       int i, last_idx;
+
+       if (!ptr_id)
+               return -EFAULT;
+
+       last_idx = state->acquired_refs - 1;
+       for (i = 0; i < state->acquired_refs; i++) {
+               if (state->refs[i].id == ptr_id) {
+                       if (last_idx && i != last_idx)
+                               memcpy(&state->refs[i], &state->refs[last_idx],
+                                      sizeof(*state->refs));
+                       memset(&state->refs[last_idx], 0, sizeof(*state->refs));
+                       state->acquired_refs--;
+                       return 0;
+               }
+       }
+       return -EFAULT;
+}
+
+/* variation on the above for cases where we expect that there must be an
+ * outstanding reference for the specified ptr_id.
+ */
+static int release_reference_state(struct bpf_verifier_env *env, int ptr_id)
+{
+       struct bpf_func_state *state = cur_func(env);
+       int err;
+
+       err = __release_reference_state(state, ptr_id);
+       if (WARN_ON_ONCE(err != 0))
+               verbose(env, "verifier internal error: can't release reference\n");
+       return err;
+}
+
+static int transfer_reference_state(struct bpf_func_state *dst,
+                                   struct bpf_func_state *src)
+{
+       int err = realloc_reference_state(dst, src->acquired_refs, false);
+       if (err)
+               return err;
+       err = copy_reference_state(dst, src);
+       if (err)
+               return err;
+       return 0;
 }
 
 static void free_func_state(struct bpf_func_state *state)
 {
        if (!state)
                return;
+       kfree(state->refs);
        kfree(state->stack);
        kfree(state);
 }
 {
        int err;
 
-       err = realloc_func_state(dst, src->allocated_stack, false);
+       err = realloc_func_state(dst, src->allocated_stack, src->acquired_refs,
+                                false);
+       if (err)
+               return err;
+       memcpy(dst, src, offsetof(struct bpf_func_state, acquired_refs));
+       err = copy_reference_state(dst, src);
        if (err)
                return err;
-       memcpy(dst, src, offsetof(struct bpf_func_state, allocated_stack));
        return copy_stack_state(dst, src);
 }
 
        enum bpf_reg_type type;
 
        err = realloc_func_state(state, round_up(slot + 1, BPF_REG_SIZE),
-                                true);
+                                state->acquired_refs, true);
        if (err)
                return err;
        /* caller checked that off % size == 0 and -MAX_BPF_STACK <= off < 0,
 {
        const struct bpf_reg_state *reg = cur_regs(env) + regno;
 
-       return reg->type == PTR_TO_CTX;
+       return reg->type == PTR_TO_CTX ||
+              reg->type == PTR_TO_SOCKET;
 }
 
 static bool is_pkt_reg(struct bpf_verifier_env *env, int regno)
                expected_type = PTR_TO_SOCKET;
                if (type != expected_type)
                        goto err_type;
+               if (meta->ptr_id || !reg->id) {
+                       verbose(env, "verifier internal error: mismatched references meta=%d, reg=%d\n",
+                               meta->ptr_id, reg->id);
+                       return -EFAULT;
+               }
+               meta->ptr_id = reg->id;
        } else if (arg_type_is_mem_ptr(arg_type)) {
                expected_type = PTR_TO_STACK;
                /* One exception here. In case function allows for NULL to be
        return true;
 }
 
+static bool check_refcount_ok(const struct bpf_func_proto *fn)
+{
+       int count = 0;
+
+       if (arg_type_is_refcounted(fn->arg1_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg2_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg3_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg4_type))
+               count++;
+       if (arg_type_is_refcounted(fn->arg5_type))
+               count++;
+
+       /* We only support one arg being unreferenced at the moment,
+        * which is sufficient for the helper functions we have right now.
+        */
+       return count <= 1;
+}
+
 static int check_func_proto(const struct bpf_func_proto *fn)
 {
        return check_raw_mode_ok(fn) &&
-              check_arg_pair_ok(fn) ? 0 : -EINVAL;
+              check_arg_pair_ok(fn) &&
+              check_refcount_ok(fn) ? 0 : -EINVAL;
 }
 
 /* Packet data might have moved, any old PTR_TO_PACKET[_META,_END]
                __clear_all_pkt_pointers(env, vstate->frame[i]);
 }
 
+static void release_reg_references(struct bpf_verifier_env *env,
+                                  struct bpf_func_state *state, int id)
+{
+       struct bpf_reg_state *regs = state->regs, *reg;
+       int i;
+
+       for (i = 0; i < MAX_BPF_REG; i++)
+               if (regs[i].id == id)
+                       mark_reg_unknown(env, regs, i);
+
+       bpf_for_each_spilled_reg(i, state, reg) {
+               if (!reg)
+                       continue;
+               if (reg_is_refcounted(reg) && reg->id == id)
+                       __mark_reg_unknown(reg);
+       }
+}
+
+/* The pointer with the specified id has released its reference to kernel
+ * resources. Identify all copies of the same pointer and clear the reference.
+ */
+static int release_reference(struct bpf_verifier_env *env,
+                            struct bpf_call_arg_meta *meta)
+{
+       struct bpf_verifier_state *vstate = env->cur_state;
+       int i;
+
+       for (i = 0; i <= vstate->curframe; i++)
+               release_reg_references(env, vstate->frame[i], meta->ptr_id);
+
+       return release_reference_state(env, meta->ptr_id);
+}
+
 static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
                           int *insn_idx)
 {
        struct bpf_verifier_state *state = env->cur_state;
        struct bpf_func_state *caller, *callee;
-       int i, subprog, target_insn;
+       int i, err, subprog, target_insn;
 
        if (state->curframe + 1 >= MAX_CALL_FRAMES) {
                verbose(env, "the call stack of %d frames is too deep\n",
                        state->curframe + 1 /* frameno within this callchain */,
                        subprog /* subprog number within this prog */);
 
+       /* Transfer references to the callee */
+       err = transfer_reference_state(callee, caller);
+       if (err)
+               return err;
+
        /* copy r1 - r5 args that callee can access.  The copy includes parent
         * pointers, which connects us up to the liveness chain
         */
        struct bpf_verifier_state *state = env->cur_state;
        struct bpf_func_state *caller, *callee;
        struct bpf_reg_state *r0;
+       int err;
 
        callee = state->frame[state->curframe];
        r0 = &callee->regs[BPF_REG_0];
        /* return to the caller whatever r0 had in the callee */
        caller->regs[BPF_REG_0] = *r0;
 
+       /* Transfer references to the caller */
+       err = transfer_reference_state(caller, callee);
+       if (err)
+               return err;
+
        *insn_idx = callee->callsite + 1;
        if (env->log.level) {
                verbose(env, "returning from callee:\n");
        return 0;
 }
 
+static int check_reference_leak(struct bpf_verifier_env *env)
+{
+       struct bpf_func_state *state = cur_func(env);
+       int i;
+
+       for (i = 0; i < state->acquired_refs; i++) {
+               verbose(env, "Unreleased reference id=%d alloc_insn=%d\n",
+                       state->refs[i].id, state->refs[i].insn_idx);
+       }
+       return state->acquired_refs ? -EINVAL : 0;
+}
+
 static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx)
 {
        const struct bpf_func_proto *fn = NULL;
                        return err;
        }
 
+       if (func_id == BPF_FUNC_tail_call) {
+               err = check_reference_leak(env);
+               if (err) {
+                       verbose(env, "tail_call would lead to reference leak\n");
+                       return err;
+               }
+       } else if (is_release_function(func_id)) {
+               err = release_reference(env, &meta);
+               if (err)
+                       return err;
+       }
+
        regs = cur_regs(env);
 
        /* check that flags argument in get_local_storage(map, flags) is 0,
                regs[BPF_REG_0].map_ptr = meta.map_ptr;
                regs[BPF_REG_0].id = ++env->id_gen;
        } else if (fn->ret_type == RET_PTR_TO_SOCKET_OR_NULL) {
+               int id = acquire_reference_state(env, insn_idx);
+               if (id < 0)
+                       return id;
                mark_reg_known_zero(env, regs, BPF_REG_0);
                regs[BPF_REG_0].type = PTR_TO_SOCKET_OR_NULL;
-               regs[BPF_REG_0].id = ++env->id_gen;
+               regs[BPF_REG_0].id = id;
        } else {
                verbose(env, "unknown return type %d of func %s#%d\n",
                        fn->ret_type, func_id_name(func_id), func_id);
        }
 }
 
-static void mark_ptr_or_null_reg(struct bpf_reg_state *reg, u32 id,
+static void mark_ptr_or_null_reg(struct bpf_func_state *state,
+                                struct bpf_reg_state *reg, u32 id,
                                 bool is_null)
 {
        if (reg_type_may_be_null(reg->type) && reg->id == id) {
                } else if (reg->type == PTR_TO_SOCKET_OR_NULL) {
                        reg->type = PTR_TO_SOCKET;
                }
-               /* We don't need id from this point onwards anymore, thus we
-                * should better reset it, so that state pruning has chances
-                * to take effect.
-                */
-               reg->id = 0;
+               if (is_null || !reg_is_refcounted(reg)) {
+                       /* We don't need id from this point onwards anymore,
+                        * thus we should better reset it, so that state
+                        * pruning has chances to take effect.
+                        */
+                       reg->id = 0;
+               }
        }
 }
 
        u32 id = regs[regno].id;
        int i, j;
 
+       if (reg_is_refcounted_or_null(®s[regno]) && is_null)
+               __release_reference_state(state, id);
+
        for (i = 0; i < MAX_BPF_REG; i++)
-               mark_ptr_or_null_reg(®s[i], id, is_null);
+               mark_ptr_or_null_reg(state, ®s[i], id, is_null);
 
        for (j = 0; j <= vstate->curframe; j++) {
                state = vstate->frame[j];
                bpf_for_each_spilled_reg(i, state, reg) {
                        if (!reg)
                                continue;
-                       mark_ptr_or_null_reg(reg, id, is_null);
+                       mark_ptr_or_null_reg(state, reg, id, is_null);
                }
        }
 }
        if (err)
                return err;
 
+       /* Disallow usage of BPF_LD_[ABS|IND] with reference tracking, as
+        * gen_ld_abs() may terminate the program at runtime, leading to
+        * reference leak.
+        */
+       err = check_reference_leak(env);
+       if (err) {
+               verbose(env, "BPF_LD_[ABS|IND] cannot be mixed with socket references\n");
+               return err;
+       }
+
        if (regs[BPF_REG_6].type != PTR_TO_CTX) {
                verbose(env,
                        "at the time of BPF_LD_ABS|IND R6 != pointer to skb\n");
        return true;
 }
 
+static bool refsafe(struct bpf_func_state *old, struct bpf_func_state *cur)
+{
+       if (old->acquired_refs != cur->acquired_refs)
+               return false;
+       return !memcmp(old->refs, cur->refs,
+                      sizeof(*old->refs) * old->acquired_refs);
+}
+
 /* compare two verifier states
  *
  * all states stored in state_list are known to be valid, since
 
        if (!stacksafe(old, cur, idmap))
                goto out_free;
+
+       if (!refsafe(old, cur))
+               goto out_free;
        ret = true;
 out_free:
        kfree(idmap);
 
                regs = cur_regs(env);
                env->insn_aux_data[insn_idx].seen = true;
+
                if (class == BPF_ALU || class == BPF_ALU64) {
                        err = check_alu_op(env, insn);
                        if (err)
                                        continue;
                                }
 
+                               err = check_reference_leak(env);
+                               if (err)
+                                       return err;
+
                                /* eBPF calling convetion is such that R0 is used
                                 * to return the value from eBPF program.
                                 * Make sure that it's readable at this time