iroh_relay/
server.rs

1//! A fully-fledged iroh-relay server over HTTP or HTTPS.
2//!
3//! This module provides an API to run a full fledged iroh-relay server.  It is primarily
4//! used by the `iroh-relay` binary in this crate.  It can be used to run a relay server in
5//! other locations however.
6//!
7//! This code is fully written in a form of structured-concurrency: every spawned task is
8//! always attached to a handle and when the handle is dropped the tasks abort.  So tasks
9//! can not outlive their handle.  It is also always possible to await for completion of a
10//! task.  Some tasks additionally have a method to do graceful shutdown.
11//!
12//! The relay server hosts the following services:
13//!
14//! - HTTPS `/relay`: The main URL endpoint to which clients connect and sends traffic over.
15//! - HTTPS `/ping`: Used for net_report probes.
16//! - HTTPS `/generate_204`: Used for net_report probes.
17
18use std::{fmt, future::Future, net::SocketAddr, num::NonZeroU32, pin::Pin, sync::Arc};
19
20use derive_more::Debug;
21use http::{
22    HeaderMap, HeaderValue, Method, Request, Response, StatusCode, header::InvalidHeaderValue,
23    response::Builder as ResponseBuilder,
24};
25use hyper::body::Incoming;
26use iroh_base::EndpointId;
27#[cfg(feature = "test-utils")]
28use iroh_base::RelayUrl;
29use n0_error::{e, stack_error};
30use n0_future::{StreamExt, future::Boxed};
31use serde::Serialize;
32use tokio::{
33    net::TcpListener,
34    task::{JoinError, JoinSet},
35};
36use tokio_util::task::AbortOnDropHandle;
37use tracing::{Instrument, debug, error, info, info_span, instrument};
38
39use crate::{
40    defaults::DEFAULT_KEY_CACHE_CAPACITY,
41    http::RELAY_PROBE_PATH,
42    quic::server::{QuicServer, QuicSpawnError, ServerHandle as QuicServerHandle},
43};
44
45mod client;
46mod clients;
47mod http_server;
48mod metrics;
49pub(crate) mod resolver;
50pub(crate) mod streams;
51#[cfg(feature = "test-utils")]
52pub mod testing;
53
54pub use self::{
55    metrics::{Metrics, RelayMetrics},
56    resolver::{DEFAULT_CERT_RELOAD_INTERVAL, ReloadingResolver},
57};
58
59const NO_CONTENT_CHALLENGE_HEADER: &str = "X-Iroh-Challenge";
60const NO_CONTENT_RESPONSE_HEADER: &str = "X-Iroh-Response";
61const NOTFOUND: &[u8] = b"Not Found";
62const ROBOTS_TXT: &[u8] = b"User-agent: *\nDisallow: /\n";
63const INDEX: &[u8] = br#"<html><body>
64<h1>Iroh Relay</h1>
65<p>
66  This is an <a href="https://iroh.computer/">Iroh</a> Relay server.
67</p>
68"#;
69const TLS_HEADERS: [(&str, &str); 2] = [
70    (
71        "Strict-Transport-Security",
72        "max-age=63072000; includeSubDomains",
73    ),
74    (
75        "Content-Security-Policy",
76        "default-src 'none'; frame-ancestors 'none'; form-action 'none'; base-uri 'self'; block-all-mixed-content; plugin-types 'none'",
77    ),
78];
79
80type BytesBody = http_body_util::Full<hyper::body::Bytes>;
81type HyperError = Box<dyn std::error::Error + Send + Sync>;
82type HyperResult<T> = std::result::Result<T, HyperError>;
83
84/// Creates a new [`BytesBody`] with no content.
85fn body_empty() -> BytesBody {
86    http_body_util::Full::new(hyper::body::Bytes::new())
87}
88
89/// Configuration for the full Relay.
90///
91/// Be aware the generic parameters are for when using the Let's Encrypt TLS configuration.
92/// If not used dummy ones need to be provided, e.g. `ServerConfig::<(), ()>::default()`.
93#[derive(Debug, Default)]
94pub struct ServerConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
95    /// Configuration for the Relay server, disabled if `None`.
96    pub relay: Option<RelayConfig<EC, EA>>,
97    /// Configuration for the QUIC server, disabled if `None`.
98    pub quic: Option<QuicConfig>,
99    /// Socket to serve metrics on.
100    #[cfg(feature = "metrics")]
101    pub metrics_addr: Option<SocketAddr>,
102}
103
104/// Configuration for the Relay HTTP and HTTPS server.
105///
106/// This includes the HTTP services hosted by the Relay server, the Relay `/relay` HTTP
107/// endpoint is only one of the services served.
108#[derive(Debug)]
109pub struct RelayConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
110    /// The socket address on which the Relay HTTP server should bind.
111    ///
112    /// Normally you'd choose port `80`.  The bind address for the HTTPS server is
113    /// configured in [`RelayConfig::tls`].
114    ///
115    /// If [`RelayConfig::tls`] is `None` then this serves all the HTTP services without
116    /// TLS.
117    pub http_bind_addr: SocketAddr,
118    /// TLS configuration for the HTTPS server.
119    ///
120    /// If *None* all the HTTP services that would be served here are served from
121    /// [`RelayConfig::http_bind_addr`].
122    pub tls: Option<TlsConfig<EC, EA>>,
123    /// Rate limits.
124    pub limits: Limits,
125    /// Key cache capacity.
126    pub key_cache_capacity: Option<usize>,
127    /// Access configuration.
128    pub access: AccessConfig,
129}
130
131/// Controls which endpoints are allowed to use the relay.
132#[derive(derive_more::Debug)]
133pub enum AccessConfig {
134    /// Everyone
135    Everyone,
136    /// Only endpoints for which the function returns `Access::Allow`.
137    #[debug("restricted")]
138    Restricted(Box<dyn Fn(EndpointId) -> Boxed<Access> + Send + Sync + 'static>),
139}
140
141impl AccessConfig {
142    /// Is this endpoint allowed?
143    pub async fn is_allowed(&self, endpoint: EndpointId) -> bool {
144        match self {
145            Self::Everyone => true,
146            Self::Restricted(check) => {
147                let res = check(endpoint).await;
148                matches!(res, Access::Allow)
149            }
150        }
151    }
152}
153
154/// Access restriction for an endpoint.
155#[derive(Debug, Copy, Clone, PartialEq, Eq)]
156pub enum Access {
157    /// Access is allowed.
158    Allow,
159    /// Access is denied.
160    Deny,
161}
162
163/// Configuration for the QUIC server.
164#[derive(Debug)]
165pub struct QuicConfig {
166    /// The socket address on which the QUIC server should bind.
167    ///
168    /// Normally you'd chose port `7842`, see [`crate::defaults::DEFAULT_RELAY_QUIC_PORT`].
169    pub bind_addr: SocketAddr,
170    /// The TLS server configuration for the QUIC server.
171    ///
172    /// If this [`rustls::ServerConfig`] does not support TLS 1.3, the QUIC server will fail
173    /// to spawn.
174    pub server_config: rustls::ServerConfig,
175}
176
177/// TLS configuration for Relay server.
178///
179/// Normally the Relay server accepts connections on both HTTPS and HTTP.
180#[derive(Debug)]
181pub struct TlsConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
182    /// The socket address on which to serve the HTTPS server.
183    ///
184    /// Since the captive portal probe has to run over plain text HTTP and TLS is used for
185    /// the main relay server this has to be on a different port.  When TLS is not enabled
186    /// this is served on the [`RelayConfig::http_bind_addr`] socket address.
187    ///
188    /// Normally you'd choose port `80`.
189    pub https_bind_addr: SocketAddr,
190    /// The socket address on which to server the QUIC server is QUIC is enabled.
191    pub quic_bind_addr: SocketAddr,
192    /// Mode for getting a cert.
193    pub cert: CertConfig<EC, EA>,
194    /// The server configuration.
195    pub server_config: rustls::ServerConfig,
196}
197
198/// Rate limits.
199// TODO: accept_conn_limit and accept_conn_burst are not currently implemented.
200#[derive(Debug, Default)]
201pub struct Limits {
202    /// Rate limit for accepting new connection. Unlimited if not set.
203    pub accept_conn_limit: Option<f64>,
204    /// Burst limit for accepting new connection. Unlimited if not set.
205    pub accept_conn_burst: Option<usize>,
206    /// Rate limits for incoming traffic from a client connection.
207    pub client_rx: Option<ClientRateLimit>,
208}
209
210/// Per-client rate limit configuration.
211#[derive(Debug, Copy, Clone)]
212pub struct ClientRateLimit {
213    /// Max number of bytes per second to read from the client connection.
214    pub bytes_per_second: NonZeroU32,
215    /// Max number of bytes to read in a single burst.
216    pub max_burst_bytes: Option<NonZeroU32>,
217}
218
219/// TLS certificate configuration.
220#[derive(derive_more::Debug)]
221pub enum CertConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
222    /// Use Let's Encrypt.
223    LetsEncrypt {
224        /// State for Let's Encrypt certificates.
225        #[debug("AcmeConfig")]
226        state: tokio_rustls_acme::AcmeState<EC, EA>,
227    },
228    /// Use a static TLS key and certificate chain.
229    Manual {
230        /// The TLS certificate chain.
231        certs: Vec<rustls::pki_types::CertificateDer<'static>>,
232    },
233    /// Use a TLS key and certificate chain that can be reloaded.
234    Reloading,
235}
236
237/// A running Relay + QAD server.
238///
239/// This is a full Relay server, including QAD, Relay and various associated HTTP services.
240///
241/// Dropping this will stop the server.
242#[derive(Debug)]
243pub struct Server {
244    /// The address of the HTTP server, if configured.
245    http_addr: Option<SocketAddr>,
246    /// The address of the HTTPS server, if the relay server is using TLS.
247    ///
248    /// If the Relay server is not using TLS then it is served from the
249    /// [`Server::http_addr`].
250    https_addr: Option<SocketAddr>,
251    /// The address of the QUIC server, if configured.
252    quic_addr: Option<SocketAddr>,
253    /// Handle to the relay server.
254    relay_handle: Option<http_server::ServerHandle>,
255    /// Handle to the quic server.
256    quic_handle: Option<QuicServerHandle>,
257    /// The main task running the server.
258    supervisor: AbortOnDropHandle<Result<(), SupervisorError>>,
259    /// The certificate for the server.
260    ///
261    /// If the server has manual certificates configured the certificate chain will be
262    /// available here, this can be used by a client to authenticate the server.
263    certificates: Option<Vec<rustls::pki_types::CertificateDer<'static>>>,
264    metrics: RelayMetrics,
265}
266
267/// Server spawn errors
268#[allow(missing_docs)]
269#[stack_error(derive, add_meta, std_sources)]
270#[non_exhaustive]
271pub enum SpawnError {
272    #[error("Unable to get local address")]
273    LocalAddr { source: std::io::Error },
274    #[error("Failed to bind QAD listener")]
275    QuicSpawn { source: QuicSpawnError },
276    #[error("Failed to parse TLS header")]
277    TlsHeaderParse { source: InvalidHeaderValue },
278    #[error("Failed to bind TcpListener")]
279    BindTlsListener { source: std::io::Error },
280    #[error("No local address")]
281    NoLocalAddr { source: std::io::Error },
282    #[error("Failed to bind server socket to {addr}")]
283    BindTcpListener {
284        source: std::io::Error,
285        addr: SocketAddr,
286    },
287}
288
289/// Server task errors
290#[allow(missing_docs)]
291#[stack_error(derive, add_meta)]
292#[non_exhaustive]
293pub enum SupervisorError {
294    #[error("Error starting metrics server")]
295    Metrics {
296        #[error(std_err)]
297        source: std::io::Error,
298    },
299    #[error("Acme event stream finished")]
300    AcmeEventStreamFinished {},
301    #[error(transparent)]
302    JoinError {
303        #[error(from, std_err)]
304        source: JoinError,
305    },
306    #[error("No relay services are enabled")]
307    NoRelayServicesEnabled {},
308    #[error("Task cancelled")]
309    TaskCancelled {},
310}
311
312impl Server {
313    /// Starts the server.
314    pub async fn spawn<EC, EA>(config: ServerConfig<EC, EA>) -> Result<Self, SpawnError>
315    where
316        EC: fmt::Debug + 'static,
317        EA: fmt::Debug + 'static,
318    {
319        let mut tasks = JoinSet::new();
320
321        let metrics = RelayMetrics::default();
322
323        #[cfg(feature = "metrics")]
324        if let Some(addr) = config.metrics_addr {
325            debug!("Starting metrics server");
326            let mut registry = iroh_metrics::Registry::default();
327            registry.register_all(&metrics);
328            tasks.spawn(
329                async move {
330                    iroh_metrics::service::start_metrics_server(addr, Arc::new(registry))
331                        .await
332                        .map_err(|err| e!(SupervisorError::Metrics, err))
333                }
334                .instrument(info_span!("metrics-server")),
335            );
336        }
337
338        // Start the Relay server, but first clone the certs out.
339        let certificates = config.relay.as_ref().and_then(|relay| {
340            relay.tls.as_ref().and_then(|tls| match tls.cert {
341                CertConfig::LetsEncrypt { .. } => None,
342                CertConfig::Manual { ref certs, .. } => Some(certs.clone()),
343                CertConfig::Reloading => None,
344            })
345        });
346
347        let quic_server = match config.quic {
348            Some(quic_config) => {
349                debug!("Starting QUIC server {}", quic_config.bind_addr);
350                Some(QuicServer::spawn(quic_config).map_err(|err| e!(SpawnError::QuicSpawn, err))?)
351            }
352            None => None,
353        };
354        let quic_addr = quic_server.as_ref().map(|srv| srv.bind_addr());
355        let quic_handle = quic_server.as_ref().map(|srv| srv.handle());
356
357        let (relay_server, http_addr) = match config.relay {
358            Some(relay_config) => {
359                debug!("Starting Relay server");
360                let mut headers = HeaderMap::new();
361                for (name, value) in TLS_HEADERS.iter() {
362                    headers.insert(
363                        *name,
364                        value
365                            .parse()
366                            .map_err(|err| e!(SpawnError::TlsHeaderParse, err))?,
367                    );
368                }
369                let relay_bind_addr = match relay_config.tls {
370                    Some(ref tls) => tls.https_bind_addr,
371                    None => relay_config.http_bind_addr,
372                };
373                let key_cache_capacity = relay_config
374                    .key_cache_capacity
375                    .unwrap_or(DEFAULT_KEY_CACHE_CAPACITY);
376                let mut builder = http_server::ServerBuilder::new(relay_bind_addr)
377                    .metrics(metrics.server.clone())
378                    .headers(headers)
379                    .key_cache_capacity(key_cache_capacity)
380                    .access(relay_config.access)
381                    .request_handler(Method::GET, "/", Box::new(root_handler))
382                    .request_handler(Method::GET, "/index.html", Box::new(root_handler))
383                    .request_handler(Method::GET, RELAY_PROBE_PATH, Box::new(probe_handler))
384                    .request_handler(Method::GET, "/robots.txt", Box::new(robots_handler))
385                    .request_handler(Method::GET, "/healthz", Box::new(healthz_handler));
386                if let Some(cfg) = relay_config.limits.client_rx {
387                    builder = builder.client_rx_ratelimit(cfg);
388                }
389                let http_addr = match relay_config.tls {
390                    Some(tls_config) => {
391                        let server_tls_config = match tls_config.cert {
392                            CertConfig::LetsEncrypt { mut state } => {
393                                let acceptor =
394                                    http_server::TlsAcceptor::LetsEncrypt(state.acceptor());
395                                tasks.spawn(
396                                    async move {
397                                        while let Some(event) = state.next().await {
398                                            match event {
399                                                Ok(ok) => debug!("acme event: {ok:?}"),
400                                                Err(err) => error!("error: {err:?}"),
401                                            }
402                                        }
403                                        Err(e!(SupervisorError::AcmeEventStreamFinished))
404                                    }
405                                    .instrument(info_span!("acme")),
406                                );
407                                Some(http_server::TlsConfig {
408                                    config: Arc::new(tls_config.server_config),
409                                    acceptor,
410                                })
411                            }
412                            CertConfig::Manual { .. } | CertConfig::Reloading => {
413                                let server_config = Arc::new(tls_config.server_config);
414                                let acceptor =
415                                    tokio_rustls::TlsAcceptor::from(server_config.clone());
416                                let acceptor = http_server::TlsAcceptor::Manual(acceptor);
417                                Some(http_server::TlsConfig {
418                                    config: server_config,
419                                    acceptor,
420                                })
421                            }
422                        };
423                        builder = builder.tls_config(server_tls_config);
424
425                        // Some services always need to be served over HTTP without TLS.  Run
426                        // these standalone.
427                        let http_listener = TcpListener::bind(&relay_config.http_bind_addr)
428                            .await
429                            .map_err(|err| e!(SpawnError::BindTlsListener, err))?;
430                        let http_addr = http_listener
431                            .local_addr()
432                            .map_err(|err| e!(SpawnError::NoLocalAddr, err))?;
433                        tasks.spawn(
434                            async move {
435                                run_captive_portal_service(http_listener).await;
436                                Ok(())
437                            }
438                            .instrument(info_span!("http-service", addr = %http_addr)),
439                        );
440                        Some(http_addr)
441                    }
442                    None => {
443                        // If running Relay without TLS add the plain HTTP server directly
444                        // to the Relay server.
445                        builder = builder.request_handler(
446                            Method::GET,
447                            "/generate_204",
448                            Box::new(serve_no_content_handler),
449                        );
450                        None
451                    }
452                };
453                let relay_server = builder.spawn().await?;
454                (Some(relay_server), http_addr)
455            }
456            None => (None, None),
457        };
458        // If http_addr is Some then relay_server is serving HTTPS.  If http_addr is None
459        // relay_server is serving HTTP, including the /generate_204 service.
460        let relay_addr = relay_server.as_ref().map(|srv| srv.addr());
461        let relay_handle = relay_server.as_ref().map(|srv| srv.handle());
462        let task = tokio::spawn(relay_supervisor(tasks, relay_server, quic_server));
463
464        Ok(Self {
465            http_addr: http_addr.or(relay_addr),
466            https_addr: http_addr.and(relay_addr),
467            quic_addr,
468            relay_handle,
469            quic_handle,
470            supervisor: AbortOnDropHandle::new(task),
471            certificates,
472            metrics,
473        })
474    }
475
476    /// Requests graceful shutdown.
477    ///
478    /// Returns once all server tasks have stopped.
479    pub async fn shutdown(self) -> Result<(), SupervisorError> {
480        // Only the Relay server and QUIC server need shutting down, the supervisor will abort the tasks in
481        // the JoinSet when the server terminates.
482        if let Some(handle) = self.relay_handle {
483            handle.shutdown();
484        }
485        if let Some(handle) = self.quic_handle {
486            handle.shutdown();
487        }
488        self.supervisor.await?
489    }
490
491    /// Returns the handle for the task.
492    ///
493    /// This allows waiting for the server's supervisor task to finish.  Can be useful in
494    /// case there is an error in the server before it is shut down.
495    pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<Result<(), SupervisorError>> {
496        &mut self.supervisor
497    }
498
499    /// The socket address the HTTPS server is listening on.
500    pub fn https_addr(&self) -> Option<SocketAddr> {
501        self.https_addr
502    }
503
504    /// The socket address the HTTP server is listening on.
505    pub fn http_addr(&self) -> Option<SocketAddr> {
506        self.http_addr
507    }
508
509    /// The socket address the QUIC server is listening on.
510    pub fn quic_addr(&self) -> Option<SocketAddr> {
511        self.quic_addr
512    }
513
514    /// The certificates chain if configured with manual TLS certificates.
515    pub fn certificates(&self) -> Option<Vec<rustls::pki_types::CertificateDer<'static>>> {
516        self.certificates.clone()
517    }
518
519    /// Get the server's https [`RelayUrl`].
520    ///
521    /// This uses [`Self::https_addr`] so it's mostly useful for local development.
522    #[cfg(feature = "test-utils")]
523    pub fn https_url(&self) -> Option<RelayUrl> {
524        self.https_addr.map(|addr| {
525            url::Url::parse(&format!("https://{addr}"))
526                .expect("valid url")
527                .into()
528        })
529    }
530
531    /// Get the server's http [`RelayUrl`].
532    ///
533    /// This uses [`Self::http_addr`] so it's mostly useful for local development.
534    #[cfg(feature = "test-utils")]
535    pub fn http_url(&self) -> Option<RelayUrl> {
536        self.http_addr.map(|addr| {
537            url::Url::parse(&format!("http://{addr}"))
538                .expect("valid url")
539                .into()
540        })
541    }
542
543    /// Returns the metrics collected in the relay server.
544    pub fn metrics(&self) -> &RelayMetrics {
545        &self.metrics
546    }
547}
548
549/// Supervisor for the relay server tasks.
550///
551/// As soon as one of the tasks exits, all other tasks are stopped and the server stops.
552/// The supervisor finishes once all tasks are finished.
553#[instrument(skip_all)]
554async fn relay_supervisor(
555    mut tasks: JoinSet<Result<(), SupervisorError>>,
556    mut relay_http_server: Option<http_server::Server>,
557    mut quic_server: Option<QuicServer>,
558) -> Result<(), SupervisorError> {
559    let quic_enabled = quic_server.is_some();
560    let mut quic_fut = match quic_server {
561        Some(ref mut server) => n0_future::Either::Left(server.task_handle()),
562        None => n0_future::Either::Right(n0_future::future::pending()),
563    };
564    let relay_enabled = relay_http_server.is_some();
565    let mut relay_fut = match relay_http_server {
566        Some(ref mut server) => n0_future::Either::Left(server.task_handle()),
567        None => n0_future::Either::Right(n0_future::future::pending()),
568    };
569    let res = tokio::select! {
570        biased;
571        Some(ret) = tasks.join_next() => ret,
572        ret = &mut quic_fut, if quic_enabled => ret.map(Ok),
573        ret = &mut relay_fut, if relay_enabled => ret.map(Ok),
574        else => Ok(Err(e!(SupervisorError::NoRelayServicesEnabled))),
575    };
576    let ret = match res {
577        Ok(Ok(())) => {
578            debug!("Task exited");
579            Ok(())
580        }
581        Ok(Err(err)) => {
582            error!(%err, "Task failed");
583            Err(err)
584        }
585        Err(err) => {
586            if let Ok(panic) = err.try_into_panic() {
587                error!("Task panicked");
588                std::panic::resume_unwind(panic);
589            }
590            debug!("Task cancelled");
591            Err(e!(SupervisorError::TaskCancelled))
592        }
593    };
594
595    // Ensure the HTTP server terminated, there is no harm in calling this after it is
596    // already shut down.
597    if let Some(server) = relay_http_server {
598        server.shutdown();
599    }
600
601    // Ensure the QUIC server is closed
602    if let Some(server) = quic_server {
603        server.shutdown().await;
604    }
605
606    // Stop all remaining tasks
607    tasks.shutdown().await;
608
609    ret
610}
611
612fn root_handler(
613    _r: Request<Incoming>,
614    response: ResponseBuilder,
615) -> HyperResult<Response<BytesBody>> {
616    response
617        .status(StatusCode::OK)
618        .header("Content-Type", "text/html; charset=utf-8")
619        .body(INDEX.into())
620        .map_err(|err| Box::new(err) as HyperError)
621}
622
623/// HTTP latency queries
624fn probe_handler(
625    _r: Request<Incoming>,
626    response: ResponseBuilder,
627) -> HyperResult<Response<BytesBody>> {
628    response
629        .status(StatusCode::OK)
630        .header("Access-Control-Allow-Origin", "*")
631        .body(body_empty())
632        .map_err(|err| Box::new(err) as HyperError)
633}
634
635fn robots_handler(
636    _r: Request<Incoming>,
637    response: ResponseBuilder,
638) -> HyperResult<Response<BytesBody>> {
639    response
640        .status(StatusCode::OK)
641        .body(ROBOTS_TXT.into())
642        .map_err(|err| Box::new(err) as HyperError)
643}
644
645/// For captive portal detection.
646fn serve_no_content_handler<B: hyper::body::Body>(
647    r: Request<B>,
648    mut response: ResponseBuilder,
649) -> HyperResult<Response<BytesBody>> {
650    let check = |c: &HeaderValue| {
651        !c.is_empty() && c.len() < 64 && c.as_bytes().iter().all(|c| is_challenge_char(*c as char))
652    };
653
654    if let Some(challenge) = r.headers().get(NO_CONTENT_CHALLENGE_HEADER)
655        && check(challenge)
656    {
657        response = response.header(
658            NO_CONTENT_RESPONSE_HEADER,
659            format!("response {}", challenge.to_str()?),
660        );
661    }
662
663    response
664        .status(StatusCode::NO_CONTENT)
665        .body(body_empty())
666        .map_err(|err| Box::new(err) as HyperError)
667}
668
669fn is_challenge_char(c: char) -> bool {
670    // Semi-randomly chosen as a limited set of valid characters
671    c.is_ascii_lowercase()
672        || c.is_ascii_uppercase()
673        || c.is_ascii_digit()
674        || c == '.'
675        || c == '-'
676        || c == '_'
677}
678
679/// Health check response
680#[derive(Serialize)]
681struct Health {
682    status: &'static str,
683    version: &'static str,
684    git_hash: &'static str,
685}
686
687fn healthz_handler(
688    _r: Request<Incoming>,
689    response: ResponseBuilder,
690) -> HyperResult<Response<BytesBody>> {
691    let health = Health {
692        status: "ok",
693        version: env!("CARGO_PKG_VERSION"),
694        git_hash: option_env!("VERGEN_GIT_SHA").unwrap_or("unknown"),
695    };
696    let body = serde_json::to_string(&health).unwrap_or_else(|_| r#"{"status":"error"}"#.into());
697    response
698        .status(StatusCode::OK)
699        .header("Content-Type", "application/json")
700        .body(body.into())
701        .map_err(|err| Box::new(err) as HyperError)
702}
703
704/// This is a future that never returns, drop it to cancel/abort.
705async fn run_captive_portal_service(http_listener: TcpListener) {
706    info!("serving");
707
708    // If this future is cancelled, this is dropped and all tasks are aborted.
709    let mut tasks = JoinSet::new();
710
711    loop {
712        tokio::select! {
713            biased;
714
715            Some(res) = tasks.join_next() => {
716                if let Err(err) = res
717                    && err.is_panic()
718                {
719                    panic!("task panicked: {err:#?}");
720                }
721            }
722
723            res = http_listener.accept() => {
724                match res {
725                    Ok((stream, peer_addr)) => {
726                        debug!(%peer_addr, "Connection opened",);
727                        let handler = CaptivePortalService;
728
729                        tasks.spawn(async move {
730                            let stream = crate::server::streams::MaybeTlsStream::Plain(stream);
731                            let stream = hyper_util::rt::TokioIo::new(stream);
732                            if let Err(err) = hyper::server::conn::http1::Builder::new()
733                                .serve_connection(stream, handler)
734                                .with_upgrades()
735                                .await
736                            {
737                                error!("Failed to serve connection: {err:?}");
738                            }
739                        });
740                    }
741                    Err(err) => {
742                        error!(
743                            "[CaptivePortalService] failed to accept connection: {:#?}",
744                            err
745                        );
746                    }
747                }
748            }
749        }
750    }
751}
752
753#[derive(Clone)]
754struct CaptivePortalService;
755
756impl hyper::service::Service<Request<Incoming>> for CaptivePortalService {
757    type Response = Response<BytesBody>;
758    type Error = HyperError;
759    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
760
761    fn call(&self, req: Request<Incoming>) -> Self::Future {
762        match (req.method(), req.uri().path()) {
763            // Captive Portal checker
764            (&Method::GET, "/generate_204") => {
765                Box::pin(async move { serve_no_content_handler(req, Response::builder()) })
766            }
767            _ => {
768                // Return 404 not found response.
769                let r = Response::builder()
770                    .status(StatusCode::NOT_FOUND)
771                    .body(NOTFOUND.into())
772                    .map_err(|err| Box::new(err) as HyperError);
773                Box::pin(async move { r })
774            }
775        }
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use std::{net::Ipv4Addr, time::Duration};
782
783    use http::StatusCode;
784    use iroh_base::{EndpointId, RelayUrl, SecretKey};
785    use n0_error::Result;
786    use n0_future::{FutureExt, SinkExt, StreamExt};
787    use n0_tracing_test::traced_test;
788    use rand::SeedableRng;
789    use tracing::{info, instrument};
790
791    use super::{
792        Access, AccessConfig, NO_CONTENT_CHALLENGE_HEADER, NO_CONTENT_RESPONSE_HEADER, RelayConfig,
793        Server, ServerConfig, SpawnError,
794    };
795    use crate::{
796        client::{ClientBuilder, ConnectError},
797        dns::DnsResolver,
798        protos::{
799            handshake,
800            relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg},
801        },
802    };
803
804    async fn spawn_local_relay() -> std::result::Result<Server, SpawnError> {
805        Server::spawn(ServerConfig::<(), ()> {
806            relay: Some(RelayConfig::<(), ()> {
807                http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(),
808                tls: None,
809                limits: Default::default(),
810                key_cache_capacity: Some(1024),
811                access: AccessConfig::Everyone,
812            }),
813            quic: None,
814            metrics_addr: None,
815        })
816        .await
817    }
818
819    #[instrument]
820    async fn try_send_recv(
821        client_a: &mut crate::client::Client,
822        client_b: &mut crate::client::Client,
823        b_key: EndpointId,
824        msg: Datagrams,
825    ) -> Result<RelayToClientMsg> {
826        // try resend 10 times
827        for _ in 0..10 {
828            client_a
829                .send(ClientToRelayMsg::Datagrams {
830                    dst_endpoint_id: b_key,
831                    datagrams: msg.clone(),
832                })
833                .await?;
834            let Ok(res) = tokio::time::timeout(Duration::from_millis(500), client_b.next()).await
835            else {
836                continue;
837            };
838            let res = res.expect("stream finished")?;
839            return Ok(res);
840        }
841        panic!("failed to send and recv message");
842    }
843
844    fn dns_resolver() -> DnsResolver {
845        DnsResolver::new()
846    }
847
848    #[tokio::test]
849    #[traced_test]
850    async fn test_no_services() {
851        let mut server = Server::spawn(ServerConfig::<(), ()>::default())
852            .await
853            .unwrap();
854        let res = tokio::time::timeout(Duration::from_secs(5), server.task_handle())
855            .await
856            .expect("timeout, server not finished")
857            .expect("server task JoinError");
858        assert!(res.is_err());
859    }
860
861    #[tokio::test]
862    #[traced_test]
863    async fn test_conflicting_bind() {
864        let mut server = Server::spawn(ServerConfig::<(), ()> {
865            relay: Some(RelayConfig {
866                http_bind_addr: (Ipv4Addr::LOCALHOST, 1234).into(),
867                tls: None,
868                limits: Default::default(),
869                key_cache_capacity: Some(1024),
870                access: AccessConfig::Everyone,
871            }),
872            quic: None,
873            metrics_addr: Some((Ipv4Addr::LOCALHOST, 1234).into()),
874        })
875        .await
876        .unwrap();
877        let res = tokio::time::timeout(Duration::from_secs(5), server.task_handle())
878            .await
879            .expect("timeout, server not finished")
880            .expect("server task JoinError");
881        assert!(res.is_err()); // AddrInUse
882    }
883
884    #[tokio::test]
885    #[traced_test]
886    async fn test_root_handler() {
887        let server = spawn_local_relay().await.unwrap();
888        let url = format!("http://{}", server.http_addr().unwrap());
889
890        let client = reqwest::Client::builder().use_rustls_tls().build().unwrap();
891        let response = client.get(&url).send().await.unwrap();
892        assert_eq!(response.status(), 200);
893        let body = response.text().await.unwrap();
894        assert!(body.contains("iroh.computer"));
895    }
896
897    #[tokio::test]
898    #[traced_test]
899    async fn test_captive_portal_service() {
900        let server = spawn_local_relay().await.unwrap();
901        let url = format!("http://{}/generate_204", server.http_addr().unwrap());
902        let challenge = "123az__.";
903
904        let client = reqwest::Client::builder().use_rustls_tls().build().unwrap();
905        let response = client
906            .get(&url)
907            .header(NO_CONTENT_CHALLENGE_HEADER, challenge)
908            .send()
909            .await
910            .unwrap();
911        assert_eq!(response.status(), StatusCode::NO_CONTENT);
912        let header = response.headers().get(NO_CONTENT_RESPONSE_HEADER).unwrap();
913        assert_eq!(header.to_str().unwrap(), format!("response {challenge}"));
914        let body = response.text().await.unwrap();
915        assert!(body.is_empty());
916    }
917
918    #[tokio::test]
919    #[traced_test]
920    async fn test_relay_clients() -> Result<()> {
921        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
922        let server = spawn_local_relay().await?;
923
924        let relay_url = format!("http://{}", server.http_addr().unwrap());
925        let relay_url: RelayUrl = relay_url.parse()?;
926
927        // set up client a
928        let a_secret_key = SecretKey::generate(&mut rng);
929        let a_key = a_secret_key.public();
930        let resolver = dns_resolver();
931        info!("client a build & connect");
932        let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver.clone())
933            .connect()
934            .await?;
935
936        // set up client b
937        let b_secret_key = SecretKey::generate(&mut rng);
938        let b_key = b_secret_key.public();
939        info!("client b build & connect");
940        let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver.clone())
941            .connect()
942            .await?;
943
944        info!("sending a -> b");
945
946        // send message from a to b
947        let msg = Datagrams::from("hello, b");
948        let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?;
949        let RelayToClientMsg::Datagrams {
950            remote_endpoint_id,
951            datagrams,
952        } = res
953        else {
954            panic!("client_b received unexpected message {res:?}");
955        };
956
957        assert_eq!(a_key, remote_endpoint_id);
958        assert_eq!(msg, datagrams);
959
960        info!("sending b -> a");
961        // send message from b to a
962        let msg = Datagrams::from("howdy, a");
963        let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?;
964
965        let RelayToClientMsg::Datagrams {
966            remote_endpoint_id,
967            datagrams,
968        } = res
969        else {
970            panic!("client_a received unexpected message {res:?}");
971        };
972
973        assert_eq!(b_key, remote_endpoint_id);
974        assert_eq!(msg, datagrams);
975
976        Ok(())
977    }
978
979    #[tokio::test]
980    #[traced_test]
981    async fn test_relay_access_control() -> Result<()> {
982        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
983        let current_span = tracing::info_span!("this is a test");
984        let _guard = current_span.enter();
985
986        let a_secret_key = SecretKey::generate(&mut rng);
987        let a_key = a_secret_key.public();
988
989        let server = Server::spawn(ServerConfig::<(), ()> {
990            relay: Some(RelayConfig::<(), ()> {
991                http_bind_addr: (Ipv4Addr::LOCALHOST, 0).into(),
992                tls: None,
993                limits: Default::default(),
994                key_cache_capacity: Some(1024),
995                access: AccessConfig::Restricted(Box::new(move |endpoint_id| {
996                    async move {
997                        info!("checking {}", endpoint_id);
998                        // reject endpoint a
999                        if endpoint_id == a_key {
1000                            Access::Deny
1001                        } else {
1002                            Access::Allow
1003                        }
1004                    }
1005                    .boxed()
1006                })),
1007            }),
1008            quic: None,
1009            metrics_addr: None,
1010        })
1011        .await?;
1012
1013        let relay_url = format!("http://{}", server.http_addr().unwrap());
1014        let relay_url: RelayUrl = relay_url.parse()?;
1015
1016        // set up client a
1017        let resolver = dns_resolver();
1018        let result = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver)
1019            .connect()
1020            .await;
1021
1022        assert!(
1023            matches!(result, Err(ConnectError::Handshake { source: handshake::Error::ServerDeniedAuth { reason, .. }, .. }) if reason == "not authorized")
1024        );
1025
1026        // test that another client has access
1027
1028        // set up client b
1029        let b_secret_key = SecretKey::generate(&mut rng);
1030        let b_key = b_secret_key.public();
1031
1032        let resolver = dns_resolver();
1033        let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver)
1034            .connect()
1035            .await?;
1036
1037        // set up client c
1038        let c_secret_key = SecretKey::generate(&mut rng);
1039        let c_key = c_secret_key.public();
1040
1041        let resolver = dns_resolver();
1042        let mut client_c = ClientBuilder::new(relay_url.clone(), c_secret_key, resolver)
1043            .connect()
1044            .await?;
1045
1046        // send message from b to c
1047        let msg = Datagrams::from("hello, c");
1048        let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?;
1049
1050        if let RelayToClientMsg::Datagrams {
1051            remote_endpoint_id,
1052            datagrams,
1053        } = res
1054        {
1055            assert_eq!(b_key, remote_endpoint_id);
1056            assert_eq!(msg, datagrams);
1057        } else {
1058            panic!("client_c received unexpected message {res:?}");
1059        }
1060
1061        Ok(())
1062    }
1063
1064    #[tokio::test]
1065    #[traced_test]
1066    async fn test_relay_clients_full() -> Result<()> {
1067        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1068        let server = spawn_local_relay().await.unwrap();
1069        let relay_url = format!("http://{}", server.http_addr().unwrap());
1070        let relay_url: RelayUrl = relay_url.parse().unwrap();
1071
1072        // set up client a
1073        let a_secret_key = SecretKey::generate(&mut rng);
1074        let resolver = dns_resolver();
1075        let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver.clone())
1076            .connect()
1077            .await?;
1078
1079        // set up client b
1080        let b_secret_key = SecretKey::generate(&mut rng);
1081        let b_key = b_secret_key.public();
1082        let _client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver.clone())
1083            .connect()
1084            .await?;
1085
1086        // send messages from a to b, without b receiving anything.
1087        // we should still keep succeeding to send, even if the packet won't be forwarded
1088        // by the relay server because the server's send queue for b fills up.
1089        let msg = Datagrams::from("hello, b");
1090        for _i in 0..1000 {
1091            client_a
1092                .send(ClientToRelayMsg::Datagrams {
1093                    dst_endpoint_id: b_key,
1094                    datagrams: msg.clone(),
1095                })
1096                .await?;
1097        }
1098        Ok(())
1099    }
1100}