1use std::{
2 collections::{HashMap, hash_map},
3 convert::TryFrom,
4 fmt, mem,
5 net::{IpAddr, SocketAddr},
6 ops::{Index, IndexMut},
7 sync::Arc,
8};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use rand::{Rng, RngExt, SeedableRng, rngs::StdRng};
12use rustc_hash::FxHashMap;
13use slab::Slab;
14use thiserror::Error;
15use tracing::{debug, error, trace, warn};
16
17use crate::{
18 Duration, FourTuple, INITIAL_MTU, Instant, MAX_CID_SIZE, MIN_INITIAL_SIZE, PathId,
19 RESET_TOKEN_SIZE, ResetToken, Side, Transmit, TransportConfig, TransportError,
20 cid_generator::ConnectionIdGenerator,
21 coding::{BufMutExt, Decodable, Encodable, UnexpectedEnd},
22 config::{ClientConfig, EndpointConfig, ServerConfig},
23 connection::{Connection, ConnectionError, SideArgs},
24 crypto::{self, Keys, UnsupportedVersion},
25 frame,
26 packet::{
27 FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, PacketDecodeError,
28 PacketNumber, PartialDecode, ProtectedInitialHeader,
29 },
30 shared::{
31 ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint,
32 EndpointEvent, EndpointEventInner, IssuedCid,
33 },
34 token::{IncomingToken, InvalidRetryTokenError, Token, TokenPayload},
35 transport_parameters::{PreferredAddress, TransportParameters},
36};
37
38pub struct Endpoint {
43 rng: StdRng,
44 index: ConnectionIndex,
45 connections: Slab<ConnectionMeta>,
46 local_cid_generator: Box<dyn ConnectionIdGenerator>,
47 config: Arc<EndpointConfig>,
48 server_config: Option<Arc<ServerConfig>>,
49 allow_mtud: bool,
51 last_stateless_reset: Option<Instant>,
53 incoming_buffers: Slab<IncomingBuffer>,
55 all_incoming_buffers_total_bytes: u64,
56}
57
58impl Endpoint {
59 pub fn new(
65 config: Arc<EndpointConfig>,
66 server_config: Option<Arc<ServerConfig>>,
67 allow_mtud: bool,
68 ) -> Self {
69 Self {
70 rng: config
71 .rng_seed
72 .map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::from_seed),
73 index: ConnectionIndex::default(),
74 connections: Slab::new(),
75 local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
76 config,
77 server_config,
78 allow_mtud,
79 last_stateless_reset: None,
80 incoming_buffers: Slab::new(),
81 all_incoming_buffers_total_bytes: 0,
82 }
83 }
84
85 pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
87 self.server_config = server_config;
88 }
89
90 pub fn handle_event(
94 &mut self,
95 ch: ConnectionHandle,
96 event: EndpointEvent,
97 ) -> Option<ConnectionEvent> {
98 use EndpointEventInner::*;
99 match event.0 {
100 NeedIdentifiers(path_id, now, n) => {
101 return Some(self.send_new_identifiers(path_id, now, ch, n));
102 }
103 ResetToken(path_id, remote, token) => {
104 if let Some(old) = self.connections[ch]
105 .reset_token
106 .insert(path_id, (remote, token))
107 {
108 self.index.connection_reset_tokens.remove(old.0, old.1);
109 }
110 if self.index.connection_reset_tokens.insert(remote, token, ch) {
111 warn!("duplicate reset token");
112 }
113 }
114 RetireResetToken(path_id) => {
115 if let Some(old) = self.connections[ch].reset_token.remove(&path_id) {
116 self.index.connection_reset_tokens.remove(old.0, old.1);
117 }
118 }
119 RetireConnectionId(now, path_id, seq, allow_more_cids) => {
120 if let Some(cid) = self.connections[ch]
121 .local_cids
122 .get_mut(&path_id)
123 .and_then(|pcid| pcid.cids.remove(&seq))
124 {
125 trace!(%path_id, "local CID retired {}: {}", seq, cid);
126 self.index.retire(cid);
127 if allow_more_cids {
128 return Some(self.send_new_identifiers(path_id, now, ch, 1));
129 }
130 }
131 }
132 Draining => {
133 }
135 Drained => {
136 if let Some(conn) = self.connections.try_remove(ch.0) {
137 self.index.remove(&conn);
138 } else {
139 error!(id = ch.0, "unknown connection drained");
143 }
144 }
145 }
146 None
147 }
148
149 pub fn handle(
151 &mut self,
152 now: Instant,
153 network_path: FourTuple,
154 ecn: Option<EcnCodepoint>,
155 data: BytesMut,
156 buf: &mut Vec<u8>,
157 ) -> Option<DatagramEvent> {
158 let datagram_len = data.len();
160 let mut event = match PartialDecode::new(
161 data,
162 &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()),
163 &self.config.supported_versions,
164 self.config.grease_quic_bit,
165 ) {
166 Ok((first_decode, remaining)) => DatagramConnectionEvent {
167 now,
168 network_path,
169 path_id: PathId::ZERO, ecn,
171 first_decode,
172 remaining,
173 },
174 Err(PacketDecodeError::UnsupportedVersion {
175 src_cid,
176 dst_cid,
177 version,
178 }) => {
179 if self.server_config.is_none() {
180 debug!("dropping packet with unsupported version");
181 return None;
182 }
183 trace!("sending version negotiation");
184 Header::VersionNegotiate {
186 random: self.rng.random::<u8>() | 0x40,
187 src_cid: dst_cid,
188 dst_cid: src_cid,
189 }
190 .encode(buf);
191 buf.write::<u32>(match version {
193 0x0a1a_2a3a => 0x0a1a_2a4a,
194 _ => 0x0a1a_2a3a,
195 });
196 for &version in &self.config.supported_versions {
197 buf.write(version);
198 }
199 return Some(DatagramEvent::Response(Transmit {
200 destination: network_path.remote,
201 ecn: None,
202 size: buf.len(),
203 segment_size: None,
204 src_ip: network_path.local_ip,
205 }));
206 }
207 Err(e) => {
208 trace!("malformed header: {}", e);
209 return None;
210 }
211 };
212
213 let dst_cid = event.first_decode.dst_cid();
214
215 if let Some(route_to) = self.index.get(&network_path, &event.first_decode) {
216 event.path_id = match route_to {
217 RouteDatagramTo::Incoming(_) => PathId::ZERO,
218 RouteDatagramTo::Connection(_, path_id) => path_id,
219 };
220 match route_to {
221 RouteDatagramTo::Incoming(incoming_idx) => {
222 let incoming_buffer = &mut self.incoming_buffers[incoming_idx];
223 let config = &self.server_config.as_ref().unwrap();
224
225 if incoming_buffer
226 .total_bytes
227 .checked_add(datagram_len as u64)
228 .is_some_and(|n| n <= config.incoming_buffer_size)
229 && self
230 .all_incoming_buffers_total_bytes
231 .checked_add(datagram_len as u64)
232 .is_some_and(|n| n <= config.incoming_buffer_size_total)
233 {
234 incoming_buffer.datagrams.push(event);
235 incoming_buffer.total_bytes += datagram_len as u64;
236 self.all_incoming_buffers_total_bytes += datagram_len as u64;
237 }
238
239 None
240 }
241 RouteDatagramTo::Connection(ch, _path_id) => Some(DatagramEvent::ConnectionEvent(
242 ch,
243 ConnectionEvent(ConnectionEventInner::Datagram(event)),
244 )),
245 }
246 } else if event.first_decode.initial_header().is_some() {
247 self.handle_first_packet(datagram_len, event, network_path, buf)
250 } else if event.first_decode.has_long_header() {
251 debug!(
252 "ignoring non-initial packet for unknown connection {}",
253 dst_cid
254 );
255 None
256 } else if !event.first_decode.is_initial()
257 && self.local_cid_generator.validate(dst_cid).is_err()
258 {
259 debug!("dropping packet with invalid CID");
260 None
261 } else if dst_cid.is_empty() {
262 trace!("dropping unrecognized short packet without ID");
263 None
264 } else {
265 self.stateless_reset(now, datagram_len, network_path, dst_cid, buf)
268 .map(DatagramEvent::Response)
269 }
270 }
271
272 fn stateless_reset(
274 &mut self,
275 now: Instant,
276 inciting_dgram_len: usize,
277 network_path: FourTuple,
278 dst_cid: ConnectionId,
279 buf: &mut Vec<u8>,
280 ) -> Option<Transmit> {
281 if self
282 .last_stateless_reset
283 .is_some_and(|last| last + self.config.min_reset_interval > now)
284 {
285 debug!("ignoring unexpected packet within minimum stateless reset interval");
286 return None;
287 }
288
289 const MIN_PADDING_LEN: usize = 5;
291
292 let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
295 Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
296 _ => {
297 debug!(
298 "ignoring unexpected {} byte packet: not larger than minimum stateless reset size",
299 inciting_dgram_len
300 );
301 return None;
302 }
303 };
304
305 debug!(%dst_cid, %network_path.remote, "sending stateless reset");
306 self.last_stateless_reset = Some(now);
307 const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
309 let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
310 max_padding_len
311 } else {
312 self.rng
313 .random_range(IDEAL_MIN_PADDING_LEN..max_padding_len)
314 };
315 buf.reserve(padding_len + RESET_TOKEN_SIZE);
316 buf.resize(padding_len, 0);
317 self.rng.fill_bytes(&mut buf[0..padding_len]);
318 buf[0] = 0b0100_0000 | (buf[0] >> 2);
319 buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));
320
321 debug_assert!(buf.len() < inciting_dgram_len);
322
323 Some(Transmit {
324 destination: network_path.remote,
325 ecn: None,
326 size: buf.len(),
327 segment_size: None,
328 src_ip: network_path.local_ip,
329 })
330 }
331
332 pub fn connect(
334 &mut self,
335 now: Instant,
336 config: ClientConfig,
337 remote: SocketAddr,
338 server_name: &str,
339 ) -> Result<(ConnectionHandle, Connection), ConnectError> {
340 if self.cids_exhausted() {
341 return Err(ConnectError::CidsExhausted);
342 }
343 if remote.port() == 0 || remote.ip().is_unspecified() {
344 return Err(ConnectError::InvalidRemoteAddress(remote));
345 }
346 if !self.config.supported_versions.contains(&config.version) {
347 return Err(ConnectError::UnsupportedVersion);
348 }
349
350 let remote_id = (config.initial_dst_cid_provider)();
351 trace!(initial_dcid = %remote_id);
352
353 let ch = ConnectionHandle(self.connections.vacant_key());
354 let local_cid = self.new_cid(ch, PathId::ZERO);
355 let params = TransportParameters::new(
356 &config.transport,
357 &self.config,
358 self.local_cid_generator.as_ref(),
359 local_cid,
360 None,
361 &mut self.rng,
362 );
363 let tls = config
364 .crypto
365 .start_session(config.version, server_name, ¶ms)?;
366
367 let conn = self.add_connection(
368 ch,
369 config.version,
370 remote_id,
371 local_cid,
372 remote_id,
373 FourTuple {
374 remote,
375 local_ip: None,
376 },
377 now,
378 tls,
379 config.transport,
380 SideArgs::Client {
381 token_store: config.token_store,
382 server_name: server_name.into(),
383 },
384 ¶ms,
385 );
386 Ok((ch, conn))
387 }
388
389 fn send_new_identifiers(
391 &mut self,
392 path_id: PathId,
393 now: Instant,
394 ch: ConnectionHandle,
395 num: u64,
396 ) -> ConnectionEvent {
397 let mut ids = vec![];
398 for _ in 0..num {
399 let id = self.new_cid(ch, path_id);
400 let cid_meta = self.connections[ch].local_cids.entry(path_id).or_default();
401 let sequence = cid_meta.issued;
402 cid_meta.issued += 1;
403 cid_meta.cids.insert(sequence, id);
404 ids.push(IssuedCid {
405 path_id,
406 sequence,
407 id,
408 reset_token: ResetToken::new(&*self.config.reset_key, id),
409 });
410 }
411 ConnectionEvent(ConnectionEventInner::NewIdentifiers(
412 ids,
413 now,
414 self.local_cid_generator.cid_len(),
415 self.local_cid_generator.cid_lifetime(),
416 ))
417 }
418
419 fn new_cid(&mut self, ch: ConnectionHandle, path_id: PathId) -> ConnectionId {
421 loop {
422 let cid = self.local_cid_generator.generate_cid();
423 if cid.is_empty() {
424 debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
426 return cid;
427 }
428 if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
429 e.insert((ch, path_id));
430 break cid;
431 }
432 }
433 }
434
435 fn handle_first_packet(
436 &mut self,
437 datagram_len: usize,
438 event: DatagramConnectionEvent,
439 network_path: FourTuple,
440 buf: &mut Vec<u8>,
441 ) -> Option<DatagramEvent> {
442 let dst_cid = event.first_decode.dst_cid();
443 let header = event.first_decode.initial_header().unwrap();
444
445 let Some(server_config) = &self.server_config else {
446 debug!("packet for unrecognized connection {}", dst_cid);
447 return self
448 .stateless_reset(event.now, datagram_len, network_path, dst_cid, buf)
449 .map(DatagramEvent::Response);
450 };
451
452 if datagram_len < MIN_INITIAL_SIZE as usize {
453 debug!("ignoring short initial for connection {}", dst_cid);
454 return None;
455 }
456
457 let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) {
458 Ok(keys) => keys,
459 Err(UnsupportedVersion) => {
460 debug!(
463 "ignoring initial packet version {:#x} unsupported by cryptographic layer",
464 header.version
465 );
466 return None;
467 }
468 };
469
470 if let Err(reason) = self.early_validate_first_packet(header) {
471 return Some(DatagramEvent::Response(self.initial_close(
472 header.version,
473 network_path,
474 &crypto,
475 header.src_cid,
476 reason,
477 buf,
478 )));
479 }
480
481 let packet = match event.first_decode.finish(Some(&*crypto.header.remote)) {
482 Ok(packet) => packet,
483 Err(e) => {
484 trace!("unable to decode initial packet: {}", e);
485 return None;
486 }
487 };
488
489 if !packet.reserved_bits_valid() {
490 debug!("dropping connection attempt with invalid reserved bits");
491 return None;
492 }
493
494 let Header::Initial(header) = packet.header else {
495 panic!("non-initial packet in handle_first_packet()");
496 };
497
498 let server_config = self.server_config.as_ref().unwrap().clone();
499
500 let token = match IncomingToken::from_header(&header, &server_config, network_path.remote) {
501 Ok(token) => token,
502 Err(InvalidRetryTokenError) => {
503 debug!("rejecting invalid retry token");
504 return Some(DatagramEvent::Response(self.initial_close(
505 header.version,
506 network_path,
507 &crypto,
508 header.src_cid,
509 TransportError::INVALID_TOKEN(""),
510 buf,
511 )));
512 }
513 };
514
515 let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default());
516 self.index
517 .insert_initial_incoming(header.dst_cid, incoming_idx);
518
519 Some(DatagramEvent::NewConnection(Incoming {
520 received_at: event.now,
521 network_path,
522 ecn: event.ecn,
523 packet: InitialPacket {
524 header,
525 header_data: packet.header_data,
526 payload: packet.payload,
527 },
528 rest: event.remaining,
529 crypto,
530 token,
531 incoming_idx,
532 improper_drop_warner: IncomingImproperDropWarner,
533 }))
534 }
535
536 pub fn accept(
539 &mut self,
540 mut incoming: Incoming,
541 now: Instant,
542 buf: &mut Vec<u8>,
543 server_config: Option<Arc<ServerConfig>>,
544 ) -> Result<(ConnectionHandle, Connection), Box<AcceptError>> {
545 let remote_address_validated = incoming.remote_address_validated();
546 incoming.improper_drop_warner.dismiss();
547 let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
548 self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
549
550 let packet_number = incoming.packet.header.number.expand(0);
551 let InitialHeader {
552 src_cid,
553 dst_cid,
554 version,
555 ..
556 } = incoming.packet.header;
557 let server_config =
558 server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());
559
560 if server_config
561 .transport
562 .max_idle_timeout
563 .is_some_and(|timeout| {
564 incoming.received_at + Duration::from_millis(timeout.into()) <= now
565 })
566 {
567 debug!("abandoning accept of stale initial");
568 self.index.remove_initial(dst_cid);
569 return Err(Box::new(AcceptError {
570 cause: ConnectionError::TimedOut,
571 response: None,
572 }));
573 }
574
575 if self.cids_exhausted() {
576 debug!("refusing connection");
577 self.index.remove_initial(dst_cid);
578 return Err(Box::new(AcceptError {
579 cause: ConnectionError::CidsExhausted,
580 response: Some(self.initial_close(
581 version,
582 incoming.network_path,
583 &incoming.crypto,
584 src_cid,
585 TransportError::CONNECTION_REFUSED(""),
586 buf,
587 )),
588 }));
589 }
590
591 if incoming
592 .crypto
593 .packet
594 .remote
595 .decrypt(
596 PathId::ZERO,
597 packet_number,
598 &incoming.packet.header_data,
599 &mut incoming.packet.payload,
600 )
601 .is_err()
602 {
603 debug!(packet_number, "failed to authenticate initial packet");
604 self.index.remove_initial(dst_cid);
605 return Err(Box::new(AcceptError {
606 cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
607 response: None,
608 }));
609 };
610
611 let ch = ConnectionHandle(self.connections.vacant_key());
612 let local_cid = self.new_cid(ch, PathId::ZERO);
613 let mut params = TransportParameters::new(
614 &server_config.transport,
615 &self.config,
616 self.local_cid_generator.as_ref(),
617 local_cid,
618 Some(&server_config),
619 &mut self.rng,
620 );
621 params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, local_cid));
622 params.original_dst_cid = Some(incoming.token.orig_dst_cid);
623 params.retry_src_cid = incoming.token.retry_src_cid;
624 let mut pref_addr_cid = None;
625 if server_config.has_preferred_address() {
626 let cid = self.new_cid(ch, PathId::ZERO);
627 pref_addr_cid = Some(cid);
628 params.preferred_address = Some(PreferredAddress {
629 address_v4: server_config.preferred_address_v4,
630 address_v6: server_config.preferred_address_v6,
631 connection_id: cid,
632 stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
633 });
634 }
635
636 let tls = server_config.crypto.start_session(version, ¶ms);
637 let transport_config = server_config.transport.clone();
638 let mut conn = self.add_connection(
639 ch,
640 version,
641 dst_cid,
642 local_cid,
643 src_cid,
644 incoming.network_path,
645 incoming.received_at,
646 tls,
647 transport_config,
648 SideArgs::Server {
649 server_config,
650 pref_addr_cid,
651 path_validated: remote_address_validated,
652 },
653 ¶ms,
654 );
655 self.index.insert_initial(dst_cid, ch);
656
657 match conn.handle_first_packet(
658 incoming.received_at,
659 incoming.network_path,
660 incoming.ecn,
661 packet_number,
662 incoming.packet,
663 incoming.rest,
664 ) {
665 Ok(()) => {
666 trace!(
667 id = ch.0,
668 icid = %dst_cid,
669 network_path = %incoming.network_path,
670 "new connection",
671 );
672
673 for event in incoming_buffer.datagrams {
674 conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event)))
675 }
676
677 Ok((ch, conn))
678 }
679 Err(e) => {
680 debug!("handshake failed: {}", e);
681 self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
682 let response = match e {
683 ConnectionError::TransportError(ref e) => Some(self.initial_close(
684 version,
685 incoming.network_path,
686 &incoming.crypto,
687 src_cid,
688 e.clone(),
689 buf,
690 )),
691 _ => None,
692 };
693 Err(Box::new(AcceptError { cause: e, response }))
694 }
695 }
696 }
697
698 fn early_validate_first_packet(
700 &mut self,
701 header: &ProtectedInitialHeader,
702 ) -> Result<(), TransportError> {
703 let config = &self.server_config.as_ref().unwrap();
704 if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming {
705 return Err(TransportError::CONNECTION_REFUSED(""));
706 }
707
708 if header.dst_cid.len() < 8
713 && (header.token_pos.is_empty()
714 || header.dst_cid.len() != self.local_cid_generator.cid_len())
715 {
716 debug!(
717 "rejecting connection due to invalid DCID length {}",
718 header.dst_cid.len()
719 );
720 return Err(TransportError::PROTOCOL_VIOLATION(
721 "invalid destination CID length",
722 ));
723 }
724
725 Ok(())
726 }
727
728 pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Transmit {
730 self.clean_up_incoming(&incoming);
731 incoming.improper_drop_warner.dismiss();
732
733 self.initial_close(
734 incoming.packet.header.version,
735 incoming.network_path,
736 &incoming.crypto,
737 incoming.packet.header.src_cid,
738 TransportError::CONNECTION_REFUSED(""),
739 buf,
740 )
741 }
742
743 pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Result<Transmit, RetryError> {
747 if !incoming.may_retry() {
748 return Err(RetryError(Box::new(incoming)));
749 }
750
751 self.clean_up_incoming(&incoming);
752 incoming.improper_drop_warner.dismiss();
753
754 let server_config = self.server_config.as_ref().unwrap();
755
756 let local_cid = self.local_cid_generator.generate_cid();
763
764 let payload = TokenPayload::Retry {
765 address: incoming.network_path.remote,
766 orig_dst_cid: incoming.packet.header.dst_cid,
767 issued: server_config.time_source.now(),
768 };
769 let token = Token::new(payload, &mut self.rng).encode(&*server_config.token_key);
770
771 let header = Header::Retry {
772 src_cid: local_cid,
773 dst_cid: incoming.packet.header.src_cid,
774 version: incoming.packet.header.version,
775 };
776
777 let encode = header.encode(buf);
778 buf.put_slice(&token);
779 buf.extend_from_slice(&server_config.crypto.retry_tag(
780 incoming.packet.header.version,
781 incoming.packet.header.dst_cid,
782 buf,
783 ));
784 encode.finish(buf, &*incoming.crypto.header.local, None);
785
786 Ok(Transmit {
787 destination: incoming.network_path.remote,
788 ecn: None,
789 size: buf.len(),
790 segment_size: None,
791 src_ip: incoming.network_path.local_ip,
792 })
793 }
794
795 pub fn ignore(&mut self, incoming: Incoming) {
800 self.clean_up_incoming(&incoming);
801 incoming.improper_drop_warner.dismiss();
802 }
803
804 fn clean_up_incoming(&mut self, incoming: &Incoming) {
806 self.index.remove_initial(incoming.packet.header.dst_cid);
807 let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
808 self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
809 }
810
811 fn add_connection(
812 &mut self,
813 ch: ConnectionHandle,
814 version: u32,
815 init_cid: ConnectionId,
816 local_cid: ConnectionId,
817 remote_cid: ConnectionId,
818 network_path: FourTuple,
819 now: Instant,
820 tls: Box<dyn crypto::Session>,
821 transport_config: Arc<TransportConfig>,
822 side_args: SideArgs,
823 params: &TransportParameters,
825 ) -> Connection {
826 let mut rng_seed = [0; 32];
827 self.rng.fill_bytes(&mut rng_seed);
828 let side = side_args.side();
829 let pref_addr_cid = side_args.pref_addr_cid();
830
831 let qlog =
832 transport_config.create_qlog_sink(side_args.side(), network_path.remote, init_cid, now);
833
834 qlog.emit_connection_started(
835 now,
836 local_cid,
837 remote_cid,
838 network_path.remote,
839 network_path.local_ip,
840 params,
841 );
842
843 let conn = Connection::new(
844 self.config.clone(),
845 transport_config,
846 init_cid,
847 local_cid,
848 remote_cid,
849 network_path,
850 tls,
851 self.local_cid_generator.as_ref(),
852 now,
853 version,
854 self.allow_mtud,
855 rng_seed,
856 side_args,
857 qlog,
858 );
859
860 let mut path_cids = PathLocalCids::default();
861 path_cids.cids.insert(path_cids.issued, local_cid);
862 path_cids.issued += 1;
863
864 if let Some(cid) = pref_addr_cid {
865 debug_assert_eq!(path_cids.issued, 1, "preferred address cid seq must be 1");
866 path_cids.cids.insert(path_cids.issued, cid);
867 path_cids.issued += 1;
868 }
869
870 let id = self.connections.insert(ConnectionMeta {
871 init_cid,
872 local_cids: FxHashMap::from_iter([(PathId::ZERO, path_cids)]),
873 network_path,
874 side,
875 reset_token: Default::default(),
876 });
877 debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
878
879 self.index.insert_conn(network_path, local_cid, ch, side);
880
881 conn
882 }
883
884 fn initial_close(
885 &mut self,
886 version: u32,
887 network_path: FourTuple,
888 crypto: &Keys,
889 remote_id: ConnectionId,
890 reason: TransportError,
891 buf: &mut Vec<u8>,
892 ) -> Transmit {
893 let local_id = self.local_cid_generator.generate_cid();
897 let number = PacketNumber::U8(0);
898 let header = Header::Initial(InitialHeader {
899 dst_cid: remote_id,
900 src_cid: local_id,
901 number,
902 token: Bytes::new(),
903 version,
904 });
905
906 let partial_encode = header.encode(buf);
907 let max_len =
908 INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len();
909 frame::Close::from(reason).encoder(max_len).encode(buf);
910 buf.resize(buf.len() + crypto.packet.local.tag_len(), 0);
911 partial_encode.finish(
912 buf,
913 &*crypto.header.local,
914 Some((0, Default::default(), &*crypto.packet.local)),
915 );
916 Transmit {
917 destination: network_path.remote,
918 ecn: None,
919 size: buf.len(),
920 segment_size: None,
921 src_ip: network_path.local_ip,
922 }
923 }
924
925 pub fn config(&self) -> &EndpointConfig {
927 &self.config
928 }
929
930 pub fn open_connections(&self) -> usize {
932 self.connections.len()
933 }
934
935 pub fn incoming_buffer_bytes(&self) -> u64 {
938 self.all_incoming_buffers_total_bytes
939 }
940
941 #[cfg(test)]
942 pub(crate) fn known_connections(&self) -> usize {
943 let x = self.connections.len();
944 debug_assert_eq!(x, self.index.connection_ids_initial.len());
945 debug_assert!(x >= self.index.connection_reset_tokens.0.len());
947 debug_assert!(x >= self.index.incoming_connection_remotes.len());
949 debug_assert!(x >= self.index.outgoing_connection_remotes.len());
950 x
951 }
952
953 #[cfg(test)]
954 pub(crate) fn known_cids(&self) -> usize {
955 self.index.connection_ids.len()
956 }
957
958 fn cids_exhausted(&self) -> bool {
963 let cid_len = self.local_cid_generator.cid_len();
964 if cid_len == 0 || cid_len > 4 {
965 return false;
966 }
967
968 let bits = (cid_len * 8) as u32;
970 let space = 1u64 << bits;
971 let reserve = 1u64 << (bits - 2);
972 let len = self.index.connection_ids.len() as u64;
973
974 len > (space - reserve)
975 }
976}
977
978impl fmt::Debug for Endpoint {
979 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
980 fmt.debug_struct("Endpoint")
981 .field("rng", &self.rng)
982 .field("index", &self.index)
983 .field("connections", &self.connections)
984 .field("config", &self.config)
985 .field("server_config", &self.server_config)
986 .field("incoming_buffers.len", &self.incoming_buffers.len())
988 .field(
989 "all_incoming_buffers_total_bytes",
990 &self.all_incoming_buffers_total_bytes,
991 )
992 .finish()
993 }
994}
995
996#[derive(Default)]
998struct IncomingBuffer {
999 datagrams: Vec<DatagramConnectionEvent>,
1000 total_bytes: u64,
1001}
1002
1003#[derive(Copy, Clone, Debug)]
1005enum RouteDatagramTo {
1006 Incoming(usize),
1007 Connection(ConnectionHandle, PathId),
1008}
1009
1010#[derive(Default, Debug)]
1012struct ConnectionIndex {
1013 connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
1019 connection_ids: FxHashMap<ConnectionId, (ConnectionHandle, PathId)>,
1023 incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
1027 outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
1037 connection_reset_tokens: ResetTokenTable,
1042}
1043
1044impl ConnectionIndex {
1045 fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
1047 if dst_cid.is_empty() {
1048 return;
1049 }
1050 self.connection_ids_initial
1051 .insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
1052 }
1053
1054 fn remove_initial(&mut self, dst_cid: ConnectionId) {
1056 if dst_cid.is_empty() {
1057 return;
1058 }
1059 let removed = self.connection_ids_initial.remove(&dst_cid);
1060 debug_assert!(removed.is_some());
1061 }
1062
1063 fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
1065 if dst_cid.is_empty() {
1066 return;
1067 }
1068 self.connection_ids_initial.insert(
1069 dst_cid,
1070 RouteDatagramTo::Connection(connection, PathId::ZERO),
1071 );
1072 }
1073
1074 fn insert_conn(
1077 &mut self,
1078 network_path: FourTuple,
1079 dst_cid: ConnectionId,
1080 connection: ConnectionHandle,
1081 side: Side,
1082 ) {
1083 match dst_cid.len() {
1084 0 => match side {
1085 Side::Server => {
1086 self.incoming_connection_remotes
1087 .insert(network_path, connection);
1088 }
1089 Side::Client => {
1090 self.outgoing_connection_remotes
1091 .insert(network_path.remote, connection);
1092 }
1093 },
1094 _ => {
1095 self.connection_ids
1096 .insert(dst_cid, (connection, PathId::ZERO));
1097 }
1098 }
1099 }
1100
1101 fn retire(&mut self, dst_cid: ConnectionId) {
1103 self.connection_ids.remove(&dst_cid);
1104 }
1105
1106 fn remove(&mut self, conn: &ConnectionMeta) {
1108 if conn.side.is_server() {
1109 self.remove_initial(conn.init_cid);
1110 }
1111 for cid in conn
1112 .local_cids
1113 .values()
1114 .flat_map(|pcids| pcids.cids.values())
1115 {
1116 self.connection_ids.remove(cid);
1117 }
1118 self.incoming_connection_remotes.remove(&conn.network_path);
1119 self.outgoing_connection_remotes
1120 .remove(&conn.network_path.remote);
1121 for (remote, token) in conn.reset_token.values() {
1122 self.connection_reset_tokens.remove(*remote, *token);
1123 }
1124 }
1125
1126 fn get(&self, network_path: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
1128 if !datagram.dst_cid().is_empty()
1129 && let Some(&(ch, path_id)) = self.connection_ids.get(&datagram.dst_cid())
1130 {
1131 return Some(RouteDatagramTo::Connection(ch, path_id));
1132 }
1133 if (datagram.is_initial() || datagram.is_0rtt())
1134 && let Some(&ch) = self.connection_ids_initial.get(&datagram.dst_cid())
1135 {
1136 return Some(ch);
1137 }
1138 if datagram.dst_cid().is_empty() {
1139 if let Some(&ch) = self.incoming_connection_remotes.get(network_path) {
1140 return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1143 }
1144 if let Some(&ch) = self.outgoing_connection_remotes.get(&network_path.remote) {
1145 return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1147 }
1148 }
1149 let data = datagram.data();
1150 if data.len() < RESET_TOKEN_SIZE {
1151 return None;
1152 }
1153 self.connection_reset_tokens
1156 .get(network_path.remote, &data[data.len() - RESET_TOKEN_SIZE..])
1157 .cloned()
1158 .map(|ch| RouteDatagramTo::Connection(ch, PathId::ZERO))
1159 }
1160}
1161
1162#[derive(Debug)]
1163pub(crate) struct ConnectionMeta {
1164 init_cid: ConnectionId,
1165 local_cids: FxHashMap<PathId, PathLocalCids>,
1167 network_path: FourTuple,
1172 side: Side,
1173 reset_token: FxHashMap<PathId, (SocketAddr, ResetToken)>,
1183}
1184
1185#[derive(Debug, Default)]
1187struct PathLocalCids {
1188 issued: u64,
1192 cids: FxHashMap<u64, ConnectionId>,
1194}
1195
1196#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
1198pub struct ConnectionHandle(pub usize);
1199
1200impl From<ConnectionHandle> for usize {
1201 fn from(x: ConnectionHandle) -> Self {
1202 x.0
1203 }
1204}
1205
1206impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
1207 type Output = ConnectionMeta;
1208 fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
1209 &self[ch.0]
1210 }
1211}
1212
1213impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
1214 fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
1215 &mut self[ch.0]
1216 }
1217}
1218
1219pub enum DatagramEvent {
1221 ConnectionEvent(ConnectionHandle, ConnectionEvent),
1223 NewConnection(Incoming),
1225 Response(Transmit),
1227}
1228
1229#[derive(derive_more::Debug)]
1231pub struct Incoming {
1232 #[debug(skip)]
1233 received_at: Instant,
1234 network_path: FourTuple,
1235 ecn: Option<EcnCodepoint>,
1236 #[debug(skip)]
1237 packet: InitialPacket,
1238 #[debug(skip)]
1239 rest: Option<BytesMut>,
1240 #[debug(skip)]
1241 crypto: Keys,
1242 token: IncomingToken,
1243 incoming_idx: usize,
1244 #[debug(skip)]
1245 improper_drop_warner: IncomingImproperDropWarner,
1246}
1247
1248impl Incoming {
1249 pub fn local_ip(&self) -> Option<IpAddr> {
1251 self.network_path.local_ip
1252 }
1253
1254 pub fn remote_address(&self) -> SocketAddr {
1256 self.network_path.remote
1257 }
1258
1259 pub fn remote_address_validated(&self) -> bool {
1267 self.token.validated
1268 }
1269
1270 pub fn may_retry(&self) -> bool {
1275 self.token.retry_src_cid.is_none()
1276 }
1277
1278 pub fn orig_dst_cid(&self) -> ConnectionId {
1280 self.token.orig_dst_cid
1281 }
1282
1283 pub fn decrypt(&self) -> Option<DecryptedInitial> {
1288 let packet_number = self.packet.header.number.expand(0);
1289 let mut payload = self.packet.payload.clone();
1290 self.crypto
1291 .packet
1292 .remote
1293 .decrypt(
1294 PathId::ZERO,
1295 packet_number,
1296 &self.packet.header_data,
1297 &mut payload,
1298 )
1299 .ok()?;
1300 Some(DecryptedInitial(payload.freeze()))
1301 }
1302}
1303
1304pub struct DecryptedInitial(Bytes);
1309
1310impl DecryptedInitial {
1311 pub fn alpns(&self) -> Option<IncomingAlpns> {
1317 let frames = frame::Iter::new(self.0.clone()).ok()?;
1318 let mut first = None;
1319 let mut rest = Vec::new();
1320 for frame in frames {
1321 match frame {
1322 Ok(frame::Frame::Crypto(crypto)) => match first {
1323 None => first = Some(crypto),
1324 Some(_) => rest.push(crypto),
1325 },
1326 Err(_) => return None,
1327 _ => {}
1328 }
1329 }
1330 let first = first?;
1331
1332 if rest.is_empty() && first.offset == 0 {
1334 let data = find_alpn_data(&first.data).ok()?;
1335 return Some(IncomingAlpns { data, pos: 0 });
1336 }
1337
1338 rest.push(first);
1340 let source = assemble_crypto_frames(&mut rest)?;
1341 let data = find_alpn_data(&source).ok()?;
1342 Some(IncomingAlpns { data, pos: 0 })
1343 }
1344}
1345
1346const TLS_HANDSHAKE_TYPE_CLIENT_HELLO: u8 = 0x01;
1349const TLS_EXTENSION_TYPE_ALPN: u16 = 0x0010;
1352const TLS_CLIENT_HELLO_FIXED_LEN: usize = 2 + 32;
1355
1356pub struct IncomingAlpns {
1361 data: Bytes,
1362 pos: usize,
1363}
1364
1365impl Iterator for IncomingAlpns {
1366 type Item = Result<Bytes, UnexpectedEnd>;
1367
1368 fn next(&mut self) -> Option<Self::Item> {
1369 if self.pos >= self.data.len() {
1370 return None;
1371 }
1372 let len = self.data[self.pos] as usize;
1373 self.pos += 1;
1374 if self.pos + len > self.data.len() {
1375 return Some(Err(UnexpectedEnd));
1376 }
1377 let proto = self.data.slice(self.pos..self.pos + len);
1378 self.pos += len;
1379 Some(Ok(proto))
1380 }
1381}
1382
1383fn assemble_crypto_frames(frames: &mut [frame::Crypto]) -> Option<Bytes> {
1387 frames.sort_by_key(|f| f.offset);
1388 let capacity = frames.iter().map(|f| f.data.len()).sum();
1389 let mut buf = Vec::with_capacity(capacity);
1390 for f in frames.iter() {
1391 let start = f.offset as usize;
1392 if start > buf.len() {
1393 return None;
1394 }
1395 let end = start + f.data.len();
1396 if end > buf.len() {
1397 buf.extend_from_slice(&f.data[buf.len() - start..]);
1398 }
1399 }
1400 Some(Bytes::from(buf))
1401}
1402
1403fn find_alpn_data(source: &Bytes) -> Result<Bytes, UnexpectedEnd> {
1409 let mut r = &**source;
1410
1411 if u8::decode(&mut r)? != TLS_HANDSHAKE_TYPE_CLIENT_HELLO {
1412 return Err(UnexpectedEnd);
1413 }
1414
1415 let len = decode_u24(&mut r)?;
1417 let mut body = take(&mut r, len)?;
1418
1419 skip(&mut body, TLS_CLIENT_HELLO_FIXED_LEN)?;
1421
1422 skip_u8_prefixed(&mut body)?;
1424 skip_u16_prefixed(&mut body)?;
1425 skip_u8_prefixed(&mut body)?;
1426
1427 let mut exts = take_u16_prefixed(&mut body)?;
1429 while exts.has_remaining() {
1430 let ext_type = u16::decode(&mut exts)?;
1431 let ext_data = take_u16_prefixed(&mut exts)?;
1432 if ext_type == TLS_EXTENSION_TYPE_ALPN {
1433 let list = take_u16_prefixed(&mut &*ext_data)?;
1434 return Ok(source.slice_ref(list));
1435 }
1436 }
1437 Err(UnexpectedEnd)
1438}
1439
1440fn decode_u24(r: &mut &[u8]) -> Result<usize, UnexpectedEnd> {
1442 let a = u8::decode(r)?;
1443 let b = u8::decode(r)?;
1444 let c = u8::decode(r)?;
1445 Ok(u32::from_be_bytes([0, a, b, c]) as usize)
1446}
1447
1448fn take<'a>(r: &mut &'a [u8], len: usize) -> Result<&'a [u8], UnexpectedEnd> {
1450 if r.remaining() < len {
1451 return Err(UnexpectedEnd);
1452 }
1453 let data = &r[..len];
1454 r.advance(len);
1455 Ok(data)
1456}
1457
1458fn take_u16_prefixed<'a>(r: &mut &'a [u8]) -> Result<&'a [u8], UnexpectedEnd> {
1460 let len = u16::decode(r)? as usize;
1461 take(r, len)
1462}
1463
1464fn skip(r: &mut &[u8], len: usize) -> Result<(), UnexpectedEnd> {
1466 take(r, len)?;
1467 Ok(())
1468}
1469
1470fn skip_u8_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1472 let len = u8::decode(r)? as usize;
1473 skip(r, len)
1474}
1475
1476fn skip_u16_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1478 let len = u16::decode(r)? as usize;
1479 skip(r, len)
1480}
1481
1482struct IncomingImproperDropWarner;
1483
1484impl IncomingImproperDropWarner {
1485 fn dismiss(self) {
1486 mem::forget(self);
1487 }
1488}
1489
1490impl Drop for IncomingImproperDropWarner {
1491 fn drop(&mut self) {
1492 warn!(
1493 "noq_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \
1494 (may cause memory leak and eventual inability to accept new connections)"
1495 );
1496 }
1497}
1498
1499#[derive(Debug, Error, Clone, PartialEq, Eq)]
1503pub enum ConnectError {
1504 #[error("endpoint stopping")]
1508 EndpointStopping,
1509 #[error("CIDs exhausted")]
1513 CidsExhausted,
1514 #[error("invalid server name: {0}")]
1516 InvalidServerName(String),
1517 #[error("invalid remote address: {0}")]
1521 InvalidRemoteAddress(SocketAddr),
1522 #[error("no default client config")]
1526 NoDefaultClientConfig,
1527 #[error("unsupported QUIC version")]
1529 UnsupportedVersion,
1530}
1531
1532#[derive(Debug)]
1534pub struct AcceptError {
1535 pub cause: ConnectionError,
1537 pub response: Option<Transmit>,
1539}
1540
1541#[derive(Debug, Error)]
1543#[error("retry() with validated Incoming")]
1544pub struct RetryError(Box<Incoming>);
1545
1546impl RetryError {
1547 pub fn into_incoming(self) -> Incoming {
1549 *self.0
1550 }
1551}
1552
1553#[derive(Default, Debug)]
1558struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
1559
1560impl ResetTokenTable {
1561 fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
1562 self.0
1563 .entry(remote)
1564 .or_default()
1565 .insert(token, ch)
1566 .is_some()
1567 }
1568
1569 fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
1570 use std::collections::hash_map::Entry;
1571 match self.0.entry(remote) {
1572 Entry::Vacant(_) => {}
1573 Entry::Occupied(mut e) => {
1574 e.get_mut().remove(&token);
1575 if e.get().is_empty() {
1576 e.remove_entry();
1577 }
1578 }
1579 }
1580 }
1581
1582 fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
1583 let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?);
1584 self.0.get(&remote)?.get(&token)
1585 }
1586}
1587
1588#[cfg(test)]
1589mod tests {
1590 use super::*;
1591
1592 #[test]
1593 fn assemble_contiguous() {
1594 let data = b"hello world";
1595 let mut frames = vec![
1596 frame::Crypto {
1597 offset: 0,
1598 data: Bytes::from_static(&data[..5]),
1599 },
1600 frame::Crypto {
1601 offset: 5,
1602 data: Bytes::from_static(&data[5..]),
1603 },
1604 ];
1605 assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1606 }
1607
1608 #[test]
1609 fn assemble_out_of_order() {
1610 let data = b"hello world";
1611 let mut frames = vec![
1612 frame::Crypto {
1613 offset: 5,
1614 data: Bytes::from_static(&data[5..]),
1615 },
1616 frame::Crypto {
1617 offset: 0,
1618 data: Bytes::from_static(&data[..5]),
1619 },
1620 ];
1621 assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1622 }
1623
1624 #[test]
1625 fn assemble_with_overlap() {
1626 let data = b"hello world";
1627 let mut frames = vec![
1628 frame::Crypto {
1629 offset: 0,
1630 data: Bytes::from_static(&data[..7]),
1631 },
1632 frame::Crypto {
1633 offset: 5,
1634 data: Bytes::from_static(&data[5..]),
1635 },
1636 ];
1637 assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1638 }
1639
1640 #[test]
1641 fn assemble_with_gap() {
1642 let mut frames = vec![frame::Crypto {
1643 offset: 10,
1644 data: Bytes::from_static(b"world"),
1645 }];
1646 assert!(assemble_crypto_frames(&mut frames).is_none());
1647 }
1648}