iroh_quinn/
endpoint.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io::{self, IoSliceMut},
6    mem,
7    net::{SocketAddr, SocketAddrV6},
8    num::NonZeroUsize,
9    pin::Pin,
10    str,
11    sync::{Arc, Mutex},
12    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
13};
14
15#[cfg(all(
16    not(wasm_browser),
17    any(feature = "runtime-tokio", feature = "runtime-smol"),
18    any(feature = "aws-lc-rs", feature = "ring"),
19))]
20use crate::runtime::default_runtime;
21use crate::{
22    Instant,
23    runtime::{AsyncUdpSocket, Runtime, UdpSender},
24    udp_transmit,
25};
26use bytes::{Bytes, BytesMut};
27use pin_project_lite::pin_project;
28use proto::{
29    self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
30    EndpointEvent, FourTuple, ServerConfig,
31};
32use rustc_hash::FxHashMap;
33#[cfg(all(
34    not(wasm_browser),
35    any(feature = "runtime-tokio", feature = "runtime-smol"),
36    any(feature = "aws-lc-rs", feature = "ring"),
37))]
38use socket2::{Domain, Protocol, Socket, Type};
39use tokio::sync::{Notify, futures::Notified, mpsc};
40use tracing::{Instrument, Span};
41use udp::{BATCH_SIZE, RecvMeta};
42
43use crate::{
44    ConnectionEvent, EndpointConfig, IO_LOOP_BOUND, RECV_TIME_BOUND, VarInt,
45    connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter,
46};
47
48/// A QUIC endpoint.
49///
50/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both
51/// client and server for different connections.
52///
53/// May be cloned to obtain another handle to the same endpoint.
54#[derive(Debug, Clone)]
55pub struct Endpoint {
56    pub(crate) inner: EndpointRef,
57    runtime: Arc<dyn Runtime>,
58}
59
60impl Endpoint {
61    /// Helper to construct an endpoint for use with outgoing connections only
62    ///
63    /// Note that `addr` is the *local* address to bind to, which should usually be a wildcard
64    /// address like `0.0.0.0:0` or `[::]:0`, which allow communication with any reachable IPv4 or
65    /// IPv6 address respectively from an OS-assigned port.
66    ///
67    /// If an IPv6 address is provided, attempts to make the socket dual-stack so as to allow
68    /// communication with both IPv4 and IPv6 addresses. As such, calling `Endpoint::client` with
69    /// the address `[::]:0` is a reasonable default to maximize the ability to connect to other
70    /// address. For example:
71    ///
72    /// ```
73    /// iroh_quinn::Endpoint::client((std::net::Ipv6Addr::UNSPECIFIED, 0).into());
74    /// ```
75    ///
76    /// Some environments may not allow creation of dual-stack sockets, in which case an IPv6
77    /// client will only be able to connect to IPv6 servers. An IPv4 client is never dual-stack.
78    #[cfg(all(
79        not(wasm_browser),
80        any(feature = "runtime-tokio", feature = "runtime-smol"),
81        any(feature = "aws-lc-rs", feature = "ring"), // `EndpointConfig::default()` is only available with these
82    ))]
83    pub fn client(addr: SocketAddr) -> io::Result<Self> {
84        let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
85        if addr.is_ipv6() {
86            if let Err(e) = socket.set_only_v6(false) {
87                tracing::debug!(%e, "unable to make socket dual-stack");
88            }
89        }
90        socket.bind(&addr.into())?;
91        let runtime =
92            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
93        Self::new_with_abstract_socket(
94            EndpointConfig::default(),
95            None,
96            runtime.wrap_udp_socket(socket.into())?,
97            runtime,
98        )
99    }
100
101    /// Returns relevant stats from this Endpoint
102    pub fn stats(&self) -> EndpointStats {
103        self.inner.state.lock().unwrap().stats
104    }
105
106    /// Helper to construct an endpoint for use with both incoming and outgoing connections
107    ///
108    /// Platform defaults for dual-stack sockets vary. For example, any socket bound to a wildcard
109    /// IPv6 address on Windows will not by default be able to communicate with IPv4
110    /// addresses. Portable applications should bind an address that matches the family they wish to
111    /// communicate within.
112    #[cfg(all(
113        not(wasm_browser),
114        any(feature = "runtime-tokio", feature = "runtime-smol"),
115        any(feature = "aws-lc-rs", feature = "ring"), // `EndpointConfig::default()` is only available with these
116    ))]
117    pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
118        let socket = std::net::UdpSocket::bind(addr)?;
119        let runtime =
120            default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
121        Self::new_with_abstract_socket(
122            EndpointConfig::default(),
123            Some(config),
124            runtime.wrap_udp_socket(socket)?,
125            runtime,
126        )
127    }
128
129    /// Construct an endpoint with arbitrary configuration and socket
130    #[cfg(not(wasm_browser))]
131    pub fn new(
132        config: EndpointConfig,
133        server_config: Option<ServerConfig>,
134        socket: std::net::UdpSocket,
135        runtime: Arc<dyn Runtime>,
136    ) -> io::Result<Self> {
137        let socket = runtime.wrap_udp_socket(socket)?;
138        Self::new_with_abstract_socket(config, server_config, socket, runtime)
139    }
140
141    /// Construct an endpoint with arbitrary configuration and pre-constructed abstract socket
142    ///
143    /// Useful when `socket` has additional state (e.g. sidechannels) attached for which shared
144    /// ownership is needed.
145    pub fn new_with_abstract_socket(
146        config: EndpointConfig,
147        server_config: Option<ServerConfig>,
148        socket: Box<dyn AsyncUdpSocket>,
149        runtime: Arc<dyn Runtime>,
150    ) -> io::Result<Self> {
151        let addr = socket.local_addr()?;
152        let allow_mtud = !socket.may_fragment();
153        let rc = EndpointRef::new(
154            socket,
155            proto::Endpoint::new(Arc::new(config), server_config.map(Arc::new), allow_mtud),
156            addr.is_ipv6(),
157            runtime.clone(),
158        );
159        let driver = EndpointDriver(rc.clone());
160        runtime.spawn(Box::pin(
161            async {
162                if let Err(e) = driver.await {
163                    tracing::error!("I/O error: {}", e);
164                }
165            }
166            .instrument(Span::current()),
167        ));
168        Ok(Self { inner: rc, runtime })
169    }
170
171    /// Get the next incoming connection attempt from a client
172    ///
173    /// Yields [`Incoming`]s, or `None` if the endpoint is [`close`](Self::close)d. [`Incoming`]
174    /// can be `await`ed to obtain the final [`Connection`](crate::Connection), or used to e.g.
175    /// filter connection attempts or force address validation, or converted into an intermediate
176    /// `Connecting` future which can be used to e.g. send 0.5-RTT data.
177    pub fn accept(&self) -> Accept<'_> {
178        Accept {
179            endpoint: self,
180            notify: self.inner.shared.incoming.notified(),
181        }
182    }
183
184    /// Set the client configuration used by `connect`
185    pub fn set_default_client_config(&self, config: ClientConfig) {
186        self.inner.0.state.lock().unwrap().default_client_config = Some(config);
187    }
188
189    /// Connect to a remote endpoint
190    ///
191    /// `server_name` must be covered by the certificate presented by the server. This prevents a
192    /// connection from being intercepted by an attacker with a valid certificate for some other
193    /// server.
194    ///
195    /// May fail immediately due to configuration errors, or in the future if the connection could
196    /// not be established.
197    pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
198        let Some(config) = self
199            .inner
200            .0
201            .state
202            .lock()
203            .unwrap()
204            .default_client_config
205            .clone()
206        else {
207            return Err(ConnectError::NoDefaultClientConfig);
208        };
209
210        self.connect_with(config, addr, server_name)
211    }
212
213    /// Connect to a remote endpoint using a custom configuration.
214    ///
215    /// See [`connect()`] for details.
216    ///
217    /// [`connect()`]: Endpoint::connect
218    pub fn connect_with(
219        &self,
220        config: ClientConfig,
221        addr: SocketAddr,
222        server_name: &str,
223    ) -> Result<Connecting, ConnectError> {
224        let mut endpoint = self.inner.state.lock().unwrap();
225        if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
226            return Err(ConnectError::EndpointStopping);
227        }
228        if addr.is_ipv6() && !endpoint.ipv6 {
229            return Err(ConnectError::InvalidRemoteAddress(addr));
230        }
231        let addr = if endpoint.ipv6 {
232            SocketAddr::V6(ensure_ipv6(addr))
233        } else {
234            addr
235        };
236
237        let (ch, conn) = endpoint
238            .inner
239            .connect(self.runtime.now(), config, addr, server_name)?;
240
241        let sender = endpoint.socket.create_sender();
242        endpoint.stats.outgoing_handshakes += 1;
243        Ok(endpoint
244            .recv_state
245            .connections
246            .insert(ch, conn, sender, self.runtime.clone()))
247    }
248
249    /// Switch to a new UDP socket
250    ///
251    /// See [`Endpoint::rebind_abstract()`] for details.
252    #[cfg(not(wasm_browser))]
253    pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
254        self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
255    }
256
257    /// Switch to a new UDP socket
258    ///
259    /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming
260    /// connections and connections to servers unreachable from the new address will be lost.
261    ///
262    /// On error, the old UDP socket is retained.
263    pub fn rebind_abstract(&self, socket: Box<dyn AsyncUdpSocket>) -> io::Result<()> {
264        let addr = socket.local_addr()?;
265        let mut inner = self.inner.state.lock().unwrap();
266        inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
267        inner.ipv6 = addr.is_ipv6();
268
269        // Update connection socket references
270        for sender in inner.recv_state.connections.senders.values() {
271            // Ignoring errors from dropped connections
272            let _ = sender.send(ConnectionEvent::Rebind(inner.socket.create_sender()));
273        }
274        if let Some(driver) = inner.driver.take() {
275            // Ensure the driver can register for wake-ups from the new socket
276            driver.wake();
277        }
278
279        Ok(())
280    }
281
282    /// Replace the server configuration, affecting new incoming connections only
283    ///
284    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
285    pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
286        self.inner
287            .state
288            .lock()
289            .unwrap()
290            .inner
291            .set_server_config(server_config.map(Arc::new))
292    }
293
294    /// Get the local `SocketAddr` the underlying socket is bound to
295    pub fn local_addr(&self) -> io::Result<SocketAddr> {
296        self.inner.state.lock().unwrap().socket.local_addr()
297    }
298
299    /// Get the number of connections that are currently open
300    pub fn open_connections(&self) -> usize {
301        self.inner.state.lock().unwrap().inner.open_connections()
302    }
303
304    /// Close all of this endpoint's connections immediately and cease accepting new connections.
305    ///
306    /// See [`Connection::close()`] for details.
307    ///
308    /// [`Connection::close()`]: crate::Connection::close
309    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
310        let reason = Bytes::copy_from_slice(reason);
311        let mut endpoint = self.inner.state.lock().unwrap();
312        endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
313        for sender in endpoint.recv_state.connections.senders.values() {
314            // Ignoring errors from dropped connections
315            let _ = sender.send(ConnectionEvent::Close {
316                error_code,
317                reason: reason.clone(),
318            });
319        }
320        self.inner.shared.incoming.notify_waiters();
321    }
322
323    /// Wait for all connections on the endpoint to be cleanly shut down
324    ///
325    /// Waiting for this condition before exiting ensures that a good-faith effort is made to notify
326    /// peers of recent connection closes, whereas exiting immediately could force them to wait out
327    /// the idle timeout period.
328    ///
329    /// Does not proactively close existing connections or cause incoming connections to be
330    /// rejected. Consider calling [`close()`] if that is desired.
331    ///
332    /// [`close()`]: Endpoint::close
333    pub async fn wait_idle(&self) {
334        loop {
335            {
336                let endpoint = &mut *self.inner.state.lock().unwrap();
337                if endpoint.recv_state.connections.is_empty() {
338                    break;
339                }
340                // Construct future while lock is held to avoid race
341                self.inner.shared.idle.notified()
342            }
343            .await;
344        }
345    }
346}
347
348/// Statistics on [Endpoint] activity
349#[non_exhaustive]
350#[derive(Debug, Default, Copy, Clone)]
351pub struct EndpointStats {
352    /// Cumulative number of Quic handshakes accepted by this [Endpoint]
353    pub accepted_handshakes: u64,
354    /// Cumulative number of Quic handshakes sent from this [Endpoint]
355    pub outgoing_handshakes: u64,
356    /// Cumulative number of Quic handshakes refused on this [Endpoint]
357    pub refused_handshakes: u64,
358    /// Cumulative number of Quic handshakes ignored on this [Endpoint]
359    pub ignored_handshakes: u64,
360}
361
362/// A future that drives IO on an endpoint
363///
364/// This task functions as the switch point between the UDP socket object and the
365/// `Endpoint` responsible for routing datagrams to their owning `Connection`.
366/// In order to do so, it also facilitates the exchange of different types of events
367/// flowing between the `Endpoint` and the tasks managing `Connection`s. As such,
368/// running this task is necessary to keep the endpoint's connections running.
369///
370/// `EndpointDriver` futures terminate when all clones of the `Endpoint` have been dropped, or when
371/// an I/O error occurs.
372#[must_use = "endpoint drivers must be spawned for I/O to occur"]
373#[derive(Debug)]
374pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
375
376impl Future for EndpointDriver {
377    type Output = Result<(), io::Error>;
378
379    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
380        let mut endpoint = self.0.state.lock().unwrap();
381        if endpoint.driver.is_none() {
382            endpoint.driver = Some(cx.waker().clone());
383        }
384
385        let now = endpoint.runtime.now();
386        let mut keep_going = false;
387        keep_going |= endpoint.drive_recv(cx, now)?;
388        keep_going |= endpoint.handle_events(cx, &self.0.shared);
389
390        if !endpoint.recv_state.incoming.is_empty() {
391            self.0.shared.incoming.notify_waiters();
392        }
393
394        if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
395            Poll::Ready(Ok(()))
396        } else {
397            drop(endpoint);
398            // If there is more work to do schedule the endpoint task again.
399            // `wake_by_ref()` is called outside the lock to minimize
400            // lock contention on a multithreaded runtime.
401            if keep_going {
402                cx.waker().wake_by_ref();
403            }
404            Poll::Pending
405        }
406    }
407}
408
409impl Drop for EndpointDriver {
410    fn drop(&mut self) {
411        let mut endpoint = self.0.state.lock().unwrap();
412        endpoint.driver_lost = true;
413        self.0.shared.incoming.notify_waiters();
414        // Drop all outgoing channels, signaling the termination of the endpoint to the associated
415        // connections.
416        endpoint.recv_state.connections.senders.clear();
417    }
418}
419
420#[derive(Debug)]
421pub(crate) struct EndpointInner {
422    pub(crate) state: Mutex<State>,
423    pub(crate) shared: Shared,
424}
425
426impl EndpointInner {
427    pub(crate) fn accept(
428        &self,
429        incoming: proto::Incoming,
430        server_config: Option<Arc<ServerConfig>>,
431    ) -> Result<Connecting, ConnectionError> {
432        let mut state = self.state.lock().unwrap();
433        let mut response_buffer = Vec::new();
434        let now = state.runtime.now();
435        match state
436            .inner
437            .accept(incoming, now, &mut response_buffer, server_config)
438        {
439            Ok((handle, conn)) => {
440                state.stats.accepted_handshakes += 1;
441                let sender = state.socket.create_sender();
442                let runtime = state.runtime.clone();
443                Ok(state
444                    .recv_state
445                    .connections
446                    .insert(handle, conn, sender, runtime))
447            }
448            Err(error) => {
449                if let Some(transmit) = error.response {
450                    respond(transmit, &response_buffer, &mut state.sender);
451                }
452                Err(error.cause)
453            }
454        }
455    }
456
457    pub(crate) fn refuse(&self, incoming: proto::Incoming) {
458        let mut state = self.state.lock().unwrap();
459        state.stats.refused_handshakes += 1;
460        let mut response_buffer = Vec::new();
461        let transmit = state.inner.refuse(incoming, &mut response_buffer);
462        respond(transmit, &response_buffer, &mut state.sender);
463    }
464
465    pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
466        let mut state = self.state.lock().unwrap();
467        let mut response_buffer = Vec::new();
468        let transmit = state.inner.retry(incoming, &mut response_buffer)?;
469        respond(transmit, &response_buffer, &mut state.sender);
470        Ok(())
471    }
472
473    pub(crate) fn ignore(&self, incoming: proto::Incoming) {
474        let mut state = self.state.lock().unwrap();
475        state.stats.ignored_handshakes += 1;
476        state.inner.ignore(incoming);
477    }
478}
479
480#[derive(Debug)]
481pub(crate) struct State {
482    socket: Box<dyn AsyncUdpSocket>,
483    sender: Pin<Box<dyn UdpSender>>,
484    /// During an active migration, abandoned_socket receives traffic
485    /// until the first packet arrives on the new socket.
486    prev_socket: Option<Box<dyn AsyncUdpSocket>>,
487    inner: proto::Endpoint,
488    recv_state: RecvState,
489    driver: Option<Waker>,
490    ipv6: bool,
491    events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
492    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
493    ref_count: usize,
494    driver_lost: bool,
495    runtime: Arc<dyn Runtime>,
496    stats: EndpointStats,
497    default_client_config: Option<ClientConfig>,
498}
499
500#[derive(Debug)]
501pub(crate) struct Shared {
502    incoming: Notify,
503    idle: Notify,
504}
505
506impl State {
507    fn drive_recv(&mut self, cx: &mut Context<'_>, now: Instant) -> Result<bool, io::Error> {
508        let get_time = || self.runtime.now();
509        self.recv_state.recv_limiter.start_cycle(get_time);
510        if let Some(socket) = &mut self.prev_socket {
511            // We don't care about the `PollProgress` from old sockets.
512            let poll_res = self.recv_state.poll_socket(
513                cx,
514                &mut self.inner,
515                &mut **socket,
516                &mut self.sender,
517                &*self.runtime,
518                now,
519            );
520            if poll_res.is_err() {
521                self.prev_socket = None;
522            }
523        };
524        let poll_res = self.recv_state.poll_socket(
525            cx,
526            &mut self.inner,
527            &mut *self.socket,
528            &mut self.sender,
529            &*self.runtime,
530            now,
531        );
532        self.recv_state.recv_limiter.finish_cycle(get_time);
533        let poll_res = poll_res?;
534        if poll_res.received_connection_packet {
535            // Traffic has arrived on self.socket, therefore there is no need for the abandoned
536            // one anymore. TODO: Account for multiple outgoing connections.
537            self.prev_socket = None;
538        }
539        Ok(poll_res.keep_going)
540    }
541
542    fn handle_events(&mut self, cx: &mut Context<'_>, shared: &Shared) -> bool {
543        for _ in 0..IO_LOOP_BOUND {
544            let (ch, event) = match self.events.poll_recv(cx) {
545                Poll::Ready(Some(x)) => x,
546                Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
547                Poll::Pending => {
548                    return false;
549                }
550            };
551
552            if event.is_drained() {
553                self.recv_state.connections.senders.remove(&ch);
554                if self.recv_state.connections.is_empty() {
555                    shared.idle.notify_waiters();
556                }
557            }
558            let Some(event) = self.inner.handle_event(ch, event) else {
559                continue;
560            };
561            // Ignoring errors from dropped connections that haven't yet been cleaned up
562            let _ = self
563                .recv_state
564                .connections
565                .senders
566                .get_mut(&ch)
567                .unwrap()
568                .send(ConnectionEvent::Proto(event));
569        }
570
571        true
572    }
573}
574
575impl Drop for State {
576    fn drop(&mut self) {
577        for incoming in self.recv_state.incoming.drain(..) {
578            self.inner.ignore(incoming);
579        }
580    }
581}
582
583fn respond(
584    transmit: proto::Transmit,
585    response_buffer: &[u8],
586    sender: &mut Pin<Box<dyn UdpSender>>,
587) {
588    // Send if there's kernel buffer space; otherwise, drop it
589    //
590    // As an endpoint-generated packet, we know this is an
591    // immediate, stateless response to an unconnected peer,
592    // one of:
593    //
594    // - A version negotiation response due to an unknown version
595    // - A `CLOSE` due to a malformed or unwanted connection attempt
596    // - A stateless reset due to an unrecognized connection
597    // - A `Retry` packet due to a connection attempt when
598    //   `use_retry` is set
599    //
600    // In each case, a well-behaved peer can be trusted to retry a
601    // few times, which is guaranteed to produce the same response
602    // from us. Repeated failures might at worst cause a peer's new
603    // connection attempt to time out, which is acceptable if we're
604    // under such heavy load that there's never room for this code
605    // to transmit. This is morally equivalent to the packet getting
606    // lost due to congestion further along the link, which
607    // similarly relies on peer retries for recovery.
608
609    // Copied from rust 1.85's std::task::Waker::noop() implementation for backwards compatibility
610    const NOOP: RawWaker = {
611        const VTABLE: RawWakerVTable = RawWakerVTable::new(
612            // Cloning just returns a new no-op raw waker
613            |_| NOOP,
614            // `wake` does nothing
615            |_| {},
616            // `wake_by_ref` does nothing
617            |_| {},
618            // Dropping does nothing as we don't allocate anything
619            |_| {},
620        );
621        RawWaker::new(std::ptr::null(), &VTABLE)
622    };
623    // SAFETY: Copied from rust stdlib, the NOOP waker is thread-safe and doesn't violate the RawWakerVTable contract,
624    // it doesn't access the data pointer at all.
625    let waker = unsafe { Waker::from_raw(NOOP) };
626    let mut cx = Context::from_waker(&waker);
627    _ = sender.as_mut().poll_send(
628        &udp_transmit(&transmit, &response_buffer[..transmit.size]),
629        &mut cx,
630    );
631}
632
633#[inline]
634fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
635    match ecn {
636        udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
637        udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
638        udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
639    }
640}
641
642#[derive(Debug)]
643struct ConnectionSet {
644    /// Senders for communicating with the endpoint's connections
645    senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
646    /// Stored to give out clones to new ConnectionInners
647    sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
648    /// Set if the endpoint has been manually closed
649    close: Option<(VarInt, Bytes)>,
650}
651
652impl ConnectionSet {
653    fn insert(
654        &mut self,
655        handle: ConnectionHandle,
656        conn: proto::Connection,
657        sender: Pin<Box<dyn UdpSender>>,
658        runtime: Arc<dyn Runtime>,
659    ) -> Connecting {
660        let (send, recv) = mpsc::unbounded_channel();
661        if let Some((error_code, ref reason)) = self.close {
662            send.send(ConnectionEvent::Close {
663                error_code,
664                reason: reason.clone(),
665            })
666            .unwrap();
667        }
668        self.senders.insert(handle, send);
669        Connecting::new(handle, conn, self.sender.clone(), recv, sender, runtime)
670    }
671
672    fn is_empty(&self) -> bool {
673        self.senders.is_empty()
674    }
675}
676
677pub(crate) fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
678    match x {
679        SocketAddr::V6(x) => x,
680        SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
681    }
682}
683
684pin_project! {
685    /// Future produced by [`Endpoint::accept`]
686    pub struct Accept<'a> {
687        endpoint: &'a Endpoint,
688        #[pin]
689        notify: Notified<'a>,
690    }
691}
692
693impl Future for Accept<'_> {
694    type Output = Option<Incoming>;
695    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
696        let mut this = self.project();
697        let mut endpoint = this.endpoint.inner.state.lock().unwrap();
698        if endpoint.driver_lost {
699            return Poll::Ready(None);
700        }
701        if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
702            // Release the mutex lock on endpoint so cloning it doesn't deadlock
703            drop(endpoint);
704            let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
705            return Poll::Ready(Some(incoming));
706        }
707        if endpoint.recv_state.connections.close.is_some() {
708            return Poll::Ready(None);
709        }
710        loop {
711            match this.notify.as_mut().poll(ctx) {
712                // `state` lock ensures we didn't race with readiness
713                Poll::Pending => return Poll::Pending,
714                // Spurious wakeup, get a new future
715                Poll::Ready(()) => this
716                    .notify
717                    .set(this.endpoint.inner.shared.incoming.notified()),
718            }
719        }
720    }
721}
722
723#[derive(Debug)]
724pub(crate) struct EndpointRef(Arc<EndpointInner>);
725
726impl EndpointRef {
727    pub(crate) fn new(
728        socket: Box<dyn AsyncUdpSocket>,
729        inner: proto::Endpoint,
730        ipv6: bool,
731        runtime: Arc<dyn Runtime>,
732    ) -> Self {
733        let (sender, events) = mpsc::unbounded_channel();
734        let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
735        let sender = socket.create_sender();
736        Self(Arc::new(EndpointInner {
737            shared: Shared {
738                incoming: Notify::new(),
739                idle: Notify::new(),
740            },
741            state: Mutex::new(State {
742                socket,
743                sender,
744                prev_socket: None,
745                inner,
746                ipv6,
747                events,
748                driver: None,
749                ref_count: 0,
750                driver_lost: false,
751                recv_state,
752                runtime,
753                stats: EndpointStats::default(),
754                default_client_config: None,
755            }),
756        }))
757    }
758}
759
760impl Clone for EndpointRef {
761    fn clone(&self) -> Self {
762        self.0.state.lock().unwrap().ref_count += 1;
763        Self(self.0.clone())
764    }
765}
766
767impl Drop for EndpointRef {
768    fn drop(&mut self) {
769        let endpoint = &mut *self.0.state.lock().unwrap();
770        if let Some(x) = endpoint.ref_count.checked_sub(1) {
771            endpoint.ref_count = x;
772            if x == 0 {
773                // If the driver is about to be on its own, ensure it can shut down if the last
774                // connection is gone.
775                if let Some(task) = endpoint.driver.take() {
776                    task.wake();
777                }
778            }
779        }
780    }
781}
782
783impl std::ops::Deref for EndpointRef {
784    type Target = EndpointInner;
785    fn deref(&self) -> &Self::Target {
786        &self.0
787    }
788}
789
790/// State directly involved in handling incoming packets
791struct RecvState {
792    incoming: VecDeque<proto::Incoming>,
793    connections: ConnectionSet,
794    recv_buf: Box<[u8]>,
795    recv_limiter: WorkLimiter,
796}
797
798impl RecvState {
799    fn new(
800        sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
801        max_receive_segments: NonZeroUsize,
802        endpoint: &proto::Endpoint,
803    ) -> Self {
804        let recv_buf = vec![
805            0;
806            endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
807                * max_receive_segments.get()
808                * BATCH_SIZE
809        ];
810        Self {
811            connections: ConnectionSet {
812                senders: FxHashMap::default(),
813                sender,
814                close: None,
815            },
816            incoming: VecDeque::new(),
817            recv_buf: recv_buf.into(),
818            recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
819        }
820    }
821
822    fn poll_socket(
823        &mut self,
824        cx: &mut Context<'_>,
825        endpoint: &mut proto::Endpoint,
826        socket: &mut dyn AsyncUdpSocket,
827        sender: &mut Pin<Box<dyn UdpSender>>,
828        runtime: &dyn Runtime,
829        now: Instant,
830    ) -> Result<PollProgress, io::Error> {
831        let mut received_connection_packet = false;
832        let mut metas = [RecvMeta::default(); BATCH_SIZE];
833        let mut iovs: [IoSliceMut<'_>; BATCH_SIZE] = {
834            let mut bufs = self
835                .recv_buf
836                .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
837                .map(IoSliceMut::new);
838
839            // expect() safe as self.recv_buf is chunked into BATCH_SIZE items
840            // and iovs will be of size BATCH_SIZE, thus from_fn is called
841            // exactly BATCH_SIZE times.
842            std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
843        };
844        loop {
845            match socket.poll_recv(cx, &mut iovs, &mut metas) {
846                Poll::Ready(Ok(msgs)) => {
847                    self.recv_limiter.record_work(msgs);
848                    for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
849                        let mut data: BytesMut = buf[0..meta.len].into();
850                        while !data.is_empty() {
851                            let buf = data.split_to(meta.stride.min(data.len()));
852                            let mut response_buffer = Vec::new();
853                            let addresses = FourTuple {
854                                remote: meta.addr,
855                                local_ip: meta.dst_ip,
856                            };
857                            match endpoint.handle(
858                                now,
859                                addresses,
860                                meta.ecn.map(proto_ecn),
861                                buf,
862                                &mut response_buffer,
863                            ) {
864                                Some(DatagramEvent::NewConnection(incoming)) => {
865                                    if self.connections.close.is_none() {
866                                        self.incoming.push_back(incoming);
867                                    } else {
868                                        let transmit =
869                                            endpoint.refuse(incoming, &mut response_buffer);
870                                        respond(transmit, &response_buffer, sender);
871                                    }
872                                }
873                                Some(DatagramEvent::ConnectionEvent(handle, event)) => {
874                                    // Ignoring errors from dropped connections that haven't yet been cleaned up
875                                    received_connection_packet = true;
876                                    let _ = self
877                                        .connections
878                                        .senders
879                                        .get_mut(&handle)
880                                        .unwrap()
881                                        .send(ConnectionEvent::Proto(event));
882                                }
883                                Some(DatagramEvent::Response(transmit)) => {
884                                    respond(transmit, &response_buffer, sender);
885                                }
886                                None => {}
887                            }
888                        }
889                    }
890                }
891                Poll::Pending => {
892                    return Ok(PollProgress {
893                        received_connection_packet,
894                        keep_going: false,
895                    });
896                }
897                // Ignore ECONNRESET as it's undefined in QUIC and may be injected by an
898                // attacker
899                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
900                    continue;
901                }
902                Poll::Ready(Err(e)) => {
903                    return Err(e);
904                }
905            }
906            if !self.recv_limiter.allow_work(|| runtime.now()) {
907                return Ok(PollProgress {
908                    received_connection_packet,
909                    keep_going: true,
910                });
911            }
912        }
913    }
914}
915
916impl fmt::Debug for RecvState {
917    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
918        f.debug_struct("RecvState")
919            .field("incoming", &self.incoming)
920            .field("connections", &self.connections)
921            // recv_buf too large
922            .field("recv_limiter", &self.recv_limiter)
923            .finish_non_exhaustive()
924    }
925}
926
927#[derive(Default)]
928struct PollProgress {
929    /// Whether a datagram was routed to an existing connection
930    received_connection_packet: bool,
931    /// Whether datagram handling was interrupted early by the work limiter for fairness
932    keep_going: bool,
933}