noq_proto/
endpoint.rs

1use std::{
2    collections::{HashMap, hash_map},
3    convert::TryFrom,
4    fmt, mem,
5    net::{IpAddr, SocketAddr},
6    ops::{Index, IndexMut},
7    sync::Arc,
8};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use rand::{Rng, RngExt, SeedableRng, rngs::StdRng};
12use rustc_hash::FxHashMap;
13use slab::Slab;
14use thiserror::Error;
15use tracing::{debug, error, trace, warn};
16
17use crate::{
18    Duration, FourTuple, INITIAL_MTU, Instant, MAX_CID_SIZE, MIN_INITIAL_SIZE, PathId,
19    RESET_TOKEN_SIZE, ResetToken, Side, Transmit, TransportConfig, TransportError,
20    cid_generator::ConnectionIdGenerator,
21    coding::{BufMutExt, Decodable, Encodable, UnexpectedEnd},
22    config::{ClientConfig, EndpointConfig, ServerConfig},
23    connection::{Connection, ConnectionError, SideArgs},
24    crypto::{self, Keys, UnsupportedVersion},
25    frame,
26    packet::{
27        FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, PacketDecodeError,
28        PacketNumber, PartialDecode, ProtectedInitialHeader,
29    },
30    shared::{
31        ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint,
32        EndpointEvent, EndpointEventInner, IssuedCid,
33    },
34    token::{IncomingToken, InvalidRetryTokenError, Token, TokenPayload},
35    transport_parameters::{PreferredAddress, TransportParameters},
36};
37
38/// The main entry point to the library
39///
40/// This object performs no I/O whatsoever. Instead, it consumes incoming packets and
41/// connection-generated events via `handle` and `handle_event`.
42pub struct Endpoint {
43    rng: StdRng,
44    index: ConnectionIndex,
45    connections: Slab<ConnectionMeta>,
46    local_cid_generator: Box<dyn ConnectionIdGenerator>,
47    config: Arc<EndpointConfig>,
48    server_config: Option<Arc<ServerConfig>>,
49    /// Whether the underlying UDP socket promises not to fragment packets
50    allow_mtud: bool,
51    /// Time at which a stateless reset was most recently sent
52    last_stateless_reset: Option<Instant>,
53    /// Buffered Initial and 0-RTT messages for pending incoming connections
54    incoming_buffers: Slab<IncomingBuffer>,
55    all_incoming_buffers_total_bytes: u64,
56}
57
58impl Endpoint {
59    /// Create a new endpoint
60    ///
61    /// `allow_mtud` enables path MTU detection when requested by `Connection` configuration for
62    /// better performance. This requires that outgoing packets are never fragmented, which can be
63    /// achieved via e.g. the `IPV6_DONTFRAG` socket option.
64    pub fn new(
65        config: Arc<EndpointConfig>,
66        server_config: Option<Arc<ServerConfig>>,
67        allow_mtud: bool,
68    ) -> Self {
69        Self {
70            rng: config
71                .rng_seed
72                .map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::from_seed),
73            index: ConnectionIndex::default(),
74            connections: Slab::new(),
75            local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
76            config,
77            server_config,
78            allow_mtud,
79            last_stateless_reset: None,
80            incoming_buffers: Slab::new(),
81            all_incoming_buffers_total_bytes: 0,
82        }
83    }
84
85    /// Replace the server configuration, affecting new incoming connections only
86    pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
87        self.server_config = server_config;
88    }
89
90    /// Process `EndpointEvent`s emitted from related `Connection`s
91    ///
92    /// In turn, processing this event may return a `ConnectionEvent` for the same `Connection`.
93    pub fn handle_event(
94        &mut self,
95        ch: ConnectionHandle,
96        event: EndpointEvent,
97    ) -> Option<ConnectionEvent> {
98        use EndpointEventInner::*;
99        match event.0 {
100            NeedIdentifiers(path_id, now, n) => {
101                return Some(self.send_new_identifiers(path_id, now, ch, n));
102            }
103            ResetToken(path_id, remote, token) => {
104                if let Some(old) = self.connections[ch]
105                    .reset_token
106                    .insert(path_id, (remote, token))
107                {
108                    self.index.connection_reset_tokens.remove(old.0, old.1);
109                }
110                if self.index.connection_reset_tokens.insert(remote, token, ch) {
111                    warn!("duplicate reset token");
112                }
113            }
114            RetireResetToken(path_id) => {
115                if let Some(old) = self.connections[ch].reset_token.remove(&path_id) {
116                    self.index.connection_reset_tokens.remove(old.0, old.1);
117                }
118            }
119            RetireConnectionId(now, path_id, seq, allow_more_cids) => {
120                if let Some(cid) = self.connections[ch]
121                    .local_cids
122                    .get_mut(&path_id)
123                    .and_then(|pcid| pcid.cids.remove(&seq))
124                {
125                    trace!(%path_id, "local CID retired {}: {}", seq, cid);
126                    self.index.retire(cid);
127                    if allow_more_cids {
128                        return Some(self.send_new_identifiers(path_id, now, ch, 1));
129                    }
130                }
131            }
132            Draining => {
133                // Nothing to do.
134            }
135            Drained => {
136                if let Some(conn) = self.connections.try_remove(ch.0) {
137                    self.index.remove(&conn);
138                } else {
139                    // This indicates a bug in downstream code, which could cause spurious
140                    // connection loss instead of this error if the CID was (re)allocated prior to
141                    // the illegal call.
142                    error!(id = ch.0, "unknown connection drained");
143                }
144            }
145        }
146        None
147    }
148
149    /// Process an incoming UDP datagram
150    pub fn handle(
151        &mut self,
152        now: Instant,
153        network_path: FourTuple,
154        ecn: Option<EcnCodepoint>,
155        data: BytesMut,
156        buf: &mut Vec<u8>,
157    ) -> Option<DatagramEvent> {
158        // Partially decode packet or short-circuit if unable
159        let datagram_len = data.len();
160        let mut event = match PartialDecode::new(
161            data,
162            &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()),
163            &self.config.supported_versions,
164            self.config.grease_quic_bit,
165        ) {
166            Ok((first_decode, remaining)) => DatagramConnectionEvent {
167                now,
168                network_path,
169                path_id: PathId::ZERO, // Corrected later for existing paths
170                ecn,
171                first_decode,
172                remaining,
173            },
174            Err(PacketDecodeError::UnsupportedVersion {
175                src_cid,
176                dst_cid,
177                version,
178            }) => {
179                if self.server_config.is_none() {
180                    debug!("dropping packet with unsupported version");
181                    return None;
182                }
183                trace!("sending version negotiation");
184                // Negotiate versions
185                Header::VersionNegotiate {
186                    random: self.rng.random::<u8>() | 0x40,
187                    src_cid: dst_cid,
188                    dst_cid: src_cid,
189                }
190                .encode(buf);
191                // Grease with a reserved version
192                buf.write::<u32>(match version {
193                    0x0a1a_2a3a => 0x0a1a_2a4a,
194                    _ => 0x0a1a_2a3a,
195                });
196                for &version in &self.config.supported_versions {
197                    buf.write(version);
198                }
199                return Some(DatagramEvent::Response(Transmit {
200                    destination: network_path.remote,
201                    ecn: None,
202                    size: buf.len(),
203                    segment_size: None,
204                    src_ip: network_path.local_ip,
205                }));
206            }
207            Err(e) => {
208                trace!("malformed header: {}", e);
209                return None;
210            }
211        };
212
213        let dst_cid = event.first_decode.dst_cid();
214
215        if let Some(route_to) = self.index.get(&network_path, &event.first_decode) {
216            event.path_id = match route_to {
217                RouteDatagramTo::Incoming(_) => PathId::ZERO,
218                RouteDatagramTo::Connection(_, path_id) => path_id,
219            };
220            match route_to {
221                RouteDatagramTo::Incoming(incoming_idx) => {
222                    let incoming_buffer = &mut self.incoming_buffers[incoming_idx];
223                    let config = &self.server_config.as_ref().unwrap();
224
225                    if incoming_buffer
226                        .total_bytes
227                        .checked_add(datagram_len as u64)
228                        .is_some_and(|n| n <= config.incoming_buffer_size)
229                        && self
230                            .all_incoming_buffers_total_bytes
231                            .checked_add(datagram_len as u64)
232                            .is_some_and(|n| n <= config.incoming_buffer_size_total)
233                    {
234                        incoming_buffer.datagrams.push(event);
235                        incoming_buffer.total_bytes += datagram_len as u64;
236                        self.all_incoming_buffers_total_bytes += datagram_len as u64;
237                    }
238
239                    None
240                }
241                RouteDatagramTo::Connection(ch, _path_id) => Some(DatagramEvent::ConnectionEvent(
242                    ch,
243                    ConnectionEvent(ConnectionEventInner::Datagram(event)),
244                )),
245            }
246        } else if event.first_decode.initial_header().is_some() {
247            // Potentially create a new connection
248
249            self.handle_first_packet(datagram_len, event, network_path, buf)
250        } else if event.first_decode.has_long_header() {
251            debug!(
252                "ignoring non-initial packet for unknown connection {}",
253                dst_cid
254            );
255            None
256        } else if !event.first_decode.is_initial()
257            && self.local_cid_generator.validate(dst_cid).is_err()
258        {
259            debug!("dropping packet with invalid CID");
260            None
261        } else if dst_cid.is_empty() {
262            trace!("dropping unrecognized short packet without ID");
263            None
264        } else {
265            // If we got this far, we're receiving a seemingly valid packet for an unknown
266            // connection. Send a stateless reset if possible.
267            self.stateless_reset(now, datagram_len, network_path, dst_cid, buf)
268                .map(DatagramEvent::Response)
269        }
270    }
271
272    /// Builds a stateless reset packet to respond with
273    fn stateless_reset(
274        &mut self,
275        now: Instant,
276        inciting_dgram_len: usize,
277        network_path: FourTuple,
278        dst_cid: ConnectionId,
279        buf: &mut Vec<u8>,
280    ) -> Option<Transmit> {
281        if self
282            .last_stateless_reset
283            .is_some_and(|last| last + self.config.min_reset_interval > now)
284        {
285            debug!("ignoring unexpected packet within minimum stateless reset interval");
286            return None;
287        }
288
289        /// Minimum amount of padding for the stateless reset to look like a short-header packet
290        const MIN_PADDING_LEN: usize = 5;
291
292        // Prevent amplification attacks and reset loops by ensuring we pad to at most 1 byte
293        // smaller than the inciting packet.
294        let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
295            Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
296            _ => {
297                debug!(
298                    "ignoring unexpected {} byte packet: not larger than minimum stateless reset size",
299                    inciting_dgram_len
300                );
301                return None;
302            }
303        };
304
305        debug!(%dst_cid, %network_path.remote, "sending stateless reset");
306        self.last_stateless_reset = Some(now);
307        // Resets with at least this much padding can't possibly be distinguished from real packets
308        const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
309        let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
310            max_padding_len
311        } else {
312            self.rng
313                .random_range(IDEAL_MIN_PADDING_LEN..max_padding_len)
314        };
315        buf.reserve(padding_len + RESET_TOKEN_SIZE);
316        buf.resize(padding_len, 0);
317        self.rng.fill_bytes(&mut buf[0..padding_len]);
318        buf[0] = 0b0100_0000 | (buf[0] >> 2);
319        buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));
320
321        debug_assert!(buf.len() < inciting_dgram_len);
322
323        Some(Transmit {
324            destination: network_path.remote,
325            ecn: None,
326            size: buf.len(),
327            segment_size: None,
328            src_ip: network_path.local_ip,
329        })
330    }
331
332    /// Initiate a connection
333    pub fn connect(
334        &mut self,
335        now: Instant,
336        config: ClientConfig,
337        remote: SocketAddr,
338        server_name: &str,
339    ) -> Result<(ConnectionHandle, Connection), ConnectError> {
340        if self.cids_exhausted() {
341            return Err(ConnectError::CidsExhausted);
342        }
343        if remote.port() == 0 || remote.ip().is_unspecified() {
344            return Err(ConnectError::InvalidRemoteAddress(remote));
345        }
346        if !self.config.supported_versions.contains(&config.version) {
347            return Err(ConnectError::UnsupportedVersion);
348        }
349
350        let remote_id = (config.initial_dst_cid_provider)();
351        trace!(initial_dcid = %remote_id);
352
353        let ch = ConnectionHandle(self.connections.vacant_key());
354        let local_cid = self.new_cid(ch, PathId::ZERO);
355        let params = TransportParameters::new(
356            &config.transport,
357            &self.config,
358            self.local_cid_generator.as_ref(),
359            local_cid,
360            None,
361            &mut self.rng,
362        );
363        let tls = config
364            .crypto
365            .start_session(config.version, server_name, &params)?;
366
367        let conn = self.add_connection(
368            ch,
369            config.version,
370            remote_id,
371            local_cid,
372            remote_id,
373            FourTuple {
374                remote,
375                local_ip: None,
376            },
377            now,
378            tls,
379            config.transport,
380            SideArgs::Client {
381                token_store: config.token_store,
382                server_name: server_name.into(),
383            },
384            &params,
385        );
386        Ok((ch, conn))
387    }
388
389    /// Generates new CIDs and creates message to send to the connection state
390    fn send_new_identifiers(
391        &mut self,
392        path_id: PathId,
393        now: Instant,
394        ch: ConnectionHandle,
395        num: u64,
396    ) -> ConnectionEvent {
397        let mut ids = vec![];
398        for _ in 0..num {
399            let id = self.new_cid(ch, path_id);
400            let cid_meta = self.connections[ch].local_cids.entry(path_id).or_default();
401            let sequence = cid_meta.issued;
402            cid_meta.issued += 1;
403            cid_meta.cids.insert(sequence, id);
404            ids.push(IssuedCid {
405                path_id,
406                sequence,
407                id,
408                reset_token: ResetToken::new(&*self.config.reset_key, id),
409            });
410        }
411        ConnectionEvent(ConnectionEventInner::NewIdentifiers(
412            ids,
413            now,
414            self.local_cid_generator.cid_len(),
415            self.local_cid_generator.cid_lifetime(),
416        ))
417    }
418
419    /// Generate a connection ID for `ch`
420    fn new_cid(&mut self, ch: ConnectionHandle, path_id: PathId) -> ConnectionId {
421        loop {
422            let cid = self.local_cid_generator.generate_cid();
423            if cid.is_empty() {
424                // Zero-length CID; nothing to track
425                debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
426                return cid;
427            }
428            if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
429                e.insert((ch, path_id));
430                break cid;
431            }
432        }
433    }
434
435    fn handle_first_packet(
436        &mut self,
437        datagram_len: usize,
438        event: DatagramConnectionEvent,
439        network_path: FourTuple,
440        buf: &mut Vec<u8>,
441    ) -> Option<DatagramEvent> {
442        let dst_cid = event.first_decode.dst_cid();
443        let header = event.first_decode.initial_header().unwrap();
444
445        let Some(server_config) = &self.server_config else {
446            debug!("packet for unrecognized connection {}", dst_cid);
447            return self
448                .stateless_reset(event.now, datagram_len, network_path, dst_cid, buf)
449                .map(DatagramEvent::Response);
450        };
451
452        if datagram_len < MIN_INITIAL_SIZE as usize {
453            debug!("ignoring short initial for connection {}", dst_cid);
454            return None;
455        }
456
457        let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) {
458            Ok(keys) => keys,
459            Err(UnsupportedVersion) => {
460                // This probably indicates that the user set supported_versions incorrectly in
461                // `EndpointConfig`.
462                debug!(
463                    "ignoring initial packet version {:#x} unsupported by cryptographic layer",
464                    header.version
465                );
466                return None;
467            }
468        };
469
470        if let Err(reason) = self.early_validate_first_packet(header) {
471            return Some(DatagramEvent::Response(self.initial_close(
472                header.version,
473                network_path,
474                &crypto,
475                header.src_cid,
476                reason,
477                buf,
478            )));
479        }
480
481        let packet = match event.first_decode.finish(Some(&*crypto.header.remote)) {
482            Ok(packet) => packet,
483            Err(e) => {
484                trace!("unable to decode initial packet: {}", e);
485                return None;
486            }
487        };
488
489        if !packet.reserved_bits_valid() {
490            debug!("dropping connection attempt with invalid reserved bits");
491            return None;
492        }
493
494        let Header::Initial(header) = packet.header else {
495            panic!("non-initial packet in handle_first_packet()");
496        };
497
498        let server_config = self.server_config.as_ref().unwrap().clone();
499
500        let token = match IncomingToken::from_header(&header, &server_config, network_path.remote) {
501            Ok(token) => token,
502            Err(InvalidRetryTokenError) => {
503                debug!("rejecting invalid retry token");
504                return Some(DatagramEvent::Response(self.initial_close(
505                    header.version,
506                    network_path,
507                    &crypto,
508                    header.src_cid,
509                    TransportError::INVALID_TOKEN(""),
510                    buf,
511                )));
512            }
513        };
514
515        let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default());
516        self.index
517            .insert_initial_incoming(header.dst_cid, incoming_idx);
518
519        Some(DatagramEvent::NewConnection(Incoming {
520            received_at: event.now,
521            network_path,
522            ecn: event.ecn,
523            packet: InitialPacket {
524                header,
525                header_data: packet.header_data,
526                payload: packet.payload,
527            },
528            rest: event.remaining,
529            crypto,
530            token,
531            incoming_idx,
532            improper_drop_warner: IncomingImproperDropWarner,
533        }))
534    }
535
536    /// Attempt to accept this incoming connection (an error may still occur)
537    // box err to avoid clippy::result_large_err
538    pub fn accept(
539        &mut self,
540        mut incoming: Incoming,
541        now: Instant,
542        buf: &mut Vec<u8>,
543        server_config: Option<Arc<ServerConfig>>,
544    ) -> Result<(ConnectionHandle, Connection), Box<AcceptError>> {
545        let remote_address_validated = incoming.remote_address_validated();
546        incoming.improper_drop_warner.dismiss();
547        let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
548        self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
549
550        let packet_number = incoming.packet.header.number.expand(0);
551        let InitialHeader {
552            src_cid,
553            dst_cid,
554            version,
555            ..
556        } = incoming.packet.header;
557        let server_config =
558            server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());
559
560        if server_config
561            .transport
562            .max_idle_timeout
563            .is_some_and(|timeout| {
564                incoming.received_at + Duration::from_millis(timeout.into()) <= now
565            })
566        {
567            debug!("abandoning accept of stale initial");
568            self.index.remove_initial(dst_cid);
569            return Err(Box::new(AcceptError {
570                cause: ConnectionError::TimedOut,
571                response: None,
572            }));
573        }
574
575        if self.cids_exhausted() {
576            debug!("refusing connection");
577            self.index.remove_initial(dst_cid);
578            return Err(Box::new(AcceptError {
579                cause: ConnectionError::CidsExhausted,
580                response: Some(self.initial_close(
581                    version,
582                    incoming.network_path,
583                    &incoming.crypto,
584                    src_cid,
585                    TransportError::CONNECTION_REFUSED(""),
586                    buf,
587                )),
588            }));
589        }
590
591        if incoming
592            .crypto
593            .packet
594            .remote
595            .decrypt(
596                PathId::ZERO,
597                packet_number,
598                &incoming.packet.header_data,
599                &mut incoming.packet.payload,
600            )
601            .is_err()
602        {
603            debug!(packet_number, "failed to authenticate initial packet");
604            self.index.remove_initial(dst_cid);
605            return Err(Box::new(AcceptError {
606                cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
607                response: None,
608            }));
609        };
610
611        let ch = ConnectionHandle(self.connections.vacant_key());
612        let local_cid = self.new_cid(ch, PathId::ZERO);
613        let mut params = TransportParameters::new(
614            &server_config.transport,
615            &self.config,
616            self.local_cid_generator.as_ref(),
617            local_cid,
618            Some(&server_config),
619            &mut self.rng,
620        );
621        params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, local_cid));
622        params.original_dst_cid = Some(incoming.token.orig_dst_cid);
623        params.retry_src_cid = incoming.token.retry_src_cid;
624        let mut pref_addr_cid = None;
625        if server_config.has_preferred_address() {
626            let cid = self.new_cid(ch, PathId::ZERO);
627            pref_addr_cid = Some(cid);
628            params.preferred_address = Some(PreferredAddress {
629                address_v4: server_config.preferred_address_v4,
630                address_v6: server_config.preferred_address_v6,
631                connection_id: cid,
632                stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
633            });
634        }
635
636        let tls = server_config.crypto.start_session(version, &params);
637        let transport_config = server_config.transport.clone();
638        let mut conn = self.add_connection(
639            ch,
640            version,
641            dst_cid,
642            local_cid,
643            src_cid,
644            incoming.network_path,
645            incoming.received_at,
646            tls,
647            transport_config,
648            SideArgs::Server {
649                server_config,
650                pref_addr_cid,
651                path_validated: remote_address_validated,
652            },
653            &params,
654        );
655        self.index.insert_initial(dst_cid, ch);
656
657        match conn.handle_first_packet(
658            incoming.received_at,
659            incoming.network_path,
660            incoming.ecn,
661            packet_number,
662            incoming.packet,
663            incoming.rest,
664        ) {
665            Ok(()) => {
666                trace!(
667                    id = ch.0,
668                    icid = %dst_cid,
669                    network_path = %incoming.network_path,
670                    "new connection",
671                );
672
673                for event in incoming_buffer.datagrams {
674                    conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event)))
675                }
676
677                Ok((ch, conn))
678            }
679            Err(e) => {
680                debug!("handshake failed: {}", e);
681                self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
682                let response = match e {
683                    ConnectionError::TransportError(ref e) => Some(self.initial_close(
684                        version,
685                        incoming.network_path,
686                        &incoming.crypto,
687                        src_cid,
688                        e.clone(),
689                        buf,
690                    )),
691                    _ => None,
692                };
693                Err(Box::new(AcceptError { cause: e, response }))
694            }
695        }
696    }
697
698    /// Check if we should refuse a connection attempt regardless of the packet's contents
699    fn early_validate_first_packet(
700        &mut self,
701        header: &ProtectedInitialHeader,
702    ) -> Result<(), TransportError> {
703        let config = &self.server_config.as_ref().unwrap();
704        if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming {
705            return Err(TransportError::CONNECTION_REFUSED(""));
706        }
707
708        // RFC9000 §7.2 dictates that initial (client-chosen) destination CIDs must be at least 8
709        // bytes. If this is a Retry packet, then the length must instead match our usual CID
710        // length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll
711        // also need to validate CID length for those after decoding the token.
712        if header.dst_cid.len() < 8
713            && (header.token_pos.is_empty()
714                || header.dst_cid.len() != self.local_cid_generator.cid_len())
715        {
716            debug!(
717                "rejecting connection due to invalid DCID length {}",
718                header.dst_cid.len()
719            );
720            return Err(TransportError::PROTOCOL_VIOLATION(
721                "invalid destination CID length",
722            ));
723        }
724
725        Ok(())
726    }
727
728    /// Reject this incoming connection attempt
729    pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Transmit {
730        self.clean_up_incoming(&incoming);
731        incoming.improper_drop_warner.dismiss();
732
733        self.initial_close(
734            incoming.packet.header.version,
735            incoming.network_path,
736            &incoming.crypto,
737            incoming.packet.header.src_cid,
738            TransportError::CONNECTION_REFUSED(""),
739            buf,
740        )
741    }
742
743    /// Respond with a retry packet, requiring the client to retry with address validation
744    ///
745    /// Errors if `incoming.may_retry()` is false.
746    pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Result<Transmit, RetryError> {
747        if !incoming.may_retry() {
748            return Err(RetryError(Box::new(incoming)));
749        }
750
751        self.clean_up_incoming(&incoming);
752        incoming.improper_drop_warner.dismiss();
753
754        let server_config = self.server_config.as_ref().unwrap();
755
756        // First Initial
757        // The peer will use this as the DCID of its following Initials. Initial DCIDs are
758        // looked up separately from Handshake/Data DCIDs, so there is no risk of collision
759        // with established connections. In the unlikely event that a collision occurs
760        // between two connections in the initial phase, both will fail fast and may be
761        // retried by the application layer.
762        let local_cid = self.local_cid_generator.generate_cid();
763
764        let payload = TokenPayload::Retry {
765            address: incoming.network_path.remote,
766            orig_dst_cid: incoming.packet.header.dst_cid,
767            issued: server_config.time_source.now(),
768        };
769        let token = Token::new(payload, &mut self.rng).encode(&*server_config.token_key);
770
771        let header = Header::Retry {
772            src_cid: local_cid,
773            dst_cid: incoming.packet.header.src_cid,
774            version: incoming.packet.header.version,
775        };
776
777        let encode = header.encode(buf);
778        buf.put_slice(&token);
779        buf.extend_from_slice(&server_config.crypto.retry_tag(
780            incoming.packet.header.version,
781            incoming.packet.header.dst_cid,
782            buf,
783        ));
784        encode.finish(buf, &*incoming.crypto.header.local, None);
785
786        Ok(Transmit {
787            destination: incoming.network_path.remote,
788            ecn: None,
789            size: buf.len(),
790            segment_size: None,
791            src_ip: incoming.network_path.local_ip,
792        })
793    }
794
795    /// Ignore this incoming connection attempt, not sending any packet in response
796    ///
797    /// Doing this actively, rather than merely dropping the [`Incoming`], is necessary to prevent
798    /// memory leaks due to state within [`Endpoint`] tracking the incoming connection.
799    pub fn ignore(&mut self, incoming: Incoming) {
800        self.clean_up_incoming(&incoming);
801        incoming.improper_drop_warner.dismiss();
802    }
803
804    /// Clean up endpoint data structures associated with an `Incoming`.
805    fn clean_up_incoming(&mut self, incoming: &Incoming) {
806        self.index.remove_initial(incoming.packet.header.dst_cid);
807        let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
808        self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
809    }
810
811    fn add_connection(
812        &mut self,
813        ch: ConnectionHandle,
814        version: u32,
815        init_cid: ConnectionId,
816        local_cid: ConnectionId,
817        remote_cid: ConnectionId,
818        network_path: FourTuple,
819        now: Instant,
820        tls: Box<dyn crypto::Session>,
821        transport_config: Arc<TransportConfig>,
822        side_args: SideArgs,
823        // Only used for qlog.
824        params: &TransportParameters,
825    ) -> Connection {
826        let mut rng_seed = [0; 32];
827        self.rng.fill_bytes(&mut rng_seed);
828        let side = side_args.side();
829        let pref_addr_cid = side_args.pref_addr_cid();
830
831        let qlog =
832            transport_config.create_qlog_sink(side_args.side(), network_path.remote, init_cid, now);
833
834        qlog.emit_connection_started(
835            now,
836            local_cid,
837            remote_cid,
838            network_path.remote,
839            network_path.local_ip,
840            params,
841        );
842
843        let conn = Connection::new(
844            self.config.clone(),
845            transport_config,
846            init_cid,
847            local_cid,
848            remote_cid,
849            network_path,
850            tls,
851            self.local_cid_generator.as_ref(),
852            now,
853            version,
854            self.allow_mtud,
855            rng_seed,
856            side_args,
857            qlog,
858        );
859
860        let mut path_cids = PathLocalCids::default();
861        path_cids.cids.insert(path_cids.issued, local_cid);
862        path_cids.issued += 1;
863
864        if let Some(cid) = pref_addr_cid {
865            debug_assert_eq!(path_cids.issued, 1, "preferred address cid seq must be 1");
866            path_cids.cids.insert(path_cids.issued, cid);
867            path_cids.issued += 1;
868        }
869
870        let id = self.connections.insert(ConnectionMeta {
871            init_cid,
872            local_cids: FxHashMap::from_iter([(PathId::ZERO, path_cids)]),
873            network_path,
874            side,
875            reset_token: Default::default(),
876        });
877        debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
878
879        self.index.insert_conn(network_path, local_cid, ch, side);
880
881        conn
882    }
883
884    fn initial_close(
885        &mut self,
886        version: u32,
887        network_path: FourTuple,
888        crypto: &Keys,
889        remote_id: ConnectionId,
890        reason: TransportError,
891        buf: &mut Vec<u8>,
892    ) -> Transmit {
893        // We don't need to worry about CID collisions in initial closes because the peer
894        // shouldn't respond, and if it does, and the CID collides, we'll just drop the
895        // unexpected response.
896        let local_id = self.local_cid_generator.generate_cid();
897        let number = PacketNumber::U8(0);
898        let header = Header::Initial(InitialHeader {
899            dst_cid: remote_id,
900            src_cid: local_id,
901            number,
902            token: Bytes::new(),
903            version,
904        });
905
906        let partial_encode = header.encode(buf);
907        let max_len =
908            INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len();
909        frame::Close::from(reason).encoder(max_len).encode(buf);
910        buf.resize(buf.len() + crypto.packet.local.tag_len(), 0);
911        partial_encode.finish(
912            buf,
913            &*crypto.header.local,
914            Some((0, Default::default(), &*crypto.packet.local)),
915        );
916        Transmit {
917            destination: network_path.remote,
918            ecn: None,
919            size: buf.len(),
920            segment_size: None,
921            src_ip: network_path.local_ip,
922        }
923    }
924
925    /// Access the configuration used by this endpoint
926    pub fn config(&self) -> &EndpointConfig {
927        &self.config
928    }
929
930    /// Number of connections that are currently open
931    pub fn open_connections(&self) -> usize {
932        self.connections.len()
933    }
934
935    /// Counter for the number of bytes currently used
936    /// in the buffers for Initial and 0-RTT messages for pending incoming connections
937    pub fn incoming_buffer_bytes(&self) -> u64 {
938        self.all_incoming_buffers_total_bytes
939    }
940
941    #[cfg(test)]
942    pub(crate) fn known_connections(&self) -> usize {
943        let x = self.connections.len();
944        debug_assert_eq!(x, self.index.connection_ids_initial.len());
945        // Not all connections have known reset tokens
946        debug_assert!(x >= self.index.connection_reset_tokens.0.len());
947        // Not all connections have unique remotes, and 0-length CIDs might not be in use.
948        debug_assert!(x >= self.index.incoming_connection_remotes.len());
949        debug_assert!(x >= self.index.outgoing_connection_remotes.len());
950        x
951    }
952
953    #[cfg(test)]
954    pub(crate) fn known_cids(&self) -> usize {
955        self.index.connection_ids.len()
956    }
957
958    /// Whether we've used up 3/4 of the available CID space
959    ///
960    /// We leave some space unused so that `new_cid` can be relied upon to finish quickly. We don't
961    /// bother to check when CID longer than 4 bytes are used because 2^40 connections is a lot.
962    fn cids_exhausted(&self) -> bool {
963        let cid_len = self.local_cid_generator.cid_len();
964        if cid_len == 0 || cid_len > 4 {
965            return false;
966        }
967
968        // Keep this architecture-independent: on 32-bit targets, 2usize.pow(32) overflows.
969        let bits = (cid_len * 8) as u32;
970        let space = 1u64 << bits;
971        let reserve = 1u64 << (bits - 2);
972        let len = self.index.connection_ids.len() as u64;
973
974        len > (space - reserve)
975    }
976}
977
978impl fmt::Debug for Endpoint {
979    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
980        fmt.debug_struct("Endpoint")
981            .field("rng", &self.rng)
982            .field("index", &self.index)
983            .field("connections", &self.connections)
984            .field("config", &self.config)
985            .field("server_config", &self.server_config)
986            // incoming_buffers too large
987            .field("incoming_buffers.len", &self.incoming_buffers.len())
988            .field(
989                "all_incoming_buffers_total_bytes",
990                &self.all_incoming_buffers_total_bytes,
991            )
992            .finish()
993    }
994}
995
996/// Buffered Initial and 0-RTT messages for a pending incoming connection
997#[derive(Default)]
998struct IncomingBuffer {
999    datagrams: Vec<DatagramConnectionEvent>,
1000    total_bytes: u64,
1001}
1002
1003/// Part of protocol state incoming datagrams can be routed to
1004#[derive(Copy, Clone, Debug)]
1005enum RouteDatagramTo {
1006    Incoming(usize),
1007    Connection(ConnectionHandle, PathId),
1008}
1009
1010/// Maps packets to existing connections
1011#[derive(Default, Debug)]
1012struct ConnectionIndex {
1013    /// Identifies connections based on the initial DCID the peer utilized
1014    ///
1015    /// Uses a standard `HashMap` to protect against hash collision attacks.
1016    ///
1017    /// Used by the server, not the client.
1018    connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
1019    /// Identifies connections based on locally created CIDs
1020    ///
1021    /// Uses a cheaper hash function since keys are locally created
1022    connection_ids: FxHashMap<ConnectionId, (ConnectionHandle, PathId)>,
1023    /// Identifies incoming connections with zero-length CIDs
1024    ///
1025    /// Uses a standard `HashMap` to protect against hash collision attacks.
1026    incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
1027    /// Identifies outgoing connections with zero-length CIDs
1028    ///
1029    /// We don't yet support explicit source addresses for client connections, and zero-length CIDs
1030    /// require a unique 4-tuple, so at most one client connection with zero-length local CIDs
1031    /// may be established per remote. We must omit the local address from the key because we don't
1032    /// necessarily know what address we're sending from, and hence receiving at.
1033    ///
1034    /// Uses a standard `HashMap` to protect against hash collision attacks.
1035    // TODO(matheus23): It's possible this could be changed now that we track the full 4-tuple on the client side, too.
1036    outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
1037    /// Reset tokens provided by the peer for the CID each connection is currently sending to
1038    ///
1039    /// Incoming stateless resets do not have correct CIDs, so we need this to identify the correct
1040    /// recipient, if any.
1041    connection_reset_tokens: ResetTokenTable,
1042}
1043
1044impl ConnectionIndex {
1045    /// Associate an incoming connection with its initial destination CID
1046    fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
1047        if dst_cid.is_empty() {
1048            return;
1049        }
1050        self.connection_ids_initial
1051            .insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
1052    }
1053
1054    /// Remove an association with an initial destination CID
1055    fn remove_initial(&mut self, dst_cid: ConnectionId) {
1056        if dst_cid.is_empty() {
1057            return;
1058        }
1059        let removed = self.connection_ids_initial.remove(&dst_cid);
1060        debug_assert!(removed.is_some());
1061    }
1062
1063    /// Associate a connection with its initial destination CID
1064    fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
1065        if dst_cid.is_empty() {
1066            return;
1067        }
1068        self.connection_ids_initial.insert(
1069            dst_cid,
1070            RouteDatagramTo::Connection(connection, PathId::ZERO),
1071        );
1072    }
1073
1074    /// Associate a connection with its first locally-chosen destination CID if used, or otherwise
1075    /// its current 4-tuple
1076    fn insert_conn(
1077        &mut self,
1078        network_path: FourTuple,
1079        dst_cid: ConnectionId,
1080        connection: ConnectionHandle,
1081        side: Side,
1082    ) {
1083        match dst_cid.len() {
1084            0 => match side {
1085                Side::Server => {
1086                    self.incoming_connection_remotes
1087                        .insert(network_path, connection);
1088                }
1089                Side::Client => {
1090                    self.outgoing_connection_remotes
1091                        .insert(network_path.remote, connection);
1092                }
1093            },
1094            _ => {
1095                self.connection_ids
1096                    .insert(dst_cid, (connection, PathId::ZERO));
1097            }
1098        }
1099    }
1100
1101    /// Discard a connection ID
1102    fn retire(&mut self, dst_cid: ConnectionId) {
1103        self.connection_ids.remove(&dst_cid);
1104    }
1105
1106    /// Remove all references to a connection
1107    fn remove(&mut self, conn: &ConnectionMeta) {
1108        if conn.side.is_server() {
1109            self.remove_initial(conn.init_cid);
1110        }
1111        for cid in conn
1112            .local_cids
1113            .values()
1114            .flat_map(|pcids| pcids.cids.values())
1115        {
1116            self.connection_ids.remove(cid);
1117        }
1118        self.incoming_connection_remotes.remove(&conn.network_path);
1119        self.outgoing_connection_remotes
1120            .remove(&conn.network_path.remote);
1121        for (remote, token) in conn.reset_token.values() {
1122            self.connection_reset_tokens.remove(*remote, *token);
1123        }
1124    }
1125
1126    /// Find the existing connection that `datagram` should be routed to, if any
1127    fn get(&self, network_path: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
1128        if !datagram.dst_cid().is_empty()
1129            && let Some(&(ch, path_id)) = self.connection_ids.get(&datagram.dst_cid())
1130        {
1131            return Some(RouteDatagramTo::Connection(ch, path_id));
1132        }
1133        if (datagram.is_initial() || datagram.is_0rtt())
1134            && let Some(&ch) = self.connection_ids_initial.get(&datagram.dst_cid())
1135        {
1136            return Some(ch);
1137        }
1138        if datagram.dst_cid().is_empty() {
1139            if let Some(&ch) = self.incoming_connection_remotes.get(network_path) {
1140                // Never multipath because QUIC-MULTIPATH 1.1 mandates the use of non-zero
1141                // length CIDs.  So this is always PathId::ZERO.
1142                return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1143            }
1144            if let Some(&ch) = self.outgoing_connection_remotes.get(&network_path.remote) {
1145                // Like above, QUIC-MULTIPATH 1.1 mandates the use of non-zero length CIDs.
1146                return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1147            }
1148        }
1149        let data = datagram.data();
1150        if data.len() < RESET_TOKEN_SIZE {
1151            return None;
1152        }
1153        // For stateless resets the PathId is meaningless since it closes the entire
1154        // connection regardless of path.  So use PathId::ZERO.
1155        self.connection_reset_tokens
1156            .get(network_path.remote, &data[data.len() - RESET_TOKEN_SIZE..])
1157            .cloned()
1158            .map(|ch| RouteDatagramTo::Connection(ch, PathId::ZERO))
1159    }
1160}
1161
1162#[derive(Debug)]
1163pub(crate) struct ConnectionMeta {
1164    init_cid: ConnectionId,
1165    /// Locally issues CIDs for each path
1166    local_cids: FxHashMap<PathId, PathLocalCids>,
1167    /// Remote/local addresses the connection began with
1168    ///
1169    /// Only needed to support connections with zero-length CIDs, which cannot migrate, so we don't
1170    /// bother keeping it up to date.
1171    network_path: FourTuple,
1172    side: Side,
1173    /// Reset tokens provided by the peer for CIDs we're currently sending to
1174    ///
1175    /// Since each reset token is for a CID, it is also for a fixed remote address which is
1176    /// also stored. This allows us to look up which reset tokens we might expect from a
1177    /// given remote address, see [`ResetTokenTable`].
1178    ///
1179    /// Each path has its own active CID. We use the [`PathId`] as a unique index, allowing
1180    /// us to retire the reset token when a path is abandoned.
1181    // TODO(matheus23): Should be migrated to make reset tokens per 4-tuple instead of per remote addr
1182    reset_token: FxHashMap<PathId, (SocketAddr, ResetToken)>,
1183}
1184
1185/// Local connection IDs for a single path
1186#[derive(Debug, Default)]
1187struct PathLocalCids {
1188    /// Number of connection IDs that have been issued in (PATH_)NEW_CONNECTION_ID frames
1189    ///
1190    /// Another way of saying this is that this is the next sequence number to be issued.
1191    issued: u64,
1192    /// Issues CIDs indexed by their sequence number.
1193    cids: FxHashMap<u64, ConnectionId>,
1194}
1195
1196/// Internal identifier for a `Connection` currently associated with an endpoint
1197#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
1198pub struct ConnectionHandle(pub usize);
1199
1200impl From<ConnectionHandle> for usize {
1201    fn from(x: ConnectionHandle) -> Self {
1202        x.0
1203    }
1204}
1205
1206impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
1207    type Output = ConnectionMeta;
1208    fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
1209        &self[ch.0]
1210    }
1211}
1212
1213impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
1214    fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
1215        &mut self[ch.0]
1216    }
1217}
1218
1219/// Event resulting from processing a single datagram
1220pub enum DatagramEvent {
1221    /// The datagram is redirected to its `Connection`
1222    ConnectionEvent(ConnectionHandle, ConnectionEvent),
1223    /// The datagram may result in starting a new `Connection`
1224    NewConnection(Incoming),
1225    /// Response generated directly by the endpoint
1226    Response(Transmit),
1227}
1228
1229/// An incoming connection for which the server has not yet begun its part of the handshake.
1230#[derive(derive_more::Debug)]
1231pub struct Incoming {
1232    #[debug(skip)]
1233    received_at: Instant,
1234    network_path: FourTuple,
1235    ecn: Option<EcnCodepoint>,
1236    #[debug(skip)]
1237    packet: InitialPacket,
1238    #[debug(skip)]
1239    rest: Option<BytesMut>,
1240    #[debug(skip)]
1241    crypto: Keys,
1242    token: IncomingToken,
1243    incoming_idx: usize,
1244    #[debug(skip)]
1245    improper_drop_warner: IncomingImproperDropWarner,
1246}
1247
1248impl Incoming {
1249    /// The local IP address which was used when the peer established the connection
1250    pub fn local_ip(&self) -> Option<IpAddr> {
1251        self.network_path.local_ip
1252    }
1253
1254    /// The peer's UDP address
1255    pub fn remote_address(&self) -> SocketAddr {
1256        self.network_path.remote
1257    }
1258
1259    /// Whether the socket address that is initiating this connection has been validated
1260    ///
1261    /// This means that the sender of the initial packet has proved that they can receive traffic
1262    /// sent to `self.remote_address()`.
1263    ///
1264    /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true.
1265    /// The inverse is not guaranteed.
1266    pub fn remote_address_validated(&self) -> bool {
1267        self.token.validated
1268    }
1269
1270    /// Whether it is legal to respond with a retry packet
1271    ///
1272    /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true.
1273    /// The inverse is not guaranteed.
1274    pub fn may_retry(&self) -> bool {
1275        self.token.retry_src_cid.is_none()
1276    }
1277
1278    /// The original destination connection ID sent by the client
1279    pub fn orig_dst_cid(&self) -> ConnectionId {
1280        self.token.orig_dst_cid
1281    }
1282
1283    /// Decrypt the Initial packet payload
1284    ///
1285    /// This clones and decrypts the packet payload (~1200 bytes).
1286    /// Can be used to extract information from the TLS ClientHello without completing the handshake.
1287    pub fn decrypt(&self) -> Option<DecryptedInitial> {
1288        let packet_number = self.packet.header.number.expand(0);
1289        let mut payload = self.packet.payload.clone();
1290        self.crypto
1291            .packet
1292            .remote
1293            .decrypt(
1294                PathId::ZERO,
1295                packet_number,
1296                &self.packet.header_data,
1297                &mut payload,
1298            )
1299            .ok()?;
1300        Some(DecryptedInitial(payload.freeze()))
1301    }
1302}
1303
1304/// Decrypted payload of a QUIC Initial packet
1305///
1306/// Obtained via [`Incoming::decrypt`]. Can be used to extract information from
1307/// the TLS ClientHello without completing the handshake.
1308pub struct DecryptedInitial(Bytes);
1309
1310impl DecryptedInitial {
1311    /// Best-effort extraction of the ALPN protocols from the TLS ClientHello
1312    ///
1313    /// Parses the CRYPTO frames to extract the ALPN extension. This is intended
1314    /// for routing and filtering; it is not guaranteed to succeed if the
1315    /// ClientHello spans multiple packets. Returns `None` if parsing fails.
1316    pub fn alpns(&self) -> Option<IncomingAlpns> {
1317        let frames = frame::Iter::new(self.0.clone()).ok()?;
1318        let mut first = None;
1319        let mut rest = Vec::new();
1320        for frame in frames {
1321            match frame {
1322                Ok(frame::Frame::Crypto(crypto)) => match first {
1323                    None => first = Some(crypto),
1324                    Some(_) => rest.push(crypto),
1325                },
1326                Err(_) => return None,
1327                _ => {}
1328            }
1329        }
1330        let first = first?;
1331
1332        // Fast path: single CRYPTO frame at offset 0 (no extra allocation)
1333        if rest.is_empty() && first.offset == 0 {
1334            let data = find_alpn_data(&first.data).ok()?;
1335            return Some(IncomingAlpns { data, pos: 0 });
1336        }
1337
1338        // Slow path: reassemble multiple CRYPTO frames
1339        rest.push(first);
1340        let source = assemble_crypto_frames(&mut rest)?;
1341        let data = find_alpn_data(&source).ok()?;
1342        Some(IncomingAlpns { data, pos: 0 })
1343    }
1344}
1345
1346/// TLS handshake type for ClientHello messages
1347/// <https://www.rfc-editor.org/rfc/rfc8446#section-4.1.2>
1348const TLS_HANDSHAKE_TYPE_CLIENT_HELLO: u8 = 0x01;
1349/// TLS extension type for Application-Layer Protocol Negotiation
1350/// <https://www.rfc-editor.org/rfc/rfc7301#section-3.1>
1351const TLS_EXTENSION_TYPE_ALPN: u16 = 0x0010;
1352/// Size of the fixed-length fields in a ClientHello (client_version + random)
1353/// <https://www.rfc-editor.org/rfc/rfc8446#section-4.1.2>
1354const TLS_CLIENT_HELLO_FIXED_LEN: usize = 2 + 32;
1355
1356/// Iterator over ALPN protocol names from a TLS ClientHello
1357///
1358/// Yields protocol names as [`Bytes`] slices. On the common fast path (single
1359/// CRYPTO frame), the only allocation is the payload clone for decryption.
1360pub struct IncomingAlpns {
1361    data: Bytes,
1362    pos: usize,
1363}
1364
1365impl Iterator for IncomingAlpns {
1366    type Item = Result<Bytes, UnexpectedEnd>;
1367
1368    fn next(&mut self) -> Option<Self::Item> {
1369        if self.pos >= self.data.len() {
1370            return None;
1371        }
1372        let len = self.data[self.pos] as usize;
1373        self.pos += 1;
1374        if self.pos + len > self.data.len() {
1375            return Some(Err(UnexpectedEnd));
1376        }
1377        let proto = self.data.slice(self.pos..self.pos + len);
1378        self.pos += len;
1379        Some(Ok(proto))
1380    }
1381}
1382
1383/// Sort CRYPTO frames by offset and concatenate into a contiguous `Bytes`
1384///
1385/// Returns `None` if there are gaps in the stream.
1386fn assemble_crypto_frames(frames: &mut [frame::Crypto]) -> Option<Bytes> {
1387    frames.sort_by_key(|f| f.offset);
1388    let capacity = frames.iter().map(|f| f.data.len()).sum();
1389    let mut buf = Vec::with_capacity(capacity);
1390    for f in frames.iter() {
1391        let start = f.offset as usize;
1392        if start > buf.len() {
1393            return None;
1394        }
1395        let end = start + f.data.len();
1396        if end > buf.len() {
1397            buf.extend_from_slice(&f.data[buf.len() - start..]);
1398        }
1399    }
1400    Some(Bytes::from(buf))
1401}
1402
1403/// Locate the raw ALPN protocol list data within a TLS ClientHello message
1404///
1405/// Parses the ClientHello in `source` and returns a [`Bytes`] containing the
1406/// u8-length-prefixed protocol names (after the outer ProtocolNameList u16
1407/// length prefix). The returned `Bytes` is a zero-copy slice of `source`.
1408fn find_alpn_data(source: &Bytes) -> Result<Bytes, UnexpectedEnd> {
1409    let mut r = &**source;
1410
1411    if u8::decode(&mut r)? != TLS_HANDSHAKE_TYPE_CLIENT_HELLO {
1412        return Err(UnexpectedEnd);
1413    }
1414
1415    // Handshake message length (u24), scopes the remainder
1416    let len = decode_u24(&mut r)?;
1417    let mut body = take(&mut r, len)?;
1418
1419    // Client version + random
1420    skip(&mut body, TLS_CLIENT_HELLO_FIXED_LEN)?;
1421
1422    // Session ID, cipher suites, compression methods
1423    skip_u8_prefixed(&mut body)?;
1424    skip_u16_prefixed(&mut body)?;
1425    skip_u8_prefixed(&mut body)?;
1426
1427    // Extensions
1428    let mut exts = take_u16_prefixed(&mut body)?;
1429    while exts.has_remaining() {
1430        let ext_type = u16::decode(&mut exts)?;
1431        let ext_data = take_u16_prefixed(&mut exts)?;
1432        if ext_type == TLS_EXTENSION_TYPE_ALPN {
1433            let list = take_u16_prefixed(&mut &*ext_data)?;
1434            return Ok(source.slice_ref(list));
1435        }
1436    }
1437    Err(UnexpectedEnd)
1438}
1439
1440/// Decode a big-endian u24 as usize
1441fn decode_u24(r: &mut &[u8]) -> Result<usize, UnexpectedEnd> {
1442    let a = u8::decode(r)?;
1443    let b = u8::decode(r)?;
1444    let c = u8::decode(r)?;
1445    Ok(u32::from_be_bytes([0, a, b, c]) as usize)
1446}
1447
1448/// Take `len` bytes from the front and return them as a sub-slice
1449fn take<'a>(r: &mut &'a [u8], len: usize) -> Result<&'a [u8], UnexpectedEnd> {
1450    if r.remaining() < len {
1451        return Err(UnexpectedEnd);
1452    }
1453    let data = &r[..len];
1454    r.advance(len);
1455    Ok(data)
1456}
1457
1458/// Read a u16 length prefix and return the sub-slice it covers
1459fn take_u16_prefixed<'a>(r: &mut &'a [u8]) -> Result<&'a [u8], UnexpectedEnd> {
1460    let len = u16::decode(r)? as usize;
1461    take(r, len)
1462}
1463
1464/// Advance past `n` bytes
1465fn skip(r: &mut &[u8], len: usize) -> Result<(), UnexpectedEnd> {
1466    take(r, len)?;
1467    Ok(())
1468}
1469
1470/// Skip a u8-length-prefixed field
1471fn skip_u8_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1472    let len = u8::decode(r)? as usize;
1473    skip(r, len)
1474}
1475
1476/// Skip a u16-length-prefixed field
1477fn skip_u16_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1478    let len = u16::decode(r)? as usize;
1479    skip(r, len)
1480}
1481
1482struct IncomingImproperDropWarner;
1483
1484impl IncomingImproperDropWarner {
1485    fn dismiss(self) {
1486        mem::forget(self);
1487    }
1488}
1489
1490impl Drop for IncomingImproperDropWarner {
1491    fn drop(&mut self) {
1492        warn!(
1493            "noq_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \
1494               (may cause memory leak and eventual inability to accept new connections)"
1495        );
1496    }
1497}
1498
1499/// Errors in the parameters being used to create a new connection
1500///
1501/// These arise before any I/O has been performed.
1502#[derive(Debug, Error, Clone, PartialEq, Eq)]
1503pub enum ConnectError {
1504    /// The endpoint can no longer create new connections
1505    ///
1506    /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled.
1507    #[error("endpoint stopping")]
1508    EndpointStopping,
1509    /// The connection could not be created because not enough of the CID space is available
1510    ///
1511    /// Try using longer connection IDs
1512    #[error("CIDs exhausted")]
1513    CidsExhausted,
1514    /// The given server name was malformed
1515    #[error("invalid server name: {0}")]
1516    InvalidServerName(String),
1517    /// The remote [`SocketAddr`] supplied was malformed
1518    ///
1519    /// Examples include attempting to connect to port 0, or using an inappropriate address family.
1520    #[error("invalid remote address: {0}")]
1521    InvalidRemoteAddress(SocketAddr),
1522    /// No default client configuration was set up
1523    ///
1524    /// Use `Endpoint::connect_with` to specify a client configuration.
1525    #[error("no default client config")]
1526    NoDefaultClientConfig,
1527    /// The local endpoint does not support the QUIC version specified in the client configuration
1528    #[error("unsupported QUIC version")]
1529    UnsupportedVersion,
1530}
1531
1532/// Error type for attempting to accept an [`Incoming`]
1533#[derive(Debug)]
1534pub struct AcceptError {
1535    /// Underlying error describing reason for failure
1536    pub cause: ConnectionError,
1537    /// Optional response to transmit back
1538    pub response: Option<Transmit>,
1539}
1540
1541/// Error for attempting to retry an [`Incoming`] which already bears a token from a previous retry
1542#[derive(Debug, Error)]
1543#[error("retry() with validated Incoming")]
1544pub struct RetryError(Box<Incoming>);
1545
1546impl RetryError {
1547    /// Get the [`Incoming`]
1548    pub fn into_incoming(self) -> Incoming {
1549        *self.0
1550    }
1551}
1552
1553/// Reset Tokens which are associated with peer socket addresses
1554///
1555/// The standard `HashMap` is used since both `SocketAddr` and `ResetToken` are
1556/// peer generated and might be usable for hash collision attacks.
1557#[derive(Default, Debug)]
1558struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
1559
1560impl ResetTokenTable {
1561    fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
1562        self.0
1563            .entry(remote)
1564            .or_default()
1565            .insert(token, ch)
1566            .is_some()
1567    }
1568
1569    fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
1570        use std::collections::hash_map::Entry;
1571        match self.0.entry(remote) {
1572            Entry::Vacant(_) => {}
1573            Entry::Occupied(mut e) => {
1574                e.get_mut().remove(&token);
1575                if e.get().is_empty() {
1576                    e.remove_entry();
1577                }
1578            }
1579        }
1580    }
1581
1582    fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
1583        let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?);
1584        self.0.get(&remote)?.get(&token)
1585    }
1586}
1587
1588#[cfg(test)]
1589mod tests {
1590    use super::*;
1591
1592    #[test]
1593    fn assemble_contiguous() {
1594        let data = b"hello world";
1595        let mut frames = vec![
1596            frame::Crypto {
1597                offset: 0,
1598                data: Bytes::from_static(&data[..5]),
1599            },
1600            frame::Crypto {
1601                offset: 5,
1602                data: Bytes::from_static(&data[5..]),
1603            },
1604        ];
1605        assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1606    }
1607
1608    #[test]
1609    fn assemble_out_of_order() {
1610        let data = b"hello world";
1611        let mut frames = vec![
1612            frame::Crypto {
1613                offset: 5,
1614                data: Bytes::from_static(&data[5..]),
1615            },
1616            frame::Crypto {
1617                offset: 0,
1618                data: Bytes::from_static(&data[..5]),
1619            },
1620        ];
1621        assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1622    }
1623
1624    #[test]
1625    fn assemble_with_overlap() {
1626        let data = b"hello world";
1627        let mut frames = vec![
1628            frame::Crypto {
1629                offset: 0,
1630                data: Bytes::from_static(&data[..7]),
1631            },
1632            frame::Crypto {
1633                offset: 5,
1634                data: Bytes::from_static(&data[5..]),
1635            },
1636        ];
1637        assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1638    }
1639
1640    #[test]
1641    fn assemble_with_gap() {
1642        let mut frames = vec![frame::Crypto {
1643            offset: 10,
1644            data: Bytes::from_static(b"world"),
1645        }];
1646        assert!(assemble_crypto_frames(&mut frames).is_none());
1647    }
1648}