diff options
Diffstat (limited to 'net/xfrm/espintcp.c')
-rw-r--r-- | net/xfrm/espintcp.c | 56 |
1 files changed, 45 insertions, 11 deletions
diff --git a/net/xfrm/espintcp.c b/net/xfrm/espintcp.c index 037ea156d2f9..2132a3b6df0f 100644 --- a/net/xfrm/espintcp.c +++ b/net/xfrm/espintcp.c @@ -6,6 +6,9 @@ #include <net/espintcp.h> #include <linux/skmsg.h> #include <net/inet_common.h> +#if IS_ENABLED(CONFIG_IPV6) +#include <net/ipv6_stubs.h> +#endif static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb, struct sock *sk) @@ -31,7 +34,12 @@ static void handle_esp(struct sk_buff *skb, struct sock *sk) rcu_read_lock(); skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); local_bh_disable(); - xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); + else +#endif + xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); local_bh_enable(); rcu_read_unlock(); } @@ -347,6 +355,9 @@ unlock: static struct proto espintcp_prot __ro_after_init; static struct proto_ops espintcp_ops __ro_after_init; +static struct proto espintcp6_prot; +static struct proto_ops espintcp6_ops; +static DEFINE_MUTEX(tcpv6_prot_mutex); static void espintcp_data_ready(struct sock *sk) { @@ -384,10 +395,14 @@ static void espintcp_destruct(struct sock *sk) bool tcp_is_ulp_esp(struct sock *sk) { - return sk->sk_prot == &espintcp_prot; + return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot; } EXPORT_SYMBOL_GPL(tcp_is_ulp_esp); +static void build_protos(struct proto *espintcp_prot, + struct proto_ops *espintcp_ops, + const struct proto *orig_prot, + const struct proto_ops *orig_ops); static int espintcp_init_sk(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); @@ -415,8 +430,19 @@ static int espintcp_init_sk(struct sock *sk) strp_check_rcv(&ctx->strp); skb_queue_head_init(&ctx->ike_queue); skb_queue_head_init(&ctx->out_queue); - sk->sk_prot = &espintcp_prot; - sk->sk_socket->ops = &espintcp_ops; + + if (sk->sk_family == AF_INET) { + sk->sk_prot = &espintcp_prot; + sk->sk_socket->ops = &espintcp_ops; + } else { + mutex_lock(&tcpv6_prot_mutex); + if (!espintcp6_prot.recvmsg) + build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops); + mutex_unlock(&tcpv6_prot_mutex); + + sk->sk_prot = &espintcp6_prot; + sk->sk_socket->ops = &espintcp6_ops; + } ctx->saved_data_ready = sk->sk_data_ready; ctx->saved_write_space = sk->sk_write_space; sk->sk_data_ready = espintcp_data_ready; @@ -489,6 +515,20 @@ static __poll_t espintcp_poll(struct file *file, struct socket *sock, return mask; } +static void build_protos(struct proto *espintcp_prot, + struct proto_ops *espintcp_ops, + const struct proto *orig_prot, + const struct proto_ops *orig_ops) +{ + memcpy(espintcp_prot, orig_prot, sizeof(struct proto)); + memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops)); + espintcp_prot->sendmsg = espintcp_sendmsg; + espintcp_prot->recvmsg = espintcp_recvmsg; + espintcp_prot->close = espintcp_close; + espintcp_prot->release_cb = espintcp_release; + espintcp_ops->poll = espintcp_poll; +} + static struct tcp_ulp_ops espintcp_ulp __read_mostly = { .name = "espintcp", .owner = THIS_MODULE, @@ -497,13 +537,7 @@ static struct tcp_ulp_ops espintcp_ulp __read_mostly = { void __init espintcp_init(void) { - memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot)); - memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops)); - espintcp_prot.sendmsg = espintcp_sendmsg; - espintcp_prot.recvmsg = espintcp_recvmsg; - espintcp_prot.close = espintcp_close; - espintcp_prot.release_cb = espintcp_release; - espintcp_ops.poll = espintcp_poll; + build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops); tcp_register_ulp(&espintcp_ulp); } |