1use std::{
4 collections::hash_map::Entry,
5 net::{IpAddr, SocketAddr},
6};
7
8use rustc_hash::{FxHashMap, FxHashSet};
9use tracing::trace;
10
11use crate::{
12 PathId, Side, VarInt,
13 frame::{AddAddress, ReachOut, RemoveAddress},
14};
15
16type IpPort = (IpAddr, u16);
17
18#[derive(Debug, thiserror::Error)]
20pub enum Error {
21 #[error("Tried to add too many addresses to their advertised set")]
23 TooManyAddresses,
24 #[error("Not allowed for this endpoint's connection side")]
26 WrongConnectionSide,
27 #[error("n0's nat traversal was not negotiated")]
29 ExtensionNotNegotiated,
30 #[error("Not enough addresses")]
32 NotEnoughAddresses,
33 #[error("Failed to establish paths {0}")]
35 Multipath(super::PathError),
36 #[error("The connection is already closed")]
38 Closed,
39}
40
41pub(crate) struct NatTraversalRound {
42 pub(crate) new_round: VarInt,
44 pub(crate) reach_out_at: FxHashSet<IpPort>,
46 pub(crate) addresses_to_probe: Vec<(VarInt, IpPort)>,
53 pub(crate) prev_round_path_ids: Vec<PathId>,
55}
56
57#[derive(Debug, Clone)]
59pub enum Event {
60 AddressAdded(SocketAddr),
62 AddressRemoved(SocketAddr),
64}
65
66#[derive(Debug, Default)]
68pub(crate) enum State {
69 #[default]
70 NotNegotiated,
71 ClientSide(ClientState),
72 ServerSide(ServerState),
73}
74
75#[derive(Debug)]
76pub(crate) struct ClientState {
77 max_remote_addresses: usize,
81 max_local_addresses: usize,
85 remote_addresses: FxHashMap<VarInt, (IpPort, bool)>,
91 local_addresses: FxHashSet<IpPort>,
94 round: VarInt,
96 round_path_ids: Vec<PathId>,
98}
99
100impl ClientState {
101 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
102 Self {
103 max_remote_addresses,
104 max_local_addresses,
105 remote_addresses: Default::default(),
106 local_addresses: Default::default(),
107 round: Default::default(),
108 round_path_ids: Default::default(),
109 }
110 }
111
112 fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
113 if self.local_addresses.len() < self.max_local_addresses {
114 self.local_addresses.insert(address);
115 Ok(())
116 } else if self.local_addresses.contains(&address) {
117 Ok(())
119 } else {
120 Err(Error::TooManyAddresses)
122 }
123 }
124
125 fn remove_local_address(&mut self, address: &IpPort) {
126 self.local_addresses.remove(address);
127 }
128
129 pub(crate) fn initiate_nat_traversal_round(
141 &mut self,
142 ipv6: bool,
143 ) -> Result<NatTraversalRound, Error> {
144 if self.local_addresses.is_empty() {
145 return Err(Error::NotEnoughAddresses);
146 }
147
148 let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
149 self.round = self.round.saturating_add(1u8);
150 let mut addresses_to_probe = Vec::with_capacity(self.remote_addresses.len());
151 for (id, ((ip, port), report_in_continuation)) in self.remote_addresses.iter_mut() {
152 *report_in_continuation = false;
153
154 if let Some(ip) = map_to_local_socket_family(*ip, ipv6) {
155 addresses_to_probe.push((*id, (ip, *port)));
156 } else {
157 trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
158 }
159 }
160
161 Ok(NatTraversalRound {
162 new_round: self.round,
163 reach_out_at: self.local_addresses.iter().copied().collect(),
164 addresses_to_probe,
165 prev_round_path_ids,
166 })
167 }
168
169 pub(crate) fn report_in_continuation(&mut self, id: VarInt, e: crate::PathError) {
174 match e {
175 crate::PathError::MaxPathIdReached | crate::PathError::RemoteCidsExhausted => {
176 if let Some((_address, report_in_continuation)) = self.remote_addresses.get_mut(&id)
177 {
178 *report_in_continuation = true;
179 }
180 }
181 _ => {}
182 }
183 }
184
185 pub(crate) fn continue_nat_traversal_round(&mut self, ipv6: bool) -> Option<(VarInt, IpPort)> {
195 let (id, (address, report_in_continuation)) = self
197 .remote_addresses
198 .iter_mut()
199 .filter(|(_id, (_addr, report))| *report)
200 .filter_map(|(id, ((ip, port), report))| {
201 let Some(ip) = map_to_local_socket_family(*ip, ipv6) else {
203 trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
204 return None;
205 };
206 Some((*id, ((ip, *port), report)))
207 })
208 .next()?;
209 *report_in_continuation = false;
210 Some((id, address))
211 }
212
213 pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
216 self.round_path_ids = path_ids;
217 }
218
219 pub(crate) fn add_round_path_id(&mut self, path_id: PathId) {
222 self.round_path_ids.push(path_id);
223 }
224
225 pub(crate) fn add_remote_address(
230 &mut self,
231 add_addr: AddAddress,
232 ) -> Result<Option<SocketAddr>, Error> {
233 let AddAddress { seq_no, ip, port } = add_addr;
234 let address = (ip, port);
235 let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
236 match self.remote_addresses.entry(seq_no) {
237 Entry::Occupied(mut occupied_entry) => {
238 let is_update = occupied_entry.get().0 != address;
239 if is_update {
240 occupied_entry.insert((address, false));
241 }
242 Ok(is_update.then_some(address.into()))
245 }
246 Entry::Vacant(vacant_entry) if allow_new => {
247 vacant_entry.insert((address, false));
248 Ok(Some(address.into()))
249 }
250 _ => Err(Error::TooManyAddresses),
251 }
252 }
253
254 pub(crate) fn remove_remote_address(
258 &mut self,
259 remove_addr: RemoveAddress,
260 ) -> Option<SocketAddr> {
261 self.remote_addresses
262 .remove(&remove_addr.seq_no)
263 .map(|(address, _report_in_continuation)| address.into())
264 }
265
266 pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
270 match self.remote_addresses.get(&add_addr.seq_no) {
271 None => true,
272 Some((existing, _)) => existing == &add_addr.ip_port(),
273 }
274 }
275
276 pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
277 self.remote_addresses
278 .values()
279 .map(|(address, _report_in_continuation)| (*address).into())
280 .collect()
281 }
282}
283
284#[derive(Debug)]
285pub(crate) struct ServerState {
286 max_remote_addresses: usize,
290 max_local_addresses: usize,
294 local_addresses: FxHashMap<IpPort, VarInt>,
297 next_local_addr_id: VarInt,
299 round: VarInt,
304 pending_probes: FxHashSet<IpPort>,
306}
307
308impl ServerState {
309 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
310 Self {
311 max_remote_addresses,
312 max_local_addresses,
313 local_addresses: Default::default(),
314 next_local_addr_id: Default::default(),
315 round: Default::default(),
316 pending_probes: Default::default(),
317 }
318 }
319
320 fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
321 let allow_new = self.local_addresses.len() < self.max_local_addresses;
322 match self.local_addresses.entry(address) {
323 Entry::Occupied(_) => Ok(None),
324 Entry::Vacant(vacant_entry) if allow_new => {
325 let id = self.next_local_addr_id;
326 self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
327 vacant_entry.insert(id);
328 Ok(Some(AddAddress::new(address, id)))
329 }
330 _ => Err(Error::TooManyAddresses),
331 }
332 }
333
334 fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
335 self.local_addresses.remove(address).map(RemoveAddress::new)
336 }
337
338 pub(crate) fn handle_reach_out(
343 &mut self,
344 reach_out: ReachOut,
345 ipv6: bool,
346 ) -> Result<(), Error> {
347 let ReachOut { round, ip, port } = reach_out;
348
349 if round < self.round {
350 trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
351 return Ok(());
352 }
353 let Some(ip) = map_to_local_socket_family(ip, ipv6) else {
354 trace!("Ignoring IPv6 REACH_OUT frame due to not supporting IPv6 locally");
355 return Ok(());
356 };
357
358 if round > self.round {
359 self.round = round;
360 self.pending_probes.clear();
361 } else if self.pending_probes.len() >= self.max_remote_addresses
362 && !self.pending_probes.contains(&(ip, port))
363 {
364 return Err(Error::TooManyAddresses);
365 }
366 self.pending_probes.insert((ip, port));
367 Ok(())
368 }
369
370 pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
371 self.pending_probes
372 .iter()
373 .next()
374 .copied()
375 .map(|remote| ServerProbing {
376 remote,
377 pending_probes: &mut self.pending_probes,
378 })
379 }
380}
381
382pub(crate) struct ServerProbing<'a> {
383 remote: IpPort,
384 pending_probes: &'a mut FxHashSet<IpPort>,
385}
386
387impl<'a> ServerProbing<'a> {
388 pub(crate) fn mark_as_sent(self) {
389 self.pending_probes.remove(&self.remote);
390 }
391
392 pub(crate) fn remote(&self) -> SocketAddr {
393 self.remote.into()
394 }
395}
396
397impl State {
398 pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
399 match side {
400 Side::Client => Self::ClientSide(ClientState::new(
401 max_remote_addresses.into(),
402 max_local_addresses.into(),
403 )),
404 Side::Server => Self::ServerSide(ServerState::new(
405 max_remote_addresses.into(),
406 max_local_addresses.into(),
407 )),
408 }
409 }
410
411 pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
412 match self {
413 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
414 Self::ClientSide(client_side) => Ok(client_side),
415 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
416 }
417 }
418
419 pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
420 match self {
421 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
422 Self::ClientSide(client_side) => Ok(client_side),
423 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
424 }
425 }
426
427 pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
428 match self {
429 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
430 Self::ClientSide(_) => Err(Error::WrongConnectionSide),
431 Self::ServerSide(server_side) => Ok(server_side),
432 }
433 }
434
435 pub(crate) fn add_local_address(
443 &mut self,
444 address: SocketAddr,
445 ) -> Result<Option<AddAddress>, Error> {
446 let ip_port = IpPort::from((address.ip(), address.port()));
447 match self {
448 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
449 Self::ClientSide(client_state) => {
450 client_state.add_local_address(ip_port)?;
451 Ok(None)
452 }
453 Self::ServerSide(server_state) => server_state.add_local_address(ip_port),
454 }
455 }
456
457 pub(crate) fn remove_local_address(
464 &mut self,
465 address: SocketAddr,
466 ) -> Result<Option<RemoveAddress>, Error> {
467 let address = IpPort::from((address.ip(), address.port()));
468 match self {
469 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
470 Self::ClientSide(client_state) => {
471 client_state.remove_local_address(&address);
472 Ok(None)
473 }
474 Self::ServerSide(server_state) => Ok(server_state.remove_local_address(&address)),
475 }
476 }
477
478 pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
479 match self {
480 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
481 Self::ClientSide(client_state) => Ok(client_state
482 .local_addresses
483 .iter()
484 .copied()
485 .map(Into::into)
486 .collect()),
487 Self::ServerSide(server_state) => Ok(server_state
488 .local_addresses
489 .keys()
490 .copied()
491 .map(Into::into)
492 .collect()),
493 }
494 }
495}
496
497pub(crate) fn map_to_local_socket_family(address: IpAddr, ipv6: bool) -> Option<IpAddr> {
504 let ip = match address {
505 IpAddr::V4(addr) if ipv6 => IpAddr::V6(addr.to_ipv6_mapped()),
506 IpAddr::V4(_) => address,
507 IpAddr::V6(_) if ipv6 => address,
508 IpAddr::V6(addr) => IpAddr::V4(addr.to_ipv4_mapped()?),
509 };
510 Some(ip)
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_basic_server_state() {
519 let mut state = ServerState::new(2, 2);
520
521 state
522 .handle_reach_out(
523 ReachOut {
524 round: 1u32.into(),
525 ip: std::net::Ipv4Addr::LOCALHOST.into(),
526 port: 1,
527 },
528 true,
529 )
530 .unwrap();
531
532 state
533 .handle_reach_out(
534 ReachOut {
535 round: 1u32.into(),
536 ip: "1.1.1.1".parse().unwrap(), port: 2,
538 },
539 true,
540 )
541 .unwrap();
542
543 dbg!(&state);
544 assert_eq!(state.pending_probes.len(), 2);
545
546 let probe = state.next_probe().unwrap();
547 probe.mark_as_sent();
548 let probe = state.next_probe().unwrap();
549 probe.mark_as_sent();
550
551 assert!(state.next_probe().is_none());
552 assert_eq!(state.pending_probes.len(), 0);
553 }
554
555 #[test]
556 fn test_map_to_local_socket() {
557 assert_eq!(
558 map_to_local_socket_family("1.1.1.1".parse().unwrap(), false),
559 Some("1.1.1.1".parse().unwrap())
560 );
561 assert_eq!(
562 map_to_local_socket_family("1.1.1.1".parse().unwrap(), true),
563 Some("::ffff:1.1.1.1".parse().unwrap())
564 );
565 assert_eq!(
566 map_to_local_socket_family("::1".parse().unwrap(), true),
567 Some("::1".parse().unwrap())
568 );
569 assert_eq!(
570 map_to_local_socket_family("::1".parse().unwrap(), false),
571 None
572 );
573 assert_eq!(
574 map_to_local_socket_family("::ffff:1.1.1.1".parse().unwrap(), false),
575 Some("1.1.1.1".parse().unwrap())
576 )
577 }
578}