1use std::{
4 collections::hash_map::Entry,
5 net::{IpAddr, SocketAddr},
6};
7
8use rustc_hash::{FxHashMap, FxHashSet};
9use tracing::trace;
10
11use crate::{
12 PathId, Side, VarInt,
13 frame::{AddAddress, ReachOut, RemoveAddress},
14};
15
16type IpPort = (IpAddr, u16);
17
18#[derive(Debug, thiserror::Error)]
20pub enum Error {
21 #[error("Tried to add too many addresses to their advertised set")]
23 TooManyAddresses,
24 #[error("Not allowed for this endpoint's connection side")]
26 WrongConnectionSide,
27 #[error("n0's nat traversal was not negotiated")]
29 ExtensionNotNegotiated,
30 #[error("Not enough addresses")]
32 NotEnoughAddresses,
33 #[error("Failed to establish paths {0}")]
35 Multipath(super::PathError),
36 #[error("The connection is already closed")]
38 Closed,
39}
40
41pub(crate) struct NatTraversalRound {
42 pub(crate) new_round: VarInt,
44 pub(crate) reach_out_at: FxHashSet<IpPort>,
46 pub(crate) addresses_to_probe: Vec<(VarInt, IpPort)>,
53 pub(crate) prev_round_path_ids: Vec<PathId>,
55}
56
57#[derive(Debug, Clone)]
59pub enum Event {
60 AddressAdded(SocketAddr),
62 AddressRemoved(SocketAddr),
64}
65
66#[derive(Debug, Default)]
68pub(crate) enum State {
69 #[default]
70 NotNegotiated,
71 ClientSide(ClientState),
72 ServerSide(ServerState),
73}
74
75#[derive(Debug)]
76pub(crate) struct ClientState {
77 max_remote_addresses: usize,
81 max_local_addresses: usize,
85 remote_addresses: FxHashMap<VarInt, (IpPort, bool)>,
91 local_addresses: FxHashSet<IpPort>,
94 round: VarInt,
96 round_path_ids: Vec<PathId>,
98}
99
100impl ClientState {
101 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
102 Self {
103 max_remote_addresses,
104 max_local_addresses,
105 remote_addresses: Default::default(),
106 local_addresses: Default::default(),
107 round: Default::default(),
108 round_path_ids: Default::default(),
109 }
110 }
111
112 fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
113 if self.local_addresses.len() < self.max_local_addresses {
114 self.local_addresses.insert(address);
115 Ok(())
116 } else if self.local_addresses.contains(&address) {
117 Ok(())
119 } else {
120 Err(Error::TooManyAddresses)
122 }
123 }
124
125 fn remove_local_address(&mut self, address: &IpPort) {
126 self.local_addresses.remove(address);
127 }
128
129 pub(crate) fn initiate_nat_traversal_round(
141 &mut self,
142 ipv6: bool,
143 ) -> Result<NatTraversalRound, Error> {
144 if self.local_addresses.is_empty() {
145 return Err(Error::NotEnoughAddresses);
146 }
147
148 let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
149 self.round = self.round.saturating_add(1u8);
150 let mut addresses_to_probe = Vec::with_capacity(self.remote_addresses.len());
151 for (id, ((ip, port), report_in_continuation)) in self.remote_addresses.iter_mut() {
152 *report_in_continuation = false;
153
154 if let Some(ip) = map_to_local_socket_family(*ip, ipv6) {
155 addresses_to_probe.push((*id, (ip, *port)));
156 } else {
157 trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
158 }
159 }
160
161 Ok(NatTraversalRound {
162 new_round: self.round,
163 reach_out_at: self.local_addresses.iter().copied().collect(),
164 addresses_to_probe,
165 prev_round_path_ids,
166 })
167 }
168
169 pub(crate) fn report_in_continuation(&mut self, id: VarInt, e: crate::PathError) {
174 match e {
175 crate::PathError::MaxPathIdReached | crate::PathError::RemoteCidsExhausted => {
176 if let Some((_address, report_in_continuation)) = self.remote_addresses.get_mut(&id)
177 {
178 *report_in_continuation = true;
179 }
180 }
181 _ => {}
182 }
183 }
184
185 pub(crate) fn continue_nat_traversal_round(&mut self, ipv6: bool) -> Option<(VarInt, IpPort)> {
195 let (id, (address, report_in_continuation)) = self
197 .remote_addresses
198 .iter_mut()
199 .filter(|(_id, (_addr, report))| *report)
200 .filter_map(|(id, ((ip, port), report))| {
201 let Some(ip) = map_to_local_socket_family(*ip, ipv6) else {
203 trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
204 return None;
205 };
206 Some((*id, ((ip, *port), report)))
207 })
208 .next()?;
209 *report_in_continuation = false;
210 Some((id, address))
211 }
212
213 pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
216 self.round_path_ids = path_ids;
217 }
218
219 pub(crate) fn add_round_path_id(&mut self, path_id: PathId) {
222 self.round_path_ids.push(path_id);
223 }
224
225 pub(crate) fn add_remote_address(
230 &mut self,
231 add_addr: AddAddress,
232 ) -> Result<Option<SocketAddr>, Error> {
233 let AddAddress { seq_no, ip, port } = add_addr;
234 let address = (ip, port);
235 let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
236 match self.remote_addresses.entry(seq_no) {
237 Entry::Occupied(mut occupied_entry) => {
238 let is_update = occupied_entry.get().0 != address;
239 if is_update {
240 occupied_entry.insert((address, false));
241 }
242 Ok(is_update.then_some(address.into()))
245 }
246 Entry::Vacant(vacant_entry) if allow_new => {
247 vacant_entry.insert((address, false));
248 Ok(Some(address.into()))
249 }
250 _ => Err(Error::TooManyAddresses),
251 }
252 }
253
254 pub(crate) fn remove_remote_address(
258 &mut self,
259 remove_addr: RemoveAddress,
260 ) -> Option<SocketAddr> {
261 self.remote_addresses
262 .remove(&remove_addr.seq_no)
263 .map(|(address, _report_in_continuation)| address.into())
264 }
265
266 pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
270 match self.remote_addresses.get(&add_addr.seq_no) {
271 None => true,
272 Some((existing, _)) => existing == &add_addr.ip_port(),
273 }
274 }
275
276 pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
277 self.remote_addresses
278 .values()
279 .map(|(address, _report_in_continuation)| (*address).into())
280 .collect()
281 }
282}
283
284pub(crate) const MAX_OFF_PATH_PROBE_ATTEMPTS: u8 = 10;
286
287#[derive(Debug)]
289pub(crate) struct ProbeState {
290 pub(crate) attempts: u8,
292 pub(crate) ready_to_send: bool,
294}
295
296#[derive(Debug)]
297pub(crate) struct ServerState {
298 max_remote_addresses: usize,
302 max_local_addresses: usize,
306 local_addresses: FxHashMap<IpPort, VarInt>,
309 next_local_addr_id: VarInt,
311 round: VarInt,
316 pending_probes: FxHashMap<IpPort, ProbeState>,
320}
321
322impl ServerState {
323 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
324 Self {
325 max_remote_addresses,
326 max_local_addresses,
327 local_addresses: Default::default(),
328 next_local_addr_id: Default::default(),
329 round: Default::default(),
330 pending_probes: Default::default(),
331 }
332 }
333
334 fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
335 let allow_new = self.local_addresses.len() < self.max_local_addresses;
336 match self.local_addresses.entry(address) {
337 Entry::Occupied(_) => Ok(None),
338 Entry::Vacant(vacant_entry) if allow_new => {
339 let id = self.next_local_addr_id;
340 self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
341 vacant_entry.insert(id);
342 Ok(Some(AddAddress::new(address, id)))
343 }
344 _ => Err(Error::TooManyAddresses),
345 }
346 }
347
348 fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
349 self.local_addresses.remove(address).map(RemoveAddress::new)
350 }
351
352 pub(crate) fn current_round(&self) -> VarInt {
354 self.round
355 }
356
357 pub(crate) fn handle_reach_out(
362 &mut self,
363 reach_out: ReachOut,
364 ipv6: bool,
365 ) -> Result<(), Error> {
366 let ReachOut { round, ip, port } = reach_out;
367
368 if round < self.round {
369 trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
370 return Ok(());
371 }
372 let Some(ip) = map_to_local_socket_family(ip, ipv6) else {
373 trace!("Ignoring IPv6 REACH_OUT frame due to not supporting IPv6 locally");
374 return Ok(());
375 };
376
377 if round > self.round {
378 self.round = round;
379 self.pending_probes.clear();
380 } else if self.pending_probes.len() >= self.max_remote_addresses
381 && !self.pending_probes.contains_key(&(ip, port))
382 {
383 return Err(Error::TooManyAddresses);
384 }
385 self.pending_probes.entry((ip, port)).or_insert(ProbeState {
386 attempts: 0,
387 ready_to_send: true,
388 });
389 Ok(())
390 }
391
392 pub(crate) fn queue_retries(&mut self) -> bool {
397 let mut any_requeued = false;
398 self.pending_probes.retain(|_, state| {
399 if state.attempts > 0 && state.attempts < MAX_OFF_PATH_PROBE_ATTEMPTS {
400 state.ready_to_send = true;
401 any_requeued = true;
402 true
403 } else {
404 state.attempts < MAX_OFF_PATH_PROBE_ATTEMPTS
405 }
406 });
407 any_requeued
408 }
409
410 pub(crate) fn has_pending_retries(&self) -> bool {
413 self.pending_probes
414 .values()
415 .any(|state| state.attempts > 0 && state.attempts < MAX_OFF_PATH_PROBE_ATTEMPTS)
416 }
417
418 pub(crate) fn next_probe_addr(&self) -> Option<SocketAddr> {
420 self.pending_probes
421 .iter()
422 .find(|(_, state)| state.ready_to_send)
423 .map(|(addr, _)| (*addr).into())
424 }
425
426 pub(crate) fn mark_probe_sent(&mut self, remote: IpPort) {
428 if let Some(state) = self.pending_probes.get_mut(&remote) {
429 state.attempts += 1;
430 state.ready_to_send = false;
431 }
432 }
433}
434
435impl State {
436 pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
437 match side {
438 Side::Client => Self::ClientSide(ClientState::new(
439 max_remote_addresses.into(),
440 max_local_addresses.into(),
441 )),
442 Side::Server => Self::ServerSide(ServerState::new(
443 max_remote_addresses.into(),
444 max_local_addresses.into(),
445 )),
446 }
447 }
448
449 pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
450 match self {
451 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
452 Self::ClientSide(client_side) => Ok(client_side),
453 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
454 }
455 }
456
457 pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
458 match self {
459 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
460 Self::ClientSide(client_side) => Ok(client_side),
461 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
462 }
463 }
464
465 pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
466 match self {
467 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
468 Self::ClientSide(_) => Err(Error::WrongConnectionSide),
469 Self::ServerSide(server_side) => Ok(server_side),
470 }
471 }
472
473 pub(crate) fn add_local_address(
481 &mut self,
482 address: SocketAddr,
483 ) -> Result<Option<AddAddress>, Error> {
484 let ip_port = IpPort::from((address.ip(), address.port()));
485 match self {
486 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
487 Self::ClientSide(client_state) => {
488 client_state.add_local_address(ip_port)?;
489 Ok(None)
490 }
491 Self::ServerSide(server_state) => server_state.add_local_address(ip_port),
492 }
493 }
494
495 pub(crate) fn remove_local_address(
502 &mut self,
503 address: SocketAddr,
504 ) -> Result<Option<RemoveAddress>, Error> {
505 let address = IpPort::from((address.ip(), address.port()));
506 match self {
507 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
508 Self::ClientSide(client_state) => {
509 client_state.remove_local_address(&address);
510 Ok(None)
511 }
512 Self::ServerSide(server_state) => Ok(server_state.remove_local_address(&address)),
513 }
514 }
515
516 pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
517 match self {
518 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
519 Self::ClientSide(client_state) => Ok(client_state
520 .local_addresses
521 .iter()
522 .copied()
523 .map(Into::into)
524 .collect()),
525 Self::ServerSide(server_state) => Ok(server_state
526 .local_addresses
527 .keys()
528 .copied()
529 .map(Into::into)
530 .collect()),
531 }
532 }
533}
534
535pub(crate) fn map_to_local_socket_family(address: IpAddr, ipv6: bool) -> Option<IpAddr> {
542 let ip = match address {
543 IpAddr::V4(addr) if ipv6 => IpAddr::V6(addr.to_ipv6_mapped()),
544 IpAddr::V4(_) => address,
545 IpAddr::V6(_) if ipv6 => address,
546 IpAddr::V6(addr) => IpAddr::V4(addr.to_ipv4_mapped()?),
547 };
548 Some(ip)
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 #[test]
556 fn test_basic_server_state() {
557 let mut state = ServerState::new(2, 2);
558
559 state
560 .handle_reach_out(
561 ReachOut {
562 round: 1u32.into(),
563 ip: std::net::Ipv4Addr::LOCALHOST.into(),
564 port: 1,
565 },
566 true,
567 )
568 .unwrap();
569
570 state
571 .handle_reach_out(
572 ReachOut {
573 round: 1u32.into(),
574 ip: "1.1.1.1".parse().unwrap(), port: 2,
576 },
577 true,
578 )
579 .unwrap();
580
581 dbg!(&state);
582 assert_eq!(state.pending_probes.len(), 2);
583
584 let send_probe = |state: &mut ServerState| {
586 let remote = state.next_probe_addr().unwrap();
587 state.mark_probe_sent((remote.ip(), remote.port()));
588 };
589
590 send_probe(&mut state);
591 send_probe(&mut state);
592
593 assert!(state.next_probe_addr().is_none());
595 assert_eq!(state.pending_probes.len(), 2);
596 assert!(state.has_pending_retries());
597
598 assert!(state.queue_retries());
600 send_probe(&mut state);
601 send_probe(&mut state);
602
603 assert!(state.queue_retries());
605 send_probe(&mut state);
606 send_probe(&mut state);
607
608 for _ in 3..MAX_OFF_PATH_PROBE_ATTEMPTS {
610 assert!(state.queue_retries());
611 send_probe(&mut state);
612 send_probe(&mut state);
613 }
614
615 assert!(!state.queue_retries());
617 assert!(state.next_probe_addr().is_none());
618 assert_eq!(state.pending_probes.len(), 0);
619 }
620
621 #[test]
622 fn test_map_to_local_socket() {
623 assert_eq!(
624 map_to_local_socket_family("1.1.1.1".parse().unwrap(), false),
625 Some("1.1.1.1".parse().unwrap())
626 );
627 assert_eq!(
628 map_to_local_socket_family("1.1.1.1".parse().unwrap(), true),
629 Some("::ffff:1.1.1.1".parse().unwrap())
630 );
631 assert_eq!(
632 map_to_local_socket_family("::1".parse().unwrap(), true),
633 Some("::1".parse().unwrap())
634 );
635 assert_eq!(
636 map_to_local_socket_family("::1".parse().unwrap(), false),
637 None
638 );
639 assert_eq!(
640 map_to_local_socket_family("::ffff:1.1.1.1".parse().unwrap(), false),
641 Some("1.1.1.1".parse().unwrap())
642 )
643 }
644}