1use std::{
4 collections::hash_map::Entry,
5 net::{IpAddr, SocketAddr},
6};
7
8use identity_hash::IntMap;
9use rustc_hash::{FxHashMap, FxHashSet};
10use tracing::trace;
11
12use crate::{
13 PathId, Side, VarInt,
14 frame::{AddAddress, ReachOut, RemoveAddress},
15};
16
17type IpPort = (IpAddr, u16);
18
19#[derive(Debug, thiserror::Error)]
21pub enum Error {
22 #[error("Tried to add too many addresses to their advertised set")]
24 TooManyAddresses,
25 #[error("Not allowed for this endpoint's connection side")]
27 WrongConnectionSide,
28 #[error("Iroh's nat traversal was not negotiated")]
30 ExtensionNotNegotiated,
31 #[error("Not enough addresses")]
33 NotEnoughAddresses,
34 #[error("Failed to establish paths {0}")]
36 Multipath(super::PathError),
37}
38
39pub(crate) struct NatTraversalRound {
40 pub(crate) new_round: VarInt,
42 pub(crate) reach_out_at: Vec<(IpAddr, u16)>,
44 pub(crate) addresses_to_probe: Vec<(IpAddr, u16)>,
46 pub(crate) prev_round_path_ids: Vec<PathId>,
48}
49
50#[derive(Debug, Clone)]
52pub enum Event {
53 AddressAdded(SocketAddr),
55 AddressRemoved(SocketAddr),
57}
58
59#[derive(Debug, Default)]
61pub(crate) enum State {
62 #[default]
63 NotNegotiated,
64 ClientSide(ClientState),
65 ServerSide(ServerState),
66}
67
68#[derive(Debug)]
69pub(crate) struct ClientState {
70 max_remote_addresses: usize,
74 max_local_addresses: usize,
78 remote_addresses: FxHashMap<VarInt, (IpAddr, u16)>,
81 local_addresses: FxHashSet<(IpAddr, u16)>,
84 round: VarInt,
86 round_path_ids: Vec<PathId>,
88}
89
90impl ClientState {
91 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
92 Self {
93 max_remote_addresses,
94 max_local_addresses,
95 remote_addresses: Default::default(),
96 local_addresses: Default::default(),
97 round: Default::default(),
98 round_path_ids: Default::default(),
99 }
100 }
101
102 fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
103 if self.local_addresses.len() < self.max_local_addresses {
104 self.local_addresses.insert(address);
105 Ok(())
106 } else if self.local_addresses.contains(&address) {
107 Ok(())
109 } else {
110 Err(Error::TooManyAddresses)
112 }
113 }
114
115 fn remove_local_address(&mut self, address: &IpPort) {
116 self.local_addresses.remove(address);
117 }
118
119 pub(crate) fn initiate_nat_traversal_round(&mut self) -> Result<NatTraversalRound, Error> {
126 if self.local_addresses.is_empty() {
127 return Err(Error::NotEnoughAddresses);
128 }
129
130 let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
131 self.round = self.round.saturating_add(1u8);
132
133 Ok(NatTraversalRound {
134 new_round: self.round,
135 reach_out_at: self.local_addresses.iter().copied().collect(),
136 addresses_to_probe: self.remote_addresses.values().copied().collect(),
137 prev_round_path_ids,
138 })
139 }
140
141 pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
144 self.round_path_ids = path_ids;
145 }
146
147 pub(crate) fn add_remote_address(
152 &mut self,
153 add_addr: AddAddress,
154 ) -> Result<Option<SocketAddr>, Error> {
155 let AddAddress { seq_no, ip, port } = add_addr;
156 let address = (ip, port);
157 let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
158 match self.remote_addresses.entry(seq_no) {
159 Entry::Occupied(mut occupied_entry) => {
160 let old_value = occupied_entry.insert(address);
161 Ok((address != old_value).then_some(address.into()))
164 }
165 Entry::Vacant(vacant_entry) if allow_new => {
166 vacant_entry.insert(address);
167 Ok(Some(address.into()))
168 }
169 _ => Err(Error::TooManyAddresses),
170 }
171 }
172
173 pub(crate) fn remove_remote_address(
177 &mut self,
178 remove_addr: RemoveAddress,
179 ) -> Option<SocketAddr> {
180 self.remote_addresses
181 .remove(&remove_addr.seq_no)
182 .map(Into::into)
183 }
184
185 pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
189 let existing = self.remote_addresses.get(&add_addr.seq_no);
190 existing.is_none() || existing == Some(&add_addr.ip_port())
191 }
192
193 pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
194 self.remote_addresses
195 .values()
196 .copied()
197 .map(Into::into)
198 .collect()
199 }
200}
201
202#[derive(Debug)]
203pub(crate) struct ServerState {
204 max_remote_addresses: usize,
208 max_local_addresses: usize,
212 local_addresses: FxHashMap<IpPort, VarInt>,
215 next_local_addr_id: VarInt,
217 round: VarInt,
222 pending_probes: FxHashSet<IpPort>,
224 active_probes: IntMap<u64, IpPort>,
228}
229
230impl ServerState {
231 fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
232 Self {
233 max_remote_addresses,
234 max_local_addresses,
235 local_addresses: Default::default(),
236 next_local_addr_id: Default::default(),
237 round: Default::default(),
238 pending_probes: Default::default(),
239 active_probes: Default::default(),
240 }
241 }
242
243 fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
244 let allow_new = self.local_addresses.len() < self.max_local_addresses;
245 match self.local_addresses.entry(address) {
246 Entry::Occupied(_) => Ok(None),
247 Entry::Vacant(vacant_entry) if allow_new => {
248 let id = self.next_local_addr_id;
249 self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
250 vacant_entry.insert(id);
251 Ok(Some(AddAddress::new(address, id)))
252 }
253 _ => Err(Error::TooManyAddresses),
254 }
255 }
256
257 fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
258 self.local_addresses.remove(address).map(RemoveAddress::new)
259 }
260
261 pub(crate) fn handle_reach_out(&mut self, reach_out: ReachOut) -> Result<(), Error> {
268 let ReachOut { round, ip, port } = reach_out;
269
270 if round < self.round {
271 trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
272 return Ok(());
273 }
274
275 if round > self.round {
276 self.pending_probes.clear();
277 for (token, remote) in self.active_probes.drain() {
280 let remote: SocketAddr = remote.into();
281 trace!(token=format!("{:08x}", token), %remote, "dropping nat traversal challenge pending response");
282 }
283 } else if self.pending_probes.len() >= self.max_remote_addresses {
284 return Err(Error::TooManyAddresses);
285 }
286 self.pending_probes.insert((ip, port));
287 Ok(())
288 }
289
290 pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
291 self.pending_probes
292 .iter()
293 .next()
294 .copied()
295 .map(|remote| ServerProbing {
296 remote,
297 pending_probes: &mut self.pending_probes,
298 active_probes: &mut self.active_probes,
299 })
300 }
301}
302
303pub(crate) struct ServerProbing<'a> {
304 remote: IpPort,
305 pending_probes: &'a mut FxHashSet<IpPort>,
306 active_probes: &'a mut IntMap<u64, IpPort>,
307}
308
309impl<'a> ServerProbing<'a> {
310 pub(crate) fn finish(self, token: u64) {
311 self.pending_probes.remove(&self.remote);
312 self.active_probes.insert(token, self.remote);
313 }
314
315 pub(crate) fn remote(&self) -> SocketAddr {
316 self.remote.into()
317 }
318}
319
320impl State {
321 pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
322 match side {
323 Side::Client => Self::ClientSide(ClientState::new(
324 max_remote_addresses.into(),
325 max_local_addresses.into(),
326 )),
327 Side::Server => Self::ServerSide(ServerState::new(
328 max_remote_addresses.into(),
329 max_local_addresses.into(),
330 )),
331 }
332 }
333
334 pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
335 match self {
336 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
337 Self::ClientSide(client_side) => Ok(client_side),
338 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
339 }
340 }
341
342 pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
343 match self {
344 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
345 Self::ClientSide(client_side) => Ok(client_side),
346 Self::ServerSide(_) => Err(Error::WrongConnectionSide),
347 }
348 }
349
350 pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
351 match self {
352 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
353 Self::ClientSide(_) => Err(Error::WrongConnectionSide),
354 Self::ServerSide(server_side) => Ok(server_side),
355 }
356 }
357
358 pub(crate) fn add_local_address(
366 &mut self,
367 address: SocketAddr,
368 ) -> Result<Option<AddAddress>, Error> {
369 match self {
370 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
371 Self::ClientSide(client_state) => {
372 client_state.add_local_address((address.ip(), address.port()))?;
373 Ok(None)
374 }
375 Self::ServerSide(server_state) => {
376 server_state.add_local_address((address.ip(), address.port()))
377 }
378 }
379 }
380
381 pub(crate) fn remove_local_address(
388 &mut self,
389 address: SocketAddr,
390 ) -> Result<Option<RemoveAddress>, Error> {
391 let address = &(address.ip(), address.port());
392 match self {
393 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
394 Self::ClientSide(client_state) => {
395 client_state.remove_local_address(address);
396 Ok(None)
397 }
398 Self::ServerSide(server_state) => Ok(server_state.remove_local_address(address)),
399 }
400 }
401
402 pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
403 match self {
404 Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
405 Self::ClientSide(client_state) => Ok(client_state
406 .local_addresses
407 .iter()
408 .copied()
409 .map(Into::into)
410 .collect()),
411 Self::ServerSide(server_state) => Ok(server_state
412 .local_addresses
413 .keys()
414 .copied()
415 .map(Into::into)
416 .collect()),
417 }
418 }
419}