1use std::{
4 collections::hash_map::Entry,
5 net::{IpAddr, SocketAddr},
6};
7
8use identity_hash::IntMap;
9use rustc_hash::{FxHashMap, FxHashSet};
10use tracing::trace;
11
12use crate::{
13 PathId, Side, VarInt,
14 frame::{AddAddress, ReachOut, RemoveAddress},
15};
16
17type IpPort = (IpAddr, u16);
18
19#[derive(Debug, thiserror::Error)]
21pub enum Error {
22 #[error("Tried to add too many addresses to their advertised set")]
24 TooManyAddresses,
25 #[error("Not allowed for this endpoint's connection side")]
27 WrongConnectionSide,
28 #[error("Iroh's nat traversal was not negotiated")]
30 ExtensionNotNegotiated,
31 #[error("Not enough addresses")]
33 NotEnoughAddresses,
34 #[error("Failed to establish paths {0}")]
36 Multipath(super::PathError),
37 #[error("The connection is already closed")]
39 Closed,
40}
41
42pub(crate) struct NatTraversalRound {
43 pub(crate) new_round: VarInt,
45 pub(crate) reach_out_at: Vec<IpPort>,
47 pub(crate) addresses_to_probe: Vec<(VarInt, IpPort)>,
52 pub(crate) prev_round_path_ids: Vec<PathId>,
54}
55
56#[derive(Debug, Clone)]
58pub enum Event {
59 AddressAdded(SocketAddr),
61 AddressRemoved(SocketAddr),
63}
64
65#[derive(Debug, Default)]
67pub(crate) enum State {
68 #[default]
69 NotNegotiated,
70 ClientSide(ClientState),
71 ServerSide(ServerState),
72}
73
74#[derive(Debug)]
75pub(crate) struct ClientState {
76 max_remote_addresses: usize,
80 max_local_addresses: usize,
84 remote_addresses: FxHashMap<VarInt, (IpPort, bool)>,
90 local_addresses: FxHashSet<IpPort>,
93 round: VarInt,
95 round_path_ids: Vec<PathId>,
97}
98
99impl ClientState {
100 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
101 Self {
102 max_remote_addresses,
103 max_local_addresses,
104 remote_addresses: Default::default(),
105 local_addresses: Default::default(),
106 round: Default::default(),
107 round_path_ids: Default::default(),
108 }
109 }
110
111 fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
112 if self.local_addresses.len() < self.max_local_addresses {
113 self.local_addresses.insert(address);
114 Ok(())
115 } else if self.local_addresses.contains(&address) {
116 Ok(())
118 } else {
119 Err(Error::TooManyAddresses)
121 }
122 }
123
124 fn remove_local_address(&mut self, address: &IpPort) {
125 self.local_addresses.remove(address);
126 }
127
128 pub(crate) fn initiate_nat_traversal_round(&mut self) -> Result<NatTraversalRound, Error> {
135 if self.local_addresses.is_empty() {
136 return Err(Error::NotEnoughAddresses);
137 }
138
139 let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
140 self.round = self.round.saturating_add(1u8);
141 let mut addresses_to_probe = Vec::with_capacity(self.remote_addresses.len());
142 for (id, (address, report_in_continuation)) in self.remote_addresses.iter_mut() {
143 addresses_to_probe.push((*id, *address));
144 *report_in_continuation = false;
145 }
146
147 Ok(NatTraversalRound {
148 new_round: self.round,
149 reach_out_at: self.local_addresses.iter().copied().collect(),
150 addresses_to_probe,
151 prev_round_path_ids,
152 })
153 }
154
155 pub(crate) fn report_in_continuation(&mut self, id: VarInt, e: crate::PathError) {
160 match e {
161 crate::PathError::MaxPathIdReached | crate::PathError::RemoteCidsExhausted => {
162 if let Some((_address, report_in_continuation)) = self.remote_addresses.get_mut(&id)
163 {
164 *report_in_continuation = true;
165 }
166 }
167 _ => {}
168 }
169 }
170
171 pub(crate) fn continue_nat_traversal_round(&mut self) -> Option<(VarInt, IpPort)> {
176 let (id, (address, report_in_continuation)) = self
178 .remote_addresses
179 .iter_mut()
180 .find(|(_id, (_addr, report))| *report)?;
181 *report_in_continuation = false;
182 Some((*id, *address))
183 }
184
185 pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
188 self.round_path_ids = path_ids;
189 }
190
191 pub(crate) fn add_round_path_id(&mut self, path_id: PathId) {
194 self.round_path_ids.push(path_id);
195 }
196
197 pub(crate) fn add_remote_address(
202 &mut self,
203 add_addr: AddAddress,
204 ) -> Result<Option<SocketAddr>, Error> {
205 let AddAddress { seq_no, ip, port } = add_addr;
206 let address = (ip, port);
207 let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
208 match self.remote_addresses.entry(seq_no) {
209 Entry::Occupied(mut occupied_entry) => {
210 let is_update = occupied_entry.get().0 != address;
211 if is_update {
212 occupied_entry.insert((address, false));
213 }
214 Ok(is_update.then_some(address.into()))
217 }
218 Entry::Vacant(vacant_entry) if allow_new => {
219 vacant_entry.insert((address, false));
220 Ok(Some(address.into()))
221 }
222 _ => Err(Error::TooManyAddresses),
223 }
224 }
225
226 pub(crate) fn remove_remote_address(
230 &mut self,
231 remove_addr: RemoveAddress,
232 ) -> Option<SocketAddr> {
233 self.remote_addresses
234 .remove(&remove_addr.seq_no)
235 .map(|(address, _report_in_continuation)| address.into())
236 }
237
238 pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
242 match self.remote_addresses.get(&add_addr.seq_no) {
243 None => true,
244 Some((existing, _)) => existing == &add_addr.ip_port(),
245 }
246 }
247
248 pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
249 self.remote_addresses
250 .values()
251 .map(|(address, _report_in_continuation)| (*address).into())
252 .collect()
253 }
254}
255
256#[derive(Debug)]
257pub(crate) struct ServerState {
258 max_remote_addresses: usize,
262 max_local_addresses: usize,
266 local_addresses: FxHashMap<IpPort, VarInt>,
269 next_local_addr_id: VarInt,
271 round: VarInt,
276 pending_probes: FxHashSet<IpPort>,
278 active_probes: IntMap<u64, IpPort>,
282}
283
284impl ServerState {
285 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
286 Self {
287 max_remote_addresses,
288 max_local_addresses,
289 local_addresses: Default::default(),
290 next_local_addr_id: Default::default(),
291 round: Default::default(),
292 pending_probes: Default::default(),
293 active_probes: Default::default(),
294 }
295 }
296
297 fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
298 let allow_new = self.local_addresses.len() < self.max_local_addresses;
299 match self.local_addresses.entry(address) {
300 Entry::Occupied(_) => Ok(None),
301 Entry::Vacant(vacant_entry) if allow_new => {
302 let id = self.next_local_addr_id;
303 self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
304 vacant_entry.insert(id);
305 Ok(Some(AddAddress::new(address, id)))
306 }
307 _ => Err(Error::TooManyAddresses),
308 }
309 }
310
311 fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
312 self.local_addresses.remove(address).map(RemoveAddress::new)
313 }
314
315 pub(crate) fn handle_reach_out(&mut self, reach_out: ReachOut) -> Result<(), Error> {
322 let ReachOut { round, ip, port } = reach_out;
323
324 if round < self.round {
325 trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
326 return Ok(());
327 }
328
329 if round > self.round {
330 self.round = round;
331 self.pending_probes.clear();
332 for (token, remote) in self.active_probes.drain() {
335 let remote: SocketAddr = remote.into();
336 trace!(token=format!("{:08x}", token), %remote, "dropping nat traversal challenge pending response");
337 }
338 } else if self.pending_probes.len() >= self.max_remote_addresses {
339 return Err(Error::TooManyAddresses);
340 }
341 self.pending_probes.insert((ip, port));
342 Ok(())
343 }
344
345 pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
346 self.pending_probes
347 .iter()
348 .next()
349 .copied()
350 .map(|remote| ServerProbing {
351 remote,
352 pending_probes: &mut self.pending_probes,
353 active_probes: &mut self.active_probes,
354 })
355 }
356}
357
358pub(crate) struct ServerProbing<'a> {
359 remote: IpPort,
360 pending_probes: &'a mut FxHashSet<IpPort>,
361 active_probes: &'a mut IntMap<u64, IpPort>,
362}
363
364impl<'a> ServerProbing<'a> {
365 pub(crate) fn finish(self, token: u64) {
366 self.pending_probes.remove(&self.remote);
367 self.active_probes.insert(token, self.remote);
368 }
369
370 pub(crate) fn remote(&self) -> SocketAddr {
371 self.remote.into()
372 }
373}
374
375impl State {
376 pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
377 match side {
378 Side::Client => Self::ClientSide(ClientState::new(
379 max_remote_addresses.into(),
380 max_local_addresses.into(),
381 )),
382 Side::Server => Self::ServerSide(ServerState::new(
383 max_remote_addresses.into(),
384 max_local_addresses.into(),
385 )),
386 }
387 }
388
389 pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
390 match self {
391 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
392 Self::ClientSide(client_side) => Ok(client_side),
393 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
394 }
395 }
396
397 pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
398 match self {
399 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
400 Self::ClientSide(client_side) => Ok(client_side),
401 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
402 }
403 }
404
405 pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
406 match self {
407 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
408 Self::ClientSide(_) => Err(Error::WrongConnectionSide),
409 Self::ServerSide(server_side) => Ok(server_side),
410 }
411 }
412
413 pub(crate) fn add_local_address(
421 &mut self,
422 address: SocketAddr,
423 ) -> Result<Option<AddAddress>, Error> {
424 match self {
425 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
426 Self::ClientSide(client_state) => {
427 client_state.add_local_address((address.ip(), address.port()))?;
428 Ok(None)
429 }
430 Self::ServerSide(server_state) => {
431 server_state.add_local_address((address.ip(), address.port()))
432 }
433 }
434 }
435
436 pub(crate) fn remove_local_address(
443 &mut self,
444 address: SocketAddr,
445 ) -> Result<Option<RemoveAddress>, Error> {
446 let address = &(address.ip(), address.port());
447 match self {
448 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
449 Self::ClientSide(client_state) => {
450 client_state.remove_local_address(address);
451 Ok(None)
452 }
453 Self::ServerSide(server_state) => Ok(server_state.remove_local_address(address)),
454 }
455 }
456
457 pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
458 match self {
459 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
460 Self::ClientSide(client_state) => Ok(client_state
461 .local_addresses
462 .iter()
463 .copied()
464 .map(Into::into)
465 .collect()),
466 Self::ServerSide(server_state) => Ok(server_state
467 .local_addresses
468 .keys()
469 .copied()
470 .map(Into::into)
471 .collect()),
472 }
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_basic_server_state() {
482 let mut state = ServerState::new(2, 2);
483
484 state
485 .handle_reach_out(ReachOut {
486 round: 1u32.into(),
487 ip: std::net::Ipv4Addr::LOCALHOST.into(),
488 port: 1,
489 })
490 .unwrap();
491
492 state
493 .handle_reach_out(ReachOut {
494 round: 1u32.into(),
495 ip: "1.1.1.1".parse().unwrap(), port: 2,
497 })
498 .unwrap();
499
500 dbg!(&state);
501 assert_eq!(state.pending_probes.len(), 2);
502 assert_eq!(state.active_probes.len(), 0);
503
504 let probe = state.next_probe().unwrap();
505 probe.finish(1);
506 let probe = state.next_probe().unwrap();
507 probe.finish(2);
508
509 assert!(state.next_probe().is_none());
510 assert_eq!(state.pending_probes.len(), 0);
511 assert_eq!(state.active_probes.len(), 2);
512 }
513}