1use std::{
10 collections::HashMap, future::Future, net::SocketAddr, pin::Pin, sync::Arc, time::Duration,
11};
12
13use bytes::Bytes;
14use derive_more::Debug;
15use http::{
16 header::{CONNECTION, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION},
17 response::Builder as ResponseBuilder,
18};
19use hyper::{
20 HeaderMap, Method, Request, Response, StatusCode,
21 body::Incoming,
22 header::{HeaderValue, SEC_WEBSOCKET_ACCEPT, UPGRADE},
23 service::Service,
24 upgrade::Upgraded,
25};
26use n0_error::{e, ensure, stack_error};
27use n0_future::time::Elapsed;
28use tokio::net::{TcpListener, TcpStream};
29use tokio_rustls_acme::AcmeAcceptor;
30use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
31use tracing::{Instrument, debug, error, info, info_span, trace, warn, warn_span};
32
33use super::{AccessConfig, SpawnError, clients::Clients, streams::InvalidBucketConfig};
34use crate::{
35 KeyCache,
36 defaults::{DEFAULT_KEY_CACHE_CAPACITY, timeouts::SERVER_WRITE_TIMEOUT},
37 http::{
38 CLIENT_AUTH_HEADER, RELAY_PATH, RELAY_PROTOCOL_VERSION, SUPPORTED_WEBSOCKET_VERSION,
39 WEBSOCKET_UPGRADE_PROTOCOL,
40 },
41 protos::{
42 handshake,
43 relay::{MAX_FRAME_SIZE, PER_CLIENT_SEND_QUEUE_DEPTH},
44 streams::WsBytesFramed,
45 },
46 server::{
47 ClientRateLimit,
48 client::Config,
49 metrics::Metrics,
50 streams::{MaybeTlsStream, RateLimited, RelayedStream},
51 },
52};
53
54type BytesBody = http_body_util::Full<hyper::body::Bytes>;
55type HyperError = Box<dyn std::error::Error + Send + Sync>;
56type HyperResult<T> = std::result::Result<T, HyperError>;
57type HyperHandler = Box<
58 dyn Fn(Request<Incoming>, ResponseBuilder) -> HyperResult<Response<BytesBody>>
59 + Send
60 + Sync
61 + 'static,
62>;
63
64const SEC_WEBSOCKET_ACCEPT_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
66
67fn derive_accept_key(client_key: &HeaderValue) -> String {
70 use sha1::Digest;
71
72 let mut sha1 = sha1::Sha1::new();
73 sha1.update(client_key.as_bytes());
74 sha1.update(SEC_WEBSOCKET_ACCEPT_GUID);
75 data_encoding::BASE64.encode(&sha1.finalize())
76}
77
78fn body_full(content: impl Into<hyper::body::Bytes>) -> BytesBody {
80 http_body_util::Full::new(content.into())
81}
82
83#[allow(clippy::result_large_err)]
84fn downcast_upgrade(upgraded: Upgraded) -> Result<(MaybeTlsStream, Bytes), ConnectionHandlerError> {
85 match upgraded.downcast::<hyper_util::rt::TokioIo<MaybeTlsStream>>() {
86 Ok(parts) => Ok((parts.io.into_inner(), parts.read_buf)),
87 Err(_) => Err(e!(ConnectionHandlerError::DowncastUpgrade)),
88 }
89}
90
91#[derive(Debug)]
99pub(super) struct Server {
100 addr: SocketAddr,
101 http_server_task: AbortOnDropHandle<()>,
102 cancel_server_loop: CancellationToken,
103}
104
105impl Server {
106 pub(super) fn handle(&self) -> ServerHandle {
111 ServerHandle {
112 cancel_token: self.cancel_server_loop.clone(),
113 }
114 }
115
116 pub(super) fn shutdown(&self) {
118 self.cancel_server_loop.cancel();
119 }
120
121 pub(super) fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
127 &mut self.http_server_task
128 }
129
130 pub(super) fn addr(&self) -> SocketAddr {
132 self.addr
133 }
134}
135
136#[derive(Debug, Clone)]
140pub(super) struct ServerHandle {
141 cancel_token: CancellationToken,
142}
143
144impl ServerHandle {
145 pub(super) fn shutdown(&self) {
147 self.cancel_token.cancel()
148 }
149}
150
151#[derive(Debug, Clone)]
188pub struct TlsConfig {
189 pub(super) config: Arc<rustls::ServerConfig>,
191 pub(super) acceptor: TlsAcceptor,
193}
194
195impl TlsConfig {
196 pub fn new(config: Arc<rustls::ServerConfig>) -> Self {
231 let acceptor = tokio_rustls::TlsAcceptor::from(config.clone());
232 Self {
233 config,
234 acceptor: TlsAcceptor::Manual(acceptor),
235 }
236 }
237}
238
239#[allow(missing_docs)]
241#[stack_error(derive, add_meta)]
242#[non_exhaustive]
243pub enum ServeConnectionError {
244 #[error("TLS[acme] handshake")]
245 TlsHandshake {
246 #[error(std_err)]
247 source: std::io::Error,
248 },
249 #[error("TLS[acme] serve connection")]
250 ServeConnection {
251 #[error(std_err)]
252 source: hyper::Error,
253 },
254 #[error("TLS[manual] timeout")]
255 Timeout {
256 #[error(std_err)]
257 source: Elapsed,
258 },
259 #[error("TLS[manual] accept")]
260 ManualAccept {
261 #[error(std_err)]
262 source: std::io::Error,
263 },
264 #[error("TLS[acme] accept")]
265 LetsEncryptAccept {
266 #[error(std_err)]
267 source: std::io::Error,
268 },
269 #[error("HTTPS connection")]
270 Https {
271 #[error(std_err)]
272 source: hyper::Error,
273 },
274 #[error("HTTP connection")]
275 Http {
276 #[error(std_err)]
277 source: hyper::Error,
278 },
279}
280
281#[allow(missing_docs)]
283#[stack_error(derive, add_meta, from_sources)]
284#[non_exhaustive]
285pub enum AcceptError {
286 #[error(transparent)]
287 Handshake { source: handshake::Error },
288 #[error("rate limiting misconfigured")]
289 RateLimitingMisconfigured { source: InvalidBucketConfig },
290}
291
292#[allow(missing_docs)]
294#[stack_error(derive, add_meta, from_sources)]
295#[non_exhaustive]
296pub enum ConnectionHandlerError {
297 #[error(transparent)]
298 Accept { source: AcceptError },
299 #[error("Could not downcast the upgraded connection to MaybeTlsStream")]
300 DowncastUpgrade {},
301 #[error("Cannot deal with buffered data yet: {buf:?}")]
302 BufferNotEmpty { buf: Bytes },
303}
304
305#[derive(derive_more::Debug)]
310pub(super) struct ServerBuilder {
311 addr: SocketAddr,
313 tls_config: Option<TlsConfig>,
317 handlers: Handlers,
322 headers: HeaderMap,
324 client_rx_ratelimit: Option<ClientRateLimit>,
329 key_cache_capacity: usize,
331 access: AccessConfig,
333 metrics: Option<Arc<Metrics>>,
334}
335
336impl ServerBuilder {
337 pub(super) fn new(addr: SocketAddr) -> Self {
339 Self {
340 addr,
341 tls_config: None,
342 handlers: Default::default(),
343 headers: HeaderMap::new(),
344 client_rx_ratelimit: None,
345 key_cache_capacity: DEFAULT_KEY_CACHE_CAPACITY,
346 access: AccessConfig::Everyone,
347 metrics: None,
348 }
349 }
350
351 pub(super) fn metrics(mut self, metrics: Arc<Metrics>) -> Self {
353 self.metrics = Some(metrics);
354 self
355 }
356
357 pub(super) fn access(mut self, access: AccessConfig) -> Self {
359 self.access = access;
360 self
361 }
362
363 pub(super) fn tls_config(mut self, config: Option<TlsConfig>) -> Self {
365 self.tls_config = config;
366 self
367 }
368
369 pub(super) fn client_rx_ratelimit(mut self, config: ClientRateLimit) -> Self {
374 self.client_rx_ratelimit = Some(config);
375 self
376 }
377
378 pub(super) fn request_handler(
380 mut self,
381 method: Method,
382 uri_path: &'static str,
383 handler: HyperHandler,
384 ) -> Self {
385 self.handlers.insert((method, uri_path), handler);
386 self
387 }
388
389 pub(super) fn headers(mut self, headers: HeaderMap) -> Self {
391 for (k, v) in headers.iter() {
392 self.headers.insert(k.clone(), v.clone());
393 }
394 self
395 }
396
397 pub fn key_cache_capacity(mut self, capacity: usize) -> Self {
399 self.key_cache_capacity = capacity;
400 self
401 }
402
403 pub(super) async fn spawn(self) -> Result<Server, SpawnError> {
405 let cancel_token = CancellationToken::new();
406
407 let service = RelayService::new(
408 self.handlers,
409 self.headers,
410 self.client_rx_ratelimit,
411 KeyCache::new(self.key_cache_capacity),
412 self.access,
413 self.metrics.unwrap_or_default(),
414 );
415
416 let addr = self.addr;
417 let tls_config = self.tls_config;
418
419 let listener = TcpListener::bind(&addr)
422 .await
423 .map_err(|err| e!(super::SpawnError::BindTcpListener { addr }, err))?;
424
425 let addr = listener
426 .local_addr()
427 .map_err(|err| e!(super::SpawnError::NoLocalAddr, err))?;
428 let http_str = tls_config.as_ref().map_or("HTTP/WS", |_| "HTTPS/WSS");
429 info!("[{http_str}] relay: serving on {addr}");
430
431 let cancel = cancel_token.clone();
432 let task = tokio::task::spawn(
433 async move {
434 let mut set = tokio::task::JoinSet::new();
436 loop {
437 tokio::select! {
438 biased;
439 _ = cancel.cancelled() => {
440 break;
441 }
442 Some(res) = set.join_next() => {
443 if let Err(err) = res
444 && err.is_panic()
445 {
446 panic!("task panicked: {err:#?}");
447 }
448 }
449 res = listener.accept() => match res {
450 Ok((stream, peer_addr)) => {
451 debug!("connection opened from {peer_addr}");
452 let tls_config = tls_config.clone();
453 let service = service.clone();
454 set.spawn(async move {
456 service
457 .handle_connection(stream, tls_config)
458 .await
459 }.instrument(info_span!("conn", peer = %peer_addr)));
460 }
461 Err(err) => {
462 error!("failed to accept connection: {err}");
463 }
464 }
465 }
466 }
467 service.shutdown().await;
468 set.shutdown().await;
469 debug!("server has been shutdown.");
470 }
471 .instrument(info_span!("relay-http-serve")),
472 );
473
474 Ok(Server {
475 addr,
476 http_server_task: AbortOnDropHandle::new(task),
477 cancel_server_loop: cancel_token,
478 })
479 }
480}
481
482#[derive(Clone, Debug)]
486pub struct RelayService(Arc<Inner>);
487
488#[derive(Debug)]
489struct Inner {
490 handlers: Handlers,
491 headers: HeaderMap,
492 clients: Clients,
493 write_timeout: Duration,
494 rate_limit: Option<ClientRateLimit>,
495 key_cache: KeyCache,
496 access: AccessConfig,
497 metrics: Arc<Metrics>,
498}
499
500#[stack_error(derive, add_meta)]
501enum RelayUpgradeReqError {
502 #[error("missing header: {header}")]
503 MissingHeader { header: http::HeaderName },
504 #[error("invalid header value for {header}: {details}")]
505 InvalidHeader {
506 header: http::HeaderName,
507 details: String,
508 },
509 #[error(
510 "invalid header value for {SEC_WEBSOCKET_VERSION}: unsupported websocket version, only supporting {SUPPORTED_WEBSOCKET_VERSION}"
511 )]
512 UnsupportedWebsocketVersion,
513 #[error(
514 "invalid header value for {SEC_WEBSOCKET_PROTOCOL}: unsupported relay version: we support {we_support} but you only provide {you_support}"
515 )]
516 UnsupportedRelayVersion {
517 we_support: &'static str,
518 you_support: String,
519 },
520}
521
522impl RelayService {
523 fn build_response(&self) -> http::response::Builder {
524 let mut res = Response::builder();
525 for (key, value) in self.0.headers.iter() {
526 res = res.header(key, value);
527 }
528 res
529 }
530
531 fn handle_relay_ws_upgrade(
533 &self,
534 mut req: Request<Incoming>,
535 ) -> Result<Response<BytesBody>, RelayUpgradeReqError> {
536 fn expect_header(
537 req: &Request<Incoming>,
538 header: http::HeaderName,
539 ) -> Result<&HeaderValue, RelayUpgradeReqError> {
540 req.headers()
541 .get(&header)
542 .ok_or_else(|| e!(RelayUpgradeReqError::MissingHeader { header }))
543 }
544
545 let upgrade_header = expect_header(&req, UPGRADE)?;
546 ensure!(
547 upgrade_header == HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL),
548 RelayUpgradeReqError::InvalidHeader {
549 header: UPGRADE,
550 details: format!("value must be {WEBSOCKET_UPGRADE_PROTOCOL}")
551 }
552 );
553
554 let key = expect_header(&req, SEC_WEBSOCKET_KEY)?.clone();
555 let version = expect_header(&req, SEC_WEBSOCKET_VERSION)?.clone();
556
557 ensure!(
558 version.as_bytes() == SUPPORTED_WEBSOCKET_VERSION.as_bytes(),
559 RelayUpgradeReqError::UnsupportedWebsocketVersion
560 );
561
562 let subprotocols = expect_header(&req, SEC_WEBSOCKET_PROTOCOL)?
563 .to_str()
564 .ok()
565 .ok_or_else(|| {
566 e!(RelayUpgradeReqError::InvalidHeader {
567 header: SEC_WEBSOCKET_PROTOCOL,
568 details: "header value is not ascii".to_string()
569 })
570 })?;
571 let supports_our_version = subprotocols
572 .split_whitespace()
573 .any(|p| p == RELAY_PROTOCOL_VERSION);
574 ensure!(
575 supports_our_version,
576 RelayUpgradeReqError::UnsupportedRelayVersion {
577 we_support: RELAY_PROTOCOL_VERSION,
578 you_support: subprotocols.to_string()
579 }
580 );
581
582 let client_auth_header = req.headers().get(CLIENT_AUTH_HEADER).cloned();
583
584 tokio::task::spawn({
592 let this = self.clone();
593 async move {
594 match hyper::upgrade::on(&mut req).await {
595 Ok(upgraded) => {
596 if let Err(err) = this
597 .0
598 .relay_connection_handler(upgraded, client_auth_header)
599 .await
600 {
601 warn!("error accepting upgraded connection: {err:#}",);
602 } else {
603 debug!("upgraded connection completed");
604 };
605 }
606 Err(err) => warn!("upgrade error: {err:#}"),
607 }
608 }
609 .instrument(warn_span!("handler"))
610 });
611
612 Ok(self
615 .build_response()
616 .status(StatusCode::SWITCHING_PROTOCOLS)
617 .header(
618 UPGRADE,
619 HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL),
620 )
621 .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(&key))
622 .header(
623 SEC_WEBSOCKET_PROTOCOL,
624 HeaderValue::from_static(RELAY_PROTOCOL_VERSION),
625 )
626 .header(CONNECTION, "upgrade")
627 .body(body_full("switching to websocket protocol"))
628 .expect("valid body"))
629 }
630}
631
632impl Service<Request<Incoming>> for RelayService {
633 type Response = Response<BytesBody>;
634 type Error = HyperError;
635 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
636
637 fn call(&self, req: Request<Incoming>) -> Self::Future {
638 if matches!(
640 (req.method(), req.uri().path()),
641 (&hyper::Method::GET, RELAY_PATH)
642 ) {
643 let res = match self.handle_relay_ws_upgrade(req) {
644 Ok(response) => Ok(response),
645 Err(e @ RelayUpgradeReqError::UnsupportedWebsocketVersion { .. }) => self
647 .build_response()
648 .status(StatusCode::BAD_REQUEST)
649 .header(SEC_WEBSOCKET_VERSION, SUPPORTED_WEBSOCKET_VERSION)
650 .body(body_full(e.to_string())),
651 Err(e) => self
652 .build_response()
653 .status(StatusCode::BAD_REQUEST)
654 .body(body_full(e.to_string())),
655 }
656 .map_err(Into::into);
657 return Box::pin(async move { res });
658 }
659 let uri = req.uri().clone();
663 if let Some(res) = self.0.handlers.get(&(req.method().clone(), uri.path())) {
664 let f = res(req, self.0.default_response());
665 return Box::pin(async move { f });
666 }
667 let res = self.0.not_found_fn(req, self.0.default_response());
669 Box::pin(async move { res })
670 }
671}
672
673impl Inner {
674 fn default_response(&self) -> ResponseBuilder {
675 let mut response = Response::builder();
676 for (key, value) in self.headers.iter() {
677 response = response.header(key.clone(), value.clone());
678 }
679 response
680 }
681
682 fn not_found_fn(
683 &self,
684 _req: Request<Incoming>,
685 mut res: ResponseBuilder,
686 ) -> HyperResult<Response<BytesBody>> {
687 for (k, v) in self.headers.iter() {
688 res = res.header(k.clone(), v.clone());
689 }
690 let body = body_full("Not Found");
691 let r = res.status(StatusCode::NOT_FOUND).body(body)?;
692 HyperResult::Ok(r)
693 }
694
695 async fn relay_connection_handler(
701 &self,
702 upgraded: Upgraded,
703 client_auth_header: Option<HeaderValue>,
704 ) -> Result<(), ConnectionHandlerError> {
705 debug!("relay_connection upgraded");
706 let (io, read_buf) = downcast_upgrade(upgraded)?;
707 if !read_buf.is_empty() {
708 return Err(e!(ConnectionHandlerError::BufferNotEmpty { buf: read_buf }));
709 }
710
711 self.accept(io, client_auth_header).await?;
712 Ok(())
713 }
714
715 async fn accept(
726 &self,
727 io: MaybeTlsStream,
728 client_auth_header: Option<HeaderValue>,
729 ) -> Result<(), AcceptError> {
730 trace!("accept: start");
731
732 io.disable_nagle();
734
735 let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone())
736 .map_err(|err| e!(AcceptError::RateLimitingMisconfigured, err))?;
737
738 let websocket = tokio_websockets::ServerBuilder::new()
740 .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE)))
741 .serve(io);
743
744 let mut io = WsBytesFramed { io: websocket };
745
746 let authentication = handshake::serverside(&mut io, client_auth_header).await?;
747
748 trace!(?authentication.mechanism, "accept: verified authentication");
749
750 let is_authorized = self.access.is_allowed(authentication.client_key).await;
751 let client_key = authentication.authorize_if(is_authorized, &mut io).await?;
752
753 trace!("accept: verified authorization");
754
755 let io = RelayedStream {
756 inner: io,
757 key_cache: self.key_cache.clone(),
758 };
759
760 trace!("accept: build client conn");
761 let client_conn_builder = Config {
762 endpoint_id: client_key,
763 stream: io,
764 write_timeout: self.write_timeout,
765 channel_capacity: PER_CLIENT_SEND_QUEUE_DEPTH,
766 };
767 trace!("accept: create client");
768 let endpoint_id = client_conn_builder.endpoint_id;
769 trace!(endpoint_id = %endpoint_id.fmt_short(), "create client");
770
771 self.clients
774 .register(client_conn_builder, self.metrics.clone());
775 Ok(())
776 }
777}
778
779#[derive(Clone, derive_more::Debug)]
781pub(super) enum TlsAcceptor {
782 LetsEncrypt(#[debug("tokio_rustls_acme::AcmeAcceptor")] AcmeAcceptor),
784 Manual(#[debug("tokio_rustls::TlsAcceptor")] tokio_rustls::TlsAcceptor),
787}
788
789impl RelayService {
790 pub fn new(
794 handlers: Handlers,
795 headers: HeaderMap,
796 rate_limit: Option<ClientRateLimit>,
797 key_cache: KeyCache,
798 access: AccessConfig,
799 metrics: Arc<Metrics>,
800 ) -> Self {
801 Self(Arc::new(Inner {
802 handlers,
803 headers,
804 clients: Clients::default(),
805 write_timeout: SERVER_WRITE_TIMEOUT,
806 rate_limit,
807 key_cache,
808 access,
809 metrics,
810 }))
811 }
812
813 pub async fn shutdown(&self) {
815 self.0.clients.shutdown().await;
816 }
817
818 pub async fn handle_connection(self, stream: TcpStream, tls_config: Option<TlsConfig>) {
872 let res = match tls_config {
873 Some(tls_config) => {
874 debug!("HTTPS: serve connection");
875 self.tls_serve_connection(stream, tls_config).await
876 }
877 None => {
878 debug!("HTTP: serve connection");
879 self.serve_connection(MaybeTlsStream::Plain(stream))
880 .await
881 .map_err(|err| e!(ServeConnectionError::Http, err))
882 }
883 };
884 match res {
885 Ok(()) => {}
886 Err(error) => match error {
887 ServeConnectionError::ManualAccept { source, .. }
888 | ServeConnectionError::LetsEncryptAccept { source, .. }
889 if source.kind() == std::io::ErrorKind::UnexpectedEof =>
890 {
891 debug!(reason=?source, "peer disconnected");
892 }
893 ServeConnectionError::Https { source, .. }
896 | ServeConnectionError::Http { source, .. }
897 if source.is_incomplete_message() =>
898 {
899 debug!(reason=?source, "peer disconnected");
900 }
901 _ => {
902 error!(?error, "failed to handle connection");
903 }
904 },
905 }
906 }
907
908 async fn tls_serve_connection(
910 self,
911 stream: TcpStream,
912 tls_config: TlsConfig,
913 ) -> Result<(), ServeConnectionError> {
914 let TlsConfig { acceptor, config } = tls_config;
915 match acceptor {
916 TlsAcceptor::LetsEncrypt(a) => {
917 match a
918 .accept(stream)
919 .await
920 .map_err(|err| e!(ServeConnectionError::LetsEncryptAccept, err))?
921 {
922 None => {
923 info!("TLS[acme]: received TLS-ALPN-01 validation request");
924 }
925 Some(start_handshake) => {
926 debug!("TLS[acme]: start handshake");
927 let tls_stream = start_handshake
928 .into_stream(config)
929 .await
930 .map_err(|err| e!(ServeConnectionError::TlsHandshake, err))?;
931 self.serve_connection(MaybeTlsStream::Tls(tls_stream))
932 .await
933 .map_err(|err| e!(ServeConnectionError::Https, err))?;
934 }
935 }
936 }
937 TlsAcceptor::Manual(a) => {
938 debug!("TLS[manual]: accept");
939 let tls_stream = tokio::time::timeout(Duration::from_secs(30), a.accept(stream))
940 .await
941 .map_err(|err| e!(ServeConnectionError::Timeout, err))?
942 .map_err(|err| e!(ServeConnectionError::ManualAccept, err))?;
943
944 self.serve_connection(MaybeTlsStream::Tls(tls_stream))
945 .await
946 .map_err(|err| e!(ServeConnectionError::ServeConnection, err))?;
947 }
948 }
949 Ok(())
950 }
951
952 async fn serve_connection<I>(self, io: I) -> Result<(), hyper::Error>
954 where
955 I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + Sync + 'static,
956 {
957 hyper::server::conn::http1::Builder::new()
958 .serve_connection(hyper_util::rt::TokioIo::new(io), self)
959 .with_upgrades()
960 .await
961 }
962}
963
964#[derive(Default)]
966pub struct Handlers(HashMap<(Method, &'static str), HyperHandler>);
967
968impl std::fmt::Debug for Handlers {
969 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
970 let s = self.0.keys().fold(String::new(), |curr, next| {
971 let (method, uri) = next;
972 format!("{curr}\n({method},{uri}): Box<Fn(ResponseBuilder) -> Result<Response<Body>> + Send + Sync + 'static>")
973 });
974 write!(f, "HashMap<{s}>")
975 }
976}
977
978impl std::ops::Deref for Handlers {
979 type Target = HashMap<(Method, &'static str), HyperHandler>;
980
981 fn deref(&self) -> &Self::Target {
982 &self.0
983 }
984}
985
986impl std::ops::DerefMut for Handlers {
987 fn deref_mut(&mut self) -> &mut Self::Target {
988 &mut self.0
989 }
990}
991
992#[cfg(test)]
993mod tests {
994 use std::sync::Arc;
995
996 use iroh_base::{PublicKey, SecretKey};
997 use n0_error::{Result, StdResultExt, bail_any};
998 use n0_future::{SinkExt, StreamExt};
999 use n0_tracing_test::traced_test;
1000 use rand::SeedableRng;
1001 use reqwest::Url;
1002 use tracing::info;
1003
1004 use super::*;
1005 use crate::{
1006 client::{Client, ClientBuilder, ConnectError, conn::Conn},
1007 dns::DnsResolver,
1008 protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg},
1009 tls::{CaRootsConfig, default_provider},
1010 };
1011
1012 pub(crate) fn make_tls_config() -> TlsConfig {
1013 let subject_alt_names = vec!["localhost".to_string()];
1014
1015 let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap();
1016 let rustls_certificate = cert.cert.der().clone();
1017 let rustls_key =
1018 rustls::pki_types::PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
1019 let config = rustls::ServerConfig::builder_with_provider(Arc::new(
1020 rustls::crypto::ring::default_provider(),
1021 ))
1022 .with_safe_default_protocol_versions()
1023 .expect("protocols supported by ring")
1024 .with_no_client_auth()
1025 .with_single_cert(vec![(rustls_certificate)], rustls_key.into())
1026 .expect("cert is right");
1027
1028 TlsConfig::new(Arc::new(config))
1029 }
1030
1031 #[tokio::test]
1032 #[traced_test]
1033 async fn test_http_clients_and_server() -> Result {
1034 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1035
1036 let a_key = SecretKey::generate(&mut rng);
1037 let b_key = SecretKey::generate(&mut rng);
1038
1039 let server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1041 .spawn()
1042 .await?;
1043
1044 let addr = server.addr();
1045
1046 let port = addr.port();
1048 let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
1049 ipv4_addr
1050 } else {
1051 bail_any!("cannot get ipv4 addr from socket addr {addr:?}");
1052 };
1053
1054 info!("addr: {addr}:{port}");
1055 let relay_addr: Url = format!("http://{addr}:{port}").parse().unwrap();
1056
1057 let (a_key, mut client_a) = create_test_client(a_key, relay_addr.clone()).await?;
1059 info!("created client {a_key:?}");
1060 let (b_key, mut client_b) = create_test_client(b_key, relay_addr).await?;
1061 info!("created client {b_key:?}");
1062
1063 info!("ping a");
1064 client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?;
1065 let pong = client_a.next().await.expect("eos")?;
1066 assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1067
1068 info!("ping b");
1069 client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?;
1070 let pong = client_b.next().await.expect("eos")?;
1071 assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1072
1073 info!("sending message from a to b");
1074 let msg = Datagrams::from(b"hi there, client b!");
1075 client_a
1076 .send(ClientToRelayMsg::Datagrams {
1077 dst_endpoint_id: b_key,
1078 datagrams: msg.clone(),
1079 })
1080 .await?;
1081 info!("waiting for message from a on b");
1082 let (got_key, got_msg) =
1083 process_msg(client_b.next().await).expect("expected message from client_a");
1084 assert_eq!(a_key, got_key);
1085 assert_eq!(msg, got_msg);
1086
1087 info!("sending message from b to a");
1088 let msg = Datagrams::from(b"right back at ya, client b!");
1089 client_b
1090 .send(ClientToRelayMsg::Datagrams {
1091 dst_endpoint_id: a_key,
1092 datagrams: msg.clone(),
1093 })
1094 .await?;
1095 info!("waiting for message b on a");
1096 let (got_key, got_msg) =
1097 process_msg(client_a.next().await).expect("expected message from client_b");
1098 assert_eq!(b_key, got_key);
1099 assert_eq!(msg, got_msg);
1100
1101 client_a.close().await?;
1103 client_b.close().await?;
1104 server.shutdown();
1105
1106 Ok(())
1107 }
1108
1109 async fn create_test_client(
1110 key: SecretKey,
1111 server_url: Url,
1112 ) -> Result<(PublicKey, Client), ConnectError> {
1113 let public_key = key.public();
1114 let client = ClientBuilder::new(server_url, key, DnsResolver::new()).tls_client_config(
1115 CaRootsConfig::insecure_skip_verify()
1116 .client_config(default_provider())
1117 .expect("infallible"),
1118 );
1119 let client = client.connect().await?;
1120
1121 Ok((public_key, client))
1122 }
1123
1124 fn process_msg(
1125 msg: Option<Result<RelayToClientMsg, crate::client::RecvError>>,
1126 ) -> Option<(PublicKey, Datagrams)> {
1127 match msg {
1128 Some(Err(e)) => {
1129 info!("client `recv` error {e}");
1130 None
1131 }
1132 Some(Ok(msg)) => {
1133 info!("got message on: {msg:?}");
1134 if let RelayToClientMsg::Datagrams {
1135 remote_endpoint_id: source,
1136 datagrams,
1137 } = msg
1138 {
1139 Some((source, datagrams))
1140 } else {
1141 None
1142 }
1143 }
1144 None => {
1145 info!("client end of stream");
1146 None
1147 }
1148 }
1149 }
1150
1151 #[tokio::test]
1152 #[traced_test]
1153 async fn test_https_clients_and_server() -> Result {
1154 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1155
1156 let a_key = SecretKey::generate(&mut rng);
1157 let b_key = SecretKey::generate(&mut rng);
1158
1159 let tls_config = make_tls_config();
1161
1162 let mut server = ServerBuilder::new("127.0.0.1:0".parse().unwrap())
1164 .tls_config(Some(tls_config))
1165 .spawn()
1166 .await?;
1167
1168 let addr = server.addr();
1169
1170 let port = addr.port();
1172 let addr = if let std::net::IpAddr::V4(ipv4_addr) = addr.ip() {
1173 ipv4_addr
1174 } else {
1175 bail_any!("cannot get ipv4 addr from socket addr {addr:?}");
1176 };
1177
1178 info!("Relay listening on: {addr}:{port}");
1179
1180 let url: Url = format!("https://localhost:{port}").parse().unwrap();
1181
1182 let (a_key, mut client_a) = create_test_client(a_key, url.clone()).await?;
1184 info!("created client {a_key:?}");
1185 let (b_key, mut client_b) = create_test_client(b_key, url).await?;
1186 info!("created client {b_key:?}");
1187
1188 info!("ping a");
1189 client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?;
1190 let pong = client_a.next().await.expect("eos")?;
1191 assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1192
1193 info!("ping b");
1194 client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?;
1195 let pong = client_b.next().await.expect("eos")?;
1196 assert!(matches!(pong, RelayToClientMsg::Pong { .. }));
1197
1198 info!("sending message from a to b");
1199 let msg = Datagrams::from(b"hi there, client b!");
1200 client_a
1201 .send(ClientToRelayMsg::Datagrams {
1202 dst_endpoint_id: b_key,
1203 datagrams: msg.clone(),
1204 })
1205 .await?;
1206 info!("waiting for message from a on b");
1207 let (got_key, got_msg) =
1208 process_msg(client_b.next().await).expect("expected message from client_a");
1209 assert_eq!(a_key, got_key);
1210 assert_eq!(msg, got_msg);
1211
1212 info!("sending message from b to a");
1213 let msg = Datagrams::from(b"right back at ya, client b!");
1214 client_b
1215 .send(ClientToRelayMsg::Datagrams {
1216 dst_endpoint_id: a_key,
1217 datagrams: msg.clone(),
1218 })
1219 .await?;
1220 info!("waiting for message b on a");
1221 let (got_key, got_msg) =
1222 process_msg(client_a.next().await).expect("expected message from client_b");
1223 assert_eq!(b_key, got_key);
1224 assert_eq!(msg, got_msg);
1225
1226 client_a.close().await?;
1228 client_b.close().await?;
1229 server.shutdown();
1230 server.task_handle().await.std_context("join")?;
1231
1232 Ok(())
1233 }
1234
1235 async fn make_test_client(client: tokio::io::DuplexStream, key: &SecretKey) -> Result<Conn> {
1236 let client = crate::client::streams::MaybeTlsStream::Test(client);
1237 let client = tokio_websockets::ClientBuilder::new().take_over(client);
1238 let client = Conn::new(client, KeyCache::test(), key).await?;
1239 Ok(client)
1240 }
1241
1242 #[tokio::test]
1243 #[traced_test]
1244 async fn test_server_basic() -> Result {
1245 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1246
1247 info!("Create the server.");
1248 let metrics = Arc::new(Metrics::default());
1249 let service = RelayService::new(
1250 Default::default(),
1251 Default::default(),
1252 None,
1253 KeyCache::test(),
1254 AccessConfig::Everyone,
1255 metrics.clone(),
1256 );
1257
1258 info!("Create client A and connect it to the server.");
1259 let key_a = SecretKey::generate(&mut rng);
1260 let public_key_a = key_a.public();
1261 let (client_a, rw_a) = tokio::io::duplex(10);
1262 let s = service.clone();
1263 let handler_task =
1264 tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a), None).await });
1265 let mut client_a = make_test_client(client_a, &key_a).await?;
1266 handler_task.await.std_context("join")??;
1267
1268 info!("Create client B and connect it to the server.");
1269 let key_b = SecretKey::generate(&mut rng);
1270 let public_key_b = key_b.public();
1271 let (client_b, rw_b) = tokio::io::duplex(10);
1272 let s = service.clone();
1273 let handler_task =
1274 tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b), None).await });
1275 let mut client_b = make_test_client(client_b, &key_b).await?;
1276 handler_task.await.std_context("join")??;
1277
1278 info!("Send message from A to B.");
1279 let msg = Datagrams::from(b"hello client b!!");
1280 client_a
1281 .send(ClientToRelayMsg::Datagrams {
1282 dst_endpoint_id: public_key_b,
1283 datagrams: msg.clone(),
1284 })
1285 .await?;
1286 match client_b.next().await.unwrap()? {
1287 RelayToClientMsg::Datagrams {
1288 remote_endpoint_id,
1289 datagrams,
1290 } => {
1291 assert_eq!(public_key_a, remote_endpoint_id);
1292 assert_eq!(msg, datagrams);
1293 }
1294 msg => {
1295 bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1296 }
1297 }
1298
1299 info!("Send message from B to A.");
1300 let msg = Datagrams::from(b"nice to meet you client a!!");
1301 client_b
1302 .send(ClientToRelayMsg::Datagrams {
1303 dst_endpoint_id: public_key_a,
1304 datagrams: msg.clone(),
1305 })
1306 .await?;
1307 match client_a.next().await.unwrap()? {
1308 RelayToClientMsg::Datagrams {
1309 remote_endpoint_id,
1310 datagrams,
1311 } => {
1312 assert_eq!(public_key_b, remote_endpoint_id);
1313 assert_eq!(msg, datagrams);
1314 }
1315 msg => {
1316 bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1317 }
1318 }
1319
1320 info!("Close the server and clients");
1321 service.shutdown().await;
1322 tokio::time::sleep(Duration::from_secs(1)).await;
1323
1324 info!("Fail to send message from A to B.");
1325 let res = client_a
1326 .send(ClientToRelayMsg::Datagrams {
1327 dst_endpoint_id: public_key_b,
1328 datagrams: Datagrams::from(b"try to send"),
1329 })
1330 .await;
1331 assert!(res.is_err());
1332 assert!(client_b.next().await.is_none());
1333
1334 drop(client_a);
1335 drop(client_b);
1336
1337 service.shutdown().await;
1338
1339 assert_eq!(metrics.accepts.get(), metrics.disconnects.get());
1340
1341 Ok(())
1342 }
1343
1344 #[tokio::test]
1345 async fn test_server_replace_client() -> Result {
1346 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
1347
1348 info!("Create the server.");
1349 let service = RelayService::new(
1350 Default::default(),
1351 Default::default(),
1352 None,
1353 KeyCache::test(),
1354 AccessConfig::Everyone,
1355 Default::default(),
1356 );
1357
1358 info!("Create client A and connect it to the server.");
1359 let key_a = SecretKey::generate(&mut rng);
1360 let public_key_a = key_a.public();
1361 let (client_a, rw_a) = tokio::io::duplex(10);
1362 let s = service.clone();
1363 let handler_task =
1364 tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a), None).await });
1365 let mut client_a = make_test_client(client_a, &key_a).await?;
1366 handler_task.await.std_context("join")??;
1367
1368 info!("Create client B and connect it to the server.");
1369 let key_b = SecretKey::generate(&mut rng);
1370 let public_key_b = key_b.public();
1371 let (client_b, rw_b) = tokio::io::duplex(10);
1372 let s = service.clone();
1373 let handler_task =
1374 tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b), None).await });
1375 let mut client_b = make_test_client(client_b, &key_b).await?;
1376 handler_task.await.std_context("join")??;
1377
1378 info!("Send message from A to B.");
1379 let msg = Datagrams::from(b"hello client b!!");
1380 client_a
1381 .send(ClientToRelayMsg::Datagrams {
1382 dst_endpoint_id: public_key_b,
1383 datagrams: msg.clone(),
1384 })
1385 .await?;
1386 match client_b.next().await.expect("eos")? {
1387 RelayToClientMsg::Datagrams {
1388 remote_endpoint_id,
1389 datagrams,
1390 } => {
1391 assert_eq!(public_key_a, remote_endpoint_id);
1392 assert_eq!(msg, datagrams);
1393 }
1394 msg => {
1395 bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1396 }
1397 }
1398
1399 info!("Send message from B to A.");
1400 let msg = Datagrams::from(b"nice to meet you client a!!");
1401 client_b
1402 .send(ClientToRelayMsg::Datagrams {
1403 dst_endpoint_id: public_key_a,
1404 datagrams: msg.clone(),
1405 })
1406 .await?;
1407 match client_a.next().await.expect("eos")? {
1408 RelayToClientMsg::Datagrams {
1409 remote_endpoint_id,
1410 datagrams,
1411 } => {
1412 assert_eq!(public_key_b, remote_endpoint_id);
1413 assert_eq!(msg, datagrams);
1414 }
1415 msg => {
1416 bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1417 }
1418 }
1419
1420 info!("Create client B and connect it to the server");
1421 let (new_client_b, new_rw_b) = tokio::io::duplex(10);
1422 let s = service.clone();
1423 let handler_task =
1424 tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(new_rw_b), None).await });
1425 let mut new_client_b = make_test_client(new_client_b, &key_b).await?;
1426 handler_task.await.std_context("join")??;
1427
1428 info!("Send message from A to B.");
1431 let msg = Datagrams::from(b"are you still there, b?!");
1432 client_a
1433 .send(ClientToRelayMsg::Datagrams {
1434 dst_endpoint_id: public_key_b,
1435 datagrams: msg.clone(),
1436 })
1437 .await?;
1438 match new_client_b.next().await.expect("eos")? {
1439 RelayToClientMsg::Datagrams {
1440 remote_endpoint_id,
1441 datagrams,
1442 } => {
1443 assert_eq!(public_key_a, remote_endpoint_id);
1444 assert_eq!(msg, datagrams);
1445 }
1446 msg => {
1447 bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1448 }
1449 }
1450
1451 info!("Send message from B to A.");
1452 let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!");
1453 new_client_b
1454 .send(ClientToRelayMsg::Datagrams {
1455 dst_endpoint_id: public_key_a,
1456 datagrams: msg.clone(),
1457 })
1458 .await?;
1459 match client_a.next().await.expect("eos")? {
1460 RelayToClientMsg::Datagrams {
1461 remote_endpoint_id,
1462 datagrams,
1463 } => {
1464 assert_eq!(public_key_b, remote_endpoint_id);
1465 assert_eq!(msg, datagrams);
1466 }
1467 msg => {
1468 bail_any!("expected ReceivedDatagrams msg, got {msg:?}");
1469 }
1470 }
1471
1472 info!("Close the server and clients");
1473 service.shutdown().await;
1474
1475 info!("Sending message from A to B fails");
1476 let res = client_a
1477 .send(ClientToRelayMsg::Datagrams {
1478 dst_endpoint_id: public_key_b,
1479 datagrams: Datagrams::from(b"try to send"),
1480 })
1481 .await;
1482 assert!(res.is_err());
1483 assert!(new_client_b.next().await.is_none());
1484 Ok(())
1485 }
1486}