iroh_quinn/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io::{self, IoSliceMut},
6    mem,
7    net::{SocketAddr, SocketAddrV6},
8    num::NonZeroUsize,
9    pin::Pin,
10    str,
11    sync::{Arc, Mutex},
12    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
13};
14
15#[cfg(all(
16    not(wasm_browser),
17    any(feature = "runtime-tokio", feature = "runtime-smol"),
18    any(feature = "aws-lc-rs", feature = "ring"),
19))]
20use crate::runtime::default_runtime;
21use crate::{
22    Instant,
23    runtime::{AsyncUdpSocket, Runtime, UdpSender},
24    udp_transmit,
25};
26use bytes::{Bytes, BytesMut};
27use pin_project_lite::pin_project;
28use proto::{
29    self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
30    EndpointEvent, ServerConfig,
31};
32use rustc_hash::FxHashMap;
33#[cfg(all(
34    not(wasm_browser),
35    any(feature = "runtime-tokio", feature = "runtime-smol"),
36    any(feature = "aws-lc-rs", feature = "ring"),
37))]
38use socket2::{Domain, Protocol, Socket, Type};
39use tokio::sync::{Notify, futures::Notified, mpsc};
40use tracing::{Instrument, Span};
41use udp::{BATCH_SIZE, RecvMeta};
42
43use crate::{
44    ConnectionEvent, EndpointConfig, IO_LOOP_BOUND, RECV_TIME_BOUND, VarInt,
45    connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter,
46};
47
48/// A QUIC endpoint.
49///
50/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
51/// client and server for different connections.
52///
53/// May be cloned to obtain another handle to the same endpoint.
54#[derive(Debug, Clone)]
55pub struct Endpoint {
56    pub(crate) inner: EndpointRef,
57    pub(crate) default_client_config: Option<ClientConfig>,
58    runtime: Arc<dyn Runtime>,
59}
60
61impl Endpoint {
62    /// Helper to construct an endpoint for use with outgoing connections only
63    ///
64    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
65    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
66    /// IPv6 address respectively from an OS-assigned port.
67    ///
68    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
69    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
70    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
71    /// address. For example:
72    ///
73    /// ```
74    /// iroh_quinn::Endpoint::client((std::net::Ipv6Addr::UNSPECIFIED, 0).into());
75    /// ```
76    ///
77    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
78    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
79    #[cfg(all(
80        not(wasm_browser),
81        any(feature = "runtime-tokio", feature = "runtime-smol"),
82        any(feature = "aws-lc-rs", feature = "ring"), // `EndpointConfig::default()` is only available with these
83    ))]
84    pub fn client(addr: SocketAddr) -> io::Result<Self> {
85        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
86        if addr.is_ipv6() {
87            if let Err(e) = socket.set_only_v6(false) {
88                tracing::debug!(%e, "unable to make socket dual-stack");
89            }
90        }
91        socket.bind(&addr.into())?;
92        let runtime =
93            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
94        Self::new_with_abstract_socket(
95            EndpointConfig::default(),
96            None,
97            runtime.wrap_udp_socket(socket.into())?,
98            runtime,
99        )
100    }
101
102    /// Returns relevant stats from this Endpoint
103    pub fn stats(&self) -> EndpointStats {
104        self.inner.state.lock().unwrap().stats
105    }
106
107    /// Helper to construct an endpoint for use with both incoming and outgoing connections
108    ///
109    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
110    /// IPv6 address on Windows will not by default be able to communicate with IPv4
111    /// addresses. Portable applications should bind an address that matches the family they wish to
112    /// communicate within.
113    #[cfg(all(
114        not(wasm_browser),
115        any(feature = "runtime-tokio", feature = "runtime-smol"),
116        any(feature = "aws-lc-rs", feature = "ring"), // `EndpointConfig::default()` is only available with these
117    ))]
118    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
119        let socket = std::net::UdpSocket::bind(addr)?;
120        let runtime =
121            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
122        Self::new_with_abstract_socket(
123            EndpointConfig::default(),
124            Some(config),
125            runtime.wrap_udp_socket(socket)?,
126            runtime,
127        )
128    }
129
130    /// Construct an endpoint with arbitrary configuration and socket
131    #[cfg(not(wasm_browser))]
132    pub fn new(
133        config: EndpointConfig,
134        server_config: Option<ServerConfig>,
135        socket: std::net::UdpSocket,
136        runtime: Arc<dyn Runtime>,
137    ) -> io::Result<Self> {
138        let socket = runtime.wrap_udp_socket(socket)?;
139        Self::new_with_abstract_socket(config, server_config, socket, runtime)
140    }
141
142    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
143    ///
144    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
145    /// ownership is needed.
146    pub fn new_with_abstract_socket(
147        config: EndpointConfig,
148        server_config: Option<ServerConfig>,
149        socket: Box<dyn AsyncUdpSocket>,
150        runtime: Arc<dyn Runtime>,
151    ) -> io::Result<Self> {
152        let addr = socket.local_addr()?;
153        let allow_mtud = !socket.may_fragment();
154        let rc = EndpointRef::new(
155            socket,
156            proto::Endpoint::new(
157                Arc::new(config),
158                server_config.map(Arc::new),
159                allow_mtud,
160                None,
161            ),
162            addr.is_ipv6(),
163            runtime.clone(),
164        );
165        let driver = EndpointDriver(rc.clone());
166        runtime.spawn(Box::pin(
167            async {
168                if let Err(e) = driver.await {
169                    tracing::error!("I/O error: {}", e);
170                }
171            }
172            .instrument(Span::current()),
173        ));
174        Ok(Self {
175            inner: rc,
176            default_client_config: None,
177            runtime,
178        })
179    }
180
181    /// Get the next incoming connection attempt from a client
182    ///
183    /// Yields [`Incoming`]s, or `None` if the endpoint is [`close`](Self::close)d. [`Incoming`]
184    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
185    /// filter connection attempts or force address validation, or converted into an intermediate
186    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
187    pub fn accept(&self) -> Accept<'_> {
188        Accept {
189            endpoint: self,
190            notify: self.inner.shared.incoming.notified(),
191        }
192    }
193
194    /// Set the client configuration used by `connect`
195    pub fn set_default_client_config(&mut self, config: ClientConfig) {
196        self.default_client_config = Some(config);
197    }
198
199    /// Connect to a remote endpoint
200    ///
201    /// `server_name` must be covered by the certificate presented by the server. This prevents a
202    /// connection from being intercepted by an attacker with a valid certificate for some other
203    /// server.
204    ///
205    /// May fail immediately due to configuration errors, or in the future if the connection could
206    /// not be established.
207    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
208        let config = match &self.default_client_config {
209            Some(config) => config.clone(),
210            None => return Err(ConnectError::NoDefaultClientConfig),
211        };
212
213        self.connect_with(config, addr, server_name)
214    }
215
216    /// Connect to a remote endpoint using a custom configuration.
217    ///
218    /// See [`connect()`] for details.
219    ///
220    /// [`connect()`]: Endpoint::connect
221    pub fn connect_with(
222        &self,
223        config: ClientConfig,
224        addr: SocketAddr,
225        server_name: &str,
226    ) -> Result<Connecting, ConnectError> {
227        let mut endpoint = self.inner.state.lock().unwrap();
228        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
229            return Err(ConnectError::EndpointStopping);
230        }
231        if addr.is_ipv6() && !endpoint.ipv6 {
232            return Err(ConnectError::InvalidRemoteAddress(addr));
233        }
234        let addr = if endpoint.ipv6 {
235            SocketAddr::V6(ensure_ipv6(addr))
236        } else {
237            addr
238        };
239
240        let (ch, conn) = endpoint
241            .inner
242            .connect(self.runtime.now(), config, addr, server_name)?;
243
244        let sender = endpoint.socket.create_sender();
245        endpoint.stats.outgoing_handshakes += 1;
246        Ok(endpoint
247            .recv_state
248            .connections
249            .insert(ch, conn, sender, self.runtime.clone()))
250    }
251
252    /// Switch to a new UDP socket
253    ///
254    /// See [`Endpoint::rebind_abstract()`] for details.
255    #[cfg(not(wasm_browser))]
256    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
257        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
258    }
259
260    /// Switch to a new UDP socket
261    ///
262    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
263    /// connections and connections to servers unreachable from the new address will be lost.
264    ///
265    /// On error, the old UDP socket is retained.
266    pub fn rebind_abstract(&self, socket: Box<dyn AsyncUdpSocket>) -> io::Result<()> {
267        let addr = socket.local_addr()?;
268        let mut inner = self.inner.state.lock().unwrap();
269        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
270        inner.ipv6 = addr.is_ipv6();
271
272        // Update connection socket references
273        for sender in inner.recv_state.connections.senders.values() {
274            // Ignoring errors from dropped connections
275            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.create_sender()));
276        }
277        if let Some(driver) = inner.driver.take() {
278            // Ensure the driver can register for wake-ups from the new socket
279            driver.wake();
280        }
281
282        Ok(())
283    }
284
285    /// Replace the server configuration, affecting new incoming connections only
286    ///
287    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
288    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
289        self.inner
290            .state
291            .lock()
292            .unwrap()
293            .inner
294            .set_server_config(server_config.map(Arc::new))
295    }
296
297    /// Get the local `SocketAddr` the underlying socket is bound to
298    pub fn local_addr(&self) -> io::Result<SocketAddr> {
299        self.inner.state.lock().unwrap().socket.local_addr()
300    }
301
302    /// Get the number of connections that are currently open
303    pub fn open_connections(&self) -> usize {
304        self.inner.state.lock().unwrap().inner.open_connections()
305    }
306
307    /// Close all of this endpoint's connections immediately and cease accepting new connections.
308    ///
309    /// See [`Connection::close()`] for details.
310    ///
311    /// [`Connection::close()`]: crate::Connection::close
312    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
313        let reason = Bytes::copy_from_slice(reason);
314        let mut endpoint = self.inner.state.lock().unwrap();
315        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
316        for sender in endpoint.recv_state.connections.senders.values() {
317            // Ignoring errors from dropped connections
318            let _ = sender.send(ConnectionEvent::Close {
319                error_code,
320                reason: reason.clone(),
321            });
322        }
323        self.inner.shared.incoming.notify_waiters();
324    }
325
326    /// Wait for all connections on the endpoint to be cleanly shut down
327    ///
328    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
329    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
330    /// the idle timeout period.
331    ///
332    /// Does not proactively close existing connections or cause incoming connections to be
333    /// rejected. Consider calling [`close()`] if that is desired.
334    ///
335    /// [`close()`]: Endpoint::close
336    pub async fn wait_idle(&self) {
337        loop {
338            {
339                let endpoint = &mut *self.inner.state.lock().unwrap();
340                if endpoint.recv_state.connections.is_empty() {
341                    break;
342                }
343                // Construct future while lock is held to avoid race
344                self.inner.shared.idle.notified()
345            }
346            .await;
347        }
348    }
349}
350
351/// Statistics on [Endpoint] activity
352#[non_exhaustive]
353#[derive(Debug, Default, Copy, Clone)]
354pub struct EndpointStats {
355    /// Cumulative number of Quic handshakes accepted by this [Endpoint]
356    pub accepted_handshakes: u64,
357    /// Cumulative number of Quic handshakes sent from this [Endpoint]
358    pub outgoing_handshakes: u64,
359    /// Cumulative number of Quic handshakes refused on this [Endpoint]
360    pub refused_handshakes: u64,
361    /// Cumulative number of Quic handshakes ignored on this [Endpoint]
362    pub ignored_handshakes: u64,
363}
364
365/// A future that drives IO on an endpoint
366///
367/// This task functions as the switch point between the UDP socket object and the
368/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
369/// In order to do so, it also facilitates the exchange of different types of events
370/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
371/// running this task is necessary to keep the endpoint's connections running.
372///
373/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
374/// an I/O error occurs.
375#[must_use = "endpoint drivers must be spawned for I/O to occur"]
376#[derive(Debug)]
377pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
378
379impl Future for EndpointDriver {
380    type Output = Result<(), io::Error>;
381
382    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
383        let mut endpoint = self.0.state.lock().unwrap();
384        if endpoint.driver.is_none() {
385            endpoint.driver = Some(cx.waker().clone());
386        }
387
388        let now = endpoint.runtime.now();
389        let mut keep_going = false;
390        keep_going |= endpoint.drive_recv(cx, now)?;
391        keep_going |= endpoint.handle_events(cx, &self.0.shared);
392
393        if !endpoint.recv_state.incoming.is_empty() {
394            self.0.shared.incoming.notify_waiters();
395        }
396
397        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
398            Poll::Ready(Ok(()))
399        } else {
400            drop(endpoint);
401            // If there is more work to do schedule the endpoint task again.
402            // `wake_by_ref()` is called outside the lock to minimize
403            // lock contention on a multithreaded runtime.
404            if keep_going {
405                cx.waker().wake_by_ref();
406            }
407            Poll::Pending
408        }
409    }
410}
411
412impl Drop for EndpointDriver {
413    fn drop(&mut self) {
414        let mut endpoint = self.0.state.lock().unwrap();
415        endpoint.driver_lost = true;
416        self.0.shared.incoming.notify_waiters();
417        // Drop all outgoing channels, signaling the termination of the endpoint to the associated
418        // connections.
419        endpoint.recv_state.connections.senders.clear();
420    }
421}
422
423#[derive(Debug)]
424pub(crate) struct EndpointInner {
425    pub(crate) state: Mutex<State>,
426    pub(crate) shared: Shared,
427}
428
429impl EndpointInner {
430    pub(crate) fn accept(
431        &self,
432        incoming: proto::Incoming,
433        server_config: Option<Arc<ServerConfig>>,
434    ) -> Result<Connecting, ConnectionError> {
435        let mut state = self.state.lock().unwrap();
436        let mut response_buffer = Vec::new();
437        let now = state.runtime.now();
438        match state
439            .inner
440            .accept(incoming, now, &mut response_buffer, server_config)
441        {
442            Ok((handle, conn)) => {
443                state.stats.accepted_handshakes += 1;
444                let sender = state.socket.create_sender();
445                let runtime = state.runtime.clone();
446                Ok(state
447                    .recv_state
448                    .connections
449                    .insert(handle, conn, sender, runtime))
450            }
451            Err(error) => {
452                if let Some(transmit) = error.response {
453                    respond(transmit, &response_buffer, &mut state.sender);
454                }
455                Err(error.cause)
456            }
457        }
458    }
459
460    pub(crate) fn refuse(&self, incoming: proto::Incoming) {
461        let mut state = self.state.lock().unwrap();
462        state.stats.refused_handshakes += 1;
463        let mut response_buffer = Vec::new();
464        let transmit = state.inner.refuse(incoming, &mut response_buffer);
465        respond(transmit, &response_buffer, &mut state.sender);
466    }
467
468    pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
469        let mut state = self.state.lock().unwrap();
470        let mut response_buffer = Vec::new();
471        let transmit = state.inner.retry(incoming, &mut response_buffer)?;
472        respond(transmit, &response_buffer, &mut state.sender);
473        Ok(())
474    }
475
476    pub(crate) fn ignore(&self, incoming: proto::Incoming) {
477        let mut state = self.state.lock().unwrap();
478        state.stats.ignored_handshakes += 1;
479        state.inner.ignore(incoming);
480    }
481}
482
483#[derive(Debug)]
484pub(crate) struct State {
485    socket: Box<dyn AsyncUdpSocket>,
486    sender: Pin<Box<dyn UdpSender>>,
487    /// During an active migration, abandoned_socket receives traffic
488    /// until the first packet arrives on the new socket.
489    prev_socket: Option<Box<dyn AsyncUdpSocket>>,
490    inner: proto::Endpoint,
491    recv_state: RecvState,
492    driver: Option<Waker>,
493    ipv6: bool,
494    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
495    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
496    ref_count: usize,
497    driver_lost: bool,
498    runtime: Arc<dyn Runtime>,
499    stats: EndpointStats,
500}
501
502#[derive(Debug)]
503pub(crate) struct Shared {
504    incoming: Notify,
505    idle: Notify,
506}
507
508impl State {
509    fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
510        let get_time = || self.runtime.now();
511        self.recv_state.recv_limiter.start_cycle(get_time);
512        if let Some(socket) = &mut self.prev_socket {
513            // We don't care about the `PollProgress` from old sockets.
514            let poll_res = self.recv_state.poll_socket(
515                cx,
516                &mut self.inner,
517                &mut **socket,
518                &mut self.sender,
519                &*self.runtime,
520                now,
521            );
522            if poll_res.is_err() {
523                self.prev_socket = None;
524            }
525        };
526        let poll_res = self.recv_state.poll_socket(
527            cx,
528            &mut self.inner,
529            &mut *self.socket,
530            &mut self.sender,
531            &*self.runtime,
532            now,
533        );
534        self.recv_state.recv_limiter.finish_cycle(get_time);
535        let poll_res = poll_res?;
536        if poll_res.received_connection_packet {
537            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
538            // one anymore. TODO: Account for multiple outgoing connections.
539            self.prev_socket = None;
540        }
541        Ok(poll_res.keep_going)
542    }
543
544    fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
545        for _ in 0..IO_LOOP_BOUND {
546            let (ch, event) = match self.events.poll_recv(cx) {
547                Poll::Ready(Some(x)) => x,
548                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
549                Poll::Pending => {
550                    return false;
551                }
552            };
553
554            if event.is_drained() {
555                self.recv_state.connections.senders.remove(&ch);
556                if self.recv_state.connections.is_empty() {
557                    shared.idle.notify_waiters();
558                }
559            }
560            let Some(event) = self.inner.handle_event(ch, event) else {
561                continue;
562            };
563            // Ignoring errors from dropped connections that haven't yet been cleaned up
564            let _ = self
565                .recv_state
566                .connections
567                .senders
568                .get_mut(&ch)
569                .unwrap()
570                .send(ConnectionEvent::Proto(event));
571        }
572
573        true
574    }
575}
576
577impl Drop for State {
578    fn drop(&mut self) {
579        for incoming in self.recv_state.incoming.drain(..) {
580            self.inner.ignore(incoming);
581        }
582    }
583}
584
585fn respond(
586    transmit: proto::Transmit,
587    response_buffer: &[u8],
588    sender: &mut Pin<Box<dyn UdpSender>>,
589) {
590    // Send if there's kernel buffer space; otherwise, drop it
591    //
592    // As an endpoint-generated packet, we know this is an
593    // immediate, stateless response to an unconnected peer,
594    // one of:
595    //
596    // - A version negotiation response due to an unknown version
597    // - A `CLOSE` due to a malformed or unwanted connection attempt
598    // - A stateless reset due to an unrecognized connection
599    // - A `Retry` packet due to a connection attempt when
600    //   `use_retry` is set
601    //
602    // In each case, a well-behaved peer can be trusted to retry a
603    // few times, which is guaranteed to produce the same response
604    // from us. Repeated failures might at worst cause a peer's new
605    // connection attempt to time out, which is acceptable if we're
606    // under such heavy load that there's never room for this code
607    // to transmit. This is morally equivalent to the packet getting
608    // lost due to congestion further along the link, which
609    // similarly relies on peer retries for recovery.
610
611    // Copied from rust 1.85's std::task::Waker::noop() implementation for backwards compatibility
612    const NOOP: RawWaker = {
613        const VTABLE: RawWakerVTable = RawWakerVTable::new(
614            // Cloning just returns a new no-op raw waker
615            |_| NOOP,
616            // `wake` does nothing
617            |_| {},
618            // `wake_by_ref` does nothing
619            |_| {},
620            // Dropping does nothing as we don't allocate anything
621            |_| {},
622        );
623        RawWaker::new(std::ptr::null(), &VTABLE)
624    };
625    // SAFETY: Copied from rust stdlib, the NOOP waker is thread-safe and doesn't violate the RawWakerVTable contract,
626    // it doesn't access the data pointer at all.
627    let waker = unsafe { Waker::from_raw(NOOP) };
628    let mut cx = Context::from_waker(&waker);
629    _ = sender.as_mut().poll_send(
630        &udp_transmit(&transmit, &response_buffer[..transmit.size]),
631        &mut cx,
632    );
633}
634
635#[inline]
636fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
637    match ecn {
638        udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
639        udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
640        udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
641    }
642}
643
644#[derive(Debug)]
645struct ConnectionSet {
646    /// Senders for communicating with the endpoint's connections
647    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
648    /// Stored to give out clones to new ConnectionInners
649    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
650    /// Set if the endpoint has been manually closed
651    close: Option<(VarInt, Bytes)>,
652}
653
654impl ConnectionSet {
655    fn insert(
656        &mut self,
657        handle: ConnectionHandle,
658        conn: proto::Connection,
659        sender: Pin<Box<dyn UdpSender>>,
660        runtime: Arc<dyn Runtime>,
661    ) -> Connecting {
662        let (send, recv) = mpsc::unbounded_channel();
663        if let Some((error_code, ref reason)) = self.close {
664            send.send(ConnectionEvent::Close {
665                error_code,
666                reason: reason.clone(),
667            })
668            .unwrap();
669        }
670        self.senders.insert(handle, send);
671        Connecting::new(handle, conn, self.sender.clone(), recv, sender, runtime)
672    }
673
674    fn is_empty(&self) -> bool {
675        self.senders.is_empty()
676    }
677}
678
679pub(crate) fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
680    match x {
681        SocketAddr::V6(x) => x,
682        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
683    }
684}
685
686pin_project! {
687    /// Future produced by [`Endpoint::accept`]
688    pub struct Accept<'a> {
689        endpoint: &'a Endpoint,
690        #[pin]
691        notify: Notified<'a>,
692    }
693}
694
695impl Future for Accept<'_> {
696    type Output = Option<Incoming>;
697    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
698        let mut this = self.project();
699        let mut endpoint = this.endpoint.inner.state.lock().unwrap();
700        if endpoint.driver_lost {
701            return Poll::Ready(None);
702        }
703        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
704            // Release the mutex lock on endpoint so cloning it doesn't deadlock
705            drop(endpoint);
706            let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
707            return Poll::Ready(Some(incoming));
708        }
709        if endpoint.recv_state.connections.close.is_some() {
710            return Poll::Ready(None);
711        }
712        loop {
713            match this.notify.as_mut().poll(ctx) {
714                // `state` lock ensures we didn't race with readiness
715                Poll::Pending => return Poll::Pending,
716                // Spurious wakeup, get a new future
717                Poll::Ready(()) => this
718                    .notify
719                    .set(this.endpoint.inner.shared.incoming.notified()),
720            }
721        }
722    }
723}
724
725#[derive(Debug)]
726pub(crate) struct EndpointRef(Arc<EndpointInner>);
727
728impl EndpointRef {
729    pub(crate) fn new(
730        socket: Box<dyn AsyncUdpSocket>,
731        inner: proto::Endpoint,
732        ipv6: bool,
733        runtime: Arc<dyn Runtime>,
734    ) -> Self {
735        let (sender, events) = mpsc::unbounded_channel();
736        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
737        let sender = socket.create_sender();
738        Self(Arc::new(EndpointInner {
739            shared: Shared {
740                incoming: Notify::new(),
741                idle: Notify::new(),
742            },
743            state: Mutex::new(State {
744                socket,
745                sender,
746                prev_socket: None,
747                inner,
748                ipv6,
749                events,
750                driver: None,
751                ref_count: 0,
752                driver_lost: false,
753                recv_state,
754                runtime,
755                stats: EndpointStats::default(),
756            }),
757        }))
758    }
759}
760
761impl Clone for EndpointRef {
762    fn clone(&self) -> Self {
763        self.0.state.lock().unwrap().ref_count += 1;
764        Self(self.0.clone())
765    }
766}
767
768impl Drop for EndpointRef {
769    fn drop(&mut self) {
770        let endpoint = &mut *self.0.state.lock().unwrap();
771        if let Some(x) = endpoint.ref_count.checked_sub(1) {
772            endpoint.ref_count = x;
773            if x == 0 {
774                // If the driver is about to be on its own, ensure it can shut down if the last
775                // connection is gone.
776                if let Some(task) = endpoint.driver.take() {
777                    task.wake();
778                }
779            }
780        }
781    }
782}
783
784impl std::ops::Deref for EndpointRef {
785    type Target = EndpointInner;
786    fn deref(&self) -> &Self::Target {
787        &self.0
788    }
789}
790
791/// State directly involved in handling incoming packets
792struct RecvState {
793    incoming: VecDeque<proto::Incoming>,
794    connections: ConnectionSet,
795    recv_buf: Box<[u8]>,
796    recv_limiter: WorkLimiter,
797}
798
799impl RecvState {
800    fn new(
801        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
802        max_receive_segments: NonZeroUsize,
803        endpoint: &proto::Endpoint,
804    ) -> Self {
805        let recv_buf = vec![
806            0;
807            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
808                * max_receive_segments.get()
809                * BATCH_SIZE
810        ];
811        Self {
812            connections: ConnectionSet {
813                senders: FxHashMap::default(),
814                sender,
815                close: None,
816            },
817            incoming: VecDeque::new(),
818            recv_buf: recv_buf.into(),
819            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
820        }
821    }
822
823    fn poll_socket(
824        &mut self,
825        cx: &mut Context,
826        endpoint: &mut proto::Endpoint,
827        socket: &mut dyn AsyncUdpSocket,
828        sender: &mut Pin<Box<dyn UdpSender>>,
829        runtime: &dyn Runtime,
830        now: Instant,
831    ) -> Result<PollProgress, io::Error> {
832        let mut received_connection_packet = false;
833        let mut metas = [RecvMeta::default(); BATCH_SIZE];
834        let mut iovs: [IoSliceMut; BATCH_SIZE] = {
835            let mut bufs = self
836                .recv_buf
837                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
838                .map(IoSliceMut::new);
839
840            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
841            // and iovs will be of size BATCH_SIZE, thus from_fn is called
842            // exactly BATCH_SIZE times.
843            std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
844        };
845        loop {
846            match socket.poll_recv(cx, &mut iovs, &mut metas) {
847                Poll::Ready(Ok(msgs)) => {
848                    self.recv_limiter.record_work(msgs);
849                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
850                        let mut data: BytesMut = buf[0..meta.len].into();
851                        while !data.is_empty() {
852                            let buf = data.split_to(meta.stride.min(data.len()));
853                            let mut response_buffer = Vec::new();
854                            match endpoint.handle(
855                                now,
856                                meta.addr,
857                                meta.dst_ip,
858                                meta.ecn.map(proto_ecn),
859                                buf,
860                                &mut response_buffer,
861                            ) {
862                                Some(DatagramEvent::NewConnection(incoming)) => {
863                                    if self.connections.close.is_none() {
864                                        self.incoming.push_back(incoming);
865                                    } else {
866                                        let transmit =
867                                            endpoint.refuse(incoming, &mut response_buffer);
868                                        respond(transmit, &response_buffer, sender);
869                                    }
870                                }
871                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
872                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
873                                    received_connection_packet = true;
874                                    let _ = self
875                                        .connections
876                                        .senders
877                                        .get_mut(&handle)
878                                        .unwrap()
879                                        .send(ConnectionEvent::Proto(event));
880                                }
881                                Some(DatagramEvent::Response(transmit)) => {
882                                    respond(transmit, &response_buffer, sender);
883                                }
884                                None => {}
885                            }
886                        }
887                    }
888                }
889                Poll::Pending => {
890                    return Ok(PollProgress {
891                        received_connection_packet,
892                        keep_going: false,
893                    });
894                }
895                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
896                // attacker
897                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
898                    continue;
899                }
900                Poll::Ready(Err(e)) => {
901                    return Err(e);
902                }
903            }
904            if !self.recv_limiter.allow_work(|| runtime.now()) {
905                return Ok(PollProgress {
906                    received_connection_packet,
907                    keep_going: true,
908                });
909            }
910        }
911    }
912}
913
914impl fmt::Debug for RecvState {
915    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
916        f.debug_struct("RecvState")
917            .field("incoming", &self.incoming)
918            .field("connections", &self.connections)
919            // recv_buf too large
920            .field("recv_limiter", &self.recv_limiter)
921            .finish_non_exhaustive()
922    }
923}
924
925#[derive(Default)]
926struct PollProgress {
927    /// Whether a datagram was routed to an existing connection
928    received_connection_packet: bool,
929    /// Whether datagram handling was interrupted early by the work limiter for fairness
930    keep_going: bool,
931}