iroh_quinn_proto/
iroh_hp.rs

1//! iroh NAT Traversal
2
3use std::{
4    collections::hash_map::Entry,
5    net::{IpAddr, SocketAddr},
6};
7
8use identity_hash::IntMap;
9use rustc_hash::{FxHashMap, FxHashSet};
10use tracing::trace;
11
12use crate::{
13    PathId, Side, VarInt,
14    frame::{AddAddress, ReachOut, RemoveAddress},
15};
16
17type IpPort = (IpAddr, u16);
18
19/// Errors that the nat traversal state might encounter.
20#[derive(Debug, thiserror::Error)]
21pub enum Error {
22    /// An endpoint (local or remote) tried to add too many addresses to their advertised set
23    #[error("Tried to add too many addresses to their advertised set")]
24    TooManyAddresses,
25    /// The operation is not allowed for this endpoint's connection side
26    #[error("Not allowed for this endpoint's connection side")]
27    WrongConnectionSide,
28    /// The extension was not negotiated
29    #[error("Iroh's nat traversal was not negotiated")]
30    ExtensionNotNegotiated,
31    /// Not enough addresses to complete the operation
32    #[error("Not enough addresses")]
33    NotEnoughAddresses,
34    /// Nat traversal attempt failed due to a multipath error
35    #[error("Failed to establish paths {0}")]
36    Multipath(super::PathError),
37}
38
39pub(crate) struct NatTraversalRound {
40    /// Sequence number to use for the new reach out frames
41    pub(crate) new_round: VarInt,
42    /// Addresses to use to send reach out frames
43    pub(crate) reach_out_at: Vec<(IpAddr, u16)>,
44    /// Remotes to probe by attempting to open new paths
45    pub(crate) addresses_to_probe: Vec<(IpAddr, u16)>,
46    /// [`PathId`]s of the cancelled round
47    pub(crate) prev_round_path_ids: Vec<PathId>,
48}
49
50/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames.
51#[derive(Debug, Clone)]
52pub enum Event {
53    /// An ADD_ADDRESS frame was received.
54    AddressAdded(SocketAddr),
55    /// A REMOVE_ADDRESS frame was received.
56    AddressRemoved(SocketAddr),
57}
58
59/// State kept for Iroh's nat traversal
60#[derive(Debug, Default)]
61pub(crate) enum State {
62    #[default]
63    NotNegotiated,
64    ClientSide(ClientState),
65    ServerSide(ServerState),
66}
67
68#[derive(Debug)]
69pub(crate) struct ClientState {
70    /// Max number of remote addresses we allow
71    ///
72    /// This is set by the local endpoint.
73    max_remote_addresses: usize,
74    /// Max number of local addresses allowed
75    ///
76    /// This is set by the remote endpoint.
77    max_local_addresses: usize,
78    /// Candidate addresses the remote server reports as potentially reachable, to use for nat
79    /// traversal attempts.
80    remote_addresses: FxHashMap<VarInt, (IpAddr, u16)>,
81    /// Candidate addresses the local client reports as potentially reachable, to use for nat
82    /// traversal attempts. Always canonical.
83    local_addresses: FxHashSet<(IpAddr, u16)>,
84    /// Current nat holepunching round.
85    round: VarInt,
86    /// [`PathId`]s used to probe remotes assigned to this round.
87    round_path_ids: Vec<PathId>,
88}
89
90impl ClientState {
91    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
92        Self {
93            max_remote_addresses,
94            max_local_addresses,
95            remote_addresses: Default::default(),
96            local_addresses: Default::default(),
97            round: Default::default(),
98            round_path_ids: Default::default(),
99        }
100    }
101
102    fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
103        if self.local_addresses.len() < self.max_local_addresses {
104            self.local_addresses.insert(address);
105            Ok(())
106        } else if self.local_addresses.contains(&address) {
107            // at capacity, but the address is known, no issues here
108            Ok(())
109        } else {
110            // at capacity and the address is new
111            Err(Error::TooManyAddresses)
112        }
113    }
114
115    fn remove_local_address(&mut self, address: &IpPort) {
116        self.local_addresses.remove(address);
117    }
118
119    /// Initiates a new nat traversal round.
120    ///
121    /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT`
122    /// frames, and initiating probing of the known remote addresses. When a new round is
123    /// initiated, the previous one is cancelled, and paths that have not been opened should be
124    /// closed.
125    pub(crate) fn initiate_nat_traversal_round(&mut self) -> Result<NatTraversalRound, Error> {
126        if self.local_addresses.is_empty() {
127            return Err(Error::NotEnoughAddresses);
128        }
129
130        let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
131        self.round = self.round.saturating_add(1u8);
132
133        Ok(NatTraversalRound {
134            new_round: self.round,
135            reach_out_at: self.local_addresses.iter().copied().collect(),
136            addresses_to_probe: self.remote_addresses.values().copied().collect(),
137            prev_round_path_ids,
138        })
139    }
140
141    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
142    /// advertised addresses.
143    pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
144        self.round_path_ids = path_ids;
145    }
146
147    /// Adds an address to the remote set
148    ///
149    /// On success returns the address if it was new to the set. It will error when the set has no
150    /// capacity for the address.
151    pub(crate) fn add_remote_address(
152        &mut self,
153        add_addr: AddAddress,
154    ) -> Result<Option<SocketAddr>, Error> {
155        let AddAddress { seq_no, ip, port } = add_addr;
156        let address = (ip, port);
157        let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
158        match self.remote_addresses.entry(seq_no) {
159            Entry::Occupied(mut occupied_entry) => {
160                let old_value = occupied_entry.insert(address);
161                // The value might be different. This should not happen, but we assume that the new
162                // address is more recent than the previous, and thus worth updating
163                Ok((address != old_value).then_some(address.into()))
164            }
165            Entry::Vacant(vacant_entry) if allow_new => {
166                vacant_entry.insert(address);
167                Ok(Some(address.into()))
168            }
169            _ => Err(Error::TooManyAddresses),
170        }
171    }
172
173    /// Removes an address from the remote set
174    ///
175    /// Returns whether the address was present.
176    pub(crate) fn remove_remote_address(
177        &mut self,
178        remove_addr: RemoveAddress,
179    ) -> Option<SocketAddr> {
180        self.remote_addresses
181            .remove(&remove_addr.seq_no)
182            .map(Into::into)
183    }
184
185    /// Checks that a received remote address is valid
186    ///
187    /// An address is valid as long as it does not change the value of a known address id.
188    pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
189        let existing = self.remote_addresses.get(&add_addr.seq_no);
190        existing.is_none() || existing == Some(&add_addr.ip_port())
191    }
192
193    pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
194        self.remote_addresses
195            .values()
196            .copied()
197            .map(Into::into)
198            .collect()
199    }
200}
201
202#[derive(Debug)]
203pub(crate) struct ServerState {
204    /// Max number of remote addresses we allow.
205    ///
206    /// This is set by the local endpoint.
207    max_remote_addresses: usize,
208    /// Max number of local addresses allowed.
209    ///
210    /// This is set by the remote endpoint.
211    max_local_addresses: usize,
212    /// Candidate addresses the server reports as potentially reachable, to use for nat
213    /// traversal attempts.
214    local_addresses: FxHashMap<IpPort, VarInt>,
215    /// The next id to use for local addresses sent to the client.
216    next_local_addr_id: VarInt,
217    /// Current nat holepunching round
218    ///
219    /// Servers keep track of the client's most recent round and cancel probing related to previous
220    /// rounds.
221    round: VarInt,
222    /// Addresses to which PATH_CHALLENGES need to be sent.
223    pending_probes: FxHashSet<IpPort>,
224    /// Sent PATH_CHALLENGES for this round.
225    ///
226    /// This is used to validate the remotes assigned to each token.
227    active_probes: IntMap<u64, IpPort>,
228}
229
230impl ServerState {
231    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
232        Self {
233            max_remote_addresses,
234            max_local_addresses,
235            local_addresses: Default::default(),
236            next_local_addr_id: Default::default(),
237            round: Default::default(),
238            pending_probes: Default::default(),
239            active_probes: Default::default(),
240        }
241    }
242
243    fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
244        let allow_new = self.local_addresses.len() < self.max_local_addresses;
245        match self.local_addresses.entry(address) {
246            Entry::Occupied(_) => Ok(None),
247            Entry::Vacant(vacant_entry) if allow_new => {
248                let id = self.next_local_addr_id;
249                self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
250                vacant_entry.insert(id);
251                Ok(Some(AddAddress::new(address, id)))
252            }
253            _ => Err(Error::TooManyAddresses),
254        }
255    }
256
257    fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
258        self.local_addresses.remove(address).map(RemoveAddress::new)
259    }
260
261    /// Handles a received [`ReachOut`].
262    ///
263    /// It returns the token that should be sent in response to this frame as a challenge, and
264    /// whether this starts a new nat traversal round.
265    ///
266    /// If this frame was ignored, it returns `None`.
267    pub(crate) fn handle_reach_out(&mut self, reach_out: ReachOut) -> Result<(), Error> {
268        let ReachOut { round, ip, port } = reach_out;
269
270        if round < self.round {
271            trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
272            return Ok(());
273        }
274
275        if round > self.round {
276            self.pending_probes.clear();
277            // TODO(@divma): This log is here because I'm not sure if dropping the challenges
278            // without further interaction with the connection is going to cause issues.
279            for (token, remote) in self.active_probes.drain() {
280                let remote: SocketAddr = remote.into();
281                trace!(token=format!("{:08x}", token), %remote, "dropping nat traversal challenge pending response");
282            }
283        } else if self.pending_probes.len() >= self.max_remote_addresses {
284            return Err(Error::TooManyAddresses);
285        }
286        self.pending_probes.insert((ip, port));
287        Ok(())
288    }
289
290    pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
291        self.pending_probes
292            .iter()
293            .next()
294            .copied()
295            .map(|remote| ServerProbing {
296                remote,
297                pending_probes: &mut self.pending_probes,
298                active_probes: &mut self.active_probes,
299            })
300    }
301}
302
303pub(crate) struct ServerProbing<'a> {
304    remote: IpPort,
305    pending_probes: &'a mut FxHashSet<IpPort>,
306    active_probes: &'a mut IntMap<u64, IpPort>,
307}
308
309impl<'a> ServerProbing<'a> {
310    pub(crate) fn finish(self, token: u64) {
311        self.pending_probes.remove(&self.remote);
312        self.active_probes.insert(token, self.remote);
313    }
314
315    pub(crate) fn remote(&self) -> SocketAddr {
316        self.remote.into()
317    }
318}
319
320impl State {
321    pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
322        match side {
323            Side::Client => Self::ClientSide(ClientState::new(
324                max_remote_addresses.into(),
325                max_local_addresses.into(),
326            )),
327            Side::Server => Self::ServerSide(ServerState::new(
328                max_remote_addresses.into(),
329                max_local_addresses.into(),
330            )),
331        }
332    }
333
334    pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
335        match self {
336            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
337            Self::ClientSide(client_side) => Ok(client_side),
338            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
339        }
340    }
341
342    pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
343        match self {
344            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
345            Self::ClientSide(client_side) => Ok(client_side),
346            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
347        }
348    }
349
350    pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
351        match self {
352            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
353            Self::ClientSide(_) => Err(Error::WrongConnectionSide),
354            Self::ServerSide(server_side) => Ok(server_side),
355        }
356    }
357
358    /// Adds a local address to use for nat traversal.
359    ///
360    /// When this endpoint is the server within the connection, these addresses will be sent to the
361    /// client in add address frames. For clients, these addresses will be sent in reach out frames
362    /// when nat traversal attempts are initiated.
363    ///
364    /// If a frame should be sent, it is returned.
365    pub(crate) fn add_local_address(
366        &mut self,
367        address: SocketAddr,
368    ) -> Result<Option<AddAddress>, Error> {
369        match self {
370            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
371            Self::ClientSide(client_state) => {
372                client_state.add_local_address((address.ip(), address.port()))?;
373                Ok(None)
374            }
375            Self::ServerSide(server_state) => {
376                server_state.add_local_address((address.ip(), address.port()))
377            }
378        }
379    }
380
381    /// Removes a local address from the advertised set for nat traversal.
382    ///
383    /// When this endpoint is the server, removed addresses must be reported with remove address
384    /// frames. Clients will simply stop reporting these addresses in reach out frames.
385    ///
386    /// If a frame should be sent, it is returned.
387    pub(crate) fn remove_local_address(
388        &mut self,
389        address: SocketAddr,
390    ) -> Result<Option<RemoveAddress>, Error> {
391        let address = &(address.ip(), address.port());
392        match self {
393            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
394            Self::ClientSide(client_state) => {
395                client_state.remove_local_address(address);
396                Ok(None)
397            }
398            Self::ServerSide(server_state) => Ok(server_state.remove_local_address(address)),
399        }
400    }
401
402    pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
403        match self {
404            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
405            Self::ClientSide(client_state) => Ok(client_state
406                .local_addresses
407                .iter()
408                .copied()
409                .map(Into::into)
410                .collect()),
411            Self::ServerSide(server_state) => Ok(server_state
412                .local_addresses
413                .keys()
414                .copied()
415                .map(Into::into)
416                .collect()),
417        }
418    }
419}