iroh_quinn_proto/
iroh_hp.rs

1//! iroh NAT Traversal
2
3use std::{
4    collections::hash_map::Entry,
5    net::{IpAddr, SocketAddr},
6};
7
8use identity_hash::IntMap;
9use rustc_hash::{FxHashMap, FxHashSet};
10use tracing::trace;
11
12use crate::{
13    PathId, Side, VarInt,
14    frame::{AddAddress, ReachOut, RemoveAddress},
15};
16
17type IpPort = (IpAddr, u16);
18
19/// Errors that the nat traversal state might encounter.
20#[derive(Debug, thiserror::Error)]
21pub enum Error {
22    /// An endpoint (local or remote) tried to add too many addresses to their advertised set
23    #[error("Tried to add too many addresses to their advertised set")]
24    TooManyAddresses,
25    /// The operation is not allowed for this endpoint's connection side
26    #[error("Not allowed for this endpoint's connection side")]
27    WrongConnectionSide,
28    /// The extension was not negotiated
29    #[error("Iroh's nat traversal was not negotiated")]
30    ExtensionNotNegotiated,
31    /// Not enough addresses to complete the operation
32    #[error("Not enough addresses")]
33    NotEnoughAddresses,
34    /// Nat traversal attempt failed due to a multipath error
35    #[error("Failed to establish paths {0}")]
36    Multipath(super::PathError),
37    /// Attempted to initiate NAT traversal on a closed, or closing connection.
38    #[error("The connection is already closed")]
39    Closed,
40}
41
42pub(crate) struct NatTraversalRound {
43    /// Sequence number to use for the new reach out frames
44    pub(crate) new_round: VarInt,
45    /// Addresses to use to send reach out frames
46    pub(crate) reach_out_at: Vec<(IpAddr, u16)>,
47    /// Remotes to probe by attempting to open new paths
48    pub(crate) addresses_to_probe: Vec<(IpAddr, u16)>,
49    /// [`PathId`]s of the cancelled round
50    pub(crate) prev_round_path_ids: Vec<PathId>,
51}
52
53/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames.
54#[derive(Debug, Clone)]
55pub enum Event {
56    /// An ADD_ADDRESS frame was received.
57    AddressAdded(SocketAddr),
58    /// A REMOVE_ADDRESS frame was received.
59    AddressRemoved(SocketAddr),
60}
61
62/// State kept for Iroh's nat traversal
63#[derive(Debug, Default)]
64pub(crate) enum State {
65    #[default]
66    NotNegotiated,
67    ClientSide(ClientState),
68    ServerSide(ServerState),
69}
70
71#[derive(Debug)]
72pub(crate) struct ClientState {
73    /// Max number of remote addresses we allow
74    ///
75    /// This is set by the local endpoint.
76    max_remote_addresses: usize,
77    /// Max number of local addresses allowed
78    ///
79    /// This is set by the remote endpoint.
80    max_local_addresses: usize,
81    /// Candidate addresses the remote server reports as potentially reachable, to use for nat
82    /// traversal attempts.
83    remote_addresses: FxHashMap<VarInt, (IpAddr, u16)>,
84    /// Candidate addresses the local client reports as potentially reachable, to use for nat
85    /// traversal attempts. Always canonical.
86    local_addresses: FxHashSet<(IpAddr, u16)>,
87    /// Current nat holepunching round.
88    round: VarInt,
89    /// [`PathId`]s used to probe remotes assigned to this round.
90    round_path_ids: Vec<PathId>,
91}
92
93impl ClientState {
94    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
95        Self {
96            max_remote_addresses,
97            max_local_addresses,
98            remote_addresses: Default::default(),
99            local_addresses: Default::default(),
100            round: Default::default(),
101            round_path_ids: Default::default(),
102        }
103    }
104
105    fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
106        if self.local_addresses.len() < self.max_local_addresses {
107            self.local_addresses.insert(address);
108            Ok(())
109        } else if self.local_addresses.contains(&address) {
110            // at capacity, but the address is known, no issues here
111            Ok(())
112        } else {
113            // at capacity and the address is new
114            Err(Error::TooManyAddresses)
115        }
116    }
117
118    fn remove_local_address(&mut self, address: &IpPort) {
119        self.local_addresses.remove(address);
120    }
121
122    /// Initiates a new nat traversal round.
123    ///
124    /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT`
125    /// frames, and initiating probing of the known remote addresses. When a new round is
126    /// initiated, the previous one is cancelled, and paths that have not been opened should be
127    /// closed.
128    pub(crate) fn initiate_nat_traversal_round(&mut self) -> Result<NatTraversalRound, Error> {
129        if self.local_addresses.is_empty() {
130            return Err(Error::NotEnoughAddresses);
131        }
132
133        let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
134        self.round = self.round.saturating_add(1u8);
135
136        Ok(NatTraversalRound {
137            new_round: self.round,
138            reach_out_at: self.local_addresses.iter().copied().collect(),
139            addresses_to_probe: self.remote_addresses.values().copied().collect(),
140            prev_round_path_ids,
141        })
142    }
143
144    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
145    /// advertised addresses.
146    pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
147        self.round_path_ids = path_ids;
148    }
149
150    /// Adds an address to the remote set
151    ///
152    /// On success returns the address if it was new to the set. It will error when the set has no
153    /// capacity for the address.
154    pub(crate) fn add_remote_address(
155        &mut self,
156        add_addr: AddAddress,
157    ) -> Result<Option<SocketAddr>, Error> {
158        let AddAddress { seq_no, ip, port } = add_addr;
159        let address = (ip, port);
160        let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
161        match self.remote_addresses.entry(seq_no) {
162            Entry::Occupied(mut occupied_entry) => {
163                let old_value = occupied_entry.insert(address);
164                // The value might be different. This should not happen, but we assume that the new
165                // address is more recent than the previous, and thus worth updating
166                Ok((address != old_value).then_some(address.into()))
167            }
168            Entry::Vacant(vacant_entry) if allow_new => {
169                vacant_entry.insert(address);
170                Ok(Some(address.into()))
171            }
172            _ => Err(Error::TooManyAddresses),
173        }
174    }
175
176    /// Removes an address from the remote set
177    ///
178    /// Returns whether the address was present.
179    pub(crate) fn remove_remote_address(
180        &mut self,
181        remove_addr: RemoveAddress,
182    ) -> Option<SocketAddr> {
183        self.remote_addresses
184            .remove(&remove_addr.seq_no)
185            .map(Into::into)
186    }
187
188    /// Checks that a received remote address is valid
189    ///
190    /// An address is valid as long as it does not change the value of a known address id.
191    pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
192        let existing = self.remote_addresses.get(&add_addr.seq_no);
193        existing.is_none() || existing == Some(&add_addr.ip_port())
194    }
195
196    pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
197        self.remote_addresses
198            .values()
199            .copied()
200            .map(Into::into)
201            .collect()
202    }
203}
204
205#[derive(Debug)]
206pub(crate) struct ServerState {
207    /// Max number of remote addresses we allow.
208    ///
209    /// This is set by the local endpoint.
210    max_remote_addresses: usize,
211    /// Max number of local addresses allowed.
212    ///
213    /// This is set by the remote endpoint.
214    max_local_addresses: usize,
215    /// Candidate addresses the server reports as potentially reachable, to use for nat
216    /// traversal attempts.
217    local_addresses: FxHashMap<IpPort, VarInt>,
218    /// The next id to use for local addresses sent to the client.
219    next_local_addr_id: VarInt,
220    /// Current nat holepunching round
221    ///
222    /// Servers keep track of the client's most recent round and cancel probing related to previous
223    /// rounds.
224    round: VarInt,
225    /// Addresses to which PATH_CHALLENGES need to be sent.
226    pending_probes: FxHashSet<IpPort>,
227    /// Sent PATH_CHALLENGES for this round.
228    ///
229    /// This is used to validate the remotes assigned to each token.
230    active_probes: IntMap<u64, IpPort>,
231}
232
233impl ServerState {
234    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
235        Self {
236            max_remote_addresses,
237            max_local_addresses,
238            local_addresses: Default::default(),
239            next_local_addr_id: Default::default(),
240            round: Default::default(),
241            pending_probes: Default::default(),
242            active_probes: Default::default(),
243        }
244    }
245
246    fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
247        let allow_new = self.local_addresses.len() < self.max_local_addresses;
248        match self.local_addresses.entry(address) {
249            Entry::Occupied(_) => Ok(None),
250            Entry::Vacant(vacant_entry) if allow_new => {
251                let id = self.next_local_addr_id;
252                self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
253                vacant_entry.insert(id);
254                Ok(Some(AddAddress::new(address, id)))
255            }
256            _ => Err(Error::TooManyAddresses),
257        }
258    }
259
260    fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
261        self.local_addresses.remove(address).map(RemoveAddress::new)
262    }
263
264    /// Handles a received [`ReachOut`].
265    ///
266    /// It returns the token that should be sent in response to this frame as a challenge, and
267    /// whether this starts a new nat traversal round.
268    ///
269    /// If this frame was ignored, it returns `None`.
270    pub(crate) fn handle_reach_out(&mut self, reach_out: ReachOut) -> Result<(), Error> {
271        let ReachOut { round, ip, port } = reach_out;
272
273        if round < self.round {
274            trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
275            return Ok(());
276        }
277
278        if round > self.round {
279            self.round = round;
280            self.pending_probes.clear();
281            // TODO(@divma): This log is here because I'm not sure if dropping the challenges
282            // without further interaction with the connection is going to cause issues.
283            for (token, remote) in self.active_probes.drain() {
284                let remote: SocketAddr = remote.into();
285                trace!(token=format!("{:08x}", token), %remote, "dropping nat traversal challenge pending response");
286            }
287        } else if self.pending_probes.len() >= self.max_remote_addresses {
288            return Err(Error::TooManyAddresses);
289        }
290        self.pending_probes.insert((ip, port));
291        Ok(())
292    }
293
294    pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
295        self.pending_probes
296            .iter()
297            .next()
298            .copied()
299            .map(|remote| ServerProbing {
300                remote,
301                pending_probes: &mut self.pending_probes,
302                active_probes: &mut self.active_probes,
303            })
304    }
305}
306
307pub(crate) struct ServerProbing<'a> {
308    remote: IpPort,
309    pending_probes: &'a mut FxHashSet<IpPort>,
310    active_probes: &'a mut IntMap<u64, IpPort>,
311}
312
313impl<'a> ServerProbing<'a> {
314    pub(crate) fn finish(self, token: u64) {
315        self.pending_probes.remove(&self.remote);
316        self.active_probes.insert(token, self.remote);
317    }
318
319    pub(crate) fn remote(&self) -> SocketAddr {
320        self.remote.into()
321    }
322}
323
324impl State {
325    pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
326        match side {
327            Side::Client => Self::ClientSide(ClientState::new(
328                max_remote_addresses.into(),
329                max_local_addresses.into(),
330            )),
331            Side::Server => Self::ServerSide(ServerState::new(
332                max_remote_addresses.into(),
333                max_local_addresses.into(),
334            )),
335        }
336    }
337
338    pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
339        match self {
340            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
341            Self::ClientSide(client_side) => Ok(client_side),
342            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
343        }
344    }
345
346    pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
347        match self {
348            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
349            Self::ClientSide(client_side) => Ok(client_side),
350            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
351        }
352    }
353
354    pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
355        match self {
356            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
357            Self::ClientSide(_) => Err(Error::WrongConnectionSide),
358            Self::ServerSide(server_side) => Ok(server_side),
359        }
360    }
361
362    /// Adds a local address to use for nat traversal.
363    ///
364    /// When this endpoint is the server within the connection, these addresses will be sent to the
365    /// client in add address frames. For clients, these addresses will be sent in reach out frames
366    /// when nat traversal attempts are initiated.
367    ///
368    /// If a frame should be sent, it is returned.
369    pub(crate) fn add_local_address(
370        &mut self,
371        address: SocketAddr,
372    ) -> Result<Option<AddAddress>, Error> {
373        match self {
374            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
375            Self::ClientSide(client_state) => {
376                client_state.add_local_address((address.ip(), address.port()))?;
377                Ok(None)
378            }
379            Self::ServerSide(server_state) => {
380                server_state.add_local_address((address.ip(), address.port()))
381            }
382        }
383    }
384
385    /// Removes a local address from the advertised set for nat traversal.
386    ///
387    /// When this endpoint is the server, removed addresses must be reported with remove address
388    /// frames. Clients will simply stop reporting these addresses in reach out frames.
389    ///
390    /// If a frame should be sent, it is returned.
391    pub(crate) fn remove_local_address(
392        &mut self,
393        address: SocketAddr,
394    ) -> Result<Option<RemoveAddress>, Error> {
395        let address = &(address.ip(), address.port());
396        match self {
397            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
398            Self::ClientSide(client_state) => {
399                client_state.remove_local_address(address);
400                Ok(None)
401            }
402            Self::ServerSide(server_state) => Ok(server_state.remove_local_address(address)),
403        }
404    }
405
406    pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
407        match self {
408            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
409            Self::ClientSide(client_state) => Ok(client_state
410                .local_addresses
411                .iter()
412                .copied()
413                .map(Into::into)
414                .collect()),
415            Self::ServerSide(server_state) => Ok(server_state
416                .local_addresses
417                .keys()
418                .copied()
419                .map(Into::into)
420                .collect()),
421        }
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_basic_server_state() {
431        let mut state = ServerState::new(2, 2);
432
433        state
434            .handle_reach_out(ReachOut {
435                round: 1u32.into(),
436                ip: std::net::Ipv4Addr::LOCALHOST.into(),
437                port: 1,
438            })
439            .unwrap();
440
441        state
442            .handle_reach_out(ReachOut {
443                round: 1u32.into(),
444                ip: "1.1.1.1".parse().unwrap(), //std::net::Ipv4Addr::LOCALHOST.into(),
445                port: 2,
446            })
447            .unwrap();
448
449        dbg!(&state);
450        assert_eq!(state.pending_probes.len(), 2);
451        assert_eq!(state.active_probes.len(), 0);
452
453        let probe = state.next_probe().unwrap();
454        probe.finish(1);
455        let probe = state.next_probe().unwrap();
456        probe.finish(2);
457
458        assert!(state.next_probe().is_none());
459        assert_eq!(state.pending_probes.len(), 0);
460        assert_eq!(state.active_probes.len(), 2);
461    }
462}