summaryrefslogtreecommitdiff
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/core/skmsg.c4
-rw-r--r--net/ipv4/tcp_ulp.c13
-rw-r--r--net/tls/tls_main.c33
3 files changed, 43 insertions, 7 deletions
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index 93bffaad2135..6832eeb4b785 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -585,12 +585,12 @@ EXPORT_SYMBOL_GPL(sk_psock_destroy);
void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
{
- rcu_assign_sk_user_data(sk, NULL);
sk_psock_cork_free(psock);
sk_psock_zap_ingress(psock);
- sk_psock_restore_proto(sk, psock);
write_lock_bh(&sk->sk_callback_lock);
+ sk_psock_restore_proto(sk, psock);
+ rcu_assign_sk_user_data(sk, NULL);
if (psock->progs.skb_parser)
sk_psock_stop_strp(sk, psock);
write_unlock_bh(&sk->sk_callback_lock);
diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c
index 3d8a1d835471..4849edb62d52 100644
--- a/net/ipv4/tcp_ulp.c
+++ b/net/ipv4/tcp_ulp.c
@@ -96,6 +96,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen)
rcu_read_unlock();
}
+void tcp_update_ulp(struct sock *sk, struct proto *proto)
+{
+ struct inet_connection_sock *icsk = inet_csk(sk);
+
+ if (!icsk->icsk_ulp_ops) {
+ sk->sk_prot = proto;
+ return;
+ }
+
+ if (icsk->icsk_ulp_ops->update)
+ icsk->icsk_ulp_ops->update(sk, proto);
+}
+
void tcp_cleanup_ulp(struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 48f1c26459d0..f208f8455ef2 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -328,7 +328,10 @@ static void tls_sk_proto_unhash(struct sock *sk)
ctx = tls_get_ctx(sk);
tls_sk_proto_cleanup(sk, ctx, timeo);
+ write_lock_bh(&sk->sk_callback_lock);
icsk->icsk_ulp_data = NULL;
+ sk->sk_prot = ctx->sk_proto;
+ write_unlock_bh(&sk->sk_callback_lock);
if (ctx->sk_proto->unhash)
ctx->sk_proto->unhash(sk);
@@ -337,7 +340,7 @@ static void tls_sk_proto_unhash(struct sock *sk)
static void tls_sk_proto_close(struct sock *sk, long timeout)
{
- void (*sk_proto_close)(struct sock *sk, long timeout);
+ struct inet_connection_sock *icsk = inet_csk(sk);
struct tls_context *ctx = tls_get_ctx(sk);
long timeo = sock_sndtimeo(sk, 0);
bool free_ctx;
@@ -347,12 +350,15 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
lock_sock(sk);
free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
- sk_proto_close = ctx->sk_proto_close;
if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
tls_sk_proto_cleanup(sk, ctx, timeo);
+ write_lock_bh(&sk->sk_callback_lock);
+ if (free_ctx)
+ icsk->icsk_ulp_data = NULL;
sk->sk_prot = ctx->sk_proto;
+ write_unlock_bh(&sk->sk_callback_lock);
release_sock(sk);
if (ctx->tx_conf == TLS_SW)
tls_sw_free_ctx_tx(ctx);
@@ -360,7 +366,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
tls_sw_strparser_done(ctx);
if (ctx->rx_conf == TLS_SW)
tls_sw_free_ctx_rx(ctx);
- sk_proto_close(sk, timeout);
+ ctx->sk_proto_close(sk, timeout);
if (free_ctx)
tls_ctx_free(ctx);
@@ -827,7 +833,7 @@ static int tls_init(struct sock *sk)
int rc = 0;
if (tls_hw_prot(sk))
- goto out;
+ return 0;
/* The TLS ulp is currently supported only for TCP sockets
* in ESTABLISHED state.
@@ -838,22 +844,38 @@ static int tls_init(struct sock *sk)
if (sk->sk_state != TCP_ESTABLISHED)
return -ENOTSUPP;
+ tls_build_proto(sk);
+
/* allocate tls context */
+ write_lock_bh(&sk->sk_callback_lock);
ctx = create_ctx(sk);
if (!ctx) {
rc = -ENOMEM;
goto out;
}
- tls_build_proto(sk);
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx);
out:
+ write_unlock_bh(&sk->sk_callback_lock);
return rc;
}
+static void tls_update(struct sock *sk, struct proto *p)
+{
+ struct tls_context *ctx;
+
+ ctx = tls_get_ctx(sk);
+ if (likely(ctx)) {
+ ctx->sk_proto_close = p->close;
+ ctx->sk_proto = p;
+ } else {
+ sk->sk_prot = p;
+ }
+}
+
void tls_register_device(struct tls_device *device)
{
spin_lock_bh(&device_spinlock);
@@ -874,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
.name = "tls",
.owner = THIS_MODULE,
.init = tls_init,
+ .update = tls_update,
};
static int __init tls_register(void)