1use std::{net::SocketAddr, sync::Arc};
4
5use n0_error::stack_error;
6use n0_future::time::Duration;
7use noq::{VarInt, crypto::rustls::QuicClientConfig};
8use tokio::sync::watch;
9
10pub const ALPN_QUIC_ADDR_DISC: &[u8] = b"/iroh-qad/0";
12pub const QUIC_ADDR_DISC_CLOSE_CODE: VarInt = VarInt::from_u32(1);
14pub const QUIC_ADDR_DISC_CLOSE_REASON: &[u8] = b"finished";
16
17#[cfg(feature = "server")]
18pub(crate) mod server {
19 use n0_error::e;
20 use noq::{
21 ApplicationClose, ConnectionError,
22 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
23 };
24 use tokio::task::JoinSet;
25 use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
26 use tracing::{Instrument, debug, info, info_span};
27
28 use super::*;
29 pub use crate::server::QuicConfig;
30
31 pub struct QuicServer {
32 bind_addr: SocketAddr,
33 cancel: CancellationToken,
34 handle: AbortOnDropHandle<()>,
35 }
36
37 #[allow(missing_docs)]
39 #[stack_error(derive, add_meta)]
40 #[non_exhaustive]
41 pub enum QuicSpawnError {
42 #[error(transparent)]
43 NoInitialCipherSuite {
44 #[error(std_err, from)]
45 source: NoInitialCipherSuite,
46 },
47 #[error("Unable to spawn a QUIC endpoint server")]
48 EndpointServer {
49 #[error(std_err)]
50 source: std::io::Error,
51 },
52 #[error("Unable to get the local address from the endpoint")]
53 LocalAddr {
54 #[error(std_err)]
55 source: std::io::Error,
56 },
57 }
58
59 impl QuicServer {
60 pub fn handle(&self) -> ServerHandle {
65 ServerHandle {
66 cancel_token: self.cancel.clone(),
67 }
68 }
69
70 pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
76 &mut self.handle
77 }
78
79 pub fn bind_addr(&self) -> SocketAddr {
81 self.bind_addr
82 }
83
84 pub(crate) fn spawn(mut quic_config: QuicConfig) -> Result<Self, QuicSpawnError> {
97 quic_config.server_config.alpn_protocols =
98 vec![crate::quic::ALPN_QUIC_ADDR_DISC.to_vec()];
99 let server_config = QuicServerConfig::try_from(quic_config.server_config)?;
100 let mut server_config = noq::ServerConfig::with_crypto(Arc::new(server_config));
101 let transport_config =
102 Arc::get_mut(&mut server_config.transport).expect("not used yet");
103 transport_config
104 .max_concurrent_uni_streams(0_u8.into())
105 .max_concurrent_bidi_streams(0_u8.into())
106 .send_observed_address_reports(true);
108
109 let endpoint = noq::Endpoint::server(server_config, quic_config.bind_addr)
110 .map_err(|err| e!(QuicSpawnError::EndpointServer, err))?;
111 let bind_addr = endpoint
112 .local_addr()
113 .map_err(|err| e!(QuicSpawnError::LocalAddr, err))?;
114
115 info!(?bind_addr, "QUIC server listening on");
116
117 let cancel = CancellationToken::new();
118 let cancel_accept_loop = cancel.clone();
119
120 let task = tokio::task::spawn(
121 async move {
122 let mut set = JoinSet::new();
123 debug!("waiting for connections...");
124 loop {
125 tokio::select! {
126 biased;
127 _ = cancel_accept_loop.cancelled() => {
128 break;
129 }
130 Some(res) = set.join_next() => {
131 if let Err(err) = res {
132 if err.is_panic() {
133 panic!("task panicked: {err:#?}");
134 } else {
135 debug!("error accepting incoming connection: {err:#?}");
136 }
137 }
138 }
139 res = endpoint.accept() => match res {
140 Some(conn) => {
141 debug!("accepting connection");
142 let remote_addr = conn.remote_address();
143 set.spawn(
144 handle_connection(conn).instrument(info_span!("qad-conn", %remote_addr))
145 ); }
146 None => {
147 debug!("endpoint closed");
148 break;
149 }
150 }
151 }
152 }
153 endpoint.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
156 endpoint.wait_idle().await;
157
158 set.abort_all();
161 while !set.is_empty() {
162 _ = set.join_next().await;
163 }
164
165 debug!("quic endpoint has been shutdown.");
166 }
167 .instrument(info_span!("quic-endpoint")),
168 );
169 Ok(Self {
170 bind_addr,
171 cancel,
172 handle: AbortOnDropHandle::new(task),
173 })
174 }
175
176 pub async fn shutdown(mut self) {
179 self.cancel.cancel();
180 if !self.task_handle().is_finished() {
181 _ = self.task_handle().await;
184 }
185 }
186 }
187
188 #[derive(Debug, Clone)]
192 pub struct ServerHandle {
193 cancel_token: CancellationToken,
194 }
195
196 impl ServerHandle {
197 pub fn shutdown(&self) {
199 self.cancel_token.cancel()
200 }
201 }
202
203 async fn handle_connection(incoming: noq::Incoming) -> Result<(), ConnectionError> {
205 let connection = match incoming.await {
206 Ok(conn) => conn,
207 Err(e) => {
208 return Err(e);
209 }
210 };
211 debug!("established");
212 let connection_err = connection.closed().await;
214 match connection_err {
215 noq::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. })
216 if error_code == QUIC_ADDR_DISC_CLOSE_CODE =>
217 {
218 Ok(())
219 }
220 _ => Err(connection_err),
221 }
222 }
223}
224
225#[allow(missing_docs)]
227#[stack_error(derive, add_meta, from_sources, std_sources)]
228#[non_exhaustive]
229pub enum Error {
230 #[error(transparent)]
231 Connect {
232 #[error(std_err)]
233 source: noq::ConnectError,
234 },
235 #[error(transparent)]
236 Connection {
237 #[error(std_err)]
238 source: noq::ConnectionError,
239 },
240 #[error(transparent)]
241 WatchRecv {
242 #[error(std_err)]
243 source: watch::error::RecvError,
244 },
245}
246
247#[derive(Debug, Clone)]
249pub struct QuicClient {
250 ep: noq::Endpoint,
252 client_config: noq::ClientConfig,
254}
255
256impl QuicClient {
257 pub fn new(ep: noq::Endpoint, mut client_config: rustls::ClientConfig) -> Self {
260 client_config.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.into()];
262 let mut client_config = noq::ClientConfig::new(Arc::new(
265 QuicClientConfig::try_from(client_config).expect("known ciphersuite"),
266 ));
267
268 let mut transport = noq_proto::TransportConfig::default();
270 transport.initial_rtt(Duration::from_millis(111));
281 transport.receive_observed_address_reports(true);
282
283 transport.keep_alive_interval(Some(Duration::from_secs(25)));
285 transport.max_idle_timeout(Some(
286 Duration::from_secs(35).try_into().expect("known value"),
287 ));
288 client_config.transport_config(Arc::new(transport));
289
290 Self { ep, client_config }
291 }
292
293 #[cfg(all(test, feature = "server"))]
300 async fn get_addr_and_latency(
301 &self,
302 server_addr: SocketAddr,
303 host: &str,
304 ) -> Result<(SocketAddr, std::time::Duration), Error> {
305 use noq_proto::PathId;
306
307 let connecting = self
308 .ep
309 .connect_with(self.client_config.clone(), server_addr, host);
310 let conn = connecting?.await?;
311 let mut external_addresses = conn.observed_external_addr();
312 let res = match external_addresses.wait_for(|addr| addr.is_some()).await {
328 Ok(res) => res,
329 Err(err) => {
330 conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
332 return Err(err.into());
333 }
334 };
335 let mut observed_addr = res.expect("checked");
336 observed_addr = SocketAddr::new(observed_addr.ip().to_canonical(), observed_addr.port());
339 let latency = conn.rtt(PathId::ZERO).unwrap_or_default();
340 conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
342 Ok((observed_addr, latency))
343 }
344
345 pub async fn create_conn(
347 &self,
348 server_addr: SocketAddr,
349 host: &str,
350 ) -> Result<noq::Connection, Error> {
351 let config = self.client_config.clone();
352 let connecting = self.ep.connect_with(config, server_addr, host);
353 let conn = connecting?.await?;
354 Ok(conn)
355 }
356}
357
358#[cfg(all(test, feature = "server"))]
359mod tests {
360 use std::net::Ipv4Addr;
361
362 use n0_error::{Result, StdResultExt};
363 use n0_future::{
364 task::AbortOnDropHandle,
365 time::{self, Instant},
366 };
367 use n0_tracing_test::traced_test;
368 use noq::crypto::rustls::QuicServerConfig;
369 use tracing::{Instrument, debug, info, info_span};
370 use webpki_types::PrivatePkcs8KeyDer;
371
372 use super::*;
373
374 #[tokio::test]
375 #[traced_test]
376 #[cfg(feature = "test-utils")]
377 async fn quic_endpoint_basic() -> Result {
378 use super::server::{QuicConfig, QuicServer};
379
380 let host: Ipv4Addr = "127.0.0.1".parse().unwrap();
381 let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config();
383 let bind_addr = SocketAddr::new(host.into(), 0);
384 let quic_server = QuicServer::spawn(QuicConfig {
385 server_config,
386 bind_addr,
387 })?;
388
389 let client_endpoint =
391 noq::Endpoint::client(SocketAddr::new(host.into(), 0)).std_context("client")?;
392 let client_addr = client_endpoint.local_addr().std_context("local addr")?;
393
394 let client_config = crate::client::make_dangerous_client_config();
397 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
398
399 let (addr, _latency) = quic_client
400 .get_addr_and_latency(quic_server.bind_addr(), &host.to_string())
401 .await?;
402
403 client_endpoint.wait_idle().await;
405 quic_server.shutdown().await;
407
408 assert_eq!(client_addr, addr);
409 Ok(())
410 }
411
412 #[tokio::test(start_paused = true)]
413 #[traced_test]
414 async fn test_qad_client_closes_unresponsive_fast() -> Result {
415 let client_endpoint = noq::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
417 .std_context("client")?;
418
419 let server_socket =
421 tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
422 .await
423 .std_context("bind")?;
424 let server_addr = server_socket.local_addr().std_context("local addr")?;
425
426 let client_config = crate::client::make_dangerous_client_config();
429 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
430
431 let task = AbortOnDropHandle::new(tokio::spawn({
433 async move {
434 quic_client
435 .get_addr_and_latency(server_addr, "localhost")
436 .await
437 }
438 }));
439
440 tokio::time::sleep(Duration::from_millis(1000)).await;
442 assert!(!task.is_finished());
443
444 let before = Instant::now();
446 client_endpoint.close(0u32.into(), b"byeeeee");
447 client_endpoint.wait_idle().await;
448 let time = Instant::now().duration_since(before);
449
450 assert_eq!(time, Duration::from_millis(999));
451
452 Ok(())
453 }
454
455 #[tokio::test]
462 async fn test_qad_connect_delayed() -> Result {
464 tracing_subscriber::fmt::try_init().ok();
465 let socket = tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
468 .await
469 .std_context("bind")?;
470 let server_addr = socket.local_addr().std_context("local addr")?;
471 info!(addr = ?server_addr, "server socket bound");
472
473 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
475 .std_context("self signed")?;
476 let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
477 let mut server_crypto = rustls::ServerConfig::builder()
478 .with_no_client_auth()
479 .with_single_cert(vec![cert.cert.into()], key.into())
480 .std_context("tls")?;
481 server_crypto.key_log = Arc::new(rustls::KeyLogFile::new());
482 server_crypto.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.to_vec()];
483 let mut server_config = noq::ServerConfig::with_crypto(Arc::new(
484 QuicServerConfig::try_from(server_crypto).std_context("config")?,
485 ));
486 let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
487 transport_config.send_observed_address_reports(true);
488
489 let start = Instant::now();
490 let server_task = tokio::spawn(
491 async move {
492 info!("Dropping all packets");
493 time::timeout(Duration::from_secs(2), async {
494 let mut buf = [0u8; 1500];
495 loop {
496 let (len, src) = socket.recv_from(&mut buf).await.unwrap();
497 debug!(%len, ?src, "Dropped a packet");
498 }
499 })
500 .await
501 .ok();
502 info!("starting server");
503 let server = noq::Endpoint::new(
504 Default::default(),
505 Some(server_config),
506 socket.into_std().unwrap(),
507 Arc::new(noq::TokioRuntime),
508 )
509 .std_context("endpoint new")?;
510 info!("accepting conn");
511 let incoming = server.accept().await.expect("missing conn");
512 info!("incoming!");
513 let conn = incoming.await.std_context("incoming")?;
514 conn.closed().await;
515 server.wait_idle().await;
516 n0_error::Ok(())
517 }
518 .instrument(info_span!("server")),
519 );
520 let server_task = AbortOnDropHandle::new(server_task);
521
522 info!("starting client");
523 let client_endpoint = noq::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
524 .std_context("client")?;
525
526 let client_config = crate::client::make_dangerous_client_config();
529 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
530
531 info!("making QAD request");
533 let (addr, latency) = time::timeout(
534 Duration::from_secs(10),
535 quic_client.get_addr_and_latency(server_addr, "localhost"),
536 )
537 .await
538 .std_context("timeout")??;
539 let duration = start.elapsed();
540 info!(?duration, ?addr, ?latency, "QAD succeeded");
541 assert!(duration >= Duration::from_secs(1));
542
543 time::timeout(Duration::from_secs(10), server_task)
544 .await
545 .std_context("timeout")?
546 .std_context("server task")??;
547
548 Ok(())
549 }
550}