iroh_quinn_proto/
iroh_hp.rs

1//! iroh NAT Traversal
2
3use std::{
4    collections::hash_map::Entry,
5    net::{IpAddr, SocketAddr},
6};
7
8use identity_hash::IntMap;
9use rustc_hash::{FxHashMap, FxHashSet};
10use tracing::trace;
11
12use crate::{
13    PathId, Side, VarInt,
14    frame::{AddAddress, ReachOut, RemoveAddress},
15};
16
17type IpPort = (IpAddr, u16);
18
19/// Errors that the nat traversal state might encounter.
20#[derive(Debug, thiserror::Error)]
21pub enum Error {
22    /// An endpoint (local or remote) tried to add too many addresses to their advertised set
23    #[error("Tried to add too many addresses to their advertised set")]
24    TooManyAddresses,
25    /// The operation is not allowed for this endpoint's connection side
26    #[error("Not allowed for this endpoint's connection side")]
27    WrongConnectionSide,
28    /// The extension was not negotiated
29    #[error("Iroh's nat traversal was not negotiated")]
30    ExtensionNotNegotiated,
31    /// Not enough addresses to complete the operation
32    #[error("Not enough addresses")]
33    NotEnoughAddresses,
34    /// Nat traversal attempt failed due to a multipath error
35    #[error("Failed to establish paths {0}")]
36    Multipath(super::PathError),
37    /// Attempted to initiate NAT traversal on a closed, or closing connection.
38    #[error("The connection is already closed")]
39    Closed,
40}
41
42pub(crate) struct NatTraversalRound {
43    /// Sequence number to use for the new reach out frames.
44    pub(crate) new_round: VarInt,
45    /// Addresses to use to send reach out frames.
46    pub(crate) reach_out_at: Vec<IpPort>,
47    /// Remotes to probe by attempting to open new paths.
48    ///
49    /// The addresses include their Id, so that it can be used to signal these should be returned
50    /// in a nat traversal continuation by calling [`ClientState::report_in_continuation`].
51    pub(crate) addresses_to_probe: Vec<(VarInt, IpPort)>,
52    /// [`PathId`]s of the cancelled round.
53    pub(crate) prev_round_path_ids: Vec<PathId>,
54}
55
56/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames.
57#[derive(Debug, Clone)]
58pub enum Event {
59    /// An ADD_ADDRESS frame was received.
60    AddressAdded(SocketAddr),
61    /// A REMOVE_ADDRESS frame was received.
62    AddressRemoved(SocketAddr),
63}
64
65/// State kept for Iroh's nat traversal
66#[derive(Debug, Default)]
67pub(crate) enum State {
68    #[default]
69    NotNegotiated,
70    ClientSide(ClientState),
71    ServerSide(ServerState),
72}
73
74#[derive(Debug)]
75pub(crate) struct ClientState {
76    /// Max number of remote addresses we allow
77    ///
78    /// This is set by the local endpoint.
79    max_remote_addresses: usize,
80    /// Max number of local addresses allowed
81    ///
82    /// This is set by the remote endpoint.
83    max_local_addresses: usize,
84    /// Candidate addresses the remote server reports as potentially reachable, to use for nat
85    /// traversal attempts.
86    ///
87    /// These are indexed by their advertised Id. For each address, whether the address should be
88    /// reported in nat traversal continuations is kept.
89    remote_addresses: FxHashMap<VarInt, (IpPort, bool)>,
90    /// Candidate addresses the local client reports as potentially reachable, to use for nat
91    /// traversal attempts.
92    local_addresses: FxHashSet<IpPort>,
93    /// Current nat holepunching round.
94    round: VarInt,
95    /// [`PathId`]s used to probe remotes assigned to this round.
96    round_path_ids: Vec<PathId>,
97}
98
99impl ClientState {
100    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
101        Self {
102            max_remote_addresses,
103            max_local_addresses,
104            remote_addresses: Default::default(),
105            local_addresses: Default::default(),
106            round: Default::default(),
107            round_path_ids: Default::default(),
108        }
109    }
110
111    fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
112        if self.local_addresses.len() < self.max_local_addresses {
113            self.local_addresses.insert(address);
114            Ok(())
115        } else if self.local_addresses.contains(&address) {
116            // at capacity, but the address is known, no issues here
117            Ok(())
118        } else {
119            // at capacity and the address is new
120            Err(Error::TooManyAddresses)
121        }
122    }
123
124    fn remove_local_address(&mut self, address: &IpPort) {
125        self.local_addresses.remove(address);
126    }
127
128    /// Initiates a new nat traversal round.
129    ///
130    /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT`
131    /// frames, and initiating probing of the known remote addresses. When a new round is
132    /// initiated, the previous one is cancelled, and paths that have not been opened should be
133    /// closed.
134    pub(crate) fn initiate_nat_traversal_round(&mut self) -> Result<NatTraversalRound, Error> {
135        if self.local_addresses.is_empty() {
136            return Err(Error::NotEnoughAddresses);
137        }
138
139        let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
140        self.round = self.round.saturating_add(1u8);
141        let mut addresses_to_probe = Vec::with_capacity(self.remote_addresses.len());
142        for (id, (address, report_in_continuation)) in self.remote_addresses.iter_mut() {
143            addresses_to_probe.push((*id, *address));
144            *report_in_continuation = false;
145        }
146
147        Ok(NatTraversalRound {
148            new_round: self.round,
149            reach_out_at: self.local_addresses.iter().copied().collect(),
150            addresses_to_probe,
151            prev_round_path_ids,
152        })
153    }
154
155    /// Mark a remote address to be reported back in a nat traversal continuation if the error is
156    /// considered spurious from a nat traversal point of view.
157    ///
158    /// Ids not present are silently ignored.
159    pub(crate) fn report_in_continuation(&mut self, id: VarInt, e: crate::PathError) {
160        match e {
161            crate::PathError::MaxPathIdReached | crate::PathError::RemoteCidsExhausted => {
162                if let Some((_address, report_in_continuation)) = self.remote_addresses.get_mut(&id)
163                {
164                    *report_in_continuation = true;
165                }
166            }
167            _ => {}
168        }
169    }
170
171    /// Returns an address that needs to be probed, if any.
172    ///
173    /// The address will not be returned twice unless marked as such again with
174    /// [`Self::report_in_continuation`].
175    pub(crate) fn continue_nat_traversal_round(&mut self) -> Option<(VarInt, IpPort)> {
176        // this being random depends on iteration not returning always on the same order
177        let (id, (address, report_in_continuation)) = self
178            .remote_addresses
179            .iter_mut()
180            .find(|(_id, (_addr, report))| *report)?;
181        *report_in_continuation = false;
182        Some((*id, *address))
183    }
184
185    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
186    /// advertised addresses.
187    pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
188        self.round_path_ids = path_ids;
189    }
190
191    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
192    /// advertised addresses.
193    pub(crate) fn add_round_path_id(&mut self, path_id: PathId) {
194        self.round_path_ids.push(path_id);
195    }
196
197    /// Adds an address to the remote set
198    ///
199    /// On success returns the address if it was new to the set. It will error when the set has no
200    /// capacity for the address.
201    pub(crate) fn add_remote_address(
202        &mut self,
203        add_addr: AddAddress,
204    ) -> Result<Option<SocketAddr>, Error> {
205        let AddAddress { seq_no, ip, port } = add_addr;
206        let address = (ip, port);
207        let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
208        match self.remote_addresses.entry(seq_no) {
209            Entry::Occupied(mut occupied_entry) => {
210                let is_update = occupied_entry.get().0 != address;
211                if is_update {
212                    occupied_entry.insert((address, false));
213                }
214                // The value might be different. This should not happen, but we assume that the new
215                // address is more recent than the previous, and thus worth updating
216                Ok(is_update.then_some(address.into()))
217            }
218            Entry::Vacant(vacant_entry) if allow_new => {
219                vacant_entry.insert((address, false));
220                Ok(Some(address.into()))
221            }
222            _ => Err(Error::TooManyAddresses),
223        }
224    }
225
226    /// Removes an address from the remote set
227    ///
228    /// Returns whether the address was present.
229    pub(crate) fn remove_remote_address(
230        &mut self,
231        remove_addr: RemoveAddress,
232    ) -> Option<SocketAddr> {
233        self.remote_addresses
234            .remove(&remove_addr.seq_no)
235            .map(|(address, _report_in_continuation)| address.into())
236    }
237
238    /// Checks that a received remote address is valid
239    ///
240    /// An address is valid as long as it does not change the value of a known address id.
241    pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
242        match self.remote_addresses.get(&add_addr.seq_no) {
243            None => true,
244            Some((existing, _)) => existing == &add_addr.ip_port(),
245        }
246    }
247
248    pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
249        self.remote_addresses
250            .values()
251            .map(|(address, _report_in_continuation)| (*address).into())
252            .collect()
253    }
254}
255
256#[derive(Debug)]
257pub(crate) struct ServerState {
258    /// Max number of remote addresses we allow.
259    ///
260    /// This is set by the local endpoint.
261    max_remote_addresses: usize,
262    /// Max number of local addresses allowed.
263    ///
264    /// This is set by the remote endpoint.
265    max_local_addresses: usize,
266    /// Candidate addresses the server reports as potentially reachable, to use for nat
267    /// traversal attempts.
268    local_addresses: FxHashMap<IpPort, VarInt>,
269    /// The next id to use for local addresses sent to the client.
270    next_local_addr_id: VarInt,
271    /// Current nat holepunching round
272    ///
273    /// Servers keep track of the client's most recent round and cancel probing related to previous
274    /// rounds.
275    round: VarInt,
276    /// Addresses to which PATH_CHALLENGES need to be sent.
277    pending_probes: FxHashSet<IpPort>,
278    /// Sent PATH_CHALLENGES for this round.
279    ///
280    /// This is used to validate the remotes assigned to each token.
281    active_probes: IntMap<u64, IpPort>,
282}
283
284impl ServerState {
285    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
286        Self {
287            max_remote_addresses,
288            max_local_addresses,
289            local_addresses: Default::default(),
290            next_local_addr_id: Default::default(),
291            round: Default::default(),
292            pending_probes: Default::default(),
293            active_probes: Default::default(),
294        }
295    }
296
297    fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
298        let allow_new = self.local_addresses.len() < self.max_local_addresses;
299        match self.local_addresses.entry(address) {
300            Entry::Occupied(_) => Ok(None),
301            Entry::Vacant(vacant_entry) if allow_new => {
302                let id = self.next_local_addr_id;
303                self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
304                vacant_entry.insert(id);
305                Ok(Some(AddAddress::new(address, id)))
306            }
307            _ => Err(Error::TooManyAddresses),
308        }
309    }
310
311    fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
312        self.local_addresses.remove(address).map(RemoveAddress::new)
313    }
314
315    /// Handles a received [`ReachOut`].
316    ///
317    /// It returns the token that should be sent in response to this frame as a challenge, and
318    /// whether this starts a new nat traversal round.
319    ///
320    /// If this frame was ignored, it returns `None`.
321    pub(crate) fn handle_reach_out(&mut self, reach_out: ReachOut) -> Result<(), Error> {
322        let ReachOut { round, ip, port } = reach_out;
323
324        if round < self.round {
325            trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
326            return Ok(());
327        }
328
329        if round > self.round {
330            self.round = round;
331            self.pending_probes.clear();
332            // TODO(@divma): This log is here because I'm not sure if dropping the challenges
333            // without further interaction with the connection is going to cause issues.
334            for (token, remote) in self.active_probes.drain() {
335                let remote: SocketAddr = remote.into();
336                trace!(token=format!("{:08x}", token), %remote, "dropping nat traversal challenge pending response");
337            }
338        } else if self.pending_probes.len() >= self.max_remote_addresses {
339            return Err(Error::TooManyAddresses);
340        }
341        self.pending_probes.insert((ip, port));
342        Ok(())
343    }
344
345    pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
346        self.pending_probes
347            .iter()
348            .next()
349            .copied()
350            .map(|remote| ServerProbing {
351                remote,
352                pending_probes: &mut self.pending_probes,
353                active_probes: &mut self.active_probes,
354            })
355    }
356}
357
358pub(crate) struct ServerProbing<'a> {
359    remote: IpPort,
360    pending_probes: &'a mut FxHashSet<IpPort>,
361    active_probes: &'a mut IntMap<u64, IpPort>,
362}
363
364impl<'a> ServerProbing<'a> {
365    pub(crate) fn finish(self, token: u64) {
366        self.pending_probes.remove(&self.remote);
367        self.active_probes.insert(token, self.remote);
368    }
369
370    pub(crate) fn remote(&self) -> SocketAddr {
371        self.remote.into()
372    }
373}
374
375impl State {
376    pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
377        match side {
378            Side::Client => Self::ClientSide(ClientState::new(
379                max_remote_addresses.into(),
380                max_local_addresses.into(),
381            )),
382            Side::Server => Self::ServerSide(ServerState::new(
383                max_remote_addresses.into(),
384                max_local_addresses.into(),
385            )),
386        }
387    }
388
389    pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
390        match self {
391            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
392            Self::ClientSide(client_side) => Ok(client_side),
393            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
394        }
395    }
396
397    pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
398        match self {
399            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
400            Self::ClientSide(client_side) => Ok(client_side),
401            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
402        }
403    }
404
405    pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
406        match self {
407            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
408            Self::ClientSide(_) => Err(Error::WrongConnectionSide),
409            Self::ServerSide(server_side) => Ok(server_side),
410        }
411    }
412
413    /// Adds a local address to use for nat traversal.
414    ///
415    /// When this endpoint is the server within the connection, these addresses will be sent to the
416    /// client in add address frames. For clients, these addresses will be sent in reach out frames
417    /// when nat traversal attempts are initiated.
418    ///
419    /// If a frame should be sent, it is returned.
420    pub(crate) fn add_local_address(
421        &mut self,
422        address: SocketAddr,
423    ) -> Result<Option<AddAddress>, Error> {
424        match self {
425            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
426            Self::ClientSide(client_state) => {
427                client_state.add_local_address((address.ip(), address.port()))?;
428                Ok(None)
429            }
430            Self::ServerSide(server_state) => {
431                server_state.add_local_address((address.ip(), address.port()))
432            }
433        }
434    }
435
436    /// Removes a local address from the advertised set for nat traversal.
437    ///
438    /// When this endpoint is the server, removed addresses must be reported with remove address
439    /// frames. Clients will simply stop reporting these addresses in reach out frames.
440    ///
441    /// If a frame should be sent, it is returned.
442    pub(crate) fn remove_local_address(
443        &mut self,
444        address: SocketAddr,
445    ) -> Result<Option<RemoveAddress>, Error> {
446        let address = &(address.ip(), address.port());
447        match self {
448            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
449            Self::ClientSide(client_state) => {
450                client_state.remove_local_address(address);
451                Ok(None)
452            }
453            Self::ServerSide(server_state) => Ok(server_state.remove_local_address(address)),
454        }
455    }
456
457    pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
458        match self {
459            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
460            Self::ClientSide(client_state) => Ok(client_state
461                .local_addresses
462                .iter()
463                .copied()
464                .map(Into::into)
465                .collect()),
466            Self::ServerSide(server_state) => Ok(server_state
467                .local_addresses
468                .keys()
469                .copied()
470                .map(Into::into)
471                .collect()),
472        }
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_basic_server_state() {
482        let mut state = ServerState::new(2, 2);
483
484        state
485            .handle_reach_out(ReachOut {
486                round: 1u32.into(),
487                ip: std::net::Ipv4Addr::LOCALHOST.into(),
488                port: 1,
489            })
490            .unwrap();
491
492        state
493            .handle_reach_out(ReachOut {
494                round: 1u32.into(),
495                ip: "1.1.1.1".parse().unwrap(), //std::net::Ipv4Addr::LOCALHOST.into(),
496                port: 2,
497            })
498            .unwrap();
499
500        dbg!(&state);
501        assert_eq!(state.pending_probes.len(), 2);
502        assert_eq!(state.active_probes.len(), 0);
503
504        let probe = state.next_probe().unwrap();
505        probe.finish(1);
506        let probe = state.next_probe().unwrap();
507        probe.finish(2);
508
509        assert!(state.next_probe().is_none());
510        assert_eq!(state.pending_probes.len(), 0);
511        assert_eq!(state.active_probes.len(), 2);
512    }
513}