1use std::{
4 collections::hash_map::Entry,
5 net::{IpAddr, SocketAddr},
6};
7
8use identity_hash::IntMap;
9use rustc_hash::{FxHashMap, FxHashSet};
10use tracing::trace;
11
12use crate::{
13 PathId, Side, VarInt,
14 frame::{AddAddress, ReachOut, RemoveAddress},
15};
16
17type IpPort = (IpAddr, u16);
18
19#[derive(Debug, thiserror::Error)]
21pub enum Error {
22 #[error("Tried to add too many addresses to their advertised set")]
24 TooManyAddresses,
25 #[error("Not allowed for this endpoint's connection side")]
27 WrongConnectionSide,
28 #[error("Iroh's nat traversal was not negotiated")]
30 ExtensionNotNegotiated,
31 #[error("Not enough addresses")]
33 NotEnoughAddresses,
34 #[error("Failed to establish paths {0}")]
36 Multipath(super::PathError),
37 #[error("The connection is already closed")]
39 Closed,
40}
41
42pub(crate) struct NatTraversalRound {
43 pub(crate) new_round: VarInt,
45 pub(crate) reach_out_at: Vec<(IpAddr, u16)>,
47 pub(crate) addresses_to_probe: Vec<(IpAddr, u16)>,
49 pub(crate) prev_round_path_ids: Vec<PathId>,
51}
52
53#[derive(Debug, Clone)]
55pub enum Event {
56 AddressAdded(SocketAddr),
58 AddressRemoved(SocketAddr),
60}
61
62#[derive(Debug, Default)]
64pub(crate) enum State {
65 #[default]
66 NotNegotiated,
67 ClientSide(ClientState),
68 ServerSide(ServerState),
69}
70
71#[derive(Debug)]
72pub(crate) struct ClientState {
73 max_remote_addresses: usize,
77 max_local_addresses: usize,
81 remote_addresses: FxHashMap<VarInt, (IpAddr, u16)>,
84 local_addresses: FxHashSet<(IpAddr, u16)>,
87 round: VarInt,
89 round_path_ids: Vec<PathId>,
91}
92
93impl ClientState {
94 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
95 Self {
96 max_remote_addresses,
97 max_local_addresses,
98 remote_addresses: Default::default(),
99 local_addresses: Default::default(),
100 round: Default::default(),
101 round_path_ids: Default::default(),
102 }
103 }
104
105 fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
106 if self.local_addresses.len() < self.max_local_addresses {
107 self.local_addresses.insert(address);
108 Ok(())
109 } else if self.local_addresses.contains(&address) {
110 Ok(())
112 } else {
113 Err(Error::TooManyAddresses)
115 }
116 }
117
118 fn remove_local_address(&mut self, address: &IpPort) {
119 self.local_addresses.remove(address);
120 }
121
122 pub(crate) fn initiate_nat_traversal_round(&mut self) -> Result<NatTraversalRound, Error> {
129 if self.local_addresses.is_empty() {
130 return Err(Error::NotEnoughAddresses);
131 }
132
133 let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
134 self.round = self.round.saturating_add(1u8);
135
136 Ok(NatTraversalRound {
137 new_round: self.round,
138 reach_out_at: self.local_addresses.iter().copied().collect(),
139 addresses_to_probe: self.remote_addresses.values().copied().collect(),
140 prev_round_path_ids,
141 })
142 }
143
144 pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
147 self.round_path_ids = path_ids;
148 }
149
150 pub(crate) fn add_remote_address(
155 &mut self,
156 add_addr: AddAddress,
157 ) -> Result<Option<SocketAddr>, Error> {
158 let AddAddress { seq_no, ip, port } = add_addr;
159 let address = (ip, port);
160 let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
161 match self.remote_addresses.entry(seq_no) {
162 Entry::Occupied(mut occupied_entry) => {
163 let old_value = occupied_entry.insert(address);
164 Ok((address != old_value).then_some(address.into()))
167 }
168 Entry::Vacant(vacant_entry) if allow_new => {
169 vacant_entry.insert(address);
170 Ok(Some(address.into()))
171 }
172 _ => Err(Error::TooManyAddresses),
173 }
174 }
175
176 pub(crate) fn remove_remote_address(
180 &mut self,
181 remove_addr: RemoveAddress,
182 ) -> Option<SocketAddr> {
183 self.remote_addresses
184 .remove(&remove_addr.seq_no)
185 .map(Into::into)
186 }
187
188 pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
192 let existing = self.remote_addresses.get(&add_addr.seq_no);
193 existing.is_none() || existing == Some(&add_addr.ip_port())
194 }
195
196 pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
197 self.remote_addresses
198 .values()
199 .copied()
200 .map(Into::into)
201 .collect()
202 }
203}
204
205#[derive(Debug)]
206pub(crate) struct ServerState {
207 max_remote_addresses: usize,
211 max_local_addresses: usize,
215 local_addresses: FxHashMap<IpPort, VarInt>,
218 next_local_addr_id: VarInt,
220 round: VarInt,
225 pending_probes: FxHashSet<IpPort>,
227 active_probes: IntMap<u64, IpPort>,
231}
232
233impl ServerState {
234 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
235 Self {
236 max_remote_addresses,
237 max_local_addresses,
238 local_addresses: Default::default(),
239 next_local_addr_id: Default::default(),
240 round: Default::default(),
241 pending_probes: Default::default(),
242 active_probes: Default::default(),
243 }
244 }
245
246 fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
247 let allow_new = self.local_addresses.len() < self.max_local_addresses;
248 match self.local_addresses.entry(address) {
249 Entry::Occupied(_) => Ok(None),
250 Entry::Vacant(vacant_entry) if allow_new => {
251 let id = self.next_local_addr_id;
252 self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
253 vacant_entry.insert(id);
254 Ok(Some(AddAddress::new(address, id)))
255 }
256 _ => Err(Error::TooManyAddresses),
257 }
258 }
259
260 fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
261 self.local_addresses.remove(address).map(RemoveAddress::new)
262 }
263
264 pub(crate) fn handle_reach_out(&mut self, reach_out: ReachOut) -> Result<(), Error> {
271 let ReachOut { round, ip, port } = reach_out;
272
273 if round < self.round {
274 trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
275 return Ok(());
276 }
277
278 if round > self.round {
279 self.round = round;
280 self.pending_probes.clear();
281 for (token, remote) in self.active_probes.drain() {
284 let remote: SocketAddr = remote.into();
285 trace!(token=format!("{:08x}", token), %remote, "dropping nat traversal challenge pending response");
286 }
287 } else if self.pending_probes.len() >= self.max_remote_addresses {
288 return Err(Error::TooManyAddresses);
289 }
290 self.pending_probes.insert((ip, port));
291 Ok(())
292 }
293
294 pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
295 self.pending_probes
296 .iter()
297 .next()
298 .copied()
299 .map(|remote| ServerProbing {
300 remote,
301 pending_probes: &mut self.pending_probes,
302 active_probes: &mut self.active_probes,
303 })
304 }
305}
306
307pub(crate) struct ServerProbing<'a> {
308 remote: IpPort,
309 pending_probes: &'a mut FxHashSet<IpPort>,
310 active_probes: &'a mut IntMap<u64, IpPort>,
311}
312
313impl<'a> ServerProbing<'a> {
314 pub(crate) fn finish(self, token: u64) {
315 self.pending_probes.remove(&self.remote);
316 self.active_probes.insert(token, self.remote);
317 }
318
319 pub(crate) fn remote(&self) -> SocketAddr {
320 self.remote.into()
321 }
322}
323
324impl State {
325 pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
326 match side {
327 Side::Client => Self::ClientSide(ClientState::new(
328 max_remote_addresses.into(),
329 max_local_addresses.into(),
330 )),
331 Side::Server => Self::ServerSide(ServerState::new(
332 max_remote_addresses.into(),
333 max_local_addresses.into(),
334 )),
335 }
336 }
337
338 pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
339 match self {
340 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
341 Self::ClientSide(client_side) => Ok(client_side),
342 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
343 }
344 }
345
346 pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
347 match self {
348 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
349 Self::ClientSide(client_side) => Ok(client_side),
350 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
351 }
352 }
353
354 pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
355 match self {
356 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
357 Self::ClientSide(_) => Err(Error::WrongConnectionSide),
358 Self::ServerSide(server_side) => Ok(server_side),
359 }
360 }
361
362 pub(crate) fn add_local_address(
370 &mut self,
371 address: SocketAddr,
372 ) -> Result<Option<AddAddress>, Error> {
373 match self {
374 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
375 Self::ClientSide(client_state) => {
376 client_state.add_local_address((address.ip(), address.port()))?;
377 Ok(None)
378 }
379 Self::ServerSide(server_state) => {
380 server_state.add_local_address((address.ip(), address.port()))
381 }
382 }
383 }
384
385 pub(crate) fn remove_local_address(
392 &mut self,
393 address: SocketAddr,
394 ) -> Result<Option<RemoveAddress>, Error> {
395 let address = &(address.ip(), address.port());
396 match self {
397 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
398 Self::ClientSide(client_state) => {
399 client_state.remove_local_address(address);
400 Ok(None)
401 }
402 Self::ServerSide(server_state) => Ok(server_state.remove_local_address(address)),
403 }
404 }
405
406 pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
407 match self {
408 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
409 Self::ClientSide(client_state) => Ok(client_state
410 .local_addresses
411 .iter()
412 .copied()
413 .map(Into::into)
414 .collect()),
415 Self::ServerSide(server_state) => Ok(server_state
416 .local_addresses
417 .keys()
418 .copied()
419 .map(Into::into)
420 .collect()),
421 }
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_basic_server_state() {
431 let mut state = ServerState::new(2, 2);
432
433 state
434 .handle_reach_out(ReachOut {
435 round: 1u32.into(),
436 ip: std::net::Ipv4Addr::LOCALHOST.into(),
437 port: 1,
438 })
439 .unwrap();
440
441 state
442 .handle_reach_out(ReachOut {
443 round: 1u32.into(),
444 ip: "1.1.1.1".parse().unwrap(), port: 2,
446 })
447 .unwrap();
448
449 dbg!(&state);
450 assert_eq!(state.pending_probes.len(), 2);
451 assert_eq!(state.active_probes.len(), 0);
452
453 let probe = state.next_probe().unwrap();
454 probe.finish(1);
455 let probe = state.next_probe().unwrap();
456 probe.finish(2);
457
458 assert!(state.next_probe().is_none());
459 assert_eq!(state.pending_probes.len(), 0);
460 assert_eq!(state.active_probes.len(), 2);
461 }
462}