noq_proto/
endpoint.rs

1use std::{
2    collections::{HashMap, hash_map},
3    convert::TryFrom,
4    fmt, mem,
5    net::{IpAddr, SocketAddr},
6    ops::{Index, IndexMut},
7    sync::Arc,
8};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use rand::{Rng, RngExt, SeedableRng, rngs::StdRng};
12use rustc_hash::FxHashMap;
13use slab::Slab;
14use thiserror::Error;
15use tracing::{debug, error, trace, warn};
16
17use crate::{
18    Duration, FourTuple, INITIAL_MTU, Instant, MAX_CID_SIZE, MIN_INITIAL_SIZE, PathId,
19    RESET_TOKEN_SIZE, ResetToken, Side, Transmit, TransportConfig, TransportError,
20    cid_generator::ConnectionIdGenerator,
21    coding::{BufMutExt, Decodable, Encodable, UnexpectedEnd},
22    config::{ClientConfig, EndpointConfig, ServerConfig},
23    connection::{Connection, ConnectionError, SideArgs},
24    crypto::{self, Keys, UnsupportedVersion},
25    frame,
26    packet::{
27        FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, PacketDecodeError,
28        PacketNumber, PartialDecode, ProtectedInitialHeader,
29    },
30    shared::{
31        ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint,
32        EndpointEvent, EndpointEventInner, IssuedCid,
33    },
34    token::{IncomingToken, InvalidRetryTokenError, Token, TokenPayload},
35    transport_parameters::{PreferredAddress, TransportParameters},
36};
37
38/// The main entry point to the library
39///
40/// This object performs no I/O whatsoever. Instead, it consumes incoming packets and
41/// connection-generated events via `handle` and `handle_event`.
42pub struct Endpoint {
43    rng: StdRng,
44    index: ConnectionIndex,
45    connections: Slab<ConnectionMeta>,
46    local_cid_generator: Box<dyn ConnectionIdGenerator>,
47    config: Arc<EndpointConfig>,
48    server_config: Option<Arc<ServerConfig>>,
49    /// Whether the underlying UDP socket promises not to fragment packets
50    allow_mtud: bool,
51    /// Time at which a stateless reset was most recently sent
52    last_stateless_reset: Option<Instant>,
53    /// Buffered Initial and 0-RTT messages for pending incoming connections
54    incoming_buffers: Slab<IncomingBuffer>,
55    all_incoming_buffers_total_bytes: u64,
56}
57
58impl Endpoint {
59    /// Create a new endpoint
60    ///
61    /// `allow_mtud` enables path MTU detection when requested by `Connection` configuration for
62    /// better performance. This requires that outgoing packets are never fragmented, which can be
63    /// achieved via e.g. the `IPV6_DONTFRAG` socket option.
64    pub fn new(
65        config: Arc<EndpointConfig>,
66        server_config: Option<Arc<ServerConfig>>,
67        allow_mtud: bool,
68    ) -> Self {
69        Self {
70            rng: config
71                .rng_seed
72                .map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::from_seed),
73            index: ConnectionIndex::default(),
74            connections: Slab::new(),
75            local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
76            config,
77            server_config,
78            allow_mtud,
79            last_stateless_reset: None,
80            incoming_buffers: Slab::new(),
81            all_incoming_buffers_total_bytes: 0,
82        }
83    }
84
85    /// Replace the server configuration, affecting new incoming connections only
86    pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
87        self.server_config = server_config;
88    }
89
90    /// Process `EndpointEvent`s emitted from related `Connection`s
91    ///
92    /// In turn, processing this event may return a `ConnectionEvent` for the same `Connection`.
93    pub fn handle_event(
94        &mut self,
95        ch: ConnectionHandle,
96        event: EndpointEvent,
97    ) -> Option<ConnectionEvent> {
98        use EndpointEventInner::*;
99        match event.0 {
100            NeedIdentifiers(path_id, now, n) => {
101                return Some(self.send_new_identifiers(path_id, now, ch, n));
102            }
103            ResetToken(path_id, remote, token) => {
104                if let Some(old) = self.connections[ch]
105                    .reset_token
106                    .insert(path_id, (remote, token))
107                {
108                    self.index.connection_reset_tokens.remove(old.0, old.1);
109                }
110                if self.index.connection_reset_tokens.insert(remote, token, ch) {
111                    warn!("duplicate reset token");
112                }
113            }
114            RetireResetToken(path_id) => {
115                if let Some(old) = self.connections[ch].reset_token.remove(&path_id) {
116                    self.index.connection_reset_tokens.remove(old.0, old.1);
117                }
118            }
119            RetireConnectionId(now, path_id, seq, allow_more_cids) => {
120                if let Some(cid) = self.connections[ch]
121                    .local_cids
122                    .get_mut(&path_id)
123                    .and_then(|pcid| pcid.cids.remove(&seq))
124                {
125                    trace!(%path_id, "local CID retired {}: {}", seq, cid);
126                    self.index.retire(cid);
127                    if allow_more_cids {
128                        return Some(self.send_new_identifiers(path_id, now, ch, 1));
129                    }
130                }
131            }
132            Drained => {
133                if let Some(conn) = self.connections.try_remove(ch.0) {
134                    self.index.remove(&conn);
135                } else {
136                    // This indicates a bug in downstream code, which could cause spurious
137                    // connection loss instead of this error if the CID was (re)allocated prior to
138                    // the illegal call.
139                    error!(id = ch.0, "unknown connection drained");
140                }
141            }
142        }
143        None
144    }
145
146    /// Process an incoming UDP datagram
147    pub fn handle(
148        &mut self,
149        now: Instant,
150        network_path: FourTuple,
151        ecn: Option<EcnCodepoint>,
152        data: BytesMut,
153        buf: &mut Vec<u8>,
154    ) -> Option<DatagramEvent> {
155        // Partially decode packet or short-circuit if unable
156        let datagram_len = data.len();
157        let mut event = match PartialDecode::new(
158            data,
159            &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()),
160            &self.config.supported_versions,
161            self.config.grease_quic_bit,
162        ) {
163            Ok((first_decode, remaining)) => DatagramConnectionEvent {
164                now,
165                network_path,
166                path_id: PathId::ZERO, // Corrected later for existing paths
167                ecn,
168                first_decode,
169                remaining,
170            },
171            Err(PacketDecodeError::UnsupportedVersion {
172                src_cid,
173                dst_cid,
174                version,
175            }) => {
176                if self.server_config.is_none() {
177                    debug!("dropping packet with unsupported version");
178                    return None;
179                }
180                trace!("sending version negotiation");
181                // Negotiate versions
182                Header::VersionNegotiate {
183                    random: self.rng.random::<u8>() | 0x40,
184                    src_cid: dst_cid,
185                    dst_cid: src_cid,
186                }
187                .encode(buf);
188                // Grease with a reserved version
189                buf.write::<u32>(match version {
190                    0x0a1a_2a3a => 0x0a1a_2a4a,
191                    _ => 0x0a1a_2a3a,
192                });
193                for &version in &self.config.supported_versions {
194                    buf.write(version);
195                }
196                return Some(DatagramEvent::Response(Transmit {
197                    destination: network_path.remote,
198                    ecn: None,
199                    size: buf.len(),
200                    segment_size: None,
201                    src_ip: network_path.local_ip,
202                }));
203            }
204            Err(e) => {
205                trace!("malformed header: {}", e);
206                return None;
207            }
208        };
209
210        let dst_cid = event.first_decode.dst_cid();
211
212        if let Some(route_to) = self.index.get(&network_path, &event.first_decode) {
213            event.path_id = match route_to {
214                RouteDatagramTo::Incoming(_) => PathId::ZERO,
215                RouteDatagramTo::Connection(_, path_id) => path_id,
216            };
217            match route_to {
218                RouteDatagramTo::Incoming(incoming_idx) => {
219                    let incoming_buffer = &mut self.incoming_buffers[incoming_idx];
220                    let config = &self.server_config.as_ref().unwrap();
221
222                    if incoming_buffer
223                        .total_bytes
224                        .checked_add(datagram_len as u64)
225                        .is_some_and(|n| n <= config.incoming_buffer_size)
226                        && self
227                            .all_incoming_buffers_total_bytes
228                            .checked_add(datagram_len as u64)
229                            .is_some_and(|n| n <= config.incoming_buffer_size_total)
230                    {
231                        incoming_buffer.datagrams.push(event);
232                        incoming_buffer.total_bytes += datagram_len as u64;
233                        self.all_incoming_buffers_total_bytes += datagram_len as u64;
234                    }
235
236                    None
237                }
238                RouteDatagramTo::Connection(ch, _path_id) => Some(DatagramEvent::ConnectionEvent(
239                    ch,
240                    ConnectionEvent(ConnectionEventInner::Datagram(event)),
241                )),
242            }
243        } else if event.first_decode.initial_header().is_some() {
244            // Potentially create a new connection
245
246            self.handle_first_packet(datagram_len, event, network_path, buf)
247        } else if event.first_decode.has_long_header() {
248            debug!(
249                "ignoring non-initial packet for unknown connection {}",
250                dst_cid
251            );
252            None
253        } else if !event.first_decode.is_initial()
254            && self.local_cid_generator.validate(dst_cid).is_err()
255        {
256            debug!("dropping packet with invalid CID");
257            None
258        } else if dst_cid.is_empty() {
259            trace!("dropping unrecognized short packet without ID");
260            None
261        } else {
262            // If we got this far, we're receiving a seemingly valid packet for an unknown
263            // connection. Send a stateless reset if possible.
264            self.stateless_reset(now, datagram_len, network_path, dst_cid, buf)
265                .map(DatagramEvent::Response)
266        }
267    }
268
269    /// Builds a stateless reset packet to respond with
270    fn stateless_reset(
271        &mut self,
272        now: Instant,
273        inciting_dgram_len: usize,
274        network_path: FourTuple,
275        dst_cid: ConnectionId,
276        buf: &mut Vec<u8>,
277    ) -> Option<Transmit> {
278        if self
279            .last_stateless_reset
280            .is_some_and(|last| last + self.config.min_reset_interval > now)
281        {
282            debug!("ignoring unexpected packet within minimum stateless reset interval");
283            return None;
284        }
285
286        /// Minimum amount of padding for the stateless reset to look like a short-header packet
287        const MIN_PADDING_LEN: usize = 5;
288
289        // Prevent amplification attacks and reset loops by ensuring we pad to at most 1 byte
290        // smaller than the inciting packet.
291        let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
292            Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
293            _ => {
294                debug!(
295                    "ignoring unexpected {} byte packet: not larger than minimum stateless reset size",
296                    inciting_dgram_len
297                );
298                return None;
299            }
300        };
301
302        debug!(%dst_cid, %network_path.remote, "sending stateless reset");
303        self.last_stateless_reset = Some(now);
304        // Resets with at least this much padding can't possibly be distinguished from real packets
305        const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
306        let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
307            max_padding_len
308        } else {
309            self.rng
310                .random_range(IDEAL_MIN_PADDING_LEN..max_padding_len)
311        };
312        buf.reserve(padding_len + RESET_TOKEN_SIZE);
313        buf.resize(padding_len, 0);
314        self.rng.fill_bytes(&mut buf[0..padding_len]);
315        buf[0] = 0b0100_0000 | (buf[0] >> 2);
316        buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));
317
318        debug_assert!(buf.len() < inciting_dgram_len);
319
320        Some(Transmit {
321            destination: network_path.remote,
322            ecn: None,
323            size: buf.len(),
324            segment_size: None,
325            src_ip: network_path.local_ip,
326        })
327    }
328
329    /// Initiate a connection
330    pub fn connect(
331        &mut self,
332        now: Instant,
333        config: ClientConfig,
334        remote: SocketAddr,
335        server_name: &str,
336    ) -> Result<(ConnectionHandle, Connection), ConnectError> {
337        if self.cids_exhausted() {
338            return Err(ConnectError::CidsExhausted);
339        }
340        if remote.port() == 0 || remote.ip().is_unspecified() {
341            return Err(ConnectError::InvalidRemoteAddress(remote));
342        }
343        if !self.config.supported_versions.contains(&config.version) {
344            return Err(ConnectError::UnsupportedVersion);
345        }
346
347        let remote_id = (config.initial_dst_cid_provider)();
348        trace!(initial_dcid = %remote_id);
349
350        let ch = ConnectionHandle(self.connections.vacant_key());
351        let local_cid = self.new_cid(ch, PathId::ZERO);
352        let params = TransportParameters::new(
353            &config.transport,
354            &self.config,
355            self.local_cid_generator.as_ref(),
356            local_cid,
357            None,
358            &mut self.rng,
359        );
360        let tls = config
361            .crypto
362            .start_session(config.version, server_name, &params)?;
363
364        let conn = self.add_connection(
365            ch,
366            config.version,
367            remote_id,
368            local_cid,
369            remote_id,
370            FourTuple {
371                remote,
372                local_ip: None,
373            },
374            now,
375            tls,
376            config.transport,
377            SideArgs::Client {
378                token_store: config.token_store,
379                server_name: server_name.into(),
380            },
381            &params,
382        );
383        Ok((ch, conn))
384    }
385
386    /// Generates new CIDs and creates message to send to the connection state
387    fn send_new_identifiers(
388        &mut self,
389        path_id: PathId,
390        now: Instant,
391        ch: ConnectionHandle,
392        num: u64,
393    ) -> ConnectionEvent {
394        let mut ids = vec![];
395        for _ in 0..num {
396            let id = self.new_cid(ch, path_id);
397            let cid_meta = self.connections[ch].local_cids.entry(path_id).or_default();
398            let sequence = cid_meta.issued;
399            cid_meta.issued += 1;
400            cid_meta.cids.insert(sequence, id);
401            ids.push(IssuedCid {
402                path_id,
403                sequence,
404                id,
405                reset_token: ResetToken::new(&*self.config.reset_key, id),
406            });
407        }
408        ConnectionEvent(ConnectionEventInner::NewIdentifiers(
409            ids,
410            now,
411            self.local_cid_generator.cid_len(),
412            self.local_cid_generator.cid_lifetime(),
413        ))
414    }
415
416    /// Generate a connection ID for `ch`
417    fn new_cid(&mut self, ch: ConnectionHandle, path_id: PathId) -> ConnectionId {
418        loop {
419            let cid = self.local_cid_generator.generate_cid();
420            if cid.is_empty() {
421                // Zero-length CID; nothing to track
422                debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
423                return cid;
424            }
425            if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
426                e.insert((ch, path_id));
427                break cid;
428            }
429        }
430    }
431
432    fn handle_first_packet(
433        &mut self,
434        datagram_len: usize,
435        event: DatagramConnectionEvent,
436        network_path: FourTuple,
437        buf: &mut Vec<u8>,
438    ) -> Option<DatagramEvent> {
439        let dst_cid = event.first_decode.dst_cid();
440        let header = event.first_decode.initial_header().unwrap();
441
442        let Some(server_config) = &self.server_config else {
443            debug!("packet for unrecognized connection {}", dst_cid);
444            return self
445                .stateless_reset(event.now, datagram_len, network_path, dst_cid, buf)
446                .map(DatagramEvent::Response);
447        };
448
449        if datagram_len < MIN_INITIAL_SIZE as usize {
450            debug!("ignoring short initial for connection {}", dst_cid);
451            return None;
452        }
453
454        let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) {
455            Ok(keys) => keys,
456            Err(UnsupportedVersion) => {
457                // This probably indicates that the user set supported_versions incorrectly in
458                // `EndpointConfig`.
459                debug!(
460                    "ignoring initial packet version {:#x} unsupported by cryptographic layer",
461                    header.version
462                );
463                return None;
464            }
465        };
466
467        if let Err(reason) = self.early_validate_first_packet(header) {
468            return Some(DatagramEvent::Response(self.initial_close(
469                header.version,
470                network_path,
471                &crypto,
472                header.src_cid,
473                reason,
474                buf,
475            )));
476        }
477
478        let packet = match event.first_decode.finish(Some(&*crypto.header.remote)) {
479            Ok(packet) => packet,
480            Err(e) => {
481                trace!("unable to decode initial packet: {}", e);
482                return None;
483            }
484        };
485
486        if !packet.reserved_bits_valid() {
487            debug!("dropping connection attempt with invalid reserved bits");
488            return None;
489        }
490
491        let Header::Initial(header) = packet.header else {
492            panic!("non-initial packet in handle_first_packet()");
493        };
494
495        let server_config = self.server_config.as_ref().unwrap().clone();
496
497        let token = match IncomingToken::from_header(&header, &server_config, network_path.remote) {
498            Ok(token) => token,
499            Err(InvalidRetryTokenError) => {
500                debug!("rejecting invalid retry token");
501                return Some(DatagramEvent::Response(self.initial_close(
502                    header.version,
503                    network_path,
504                    &crypto,
505                    header.src_cid,
506                    TransportError::INVALID_TOKEN(""),
507                    buf,
508                )));
509            }
510        };
511
512        let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default());
513        self.index
514            .insert_initial_incoming(header.dst_cid, incoming_idx);
515
516        Some(DatagramEvent::NewConnection(Incoming {
517            received_at: event.now,
518            network_path,
519            ecn: event.ecn,
520            packet: InitialPacket {
521                header,
522                header_data: packet.header_data,
523                payload: packet.payload,
524            },
525            rest: event.remaining,
526            crypto,
527            token,
528            incoming_idx,
529            improper_drop_warner: IncomingImproperDropWarner,
530        }))
531    }
532
533    /// Attempt to accept this incoming connection (an error may still occur)
534    // box err to avoid clippy::result_large_err
535    pub fn accept(
536        &mut self,
537        mut incoming: Incoming,
538        now: Instant,
539        buf: &mut Vec<u8>,
540        server_config: Option<Arc<ServerConfig>>,
541    ) -> Result<(ConnectionHandle, Connection), Box<AcceptError>> {
542        let remote_address_validated = incoming.remote_address_validated();
543        incoming.improper_drop_warner.dismiss();
544        let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
545        self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
546
547        let packet_number = incoming.packet.header.number.expand(0);
548        let InitialHeader {
549            src_cid,
550            dst_cid,
551            version,
552            ..
553        } = incoming.packet.header;
554        let server_config =
555            server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());
556
557        if server_config
558            .transport
559            .max_idle_timeout
560            .is_some_and(|timeout| {
561                incoming.received_at + Duration::from_millis(timeout.into()) <= now
562            })
563        {
564            debug!("abandoning accept of stale initial");
565            self.index.remove_initial(dst_cid);
566            return Err(Box::new(AcceptError {
567                cause: ConnectionError::TimedOut,
568                response: None,
569            }));
570        }
571
572        if self.cids_exhausted() {
573            debug!("refusing connection");
574            self.index.remove_initial(dst_cid);
575            return Err(Box::new(AcceptError {
576                cause: ConnectionError::CidsExhausted,
577                response: Some(self.initial_close(
578                    version,
579                    incoming.network_path,
580                    &incoming.crypto,
581                    src_cid,
582                    TransportError::CONNECTION_REFUSED(""),
583                    buf,
584                )),
585            }));
586        }
587
588        if incoming
589            .crypto
590            .packet
591            .remote
592            .decrypt(
593                PathId::ZERO,
594                packet_number,
595                &incoming.packet.header_data,
596                &mut incoming.packet.payload,
597            )
598            .is_err()
599        {
600            debug!(packet_number, "failed to authenticate initial packet");
601            self.index.remove_initial(dst_cid);
602            return Err(Box::new(AcceptError {
603                cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
604                response: None,
605            }));
606        };
607
608        let ch = ConnectionHandle(self.connections.vacant_key());
609        let local_cid = self.new_cid(ch, PathId::ZERO);
610        let mut params = TransportParameters::new(
611            &server_config.transport,
612            &self.config,
613            self.local_cid_generator.as_ref(),
614            local_cid,
615            Some(&server_config),
616            &mut self.rng,
617        );
618        params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, local_cid));
619        params.original_dst_cid = Some(incoming.token.orig_dst_cid);
620        params.retry_src_cid = incoming.token.retry_src_cid;
621        let mut pref_addr_cid = None;
622        if server_config.has_preferred_address() {
623            let cid = self.new_cid(ch, PathId::ZERO);
624            pref_addr_cid = Some(cid);
625            params.preferred_address = Some(PreferredAddress {
626                address_v4: server_config.preferred_address_v4,
627                address_v6: server_config.preferred_address_v6,
628                connection_id: cid,
629                stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
630            });
631        }
632
633        let tls = server_config.crypto.start_session(version, &params);
634        let transport_config = server_config.transport.clone();
635        let mut conn = self.add_connection(
636            ch,
637            version,
638            dst_cid,
639            local_cid,
640            src_cid,
641            incoming.network_path,
642            incoming.received_at,
643            tls,
644            transport_config,
645            SideArgs::Server {
646                server_config,
647                pref_addr_cid,
648                path_validated: remote_address_validated,
649            },
650            &params,
651        );
652        self.index.insert_initial(dst_cid, ch);
653
654        match conn.handle_first_packet(
655            incoming.received_at,
656            incoming.network_path,
657            incoming.ecn,
658            packet_number,
659            incoming.packet,
660            incoming.rest,
661        ) {
662            Ok(()) => {
663                trace!(
664                    id = ch.0,
665                    icid = %dst_cid,
666                    network_path = %incoming.network_path,
667                    "new connection",
668                );
669
670                for event in incoming_buffer.datagrams {
671                    conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event)))
672                }
673
674                Ok((ch, conn))
675            }
676            Err(e) => {
677                debug!("handshake failed: {}", e);
678                self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
679                let response = match e {
680                    ConnectionError::TransportError(ref e) => Some(self.initial_close(
681                        version,
682                        incoming.network_path,
683                        &incoming.crypto,
684                        src_cid,
685                        e.clone(),
686                        buf,
687                    )),
688                    _ => None,
689                };
690                Err(Box::new(AcceptError { cause: e, response }))
691            }
692        }
693    }
694
695    /// Check if we should refuse a connection attempt regardless of the packet's contents
696    fn early_validate_first_packet(
697        &mut self,
698        header: &ProtectedInitialHeader,
699    ) -> Result<(), TransportError> {
700        let config = &self.server_config.as_ref().unwrap();
701        if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming {
702            return Err(TransportError::CONNECTION_REFUSED(""));
703        }
704
705        // RFC9000 §7.2 dictates that initial (client-chosen) destination CIDs must be at least 8
706        // bytes. If this is a Retry packet, then the length must instead match our usual CID
707        // length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll
708        // also need to validate CID length for those after decoding the token.
709        if header.dst_cid.len() < 8
710            && (header.token_pos.is_empty()
711                || header.dst_cid.len() != self.local_cid_generator.cid_len())
712        {
713            debug!(
714                "rejecting connection due to invalid DCID length {}",
715                header.dst_cid.len()
716            );
717            return Err(TransportError::PROTOCOL_VIOLATION(
718                "invalid destination CID length",
719            ));
720        }
721
722        Ok(())
723    }
724
725    /// Reject this incoming connection attempt
726    pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Transmit {
727        self.clean_up_incoming(&incoming);
728        incoming.improper_drop_warner.dismiss();
729
730        self.initial_close(
731            incoming.packet.header.version,
732            incoming.network_path,
733            &incoming.crypto,
734            incoming.packet.header.src_cid,
735            TransportError::CONNECTION_REFUSED(""),
736            buf,
737        )
738    }
739
740    /// Respond with a retry packet, requiring the client to retry with address validation
741    ///
742    /// Errors if `incoming.may_retry()` is false.
743    pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Result<Transmit, RetryError> {
744        if !incoming.may_retry() {
745            return Err(RetryError(Box::new(incoming)));
746        }
747
748        self.clean_up_incoming(&incoming);
749        incoming.improper_drop_warner.dismiss();
750
751        let server_config = self.server_config.as_ref().unwrap();
752
753        // First Initial
754        // The peer will use this as the DCID of its following Initials. Initial DCIDs are
755        // looked up separately from Handshake/Data DCIDs, so there is no risk of collision
756        // with established connections. In the unlikely event that a collision occurs
757        // between two connections in the initial phase, both will fail fast and may be
758        // retried by the application layer.
759        let local_cid = self.local_cid_generator.generate_cid();
760
761        let payload = TokenPayload::Retry {
762            address: incoming.network_path.remote,
763            orig_dst_cid: incoming.packet.header.dst_cid,
764            issued: server_config.time_source.now(),
765        };
766        let token = Token::new(payload, &mut self.rng).encode(&*server_config.token_key);
767
768        let header = Header::Retry {
769            src_cid: local_cid,
770            dst_cid: incoming.packet.header.src_cid,
771            version: incoming.packet.header.version,
772        };
773
774        let encode = header.encode(buf);
775        buf.put_slice(&token);
776        buf.extend_from_slice(&server_config.crypto.retry_tag(
777            incoming.packet.header.version,
778            incoming.packet.header.dst_cid,
779            buf,
780        ));
781        encode.finish(buf, &*incoming.crypto.header.local, None);
782
783        Ok(Transmit {
784            destination: incoming.network_path.remote,
785            ecn: None,
786            size: buf.len(),
787            segment_size: None,
788            src_ip: incoming.network_path.local_ip,
789        })
790    }
791
792    /// Ignore this incoming connection attempt, not sending any packet in response
793    ///
794    /// Doing this actively, rather than merely dropping the [`Incoming`], is necessary to prevent
795    /// memory leaks due to state within [`Endpoint`] tracking the incoming connection.
796    pub fn ignore(&mut self, incoming: Incoming) {
797        self.clean_up_incoming(&incoming);
798        incoming.improper_drop_warner.dismiss();
799    }
800
801    /// Clean up endpoint data structures associated with an `Incoming`.
802    fn clean_up_incoming(&mut self, incoming: &Incoming) {
803        self.index.remove_initial(incoming.packet.header.dst_cid);
804        let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
805        self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
806    }
807
808    fn add_connection(
809        &mut self,
810        ch: ConnectionHandle,
811        version: u32,
812        init_cid: ConnectionId,
813        local_cid: ConnectionId,
814        remote_cid: ConnectionId,
815        network_path: FourTuple,
816        now: Instant,
817        tls: Box<dyn crypto::Session>,
818        transport_config: Arc<TransportConfig>,
819        side_args: SideArgs,
820        // Only used for qlog.
821        params: &TransportParameters,
822    ) -> Connection {
823        let mut rng_seed = [0; 32];
824        self.rng.fill_bytes(&mut rng_seed);
825        let side = side_args.side();
826        let pref_addr_cid = side_args.pref_addr_cid();
827
828        let qlog =
829            transport_config.create_qlog_sink(side_args.side(), network_path.remote, init_cid, now);
830
831        qlog.emit_connection_started(
832            now,
833            local_cid,
834            remote_cid,
835            network_path.remote,
836            network_path.local_ip,
837            params,
838        );
839
840        let conn = Connection::new(
841            self.config.clone(),
842            transport_config,
843            init_cid,
844            local_cid,
845            remote_cid,
846            network_path,
847            tls,
848            self.local_cid_generator.as_ref(),
849            now,
850            version,
851            self.allow_mtud,
852            rng_seed,
853            side_args,
854            qlog,
855        );
856
857        let mut path_cids = PathLocalCids::default();
858        path_cids.cids.insert(path_cids.issued, local_cid);
859        path_cids.issued += 1;
860
861        if let Some(cid) = pref_addr_cid {
862            debug_assert_eq!(path_cids.issued, 1, "preferred address cid seq must be 1");
863            path_cids.cids.insert(path_cids.issued, cid);
864            path_cids.issued += 1;
865        }
866
867        let id = self.connections.insert(ConnectionMeta {
868            init_cid,
869            local_cids: FxHashMap::from_iter([(PathId::ZERO, path_cids)]),
870            network_path,
871            side,
872            reset_token: Default::default(),
873        });
874        debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
875
876        self.index.insert_conn(network_path, local_cid, ch, side);
877
878        conn
879    }
880
881    fn initial_close(
882        &mut self,
883        version: u32,
884        network_path: FourTuple,
885        crypto: &Keys,
886        remote_id: ConnectionId,
887        reason: TransportError,
888        buf: &mut Vec<u8>,
889    ) -> Transmit {
890        // We don't need to worry about CID collisions in initial closes because the peer
891        // shouldn't respond, and if it does, and the CID collides, we'll just drop the
892        // unexpected response.
893        let local_id = self.local_cid_generator.generate_cid();
894        let number = PacketNumber::U8(0);
895        let header = Header::Initial(InitialHeader {
896            dst_cid: remote_id,
897            src_cid: local_id,
898            number,
899            token: Bytes::new(),
900            version,
901        });
902
903        let partial_encode = header.encode(buf);
904        let max_len =
905            INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len();
906        frame::Close::from(reason).encoder(max_len).encode(buf);
907        buf.resize(buf.len() + crypto.packet.local.tag_len(), 0);
908        partial_encode.finish(
909            buf,
910            &*crypto.header.local,
911            Some((0, Default::default(), &*crypto.packet.local)),
912        );
913        Transmit {
914            destination: network_path.remote,
915            ecn: None,
916            size: buf.len(),
917            segment_size: None,
918            src_ip: network_path.local_ip,
919        }
920    }
921
922    /// Access the configuration used by this endpoint
923    pub fn config(&self) -> &EndpointConfig {
924        &self.config
925    }
926
927    /// Number of connections that are currently open
928    pub fn open_connections(&self) -> usize {
929        self.connections.len()
930    }
931
932    /// Counter for the number of bytes currently used
933    /// in the buffers for Initial and 0-RTT messages for pending incoming connections
934    pub fn incoming_buffer_bytes(&self) -> u64 {
935        self.all_incoming_buffers_total_bytes
936    }
937
938    #[cfg(test)]
939    pub(crate) fn known_connections(&self) -> usize {
940        let x = self.connections.len();
941        debug_assert_eq!(x, self.index.connection_ids_initial.len());
942        // Not all connections have known reset tokens
943        debug_assert!(x >= self.index.connection_reset_tokens.0.len());
944        // Not all connections have unique remotes, and 0-length CIDs might not be in use.
945        debug_assert!(x >= self.index.incoming_connection_remotes.len());
946        debug_assert!(x >= self.index.outgoing_connection_remotes.len());
947        x
948    }
949
950    #[cfg(test)]
951    pub(crate) fn known_cids(&self) -> usize {
952        self.index.connection_ids.len()
953    }
954
955    /// Whether we've used up 3/4 of the available CID space
956    ///
957    /// We leave some space unused so that `new_cid` can be relied upon to finish quickly. We don't
958    /// bother to check when CID longer than 4 bytes are used because 2^40 connections is a lot.
959    fn cids_exhausted(&self) -> bool {
960        let cid_len = self.local_cid_generator.cid_len();
961        if cid_len == 0 || cid_len > 4 {
962            return false;
963        }
964
965        // Keep this architecture-independent: on 32-bit targets, 2usize.pow(32) overflows.
966        let bits = (cid_len * 8) as u32;
967        let space = 1u64 << bits;
968        let reserve = 1u64 << (bits - 2);
969        let len = self.index.connection_ids.len() as u64;
970
971        len > (space - reserve)
972    }
973}
974
975impl fmt::Debug for Endpoint {
976    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
977        fmt.debug_struct("Endpoint")
978            .field("rng", &self.rng)
979            .field("index", &self.index)
980            .field("connections", &self.connections)
981            .field("config", &self.config)
982            .field("server_config", &self.server_config)
983            // incoming_buffers too large
984            .field("incoming_buffers.len", &self.incoming_buffers.len())
985            .field(
986                "all_incoming_buffers_total_bytes",
987                &self.all_incoming_buffers_total_bytes,
988            )
989            .finish()
990    }
991}
992
993/// Buffered Initial and 0-RTT messages for a pending incoming connection
994#[derive(Default)]
995struct IncomingBuffer {
996    datagrams: Vec<DatagramConnectionEvent>,
997    total_bytes: u64,
998}
999
1000/// Part of protocol state incoming datagrams can be routed to
1001#[derive(Copy, Clone, Debug)]
1002enum RouteDatagramTo {
1003    Incoming(usize),
1004    Connection(ConnectionHandle, PathId),
1005}
1006
1007/// Maps packets to existing connections
1008#[derive(Default, Debug)]
1009struct ConnectionIndex {
1010    /// Identifies connections based on the initial DCID the peer utilized
1011    ///
1012    /// Uses a standard `HashMap` to protect against hash collision attacks.
1013    ///
1014    /// Used by the server, not the client.
1015    connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
1016    /// Identifies connections based on locally created CIDs
1017    ///
1018    /// Uses a cheaper hash function since keys are locally created
1019    connection_ids: FxHashMap<ConnectionId, (ConnectionHandle, PathId)>,
1020    /// Identifies incoming connections with zero-length CIDs
1021    ///
1022    /// Uses a standard `HashMap` to protect against hash collision attacks.
1023    incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
1024    /// Identifies outgoing connections with zero-length CIDs
1025    ///
1026    /// We don't yet support explicit source addresses for client connections, and zero-length CIDs
1027    /// require a unique 4-tuple, so at most one client connection with zero-length local CIDs
1028    /// may be established per remote. We must omit the local address from the key because we don't
1029    /// necessarily know what address we're sending from, and hence receiving at.
1030    ///
1031    /// Uses a standard `HashMap` to protect against hash collision attacks.
1032    // TODO(matheus23): It's possible this could be changed now that we track the full 4-tuple on the client side, too.
1033    outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
1034    /// Reset tokens provided by the peer for the CID each connection is currently sending to
1035    ///
1036    /// Incoming stateless resets do not have correct CIDs, so we need this to identify the correct
1037    /// recipient, if any.
1038    connection_reset_tokens: ResetTokenTable,
1039}
1040
1041impl ConnectionIndex {
1042    /// Associate an incoming connection with its initial destination CID
1043    fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
1044        if dst_cid.is_empty() {
1045            return;
1046        }
1047        self.connection_ids_initial
1048            .insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
1049    }
1050
1051    /// Remove an association with an initial destination CID
1052    fn remove_initial(&mut self, dst_cid: ConnectionId) {
1053        if dst_cid.is_empty() {
1054            return;
1055        }
1056        let removed = self.connection_ids_initial.remove(&dst_cid);
1057        debug_assert!(removed.is_some());
1058    }
1059
1060    /// Associate a connection with its initial destination CID
1061    fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
1062        if dst_cid.is_empty() {
1063            return;
1064        }
1065        self.connection_ids_initial.insert(
1066            dst_cid,
1067            RouteDatagramTo::Connection(connection, PathId::ZERO),
1068        );
1069    }
1070
1071    /// Associate a connection with its first locally-chosen destination CID if used, or otherwise
1072    /// its current 4-tuple
1073    fn insert_conn(
1074        &mut self,
1075        network_path: FourTuple,
1076        dst_cid: ConnectionId,
1077        connection: ConnectionHandle,
1078        side: Side,
1079    ) {
1080        match dst_cid.len() {
1081            0 => match side {
1082                Side::Server => {
1083                    self.incoming_connection_remotes
1084                        .insert(network_path, connection);
1085                }
1086                Side::Client => {
1087                    self.outgoing_connection_remotes
1088                        .insert(network_path.remote, connection);
1089                }
1090            },
1091            _ => {
1092                self.connection_ids
1093                    .insert(dst_cid, (connection, PathId::ZERO));
1094            }
1095        }
1096    }
1097
1098    /// Discard a connection ID
1099    fn retire(&mut self, dst_cid: ConnectionId) {
1100        self.connection_ids.remove(&dst_cid);
1101    }
1102
1103    /// Remove all references to a connection
1104    fn remove(&mut self, conn: &ConnectionMeta) {
1105        if conn.side.is_server() {
1106            self.remove_initial(conn.init_cid);
1107        }
1108        for cid in conn
1109            .local_cids
1110            .values()
1111            .flat_map(|pcids| pcids.cids.values())
1112        {
1113            self.connection_ids.remove(cid);
1114        }
1115        self.incoming_connection_remotes.remove(&conn.network_path);
1116        self.outgoing_connection_remotes
1117            .remove(&conn.network_path.remote);
1118        for (remote, token) in conn.reset_token.values() {
1119            self.connection_reset_tokens.remove(*remote, *token);
1120        }
1121    }
1122
1123    /// Find the existing connection that `datagram` should be routed to, if any
1124    fn get(&self, network_path: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
1125        if !datagram.dst_cid().is_empty()
1126            && let Some(&(ch, path_id)) = self.connection_ids.get(&datagram.dst_cid())
1127        {
1128            return Some(RouteDatagramTo::Connection(ch, path_id));
1129        }
1130        if (datagram.is_initial() || datagram.is_0rtt())
1131            && let Some(&ch) = self.connection_ids_initial.get(&datagram.dst_cid())
1132        {
1133            return Some(ch);
1134        }
1135        if datagram.dst_cid().is_empty() {
1136            if let Some(&ch) = self.incoming_connection_remotes.get(network_path) {
1137                // Never multipath because QUIC-MULTIPATH 1.1 mandates the use of non-zero
1138                // length CIDs.  So this is always PathId::ZERO.
1139                return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1140            }
1141            if let Some(&ch) = self.outgoing_connection_remotes.get(&network_path.remote) {
1142                // Like above, QUIC-MULTIPATH 1.1 mandates the use of non-zero length CIDs.
1143                return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1144            }
1145        }
1146        let data = datagram.data();
1147        if data.len() < RESET_TOKEN_SIZE {
1148            return None;
1149        }
1150        // For stateless resets the PathId is meaningless since it closes the entire
1151        // connection regardless of path.  So use PathId::ZERO.
1152        self.connection_reset_tokens
1153            .get(network_path.remote, &data[data.len() - RESET_TOKEN_SIZE..])
1154            .cloned()
1155            .map(|ch| RouteDatagramTo::Connection(ch, PathId::ZERO))
1156    }
1157}
1158
1159#[derive(Debug)]
1160pub(crate) struct ConnectionMeta {
1161    init_cid: ConnectionId,
1162    /// Locally issues CIDs for each path
1163    local_cids: FxHashMap<PathId, PathLocalCids>,
1164    /// Remote/local addresses the connection began with
1165    ///
1166    /// Only needed to support connections with zero-length CIDs, which cannot migrate, so we don't
1167    /// bother keeping it up to date.
1168    network_path: FourTuple,
1169    side: Side,
1170    /// Reset tokens provided by the peer for CIDs we're currently sending to
1171    ///
1172    /// Since each reset token is for a CID, it is also for a fixed remote address which is
1173    /// also stored. This allows us to look up which reset tokens we might expect from a
1174    /// given remote address, see [`ResetTokenTable`].
1175    ///
1176    /// Each path has its own active CID. We use the [`PathId`] as a unique index, allowing
1177    /// us to retire the reset token when a path is abandoned.
1178    // TODO(matheus23): Should be migrated to make reset tokens per 4-tuple instead of per remote addr
1179    reset_token: FxHashMap<PathId, (SocketAddr, ResetToken)>,
1180}
1181
1182/// Local connection IDs for a single path
1183#[derive(Debug, Default)]
1184struct PathLocalCids {
1185    /// Number of connection IDs that have been issued in (PATH_)NEW_CONNECTION_ID frames
1186    ///
1187    /// Another way of saying this is that this is the next sequence number to be issued.
1188    issued: u64,
1189    /// Issues CIDs indexed by their sequence number.
1190    cids: FxHashMap<u64, ConnectionId>,
1191}
1192
1193/// Internal identifier for a `Connection` currently associated with an endpoint
1194#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
1195pub struct ConnectionHandle(pub usize);
1196
1197impl From<ConnectionHandle> for usize {
1198    fn from(x: ConnectionHandle) -> Self {
1199        x.0
1200    }
1201}
1202
1203impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
1204    type Output = ConnectionMeta;
1205    fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
1206        &self[ch.0]
1207    }
1208}
1209
1210impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
1211    fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
1212        &mut self[ch.0]
1213    }
1214}
1215
1216/// Event resulting from processing a single datagram
1217pub enum DatagramEvent {
1218    /// The datagram is redirected to its `Connection`
1219    ConnectionEvent(ConnectionHandle, ConnectionEvent),
1220    /// The datagram may result in starting a new `Connection`
1221    NewConnection(Incoming),
1222    /// Response generated directly by the endpoint
1223    Response(Transmit),
1224}
1225
1226/// An incoming connection for which the server has not yet begun its part of the handshake.
1227#[derive(derive_more::Debug)]
1228pub struct Incoming {
1229    #[debug(skip)]
1230    received_at: Instant,
1231    network_path: FourTuple,
1232    ecn: Option<EcnCodepoint>,
1233    #[debug(skip)]
1234    packet: InitialPacket,
1235    #[debug(skip)]
1236    rest: Option<BytesMut>,
1237    #[debug(skip)]
1238    crypto: Keys,
1239    token: IncomingToken,
1240    incoming_idx: usize,
1241    #[debug(skip)]
1242    improper_drop_warner: IncomingImproperDropWarner,
1243}
1244
1245impl Incoming {
1246    /// The local IP address which was used when the peer established the connection
1247    pub fn local_ip(&self) -> Option<IpAddr> {
1248        self.network_path.local_ip
1249    }
1250
1251    /// The peer's UDP address
1252    pub fn remote_address(&self) -> SocketAddr {
1253        self.network_path.remote
1254    }
1255
1256    /// Whether the socket address that is initiating this connection has been validated
1257    ///
1258    /// This means that the sender of the initial packet has proved that they can receive traffic
1259    /// sent to `self.remote_address()`.
1260    ///
1261    /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true.
1262    /// The inverse is not guaranteed.
1263    pub fn remote_address_validated(&self) -> bool {
1264        self.token.validated
1265    }
1266
1267    /// Whether it is legal to respond with a retry packet
1268    ///
1269    /// If `self.remote_address_validated()` is false, `self.may_retry()` is guaranteed to be true.
1270    /// The inverse is not guaranteed.
1271    pub fn may_retry(&self) -> bool {
1272        self.token.retry_src_cid.is_none()
1273    }
1274
1275    /// The original destination connection ID sent by the client
1276    pub fn orig_dst_cid(&self) -> ConnectionId {
1277        self.token.orig_dst_cid
1278    }
1279
1280    /// Decrypt the Initial packet payload
1281    ///
1282    /// This clones and decrypts the packet payload (~1200 bytes).
1283    /// Can be used to extract information from the TLS ClientHello without completing the handshake.
1284    pub fn decrypt(&self) -> Option<DecryptedInitial> {
1285        let packet_number = self.packet.header.number.expand(0);
1286        let mut payload = self.packet.payload.clone();
1287        self.crypto
1288            .packet
1289            .remote
1290            .decrypt(
1291                PathId::ZERO,
1292                packet_number,
1293                &self.packet.header_data,
1294                &mut payload,
1295            )
1296            .ok()?;
1297        Some(DecryptedInitial(payload.freeze()))
1298    }
1299}
1300
1301/// Decrypted payload of a QUIC Initial packet
1302///
1303/// Obtained via [`Incoming::decrypt`]. Can be used to extract information from
1304/// the TLS ClientHello without completing the handshake.
1305pub struct DecryptedInitial(Bytes);
1306
1307impl DecryptedInitial {
1308    /// Best-effort extraction of the ALPN protocols from the TLS ClientHello
1309    ///
1310    /// Parses the CRYPTO frames to extract the ALPN extension. This is intended
1311    /// for routing and filtering; it is not guaranteed to succeed if the
1312    /// ClientHello spans multiple packets. Returns `None` if parsing fails.
1313    pub fn alpns(&self) -> Option<IncomingAlpns> {
1314        let frames = frame::Iter::new(self.0.clone()).ok()?;
1315        let mut first = None;
1316        let mut rest = Vec::new();
1317        for frame in frames {
1318            match frame {
1319                Ok(frame::Frame::Crypto(crypto)) => match first {
1320                    None => first = Some(crypto),
1321                    Some(_) => rest.push(crypto),
1322                },
1323                Err(_) => return None,
1324                _ => {}
1325            }
1326        }
1327        let first = first?;
1328
1329        // Fast path: single CRYPTO frame at offset 0 (no extra allocation)
1330        if rest.is_empty() && first.offset == 0 {
1331            let data = find_alpn_data(&first.data).ok()?;
1332            return Some(IncomingAlpns { data, pos: 0 });
1333        }
1334
1335        // Slow path: reassemble multiple CRYPTO frames
1336        rest.push(first);
1337        let source = assemble_crypto_frames(&mut rest)?;
1338        let data = find_alpn_data(&source).ok()?;
1339        Some(IncomingAlpns { data, pos: 0 })
1340    }
1341}
1342
1343/// TLS handshake type for ClientHello messages
1344/// <https://www.rfc-editor.org/rfc/rfc8446#section-4.1.2>
1345const TLS_HANDSHAKE_TYPE_CLIENT_HELLO: u8 = 0x01;
1346/// TLS extension type for Application-Layer Protocol Negotiation
1347/// <https://www.rfc-editor.org/rfc/rfc7301#section-3.1>
1348const TLS_EXTENSION_TYPE_ALPN: u16 = 0x0010;
1349/// Size of the fixed-length fields in a ClientHello (client_version + random)
1350/// <https://www.rfc-editor.org/rfc/rfc8446#section-4.1.2>
1351const TLS_CLIENT_HELLO_FIXED_LEN: usize = 2 + 32;
1352
1353/// Iterator over ALPN protocol names from a TLS ClientHello
1354///
1355/// Yields protocol names as [`Bytes`] slices. On the common fast path (single
1356/// CRYPTO frame), the only allocation is the payload clone for decryption.
1357pub struct IncomingAlpns {
1358    data: Bytes,
1359    pos: usize,
1360}
1361
1362impl Iterator for IncomingAlpns {
1363    type Item = Result<Bytes, UnexpectedEnd>;
1364
1365    fn next(&mut self) -> Option<Self::Item> {
1366        if self.pos >= self.data.len() {
1367            return None;
1368        }
1369        let len = self.data[self.pos] as usize;
1370        self.pos += 1;
1371        if self.pos + len > self.data.len() {
1372            return Some(Err(UnexpectedEnd));
1373        }
1374        let proto = self.data.slice(self.pos..self.pos + len);
1375        self.pos += len;
1376        Some(Ok(proto))
1377    }
1378}
1379
1380/// Sort CRYPTO frames by offset and concatenate into a contiguous `Bytes`
1381///
1382/// Returns `None` if there are gaps in the stream.
1383fn assemble_crypto_frames(frames: &mut [frame::Crypto]) -> Option<Bytes> {
1384    frames.sort_by_key(|f| f.offset);
1385    let capacity = frames.iter().map(|f| f.data.len()).sum();
1386    let mut buf = Vec::with_capacity(capacity);
1387    for f in frames.iter() {
1388        let start = f.offset as usize;
1389        if start > buf.len() {
1390            return None;
1391        }
1392        let end = start + f.data.len();
1393        if end > buf.len() {
1394            buf.extend_from_slice(&f.data[buf.len() - start..]);
1395        }
1396    }
1397    Some(Bytes::from(buf))
1398}
1399
1400/// Locate the raw ALPN protocol list data within a TLS ClientHello message
1401///
1402/// Parses the ClientHello in `source` and returns a [`Bytes`] containing the
1403/// u8-length-prefixed protocol names (after the outer ProtocolNameList u16
1404/// length prefix). The returned `Bytes` is a zero-copy slice of `source`.
1405fn find_alpn_data(source: &Bytes) -> Result<Bytes, UnexpectedEnd> {
1406    let mut r = &**source;
1407
1408    if u8::decode(&mut r)? != TLS_HANDSHAKE_TYPE_CLIENT_HELLO {
1409        return Err(UnexpectedEnd);
1410    }
1411
1412    // Handshake message length (u24), scopes the remainder
1413    let len = decode_u24(&mut r)?;
1414    let mut body = take(&mut r, len)?;
1415
1416    // Client version + random
1417    skip(&mut body, TLS_CLIENT_HELLO_FIXED_LEN)?;
1418
1419    // Session ID, cipher suites, compression methods
1420    skip_u8_prefixed(&mut body)?;
1421    skip_u16_prefixed(&mut body)?;
1422    skip_u8_prefixed(&mut body)?;
1423
1424    // Extensions
1425    let mut exts = take_u16_prefixed(&mut body)?;
1426    while exts.has_remaining() {
1427        let ext_type = u16::decode(&mut exts)?;
1428        let ext_data = take_u16_prefixed(&mut exts)?;
1429        if ext_type == TLS_EXTENSION_TYPE_ALPN {
1430            let list = take_u16_prefixed(&mut &*ext_data)?;
1431            return Ok(source.slice_ref(list));
1432        }
1433    }
1434    Err(UnexpectedEnd)
1435}
1436
1437/// Decode a big-endian u24 as usize
1438fn decode_u24(r: &mut &[u8]) -> Result<usize, UnexpectedEnd> {
1439    let a = u8::decode(r)?;
1440    let b = u8::decode(r)?;
1441    let c = u8::decode(r)?;
1442    Ok(u32::from_be_bytes([0, a, b, c]) as usize)
1443}
1444
1445/// Take `len` bytes from the front and return them as a sub-slice
1446fn take<'a>(r: &mut &'a [u8], len: usize) -> Result<&'a [u8], UnexpectedEnd> {
1447    if r.remaining() < len {
1448        return Err(UnexpectedEnd);
1449    }
1450    let data = &r[..len];
1451    r.advance(len);
1452    Ok(data)
1453}
1454
1455/// Read a u16 length prefix and return the sub-slice it covers
1456fn take_u16_prefixed<'a>(r: &mut &'a [u8]) -> Result<&'a [u8], UnexpectedEnd> {
1457    let len = u16::decode(r)? as usize;
1458    take(r, len)
1459}
1460
1461/// Advance past `n` bytes
1462fn skip(r: &mut &[u8], len: usize) -> Result<(), UnexpectedEnd> {
1463    take(r, len)?;
1464    Ok(())
1465}
1466
1467/// Skip a u8-length-prefixed field
1468fn skip_u8_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1469    let len = u8::decode(r)? as usize;
1470    skip(r, len)
1471}
1472
1473/// Skip a u16-length-prefixed field
1474fn skip_u16_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1475    let len = u16::decode(r)? as usize;
1476    skip(r, len)
1477}
1478
1479struct IncomingImproperDropWarner;
1480
1481impl IncomingImproperDropWarner {
1482    fn dismiss(self) {
1483        mem::forget(self);
1484    }
1485}
1486
1487impl Drop for IncomingImproperDropWarner {
1488    fn drop(&mut self) {
1489        warn!(
1490            "noq_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \
1491               (may cause memory leak and eventual inability to accept new connections)"
1492        );
1493    }
1494}
1495
1496/// Errors in the parameters being used to create a new connection
1497///
1498/// These arise before any I/O has been performed.
1499#[derive(Debug, Error, Clone, PartialEq, Eq)]
1500pub enum ConnectError {
1501    /// The endpoint can no longer create new connections
1502    ///
1503    /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled.
1504    #[error("endpoint stopping")]
1505    EndpointStopping,
1506    /// The connection could not be created because not enough of the CID space is available
1507    ///
1508    /// Try using longer connection IDs
1509    #[error("CIDs exhausted")]
1510    CidsExhausted,
1511    /// The given server name was malformed
1512    #[error("invalid server name: {0}")]
1513    InvalidServerName(String),
1514    /// The remote [`SocketAddr`] supplied was malformed
1515    ///
1516    /// Examples include attempting to connect to port 0, or using an inappropriate address family.
1517    #[error("invalid remote address: {0}")]
1518    InvalidRemoteAddress(SocketAddr),
1519    /// No default client configuration was set up
1520    ///
1521    /// Use `Endpoint::connect_with` to specify a client configuration.
1522    #[error("no default client config")]
1523    NoDefaultClientConfig,
1524    /// The local endpoint does not support the QUIC version specified in the client configuration
1525    #[error("unsupported QUIC version")]
1526    UnsupportedVersion,
1527}
1528
1529/// Error type for attempting to accept an [`Incoming`]
1530#[derive(Debug)]
1531pub struct AcceptError {
1532    /// Underlying error describing reason for failure
1533    pub cause: ConnectionError,
1534    /// Optional response to transmit back
1535    pub response: Option<Transmit>,
1536}
1537
1538/// Error for attempting to retry an [`Incoming`] which already bears a token from a previous retry
1539#[derive(Debug, Error)]
1540#[error("retry() with validated Incoming")]
1541pub struct RetryError(Box<Incoming>);
1542
1543impl RetryError {
1544    /// Get the [`Incoming`]
1545    pub fn into_incoming(self) -> Incoming {
1546        *self.0
1547    }
1548}
1549
1550/// Reset Tokens which are associated with peer socket addresses
1551///
1552/// The standard `HashMap` is used since both `SocketAddr` and `ResetToken` are
1553/// peer generated and might be usable for hash collision attacks.
1554#[derive(Default, Debug)]
1555struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
1556
1557impl ResetTokenTable {
1558    fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
1559        self.0
1560            .entry(remote)
1561            .or_default()
1562            .insert(token, ch)
1563            .is_some()
1564    }
1565
1566    fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
1567        use std::collections::hash_map::Entry;
1568        match self.0.entry(remote) {
1569            Entry::Vacant(_) => {}
1570            Entry::Occupied(mut e) => {
1571                e.get_mut().remove(&token);
1572                if e.get().is_empty() {
1573                    e.remove_entry();
1574                }
1575            }
1576        }
1577    }
1578
1579    fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
1580        let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?);
1581        self.0.get(&remote)?.get(&token)
1582    }
1583}
1584
1585#[cfg(test)]
1586mod tests {
1587    use super::*;
1588
1589    #[test]
1590    fn assemble_contiguous() {
1591        let data = b"hello world";
1592        let mut frames = vec![
1593            frame::Crypto {
1594                offset: 0,
1595                data: Bytes::from_static(&data[..5]),
1596            },
1597            frame::Crypto {
1598                offset: 5,
1599                data: Bytes::from_static(&data[5..]),
1600            },
1601        ];
1602        assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1603    }
1604
1605    #[test]
1606    fn assemble_out_of_order() {
1607        let data = b"hello world";
1608        let mut frames = vec![
1609            frame::Crypto {
1610                offset: 5,
1611                data: Bytes::from_static(&data[5..]),
1612            },
1613            frame::Crypto {
1614                offset: 0,
1615                data: Bytes::from_static(&data[..5]),
1616            },
1617        ];
1618        assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1619    }
1620
1621    #[test]
1622    fn assemble_with_overlap() {
1623        let data = b"hello world";
1624        let mut frames = vec![
1625            frame::Crypto {
1626                offset: 0,
1627                data: Bytes::from_static(&data[..7]),
1628            },
1629            frame::Crypto {
1630                offset: 5,
1631                data: Bytes::from_static(&data[5..]),
1632            },
1633        ];
1634        assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1635    }
1636
1637    #[test]
1638    fn assemble_with_gap() {
1639        let mut frames = vec![frame::Crypto {
1640            offset: 10,
1641            data: Bytes::from_static(b"world"),
1642        }];
1643        assert!(assemble_crypto_frames(&mut frames).is_none());
1644    }
1645}