diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index e9b4b53ab53e..d824d548447e 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -45,17 +45,27 @@ MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_LICENSE("Dual BSD/GPL"); +enum { + TLSV4, + TLSV6, + TLS_NUM_PROTS, +}; + enum { TLS_BASE_TX, TLS_SW_TX, TLS_NUM_CONFIG, }; -static struct proto tls_prots[TLS_NUM_CONFIG]; +static struct proto *saved_tcpv6_prot; +static DEFINE_MUTEX(tcpv6_prot_mutex); +static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) { - sk->sk_prot = &tls_prots[ctx->tx_conf]; + int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; + + sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf]; } int wait_on_pending_writer(struct sock *sk, long *timeo) @@ -453,8 +463,21 @@ static int tls_setsockopt(struct sock *sk, int level, int optname, return do_tls_setsockopt(sk, optname, optval, optlen); } +static void build_protos(struct proto *prot, struct proto *base) +{ + prot[TLS_BASE_TX] = *base; + prot[TLS_BASE_TX].setsockopt = tls_setsockopt; + prot[TLS_BASE_TX].getsockopt = tls_getsockopt; + prot[TLS_BASE_TX].close = tls_sk_proto_close; + + prot[TLS_SW_TX] = prot[TLS_BASE_TX]; + prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; + prot[TLS_SW_TX].sendpage = tls_sw_sendpage; +} + static int tls_init(struct sock *sk) { + int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; struct inet_connection_sock *icsk = inet_csk(sk); struct tls_context *ctx; int rc = 0; @@ -479,6 +502,17 @@ static int tls_init(struct sock *sk) ctx->getsockopt = sk->sk_prot->getsockopt; ctx->sk_proto_close = sk->sk_prot->close; + /* Build IPv6 TLS whenever the address of tcpv6_prot changes */ + if (ip_ver == TLSV6 && + unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { + mutex_lock(&tcpv6_prot_mutex); + if (likely(sk->sk_prot != saved_tcpv6_prot)) { + build_protos(tls_prots[TLSV6], sk->sk_prot); + smp_store_release(&saved_tcpv6_prot, sk->sk_prot); + } + mutex_unlock(&tcpv6_prot_mutex); + } + ctx->tx_conf = TLS_BASE_TX; update_sk_prot(sk, ctx); out: @@ -493,21 +527,9 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .init = tls_init, }; -static void build_protos(struct proto *prot, struct proto *base) -{ - prot[TLS_BASE_TX] = *base; - prot[TLS_BASE_TX].setsockopt = tls_setsockopt; - prot[TLS_BASE_TX].getsockopt = tls_getsockopt; - prot[TLS_BASE_TX].close = tls_sk_proto_close; - - prot[TLS_SW_TX] = prot[TLS_BASE_TX]; - prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; - prot[TLS_SW_TX].sendpage = tls_sw_sendpage; -} - static int __init tls_register(void) { - build_protos(tls_prots, &tcp_prot); + build_protos(tls_prots[TLSV4], &tcp_prot); tcp_register_ulp(&tcp_tls_ulp_ops);