1use std::{
2 collections::{HashMap, hash_map},
3 convert::TryFrom,
4 fmt, mem,
5 net::{IpAddr, SocketAddr},
6 ops::{Index, IndexMut},
7 sync::Arc,
8};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use rand::{Rng, RngExt, SeedableRng, rngs::StdRng};
12use rustc_hash::FxHashMap;
13use slab::Slab;
14use thiserror::Error;
15use tracing::{debug, error, trace, warn};
16
17use crate::{
18 Duration, FourTuple, INITIAL_MTU, Instant, MAX_CID_SIZE, MIN_INITIAL_SIZE, PathId,
19 RESET_TOKEN_SIZE, ResetToken, Side, Transmit, TransportConfig, TransportError,
20 cid_generator::ConnectionIdGenerator,
21 coding::{BufMutExt, Decodable, Encodable, UnexpectedEnd},
22 config::{ClientConfig, EndpointConfig, ServerConfig},
23 connection::{Connection, ConnectionError, SideArgs},
24 crypto::{self, Keys, UnsupportedVersion},
25 frame,
26 packet::{
27 FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, PacketDecodeError,
28 PacketNumber, PartialDecode, ProtectedInitialHeader,
29 },
30 shared::{
31 ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint,
32 EndpointEvent, EndpointEventInner, IssuedCid,
33 },
34 token::{IncomingToken, InvalidRetryTokenError, Token, TokenPayload},
35 transport_parameters::{PreferredAddress, TransportParameters},
36};
37
38pub struct Endpoint {
43 rng: StdRng,
44 index: ConnectionIndex,
45 connections: Slab<ConnectionMeta>,
46 local_cid_generator: Box<dyn ConnectionIdGenerator>,
47 config: Arc<EndpointConfig>,
48 server_config: Option<Arc<ServerConfig>>,
49 allow_mtud: bool,
51 last_stateless_reset: Option<Instant>,
53 incoming_buffers: Slab<IncomingBuffer>,
55 all_incoming_buffers_total_bytes: u64,
56}
57
58impl Endpoint {
59 pub fn new(
65 config: Arc<EndpointConfig>,
66 server_config: Option<Arc<ServerConfig>>,
67 allow_mtud: bool,
68 ) -> Self {
69 Self {
70 rng: config
71 .rng_seed
72 .map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::from_seed),
73 index: ConnectionIndex::default(),
74 connections: Slab::new(),
75 local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
76 config,
77 server_config,
78 allow_mtud,
79 last_stateless_reset: None,
80 incoming_buffers: Slab::new(),
81 all_incoming_buffers_total_bytes: 0,
82 }
83 }
84
85 pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
87 self.server_config = server_config;
88 }
89
90 pub fn handle_event(
94 &mut self,
95 ch: ConnectionHandle,
96 event: EndpointEvent,
97 ) -> Option<ConnectionEvent> {
98 use EndpointEventInner::*;
99 match event.0 {
100 NeedIdentifiers(path_id, now, n) => {
101 return Some(self.send_new_identifiers(path_id, now, ch, n));
102 }
103 ResetToken(path_id, remote, token) => {
104 if let Some(old) = self.connections[ch]
105 .reset_token
106 .insert(path_id, (remote, token))
107 {
108 self.index.connection_reset_tokens.remove(old.0, old.1);
109 }
110 if self.index.connection_reset_tokens.insert(remote, token, ch) {
111 warn!("duplicate reset token");
112 }
113 }
114 RetireResetToken(path_id) => {
115 if let Some(old) = self.connections[ch].reset_token.remove(&path_id) {
116 self.index.connection_reset_tokens.remove(old.0, old.1);
117 }
118 }
119 RetireConnectionId(now, path_id, seq, allow_more_cids) => {
120 if let Some(cid) = self.connections[ch]
121 .local_cids
122 .get_mut(&path_id)
123 .and_then(|pcid| pcid.cids.remove(&seq))
124 {
125 trace!(%path_id, "local CID retired {}: {}", seq, cid);
126 self.index.retire(cid);
127 if allow_more_cids {
128 return Some(self.send_new_identifiers(path_id, now, ch, 1));
129 }
130 }
131 }
132 Drained => {
133 if let Some(conn) = self.connections.try_remove(ch.0) {
134 self.index.remove(&conn);
135 } else {
136 error!(id = ch.0, "unknown connection drained");
140 }
141 }
142 }
143 None
144 }
145
146 pub fn handle(
148 &mut self,
149 now: Instant,
150 network_path: FourTuple,
151 ecn: Option<EcnCodepoint>,
152 data: BytesMut,
153 buf: &mut Vec<u8>,
154 ) -> Option<DatagramEvent> {
155 let datagram_len = data.len();
157 let mut event = match PartialDecode::new(
158 data,
159 &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()),
160 &self.config.supported_versions,
161 self.config.grease_quic_bit,
162 ) {
163 Ok((first_decode, remaining)) => DatagramConnectionEvent {
164 now,
165 network_path,
166 path_id: PathId::ZERO, ecn,
168 first_decode,
169 remaining,
170 },
171 Err(PacketDecodeError::UnsupportedVersion {
172 src_cid,
173 dst_cid,
174 version,
175 }) => {
176 if self.server_config.is_none() {
177 debug!("dropping packet with unsupported version");
178 return None;
179 }
180 trace!("sending version negotiation");
181 Header::VersionNegotiate {
183 random: self.rng.random::<u8>() | 0x40,
184 src_cid: dst_cid,
185 dst_cid: src_cid,
186 }
187 .encode(buf);
188 buf.write::<u32>(match version {
190 0x0a1a_2a3a => 0x0a1a_2a4a,
191 _ => 0x0a1a_2a3a,
192 });
193 for &version in &self.config.supported_versions {
194 buf.write(version);
195 }
196 return Some(DatagramEvent::Response(Transmit {
197 destination: network_path.remote,
198 ecn: None,
199 size: buf.len(),
200 segment_size: None,
201 src_ip: network_path.local_ip,
202 }));
203 }
204 Err(e) => {
205 trace!("malformed header: {}", e);
206 return None;
207 }
208 };
209
210 let dst_cid = event.first_decode.dst_cid();
211
212 if let Some(route_to) = self.index.get(&network_path, &event.first_decode) {
213 event.path_id = match route_to {
214 RouteDatagramTo::Incoming(_) => PathId::ZERO,
215 RouteDatagramTo::Connection(_, path_id) => path_id,
216 };
217 match route_to {
218 RouteDatagramTo::Incoming(incoming_idx) => {
219 let incoming_buffer = &mut self.incoming_buffers[incoming_idx];
220 let config = &self.server_config.as_ref().unwrap();
221
222 if incoming_buffer
223 .total_bytes
224 .checked_add(datagram_len as u64)
225 .is_some_and(|n| n <= config.incoming_buffer_size)
226 && self
227 .all_incoming_buffers_total_bytes
228 .checked_add(datagram_len as u64)
229 .is_some_and(|n| n <= config.incoming_buffer_size_total)
230 {
231 incoming_buffer.datagrams.push(event);
232 incoming_buffer.total_bytes += datagram_len as u64;
233 self.all_incoming_buffers_total_bytes += datagram_len as u64;
234 }
235
236 None
237 }
238 RouteDatagramTo::Connection(ch, _path_id) => Some(DatagramEvent::ConnectionEvent(
239 ch,
240 ConnectionEvent(ConnectionEventInner::Datagram(event)),
241 )),
242 }
243 } else if event.first_decode.initial_header().is_some() {
244 self.handle_first_packet(datagram_len, event, network_path, buf)
247 } else if event.first_decode.has_long_header() {
248 debug!(
249 "ignoring non-initial packet for unknown connection {}",
250 dst_cid
251 );
252 None
253 } else if !event.first_decode.is_initial()
254 && self.local_cid_generator.validate(dst_cid).is_err()
255 {
256 debug!("dropping packet with invalid CID");
257 None
258 } else if dst_cid.is_empty() {
259 trace!("dropping unrecognized short packet without ID");
260 None
261 } else {
262 self.stateless_reset(now, datagram_len, network_path, dst_cid, buf)
265 .map(DatagramEvent::Response)
266 }
267 }
268
269 fn stateless_reset(
271 &mut self,
272 now: Instant,
273 inciting_dgram_len: usize,
274 network_path: FourTuple,
275 dst_cid: ConnectionId,
276 buf: &mut Vec<u8>,
277 ) -> Option<Transmit> {
278 if self
279 .last_stateless_reset
280 .is_some_and(|last| last + self.config.min_reset_interval > now)
281 {
282 debug!("ignoring unexpected packet within minimum stateless reset interval");
283 return None;
284 }
285
286 const MIN_PADDING_LEN: usize = 5;
288
289 let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
292 Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
293 _ => {
294 debug!(
295 "ignoring unexpected {} byte packet: not larger than minimum stateless reset size",
296 inciting_dgram_len
297 );
298 return None;
299 }
300 };
301
302 debug!(%dst_cid, %network_path.remote, "sending stateless reset");
303 self.last_stateless_reset = Some(now);
304 const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
306 let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
307 max_padding_len
308 } else {
309 self.rng
310 .random_range(IDEAL_MIN_PADDING_LEN..max_padding_len)
311 };
312 buf.reserve(padding_len + RESET_TOKEN_SIZE);
313 buf.resize(padding_len, 0);
314 self.rng.fill_bytes(&mut buf[0..padding_len]);
315 buf[0] = 0b0100_0000 | (buf[0] >> 2);
316 buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));
317
318 debug_assert!(buf.len() < inciting_dgram_len);
319
320 Some(Transmit {
321 destination: network_path.remote,
322 ecn: None,
323 size: buf.len(),
324 segment_size: None,
325 src_ip: network_path.local_ip,
326 })
327 }
328
329 pub fn connect(
331 &mut self,
332 now: Instant,
333 config: ClientConfig,
334 remote: SocketAddr,
335 server_name: &str,
336 ) -> Result<(ConnectionHandle, Connection), ConnectError> {
337 if self.cids_exhausted() {
338 return Err(ConnectError::CidsExhausted);
339 }
340 if remote.port() == 0 || remote.ip().is_unspecified() {
341 return Err(ConnectError::InvalidRemoteAddress(remote));
342 }
343 if !self.config.supported_versions.contains(&config.version) {
344 return Err(ConnectError::UnsupportedVersion);
345 }
346
347 let remote_id = (config.initial_dst_cid_provider)();
348 trace!(initial_dcid = %remote_id);
349
350 let ch = ConnectionHandle(self.connections.vacant_key());
351 let local_cid = self.new_cid(ch, PathId::ZERO);
352 let params = TransportParameters::new(
353 &config.transport,
354 &self.config,
355 self.local_cid_generator.as_ref(),
356 local_cid,
357 None,
358 &mut self.rng,
359 );
360 let tls = config
361 .crypto
362 .start_session(config.version, server_name, ¶ms)?;
363
364 let conn = self.add_connection(
365 ch,
366 config.version,
367 remote_id,
368 local_cid,
369 remote_id,
370 FourTuple {
371 remote,
372 local_ip: None,
373 },
374 now,
375 tls,
376 config.transport,
377 SideArgs::Client {
378 token_store: config.token_store,
379 server_name: server_name.into(),
380 },
381 ¶ms,
382 );
383 Ok((ch, conn))
384 }
385
386 fn send_new_identifiers(
388 &mut self,
389 path_id: PathId,
390 now: Instant,
391 ch: ConnectionHandle,
392 num: u64,
393 ) -> ConnectionEvent {
394 let mut ids = vec![];
395 for _ in 0..num {
396 let id = self.new_cid(ch, path_id);
397 let cid_meta = self.connections[ch].local_cids.entry(path_id).or_default();
398 let sequence = cid_meta.issued;
399 cid_meta.issued += 1;
400 cid_meta.cids.insert(sequence, id);
401 ids.push(IssuedCid {
402 path_id,
403 sequence,
404 id,
405 reset_token: ResetToken::new(&*self.config.reset_key, id),
406 });
407 }
408 ConnectionEvent(ConnectionEventInner::NewIdentifiers(
409 ids,
410 now,
411 self.local_cid_generator.cid_len(),
412 self.local_cid_generator.cid_lifetime(),
413 ))
414 }
415
416 fn new_cid(&mut self, ch: ConnectionHandle, path_id: PathId) -> ConnectionId {
418 loop {
419 let cid = self.local_cid_generator.generate_cid();
420 if cid.is_empty() {
421 debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
423 return cid;
424 }
425 if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
426 e.insert((ch, path_id));
427 break cid;
428 }
429 }
430 }
431
432 fn handle_first_packet(
433 &mut self,
434 datagram_len: usize,
435 event: DatagramConnectionEvent,
436 network_path: FourTuple,
437 buf: &mut Vec<u8>,
438 ) -> Option<DatagramEvent> {
439 let dst_cid = event.first_decode.dst_cid();
440 let header = event.first_decode.initial_header().unwrap();
441
442 let Some(server_config) = &self.server_config else {
443 debug!("packet for unrecognized connection {}", dst_cid);
444 return self
445 .stateless_reset(event.now, datagram_len, network_path, dst_cid, buf)
446 .map(DatagramEvent::Response);
447 };
448
449 if datagram_len < MIN_INITIAL_SIZE as usize {
450 debug!("ignoring short initial for connection {}", dst_cid);
451 return None;
452 }
453
454 let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) {
455 Ok(keys) => keys,
456 Err(UnsupportedVersion) => {
457 debug!(
460 "ignoring initial packet version {:#x} unsupported by cryptographic layer",
461 header.version
462 );
463 return None;
464 }
465 };
466
467 if let Err(reason) = self.early_validate_first_packet(header) {
468 return Some(DatagramEvent::Response(self.initial_close(
469 header.version,
470 network_path,
471 &crypto,
472 header.src_cid,
473 reason,
474 buf,
475 )));
476 }
477
478 let packet = match event.first_decode.finish(Some(&*crypto.header.remote)) {
479 Ok(packet) => packet,
480 Err(e) => {
481 trace!("unable to decode initial packet: {}", e);
482 return None;
483 }
484 };
485
486 if !packet.reserved_bits_valid() {
487 debug!("dropping connection attempt with invalid reserved bits");
488 return None;
489 }
490
491 let Header::Initial(header) = packet.header else {
492 panic!("non-initial packet in handle_first_packet()");
493 };
494
495 let server_config = self.server_config.as_ref().unwrap().clone();
496
497 let token = match IncomingToken::from_header(&header, &server_config, network_path.remote) {
498 Ok(token) => token,
499 Err(InvalidRetryTokenError) => {
500 debug!("rejecting invalid retry token");
501 return Some(DatagramEvent::Response(self.initial_close(
502 header.version,
503 network_path,
504 &crypto,
505 header.src_cid,
506 TransportError::INVALID_TOKEN(""),
507 buf,
508 )));
509 }
510 };
511
512 let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default());
513 self.index
514 .insert_initial_incoming(header.dst_cid, incoming_idx);
515
516 Some(DatagramEvent::NewConnection(Incoming {
517 received_at: event.now,
518 network_path,
519 ecn: event.ecn,
520 packet: InitialPacket {
521 header,
522 header_data: packet.header_data,
523 payload: packet.payload,
524 },
525 rest: event.remaining,
526 crypto,
527 token,
528 incoming_idx,
529 improper_drop_warner: IncomingImproperDropWarner,
530 }))
531 }
532
533 pub fn accept(
536 &mut self,
537 mut incoming: Incoming,
538 now: Instant,
539 buf: &mut Vec<u8>,
540 server_config: Option<Arc<ServerConfig>>,
541 ) -> Result<(ConnectionHandle, Connection), Box<AcceptError>> {
542 let remote_address_validated = incoming.remote_address_validated();
543 incoming.improper_drop_warner.dismiss();
544 let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
545 self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
546
547 let packet_number = incoming.packet.header.number.expand(0);
548 let InitialHeader {
549 src_cid,
550 dst_cid,
551 version,
552 ..
553 } = incoming.packet.header;
554 let server_config =
555 server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());
556
557 if server_config
558 .transport
559 .max_idle_timeout
560 .is_some_and(|timeout| {
561 incoming.received_at + Duration::from_millis(timeout.into()) <= now
562 })
563 {
564 debug!("abandoning accept of stale initial");
565 self.index.remove_initial(dst_cid);
566 return Err(Box::new(AcceptError {
567 cause: ConnectionError::TimedOut,
568 response: None,
569 }));
570 }
571
572 if self.cids_exhausted() {
573 debug!("refusing connection");
574 self.index.remove_initial(dst_cid);
575 return Err(Box::new(AcceptError {
576 cause: ConnectionError::CidsExhausted,
577 response: Some(self.initial_close(
578 version,
579 incoming.network_path,
580 &incoming.crypto,
581 src_cid,
582 TransportError::CONNECTION_REFUSED(""),
583 buf,
584 )),
585 }));
586 }
587
588 if incoming
589 .crypto
590 .packet
591 .remote
592 .decrypt(
593 PathId::ZERO,
594 packet_number,
595 &incoming.packet.header_data,
596 &mut incoming.packet.payload,
597 )
598 .is_err()
599 {
600 debug!(packet_number, "failed to authenticate initial packet");
601 self.index.remove_initial(dst_cid);
602 return Err(Box::new(AcceptError {
603 cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
604 response: None,
605 }));
606 };
607
608 let ch = ConnectionHandle(self.connections.vacant_key());
609 let local_cid = self.new_cid(ch, PathId::ZERO);
610 let mut params = TransportParameters::new(
611 &server_config.transport,
612 &self.config,
613 self.local_cid_generator.as_ref(),
614 local_cid,
615 Some(&server_config),
616 &mut self.rng,
617 );
618 params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, local_cid));
619 params.original_dst_cid = Some(incoming.token.orig_dst_cid);
620 params.retry_src_cid = incoming.token.retry_src_cid;
621 let mut pref_addr_cid = None;
622 if server_config.has_preferred_address() {
623 let cid = self.new_cid(ch, PathId::ZERO);
624 pref_addr_cid = Some(cid);
625 params.preferred_address = Some(PreferredAddress {
626 address_v4: server_config.preferred_address_v4,
627 address_v6: server_config.preferred_address_v6,
628 connection_id: cid,
629 stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
630 });
631 }
632
633 let tls = server_config.crypto.start_session(version, ¶ms);
634 let transport_config = server_config.transport.clone();
635 let mut conn = self.add_connection(
636 ch,
637 version,
638 dst_cid,
639 local_cid,
640 src_cid,
641 incoming.network_path,
642 incoming.received_at,
643 tls,
644 transport_config,
645 SideArgs::Server {
646 server_config,
647 pref_addr_cid,
648 path_validated: remote_address_validated,
649 },
650 ¶ms,
651 );
652 self.index.insert_initial(dst_cid, ch);
653
654 match conn.handle_first_packet(
655 incoming.received_at,
656 incoming.network_path,
657 incoming.ecn,
658 packet_number,
659 incoming.packet,
660 incoming.rest,
661 ) {
662 Ok(()) => {
663 trace!(
664 id = ch.0,
665 icid = %dst_cid,
666 network_path = %incoming.network_path,
667 "new connection",
668 );
669
670 for event in incoming_buffer.datagrams {
671 conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event)))
672 }
673
674 Ok((ch, conn))
675 }
676 Err(e) => {
677 debug!("handshake failed: {}", e);
678 self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
679 let response = match e {
680 ConnectionError::TransportError(ref e) => Some(self.initial_close(
681 version,
682 incoming.network_path,
683 &incoming.crypto,
684 src_cid,
685 e.clone(),
686 buf,
687 )),
688 _ => None,
689 };
690 Err(Box::new(AcceptError { cause: e, response }))
691 }
692 }
693 }
694
695 fn early_validate_first_packet(
697 &mut self,
698 header: &ProtectedInitialHeader,
699 ) -> Result<(), TransportError> {
700 let config = &self.server_config.as_ref().unwrap();
701 if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming {
702 return Err(TransportError::CONNECTION_REFUSED(""));
703 }
704
705 if header.dst_cid.len() < 8
710 && (header.token_pos.is_empty()
711 || header.dst_cid.len() != self.local_cid_generator.cid_len())
712 {
713 debug!(
714 "rejecting connection due to invalid DCID length {}",
715 header.dst_cid.len()
716 );
717 return Err(TransportError::PROTOCOL_VIOLATION(
718 "invalid destination CID length",
719 ));
720 }
721
722 Ok(())
723 }
724
725 pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Transmit {
727 self.clean_up_incoming(&incoming);
728 incoming.improper_drop_warner.dismiss();
729
730 self.initial_close(
731 incoming.packet.header.version,
732 incoming.network_path,
733 &incoming.crypto,
734 incoming.packet.header.src_cid,
735 TransportError::CONNECTION_REFUSED(""),
736 buf,
737 )
738 }
739
740 pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Result<Transmit, RetryError> {
744 if !incoming.may_retry() {
745 return Err(RetryError(Box::new(incoming)));
746 }
747
748 self.clean_up_incoming(&incoming);
749 incoming.improper_drop_warner.dismiss();
750
751 let server_config = self.server_config.as_ref().unwrap();
752
753 let local_cid = self.local_cid_generator.generate_cid();
760
761 let payload = TokenPayload::Retry {
762 address: incoming.network_path.remote,
763 orig_dst_cid: incoming.packet.header.dst_cid,
764 issued: server_config.time_source.now(),
765 };
766 let token = Token::new(payload, &mut self.rng).encode(&*server_config.token_key);
767
768 let header = Header::Retry {
769 src_cid: local_cid,
770 dst_cid: incoming.packet.header.src_cid,
771 version: incoming.packet.header.version,
772 };
773
774 let encode = header.encode(buf);
775 buf.put_slice(&token);
776 buf.extend_from_slice(&server_config.crypto.retry_tag(
777 incoming.packet.header.version,
778 incoming.packet.header.dst_cid,
779 buf,
780 ));
781 encode.finish(buf, &*incoming.crypto.header.local, None);
782
783 Ok(Transmit {
784 destination: incoming.network_path.remote,
785 ecn: None,
786 size: buf.len(),
787 segment_size: None,
788 src_ip: incoming.network_path.local_ip,
789 })
790 }
791
792 pub fn ignore(&mut self, incoming: Incoming) {
797 self.clean_up_incoming(&incoming);
798 incoming.improper_drop_warner.dismiss();
799 }
800
801 fn clean_up_incoming(&mut self, incoming: &Incoming) {
803 self.index.remove_initial(incoming.packet.header.dst_cid);
804 let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
805 self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
806 }
807
808 fn add_connection(
809 &mut self,
810 ch: ConnectionHandle,
811 version: u32,
812 init_cid: ConnectionId,
813 local_cid: ConnectionId,
814 remote_cid: ConnectionId,
815 network_path: FourTuple,
816 now: Instant,
817 tls: Box<dyn crypto::Session>,
818 transport_config: Arc<TransportConfig>,
819 side_args: SideArgs,
820 params: &TransportParameters,
822 ) -> Connection {
823 let mut rng_seed = [0; 32];
824 self.rng.fill_bytes(&mut rng_seed);
825 let side = side_args.side();
826 let pref_addr_cid = side_args.pref_addr_cid();
827
828 let qlog =
829 transport_config.create_qlog_sink(side_args.side(), network_path.remote, init_cid, now);
830
831 qlog.emit_connection_started(
832 now,
833 local_cid,
834 remote_cid,
835 network_path.remote,
836 network_path.local_ip,
837 params,
838 );
839
840 let conn = Connection::new(
841 self.config.clone(),
842 transport_config,
843 init_cid,
844 local_cid,
845 remote_cid,
846 network_path,
847 tls,
848 self.local_cid_generator.as_ref(),
849 now,
850 version,
851 self.allow_mtud,
852 rng_seed,
853 side_args,
854 qlog,
855 );
856
857 let mut path_cids = PathLocalCids::default();
858 path_cids.cids.insert(path_cids.issued, local_cid);
859 path_cids.issued += 1;
860
861 if let Some(cid) = pref_addr_cid {
862 debug_assert_eq!(path_cids.issued, 1, "preferred address cid seq must be 1");
863 path_cids.cids.insert(path_cids.issued, cid);
864 path_cids.issued += 1;
865 }
866
867 let id = self.connections.insert(ConnectionMeta {
868 init_cid,
869 local_cids: FxHashMap::from_iter([(PathId::ZERO, path_cids)]),
870 network_path,
871 side,
872 reset_token: Default::default(),
873 });
874 debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
875
876 self.index.insert_conn(network_path, local_cid, ch, side);
877
878 conn
879 }
880
881 fn initial_close(
882 &mut self,
883 version: u32,
884 network_path: FourTuple,
885 crypto: &Keys,
886 remote_id: ConnectionId,
887 reason: TransportError,
888 buf: &mut Vec<u8>,
889 ) -> Transmit {
890 let local_id = self.local_cid_generator.generate_cid();
894 let number = PacketNumber::U8(0);
895 let header = Header::Initial(InitialHeader {
896 dst_cid: remote_id,
897 src_cid: local_id,
898 number,
899 token: Bytes::new(),
900 version,
901 });
902
903 let partial_encode = header.encode(buf);
904 let max_len =
905 INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len();
906 frame::Close::from(reason).encoder(max_len).encode(buf);
907 buf.resize(buf.len() + crypto.packet.local.tag_len(), 0);
908 partial_encode.finish(
909 buf,
910 &*crypto.header.local,
911 Some((0, Default::default(), &*crypto.packet.local)),
912 );
913 Transmit {
914 destination: network_path.remote,
915 ecn: None,
916 size: buf.len(),
917 segment_size: None,
918 src_ip: network_path.local_ip,
919 }
920 }
921
922 pub fn config(&self) -> &EndpointConfig {
924 &self.config
925 }
926
927 pub fn open_connections(&self) -> usize {
929 self.connections.len()
930 }
931
932 pub fn incoming_buffer_bytes(&self) -> u64 {
935 self.all_incoming_buffers_total_bytes
936 }
937
938 #[cfg(test)]
939 pub(crate) fn known_connections(&self) -> usize {
940 let x = self.connections.len();
941 debug_assert_eq!(x, self.index.connection_ids_initial.len());
942 debug_assert!(x >= self.index.connection_reset_tokens.0.len());
944 debug_assert!(x >= self.index.incoming_connection_remotes.len());
946 debug_assert!(x >= self.index.outgoing_connection_remotes.len());
947 x
948 }
949
950 #[cfg(test)]
951 pub(crate) fn known_cids(&self) -> usize {
952 self.index.connection_ids.len()
953 }
954
955 fn cids_exhausted(&self) -> bool {
960 let cid_len = self.local_cid_generator.cid_len();
961 if cid_len == 0 || cid_len > 4 {
962 return false;
963 }
964
965 let bits = (cid_len * 8) as u32;
967 let space = 1u64 << bits;
968 let reserve = 1u64 << (bits - 2);
969 let len = self.index.connection_ids.len() as u64;
970
971 len > (space - reserve)
972 }
973}
974
975impl fmt::Debug for Endpoint {
976 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
977 fmt.debug_struct("Endpoint")
978 .field("rng", &self.rng)
979 .field("index", &self.index)
980 .field("connections", &self.connections)
981 .field("config", &self.config)
982 .field("server_config", &self.server_config)
983 .field("incoming_buffers.len", &self.incoming_buffers.len())
985 .field(
986 "all_incoming_buffers_total_bytes",
987 &self.all_incoming_buffers_total_bytes,
988 )
989 .finish()
990 }
991}
992
993#[derive(Default)]
995struct IncomingBuffer {
996 datagrams: Vec<DatagramConnectionEvent>,
997 total_bytes: u64,
998}
999
1000#[derive(Copy, Clone, Debug)]
1002enum RouteDatagramTo {
1003 Incoming(usize),
1004 Connection(ConnectionHandle, PathId),
1005}
1006
1007#[derive(Default, Debug)]
1009struct ConnectionIndex {
1010 connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
1016 connection_ids: FxHashMap<ConnectionId, (ConnectionHandle, PathId)>,
1020 incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
1024 outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
1034 connection_reset_tokens: ResetTokenTable,
1039}
1040
1041impl ConnectionIndex {
1042 fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
1044 if dst_cid.is_empty() {
1045 return;
1046 }
1047 self.connection_ids_initial
1048 .insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
1049 }
1050
1051 fn remove_initial(&mut self, dst_cid: ConnectionId) {
1053 if dst_cid.is_empty() {
1054 return;
1055 }
1056 let removed = self.connection_ids_initial.remove(&dst_cid);
1057 debug_assert!(removed.is_some());
1058 }
1059
1060 fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
1062 if dst_cid.is_empty() {
1063 return;
1064 }
1065 self.connection_ids_initial.insert(
1066 dst_cid,
1067 RouteDatagramTo::Connection(connection, PathId::ZERO),
1068 );
1069 }
1070
1071 fn insert_conn(
1074 &mut self,
1075 network_path: FourTuple,
1076 dst_cid: ConnectionId,
1077 connection: ConnectionHandle,
1078 side: Side,
1079 ) {
1080 match dst_cid.len() {
1081 0 => match side {
1082 Side::Server => {
1083 self.incoming_connection_remotes
1084 .insert(network_path, connection);
1085 }
1086 Side::Client => {
1087 self.outgoing_connection_remotes
1088 .insert(network_path.remote, connection);
1089 }
1090 },
1091 _ => {
1092 self.connection_ids
1093 .insert(dst_cid, (connection, PathId::ZERO));
1094 }
1095 }
1096 }
1097
1098 fn retire(&mut self, dst_cid: ConnectionId) {
1100 self.connection_ids.remove(&dst_cid);
1101 }
1102
1103 fn remove(&mut self, conn: &ConnectionMeta) {
1105 if conn.side.is_server() {
1106 self.remove_initial(conn.init_cid);
1107 }
1108 for cid in conn
1109 .local_cids
1110 .values()
1111 .flat_map(|pcids| pcids.cids.values())
1112 {
1113 self.connection_ids.remove(cid);
1114 }
1115 self.incoming_connection_remotes.remove(&conn.network_path);
1116 self.outgoing_connection_remotes
1117 .remove(&conn.network_path.remote);
1118 for (remote, token) in conn.reset_token.values() {
1119 self.connection_reset_tokens.remove(*remote, *token);
1120 }
1121 }
1122
1123 fn get(&self, network_path: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
1125 if !datagram.dst_cid().is_empty()
1126 && let Some(&(ch, path_id)) = self.connection_ids.get(&datagram.dst_cid())
1127 {
1128 return Some(RouteDatagramTo::Connection(ch, path_id));
1129 }
1130 if (datagram.is_initial() || datagram.is_0rtt())
1131 && let Some(&ch) = self.connection_ids_initial.get(&datagram.dst_cid())
1132 {
1133 return Some(ch);
1134 }
1135 if datagram.dst_cid().is_empty() {
1136 if let Some(&ch) = self.incoming_connection_remotes.get(network_path) {
1137 return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1140 }
1141 if let Some(&ch) = self.outgoing_connection_remotes.get(&network_path.remote) {
1142 return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
1144 }
1145 }
1146 let data = datagram.data();
1147 if data.len() < RESET_TOKEN_SIZE {
1148 return None;
1149 }
1150 self.connection_reset_tokens
1153 .get(network_path.remote, &data[data.len() - RESET_TOKEN_SIZE..])
1154 .cloned()
1155 .map(|ch| RouteDatagramTo::Connection(ch, PathId::ZERO))
1156 }
1157}
1158
1159#[derive(Debug)]
1160pub(crate) struct ConnectionMeta {
1161 init_cid: ConnectionId,
1162 local_cids: FxHashMap<PathId, PathLocalCids>,
1164 network_path: FourTuple,
1169 side: Side,
1170 reset_token: FxHashMap<PathId, (SocketAddr, ResetToken)>,
1180}
1181
1182#[derive(Debug, Default)]
1184struct PathLocalCids {
1185 issued: u64,
1189 cids: FxHashMap<u64, ConnectionId>,
1191}
1192
1193#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
1195pub struct ConnectionHandle(pub usize);
1196
1197impl From<ConnectionHandle> for usize {
1198 fn from(x: ConnectionHandle) -> Self {
1199 x.0
1200 }
1201}
1202
1203impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
1204 type Output = ConnectionMeta;
1205 fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
1206 &self[ch.0]
1207 }
1208}
1209
1210impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
1211 fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
1212 &mut self[ch.0]
1213 }
1214}
1215
1216pub enum DatagramEvent {
1218 ConnectionEvent(ConnectionHandle, ConnectionEvent),
1220 NewConnection(Incoming),
1222 Response(Transmit),
1224}
1225
1226#[derive(derive_more::Debug)]
1228pub struct Incoming {
1229 #[debug(skip)]
1230 received_at: Instant,
1231 network_path: FourTuple,
1232 ecn: Option<EcnCodepoint>,
1233 #[debug(skip)]
1234 packet: InitialPacket,
1235 #[debug(skip)]
1236 rest: Option<BytesMut>,
1237 #[debug(skip)]
1238 crypto: Keys,
1239 token: IncomingToken,
1240 incoming_idx: usize,
1241 #[debug(skip)]
1242 improper_drop_warner: IncomingImproperDropWarner,
1243}
1244
1245impl Incoming {
1246 pub fn local_ip(&self) -> Option<IpAddr> {
1248 self.network_path.local_ip
1249 }
1250
1251 pub fn remote_address(&self) -> SocketAddr {
1253 self.network_path.remote
1254 }
1255
1256 pub fn remote_address_validated(&self) -> bool {
1264 self.token.validated
1265 }
1266
1267 pub fn may_retry(&self) -> bool {
1272 self.token.retry_src_cid.is_none()
1273 }
1274
1275 pub fn orig_dst_cid(&self) -> ConnectionId {
1277 self.token.orig_dst_cid
1278 }
1279
1280 pub fn decrypt(&self) -> Option<DecryptedInitial> {
1285 let packet_number = self.packet.header.number.expand(0);
1286 let mut payload = self.packet.payload.clone();
1287 self.crypto
1288 .packet
1289 .remote
1290 .decrypt(
1291 PathId::ZERO,
1292 packet_number,
1293 &self.packet.header_data,
1294 &mut payload,
1295 )
1296 .ok()?;
1297 Some(DecryptedInitial(payload.freeze()))
1298 }
1299}
1300
1301pub struct DecryptedInitial(Bytes);
1306
1307impl DecryptedInitial {
1308 pub fn alpns(&self) -> Option<IncomingAlpns> {
1314 let frames = frame::Iter::new(self.0.clone()).ok()?;
1315 let mut first = None;
1316 let mut rest = Vec::new();
1317 for frame in frames {
1318 match frame {
1319 Ok(frame::Frame::Crypto(crypto)) => match first {
1320 None => first = Some(crypto),
1321 Some(_) => rest.push(crypto),
1322 },
1323 Err(_) => return None,
1324 _ => {}
1325 }
1326 }
1327 let first = first?;
1328
1329 if rest.is_empty() && first.offset == 0 {
1331 let data = find_alpn_data(&first.data).ok()?;
1332 return Some(IncomingAlpns { data, pos: 0 });
1333 }
1334
1335 rest.push(first);
1337 let source = assemble_crypto_frames(&mut rest)?;
1338 let data = find_alpn_data(&source).ok()?;
1339 Some(IncomingAlpns { data, pos: 0 })
1340 }
1341}
1342
1343const TLS_HANDSHAKE_TYPE_CLIENT_HELLO: u8 = 0x01;
1346const TLS_EXTENSION_TYPE_ALPN: u16 = 0x0010;
1349const TLS_CLIENT_HELLO_FIXED_LEN: usize = 2 + 32;
1352
1353pub struct IncomingAlpns {
1358 data: Bytes,
1359 pos: usize,
1360}
1361
1362impl Iterator for IncomingAlpns {
1363 type Item = Result<Bytes, UnexpectedEnd>;
1364
1365 fn next(&mut self) -> Option<Self::Item> {
1366 if self.pos >= self.data.len() {
1367 return None;
1368 }
1369 let len = self.data[self.pos] as usize;
1370 self.pos += 1;
1371 if self.pos + len > self.data.len() {
1372 return Some(Err(UnexpectedEnd));
1373 }
1374 let proto = self.data.slice(self.pos..self.pos + len);
1375 self.pos += len;
1376 Some(Ok(proto))
1377 }
1378}
1379
1380fn assemble_crypto_frames(frames: &mut [frame::Crypto]) -> Option<Bytes> {
1384 frames.sort_by_key(|f| f.offset);
1385 let capacity = frames.iter().map(|f| f.data.len()).sum();
1386 let mut buf = Vec::with_capacity(capacity);
1387 for f in frames.iter() {
1388 let start = f.offset as usize;
1389 if start > buf.len() {
1390 return None;
1391 }
1392 let end = start + f.data.len();
1393 if end > buf.len() {
1394 buf.extend_from_slice(&f.data[buf.len() - start..]);
1395 }
1396 }
1397 Some(Bytes::from(buf))
1398}
1399
1400fn find_alpn_data(source: &Bytes) -> Result<Bytes, UnexpectedEnd> {
1406 let mut r = &**source;
1407
1408 if u8::decode(&mut r)? != TLS_HANDSHAKE_TYPE_CLIENT_HELLO {
1409 return Err(UnexpectedEnd);
1410 }
1411
1412 let len = decode_u24(&mut r)?;
1414 let mut body = take(&mut r, len)?;
1415
1416 skip(&mut body, TLS_CLIENT_HELLO_FIXED_LEN)?;
1418
1419 skip_u8_prefixed(&mut body)?;
1421 skip_u16_prefixed(&mut body)?;
1422 skip_u8_prefixed(&mut body)?;
1423
1424 let mut exts = take_u16_prefixed(&mut body)?;
1426 while exts.has_remaining() {
1427 let ext_type = u16::decode(&mut exts)?;
1428 let ext_data = take_u16_prefixed(&mut exts)?;
1429 if ext_type == TLS_EXTENSION_TYPE_ALPN {
1430 let list = take_u16_prefixed(&mut &*ext_data)?;
1431 return Ok(source.slice_ref(list));
1432 }
1433 }
1434 Err(UnexpectedEnd)
1435}
1436
1437fn decode_u24(r: &mut &[u8]) -> Result<usize, UnexpectedEnd> {
1439 let a = u8::decode(r)?;
1440 let b = u8::decode(r)?;
1441 let c = u8::decode(r)?;
1442 Ok(u32::from_be_bytes([0, a, b, c]) as usize)
1443}
1444
1445fn take<'a>(r: &mut &'a [u8], len: usize) -> Result<&'a [u8], UnexpectedEnd> {
1447 if r.remaining() < len {
1448 return Err(UnexpectedEnd);
1449 }
1450 let data = &r[..len];
1451 r.advance(len);
1452 Ok(data)
1453}
1454
1455fn take_u16_prefixed<'a>(r: &mut &'a [u8]) -> Result<&'a [u8], UnexpectedEnd> {
1457 let len = u16::decode(r)? as usize;
1458 take(r, len)
1459}
1460
1461fn skip(r: &mut &[u8], len: usize) -> Result<(), UnexpectedEnd> {
1463 take(r, len)?;
1464 Ok(())
1465}
1466
1467fn skip_u8_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1469 let len = u8::decode(r)? as usize;
1470 skip(r, len)
1471}
1472
1473fn skip_u16_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
1475 let len = u16::decode(r)? as usize;
1476 skip(r, len)
1477}
1478
1479struct IncomingImproperDropWarner;
1480
1481impl IncomingImproperDropWarner {
1482 fn dismiss(self) {
1483 mem::forget(self);
1484 }
1485}
1486
1487impl Drop for IncomingImproperDropWarner {
1488 fn drop(&mut self) {
1489 warn!(
1490 "noq_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \
1491 (may cause memory leak and eventual inability to accept new connections)"
1492 );
1493 }
1494}
1495
1496#[derive(Debug, Error, Clone, PartialEq, Eq)]
1500pub enum ConnectError {
1501 #[error("endpoint stopping")]
1505 EndpointStopping,
1506 #[error("CIDs exhausted")]
1510 CidsExhausted,
1511 #[error("invalid server name: {0}")]
1513 InvalidServerName(String),
1514 #[error("invalid remote address: {0}")]
1518 InvalidRemoteAddress(SocketAddr),
1519 #[error("no default client config")]
1523 NoDefaultClientConfig,
1524 #[error("unsupported QUIC version")]
1526 UnsupportedVersion,
1527}
1528
1529#[derive(Debug)]
1531pub struct AcceptError {
1532 pub cause: ConnectionError,
1534 pub response: Option<Transmit>,
1536}
1537
1538#[derive(Debug, Error)]
1540#[error("retry() with validated Incoming")]
1541pub struct RetryError(Box<Incoming>);
1542
1543impl RetryError {
1544 pub fn into_incoming(self) -> Incoming {
1546 *self.0
1547 }
1548}
1549
1550#[derive(Default, Debug)]
1555struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
1556
1557impl ResetTokenTable {
1558 fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
1559 self.0
1560 .entry(remote)
1561 .or_default()
1562 .insert(token, ch)
1563 .is_some()
1564 }
1565
1566 fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
1567 use std::collections::hash_map::Entry;
1568 match self.0.entry(remote) {
1569 Entry::Vacant(_) => {}
1570 Entry::Occupied(mut e) => {
1571 e.get_mut().remove(&token);
1572 if e.get().is_empty() {
1573 e.remove_entry();
1574 }
1575 }
1576 }
1577 }
1578
1579 fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
1580 let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?);
1581 self.0.get(&remote)?.get(&token)
1582 }
1583}
1584
1585#[cfg(test)]
1586mod tests {
1587 use super::*;
1588
1589 #[test]
1590 fn assemble_contiguous() {
1591 let data = b"hello world";
1592 let mut frames = vec![
1593 frame::Crypto {
1594 offset: 0,
1595 data: Bytes::from_static(&data[..5]),
1596 },
1597 frame::Crypto {
1598 offset: 5,
1599 data: Bytes::from_static(&data[5..]),
1600 },
1601 ];
1602 assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1603 }
1604
1605 #[test]
1606 fn assemble_out_of_order() {
1607 let data = b"hello world";
1608 let mut frames = vec![
1609 frame::Crypto {
1610 offset: 5,
1611 data: Bytes::from_static(&data[5..]),
1612 },
1613 frame::Crypto {
1614 offset: 0,
1615 data: Bytes::from_static(&data[..5]),
1616 },
1617 ];
1618 assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1619 }
1620
1621 #[test]
1622 fn assemble_with_overlap() {
1623 let data = b"hello world";
1624 let mut frames = vec![
1625 frame::Crypto {
1626 offset: 0,
1627 data: Bytes::from_static(&data[..7]),
1628 },
1629 frame::Crypto {
1630 offset: 5,
1631 data: Bytes::from_static(&data[5..]),
1632 },
1633 ];
1634 assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
1635 }
1636
1637 #[test]
1638 fn assemble_with_gap() {
1639 let mut frames = vec![frame::Crypto {
1640 offset: 10,
1641 data: Bytes::from_static(b"world"),
1642 }];
1643 assert!(assemble_crypto_frames(&mut frames).is_none());
1644 }
1645}