1use std::{net::SocketAddr, sync::Arc};
4
5use n0_error::stack_error;
6use n0_future::time::Duration;
7use quinn::{VarInt, crypto::rustls::QuicClientConfig};
8use tokio::sync::watch;
9
10pub const ALPN_QUIC_ADDR_DISC: &[u8] = b"/iroh-qad/0";
12pub const QUIC_ADDR_DISC_CLOSE_CODE: VarInt = VarInt::from_u32(1);
14pub const QUIC_ADDR_DISC_CLOSE_REASON: &[u8] = b"finished";
16
17#[cfg(feature = "server")]
18pub(crate) mod server {
19 use n0_error::e;
20 use quinn::{
21 ApplicationClose, ConnectionError,
22 crypto::rustls::{NoInitialCipherSuite, QuicServerConfig},
23 };
24 use tokio::task::JoinSet;
25 use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
26 use tracing::{Instrument, debug, info, info_span};
27
28 use super::*;
29 pub use crate::server::QuicConfig;
30
31 pub struct QuicServer {
32 bind_addr: SocketAddr,
33 cancel: CancellationToken,
34 handle: AbortOnDropHandle<()>,
35 }
36
37 #[allow(missing_docs)]
39 #[stack_error(derive, add_meta)]
40 #[non_exhaustive]
41 pub enum QuicSpawnError {
42 #[error(transparent)]
43 NoInitialCipherSuite {
44 #[error(std_err, from)]
45 source: NoInitialCipherSuite,
46 },
47 #[error("Unable to spawn a QUIC endpoint server")]
48 EndpointServer {
49 #[error(std_err)]
50 source: std::io::Error,
51 },
52 #[error("Unable to get the local address from the endpoint")]
53 LocalAddr {
54 #[error(std_err)]
55 source: std::io::Error,
56 },
57 }
58
59 impl QuicServer {
60 pub fn handle(&self) -> ServerHandle {
65 ServerHandle {
66 cancel_token: self.cancel.clone(),
67 }
68 }
69
70 pub fn task_handle(&mut self) -> &mut AbortOnDropHandle<()> {
76 &mut self.handle
77 }
78
79 pub fn bind_addr(&self) -> SocketAddr {
81 self.bind_addr
82 }
83
84 pub(crate) fn spawn(mut quic_config: QuicConfig) -> Result<Self, QuicSpawnError> {
97 quic_config.server_config.alpn_protocols =
98 vec![crate::quic::ALPN_QUIC_ADDR_DISC.to_vec()];
99 let server_config = QuicServerConfig::try_from(quic_config.server_config)?;
100 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
101 let transport_config =
102 Arc::get_mut(&mut server_config.transport).expect("not used yet");
103 transport_config
104 .max_concurrent_uni_streams(0_u8.into())
105 .max_concurrent_bidi_streams(0_u8.into())
106 .send_observed_address_reports(true);
108
109 let endpoint = quinn::Endpoint::server(server_config, quic_config.bind_addr)
110 .map_err(|err| e!(QuicSpawnError::EndpointServer, err))?;
111 let bind_addr = endpoint
112 .local_addr()
113 .map_err(|err| e!(QuicSpawnError::LocalAddr, err))?;
114
115 info!(?bind_addr, "QUIC server listening on");
116
117 let cancel = CancellationToken::new();
118 let cancel_accept_loop = cancel.clone();
119
120 let task = tokio::task::spawn(
121 async move {
122 let mut set = JoinSet::new();
123 debug!("waiting for connections...");
124 loop {
125 tokio::select! {
126 biased;
127 _ = cancel_accept_loop.cancelled() => {
128 break;
129 }
130 Some(res) = set.join_next() => {
131 if let Err(err) = res {
132 if err.is_panic() {
133 panic!("task panicked: {err:#?}");
134 } else {
135 debug!("error accepting incoming connection: {err:#?}");
136 }
137 }
138 }
139 res = endpoint.accept() => match res {
140 Some(conn) => {
141 debug!("accepting connection");
142 let remote_addr = conn.remote_address();
143 set.spawn(
144 handle_connection(conn).instrument(info_span!("qad-conn", %remote_addr))
145 ); }
146 None => {
147 debug!("endpoint closed");
148 break;
149 }
150 }
151 }
152 }
153 endpoint.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
156 endpoint.wait_idle().await;
157
158 set.abort_all();
161 while !set.is_empty() {
162 _ = set.join_next().await;
163 }
164
165 debug!("quic endpoint has been shutdown.");
166 }
167 .instrument(info_span!("quic-endpoint")),
168 );
169 Ok(Self {
170 bind_addr,
171 cancel,
172 handle: AbortOnDropHandle::new(task),
173 })
174 }
175
176 pub async fn shutdown(mut self) {
179 self.cancel.cancel();
180 if !self.task_handle().is_finished() {
181 _ = self.task_handle().await;
184 }
185 }
186 }
187
188 #[derive(Debug, Clone)]
192 pub struct ServerHandle {
193 cancel_token: CancellationToken,
194 }
195
196 impl ServerHandle {
197 pub fn shutdown(&self) {
199 self.cancel_token.cancel()
200 }
201 }
202
203 async fn handle_connection(incoming: quinn::Incoming) -> Result<(), ConnectionError> {
205 let connection = match incoming.await {
206 Ok(conn) => conn,
207 Err(e) => {
208 return Err(e);
209 }
210 };
211 debug!("established");
212 let connection_err = connection.closed().await;
214 match connection_err {
215 quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. })
216 if error_code == QUIC_ADDR_DISC_CLOSE_CODE =>
217 {
218 Ok(())
219 }
220 _ => Err(connection_err),
221 }
222 }
223}
224
225#[allow(missing_docs)]
227#[stack_error(derive, add_meta, from_sources, std_sources)]
228#[non_exhaustive]
229pub enum Error {
230 #[error(transparent)]
231 Connect {
232 #[error(std_err)]
233 source: quinn::ConnectError,
234 },
235 #[error(transparent)]
236 Connection {
237 #[error(std_err)]
238 source: quinn::ConnectionError,
239 },
240 #[error(transparent)]
241 WatchRecv {
242 #[error(std_err)]
243 source: watch::error::RecvError,
244 },
245}
246
247#[derive(Debug, Clone)]
249pub struct QuicClient {
250 ep: quinn::Endpoint,
252 client_config: quinn::ClientConfig,
254}
255
256impl QuicClient {
257 pub fn new(ep: quinn::Endpoint, mut client_config: rustls::ClientConfig) -> Self {
260 client_config.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.into()];
262 let mut client_config = quinn::ClientConfig::new(Arc::new(
265 QuicClientConfig::try_from(client_config).expect("known ciphersuite"),
266 ));
267
268 let mut transport = quinn_proto::TransportConfig::default();
270 transport.initial_rtt(Duration::from_millis(111));
281 transport.receive_observed_address_reports(true);
282
283 transport.keep_alive_interval(Some(Duration::from_secs(25)));
285 transport.max_idle_timeout(Some(
286 Duration::from_secs(35).try_into().expect("known value"),
287 ));
288 client_config.transport_config(Arc::new(transport));
289
290 Self { ep, client_config }
291 }
292
293 #[cfg(all(test, feature = "server"))]
300 async fn get_addr_and_latency(
301 &self,
302 server_addr: SocketAddr,
303 host: &str,
304 ) -> Result<(SocketAddr, std::time::Duration), Error> {
305 use quinn_proto::PathId;
306
307 let connecting = self
308 .ep
309 .connect_with(self.client_config.clone(), server_addr, host);
310 let conn = connecting?.await?;
311 let mut external_addresses = conn.observed_external_addr();
312 let res = match external_addresses.wait_for(|addr| addr.is_some()).await {
328 Ok(res) => res,
329 Err(err) => {
330 conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
332 return Err(err.into());
333 }
334 };
335 let mut observed_addr = res.expect("checked");
336 observed_addr = SocketAddr::new(observed_addr.ip().to_canonical(), observed_addr.port());
339 let latency = conn.rtt(PathId::ZERO).unwrap_or_default();
340 conn.close(QUIC_ADDR_DISC_CLOSE_CODE, QUIC_ADDR_DISC_CLOSE_REASON);
342 Ok((observed_addr, latency))
343 }
344
345 pub async fn create_conn(
347 &self,
348 server_addr: SocketAddr,
349 host: &str,
350 ) -> Result<quinn::Connection, Error> {
351 let config = self.client_config.clone();
352 let connecting = self.ep.connect_with(config, server_addr, host);
353 let conn = connecting?.await?;
354 Ok(conn)
355 }
356}
357
358#[cfg(all(test, feature = "server"))]
359mod tests {
360 use std::net::Ipv4Addr;
361
362 use n0_error::{Result, StdResultExt};
363 use n0_future::{
364 task::AbortOnDropHandle,
365 time::{self, Instant},
366 };
367 use n0_tracing_test::traced_test;
368 use quinn::crypto::rustls::QuicServerConfig;
369 use tracing::{Instrument, debug, info, info_span};
370 use webpki_types::PrivatePkcs8KeyDer;
371
372 use super::*;
373
374 #[tokio::test]
375 #[traced_test]
376 #[cfg(feature = "test-utils")]
377 async fn quic_endpoint_basic() -> Result {
378 use super::server::{QuicConfig, QuicServer};
379
380 let host: Ipv4Addr = "127.0.0.1".parse().unwrap();
381 let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config();
383 let bind_addr = SocketAddr::new(host.into(), 0);
384 let quic_server = QuicServer::spawn(QuicConfig {
385 server_config,
386 bind_addr,
387 })?;
388
389 let client_endpoint =
391 quinn::Endpoint::client(SocketAddr::new(host.into(), 0)).std_context("client")?;
392 let client_addr = client_endpoint.local_addr().std_context("local addr")?;
393
394 let client_config = crate::client::make_dangerous_client_config();
397 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
398
399 let (addr, _latency) = quic_client
400 .get_addr_and_latency(quic_server.bind_addr(), &host.to_string())
401 .await?;
402
403 client_endpoint.wait_idle().await;
405 quic_server.shutdown().await;
407
408 assert_eq!(client_addr, addr);
409 Ok(())
410 }
411
412 #[tokio::test(start_paused = true)]
413 #[traced_test]
414 async fn test_qad_client_closes_unresponsive_fast() -> Result {
415 let client_endpoint =
417 quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
418 .std_context("client")?;
419
420 let server_socket =
422 tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
423 .await
424 .std_context("bind")?;
425 let server_addr = server_socket.local_addr().std_context("local addr")?;
426
427 let client_config = crate::client::make_dangerous_client_config();
430 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
431
432 let task = AbortOnDropHandle::new(tokio::spawn({
434 async move {
435 quic_client
436 .get_addr_and_latency(server_addr, "localhost")
437 .await
438 }
439 }));
440
441 tokio::time::sleep(Duration::from_millis(1000)).await;
443 assert!(!task.is_finished());
444
445 let before = Instant::now();
447 client_endpoint.close(0u32.into(), b"byeeeee");
448 client_endpoint.wait_idle().await;
449 let time = Instant::now().duration_since(before);
450
451 assert_eq!(time, Duration::from_millis(999));
452
453 Ok(())
454 }
455
456 #[tokio::test]
463 async fn test_qad_connect_delayed() -> Result {
465 tracing_subscriber::fmt::try_init().ok();
466 let socket = tokio::net::UdpSocket::bind(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
469 .await
470 .std_context("bind")?;
471 let server_addr = socket.local_addr().std_context("local addr")?;
472 info!(addr = ?server_addr, "server socket bound");
473
474 let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
476 .std_context("self signed")?;
477 let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
478 let mut server_crypto = rustls::ServerConfig::builder()
479 .with_no_client_auth()
480 .with_single_cert(vec![cert.cert.into()], key.into())
481 .std_context("tls")?;
482 server_crypto.key_log = Arc::new(rustls::KeyLogFile::new());
483 server_crypto.alpn_protocols = vec![ALPN_QUIC_ADDR_DISC.to_vec()];
484 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
485 QuicServerConfig::try_from(server_crypto).std_context("config")?,
486 ));
487 let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
488 transport_config.send_observed_address_reports(true);
489
490 let start = Instant::now();
491 let server_task = tokio::spawn(
492 async move {
493 info!("Dropping all packets");
494 time::timeout(Duration::from_secs(2), async {
495 let mut buf = [0u8; 1500];
496 loop {
497 let (len, src) = socket.recv_from(&mut buf).await.unwrap();
498 debug!(%len, ?src, "Dropped a packet");
499 }
500 })
501 .await
502 .ok();
503 info!("starting server");
504 let server = quinn::Endpoint::new(
505 Default::default(),
506 Some(server_config),
507 socket.into_std().unwrap(),
508 Arc::new(quinn::TokioRuntime),
509 )
510 .std_context("endpoint new")?;
511 info!("accepting conn");
512 let incoming = server.accept().await.expect("missing conn");
513 info!("incoming!");
514 let conn = incoming.await.std_context("incoming")?;
515 conn.closed().await;
516 server.wait_idle().await;
517 n0_error::Ok(())
518 }
519 .instrument(info_span!("server")),
520 );
521 let server_task = AbortOnDropHandle::new(server_task);
522
523 info!("starting client");
524 let client_endpoint =
525 quinn::Endpoint::client(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0))
526 .std_context("client")?;
527
528 let client_config = crate::client::make_dangerous_client_config();
531 let quic_client = QuicClient::new(client_endpoint.clone(), client_config);
532
533 info!("making QAD request");
535 let (addr, latency) = time::timeout(
536 Duration::from_secs(10),
537 quic_client.get_addr_and_latency(server_addr, "localhost"),
538 )
539 .await
540 .std_context("timeout")??;
541 let duration = start.elapsed();
542 info!(?duration, ?addr, ?latency, "QAD succeeded");
543 assert!(duration >= Duration::from_secs(1));
544
545 time::timeout(Duration::from_secs(10), server_task)
546 .await
547 .std_context("timeout")?
548 .std_context("server task")??;
549
550 Ok(())
551 }
552}