iroh_relay/protos/
handshake.rs

1//! Implements the handshake protocol that authenticates and authorizes clients connecting to the relays.
2//!
3//! The purpose of the handshake is to
4//! 1. Inform the relay of the client's EndpointId
5//! 2. Check that the connecting client owns the secret key for its EndpointId ("is authentic"/"authentication")
6//! 3. Possibly check that the client has access to this relay, if the relay requires authorization.
7//!
8//! Additional complexity comes from the fact that there's two ways that clients can authenticate with
9//! relays.
10//!
11//! One way is via an explicitly sent challenge:
12//!
13//! 1. Once a websocket connection is opened, a client receives a challenge (the `ServerChallenge` frame)
14//! 2. The client sends back what is essentially a signature of that challenge with their secret key
15//!    that matches the EndpointId they have, as well as the EndpointId (the `ClientAuth` frame)
16//!
17//! The second way is very similar to the [Concealed HTTP Auth RFC], and involves send a header that
18//! contains a signature of some shared keying material extracted from TLS ([RFC 5705]).
19//!
20//! The second way can save a full round trip, because the challenge doesn't have to be sent to the client
21//! first, however, it won't always work, as it relies on the keying material extraction feature of TLS,
22//! which is not available in browsers (but might be in the future?) and might break when there's an
23//! HTTPS proxy that doesn't properly deal with this TLS feature.
24//!
25//! [Concealed HTTP Auth RFC]: https://datatracker.ietf.org/doc/rfc9729/
26//! [RFC 5705]: https://datatracker.ietf.org/doc/html/rfc5705
27use bytes::{BufMut, Bytes, BytesMut};
28use data_encoding::BASE32HEX_NOPAD as HEX;
29#[cfg(not(wasm_browser))]
30use http::HeaderValue;
31#[cfg(feature = "server")]
32use iroh_base::Signature;
33use iroh_base::{PublicKey, SecretKey};
34use n0_error::{e, ensure, stack_error};
35use n0_future::{SinkExt, TryStreamExt};
36#[cfg(feature = "server")]
37use rand::CryptoRng;
38use tracing::trace;
39
40use super::{
41    common::{FrameType, FrameTypeError},
42    streams::BytesStreamSink,
43};
44use crate::ExportKeyingMaterial;
45
46/// Domain separation string for the [`ServerChallenge`] signature
47const DOMAIN_SEP_CHALLENGE: &str = "iroh-relay handshake v1 challenge signature";
48
49/// Domain separation label for [`KeyMaterialClientAuth`]'s use of [`ExportKeyingMaterial`]
50#[cfg(not(wasm_browser))]
51const DOMAIN_SEP_TLS_EXPORT_LABEL: &[u8] = b"iroh-relay handshake v1";
52
53/// Authentication message from the client.
54#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
55#[cfg_attr(wasm_browser, allow(unused))]
56pub(crate) struct KeyMaterialClientAuth {
57    /// The client's public key
58    pub(crate) public_key: PublicKey,
59    /// A signature of (a hash of) extracted key material.
60    #[serde(with = "serde_bytes")]
61    #[debug("{}", HEX.encode(signature))]
62    pub(crate) signature: [u8; 64],
63    /// Part of the extracted key material.
64    ///
65    /// Allows making sure we have the same underlying key material.
66    #[debug("{}", HEX.encode(key_material_suffix))]
67    pub(crate) key_material_suffix: [u8; 16],
68}
69
70/// A challenge for the client to sign with their secret key for EndpointId authentication.
71#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
72pub(crate) struct ServerChallenge {
73    /// The challenge to sign.
74    /// Must be randomly generated with an RNG that is safe to use for crypto.
75    #[debug("{}", HEX.encode(challenge))]
76    pub(crate) challenge: [u8; 16],
77}
78
79/// Authentication message from the client.
80///
81/// Used when authentication via [`KeyMaterialClientAuth`] didn't work.
82#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
83pub(crate) struct ClientAuth {
84    /// The client's public key, a.k.a. the `EndpointId`
85    pub(crate) public_key: PublicKey,
86    /// A signature of (a hash of) the [`ServerChallenge`].
87    ///
88    /// This is what provides the authentication.
89    #[serde(with = "serde_bytes")]
90    #[debug("{}", HEX.encode(signature))]
91    pub(crate) signature: [u8; 64],
92}
93
94/// Confirmation of successful connection.
95#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)]
96pub(crate) struct ServerConfirmsAuth;
97
98/// Denial of connection. The client couldn't be verified as authentic.
99#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
100pub(crate) struct ServerDeniesAuth {
101    reason: String,
102}
103
104/// Trait for getting the frame type tag for a frame.
105///
106/// Used only in the handshake, as the frame we expect next
107/// is fairly stateful.
108/// Not used in the send/recv protocol, as any frame is
109/// allowed to happen at any time there.
110trait Frame {
111    /// The frame type this frame is identified by and prefixed with
112    const TAG: FrameType;
113}
114
115impl<T: Frame> Frame for &T {
116    const TAG: FrameType = T::TAG;
117}
118
119impl Frame for ServerChallenge {
120    const TAG: FrameType = FrameType::ServerChallenge;
121}
122
123impl Frame for ClientAuth {
124    const TAG: FrameType = FrameType::ClientAuth;
125}
126
127impl Frame for ServerConfirmsAuth {
128    const TAG: FrameType = FrameType::ServerConfirmsAuth;
129}
130
131impl Frame for ServerDeniesAuth {
132    const TAG: FrameType = FrameType::ServerDeniesAuth;
133}
134
135#[stack_error(derive, add_meta)]
136#[allow(missing_docs)]
137#[non_exhaustive]
138pub enum Error {
139    #[error(transparent)]
140    Websocket {
141        #[cfg(not(wasm_browser))]
142        #[error(from, std_err)]
143        source: tokio_websockets::Error,
144        #[cfg(wasm_browser)]
145        #[error(from, std_err)]
146        source: ws_stream_wasm::WsErr,
147    },
148    #[error("Handshake stream ended prematurely")]
149    UnexpectedEnd {},
150    #[error(transparent)]
151    FrameTypeError {
152        #[error(from)]
153        source: FrameTypeError,
154    },
155    #[error("The relay denied our authentication ({reason})")]
156    ServerDeniedAuth { reason: String },
157    #[error("Unexpected tag, got {frame_type:?}, but expected one of {expected_types:?}")]
158    UnexpectedFrameType {
159        frame_type: FrameType,
160        expected_types: Vec<FrameType>,
161    },
162    #[error("Handshake failed while deserializing {frame_type:?} frame")]
163    DeserializationError {
164        frame_type: FrameType,
165        #[error(std_err)]
166        source: postcard::Error,
167    },
168    #[cfg(feature = "server")]
169    /// Failed to deserialize client auth header
170    ClientAuthHeaderInvalid { value: HeaderValue },
171}
172
173#[cfg(feature = "server")]
174#[stack_error(derive, add_meta)]
175pub(crate) enum VerificationError {
176    #[error("Couldn't export TLS keying material on our end")]
177    NoKeyingMaterial,
178    #[error(
179        "Client didn't extract the same keying material, the suffix mismatched: expected {expected:X?} but got {actual:X?}"
180    )]
181    MismatchedSuffix {
182        expected: [u8; 16],
183        actual: [u8; 16],
184    },
185    #[error(
186        "Client signature {signature:X?} for message {message:X?} invalid for public key {public_key}"
187    )]
188    SignatureInvalid {
189        source: iroh_base::SignatureError,
190        message: Vec<u8>,
191        signature: [u8; 64],
192        public_key: PublicKey,
193    },
194}
195
196impl ServerChallenge {
197    /// Generates a new challenge.
198    #[cfg(feature = "server")]
199    pub(crate) fn new<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
200        let mut challenge = [0u8; 16];
201        rng.fill_bytes(&mut challenge);
202        Self { challenge }
203    }
204
205    /// The actual message bytes to sign (and verify against) for this challenge.
206    fn message_to_sign(&self) -> [u8; 32] {
207        // We're signing a key instead of the direct challenge.
208        // This gives us domain separation protecting from multiple possible attacks,
209        // but especially this one:
210        // Assume a malicious relay. If the protocol required the client to sign the
211        // challenge directly, this would allow the relay to obtain an arbitrary 16-byte
212        // signature, if it maliciously choses the challenge instead of generating it
213        // randomly.
214        // Deriving a key to sign instead mitigates this attack.
215        blake3::derive_key(DOMAIN_SEP_CHALLENGE, &self.challenge)
216    }
217}
218
219impl ClientAuth {
220    /// Generates a signature for the given challenge from the server.
221    pub(crate) fn new(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self {
222        Self {
223            public_key: secret_key.public(),
224            signature: secret_key.sign(&challenge.message_to_sign()).to_bytes(),
225        }
226    }
227
228    /// Verifies this client's authentication given the challenge this was sent in response to.
229    #[cfg(feature = "server")]
230    pub(crate) fn verify(&self, challenge: &ServerChallenge) -> Result<(), Box<VerificationError>> {
231        let message = challenge.message_to_sign();
232        self.public_key
233            .verify(&message, &Signature::from_bytes(&self.signature))
234            .map_err(|err| {
235                e!(VerificationError::SignatureInvalid {
236                    source: err,
237                    message: message.to_vec(),
238                    signature: self.signature,
239                    public_key: self.public_key
240                })
241            })
242            .map_err(Box::new)
243    }
244}
245
246#[cfg(not(wasm_browser))]
247impl KeyMaterialClientAuth {
248    /// Generates a client's authentication, similar to [`ClientAuth`], but by using TLS keying material
249    /// instead of a received challenge.
250    pub(crate) fn new(secret_key: &SecretKey, io: &impl ExportKeyingMaterial) -> Option<Self> {
251        let public_key = secret_key.public();
252        let key_material = io.export_keying_material(
253            [0u8; 32],
254            DOMAIN_SEP_TLS_EXPORT_LABEL,
255            Some(secret_key.public().as_bytes()),
256        )?;
257        // We split the export and only sign the first 16 bytes, and
258        // pass through the last 16 bytes. See also the note in [Self::verify].
259        let (message, suffix) = key_material.split_at(16);
260        Some(Self {
261            public_key,
262            signature: secret_key.sign(message).to_bytes(),
263            key_material_suffix: suffix.try_into().expect("hardcoded length"),
264        })
265    }
266
267    /// Generate the base64url-nopad-encoded header value.
268    pub(crate) fn into_header_value(self) -> HeaderValue {
269        HeaderValue::from_str(
270            &data_encoding::BASE64URL_NOPAD
271                .encode(&postcard::to_allocvec(&self).expect("encoding never fails")),
272        )
273        .expect("BASE64URL_NOPAD encoding contained invisible ascii characters")
274    }
275
276    /// Verifies this client auth on the server side using the same key material.
277    ///
278    /// This might return false for a couple of reasons:
279    /// 1. The exported keying material might not be the same between both ends of the TLS session
280    ///    (e.g. there's an HTTPS proxy in between that doesn't think/care about the TLS keying material exporter).
281    ///    This situation is detected when the key material suffix mismatches.
282    /// 2. The signature itself doesn't verify.
283    #[cfg(feature = "server")]
284    pub(crate) fn verify(
285        &self,
286        io: &impl ExportKeyingMaterial,
287    ) -> Result<(), Box<VerificationError>> {
288        let key_material = io
289            .export_keying_material(
290                [0u8; 32],
291                DOMAIN_SEP_TLS_EXPORT_LABEL,
292                Some(self.public_key.as_bytes()),
293            )
294            .ok_or_else(|| e!(VerificationError::NoKeyingMaterial))?;
295        // We split the export and only sign the first 16 bytes, and
296        // pass through the last 16 bytes.
297        // Passing on the suffix helps the verifying end figure out what
298        // went wrong: If there's a suffix mismatch, then the exported keying
299        // material on both ends wasn't the same - so perhaps there was a
300        // TLS proxy in between or similar.
301        // If the suffix does match, but the signature doesn't verify, then
302        // there must be something wrong with the client's secret key or signature.
303        let (message, suffix) = key_material.split_at(16);
304        let suffix: [u8; 16] = suffix.try_into().expect("hardcoded length");
305        ensure!(
306            suffix == self.key_material_suffix,
307            VerificationError::MismatchedSuffix {
308                expected: self.key_material_suffix,
309                actual: suffix
310            }
311        );
312        // NOTE: We don't blake3-hash here as we do it in [`ServerChallenge::message_to_sign`],
313        // because we already have a domain separation string and keyed hashing step in
314        // the TLS export keying material above.
315        self.public_key
316            .verify(message, &Signature::from_bytes(&self.signature))
317            .map_err(|err| {
318                e!(VerificationError::SignatureInvalid {
319                    source: err,
320                    message: message.to_vec(),
321                    public_key: self.public_key,
322                    signature: self.signature
323                })
324            })
325            .map_err(Box::new)
326    }
327}
328
329/// Runs the client side of the handshake protocol.
330///
331/// See the module docs for details on the protocol.
332/// This is already after having potentially transferred a [`KeyMaterialClientAuth`],
333/// but before having received a response for whether that worked or not.
334///
335/// This requires access to the client's secret key to sign a challenge.
336pub(crate) async fn clientside(
337    io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
338    secret_key: &SecretKey,
339) -> Result<ServerConfirmsAuth, Error> {
340    let (tag, frame) = read_frame(io, &[ServerChallenge::TAG, ServerConfirmsAuth::TAG]).await?;
341
342    let (tag, frame) = if tag == ServerChallenge::TAG {
343        let challenge: ServerChallenge = deserialize_frame(frame)?;
344
345        let client_info = ClientAuth::new(secret_key, &challenge);
346        write_frame(io, client_info).await?;
347
348        read_frame(io, &[ServerConfirmsAuth::TAG, ServerDeniesAuth::TAG]).await?
349    } else {
350        (tag, frame)
351    };
352
353    match tag {
354        FrameType::ServerConfirmsAuth => {
355            let confirmation: ServerConfirmsAuth = deserialize_frame(frame)?;
356            Ok(confirmation)
357        }
358        FrameType::ServerDeniesAuth => {
359            let denial: ServerDeniesAuth = deserialize_frame(frame)?;
360            Err(e!(Error::ServerDeniedAuth {
361                reason: denial.reason
362            }))
363        }
364        _ => unreachable!(),
365    }
366}
367
368/// This represents successful authentication for the client with the `client_key` public key
369/// via the authentication [`Mechanism`] `mechanism`.
370///
371/// You must call [`SuccessfulAuthentication::authorize_if`] to finish the protocol.
372#[cfg(feature = "server")]
373/// Result of a successful authentication handshake.
374///
375/// This struct represents a client that has successfully authenticated itself to the relay
376/// server. The authorization must still be confirmed by calling [`Self::authorize_if`] to
377/// complete the protocol and notify the client of success or failure.
378#[derive(Debug)]
379#[must_use = "the protocol is not finished unless `authorize_if` is called"]
380pub struct SuccessfulAuthentication {
381    /// The authenticated client's public key.
382    pub client_key: PublicKey,
383    /// The authentication mechanism that was used.
384    pub mechanism: Mechanism,
385}
386
387/// The mechanism that was used for authentication.
388#[cfg(feature = "server")]
389#[derive(Debug, Clone, Copy, PartialEq, Eq)]
390pub enum Mechanism {
391    /// Authentication was performed by verifying a signature of a challenge we sent
392    SignedChallenge,
393    /// Authentication was performed by verifying a signature of shared extracted TLS keying material
394    SignedKeyMaterial,
395}
396
397/// Runs the server side of the handshaking protocol.
398///
399/// See the module documentation for an overview of the handshaking protocol.
400///
401/// This takes `rng` to generate cryptographic randomness for the authentication challenge.
402///
403/// This also takes the `client_auth_header`, if present, to perform authentication without
404/// requiring sending a challenge, saving a round-trip, if possible.
405///
406/// If this fails, the protocol falls back to doing a normal extra round trip with a challenge.
407///
408/// The return value [`SuccessfulAuthentication`] still needs to be resolved by calling
409/// [`SuccessfulAuthentication::authorize_if`] to finish the whole authorization protocol
410/// (otherwise the client won't be notified about auth success or failure).
411#[cfg(feature = "server")]
412pub async fn serverside(
413    io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
414    client_auth_header: Option<HeaderValue>,
415) -> Result<SuccessfulAuthentication, Error> {
416    if let Some(client_auth_header) = client_auth_header {
417        let client_auth_bytes = data_encoding::BASE64URL_NOPAD
418            .decode(client_auth_header.as_ref())
419            .map_err(|_| {
420                e!(Error::ClientAuthHeaderInvalid {
421                    value: client_auth_header.clone()
422                })
423            })?;
424
425        let client_auth: KeyMaterialClientAuth =
426            postcard::from_bytes(&client_auth_bytes).map_err(|_| {
427                e!(Error::ClientAuthHeaderInvalid {
428                    value: client_auth_header.clone()
429                })
430            })?;
431
432        if client_auth.verify(io).is_ok() {
433            trace!(?client_auth.public_key, "authentication succeeded via keying material");
434            return Ok(SuccessfulAuthentication {
435                client_key: client_auth.public_key,
436                mechanism: Mechanism::SignedKeyMaterial,
437            });
438        }
439        // Verification not succeeding is part of normal operation: The TLS exporter isn't required to match.
440        // We'll fall back to verification that takes another round trip more time.
441    }
442
443    let challenge = ServerChallenge::new(&mut rand::rng());
444    write_frame(io, &challenge).await?;
445
446    let (_, frame) = read_frame(io, &[ClientAuth::TAG]).await?;
447    let client_auth: ClientAuth = deserialize_frame(frame)?;
448
449    if let Err(err) = client_auth.verify(&challenge) {
450        trace!(?client_auth.public_key, ?err, "authentication failed");
451        let denial = ServerDeniesAuth {
452            reason: "signature invalid".into(),
453        };
454        write_frame(io, denial.clone()).await?;
455        Err(e!(Error::ServerDeniedAuth {
456            reason: denial.reason
457        }))
458    } else {
459        trace!(?client_auth.public_key, "authentication succeeded via challenge");
460        Ok(SuccessfulAuthentication {
461            client_key: client_auth.public_key,
462            mechanism: Mechanism::SignedChallenge,
463        })
464    }
465}
466
467#[cfg(feature = "server")]
468impl SuccessfulAuthentication {
469    /// Completes the authorization protocol by notifying the client of success or failure.
470    ///
471    /// After a client has been successfully authenticated via [`serverside`], the server must
472    /// decide whether to authorize the client (allow access) or deny it. This method sends
473    /// the authorization decision to the client and completes the handshake protocol.
474    ///
475    /// # Arguments
476    /// * `is_authorized` - Whether to grant access to the authenticated client
477    /// * `io` - The WebSocket stream to send the authorization response on
478    ///
479    /// # Returns
480    /// * `Ok(PublicKey)` - The client's public key if authorization was granted
481    /// * `Err(Error)` - If authorization was denied or communication failed
482    pub async fn authorize_if(
483        self,
484        is_authorized: bool,
485        io: &mut (impl BytesStreamSink + ExportKeyingMaterial),
486    ) -> Result<PublicKey, Error> {
487        if is_authorized {
488            trace!("authorizing client");
489            write_frame(io, ServerConfirmsAuth).await?;
490            Ok(self.client_key)
491        } else {
492            trace!("denying client auth");
493            let denial = ServerDeniesAuth {
494                reason: "not authorized".into(),
495            };
496            write_frame(io, denial.clone()).await?;
497            Err(e!(Error::ServerDeniedAuth {
498                reason: denial.reason
499            }))
500        }
501    }
502}
503
504async fn write_frame<F: serde::Serialize + Frame>(
505    io: &mut impl BytesStreamSink,
506    frame: F,
507) -> Result<(), Error> {
508    let mut bytes = BytesMut::new();
509    trace!(frame_type = ?F::TAG, "Writing frame");
510    F::TAG.write_to(&mut bytes);
511    let bytes = postcard::to_io(&frame, bytes.writer())
512        .expect("serialization failed") // buffer can't become "full" without being a critical failure, datastructures shouldn't ever fail serialization
513        .into_inner()
514        .freeze();
515    io.send(bytes).await?;
516    io.flush().await?;
517    Ok(())
518}
519
520async fn read_frame(
521    io: &mut impl BytesStreamSink,
522    expected_types: &[FrameType],
523) -> Result<(FrameType, Bytes), Error> {
524    let mut payload = io
525        .try_next()
526        .await?
527        .ok_or_else(|| e!(Error::UnexpectedEnd))?;
528
529    let frame_type = FrameType::from_bytes(&mut payload)?;
530    trace!(?frame_type, "Reading frame");
531    ensure!(
532        expected_types.contains(&frame_type),
533        Error::UnexpectedFrameType {
534            frame_type,
535            expected_types: expected_types.to_vec()
536        }
537    );
538
539    Ok((frame_type, payload))
540}
541
542fn deserialize_frame<F: Frame + serde::de::DeserializeOwned>(frame: Bytes) -> Result<F, Error> {
543    postcard::from_bytes(&frame).map_err(|err| {
544        e!(Error::DeserializationError {
545            frame_type: F::TAG,
546            source: err
547        })
548    })
549}
550
551#[cfg(all(test, feature = "server"))]
552mod tests {
553    use bytes::BytesMut;
554    use iroh_base::{PublicKey, SecretKey};
555    use n0_error::{Result, StackResultExt, StdResultExt};
556    use n0_future::{Sink, SinkExt, Stream, TryStreamExt};
557    use n0_tracing_test::traced_test;
558    use rand::SeedableRng;
559    use tokio_util::codec::{Framed, LengthDelimitedCodec};
560    use tracing::{Instrument, info_span};
561
562    use super::{
563        ClientAuth, KeyMaterialClientAuth, Mechanism, ServerChallenge, ServerConfirmsAuth,
564    };
565    use crate::ExportKeyingMaterial;
566
567    struct TestKeyingMaterial<IO> {
568        shared_secret: Option<u64>,
569        inner: IO,
570    }
571
572    trait WithTlsSharedSecret: Sized {
573        fn with_shared_secret(self, shared_secret: Option<u64>) -> TestKeyingMaterial<Self>;
574    }
575
576    impl<T: Sized> WithTlsSharedSecret for T {
577        fn with_shared_secret(self, shared_secret: Option<u64>) -> TestKeyingMaterial<Self> {
578            TestKeyingMaterial {
579                shared_secret,
580                inner: self,
581            }
582        }
583    }
584
585    impl<IO> ExportKeyingMaterial for TestKeyingMaterial<IO> {
586        fn export_keying_material<T: AsMut<[u8]>>(
587            &self,
588            mut output: T,
589            label: &[u8],
590            context: Option<&[u8]>,
591        ) -> Option<T> {
592            // we simulate something like exporting keying material using blake3
593
594            let label_key = blake3::hash(label);
595            let context_key = blake3::keyed_hash(label_key.as_bytes(), context.unwrap_or(&[]));
596            let mut hasher = blake3::Hasher::new_keyed(context_key.as_bytes());
597            hasher.update(&self.shared_secret?.to_le_bytes());
598            hasher.finalize_xof().fill(output.as_mut());
599
600            Some(output)
601        }
602    }
603
604    impl<V, IO: Stream<Item = V> + Unpin> Stream for TestKeyingMaterial<IO> {
605        type Item = V;
606
607        fn poll_next(
608            mut self: std::pin::Pin<&mut Self>,
609            cx: &mut std::task::Context<'_>,
610        ) -> std::task::Poll<Option<Self::Item>> {
611            std::pin::Pin::new(&mut self.inner).poll_next(cx)
612        }
613    }
614
615    impl<V, E, IO: Sink<V, Error = E> + Unpin> Sink<V> for TestKeyingMaterial<IO> {
616        type Error = E;
617
618        fn poll_ready(
619            mut self: std::pin::Pin<&mut Self>,
620            cx: &mut std::task::Context<'_>,
621        ) -> std::task::Poll<Result<(), Self::Error>> {
622            std::pin::Pin::new(&mut self.inner).poll_ready(cx)
623        }
624
625        fn start_send(mut self: std::pin::Pin<&mut Self>, item: V) -> Result<(), Self::Error> {
626            std::pin::Pin::new(&mut self.inner).start_send(item)
627        }
628
629        fn poll_flush(
630            mut self: std::pin::Pin<&mut Self>,
631            cx: &mut std::task::Context<'_>,
632        ) -> std::task::Poll<Result<(), Self::Error>> {
633            std::pin::Pin::new(&mut self.inner).poll_flush(cx)
634        }
635
636        fn poll_close(
637            mut self: std::pin::Pin<&mut Self>,
638            cx: &mut std::task::Context<'_>,
639        ) -> std::task::Poll<Result<(), Self::Error>> {
640            std::pin::Pin::new(&mut self.inner).poll_close(cx)
641        }
642    }
643
644    async fn simulate_handshake(
645        secret_key: &SecretKey,
646        client_shared_secret: Option<u64>,
647        server_shared_secret: Option<u64>,
648        restricted_to: Option<PublicKey>,
649    ) -> (Result<ServerConfirmsAuth>, Result<(PublicKey, Mechanism)>) {
650        let (client, server) = tokio::io::duplex(1024);
651
652        let mut client_io = Framed::new(client, LengthDelimitedCodec::new())
653            .map_ok(BytesMut::freeze)
654            .map_err(tokio_websockets::Error::Io)
655            .sink_map_err(tokio_websockets::Error::Io)
656            .with_shared_secret(client_shared_secret);
657        let mut server_io = Framed::new(server, LengthDelimitedCodec::new())
658            .map_ok(BytesMut::freeze)
659            .map_err(tokio_websockets::Error::Io)
660            .sink_map_err(tokio_websockets::Error::Io)
661            .with_shared_secret(server_shared_secret);
662
663        let client_auth_header = KeyMaterialClientAuth::new(secret_key, &client_io)
664            .map(KeyMaterialClientAuth::into_header_value);
665
666        n0_future::future::zip(
667            async {
668                super::clientside(&mut client_io, secret_key)
669                    .await
670                    .context("clientside")
671            }
672            .instrument(info_span!("clientside")),
673            async {
674                let auth_n = super::serverside(&mut server_io, client_auth_header)
675                    .await
676                    .context("serverside")?;
677                let mechanism = auth_n.mechanism;
678                let is_authorized = restricted_to.is_none_or(|key| key == auth_n.client_key);
679                let key = auth_n.authorize_if(is_authorized, &mut server_io).await?;
680                Ok((key, mechanism))
681            }
682            .instrument(info_span!("serverside")),
683        )
684        .await
685    }
686
687    #[tokio::test]
688    #[traced_test]
689    async fn test_handshake_via_shared_secrets() -> Result {
690        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
691
692        let secret_key = SecretKey::generate(&mut rng);
693        let (client, server) = simulate_handshake(&secret_key, Some(42), Some(42), None).await;
694        client?;
695        let (public_key, auth) = server?;
696        assert_eq!(public_key, secret_key.public());
697        assert_eq!(auth, Mechanism::SignedKeyMaterial); // it got verified via shared key material
698        Ok(())
699    }
700
701    #[tokio::test]
702    #[traced_test]
703    async fn test_handshake_via_challenge() -> Result {
704        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
705
706        let secret_key = SecretKey::generate(&mut rng);
707        let (client, server) = simulate_handshake(&secret_key, None, None, None).await;
708        client?;
709        let (public_key, auth) = server?;
710        assert_eq!(public_key, secret_key.public());
711        assert_eq!(auth, Mechanism::SignedChallenge);
712        Ok(())
713    }
714
715    #[tokio::test]
716    #[traced_test]
717    async fn test_handshake_mismatching_shared_secrets() -> Result {
718        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
719
720        let secret_key = SecretKey::generate(&mut rng);
721        // mismatching shared secrets *might* happen with HTTPS proxies that don't also middle-man the shared secret
722        let (client, server) = simulate_handshake(&secret_key, Some(10), Some(99), None).await;
723        client?;
724        let (public_key, auth) = server?;
725        assert_eq!(public_key, secret_key.public());
726        assert_eq!(auth, Mechanism::SignedChallenge);
727        Ok(())
728    }
729
730    #[tokio::test]
731    #[traced_test]
732    async fn test_handshake_challenge_fallback() -> Result {
733        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
734        let secret_key = SecretKey::generate(&mut rng);
735        // clients might not have access to shared secrets
736        let (client, server) = simulate_handshake(&secret_key, None, Some(99), None).await;
737        client?;
738        let (public_key, auth) = server?;
739        assert_eq!(public_key, secret_key.public());
740        assert_eq!(auth, Mechanism::SignedChallenge);
741        Ok(())
742    }
743
744    #[tokio::test]
745    #[traced_test]
746    async fn test_handshake_with_auth_positive() -> Result {
747        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
748        let secret_key = SecretKey::generate(&mut rng);
749        let public_key = secret_key.public();
750        let (client, server) = simulate_handshake(&secret_key, None, None, Some(public_key)).await;
751        client?;
752        let (public_key, _) = server?;
753        assert_eq!(public_key, secret_key.public());
754        Ok(())
755    }
756
757    #[tokio::test]
758    #[traced_test]
759    async fn test_handshake_with_auth_negative() -> Result {
760        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
761        let secret_key = SecretKey::generate(&mut rng);
762        let public_key = secret_key.public();
763        let wrong_secret_key = SecretKey::generate(&mut rng);
764        let (client, server) =
765            simulate_handshake(&wrong_secret_key, None, None, Some(public_key)).await;
766        assert!(client.is_err());
767        assert!(server.is_err());
768        Ok(())
769    }
770
771    #[tokio::test]
772    #[traced_test]
773    async fn test_handshake_via_shared_secret_with_auth_negative() -> Result {
774        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
775        let secret_key = SecretKey::generate(&mut rng);
776        let public_key = secret_key.public();
777        let wrong_secret_key = SecretKey::generate(&mut rng);
778        let (client, server) =
779            simulate_handshake(&wrong_secret_key, Some(42), Some(42), Some(public_key)).await;
780        assert!(client.is_err());
781        assert!(server.is_err());
782        Ok(())
783    }
784
785    #[test]
786    fn test_client_auth_roundtrip() -> Result {
787        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
788        let secret_key = SecretKey::generate(&mut rng);
789        let challenge = ServerChallenge::new(&mut rng);
790        let client_auth = ClientAuth::new(&secret_key, &challenge);
791
792        let bytes = postcard::to_allocvec(&client_auth).anyerr()?;
793        let decoded: ClientAuth = postcard::from_bytes(&bytes).anyerr()?;
794
795        assert_eq!(client_auth.public_key, decoded.public_key);
796        assert_eq!(client_auth.signature, decoded.signature);
797
798        Ok(())
799    }
800
801    #[test]
802    fn test_km_client_auth_roundtrip() -> Result {
803        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
804        let secret_key = SecretKey::generate(&mut rng);
805        let client_auth = KeyMaterialClientAuth::new(
806            &secret_key,
807            &TestKeyingMaterial {
808                inner: (),
809                shared_secret: Some(42),
810            },
811        )
812        .anyerr()?;
813
814        let bytes = postcard::to_allocvec(&client_auth).anyerr()?;
815        let decoded: KeyMaterialClientAuth = postcard::from_bytes(&bytes).anyerr()?;
816
817        assert_eq!(client_auth.public_key, decoded.public_key);
818        assert_eq!(client_auth.signature, decoded.signature);
819
820        Ok(())
821    }
822
823    #[test]
824    fn test_challenge_verification() -> Result {
825        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
826        let secret_key = SecretKey::generate(&mut rng);
827        let challenge = ServerChallenge::new(&mut rng);
828        let client_auth = ClientAuth::new(&secret_key, &challenge);
829        assert!(client_auth.verify(&challenge).is_ok());
830
831        Ok(())
832    }
833
834    #[test]
835    fn test_key_material_verification() -> Result {
836        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
837        let secret_key = SecretKey::generate(&mut rng);
838        let io = TestKeyingMaterial {
839            inner: (),
840            shared_secret: Some(42),
841        };
842        let client_auth = KeyMaterialClientAuth::new(&secret_key, &io).anyerr()?;
843        assert!(client_auth.verify(&io).is_ok());
844
845        Ok(())
846    }
847}