Mon Aug 31 20:33:58 2020 UTC ()
wg: Simplify locking.

Summary: Access to a stable established session is still allowed via
psref; all other access to peer and session state is now serialized
by struct wg_peer::wgp_lock, with no dancing around a per-session
lock.  This way, the handshake paths are locked, while the data
transmission paths are pserialized.

- Eliminate struct wg_session::wgs_lock.

- Eliminate wg_get_unstable_session -- access to the unstable session
  is allowed only with struct wgp_peer::wgp_lock held.

- Push INIT_PASSIVE->ESTABLISHED transition down into a thread task.

- Push rekey down into a thread task.

- Allocate session indices only on transition from UNKNOWN and free
  them only on transition back to UNKNOWN.

- Be a little more explicit about allowed state transitions, and
  reject some nonsensical ones.

- Sprinkle assertions and comments.

- Reduce atomic r/m/w swap operations that can just as well be
  store-release.


(riastradh)
diff -r1.48 -r1.49 src/sys/net/if_wg.c

cvs diff -r1.48 -r1.49 src/sys/net/if_wg.c (expand / switch to context diff)
--- src/sys/net/if_wg.c 2020/08/31 20:31:43 1.48
+++ src/sys/net/if_wg.c 2020/08/31 20:33:58 1.49
@@ -1,4 +1,4 @@
-/*	$NetBSD: if_wg.c,v 1.48 2020/08/31 20:31:43 riastradh Exp $	*/
+/*	$NetBSD: if_wg.c,v 1.49 2020/08/31 20:33:58 riastradh Exp $	*/
 
 /*
  * Copyright (C) Ryota Ozaki <ozaki.ryota@gmail.com>
@@ -41,7 +41,7 @@
  */
 
 #include <sys/cdefs.h>
-__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.48 2020/08/31 20:31:43 riastradh Exp $");
+__KERNEL_RCSID(0, "$NetBSD: if_wg.c,v 1.49 2020/08/31 20:33:58 riastradh Exp $");
 
 #ifdef _KERNEL_OPT
 #include "opt_inet.h"
@@ -137,7 +137,7 @@
  * - struct wg_session represents a session of a secure tunnel with a peer
  *   - Two instances of sessions belong to a peer; a stable session and a
  *     unstable session
- *   - A handshake process of a session always starts with a unstable instace
+ *   - A handshake process of a session always starts with a unstable instance
  *   - Once a session is established, its instance becomes stable and the
  *     other becomes unstable instead
  *   - Data messages are always sent via a stable session
@@ -145,25 +145,22 @@
  * Locking notes:
  * - wg interfaces (struct wg_softc, wg) is listed in wg_softcs.list and
  *   protected by wg_softcs.lock
- * - Each wg has a mutex(9) and a rwlock(9)
- *   - The mutex (wg_lock) protects its peer list (wg_peers)
- *   - A peer on the list is also protected by pserialize(9) or psref(9)
+ * - Each wg has a mutex(9) wg_lock, and a rwlock(9) wg_rwlock
+ *   - Changes to the peer list are serialized by wg_lock
+ *   - The peer list may be read with pserialize(9) and psref(9)
  *   - The rwlock (wg_rwlock) protects the routing tables (wg_rtable_ipv[46])
- * - Each peer (struct wg_peer, wgp) has a mutex
- *   - The mutex (wgp_lock) protects wgp_session_unstable and wgp_state
- * - Each session (struct wg_session, wgs) has a mutex
- *   - The mutex (wgs_lock) protects its state (wgs_state) and its handshake
- *     states
- *   - wgs_state of a unstable session can be changed while it never be
- *     changed on a stable session, so once get a session instace via
- *     wgp_session_stable we can safely access wgs_state without
- *     holding wgs_lock
- *   - A session is protected by pserialize or psref like wgp
+ *     => XXX replace by pserialize when routing table is psz-safe
+ * - Each peer (struct wg_peer, wgp) has a mutex wgp_lock, which can be taken
+ *   only in thread context and serializes:
+ *   - the stable and unstable session pointers
+ *   - all unstable session state
+ * - Packet processing may be done in softint context:
+ *   - The stable session can be read under pserialize(9) or psref(9)
+ *     - The stable session is always ESTABLISHED
  *     - On a session swap, we must wait for all readers to release a
  *       reference to a stable session before changing wgs_state and
  *       session states
- *
- * Lock order: wg_lock -> wgp_lock -> wgs_lock
+ * - Lock order: wg_lock -> wgp_lock
  */
 
 
@@ -444,7 +441,6 @@
 	struct wg_peer	*wgs_peer;
 	struct psref_target
 			wgs_psref;
-	kmutex_t	*wgs_lock;
 
 	int		wgs_state;
 #define WGS_STATE_UNKNOWN	0
@@ -457,8 +453,8 @@
 	time_t		wgs_time_last_data_sent;
 	bool		wgs_is_initiator;
 
-	uint32_t	wgs_sender_index;
-	uint32_t	wgs_receiver_index;
+	uint32_t	wgs_local_index;
+	uint32_t	wgs_remote_index;
 #ifdef __HAVE_ATOMIC64_LOADSTORE
 	volatile uint64_t
 			wgs_send_counter;
@@ -537,18 +533,12 @@
 	uint8_t	wgp_pubkey[WG_STATIC_KEY_LEN];
 	struct wg_sockaddr	*wgp_endpoint;
 	struct wg_sockaddr	*wgp_endpoint0;
-	bool			wgp_endpoint_changing;
+	volatile unsigned	wgp_endpoint_changing;
 	bool			wgp_endpoint_available;
 
 			/* The preshared key (optional) */
 	uint8_t		wgp_psk[WG_PRESHARED_KEY_LEN];
 
-	int wgp_state;
-#define WGP_STATE_INIT		0
-#define WGP_STATE_ESTABLISHED	1
-#define WGP_STATE_GIVEUP	2
-#define WGP_STATE_DESTROYING	3
-
 	void		*wgp_si;
 	pcq_t		*wgp_q;
 
@@ -585,9 +575,11 @@
 
 	volatile unsigned int	wgp_tasks;
 #define WGP_TASK_SEND_INIT_MESSAGE		__BIT(0)
-#define WGP_TASK_ENDPOINT_CHANGED		__BIT(1)
-#define WGP_TASK_SEND_KEEPALIVE_MESSAGE		__BIT(2)
-#define WGP_TASK_DESTROY_PREV_SESSION		__BIT(3)
+#define WGP_TASK_RETRY_HANDSHAKE		__BIT(1)
+#define WGP_TASK_ESTABLISH_SESSION		__BIT(2)
+#define WGP_TASK_ENDPOINT_CHANGED		__BIT(3)
+#define WGP_TASK_SEND_KEEPALIVE_MESSAGE		__BIT(4)
+#define WGP_TASK_DESTROY_PREV_SESSION		__BIT(5)
 };
 
 struct wg_ops;
@@ -652,8 +644,8 @@
 		    struct mbuf *);
 static int	wg_send_cookie_msg(struct wg_softc *, struct wg_peer *,
 		    const uint32_t, const uint8_t [], const struct sockaddr *);
-static int	wg_send_handshake_msg_resp(struct wg_softc *,
-		    struct wg_peer *, const struct wg_msg_init *);
+static int	wg_send_handshake_msg_resp(struct wg_softc *, struct wg_peer *,
+		    struct wg_session *, const struct wg_msg_init *);
 static void	wg_send_keepalive_msg(struct wg_peer *, struct wg_session *);
 
 static struct wg_peer *
@@ -691,6 +683,8 @@
 static int	wg_init(struct ifnet *);
 static void	wg_stop(struct ifnet *, int);
 
+static void	wg_purge_pending_packets(struct wg_peer *);
+
 static int	wg_clone_create(struct if_clone *, int);
 static int	wg_clone_destroy(struct ifnet *);
 
@@ -1100,111 +1094,122 @@
 	be32enc(timestamp + 8, ts.tv_nsec);
 }
 
+/*
+ * wg_get_stable_session(wgp, psref)
+ *
+ *	Get a passive reference to the current stable session, or
+ *	return NULL if there is no current stable session.
+ *
+ *	The pointer is always there but the session is not necessarily
+ *	ESTABLISHED; if it is not ESTABLISHED, return NULL.  However,
+ *	the session may transition from ESTABLISHED to DESTROYING while
+ *	holding the passive reference.
+ */
 static struct wg_session *
-wg_get_unstable_session(struct wg_peer *wgp, struct psref *psref)
-{
-	int s;
-	struct wg_session *wgs;
-
-	s = pserialize_read_enter();
-	wgs = wgp->wgp_session_unstable;
-	psref_acquire(psref, &wgs->wgs_psref, wg_psref_class);
-	pserialize_read_exit(s);
-	return wgs;
-}
-
-static struct wg_session *
 wg_get_stable_session(struct wg_peer *wgp, struct psref *psref)
 {
 	int s;
 	struct wg_session *wgs;
 
 	s = pserialize_read_enter();
-	wgs = wgp->wgp_session_stable;
-	psref_acquire(psref, &wgs->wgs_psref, wg_psref_class);
+	wgs = atomic_load_consume(&wgp->wgp_session_stable);
+	if (__predict_false(wgs->wgs_state != WGS_STATE_ESTABLISHED))
+		wgs = NULL;
+	else
+		psref_acquire(psref, &wgs->wgs_psref, wg_psref_class);
 	pserialize_read_exit(s);
+
 	return wgs;
 }
 
 static void
-wg_get_session(struct wg_session *wgs, struct psref *psref)
+wg_put_session(struct wg_session *wgs, struct psref *psref)
 {
 
-	psref_acquire(psref, &wgs->wgs_psref, wg_psref_class);
+	psref_release(psref, &wgs->wgs_psref, wg_psref_class);
 }
 
 static void
-wg_put_session(struct wg_session *wgs, struct psref *psref)
+wg_destroy_session(struct wg_softc *wg, struct wg_session *wgs)
 {
+	struct wg_peer *wgp = wgs->wgs_peer;
+	struct wg_session *wgs0 __diagused;
+	void *garbage;
 
-	psref_release(psref, &wgs->wgs_psref, wg_psref_class);
-}
+	KASSERT(mutex_owned(wgp->wgp_lock));
+	KASSERT(wgs->wgs_state != WGS_STATE_UNKNOWN);
 
-static struct wg_session *
-wg_lock_unstable_session(struct wg_peer *wgp)
-{
-	struct wg_session *wgs;
+	/* Remove the session from the table.  */
+	wgs0 = thmap_del(wg->wg_sessions_byindex,
+	    &wgs->wgs_local_index, sizeof(wgs->wgs_local_index));
+	KASSERT(wgs0 == wgs);
+	garbage = thmap_stage_gc(wg->wg_sessions_byindex);
 
-	mutex_enter(wgp->wgp_lock);
-	wgs = wgp->wgp_session_unstable;
-	mutex_enter(wgs->wgs_lock);
-	mutex_exit(wgp->wgp_lock);
-	return wgs;
+	/* Wait for passive references to drain.  */
+	pserialize_perform(wgp->wgp_psz);
+	psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
+
+	/* Free memory, zero state, and transition to UNKNOWN.  */
+	thmap_gc(wg->wg_sessions_byindex, garbage);
+	wg_clear_states(wgs);
+	wgs->wgs_state = WGS_STATE_UNKNOWN;
 }
 
-#if 0
+/*
+ * wg_get_session_index(wg, wgs)
+ *
+ *	Choose a session index for wgs->wgs_local_index, and store it
+ *	in wg's table of sessions by index.
+ *
+ *	wgs must be the unstable session of its peer, and must be
+ *	transitioning out of the UNKNOWN state.
+ */
 static void
-wg_unlock_session(struct wg_peer *wgp, struct wg_session *wgs)
+wg_get_session_index(struct wg_softc *wg, struct wg_session *wgs)
 {
-
-	mutex_exit(wgs->wgs_lock);
-}
-#endif
-
-static uint32_t
-wg_assign_sender_index(struct wg_softc *wg, struct wg_session *wgs)
-{
-	struct wg_peer *wgp = wgs->wgs_peer;
+	struct wg_peer *wgp __diagused = wgs->wgs_peer;
 	struct wg_session *wgs0;
 	uint32_t index;
-	void *garbage;
 
-	mutex_enter(wgs->wgs_lock);
+	KASSERT(mutex_owned(wgp->wgp_lock));
+	KASSERT(wgs == wgp->wgp_session_unstable);
+	KASSERT(wgs->wgs_state == WGS_STATE_UNKNOWN);
 
-	/* Release the current index, if there is one.  */
-	while ((index = wgs->wgs_sender_index) != 0) {
-		/* Remove the session by index.  */
-		thmap_del(wg->wg_sessions_byindex, &index, sizeof index);
-		wgs->wgs_sender_index = 0;
-		mutex_exit(wgs->wgs_lock);
+	do {
+		/* Pick a uniform random index.  */
+		index = cprng_strong32();
 
-		/* Wait for all thmap_gets to complete, and GC.  */
-		garbage = thmap_stage_gc(wg->wg_sessions_byindex);
-		mutex_enter(wgs->wgs_peer->wgp_lock);
-		pserialize_perform(wgp->wgp_psz);
-		mutex_exit(wgs->wgs_peer->wgp_lock);
-		thmap_gc(wg->wg_sessions_byindex, garbage);
+		/* Try to take it.  */
+		wgs->wgs_local_index = index;
+		wgs0 = thmap_put(wg->wg_sessions_byindex,
+		    &wgs->wgs_local_index, sizeof wgs->wgs_local_index, wgs);
 
-		mutex_enter(wgs->wgs_lock);
-	}
+		/* If someone else beat us, start over.  */
+	} while (__predict_false(wgs0 != wgs));
+}
 
-restart:
-	/* Pick a uniform random nonzero index.  */
-	while (__predict_false((index = cprng_strong32()) == 0))
-		continue;
+/*
+ * wg_put_session_index(wg, wgs)
+ *
+ *	Remove wgs from the table of sessions by index, wait for any
+ *	passive references to drain, and transition the session to the
+ *	UNKNOWN state.
+ *
+ *	wgs must be the unstable session of its peer, and must not be
+ *	UNKNOWN or ESTABLISHED.
+ */
+static void
+wg_put_session_index(struct wg_softc *wg, struct wg_session *wgs)
+{
+	struct wg_peer *wgp = wgs->wgs_peer;
 
-	/* Try to take it.  */
-	wgs->wgs_sender_index = index;
-	wgs0 = thmap_put(wg->wg_sessions_byindex,
-	    &wgs->wgs_sender_index, sizeof wgs->wgs_sender_index, wgs);
+	KASSERT(mutex_owned(wgp->wgp_lock));
+	KASSERT(wgs == wgp->wgp_session_unstable);
+	KASSERT(wgs->wgs_state != WGS_STATE_UNKNOWN);
+	KASSERT(wgs->wgs_state != WGS_STATE_ESTABLISHED);
 
-	/* If someone else beat us, start over.  */
-	if (__predict_false(wgs0 != wgs))
-		goto restart;
-
-	mutex_exit(wgs->wgs_lock);
-
-	return index;
+	wg_destroy_session(wg, wgs);
+	psref_target_init(&wgs->wgs_psref, wg_psref_class);
 }
 
 /*
@@ -1239,8 +1244,12 @@
 	uint8_t pubkey[WG_EPHEMERAL_KEY_LEN];
 	uint8_t privkey[WG_EPHEMERAL_KEY_LEN];
 
+	KASSERT(mutex_owned(wgp->wgp_lock));
+	KASSERT(wgs == wgp->wgp_session_unstable);
+	KASSERT(wgs->wgs_state == WGS_STATE_INIT_ACTIVE);
+
 	wgmi->wgmi_type = htole32(WG_MSG_TYPE_INIT);
-	wgmi->wgmi_sender = wg_assign_sender_index(wg, wgs);
+	wgmi->wgmi_sender = wgs->wgs_local_index;
 
 	/* [W] 5.4.2: First Message: Initiator to Responder */
 
@@ -1315,7 +1324,7 @@
 	memcpy(wgs->wgs_ephemeral_key_priv, privkey, sizeof(privkey));
 	memcpy(wgs->wgs_handshake_hash, hash, sizeof(hash));
 	memcpy(wgs->wgs_chaining_key, ckey, sizeof(ckey));
-	WG_DLOG("%s: sender=%x\n", __func__, wgs->wgs_sender_index);
+	WG_DLOG("%s: sender=%x\n", __func__, wgs->wgs_local_index);
 }
 
 static void
@@ -1330,7 +1339,6 @@
 	struct wg_session *wgs;
 	int error, ret;
 	struct psref psref_peer;
-	struct psref psref_session;
 	uint8_t mac1[WG_MAC_LEN];
 
 	WG_TRACE("init msg received");
@@ -1398,6 +1406,14 @@
 		return;
 	}
 
+	/*
+	 * Lock the peer to serialize access to cookie state.
+	 *
+	 * XXX Can we safely avoid holding the lock across DH?  Take it
+	 * just to verify mac2 and then unlock/DH/lock?
+	 */
+	mutex_enter(wgp->wgp_lock);
+
 	if (__predict_false(wg_is_underload(wg, wgp, WG_MSG_TYPE_INIT))) {
 		WG_TRACE("under load");
 		/*
@@ -1413,13 +1429,13 @@
 			WG_TRACE("sending a cookie message: no cookie included");
 			(void)wg_send_cookie_msg(wg, wgp, wgmi->wgmi_sender,
 			    wgmi->wgmi_mac1, src);
-			goto out_wgp;
+			goto out;
 		}
 		if (!wgp->wgp_last_sent_cookie_valid) {
 			WG_TRACE("sending a cookie message: no cookie sent ever");
 			(void)wg_send_cookie_msg(wg, wgp, wgmi->wgmi_sender,
 			    wgmi->wgmi_mac1, src);
-			goto out_wgp;
+			goto out;
 		}
 		uint8_t mac2[WG_MAC_LEN];
 		wg_algo_mac(mac2, sizeof(mac2), wgp->wgp_last_sent_cookie,
@@ -1427,7 +1443,7 @@
 		    offsetof(struct wg_msg_init, wgmi_mac2), NULL, 0);
 		if (!consttime_memequal(mac2, wgmi->wgmi_mac2, sizeof(mac2))) {
 			WG_DLOG("mac2 is invalid\n");
-			goto out_wgp;
+			goto out;
 		}
 		WG_TRACE("under load, but continue to sending");
 	}
@@ -1444,37 +1460,11 @@
 	if (error != 0) {
 		WG_LOG_RATECHECK(&wgp->wgp_ppsratecheck, LOG_DEBUG,
 		    "wg_algo_aead_dec for timestamp failed\n");
-		goto out_wgp;
+		goto out;
 	}
 	/* Hi := HASH(Hi || msg.timestamp) */
 	wg_algo_hash(hash, wgmi->wgmi_timestamp, sizeof(wgmi->wgmi_timestamp));
 
-	wgs = wg_lock_unstable_session(wgp);
-	if (wgs->wgs_state == WGS_STATE_DESTROYING) {
-		/*
-		 * We can assume that the peer doesn't have an
-		 * established session, so clear it now.  If the timer
-		 * fired, tough -- it won't have any effect unless we
-		 * manage to transition back to WGS_STATE_DESTROYING.
-		 */
-		WG_TRACE("Session destroying, but force to clear");
-		callout_stop(&wgp->wgp_session_dtor_timer);
-		wg_clear_states(wgs);
-		wgs->wgs_state = WGS_STATE_UNKNOWN;
-	}
-	if (wgs->wgs_state == WGS_STATE_INIT_ACTIVE) {
-		WG_TRACE("Sesssion already initializing, ignoring the message");
-		mutex_exit(wgs->wgs_lock);
-		goto out_wgp;
-	}
-	if (wgs->wgs_state == WGS_STATE_INIT_PASSIVE) {
-		WG_TRACE("Sesssion already initializing, destroying old states");
-		wg_clear_states(wgs);
-	}
-	wgs->wgs_state = WGS_STATE_INIT_PASSIVE;
-	wg_get_session(wgs, &psref_session);
-	mutex_exit(wgs->wgs_lock);
-
 	/*
 	 * [W] 5.1 "The responder keeps track of the greatest timestamp
 	 *      received per peer and discards packets containing
@@ -1489,6 +1479,37 @@
 	}
 	memcpy(wgp->wgp_timestamp_latest_init, timestamp, sizeof(timestamp));
 
+	/*
+	 * Message is good -- we're committing to handle it now, unless
+	 * we were already initiating a session.
+	 */
+	wgs = wgp->wgp_session_unstable;
+	switch (wgs->wgs_state) {
+	case WGS_STATE_UNKNOWN:		/* new session initiated by peer */
+		wg_get_session_index(wg, wgs);
+		break;
+	case WGS_STATE_INIT_ACTIVE:	/* we're already initiating, drop */
+		WG_TRACE("Session already initializing, ignoring the message");
+		goto out;
+	case WGS_STATE_INIT_PASSIVE:	/* peer is retrying, start over */
+		WG_TRACE("Session already initializing, destroying old states");
+		wg_clear_states(wgs);
+		/* keep session index */
+		break;
+	case WGS_STATE_ESTABLISHED:	/* can't happen */
+		panic("unstable session can't be established");
+		break;
+	case WGS_STATE_DESTROYING:	/* rekey initiated by peer */
+		WG_TRACE("Session destroying, but force to clear");
+		callout_stop(&wgp->wgp_session_dtor_timer);
+		wg_clear_states(wgs);
+		/* keep session index */
+		break;
+	default:
+		panic("invalid session state: %d", wgs->wgs_state);
+	}
+	wgs->wgs_state = WGS_STATE_INIT_PASSIVE;
+
 	memcpy(wgs->wgs_handshake_hash, hash, sizeof(hash));
 	memcpy(wgs->wgs_chaining_key, ckey, sizeof(ckey));
 	memcpy(wgs->wgs_ephemeral_key_peer, wgmi->wgmi_ephemeral,
@@ -1496,37 +1517,16 @@
 
 	wg_update_endpoint_if_necessary(wgp, src);
 
-	(void)wg_send_handshake_msg_resp(wg, wgp, wgmi);
+	(void)wg_send_handshake_msg_resp(wg, wgp, wgs, wgmi);
 
 	wg_calculate_keys(wgs, false);
 	wg_clear_states(wgs);
 
-	wg_put_session(wgs, &psref_session);
-	wg_put_peer(wgp, &psref_peer);
-	return;
-
 out:
-	mutex_enter(wgs->wgs_lock);
-	KASSERT(wgs->wgs_state == WGS_STATE_INIT_PASSIVE);
-	wgs->wgs_state = WGS_STATE_UNKNOWN;
-	mutex_exit(wgs->wgs_lock);
-	wg_put_session(wgs, &psref_session);
-out_wgp:
+	mutex_exit(wgp->wgp_lock);
 	wg_put_peer(wgp, &psref_peer);
 }
 
-static void
-wg_schedule_handshake_timeout_timer(struct wg_peer *wgp)
-{
-
-	mutex_enter(wgp->wgp_lock);
-	if (__predict_true(wgp->wgp_state != WGP_STATE_DESTROYING)) {
-		callout_schedule(&wgp->wgp_handshake_timeout_timer,
-		    MIN(wg_rekey_timeout, INT_MAX/hz) * hz);
-	}
-	mutex_exit(wgp->wgp_lock);
-}
-
 static struct socket *
 wg_get_so_by_af(struct wg_worker *wgw, const int af)
 {
@@ -1585,27 +1585,32 @@
 	struct mbuf *m;
 	struct wg_msg_init *wgmi;
 	struct wg_session *wgs;
-	struct psref psref;
 
-	wgs = wg_lock_unstable_session(wgp);
-	if (wgs->wgs_state == WGS_STATE_DESTROYING) {
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
+	wgs = wgp->wgp_session_unstable;
+	/* XXX pull dispatch out into wg_task_send_init_message */
+	switch (wgs->wgs_state) {
+	case WGS_STATE_UNKNOWN:		/* new session initiated by us */
+		wg_get_session_index(wg, wgs);
+		break;
+	case WGS_STATE_INIT_ACTIVE:	/* we're already initiating, stop */
+		WG_TRACE("Session already initializing, skip starting new one");
+		return EBUSY;
+	case WGS_STATE_INIT_PASSIVE:	/* peer was trying -- XXX what now? */
+		WG_TRACE("Session already initializing, destroying old states");
+		wg_clear_states(wgs);
+		/* keep session index */
+		break;
+	case WGS_STATE_ESTABLISHED:	/* can't happen */
+		panic("unstable session can't be established");
+		break;
+	case WGS_STATE_DESTROYING:	/* rekey initiated by us too early */
 		WG_TRACE("Session destroying");
-		mutex_exit(wgs->wgs_lock);
 		/* XXX should wait? */
 		return EBUSY;
 	}
-	if (wgs->wgs_state == WGS_STATE_INIT_ACTIVE) {
-		WG_TRACE("Sesssion already initializing, skip starting a new one");
-		mutex_exit(wgs->wgs_lock);
-		return EBUSY;
-	}
-	if (wgs->wgs_state == WGS_STATE_INIT_PASSIVE) {
-		WG_TRACE("Sesssion already initializing, destroying old states");
-		wg_clear_states(wgs);
-	}
 	wgs->wgs_state = WGS_STATE_INIT_ACTIVE;
-	wg_get_session(wgs, &psref);
-	mutex_exit(wgs->wgs_lock);
 
 	m = m_gethdr(M_WAIT, MT_DATA);
 	m->m_pkthdr.len = m->m_len = sizeof(*wgmi);
@@ -1618,36 +1623,35 @@
 
 		if (wgp->wgp_handshake_start_time == 0)
 			wgp->wgp_handshake_start_time = time_uptime;
-		wg_schedule_handshake_timeout_timer(wgp);
+		callout_schedule(&wgp->wgp_handshake_timeout_timer,
+		    MIN(wg_rekey_timeout, INT_MAX/hz) * hz);
 	} else {
-		mutex_enter(wgs->wgs_lock);
-		KASSERT(wgs->wgs_state == WGS_STATE_INIT_ACTIVE);
-		wgs->wgs_state = WGS_STATE_UNKNOWN;
-		mutex_exit(wgs->wgs_lock);
+		wg_put_session_index(wg, wgs);
 	}
-	wg_put_session(wgs, &psref);
 
 	return error;
 }
 
 static void
 wg_fill_msg_resp(struct wg_softc *wg, struct wg_peer *wgp,
-    struct wg_msg_resp *wgmr, const struct wg_msg_init *wgmi)
+    struct wg_session *wgs, struct wg_msg_resp *wgmr,
+    const struct wg_msg_init *wgmi)
 {
 	uint8_t ckey[WG_CHAINING_KEY_LEN]; /* [W] 5.4.3: Cr */
 	uint8_t hash[WG_HASH_LEN]; /* [W] 5.4.3: Hr */
 	uint8_t cipher_key[WG_KDF_OUTPUT_LEN];
 	uint8_t pubkey[WG_EPHEMERAL_KEY_LEN];
 	uint8_t privkey[WG_EPHEMERAL_KEY_LEN];
-	struct wg_session *wgs;
-	struct psref psref;
 
-	wgs = wg_get_unstable_session(wgp, &psref);
+	KASSERT(mutex_owned(wgp->wgp_lock));
+	KASSERT(wgs == wgp->wgp_session_unstable);
+	KASSERT(wgs->wgs_state == WGS_STATE_INIT_PASSIVE);
+
 	memcpy(hash, wgs->wgs_handshake_hash, sizeof(hash));
 	memcpy(ckey, wgs->wgs_chaining_key, sizeof(ckey));
 
 	wgmr->wgmr_type = htole32(WG_MSG_TYPE_RESP);
-	wgmr->wgmr_sender = wg_assign_sender_index(wg, wgs);
+	wgmr->wgmr_sender = wgs->wgs_local_index;
 	wgmr->wgmr_receiver = wgmi->wgmi_sender;
 
 	/* [W] 5.4.3 Second Message: Responder to Initiator */
@@ -1718,21 +1722,26 @@
 	memcpy(wgs->wgs_chaining_key, ckey, sizeof(ckey));
 	memcpy(wgs->wgs_ephemeral_key_pub, pubkey, sizeof(pubkey));
 	memcpy(wgs->wgs_ephemeral_key_priv, privkey, sizeof(privkey));
-	wgs->wgs_receiver_index = wgmi->wgmi_sender;
-	WG_DLOG("sender=%x\n", wgs->wgs_sender_index);
-	WG_DLOG("receiver=%x\n", wgs->wgs_receiver_index);
-	wg_put_session(wgs, &psref);
+	wgs->wgs_remote_index = wgmi->wgmi_sender;
+	WG_DLOG("sender=%x\n", wgs->wgs_local_index);
+	WG_DLOG("receiver=%x\n", wgs->wgs_remote_index);
 }
 
 static void
 wg_swap_sessions(struct wg_peer *wgp)
 {
+	struct wg_session *wgs, *wgs_prev;
 
 	KASSERT(mutex_owned(wgp->wgp_lock));
 
-	wgp->wgp_session_unstable = atomic_swap_ptr(&wgp->wgp_session_stable,
-	    wgp->wgp_session_unstable);
-	KASSERT(wgp->wgp_session_stable->wgs_state == WGS_STATE_ESTABLISHED);
+	wgs = wgp->wgp_session_unstable;
+	KASSERT(wgs->wgs_state == WGS_STATE_ESTABLISHED);
+
+	wgs_prev = wgp->wgp_session_stable;
+	KASSERT(wgs_prev->wgs_state == WGS_STATE_ESTABLISHED ||
+	    wgs_prev->wgs_state == WGS_STATE_UNKNOWN);
+	atomic_store_release(&wgp->wgp_session_stable, wgs);
+	wgp->wgp_session_unstable = wgs_prev;
 }
 
 static void
@@ -1772,6 +1781,14 @@
 
 	wgp = wgs->wgs_peer;
 
+	mutex_enter(wgp->wgp_lock);
+
+	/* If we weren't waiting for a handshake response, drop it.  */
+	if (wgs->wgs_state != WGS_STATE_INIT_ACTIVE) {
+		WG_TRACE("peer sent spurious handshake response, ignoring");
+		goto out;
+	}
+
 	if (__predict_false(wg_is_underload(wg, wgp, WG_MSG_TYPE_RESP))) {
 		WG_TRACE("under load");
 		/*
@@ -1865,9 +1882,10 @@
 
 	memcpy(wgs->wgs_handshake_hash, hash, sizeof(wgs->wgs_handshake_hash));
 	memcpy(wgs->wgs_chaining_key, ckey, sizeof(wgs->wgs_chaining_key));
-	wgs->wgs_receiver_index = wgmr->wgmr_sender;
-	WG_DLOG("receiver=%x\n", wgs->wgs_receiver_index);
+	wgs->wgs_remote_index = wgmr->wgmr_sender;
+	WG_DLOG("receiver=%x\n", wgs->wgs_remote_index);
 
+	KASSERT(wgs->wgs_state == WGS_STATE_INIT_ACTIVE);
 	wgs->wgs_state = WGS_STATE_ESTABLISHED;
 	wgs->wgs_time_established = time_uptime;
 	wgs->wgs_time_last_data_sent = 0;
@@ -1876,18 +1894,15 @@
 	wg_clear_states(wgs);
 	WG_TRACE("WGS_STATE_ESTABLISHED");
 
-	callout_halt(&wgp->wgp_handshake_timeout_timer, NULL);
+	callout_stop(&wgp->wgp_handshake_timeout_timer);
 
-	mutex_enter(wgp->wgp_lock);
 	wg_swap_sessions(wgp);
+	KASSERT(wgs == wgp->wgp_session_stable);
 	wgs_prev = wgp->wgp_session_unstable;
-	mutex_enter(wgs_prev->wgs_lock);
-
 	getnanotime(&wgp->wgp_last_handshake_time);
 	wgp->wgp_handshake_start_time = 0;
 	wgp->wgp_last_sent_mac1_valid = false;
 	wgp->wgp_last_sent_cookie_valid = false;
-	mutex_exit(wgp->wgp_lock);
 
 	wg_schedule_rekey_timer(wgp);
 
@@ -1907,28 +1922,40 @@
 	WG_TRACE("softint scheduled");
 
 	if (wgs_prev->wgs_state == WGS_STATE_ESTABLISHED) {
+		/* Wait for wg_get_stable_session to drain.  */
+		pserialize_perform(wgp->wgp_psz);
+
+		/* Transition ESTABLISHED->DESTROYING.  */
 		wgs_prev->wgs_state = WGS_STATE_DESTROYING;
+
 		/* We can't destroy the old session immediately */
 		wg_schedule_session_dtor_timer(wgp);
+	} else {
+		KASSERTMSG(wgs_prev->wgs_state == WGS_STATE_UNKNOWN,
+		    "state=%d", wgs_prev->wgs_state);
 	}
-	mutex_exit(wgs_prev->wgs_lock);
 
 out:
+	mutex_exit(wgp->wgp_lock);
 	wg_put_session(wgs, &psref);
 }
 
 static int
 wg_send_handshake_msg_resp(struct wg_softc *wg, struct wg_peer *wgp,
-    const struct wg_msg_init *wgmi)
+    struct wg_session *wgs, const struct wg_msg_init *wgmi)
 {
 	int error;
 	struct mbuf *m;
 	struct wg_msg_resp *wgmr;
 
+	KASSERT(mutex_owned(wgp->wgp_lock));
+	KASSERT(wgs == wgp->wgp_session_unstable);
+	KASSERT(wgs->wgs_state == WGS_STATE_INIT_PASSIVE);
+
 	m = m_gethdr(M_WAIT, MT_DATA);
 	m->m_pkthdr.len = m->m_len = sizeof(*wgmr);
 	wgmr = mtod(m, struct wg_msg_resp *);
-	wg_fill_msg_resp(wg, wgp, wgmr, wgmi);
+	wg_fill_msg_resp(wg, wgp, wgs, wgmr, wgmi);
 
 	error = wg->wg_ops->send_hs_msg(wgp, m);
 	if (error == 0)
@@ -1962,6 +1989,8 @@
 	size_t addrlen;
 	uint16_t uh_sport; /* be */
 
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
 	wgmc->wgmc_type = htole32(WG_MSG_TYPE_COOKIE);
 	wgmc->wgmc_receiver = sender;
 	cprng_fast(wgmc->wgmc_salt, sizeof(wgmc->wgmc_salt));
@@ -2019,6 +2048,8 @@
 	struct mbuf *m;
 	struct wg_msg_cookie *wgmc;
 
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
 	m = m_gethdr(M_WAIT, MT_DATA);
 	m->m_pkthdr.len = m->m_len = sizeof(*wgmc);
 	wgmc = mtod(m, struct wg_msg_cookie *);
@@ -2053,6 +2084,8 @@
 wg_calculate_keys(struct wg_session *wgs, const bool initiator)
 {
 
+	KASSERT(mutex_owned(wgs->wgs_peer->wgp_lock));
+
 	/*
 	 * [W] 5.4.5: Ti^send = Tr^recv, Ti^recv = Tr^send := KDF2(Ci = Cr, e)
 	 */
@@ -2103,6 +2136,8 @@
 wg_clear_states(struct wg_session *wgs)
 {
 
+	KASSERT(mutex_owned(wgs->wgs_peer->wgp_lock));
+
 	wgs->wgs_send_counter = 0;
 	sliwin_reset(&wgs->wgs_recvwin->window);
 
@@ -2123,8 +2158,11 @@
 
 	int s = pserialize_read_enter();
 	wgs = thmap_get(wg->wg_sessions_byindex, &index, sizeof index);
-	if (wgs != NULL)
+	if (wgs != NULL) {
+		KASSERT(atomic_load_relaxed(&wgs->wgs_state) !=
+		    WGS_STATE_UNKNOWN);
 		psref_acquire(psref, &wgs->wgs_psref, wg_psref_class);
+	}
 	pserialize_read_exit(s);
 
 	return wgs;
@@ -2181,19 +2219,16 @@
 static void
 wg_change_endpoint(struct wg_peer *wgp, const struct sockaddr *new)
 {
+	struct wg_sockaddr *wgsa_prev;
 
-	KASSERT(mutex_owned(wgp->wgp_lock));
-
 	WG_TRACE("Changing endpoint");
 
 	memcpy(wgp->wgp_endpoint0, new, new->sa_len);
-#ifndef __HAVE_ATOMIC_AS_MEMBAR	/* store-release */
-	membar_exit();
-#endif
-	wgp->wgp_endpoint0 = atomic_swap_ptr(&wgp->wgp_endpoint,
-	    wgp->wgp_endpoint0);
-	wgp->wgp_endpoint_available = true;
-	wgp->wgp_endpoint_changing = true;
+	wgsa_prev = wgp->wgp_endpoint;
+	atomic_store_release(&wgp->wgp_endpoint, wgp->wgp_endpoint0);
+	wgp->wgp_endpoint0 = wgsa_prev;
+	atomic_store_release(&wgp->wgp_endpoint_available, true);
+
 	wg_schedule_peer_task(wgp, WGP_TASK_ENDPOINT_CHANGED);
 }
 
@@ -2281,13 +2316,6 @@
 
 	WG_TRACE("enter");
 
-	mutex_enter(wgp->wgp_lock);
-	if (__predict_false(wgp->wgp_state == WGP_STATE_DESTROYING)) {
-		mutex_exit(wgp->wgp_lock);
-		return;
-	}
-	mutex_exit(wgp->wgp_lock);
-
 	wg_schedule_peer_task(wgp, WGP_TASK_DESTROY_PREV_SESSION);
 }
 
@@ -2337,12 +2365,10 @@
 	 */
 	if (__predict_false(sockaddr_cmp(src, wgsatosa(wgsa)) != 0 ||
 		!sockaddr_port_match(src, wgsatosa(wgsa)))) {
-		mutex_enter(wgp->wgp_lock);
 		/* XXX We can't change the endpoint twice in a short period */
-		if (!wgp->wgp_endpoint_changing) {
+		if (atomic_swap_uint(&wgp->wgp_endpoint_changing, 1) == 0) {
 			wg_change_endpoint(wgp, src);
 		}
-		mutex_exit(wgp->wgp_lock);
 	}
 
 	wg_put_sa(wgp, wgsa, &psref);
@@ -2357,6 +2383,7 @@
 	size_t encrypted_len, decrypted_len;
 	struct wg_session *wgs;
 	struct wg_peer *wgp;
+	int state;
 	size_t mlen;
 	struct psref psref;
 	int error, af;
@@ -2369,14 +2396,43 @@
 	KASSERT(wgmd->wgmd_type == htole32(WG_MSG_TYPE_DATA));
 	WG_TRACE("data");
 
+	/* Find the putative session, or drop.  */
 	wgs = wg_lookup_session_by_index(wg, wgmd->wgmd_receiver, &psref);
 	if (wgs == NULL) {
 		WG_TRACE("No session found");
 		m_freem(m);
 		return;
 	}
+
+	/*
+	 * We are only ready to handle data when in INIT_PASSIVE,
+	 * ESTABLISHED, or DESTROYING.  All transitions out of that
+	 * state dissociate the session index and drain psrefs.
+	 */
+	state = atomic_load_relaxed(&wgs->wgs_state);
+	switch (state) {
+	case WGS_STATE_UNKNOWN:
+		panic("wg session %p in unknown state has session index %u",
+		    wgs, wgmd->wgmd_receiver);
+	case WGS_STATE_INIT_ACTIVE:
+		WG_TRACE("not yet ready for data");
+		goto out;
+	case WGS_STATE_INIT_PASSIVE:
+	case WGS_STATE_ESTABLISHED:
+	case WGS_STATE_DESTROYING:
+		break;
+	}
+
+	/*
+	 * Get the peer, for rate-limited logs (XXX MPSAFE, dtrace) and
+	 * to update the endpoint if authentication succeeds.
+	 */
 	wgp = wgs->wgs_peer;
 
+	/*
+	 * Reject outrageously wrong sequence numbers before doing any
+	 * crypto work or taking any locks.
+	 */
 	error = sliwin_check_fast(&wgs->wgs_recvwin->window,
 	    le64toh(wgmd->wgmd_counter));
 	if (error) {
@@ -2386,14 +2442,13 @@
 		goto out;
 	}
 
+	/* Ensure the payload and authenticator are contiguous.  */
 	mlen = m_length(m);
 	encrypted_len = mlen - sizeof(*wgmd);
-
 	if (encrypted_len < WG_AUTHTAG_LEN) {
 		WG_DLOG("Short encrypted_len: %lu\n", encrypted_len);
 		goto out;
 	}
-
 	success = m_ensure_contig(&m, sizeof(*wgmd) + encrypted_len);
 	if (success) {
 		encrypted_buf = mtod(m, char *) + sizeof(*wgmd);
@@ -2410,14 +2465,17 @@
 	KASSERT(m->m_len >= sizeof(*wgmd));
 	wgmd = mtod(m, struct wg_msg_data *);
 
+	/*
+	 * Get a buffer for the plaintext.  Add WG_AUTHTAG_LEN to avoid
+	 * a zero-length buffer (XXX).  Drop if plaintext is longer
+	 * than MCLBYTES (XXX).
+	 */
 	decrypted_len = encrypted_len - WG_AUTHTAG_LEN;
 	if (decrypted_len > MCLBYTES) {
 		/* FIXME handle larger data than MCLBYTES */
 		WG_DLOG("couldn't handle larger data than MCLBYTES\n");
 		goto out;
 	}
-
-	/* To avoid zero length */
 	n = wg_get_mbuf(0, decrypted_len + WG_AUTHTAG_LEN);
 	if (n == NULL) {
 		WG_DLOG("wg_get_mbuf failed\n");
@@ -2425,6 +2483,7 @@
 	}
 	decrypted_buf = mtod(n, char *);
 
+	/* Decrypt and verify the packet.  */
 	WG_DLOG("mlen=%lu, encrypted_len=%lu\n", mlen, encrypted_len);
 	error = wg_algo_aead_dec(decrypted_buf,
 	    encrypted_len - WG_AUTHTAG_LEN /* can be 0 */,
@@ -2438,6 +2497,7 @@
 	}
 	WG_DLOG("outsize=%u\n", (u_int)decrypted_len);
 
+	/* Packet is genuine.  Reject it if a replay or just too old.  */
 	mutex_enter(&wgs->wgs_recvwin->lock);
 	error = sliwin_update(&wgs->wgs_recvwin->window,
 	    le64toh(wgmd->wgmd_counter));
@@ -2450,19 +2510,31 @@
 		goto out;
 	}
 
+	/* We're done with m now; free it and chuck the pointers.  */
 	m_freem(m);
 	m = NULL;
 	wgmd = NULL;
 
+	/*
+	 * Validate the encapsulated packet header and get the address
+	 * family, or drop.
+	 */
 	ok = wg_validate_inner_packet(decrypted_buf, decrypted_len, &af);
 	if (!ok) {
-		/* something wrong... */
 		m_freem(n);
 		goto out;
 	}
 
+	/*
+	 * The packet is genuine.  Update the peer's endpoint if the
+	 * source address changed.
+	 *
+	 * XXX How to prevent DoS by replaying genuine packets from the
+	 * wrong source address?
+	 */
 	wg_update_endpoint_if_necessary(wgp, src);
 
+	/* Submit it into our network stack if routable.  */
 	ok = wg_validate_route(wg, wgp, af, decrypted_buf);
 	if (ok) {
 		wg->wg_ops->input(&wg->wg_if, n, af);
@@ -2477,40 +2549,14 @@
 	}
 	n = NULL;
 
-	if (wgs->wgs_state == WGS_STATE_INIT_PASSIVE) {
-		struct wg_session *wgs_prev;
-
-		KASSERT(wgs == wgp->wgp_session_unstable);
-		wgs->wgs_state = WGS_STATE_ESTABLISHED;
-		wgs->wgs_time_established = time_uptime;
-		wgs->wgs_time_last_data_sent = 0;
-		wgs->wgs_is_initiator = false;
-		WG_TRACE("WGS_STATE_ESTABLISHED");
-
-		mutex_enter(wgp->wgp_lock);
-		wg_swap_sessions(wgp);
-		wgs_prev = wgp->wgp_session_unstable;
-		mutex_enter(wgs_prev->wgs_lock);
-		getnanotime(&wgp->wgp_last_handshake_time);
-		wgp->wgp_handshake_start_time = 0;
-		wgp->wgp_last_sent_mac1_valid = false;
-		wgp->wgp_last_sent_cookie_valid = false;
-		mutex_exit(wgp->wgp_lock);
-
-		if (wgs_prev->wgs_state == WGS_STATE_ESTABLISHED) {
-			wgs_prev->wgs_state = WGS_STATE_DESTROYING;
-			/* We can't destroy the old session immediately */
-			wg_schedule_session_dtor_timer(wgp);
-		} else {
-			wg_clear_states(wgs_prev);
-			wgs_prev->wgs_state = WGS_STATE_UNKNOWN;
-		}
-		mutex_exit(wgs_prev->wgs_lock);
-
-		/* Anyway run a softint to flush pending packets */
-		kpreempt_disable();
-		softint_schedule(wgp->wgp_si);
-		kpreempt_enable();
+	/* Update the state machine if necessary.  */
+	if (__predict_false(state == WGS_STATE_INIT_PASSIVE)) {
+		/*
+		 * We were waiting for the initiator to send their
+		 * first data transport message, and that has happened.
+		 * Schedule a task to establish this session.
+		 */
+		wg_schedule_peer_task(wgp, WGP_TASK_ESTABLISH_SESSION);
 	} else {
 		if (__predict_false(wg_need_to_send_init_message(wgs))) {
 			wg_schedule_peer_task(wgp, WGP_TASK_SEND_INIT_MESSAGE);
@@ -2555,18 +2601,24 @@
 	uint8_t cookie[WG_COOKIE_LEN];
 
 	WG_TRACE("cookie msg received");
+
+	/* Find the putative session.  */
 	wgs = wg_lookup_session_by_index(wg, wgmc->wgmc_receiver, &psref);
 	if (wgs == NULL) {
 		WG_TRACE("No session found");
 		return;
 	}
+
+	/* Lock the peer so we can update the cookie state.  */
 	wgp = wgs->wgs_peer;
+	mutex_enter(wgp->wgp_lock);
 
 	if (!wgp->wgp_last_sent_mac1_valid) {
 		WG_TRACE("No valid mac1 sent (or expired)");
 		goto out;
 	}
 
+	/* Decrypt the cookie and store it for later handshake retry.  */
 	wg_algo_mac_cookie(key, sizeof(key), wgp->wgp_pubkey,
 	    sizeof(wgp->wgp_pubkey));
 	error = wg_algo_xaead_dec(cookie, sizeof(cookie), key,
@@ -2587,6 +2639,7 @@
 	wgp->wgp_latest_cookie_time = time_uptime;
 	memcpy(wgp->wgp_latest_cookie, cookie, sizeof(wgp->wgp_latest_cookie));
 out:
+	mutex_exit(wgp->wgp_lock);
 	wg_put_session(wgs, &psref);
 }
 
@@ -2730,60 +2783,149 @@
 static void
 wg_task_send_init_message(struct wg_softc *wg, struct wg_peer *wgp)
 {
-	struct psref psref;
 	struct wg_session *wgs;
 
 	WG_TRACE("WGP_TASK_SEND_INIT_MESSAGE");
 
-	if (!wgp->wgp_endpoint_available) {
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
+	if (!atomic_load_acquire(&wgp->wgp_endpoint_available)) {
 		WGLOG(LOG_DEBUG, "No endpoint available\n");
 		/* XXX should do something? */
 		return;
 	}
 
-	wgs = wg_get_stable_session(wgp, &psref);
+	wgs = wgp->wgp_session_stable;
 	if (wgs->wgs_state == WGS_STATE_UNKNOWN) {
-		wg_put_session(wgs, &psref);
+		/* XXX What if the unstable session is already INIT_ACTIVE?  */
 		wg_send_handshake_msg_init(wg, wgp);
 	} else {
-		wg_put_session(wgs, &psref);
 		/* rekey */
-		wgs = wg_get_unstable_session(wgp, &psref);
+		wgs = wgp->wgp_session_unstable;
 		if (wgs->wgs_state != WGS_STATE_INIT_ACTIVE)
 			wg_send_handshake_msg_init(wg, wgp);
-		wg_put_session(wgs, &psref);
 	}
 }
 
 static void
+wg_task_retry_handshake(struct wg_softc *wg, struct wg_peer *wgp)
+{
+	struct wg_session *wgs;
+
+	WG_TRACE("WGP_TASK_RETRY_HANDSHAKE");
+
+	KASSERT(mutex_owned(wgp->wgp_lock));
+	KASSERT(wgp->wgp_handshake_start_time != 0);
+
+	wgs = wgp->wgp_session_unstable;
+	if (wgs->wgs_state != WGS_STATE_INIT_ACTIVE)
+		return;
+
+	/*
+	 * XXX no real need to assign a new index here, but we do need
+	 * to transition to UNKNOWN temporarily
+	 */
+	wg_put_session_index(wg, wgs);
+
+	/* [W] 6.4 Handshake Initiation Retransmission */
+	if ((time_uptime - wgp->wgp_handshake_start_time) >
+	    wg_rekey_attempt_time) {
+		/* Give up handshaking */
+		wgp->wgp_handshake_start_time = 0;
+		WG_TRACE("give up");
+
+		/*
+		 * If a new data packet comes, handshaking will be retried
+		 * and a new session would be established at that time,
+		 * however we don't want to send pending packets then.
+		 */
+		wg_purge_pending_packets(wgp);
+		return;
+	}
+
+	wg_task_send_init_message(wg, wgp);
+}
+
+static void
+wg_task_establish_session(struct wg_softc *wg, struct wg_peer *wgp)
+{
+	struct wg_session *wgs, *wgs_prev;
+
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
+	wgs = wgp->wgp_session_unstable;
+	if (wgs->wgs_state != WGS_STATE_INIT_PASSIVE)
+		/* XXX Can this happen?  */
+		return;
+
+	wgs->wgs_state = WGS_STATE_ESTABLISHED;
+	wgs->wgs_time_established = time_uptime;
+	wgs->wgs_time_last_data_sent = 0;
+	wgs->wgs_is_initiator = false;
+	WG_TRACE("WGS_STATE_ESTABLISHED");
+
+	wg_swap_sessions(wgp);
+	KASSERT(wgs == wgp->wgp_session_stable);
+	wgs_prev = wgp->wgp_session_unstable;
+	getnanotime(&wgp->wgp_last_handshake_time);
+	wgp->wgp_handshake_start_time = 0;
+	wgp->wgp_last_sent_mac1_valid = false;
+	wgp->wgp_last_sent_cookie_valid = false;
+
+	if (wgs_prev->wgs_state == WGS_STATE_ESTABLISHED) {
+		/* Wait for wg_get_stable_session to drain.  */
+		pserialize_perform(wgp->wgp_psz);
+
+		/* Transition ESTABLISHED->DESTROYING.  */
+		wgs_prev->wgs_state = WGS_STATE_DESTROYING;
+
+		/* We can't destroy the old session immediately */
+		wg_schedule_session_dtor_timer(wgp);
+	} else {
+		KASSERTMSG(wgs_prev->wgs_state == WGS_STATE_UNKNOWN,
+		    "state=%d", wgs_prev->wgs_state);
+		wg_clear_states(wgs_prev);
+		wgs_prev->wgs_state = WGS_STATE_UNKNOWN;
+	}
+
+	/* Anyway run a softint to flush pending packets */
+	kpreempt_disable();
+	softint_schedule(wgp->wgp_si);
+	kpreempt_enable();
+}
+
+static void
 wg_task_endpoint_changed(struct wg_softc *wg, struct wg_peer *wgp)
 {
 
 	WG_TRACE("WGP_TASK_ENDPOINT_CHANGED");
 
-	mutex_enter(wgp->wgp_lock);
-	if (wgp->wgp_endpoint_changing) {
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
+	if (atomic_load_relaxed(&wgp->wgp_endpoint_changing)) {
 		pserialize_perform(wgp->wgp_psz);
 		psref_target_destroy(&wgp->wgp_endpoint0->wgsa_psref,
 		    wg_psref_class);
 		psref_target_init(&wgp->wgp_endpoint0->wgsa_psref,
 		    wg_psref_class);
-		wgp->wgp_endpoint_changing = false;
+		atomic_store_release(&wgp->wgp_endpoint_changing, 0);
 	}
-	mutex_exit(wgp->wgp_lock);
 }
 
 static void
 wg_task_send_keepalive_message(struct wg_softc *wg, struct wg_peer *wgp)
 {
-	struct psref psref;
 	struct wg_session *wgs;
 
 	WG_TRACE("WGP_TASK_SEND_KEEPALIVE_MESSAGE");
 
-	wgs = wg_get_stable_session(wgp, &psref);
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
+	wgs = wgp->wgp_session_stable;
+	if (wgs->wgs_state != WGS_STATE_ESTABLISHED)
+		return;
+
 	wg_send_keepalive_msg(wgp, wgs);
-	wg_put_session(wgs, &psref);
 }
 
 static void
@@ -2793,18 +2935,12 @@
 
 	WG_TRACE("WGP_TASK_DESTROY_PREV_SESSION");
 
-	mutex_enter(wgp->wgp_lock);
+	KASSERT(mutex_owned(wgp->wgp_lock));
+
 	wgs = wgp->wgp_session_unstable;
-	mutex_enter(wgs->wgs_lock);
 	if (wgs->wgs_state == WGS_STATE_DESTROYING) {
-		pserialize_perform(wgp->wgp_psz);
-		psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
-		psref_target_init(&wgs->wgs_psref, wg_psref_class);
-		wg_clear_states(wgs);
-		wgs->wgs_state = WGS_STATE_UNKNOWN;
+		wg_put_session_index(wg, wgs);
 	}
-	mutex_exit(wgs->wgs_lock);
-	mutex_exit(wgp->wgp_lock);
 }
 
 static void
@@ -2831,14 +2967,20 @@
 
 		WG_DLOG("tasks=%x\n", tasks);
 
+		mutex_enter(wgp->wgp_lock);
 		if (ISSET(tasks, WGP_TASK_SEND_INIT_MESSAGE))
 			wg_task_send_init_message(wg, wgp);
+		if (ISSET(tasks, WGP_TASK_RETRY_HANDSHAKE))
+			wg_task_retry_handshake(wg, wgp);
+		if (ISSET(tasks, WGP_TASK_ESTABLISH_SESSION))
+			wg_task_establish_session(wg, wgp);
 		if (ISSET(tasks, WGP_TASK_ENDPOINT_CHANGED))
 			wg_task_endpoint_changed(wg, wgp);
 		if (ISSET(tasks, WGP_TASK_SEND_KEEPALIVE_MESSAGE))
 			wg_task_send_keepalive_message(wg, wgp);
 		if (ISSET(tasks, WGP_TASK_DESTROY_PREV_SESSION))
 			wg_task_destroy_prev_session(wg, wgp);
+		mutex_exit(wgp->wgp_lock);
 
 		/* New tasks may be scheduled during processing tasks */
 		WG_DLOG("wgp_tasks=%d\n", wgp->wgp_tasks);
@@ -3036,7 +3178,7 @@
 
 	wgw = kmem_zalloc(sizeof(struct wg_worker), KM_SLEEP);
 
-	mutex_init(&wgw->wgw_lock, MUTEX_DEFAULT, IPL_NONE);
+	mutex_init(&wgw->wgw_lock, MUTEX_DEFAULT, IPL_SOFTNET);
 	cv_init(&wgw->wgw_cv, ifname);
 	wgw->wgw_todie = false;
 	wgw->wgw_wakeup_reasons = 0;
@@ -3129,11 +3271,10 @@
 	struct mbuf *m;
 	struct psref psref;
 
-	wgs = wg_get_stable_session(wgp, &psref);
-	if (wgs->wgs_state != WGS_STATE_ESTABLISHED) {
+	if ((wgs = wg_get_stable_session(wgp, &psref)) == NULL) {
 		/* XXX how to treat? */
 		WG_TRACE("skipped");
-		goto out;
+		return;
 	}
 	if (wg_session_hit_limits(wgs)) {
 		wg_schedule_peer_task(wgp, WGP_TASK_SEND_INIT_MESSAGE);
@@ -3153,11 +3294,7 @@
 {
 	struct wg_peer *wgp = arg;
 
-	mutex_enter(wgp->wgp_lock);
-	if (__predict_true(wgp->wgp_state != WGP_STATE_DESTROYING)) {
-		wg_schedule_peer_task(wgp, WGP_TASK_SEND_INIT_MESSAGE);
-	}
-	mutex_exit(wgp->wgp_lock);
+	wg_schedule_peer_task(wgp, WGP_TASK_SEND_INIT_MESSAGE);
 }
 
 static void
@@ -3174,47 +3311,10 @@
 wg_handshake_timeout_timer(void *arg)
 {
 	struct wg_peer *wgp = arg;
-	struct wg_session *wgs;
-	struct psref psref;
 
 	WG_TRACE("enter");
 
-	mutex_enter(wgp->wgp_lock);
-	if (__predict_false(wgp->wgp_state == WGP_STATE_DESTROYING)) {
-		mutex_exit(wgp->wgp_lock);
-		return;
-	}
-	mutex_exit(wgp->wgp_lock);
-
-	KASSERT(wgp->wgp_handshake_start_time != 0);
-	wgs = wg_get_unstable_session(wgp, &psref);
-	KASSERT(wgs->wgs_state == WGS_STATE_INIT_ACTIVE);
-
-	/* [W] 6.4 Handshake Initiation Retransmission */
-	if ((time_uptime - wgp->wgp_handshake_start_time) >
-	    wg_rekey_attempt_time) {
-		/* Give up handshaking */
-		wgs->wgs_state = WGS_STATE_UNKNOWN;
-		wg_clear_states(wgs);
-		wgp->wgp_state = WGP_STATE_GIVEUP;
-		wgp->wgp_handshake_start_time = 0;
-		wg_put_session(wgs, &psref);
-		WG_TRACE("give up");
-		/*
-		 * If a new data packet comes, handshaking will be retried
-		 * and a new session would be established at that time,
-		 * however we don't want to send pending packets then.
-		 */
-		wg_purge_pending_packets(wgp);
-		return;
-	}
-
-	/* No response for an initiation message sent, retry handshaking */
-	wgs->wgs_state = WGS_STATE_UNKNOWN;
-	wg_clear_states(wgs);
-	wg_put_session(wgs, &psref);
-
-	wg_schedule_peer_task(wgp, WGP_TASK_SEND_INIT_MESSAGE);
+	wg_schedule_peer_task(wgp, WGP_TASK_RETRY_HANDSHAKE);
 }
 
 static struct wg_peer *
@@ -3225,7 +3325,6 @@
 	wgp = kmem_zalloc(sizeof(*wgp), KM_SLEEP);
 
 	wgp->wgp_sc = wg;
-	wgp->wgp_state = WGP_STATE_INIT;
 	wgp->wgp_q = pcq_create(1024, KM_SLEEP);
 	wgp->wgp_si = softint_establish(SOFTINT_NET, wg_peer_softint, wgp);
 	callout_init(&wgp->wgp_rekey_timer, CALLOUT_MPSAFE);
@@ -3257,23 +3356,21 @@
 	wgs->wgs_peer = wgp;
 	wgs->wgs_state = WGS_STATE_UNKNOWN;
 	psref_target_init(&wgs->wgs_psref, wg_psref_class);
-	wgs->wgs_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE);
 #ifndef __HAVE_ATOMIC64_LOADSTORE
 	mutex_init(&wgs->wgs_send_counter_lock, MUTEX_DEFAULT, IPL_SOFTNET);
 #endif
 	wgs->wgs_recvwin = kmem_zalloc(sizeof(*wgs->wgs_recvwin), KM_SLEEP);
-	mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_NONE);
+	mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_SOFTNET);
 
 	wgs = wgp->wgp_session_unstable;
 	wgs->wgs_peer = wgp;
 	wgs->wgs_state = WGS_STATE_UNKNOWN;
 	psref_target_init(&wgs->wgs_psref, wg_psref_class);
-	wgs->wgs_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE);
 #ifndef __HAVE_ATOMIC64_LOADSTORE
 	mutex_init(&wgs->wgs_send_counter_lock, MUTEX_DEFAULT, IPL_SOFTNET);
 #endif
 	wgs->wgs_recvwin = kmem_zalloc(sizeof(*wgs->wgs_recvwin), KM_SLEEP);
-	mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_NONE);
+	mutex_init(&wgs->wgs_recvwin->lock, MUTEX_DEFAULT, IPL_SOFTNET);
 
 	return wgp;
 }
@@ -3283,8 +3380,6 @@
 {
 	struct wg_session *wgs;
 	struct wg_softc *wg = wgp->wgp_sc;
-	uint32_t index;
-	void *garbage;
 
 	/* Prevent new packets from this peer on any source address.  */
 	rw_enter(wg->wg_rwlock, RW_WRITER);
@@ -3314,26 +3409,12 @@
 	callout_halt(&wgp->wgp_handshake_timeout_timer, NULL);
 	callout_halt(&wgp->wgp_session_dtor_timer, NULL);
 
-	/* Remove the sessions by index.  */
-	if ((index = wgp->wgp_session_stable->wgs_sender_index) != 0) {
-		thmap_del(wg->wg_sessions_byindex, &index, sizeof index);
-		wgp->wgp_session_stable->wgs_sender_index = 0;
-	}
-	if ((index = wgp->wgp_session_unstable->wgs_sender_index) != 0) {
-		thmap_del(wg->wg_sessions_byindex, &index, sizeof index);
-		wgp->wgp_session_unstable->wgs_sender_index = 0;
-	}
-
-	/* Wait for all thmap_gets to complete, and GC.  */
-	garbage = thmap_stage_gc(wg->wg_sessions_byindex);
-	mutex_enter(wgp->wgp_lock);
-	pserialize_perform(wgp->wgp_psz);
-	mutex_exit(wgp->wgp_lock);
-	thmap_gc(wg->wg_sessions_byindex, garbage);
-
 	wgs = wgp->wgp_session_unstable;
-	psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
-	mutex_obj_free(wgs->wgs_lock);
+	if (wgs->wgs_state != WGS_STATE_UNKNOWN) {
+		mutex_enter(wgp->wgp_lock);
+		wg_destroy_session(wg, wgs);
+		mutex_exit(wgp->wgp_lock);
+	}
 	mutex_destroy(&wgs->wgs_recvwin->lock);
 	kmem_free(wgs->wgs_recvwin, sizeof(*wgs->wgs_recvwin));
 #ifndef __HAVE_ATOMIC64_LOADSTORE
@@ -3342,8 +3423,11 @@
 	kmem_free(wgs, sizeof(*wgs));
 
 	wgs = wgp->wgp_session_stable;
-	psref_target_destroy(&wgs->wgs_psref, wg_psref_class);
-	mutex_obj_free(wgs->wgs_lock);
+	if (wgs->wgs_state != WGS_STATE_UNKNOWN) {
+		mutex_enter(wgp->wgp_lock);
+		wg_destroy_session(wg, wgs);
+		mutex_exit(wgp->wgp_lock);
+	}
 	mutex_destroy(&wgs->wgs_recvwin->lock);
 	kmem_free(wgs->wgs_recvwin, sizeof(*wgs->wgs_recvwin));
 #ifndef __HAVE_ATOMIC64_LOADSTORE
@@ -3386,7 +3470,6 @@
 		WG_PEER_WRITER_REMOVE(wgp);
 		wg->wg_npeers--;
 		mutex_enter(wgp->wgp_lock);
-		wgp->wgp_state = WGP_STATE_DESTROYING;
 		pserialize_perform(wgp->wgp_psz);
 		mutex_exit(wgp->wgp_lock);
 		PSLIST_ENTRY_DESTROY(wgp, wgp_peerlist_entry);
@@ -3423,7 +3506,6 @@
 		WG_PEER_WRITER_REMOVE(wgp);
 		wg->wg_npeers--;
 		mutex_enter(wgp->wgp_lock);
-		wgp->wgp_state = WGP_STATE_DESTROYING;
 		pserialize_perform(wgp->wgp_psz);
 		mutex_exit(wgp->wgp_lock);
 		PSLIST_ENTRY_DESTROY(wgp, wgp_peerlist_entry);
@@ -3607,7 +3689,7 @@
 
 	memset(wgmd, 0, sizeof(*wgmd));
 	wgmd->wgmd_type = htole32(WG_MSG_TYPE_DATA);
-	wgmd->wgmd_receiver = wgs->wgs_receiver_index;
+	wgmd->wgmd_receiver = wgs->wgs_remote_index;
 	/* [W] 5.4.6: msg.counter := Nm^send */
 	/* [W] 5.4.6: Nm^send := Nm^send + 1 */
 	wgmd->wgmd_counter = htole64(wg_session_inc_send_counter(wgs));
@@ -3619,31 +3701,32 @@
     const struct rtentry *rt)
 {
 	struct wg_softc *wg = ifp->if_softc;
-	int error = 0;
+	struct wg_peer *wgp = NULL;
+	struct wg_session *wgs = NULL;
+	struct psref wgp_psref, wgs_psref;
 	int bound;
-	struct psref psref;
+	int error;
 
+	bound = curlwp_bind();
+
 	/* TODO make the nest limit configurable via sysctl */
 	error = if_tunnel_check_nesting(ifp, m, 1);
-	if (error != 0) {
-		m_freem(m);
+	if (error) {
 		WGLOG(LOG_ERR, "tunneling loop detected and packet dropped\n");
-		return error;
+		goto out;
 	}
 
-	bound = curlwp_bind();
-
 	IFQ_CLASSIFY(&ifp->if_snd, m, dst->sa_family);
 
 	bpf_mtap_af(ifp, dst->sa_family, m, BPF_D_OUT);
 
 	m->m_flags &= ~(M_BCAST|M_MCAST);
 
-	struct wg_peer *wgp = wg_pick_peer_by_sa(wg, dst, &psref);
+	wgp = wg_pick_peer_by_sa(wg, dst, &wgp_psref);
 	if (wgp == NULL) {
 		WG_TRACE("peer not found");
 		error = EHOSTUNREACH;
-		goto error;
+		goto out;
 	}
 
 	/* Clear checksum-offload flags. */
@@ -3652,14 +3735,12 @@
 
 	if (!pcq_put(wgp->wgp_q, m)) {
 		error = ENOBUFS;
-		goto error;
+		goto out;
 	}
+	m = NULL;		/* consumed */
 
-	struct psref psref_wgs;
-	struct wg_session *wgs;
-	wgs = wg_get_stable_session(wgp, &psref_wgs);
-	if (wgs->wgs_state == WGS_STATE_ESTABLISHED &&
-	    !wg_session_hit_limits(wgs)) {
+	wgs = wg_get_stable_session(wgp, &wgs_psref);
+	if (wgs != NULL && !wg_session_hit_limits(wgs)) {
 		kpreempt_disable();
 		softint_schedule(wgp->wgp_si);
 		kpreempt_enable();
@@ -3668,14 +3749,13 @@
 		wg_schedule_peer_task(wgp, WGP_TASK_SEND_INIT_MESSAGE);
 		WG_TRACE("softint NOT scheduled");
 	}
-	wg_put_session(wgs, &psref_wgs);
-	wg_put_peer(wgp, &psref);
+	error = 0;
 
-	return 0;
-
-error:
+out:
+	if (wgs != NULL)
+		wg_put_session(wgs, &wgs_psref);
 	if (wgp != NULL)
-		wg_put_peer(wgp, &psref);
+		wg_put_peer(wgp, &wgp_psref);
 	if (m != NULL)
 		m_freem(m);
 	curlwp_bindx(bound);