*     Return
  *             0
  *
- * int bpf_setsockopt(struct bpf_sock_ops *bpf_socket, int level, int optname, void *optval, int optlen)
+ * int bpf_setsockopt(void *bpf_socket, int level, int optname, void *optval, int optlen)
  *     Description
  *             Emulate a call to **setsockopt()** on the socket associated to
  *             *bpf_socket*, which must be a full socket. The *level* at
  *             must be specified, see **setsockopt(2)** for more information.
  *             The option value of length *optlen* is pointed by *optval*.
  *
+ *             *bpf_socket* should be one of the following:
+ *             * **struct bpf_sock_ops** for **BPF_PROG_TYPE_SOCK_OPS**.
+ *             * **struct bpf_sock_addr** for **BPF_CGROUP_INET4_CONNECT**
+ *               and **BPF_CGROUP_INET6_CONNECT**.
+ *
  *             This helper actually implements a subset of **setsockopt()**.
  *             It supports the following *level*\ s:
  *
  *     Return
  *             0 on success, or a negative error in case of failure.
  *
- * int bpf_getsockopt(struct bpf_sock_ops *bpf_socket, int level, int optname, void *optval, int optlen)
+ * int bpf_getsockopt(void *bpf_socket, int level, int optname, void *optval, int optlen)
  *     Description
  *             Emulate a call to **getsockopt()** on the socket associated to
  *             *bpf_socket*, which must be a full socket. The *level* at
  *             The retrieved value is stored in the structure pointed by
  *             *opval* and of length *optlen*.
  *
+ *             *bpf_socket* should be one of the following:
+ *             * **struct bpf_sock_ops** for **BPF_PROG_TYPE_SOCK_OPS**.
+ *             * **struct bpf_sock_addr** for **BPF_CGROUP_INET4_CONNECT**
+ *               and **BPF_CGROUP_INET6_CONNECT**.
+ *
  *             This helper actually implements a subset of **getsockopt()**.
  *             It supports the following *level*\ s:
  *
 
        .arg1_type      = ARG_PTR_TO_CTX,
 };
 
-BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
-          int, level, int, optname, char *, optval, int, optlen)
+#define SOCKOPT_CC_REINIT (1 << 0)
+
+static int _bpf_setsockopt(struct sock *sk, int level, int optname,
+                          char *optval, int optlen, u32 flags)
 {
-       struct sock *sk = bpf_sock->sk;
        int ret = 0;
        int val;
 
        if (!sk_fullsock(sk))
                return -EINVAL;
 
+       sock_owned_by_me(sk);
+
        if (level == SOL_SOCKET) {
                if (optlen != sizeof(int))
                        return -EINVAL;
                   sk->sk_prot->setsockopt == tcp_setsockopt) {
                if (optname == TCP_CONGESTION) {
                        char name[TCP_CA_NAME_MAX];
-                       bool reinit = bpf_sock->op > BPF_SOCK_OPS_NEEDS_ECN;
+                       bool reinit = flags & SOCKOPT_CC_REINIT;
 
                        strncpy(name, optval, min_t(long, optlen,
                                                    TCP_CA_NAME_MAX-1));
        return ret;
 }
 
-static const struct bpf_func_proto bpf_setsockopt_proto = {
-       .func           = bpf_setsockopt,
-       .gpl_only       = false,
-       .ret_type       = RET_INTEGER,
-       .arg1_type      = ARG_PTR_TO_CTX,
-       .arg2_type      = ARG_ANYTHING,
-       .arg3_type      = ARG_ANYTHING,
-       .arg4_type      = ARG_PTR_TO_MEM,
-       .arg5_type      = ARG_CONST_SIZE,
-};
-
-BPF_CALL_5(bpf_getsockopt, struct bpf_sock_ops_kern *, bpf_sock,
-          int, level, int, optname, char *, optval, int, optlen)
+static int _bpf_getsockopt(struct sock *sk, int level, int optname,
+                          char *optval, int optlen)
 {
-       struct sock *sk = bpf_sock->sk;
-
        if (!sk_fullsock(sk))
                goto err_clear;
+
+       sock_owned_by_me(sk);
+
 #ifdef CONFIG_INET
        if (level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) {
                struct inet_connection_sock *icsk;
        return -EINVAL;
 }
 
-static const struct bpf_func_proto bpf_getsockopt_proto = {
-       .func           = bpf_getsockopt,
+BPF_CALL_5(bpf_sock_addr_setsockopt, struct bpf_sock_addr_kern *, ctx,
+          int, level, int, optname, char *, optval, int, optlen)
+{
+       u32 flags = 0;
+       return _bpf_setsockopt(ctx->sk, level, optname, optval, optlen,
+                              flags);
+}
+
+static const struct bpf_func_proto bpf_sock_addr_setsockopt_proto = {
+       .func           = bpf_sock_addr_setsockopt,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_ANYTHING,
+       .arg4_type      = ARG_PTR_TO_MEM,
+       .arg5_type      = ARG_CONST_SIZE,
+};
+
+BPF_CALL_5(bpf_sock_addr_getsockopt, struct bpf_sock_addr_kern *, ctx,
+          int, level, int, optname, char *, optval, int, optlen)
+{
+       return _bpf_getsockopt(ctx->sk, level, optname, optval, optlen);
+}
+
+static const struct bpf_func_proto bpf_sock_addr_getsockopt_proto = {
+       .func           = bpf_sock_addr_getsockopt,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_ANYTHING,
+       .arg4_type      = ARG_PTR_TO_UNINIT_MEM,
+       .arg5_type      = ARG_CONST_SIZE,
+};
+
+BPF_CALL_5(bpf_sock_ops_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
+          int, level, int, optname, char *, optval, int, optlen)
+{
+       u32 flags = 0;
+       if (bpf_sock->op > BPF_SOCK_OPS_NEEDS_ECN)
+               flags |= SOCKOPT_CC_REINIT;
+       return _bpf_setsockopt(bpf_sock->sk, level, optname, optval, optlen,
+                              flags);
+}
+
+static const struct bpf_func_proto bpf_sock_ops_setsockopt_proto = {
+       .func           = bpf_sock_ops_setsockopt,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_ANYTHING,
+       .arg4_type      = ARG_PTR_TO_MEM,
+       .arg5_type      = ARG_CONST_SIZE,
+};
+
+BPF_CALL_5(bpf_sock_ops_getsockopt, struct bpf_sock_ops_kern *, bpf_sock,
+          int, level, int, optname, char *, optval, int, optlen)
+{
+       return _bpf_getsockopt(bpf_sock->sk, level, optname, optval, optlen);
+}
+
+static const struct bpf_func_proto bpf_sock_ops_getsockopt_proto = {
+       .func           = bpf_sock_ops_getsockopt,
        .gpl_only       = false,
        .ret_type       = RET_INTEGER,
        .arg1_type      = ARG_PTR_TO_CTX,
                return &bpf_sk_storage_get_proto;
        case BPF_FUNC_sk_storage_delete:
                return &bpf_sk_storage_delete_proto;
+       case BPF_FUNC_setsockopt:
+               switch (prog->expected_attach_type) {
+               case BPF_CGROUP_INET4_CONNECT:
+               case BPF_CGROUP_INET6_CONNECT:
+                       return &bpf_sock_addr_setsockopt_proto;
+               default:
+                       return NULL;
+               }
+       case BPF_FUNC_getsockopt:
+               switch (prog->expected_attach_type) {
+               case BPF_CGROUP_INET4_CONNECT:
+               case BPF_CGROUP_INET6_CONNECT:
+                       return &bpf_sock_addr_getsockopt_proto;
+               default:
+                       return NULL;
+               }
        default:
                return bpf_base_func_proto(func_id);
        }
 {
        switch (func_id) {
        case BPF_FUNC_setsockopt:
-               return &bpf_setsockopt_proto;
+               return &bpf_sock_ops_setsockopt_proto;
        case BPF_FUNC_getsockopt:
-               return &bpf_getsockopt_proto;
+               return &bpf_sock_ops_getsockopt_proto;
        case BPF_FUNC_sock_ops_cb_flags_set:
                return &bpf_sock_ops_cb_flags_set_proto;
        case BPF_FUNC_sock_map_update:
 
  *     Return
  *             0
  *
- * int bpf_setsockopt(struct bpf_sock_ops *bpf_socket, int level, int optname, void *optval, int optlen)
+ * int bpf_setsockopt(void *bpf_socket, int level, int optname, void *optval, int optlen)
  *     Description
  *             Emulate a call to **setsockopt()** on the socket associated to
  *             *bpf_socket*, which must be a full socket. The *level* at
  *             must be specified, see **setsockopt(2)** for more information.
  *             The option value of length *optlen* is pointed by *optval*.
  *
+ *             *bpf_socket* should be one of the following:
+ *             * **struct bpf_sock_ops** for **BPF_PROG_TYPE_SOCK_OPS**.
+ *             * **struct bpf_sock_addr** for **BPF_CGROUP_INET4_CONNECT**
+ *               and **BPF_CGROUP_INET6_CONNECT**.
+ *
  *             This helper actually implements a subset of **setsockopt()**.
  *             It supports the following *level*\ s:
  *
  *     Return
  *             0 on success, or a negative error in case of failure.
  *
- * int bpf_getsockopt(struct bpf_sock_ops *bpf_socket, int level, int optname, void *optval, int optlen)
+ * int bpf_getsockopt(void *bpf_socket, int level, int optname, void *optval, int optlen)
  *     Description
  *             Emulate a call to **getsockopt()** on the socket associated to
  *             *bpf_socket*, which must be a full socket. The *level* at
  *             The retrieved value is stored in the structure pointed by
  *             *opval* and of length *optlen*.
  *
+ *             *bpf_socket* should be one of the following:
+ *             * **struct bpf_sock_ops** for **BPF_PROG_TYPE_SOCK_OPS**.
+ *             * **struct bpf_sock_addr** for **BPF_CGROUP_INET4_CONNECT**
+ *               and **BPF_CGROUP_INET6_CONNECT**.
+ *
  *             This helper actually implements a subset of **getsockopt()**.
  *             It supports the following *level*\ s:
  *
 
 #include <linux/in.h>
 #include <linux/in6.h>
 #include <sys/socket.h>
+#include <netinet/tcp.h>
 
 #include <bpf/bpf_helpers.h>
 #include <bpf/bpf_endian.h>
 #define DST_REWRITE_IP4                0x7f000001U
 #define DST_REWRITE_PORT4      4444
 
+#ifndef TCP_CA_NAME_MAX
+#define TCP_CA_NAME_MAX 16
+#endif
+
 int _version SEC("version") = 1;
 
 __attribute__ ((noinline))
        return 1;
 }
 
+static __inline int verify_cc(struct bpf_sock_addr *ctx,
+                             char expected[TCP_CA_NAME_MAX])
+{
+       char buf[TCP_CA_NAME_MAX];
+       int i;
+
+       if (bpf_getsockopt(ctx, SOL_TCP, TCP_CONGESTION, &buf, sizeof(buf)))
+               return 1;
+
+       for (i = 0; i < TCP_CA_NAME_MAX; i++) {
+               if (buf[i] != expected[i])
+                       return 1;
+               if (buf[i] == 0)
+                       break;
+       }
+
+       return 0;
+}
+
+static __inline int set_cc(struct bpf_sock_addr *ctx)
+{
+       char dctcp[TCP_CA_NAME_MAX] = "dctcp";
+       char cubic[TCP_CA_NAME_MAX] = "cubic";
+
+       if (bpf_setsockopt(ctx, SOL_TCP, TCP_CONGESTION, &dctcp, sizeof(dctcp)))
+               return 1;
+       if (verify_cc(ctx, dctcp))
+               return 1;
+
+       if (bpf_setsockopt(ctx, SOL_TCP, TCP_CONGESTION, &cubic, sizeof(cubic)))
+               return 1;
+       if (verify_cc(ctx, cubic))
+               return 1;
+
+       return 0;
+}
+
 SEC("cgroup/connect4")
 int connect_v4_prog(struct bpf_sock_addr *ctx)
 {
 
        bpf_sk_release(sk);
 
+       /* Rewrite congestion control. */
+       if (ctx->type == SOCK_STREAM && set_cc(ctx))
+               return 0;
+
        /* Rewrite destination. */
        ctx->user_ip4 = bpf_htonl(DST_REWRITE_IP4);
        ctx->user_port = bpf_htons(DST_REWRITE_PORT4);