iroh_relay/server/
http_server.rs

1//! Low-level HTTP server components for embedding the relay service.
2//!
3//! This module provides [`RelayService`] which can be used to embed relay functionality
4//! into an existing HTTP server. It handles individual connections and provides
5//! the core relay protocol implementation.
6//!
7//! For a complete relay server implementation, see the parent [`server`](super) module.
8
9use std::{
10    collections::HashMap, future::Future, net::SocketAddr, pin::Pin, sync::Arc, time::Duration,
11};
12
13use bytes::Bytes;
14use derive_more::Debug;
15use http::{
16    header::{CONNECTION, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION},
17    response::Builder as ResponseBuilder,
18};
19use hyper::{
20    HeaderMap, Method, Request, Response, StatusCode,
21    body::Incoming,
22    header::{HeaderValue, SEC_WEBSOCKET_ACCEPT, UPGRADE},
23    service::Service,
24    upgrade::Upgraded,
25};
26use n0_error::{e, ensure, stack_error};
27use n0_future::time::Elapsed;
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls_acme::AcmeAcceptor;
30use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
31use tracing::{Instrument, debug, error, info, info_span, trace, warn, warn_span};
32
33use super::{AccessConfig, SpawnError, clients::Clients, streams::InvalidBucketConfig};
34use crate::{
35    KeyCache,
36    defaults::{DEFAULT_KEY_CACHE_CAPACITY, timeouts::SERVER_WRITE_TIMEOUT},
37    http::{
38        CLIENT_AUTH_HEADER, RELAY_PATH, RELAY_PROTOCOL_VERSION, SUPPORTED_WEBSOCKET_VERSION,
39        WEBSOCKET_UPGRADE_PROTOCOL,
40    },
41    protos::{
42        handshake,
43        relay::{MAX_FRAME_SIZE, PER_CLIENT_SEND_QUEUE_DEPTH},
44        streams::WsBytesFramed,
45    },
46    server::{
47        ClientRateLimit,
48        client::Config,
49        metrics::Metrics,
50        streams::{MaybeTlsStream, RateLimited, RelayedStream},
51    },
52};
53
54type BytesBody = http_body_util::Full<hyper::body::Bytes>;
55type HyperError = Box<dyn std::error::Error + Send + Sync>;
56type HyperResult<T> = std::result::Result<T, HyperError>;
57type HyperHandler = Box<
58    dyn Fn(Request<Incoming>, ResponseBuilder) -> HyperResult<Response<BytesBody>>
59        + Send
60        + Sync
61        + 'static,
62>;
63
64/// WebSocket GUID needed for accepting websocket connections, see RFC 6455 (<https://www.rfc-editor.org/rfc/rfc6455>) section 1.3
65const SEC_WEBSOCKET_ACCEPT_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
66
67/// Derives the accept key for WebSocket handshake according to RFC 6455.
68/// Takes the client's Sec-WebSocket-Key value and returns the calculated accept key.
69fn derive_accept_key(client_key: &HeaderValue) -> String {
70    use sha1::Digest;
71
72    let mut sha1 = sha1::Sha1::new();
73    sha1.update(client_key.as_bytes());
74    sha1.update(SEC_WEBSOCKET_ACCEPT_GUID);
75    data_encoding::BASE64.encode(&sha1.finalize())
76}
77
78/// Creates a new [`BytesBody`] with given content.
79fn body_full(content: impl Into<hyper::body::Bytes>) -> BytesBody {
80    http_body_util::Full::new(content.into())
81}
82
83#[allow(clippy::result_large_err)]
84fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes), ConnectionHandlerError> {
85    match upgraded.downcast::<hyper_util::rt::TokioIo<MaybeTlsStream>>() {
86        Ok(parts) => Ok((parts.io.into_inner(), parts.read_buf)),
87        Err(_) => Err(e!(ConnectionHandlerError::DowncastUpgrade)),
88    }
89}
90
91/// The Relay HTTP server.
92///
93/// A running HTTP server serving the relay endpoint and optionally a number of additional
94/// HTTP services added with [`ServerBuilder::request_handler`].  If configured using
95/// [`ServerBuilder::tls_config`] the server will handle TLS as well.
96///
97/// Created using [`ServerBuilder::spawn`].
98#[derive(Debug)]
99pub(super) struct Server {
100    addr: SocketAddr,
101    http_server_task: AbortOnDropHandle<()>,
102    cancel_server_loop: CancellationToken,
103}
104
105impl Server {
106    /// Returns a handle for this server.
107    ///
108    /// The server runs in the background as several async tasks.  This allows controlling
109    /// the server, in particular it allows gracefully shutting down the server.
110    pub(super) fn handle(&self) -> ServerHandle {
111        ServerHandle {
112            cancel_token: self.cancel_server_loop.clone(),
113        }
114    }
115
116    /// Closes the underlying relay server and the HTTP(S) server tasks.
117    pub(super) fn shutdown(&self) {
118        self.cancel_server_loop.cancel();
119    }
120
121    /// Returns the [`AbortOnDropHandle`] for the supervisor task managing the server.
122    ///
123    /// This is the root of all the tasks for the server.  Aborting it will abort all the
124    /// other tasks for the server.  Awaiting it will complete when all the server tasks are
125    /// completed.
126    pub(super) fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
127        &mut self.http_server_task
128    }
129
130    /// Returns the local address of this server.
131    pub(super) fn addr(&self) -> SocketAddr {
132        self.addr
133    }
134}
135
136/// A handle for the [`Server`].
137///
138/// This does not allow access to the task but can communicate with it.
139#[derive(Debug, Clone)]
140pub(super) struct ServerHandle {
141    cancel_token: CancellationToken,
142}
143
144impl ServerHandle {
145    /// Gracefully shut down the server.
146    pub(super) fn shutdown(&self) {
147        self.cancel_token.cancel()
148    }
149}
150
151/// Configuration to use for the TLS connection
152///
153/// This struct wraps a rustls server configuration and TLS acceptor for use with
154/// [`RelayService::handle_connection`].
155///
156/// # Example
157///
158/// ```
159/// use std::sync::Arc;
160///
161/// use iroh_relay::server::http_server::TlsConfig;
162/// use rustls::ServerConfig;
163/// use webpki_types::{CertificateDer, PrivateKeyDer};
164///
165/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
166/// // Generate a self-signed certificate for testing
167/// let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
168/// let cert_der = cert.cert.der().to_vec();
169/// let private_key_der = cert.signing_key.serialize_der();
170///
171/// // Create rustls types
172/// let cert_chain = vec![CertificateDer::from(cert_der)];
173/// let private_key = PrivateKeyDer::try_from(private_key_der)?;
174///
175/// // Create a rustls ServerConfig
176/// let server_config = Arc::new(
177///     ServerConfig::builder()
178///         .with_no_client_auth()
179///         .with_single_cert(cert_chain, private_key)?,
180/// );
181///
182/// // Create TlsConfig for use with RelayService
183/// let tls_config = TlsConfig::new(server_config);
184/// # Ok(())
185/// # }
186/// ```
187#[derive(Debug, Clone)]
188pub struct TlsConfig {
189    /// The server config
190    pub(super) config: Arc<rustls::ServerConfig>,
191    /// The kind
192    pub(super) acceptor: TlsAcceptor,
193}
194
195impl TlsConfig {
196    /// Creates a new `TlsConfig` from a rustls `ServerConfig`.
197    ///
198    /// This creates a manual TLS acceptor using the provided server configuration.
199    /// The acceptor will handle TLS handshakes for incoming connections.
200    ///
201    /// # Example
202    ///
203    /// ```
204    /// use std::sync::Arc;
205    ///
206    /// use iroh_relay::server::http_server::TlsConfig;
207    /// use rustls::ServerConfig;
208    /// use webpki_types::{CertificateDer, PrivateKeyDer};
209    ///
210    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
211    /// // Generate a self-signed certificate for testing
212    /// let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
213    /// let cert_der = cert.cert.der().to_vec();
214    /// let private_key_der = cert.signing_key.serialize_der();
215    ///
216    /// // Create rustls types
217    /// let cert_chain = vec![CertificateDer::from(cert_der)];
218    /// let private_key = PrivateKeyDer::try_from(private_key_der)?;
219    ///
220    /// let server_config = Arc::new(
221    ///     ServerConfig::builder()
222    ///         .with_no_client_auth()
223    ///         .with_single_cert(cert_chain, private_key)?,
224    /// );
225    ///
226    /// let tls_config = TlsConfig::new(server_config);
227    /// # Ok(())
228    /// # }
229    /// ```
230    pub fn new(config: Arc<rustls::ServerConfig>) -> Self {
231        let acceptor = tokio_rustls::TlsAcceptor::from(config.clone());
232        Self {
233            config,
234            acceptor: TlsAcceptor::Manual(acceptor),
235        }
236    }
237}
238
239/// Errors when attempting to upgrade and
240#[allow(missing_docs)]
241#[stack_error(derive, add_meta)]
242#[non_exhaustive]
243pub enum ServeConnectionError {
244    #[error("TLS[acme] handshake")]
245    TlsHandshake {
246        #[error(std_err)]
247        source: std::io::Error,
248    },
249    #[error("TLS[acme] serve connection")]
250    ServeConnection {
251        #[error(std_err)]
252        source: hyper::Error,
253    },
254    #[error("TLS[manual] timeout")]
255    Timeout {
256        #[error(std_err)]
257        source: Elapsed,
258    },
259    #[error("TLS[manual] accept")]
260    ManualAccept {
261        #[error(std_err)]
262        source: std::io::Error,
263    },
264    #[error("TLS[acme] accept")]
265    LetsEncryptAccept {
266        #[error(std_err)]
267        source: std::io::Error,
268    },
269    #[error("HTTPS connection")]
270    Https {
271        #[error(std_err)]
272        source: hyper::Error,
273    },
274    #[error("HTTP connection")]
275    Http {
276        #[error(std_err)]
277        source: hyper::Error,
278    },
279}
280
281/// Server accept errors.
282#[allow(missing_docs)]
283#[stack_error(derive, add_meta, from_sources)]
284#[non_exhaustive]
285pub enum AcceptError {
286    #[error(transparent)]
287    Handshake { source: handshake::Error },
288    #[error("rate limiting misconfigured")]
289    RateLimitingMisconfigured { source: InvalidBucketConfig },
290}
291
292/// Server connection errors, includes errors that can happen on `accept`.
293#[allow(missing_docs)]
294#[stack_error(derive, add_meta, from_sources)]
295#[non_exhaustive]
296pub enum ConnectionHandlerError {
297    #[error(transparent)]
298    Accept { source: AcceptError },
299    #[error("Could not downcast the upgraded connection to MaybeTlsStream")]
300    DowncastUpgrade {},
301    #[error("Cannot deal with buffered data yet: {buf:?}")]
302    BufferNotEmpty { buf: Bytes },
303}
304
305/// Builder for the Relay HTTP Server.
306///
307/// Defaults to handling relay requests on the "/relay" (and "/derp" for backwards compatibility) endpoint.
308/// Other HTTP endpoints can be added using [`ServerBuilder::request_handler`].
309#[derive(derive_more::Debug)]
310pub(super) struct ServerBuilder {
311    /// The ip + port combination for this server.
312    addr: SocketAddr,
313    /// Optional tls configuration/TlsAcceptor combination.
314    ///
315    /// When `None`, the server will serve HTTP, otherwise it will serve HTTPS.
316    tls_config: Option<TlsConfig>,
317    /// A map of request handlers to routes.
318    ///
319    /// Used when certain routes in your server should be made available at the same port as
320    /// the relay server, and so must be handled along side requests to the relay endpoint.
321    handlers: Handlers,
322    /// Headers to use for HTTP responses.
323    headers: HeaderMap,
324    /// Rate-limiting configuration for an individual client connection.
325    ///
326    /// Rate-limiting is enforced on received traffic from individual clients.  This
327    /// configuration applies to a single client connection.
328    client_rx_ratelimit: Option<ClientRateLimit>,
329    /// The capacity of the key cache.
330    key_cache_capacity: usize,
331    /// Access config for endpoints.
332    access: AccessConfig,
333    metrics: Option<Arc<Metrics>>,
334}
335
336impl ServerBuilder {
337    /// Creates a new [ServerBuilder].
338    pub(super) fn new(addr: SocketAddr) -> Self {
339        Self {
340            addr,
341            tls_config: None,
342            handlers: Default::default(),
343            headers: HeaderMap::new(),
344            client_rx_ratelimit: None,
345            key_cache_capacity: DEFAULT_KEY_CACHE_CAPACITY,
346            access: AccessConfig::Everyone,
347            metrics: None,
348        }
349    }
350
351    /// Sets the metrics collector.
352    pub(super) fn metrics(mut self, metrics: Arc<Metrics>) -> Self {
353        self.metrics = Some(metrics);
354        self
355    }
356
357    /// Set the access configuration.
358    pub(super) fn access(mut self, access: AccessConfig) -> Self {
359        self.access = access;
360        self
361    }
362
363    /// Serves all requests content using TLS.
364    pub(super) fn tls_config(mut self, config: Option<TlsConfig>) -> Self {
365        self.tls_config = config;
366        self
367    }
368
369    /// Sets the per-client rate-limit configuration for incoming data.
370    ///
371    /// On each client connection the incoming data is rate-limited.  By default
372    /// no rate limit is enforced.
373    pub(super) fn client_rx_ratelimit(mut self, config: ClientRateLimit) -> Self {
374        self.client_rx_ratelimit = Some(config);
375        self
376    }
377
378    /// Adds a custom handler for a specific Method & URI.
379    pub(super) fn request_handler(
380        mut self,
381        method: Method,
382        uri_path: &'static str,
383        handler: HyperHandler,
384    ) -> Self {
385        self.handlers.insert((method, uri_path), handler);
386        self
387    }
388
389    /// Adds HTTP headers to responses.
390    pub(super) fn headers(mut self, headers: HeaderMap) -> Self {
391        for (k, v) in headers.iter() {
392            self.headers.insert(k.clone(), v.clone());
393        }
394        self
395    }
396
397    /// Set the capacity of the cache for public keys.
398    pub fn key_cache_capacity(mut self, capacity: usize) -> Self {
399        self.key_cache_capacity = capacity;
400        self
401    }
402
403    /// Builds and spawns an HTTP(S) Relay Server.
404    pub(super) async fn spawn(self) -> Result<Server, SpawnError> {
405        let cancel_token = CancellationToken::new();
406
407        let service = RelayService::new(
408            self.handlers,
409            self.headers,
410            self.client_rx_ratelimit,
411            KeyCache::new(self.key_cache_capacity),
412            self.access,
413            self.metrics.unwrap_or_default(),
414        );
415
416        let addr = self.addr;
417        let tls_config = self.tls_config;
418
419        // Bind a TCP listener on `addr` and handles content using HTTPS.
420
421        let listener = TcpListener::bind(&addr)
422            .await
423            .map_err(|err| e!(super::SpawnError::BindTcpListener { addr }, err))?;
424
425        let addr = listener
426            .local_addr()
427            .map_err(|err| e!(super::SpawnError::NoLocalAddr, err))?;
428        let http_str = tls_config.as_ref().map_or("HTTP/WS", |_| "HTTPS/WSS");
429        info!("[{http_str}] relay: serving on {addr}");
430
431        let cancel = cancel_token.clone();
432        let task = tokio::task::spawn(
433            async move {
434                // create a join set to track all our connection tasks
435                let mut set = tokio::task::JoinSet::new();
436                loop {
437                    tokio::select! {
438                        biased;
439                        _ = cancel.cancelled() => {
440                            break;
441                        }
442                        Some(res) = set.join_next() => {
443                            if let Err(err) = res
444                                && err.is_panic()
445                            {
446                                panic!("task panicked: {err:#?}");
447                            }
448                        }
449                        res = listener.accept() => match res {
450                            Ok((stream, peer_addr)) => {
451                                debug!("connection opened from {peer_addr}");
452                                let tls_config = tls_config.clone();
453                                let service = service.clone();
454                                // spawn a task to handle the connection
455                                set.spawn(async move {
456                                    service
457                                        .handle_connection(stream, tls_config)
458                                        .await
459                                }.instrument(info_span!("conn", peer = %peer_addr)));
460                            }
461                            Err(err) => {
462                                error!("failed to accept connection: {err}");
463                            }
464                        }
465                    }
466                }
467                service.shutdown().await;
468                set.shutdown().await;
469                debug!("server has been shutdown.");
470            }
471            .instrument(info_span!("relay-http-serve")),
472        );
473
474        Ok(Server {
475            addr,
476            http_server_task: AbortOnDropHandle::new(task),
477            cancel_server_loop: cancel_token,
478        })
479    }
480}
481
482/// The hyper Service that serves the actual relay endpoints.
483///
484/// This service can be used standalone or embedded into an existing HTTP server.
485#[derive(Clone, Debug)]
486pub struct RelayService(Arc<Inner>);
487
488#[derive(Debug)]
489struct Inner {
490    handlers: Handlers,
491    headers: HeaderMap,
492    clients: Clients,
493    write_timeout: Duration,
494    rate_limit: Option<ClientRateLimit>,
495    key_cache: KeyCache,
496    access: AccessConfig,
497    metrics: Arc<Metrics>,
498}
499
500#[stack_error(derive, add_meta)]
501enum RelayUpgradeReqError {
502    #[error("missing header: {header}")]
503    MissingHeader { header: http::HeaderName },
504    #[error("invalid header value for {header}: {details}")]
505    InvalidHeader {
506        header: http::HeaderName,
507        details: String,
508    },
509    #[error(
510        "invalid header value for {SEC_WEBSOCKET_VERSION}: unsupported websocket version, only supporting {SUPPORTED_WEBSOCKET_VERSION}"
511    )]
512    UnsupportedWebsocketVersion,
513    #[error(
514        "invalid header value for {SEC_WEBSOCKET_PROTOCOL}: unsupported relay version: we support {we_support} but you only provide {you_support}"
515    )]
516    UnsupportedRelayVersion {
517        we_support: &'static str,
518        you_support: String,
519    },
520}
521
522impl RelayService {
523    fn build_response(&self) -> http::response::Builder {
524        let mut res = Response::builder();
525        for (key, value) in self.0.headers.iter() {
526            res = res.header(key, value);
527        }
528        res
529    }
530
531    /// Upgrades the HTTP connection to the relay protocol, runs relay client.
532    fn handle_relay_ws_upgrade(
533        &self,
534        mut req: Request<Incoming>,
535    ) -> Result<Response<BytesBody>, RelayUpgradeReqError> {
536        fn expect_header(
537            req: &Request<Incoming>,
538            header: http::HeaderName,
539        ) -> Result<&HeaderValue, RelayUpgradeReqError> {
540            req.headers()
541                .get(&header)
542                .ok_or_else(|| e!(RelayUpgradeReqError::MissingHeader { header }))
543        }
544
545        let upgrade_header = expect_header(&req, UPGRADE)?;
546        ensure!(
547            upgrade_header == HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL),
548            RelayUpgradeReqError::InvalidHeader {
549                header: UPGRADE,
550                details: format!("value must be {WEBSOCKET_UPGRADE_PROTOCOL}")
551            }
552        );
553
554        let key = expect_header(&req, SEC_WEBSOCKET_KEY)?.clone();
555        let version = expect_header(&req, SEC_WEBSOCKET_VERSION)?.clone();
556
557        ensure!(
558            version.as_bytes() == SUPPORTED_WEBSOCKET_VERSION.as_bytes(),
559            RelayUpgradeReqError::UnsupportedWebsocketVersion
560        );
561
562        let subprotocols = expect_header(&req, SEC_WEBSOCKET_PROTOCOL)?
563            .to_str()
564            .ok()
565            .ok_or_else(|| {
566                e!(RelayUpgradeReqError::InvalidHeader {
567                    header: SEC_WEBSOCKET_PROTOCOL,
568                    details: "header value is not ascii".to_string()
569                })
570            })?;
571        let supports_our_version = subprotocols
572            .split_whitespace()
573            .any(|p| p == RELAY_PROTOCOL_VERSION);
574        ensure!(
575            supports_our_version,
576            RelayUpgradeReqError::UnsupportedRelayVersion {
577                we_support: RELAY_PROTOCOL_VERSION,
578                you_support: subprotocols.to_string()
579            }
580        );
581
582        let client_auth_header = req.headers().get(CLIENT_AUTH_HEADER).cloned();
583
584        // Setup a future that will eventually receive the upgraded
585        // connection and talk a new protocol, and spawn the future
586        // into the runtime.
587        //
588        // Note: This can't possibly be fulfilled until the 101 response
589        // is returned below, so it's better to spawn this future instead
590        // waiting for it to complete to then return a response.
591        tokio::task::spawn({
592            let this = self.clone();
593            async move {
594                match hyper::upgrade::on(&mut req).await {
595                    Ok(upgraded) => {
596                        if let Err(err) = this
597                            .0
598                            .relay_connection_handler(upgraded, client_auth_header)
599                            .await
600                        {
601                            warn!("error accepting upgraded connection: {err:#}",);
602                        } else {
603                            debug!("upgraded connection completed");
604                        };
605                    }
606                    Err(err) => warn!("upgrade error: {err:#}"),
607                }
608            }
609            .instrument(warn_span!("handler"))
610        });
611
612        // Now return a 101 Response saying we agree to the upgrade to the
613        // websocket upgrade protocol
614        Ok(self
615            .build_response()
616            .status(StatusCode::SWITCHING_PROTOCOLS)
617            .header(
618                UPGRADE,
619                HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL),
620            )
621            .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(&key))
622            .header(
623                SEC_WEBSOCKET_PROTOCOL,
624                HeaderValue::from_static(RELAY_PROTOCOL_VERSION),
625            )
626            .header(CONNECTION, "upgrade")
627            .body(body_full("switching to websocket protocol"))
628            .expect("valid body"))
629    }
630}
631
632impl Service<Request<Incoming>> for RelayService {
633    type Response = Response<BytesBody>;
634    type Error = HyperError;
635    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
636
637    fn call(&self, req: Request<Incoming>) -> Self::Future {
638        // Create a client if the request hits the relay endpoint.
639        if matches!(
640            (req.method(), req.uri().path()),
641            (&hyper::Method::GET, RELAY_PATH)
642        ) {
643            let res = match self.handle_relay_ws_upgrade(req) {
644                Ok(response) => Ok(response),
645                // It's convention to send back the version(s) we *do* support
646                Err(e @ RelayUpgradeReqError::UnsupportedWebsocketVersion { .. }) => self
647                    .build_response()
648                    .status(StatusCode::BAD_REQUEST)
649                    .header(SEC_WEBSOCKET_VERSION, SUPPORTED_WEBSOCKET_VERSION)
650                    .body(body_full(e.to_string())),
651                Err(e) => self
652                    .build_response()
653                    .status(StatusCode::BAD_REQUEST)
654                    .body(body_full(e.to_string())),
655            }
656            .map_err(Into::into);
657            return Box::pin(async move { res });
658        }
659        // Otherwise handle the relay connection as normal.
660
661        // Check all other possible endpoints.
662        let uri = req.uri().clone();
663        if let Some(res) = self.0.handlers.get(&(req.method().clone(), uri.path())) {
664            let f = res(req, self.0.default_response());
665            return Box::pin(async move { f });
666        }
667        // Otherwise return 404
668        let res = self.0.not_found_fn(req, self.0.default_response());
669        Box::pin(async move { res })
670    }
671}
672
673impl Inner {
674    fn default_response(&self) -> ResponseBuilder {
675        let mut response = Response::builder();
676        for (key, value) in self.headers.iter() {
677            response = response.header(key.clone(), value.clone());
678        }
679        response
680    }
681
682    fn not_found_fn(
683        &self,
684        _req: Request<Incoming>,
685        mut res: ResponseBuilder,
686    ) -> HyperResult<Response<BytesBody>> {
687        for (k, v) in self.headers.iter() {
688            res = res.header(k.clone(), v.clone());
689        }
690        let body = body_full("Not Found");
691        let r = res.status(StatusCode::NOT_FOUND).body(body)?;
692        HyperResult::Ok(r)
693    }
694
695    /// The server HTTP handler to do HTTP upgrades.
696    ///
697    /// This handler runs while doing the connection upgrade handshake.  Once the connection
698    /// is upgraded it sends the stream to the relay server which takes it over.  After
699    /// having sent off the connection this handler returns.
700    async fn relay_connection_handler(
701        &self,
702        upgraded: Upgraded,
703        client_auth_header: Option<HeaderValue>,
704    ) -> Result<(), ConnectionHandlerError> {
705        debug!("relay_connection upgraded");
706        let (io, read_buf) = downcast_upgrade(upgraded)?;
707        if !read_buf.is_empty() {
708            return Err(e!(ConnectionHandlerError::BufferNotEmpty { buf: read_buf }));
709        }
710
711        self.accept(io, client_auth_header).await?;
712        Ok(())
713    }
714
715    /// Adds a new connection to the server and serves it.
716    ///
717    /// Will error if it takes too long (10 sec) to write or read to the connection, if there is
718    /// some read or write error to the connection,  if the server is meant to verify clients,
719    /// and is unable to verify this one, or if there is some issue communicating with the server.
720    ///
721    /// The provided [`AsyncRead`] and [`AsyncWrite`] must be already connected to the connection.
722    ///
723    /// [`AsyncRead`]: tokio::io::AsyncRead
724    /// [`AsyncWrite`]: tokio::io::AsyncWrite
725    async fn accept(
726        &self,
727        io: MaybeTlsStream,
728        client_auth_header: Option<HeaderValue>,
729    ) -> Result<(), AcceptError> {
730        trace!("accept: start");
731
732        // Set the socket to NO_DELAY.
733        io.disable_nagle();
734
735        let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone())
736            .map_err(|err| e!(AcceptError::RateLimitingMisconfigured, err))?;
737
738        // Create a server builder with default config
739        let websocket = tokio_websockets::ServerBuilder::new()
740            .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE)))
741            // Serve will create a WebSocketStream on an already upgraded connection
742            .serve(io);
743
744        let mut io = WsBytesFramed { io: websocket };
745
746        let authentication = handshake::serverside(&mut io, client_auth_header).await?;
747
748        trace!(?authentication.mechanism, "accept: verified authentication");
749
750        let is_authorized = self.access.is_allowed(authentication.client_key).await;
751        let client_key = authentication.authorize_if(is_authorized, &mut io).await?;
752
753        trace!("accept: verified authorization");
754
755        let io = RelayedStream {
756            inner: io,
757            key_cache: self.key_cache.clone(),
758        };
759
760        trace!("accept: build client conn");
761        let client_conn_builder = Config {
762            endpoint_id: client_key,
763            stream: io,
764            write_timeout: self.write_timeout,
765            channel_capacity: PER_CLIENT_SEND_QUEUE_DEPTH,
766        };
767        trace!("accept: create client");
768        let endpoint_id = client_conn_builder.endpoint_id;
769        trace!(endpoint_id = %endpoint_id.fmt_short(), "create client");
770
771        // build and register client, starting up read & write loops for the client
772        // connection
773        self.clients
774            .register(client_conn_builder, self.metrics.clone());
775        Ok(())
776    }
777}
778
779/// TLS Certificate Authority acceptor.
780#[derive(Clone, derive_more::Debug)]
781pub(super) enum TlsAcceptor {
782    /// Uses Let's Encrypt as the Certificate Authority. This is used in production.
783    LetsEncrypt(#[debug("tokio_rustls_acme::AcmeAcceptor")] AcmeAcceptor),
784    /// Manually added tls acceptor. Generally used for tests or for when we've passed in
785    /// a certificate via a file.
786    Manual(#[debug("tokio_rustls::TlsAcceptor")] tokio_rustls::TlsAcceptor),
787}
788
789impl RelayService {
790    /// Creates a new RelayService.
791    ///
792    /// This allows embedding the relay service into an existing HTTP server.
793    pub fn new(
794        handlers: Handlers,
795        headers: HeaderMap,
796        rate_limit: Option<ClientRateLimit>,
797        key_cache: KeyCache,
798        access: AccessConfig,
799        metrics: Arc<Metrics>,
800    ) -> Self {
801        Self(Arc::new(Inner {
802            handlers,
803            headers,
804            clients: Clients::default(),
805            write_timeout: SERVER_WRITE_TIMEOUT,
806            rate_limit,
807            key_cache,
808            access,
809            metrics,
810        }))
811    }
812
813    /// Shuts down the relay service, disconnecting all clients.
814    pub async fn shutdown(&self) {
815        self.0.clients.shutdown().await;
816    }
817
818    /// Handle the incoming connection.
819    ///
820    /// If a `tls_config` is given, will serve the connection using HTTPS, otherwise HTTP.
821    ///
822    /// # Example
823    ///
824    /// ```no_run
825    /// # use std::sync::Arc;
826    /// # use tokio::net::TcpStream;
827    /// # use http::HeaderMap;
828    /// # use iroh_relay::server::http_server::{Handlers, RelayService, TlsConfig};
829    /// # use iroh_relay::{KeyCache, server::{AccessConfig, Metrics}};
830    /// # use webpki_types::{CertificateDer, PrivateKeyDer};
831    /// # async fn example(stream: TcpStream) -> Result<(), Box<dyn std::error::Error>> {
832    /// // Create a relay service
833    /// let handlers = Handlers::default();
834    /// let headers = HeaderMap::new();
835    /// let key_cache = KeyCache::new(1024);
836    /// let metrics = Arc::new(Metrics::default());
837    /// let relay_service = RelayService::new(
838    ///     handlers,
839    ///     headers,
840    ///     None, // No rate limiting
841    ///     key_cache,
842    ///     AccessConfig::Everyone,
843    ///     metrics,
844    /// );
845    ///
846    /// // Generate a self-signed certificate for HTTPS
847    /// let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])?;
848    /// let cert_der = cert.cert.der().to_vec();
849    /// let private_key_der = cert.signing_key.serialize_der();
850    /// let cert_chain = vec![CertificateDer::from(cert_der)];
851    /// let private_key = PrivateKeyDer::try_from(private_key_der)?;
852    ///
853    /// // Serve with HTTPS
854    /// let server_config = Arc::new(
855    ///     rustls::ServerConfig::builder()
856    ///         .with_no_client_auth()
857    ///         .with_single_cert(cert_chain, private_key)?,
858    /// );
859    /// let tls_config = TlsConfig::new(server_config);
860    /// relay_service
861    ///     .clone()
862    ///     .handle_connection(stream, Some(tls_config))
863    ///     .await;
864    ///
865    /// // Or serve with plain HTTP
866    /// # let stream = TcpStream::connect("127.0.0.1:0").await?;
867    /// relay_service.handle_connection(stream, None).await;
868    /// # Ok(())
869    /// # }
870    /// ```
871    pub async fn handle_connection(self, stream: TcpStream, tls_config: Option<TlsConfig>) {
872        let res = match tls_config {
873            Some(tls_config) => {
874                debug!("HTTPS: serve connection");
875                self.tls_serve_connection(stream, tls_config).await
876            }
877            None => {
878                debug!("HTTP: serve connection");
879                self.serve_connection(MaybeTlsStream::Plain(stream))
880                    .await
881                    .map_err(|err| e!(ServeConnectionError::Http, err))
882            }
883        };
884        match res {
885            Ok(()) => {}
886            Err(error) => match error {
887                ServeConnectionError::ManualAccept { source, .. }
888                | ServeConnectionError::LetsEncryptAccept { source, .. }
889                    if source.kind() == std::io::ErrorKind::UnexpectedEof =>
890                {
891                    debug!(reason=?source, "peer disconnected");
892                }
893                // From hyper: <https://github.com/hyperium/hyper/commit/271bba16672ff54a44e043c5cc1ae6b9345bb172>
894                // `hyper::Error::IncompleteMessage` is hyper's equivalent of UnexpectedEof
895                ServeConnectionError::Https { source, .. }
896                | ServeConnectionError::Http { source, .. }
897                    if source.is_incomplete_message() =>
898                {
899                    debug!(reason=?source, "peer disconnected");
900                }
901                _ => {
902                    error!(?error, "failed to handle connection");
903                }
904            },
905        }
906    }
907
908    /// Serve the tls connection
909    async fn tls_serve_connection(
910        self,
911        stream: TcpStream,
912        tls_config: TlsConfig,
913    ) -> Result<(), ServeConnectionError> {
914        let TlsConfig { acceptor, config } = tls_config;
915        match acceptor {
916            TlsAcceptor::LetsEncrypt(a) => {
917                match a
918                    .accept(stream)
919                    .await
920                    .map_err(|err| e!(ServeConnectionError::LetsEncryptAccept, err))?
921                {
922                    None => {
923                        info!("TLS[acme]: received TLS-ALPN-01 validation request");
924                    }
925                    Some(start_handshake) => {
926                        debug!("TLS[acme]: start handshake");
927                        let tls_stream = start_handshake
928                            .into_stream(config)
929                            .await
930                            .map_err(|err| e!(ServeConnectionError::TlsHandshake, err))?;
931                        self.serve_connection(MaybeTlsStream::Tls(tls_stream))
932                            .await
933                            .map_err(|err| e!(ServeConnectionError::Https, err))?;
934                    }
935                }
936            }
937            TlsAcceptor::Manual(a) => {
938                debug!("TLS[manual]: accept");
939                let tls_stream = tokio::time::timeout(Duration::from_secs(30), a.accept(stream))
940                    .await
941                    .map_err(|err| e!(ServeConnectionError::Timeout, err))?
942                    .map_err(|err| e!(ServeConnectionError::ManualAccept, err))?;
943
944                self.serve_connection(MaybeTlsStream::Tls(tls_stream))
945                    .await
946                    .map_err(|err| e!(ServeConnectionError::ServeConnection, err))?;
947            }
948        }
949        Ok(())
950    }
951
952    /// Wrapper for the actual http connection (with upgrades)
953    async fn serve_connection<I>(self, io: I) -> Result<(), hyper::Error>
954    where
955        I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + 'static,
956    {
957        hyper::server::conn::http1::Builder::new()
958            .serve_connection(hyper_util::rt::TokioIo::new(io), self)
959            .with_upgrades()
960            .await
961    }
962}
963
964/// A collection of HTTP request handlers for custom endpoints.
965#[derive(Default)]
966pub struct Handlers(HashMap<(Method, &'static str), HyperHandler>);
967
968impl std::fmt::Debug for Handlers {
969    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
970        let s = self.0.keys().fold(String::new(), |curr, next| {
971            let (method, uri) = next;
972            format!("{curr}\n({method},{uri}): Box<Fn(ResponseBuilder) -> Result<Response<Body>> + Send + Sync + 'static>")
973        });
974        write!(f, "HashMap<{s}>")
975    }
976}
977
978impl std::ops::Deref for Handlers {
979    type Target = HashMap<(Method, &'static str), HyperHandler>;
980
981    fn deref(&self) -> &Self::Target {
982        &self.0
983    }
984}
985
986impl std::ops::DerefMut for Handlers {
987    fn deref_mut(&mut self) -> &mut Self::Target {
988        &mut self.0
989    }
990}
991
992#[cfg(test)]
993mod tests {
994    use std::sync::Arc;
995
996    use iroh_base::{PublicKey, SecretKey};
997    use n0_error::{Result, StdResultExt, bail_any};
998    use n0_future::{SinkExt, StreamExt};
999    use n0_tracing_test::traced_test;
1000    use rand::SeedableRng;
1001    use reqwest::Url;
1002    use tracing::info;
1003
1004    use super::*;
1005    use crate::{
1006        client::{Client, ClientBuilder, ConnectError, conn::Conn},
1007        dns::DnsResolver,
1008        protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg},
1009        tls::{CaRootsConfig, default_provider},
1010    };
1011
1012    pub(crate) fn make_tls_config() -> TlsConfig {
1013        let subject_alt_names = vec!["localhost".to_string()];
1014
1015        let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
1016        let rustls_certificate = cert.cert.der().clone();
1017        let rustls_key =
1018            rustls::pki_types::PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
1019        let config = rustls::ServerConfig::builder_with_provider(Arc::new(
1020            rustls::crypto::ring::default_provider(),
1021        ))
1022        .with_safe_default_protocol_versions()
1023        .expect("protocols supported by ring")
1024        .with_no_client_auth()
1025        .with_single_cert(vec![(rustls_certificate)], rustls_key.into())
1026        .expect("cert is right");
1027
1028        TlsConfig::new(Arc::new(config))
1029    }
1030
1031    #[tokio::test]
1032    #[traced_test]
1033    async fn test_http_clients_and_server() -> Result {
1034        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1035
1036        let a_key = SecretKey::generate(&mut rng);
1037        let b_key = SecretKey::generate(&mut rng);
1038
1039        // start server
1040        let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1041            .spawn()
1042            .await?;
1043
1044        let addr = server.addr();
1045
1046        // get dial info
1047        let port = addr.port();
1048        let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
1049            ipv4_addr
1050        } else {
1051            bail_any!("cannot get ipv4 addr from socket addr {addr:?}");
1052        };
1053
1054        info!("addr: {addr}:{port}");
1055        let relay_addr: Url = format!("http://{addr}:{port}").parse().unwrap();
1056
1057        // create clients
1058        let (a_key, mut client_a) = create_test_client(a_key, relay_addr.clone()).await?;
1059        info!("created client {a_key:?}");
1060        let (b_key, mut client_b) = create_test_client(b_key, relay_addr).await?;
1061        info!("created client {b_key:?}");
1062
1063        info!("ping a");
1064        client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?;
1065        let pong = client_a.next().await.expect("eos")?;
1066        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1067
1068        info!("ping b");
1069        client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?;
1070        let pong = client_b.next().await.expect("eos")?;
1071        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1072
1073        info!("sending message from a to b");
1074        let msg = Datagrams::from(b"hi there, client b!");
1075        client_a
1076            .send(ClientToRelayMsg::Datagrams {
1077                dst_endpoint_id: b_key,
1078                datagrams: msg.clone(),
1079            })
1080            .await?;
1081        info!("waiting for message from a on b");
1082        let (got_key, got_msg) =
1083            process_msg(client_b.next().await).expect("expected message from client_a");
1084        assert_eq!(a_key, got_key);
1085        assert_eq!(msg, got_msg);
1086
1087        info!("sending message from b to a");
1088        let msg = Datagrams::from(b"right back at ya, client b!");
1089        client_b
1090            .send(ClientToRelayMsg::Datagrams {
1091                dst_endpoint_id: a_key,
1092                datagrams: msg.clone(),
1093            })
1094            .await?;
1095        info!("waiting for message b on a");
1096        let (got_key, got_msg) =
1097            process_msg(client_a.next().await).expect("expected message from client_b");
1098        assert_eq!(b_key, got_key);
1099        assert_eq!(msg, got_msg);
1100
1101        // Close before shutting down, otherwise we'll try to send close frames on broken pipes
1102        client_a.close().await?;
1103        client_b.close().await?;
1104        server.shutdown();
1105
1106        Ok(())
1107    }
1108
1109    async fn create_test_client(
1110        key: SecretKey,
1111        server_url: Url,
1112    ) -> Result<(PublicKey, Client), ConnectError> {
1113        let public_key = key.public();
1114        let client = ClientBuilder::new(server_url, key, DnsResolver::new()).tls_client_config(
1115            CaRootsConfig::insecure_skip_verify()
1116                .client_config(default_provider())
1117                .expect("infallible"),
1118        );
1119        let client = client.connect().await?;
1120
1121        Ok((public_key, client))
1122    }
1123
1124    fn process_msg(
1125        msg: Option<Result<RelayToClientMsg, crate::client::RecvError>>,
1126    ) -> Option<(PublicKey, Datagrams)> {
1127        match msg {
1128            Some(Err(e)) => {
1129                info!("client `recv` error {e}");
1130                None
1131            }
1132            Some(Ok(msg)) => {
1133                info!("got message on: {msg:?}");
1134                if let RelayToClientMsg::Datagrams {
1135                    remote_endpoint_id: source,
1136                    datagrams,
1137                } = msg
1138                {
1139                    Some((source, datagrams))
1140                } else {
1141                    None
1142                }
1143            }
1144            None => {
1145                info!("client end of stream");
1146                None
1147            }
1148        }
1149    }
1150
1151    #[tokio::test]
1152    #[traced_test]
1153    async fn test_https_clients_and_server() -> Result {
1154        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1155
1156        let a_key = SecretKey::generate(&mut rng);
1157        let b_key = SecretKey::generate(&mut rng);
1158
1159        // create tls_config
1160        let tls_config = make_tls_config();
1161
1162        // start server
1163        let mut server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1164            .tls_config(Some(tls_config))
1165            .spawn()
1166            .await?;
1167
1168        let addr = server.addr();
1169
1170        // get dial info
1171        let port = addr.port();
1172        let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
1173            ipv4_addr
1174        } else {
1175            bail_any!("cannot get ipv4 addr from socket addr {addr:?}");
1176        };
1177
1178        info!("Relay listening on: {addr}:{port}");
1179
1180        let url: Url = format!("https://localhost:{port}").parse().unwrap();
1181
1182        // create clients
1183        let (a_key, mut client_a) = create_test_client(a_key, url.clone()).await?;
1184        info!("created client {a_key:?}");
1185        let (b_key, mut client_b) = create_test_client(b_key, url).await?;
1186        info!("created client {b_key:?}");
1187
1188        info!("ping a");
1189        client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?;
1190        let pong = client_a.next().await.expect("eos")?;
1191        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1192
1193        info!("ping b");
1194        client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?;
1195        let pong = client_b.next().await.expect("eos")?;
1196        assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1197
1198        info!("sending message from a to b");
1199        let msg = Datagrams::from(b"hi there, client b!");
1200        client_a
1201            .send(ClientToRelayMsg::Datagrams {
1202                dst_endpoint_id: b_key,
1203                datagrams: msg.clone(),
1204            })
1205            .await?;
1206        info!("waiting for message from a on b");
1207        let (got_key, got_msg) =
1208            process_msg(client_b.next().await).expect("expected message from client_a");
1209        assert_eq!(a_key, got_key);
1210        assert_eq!(msg, got_msg);
1211
1212        info!("sending message from b to a");
1213        let msg = Datagrams::from(b"right back at ya, client b!");
1214        client_b
1215            .send(ClientToRelayMsg::Datagrams {
1216                dst_endpoint_id: a_key,
1217                datagrams: msg.clone(),
1218            })
1219            .await?;
1220        info!("waiting for message b on a");
1221        let (got_key, got_msg) =
1222            process_msg(client_a.next().await).expect("expected message from client_b");
1223        assert_eq!(b_key, got_key);
1224        assert_eq!(msg, got_msg);
1225
1226        // Close before shutting down, otherwise we'll try to send close frames on broken pipes
1227        client_a.close().await?;
1228        client_b.close().await?;
1229        server.shutdown();
1230        server.task_handle().await.std_context("join")?;
1231
1232        Ok(())
1233    }
1234
1235    async fn make_test_client(client: tokio::io::DuplexStream, key: &SecretKey) -> Result<Conn> {
1236        let client = crate::client::streams::MaybeTlsStream::Test(client);
1237        let client = tokio_websockets::ClientBuilder::new().take_over(client);
1238        let client = Conn::new(client, KeyCache::test(), key).await?;
1239        Ok(client)
1240    }
1241
1242    #[tokio::test]
1243    #[traced_test]
1244    async fn test_server_basic() -> Result {
1245        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1246
1247        info!("Create the server.");
1248        let metrics = Arc::new(Metrics::default());
1249        let service = RelayService::new(
1250            Default::default(),
1251            Default::default(),
1252            None,
1253            KeyCache::test(),
1254            AccessConfig::Everyone,
1255            metrics.clone(),
1256        );
1257
1258        info!("Create client A and connect it to the server.");
1259        let key_a = SecretKey::generate(&mut rng);
1260        let public_key_a = key_a.public();
1261        let (client_a, rw_a) = tokio::io::duplex(10);
1262        let s = service.clone();
1263        let handler_task =
1264            tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a), None).await });
1265        let mut client_a = make_test_client(client_a, &key_a).await?;
1266        handler_task.await.std_context("join")??;
1267
1268        info!("Create client B and connect it to the server.");
1269        let key_b = SecretKey::generate(&mut rng);
1270        let public_key_b = key_b.public();
1271        let (client_b, rw_b) = tokio::io::duplex(10);
1272        let s = service.clone();
1273        let handler_task =
1274            tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b), None).await });
1275        let mut client_b = make_test_client(client_b, &key_b).await?;
1276        handler_task.await.std_context("join")??;
1277
1278        info!("Send message from A to B.");
1279        let msg = Datagrams::from(b"hello client b!!");
1280        client_a
1281            .send(ClientToRelayMsg::Datagrams {
1282                dst_endpoint_id: public_key_b,
1283                datagrams: msg.clone(),
1284            })
1285            .await?;
1286        match client_b.next().await.unwrap()? {
1287            RelayToClientMsg::Datagrams {
1288                remote_endpoint_id,
1289                datagrams,
1290            } => {
1291                assert_eq!(public_key_a, remote_endpoint_id);
1292                assert_eq!(msg, datagrams);
1293            }
1294            msg => {
1295                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1296            }
1297        }
1298
1299        info!("Send message from B to A.");
1300        let msg = Datagrams::from(b"nice to meet you client a!!");
1301        client_b
1302            .send(ClientToRelayMsg::Datagrams {
1303                dst_endpoint_id: public_key_a,
1304                datagrams: msg.clone(),
1305            })
1306            .await?;
1307        match client_a.next().await.unwrap()? {
1308            RelayToClientMsg::Datagrams {
1309                remote_endpoint_id,
1310                datagrams,
1311            } => {
1312                assert_eq!(public_key_b, remote_endpoint_id);
1313                assert_eq!(msg, datagrams);
1314            }
1315            msg => {
1316                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1317            }
1318        }
1319
1320        info!("Close the server and clients");
1321        service.shutdown().await;
1322        tokio::time::sleep(Duration::from_secs(1)).await;
1323
1324        info!("Fail to send message from A to B.");
1325        let res = client_a
1326            .send(ClientToRelayMsg::Datagrams {
1327                dst_endpoint_id: public_key_b,
1328                datagrams: Datagrams::from(b"try to send"),
1329            })
1330            .await;
1331        assert!(res.is_err());
1332        assert!(client_b.next().await.is_none());
1333
1334        drop(client_a);
1335        drop(client_b);
1336
1337        service.shutdown().await;
1338
1339        assert_eq!(metrics.accepts.get(), metrics.disconnects.get());
1340
1341        Ok(())
1342    }
1343
1344    #[tokio::test]
1345    async fn test_server_replace_client() -> Result {
1346        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1347
1348        info!("Create the server.");
1349        let service = RelayService::new(
1350            Default::default(),
1351            Default::default(),
1352            None,
1353            KeyCache::test(),
1354            AccessConfig::Everyone,
1355            Default::default(),
1356        );
1357
1358        info!("Create client A and connect it to the server.");
1359        let key_a = SecretKey::generate(&mut rng);
1360        let public_key_a = key_a.public();
1361        let (client_a, rw_a) = tokio::io::duplex(10);
1362        let s = service.clone();
1363        let handler_task =
1364            tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a), None).await });
1365        let mut client_a = make_test_client(client_a, &key_a).await?;
1366        handler_task.await.std_context("join")??;
1367
1368        info!("Create client B and connect it to the server.");
1369        let key_b = SecretKey::generate(&mut rng);
1370        let public_key_b = key_b.public();
1371        let (client_b, rw_b) = tokio::io::duplex(10);
1372        let s = service.clone();
1373        let handler_task =
1374            tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b), None).await });
1375        let mut client_b = make_test_client(client_b, &key_b).await?;
1376        handler_task.await.std_context("join")??;
1377
1378        info!("Send message from A to B.");
1379        let msg = Datagrams::from(b"hello client b!!");
1380        client_a
1381            .send(ClientToRelayMsg::Datagrams {
1382                dst_endpoint_id: public_key_b,
1383                datagrams: msg.clone(),
1384            })
1385            .await?;
1386        match client_b.next().await.expect("eos")? {
1387            RelayToClientMsg::Datagrams {
1388                remote_endpoint_id,
1389                datagrams,
1390            } => {
1391                assert_eq!(public_key_a, remote_endpoint_id);
1392                assert_eq!(msg, datagrams);
1393            }
1394            msg => {
1395                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1396            }
1397        }
1398
1399        info!("Send message from B to A.");
1400        let msg = Datagrams::from(b"nice to meet you client a!!");
1401        client_b
1402            .send(ClientToRelayMsg::Datagrams {
1403                dst_endpoint_id: public_key_a,
1404                datagrams: msg.clone(),
1405            })
1406            .await?;
1407        match client_a.next().await.expect("eos")? {
1408            RelayToClientMsg::Datagrams {
1409                remote_endpoint_id,
1410                datagrams,
1411            } => {
1412                assert_eq!(public_key_b, remote_endpoint_id);
1413                assert_eq!(msg, datagrams);
1414            }
1415            msg => {
1416                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1417            }
1418        }
1419
1420        info!("Create client B and connect it to the server");
1421        let (new_client_b, new_rw_b) = tokio::io::duplex(10);
1422        let s = service.clone();
1423        let handler_task =
1424            tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(new_rw_b), None).await });
1425        let mut new_client_b = make_test_client(new_client_b, &key_b).await?;
1426        handler_task.await.std_context("join")??;
1427
1428        // assert!(client_b.recv().await.is_err());
1429
1430        info!("Send message from A to B.");
1431        let msg = Datagrams::from(b"are you still there, b?!");
1432        client_a
1433            .send(ClientToRelayMsg::Datagrams {
1434                dst_endpoint_id: public_key_b,
1435                datagrams: msg.clone(),
1436            })
1437            .await?;
1438        match new_client_b.next().await.expect("eos")? {
1439            RelayToClientMsg::Datagrams {
1440                remote_endpoint_id,
1441                datagrams,
1442            } => {
1443                assert_eq!(public_key_a, remote_endpoint_id);
1444                assert_eq!(msg, datagrams);
1445            }
1446            msg => {
1447                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1448            }
1449        }
1450
1451        info!("Send message from B to A.");
1452        let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!");
1453        new_client_b
1454            .send(ClientToRelayMsg::Datagrams {
1455                dst_endpoint_id: public_key_a,
1456                datagrams: msg.clone(),
1457            })
1458            .await?;
1459        match client_a.next().await.expect("eos")? {
1460            RelayToClientMsg::Datagrams {
1461                remote_endpoint_id,
1462                datagrams,
1463            } => {
1464                assert_eq!(public_key_b, remote_endpoint_id);
1465                assert_eq!(msg, datagrams);
1466            }
1467            msg => {
1468                bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1469            }
1470        }
1471
1472        info!("Close the server and clients");
1473        service.shutdown().await;
1474
1475        info!("Sending message from A to B fails");
1476        let res = client_a
1477            .send(ClientToRelayMsg::Datagrams {
1478                dst_endpoint_id: public_key_b,
1479                datagrams: Datagrams::from(b"try to send"),
1480            })
1481            .await;
1482        assert!(res.is_err());
1483        assert!(new_client_b.next().await.is_none());
1484        Ok(())
1485    }
1486}