noq_proto/
n0_nat_traversal.rs

1//! n0's (<https://n0.computer>) NAT Traversal protocol implementation.
2
3use std::{
4    collections::hash_map::Entry,
5    net::{IpAddr, SocketAddr},
6};
7
8use rustc_hash::{FxHashMap, FxHashSet};
9use tracing::trace;
10
11use crate::{
12    PathId, Side, VarInt,
13    frame::{AddAddress, ReachOut, RemoveAddress},
14};
15
16type IpPort = (IpAddr, u16);
17
18/// Errors that the nat traversal state might encounter.
19#[derive(Debug, thiserror::Error)]
20pub enum Error {
21    /// An endpoint (local or remote) tried to add too many addresses to their advertised set
22    #[error("Tried to add too many addresses to their advertised set")]
23    TooManyAddresses,
24    /// The operation is not allowed for this endpoint's connection side
25    #[error("Not allowed for this endpoint's connection side")]
26    WrongConnectionSide,
27    /// The extension was not negotiated
28    #[error("n0's nat traversal was not negotiated")]
29    ExtensionNotNegotiated,
30    /// Not enough addresses to complete the operation
31    #[error("Not enough addresses")]
32    NotEnoughAddresses,
33    /// Nat traversal attempt failed due to a multipath error
34    #[error("Failed to establish paths {0}")]
35    Multipath(super::PathError),
36    /// Attempted to initiate NAT traversal on a closed, or closing connection.
37    #[error("The connection is already closed")]
38    Closed,
39}
40
41pub(crate) struct NatTraversalRound {
42    /// Sequence number to use for the new reach out frames.
43    pub(crate) new_round: VarInt,
44    /// Addresses to use to send reach out frames.
45    pub(crate) reach_out_at: FxHashSet<IpPort>,
46    /// Remotes to probe by attempting to open new paths.
47    ///
48    /// The addresses include their Id, so that it can be used to signal these should be returned
49    /// in a nat traversal continuation by calling [`ClientState::report_in_continuation`].
50    ///
51    /// These are filtered and mapped to the IP family the local socket supports.
52    pub(crate) addresses_to_probe: Vec<(VarInt, IpPort)>,
53    /// [`PathId`]s of the cancelled round.
54    pub(crate) prev_round_path_ids: Vec<PathId>,
55}
56
57/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames.
58#[derive(Debug, Clone)]
59pub enum Event {
60    /// An ADD_ADDRESS frame was received.
61    AddressAdded(SocketAddr),
62    /// A REMOVE_ADDRESS frame was received.
63    AddressRemoved(SocketAddr),
64}
65
66/// State kept for n0's nat traversal
67#[derive(Debug, Default)]
68pub(crate) enum State {
69    #[default]
70    NotNegotiated,
71    ClientSide(ClientState),
72    ServerSide(ServerState),
73}
74
75#[derive(Debug)]
76pub(crate) struct ClientState {
77    /// Max number of remote addresses we allow
78    ///
79    /// This is set by the local endpoint.
80    max_remote_addresses: usize,
81    /// Max number of local addresses allowed
82    ///
83    /// This is set by the remote endpoint.
84    max_local_addresses: usize,
85    /// Candidate addresses the remote server reports as potentially reachable, to use for nat
86    /// traversal attempts.
87    ///
88    /// These are indexed by their advertised Id. For each address, whether the address should be
89    /// reported in nat traversal continuations is kept.
90    remote_addresses: FxHashMap<VarInt, (IpPort, bool)>,
91    /// Candidate addresses the local client reports as potentially reachable, to use for nat
92    /// traversal attempts.
93    local_addresses: FxHashSet<IpPort>,
94    /// Current nat traversal round.
95    round: VarInt,
96    /// [`PathId`]s used to probe remotes assigned to this round.
97    round_path_ids: Vec<PathId>,
98}
99
100impl ClientState {
101    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
102        Self {
103            max_remote_addresses,
104            max_local_addresses,
105            remote_addresses: Default::default(),
106            local_addresses: Default::default(),
107            round: Default::default(),
108            round_path_ids: Default::default(),
109        }
110    }
111
112    fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
113        if self.local_addresses.len() < self.max_local_addresses {
114            self.local_addresses.insert(address);
115            Ok(())
116        } else if self.local_addresses.contains(&address) {
117            // at capacity, but the address is known, no issues here
118            Ok(())
119        } else {
120            // at capacity and the address is new
121            Err(Error::TooManyAddresses)
122        }
123    }
124
125    fn remove_local_address(&mut self, address: &IpPort) {
126        self.local_addresses.remove(address);
127    }
128
129    /// Initiates a new nat traversal round.
130    ///
131    /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT`
132    /// frames, and initiating probing of the known remote addresses. When a new round is
133    /// initiated, the previous one is cancelled, and paths that have not been opened should be
134    /// closed.
135    ///
136    /// `ipv6` indicates if the connection runs on a socket that supports IPv6. If so, then all
137    /// addresses returned in [`NatTraversalRound`] will be IPv6 addresses (and IPv4-mapped IPv6
138    /// addresses if necessary). Otherwise they're all IPv4 addresses.
139    /// See also [`map_to_local_socket_family`].
140    pub(crate) fn initiate_nat_traversal_round(
141        &mut self,
142        ipv6: bool,
143    ) -> Result<NatTraversalRound, Error> {
144        if self.local_addresses.is_empty() {
145            return Err(Error::NotEnoughAddresses);
146        }
147
148        let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
149        self.round = self.round.saturating_add(1u8);
150        let mut addresses_to_probe = Vec::with_capacity(self.remote_addresses.len());
151        for (id, ((ip, port), report_in_continuation)) in self.remote_addresses.iter_mut() {
152            *report_in_continuation = false;
153
154            if let Some(ip) = map_to_local_socket_family(*ip, ipv6) {
155                addresses_to_probe.push((*id, (ip, *port)));
156            } else {
157                trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
158            }
159        }
160
161        Ok(NatTraversalRound {
162            new_round: self.round,
163            reach_out_at: self.local_addresses.iter().copied().collect(),
164            addresses_to_probe,
165            prev_round_path_ids,
166        })
167    }
168
169    /// Mark a remote address to be reported back in a nat traversal continuation if the error is
170    /// considered spurious from a nat traversal point of view.
171    ///
172    /// Ids not present are silently ignored.
173    pub(crate) fn report_in_continuation(&mut self, id: VarInt, e: crate::PathError) {
174        match e {
175            crate::PathError::MaxPathIdReached | crate::PathError::RemoteCidsExhausted => {
176                if let Some((_address, report_in_continuation)) = self.remote_addresses.get_mut(&id)
177                {
178                    *report_in_continuation = true;
179                }
180            }
181            _ => {}
182        }
183    }
184
185    /// Returns an address that needs to be probed, if any.
186    ///
187    /// The address will not be returned twice unless marked as such again with
188    /// [`Self::report_in_continuation`].
189    ///
190    /// `ipv6` indicates if the connection runs on a socket that supports IPv6. If so, then all
191    /// addresses returned in [`NatTraversalRound`] will be IPv6 addresses (and IPv4-mapped IPv6
192    /// addresses if necessary). Otherwise they're all IPv4 addresses.
193    /// See also [`map_to_local_socket_family`].
194    pub(crate) fn continue_nat_traversal_round(&mut self, ipv6: bool) -> Option<(VarInt, IpPort)> {
195        // this being random depends on iteration not returning always on the same order
196        let (id, (address, report_in_continuation)) = self
197            .remote_addresses
198            .iter_mut()
199            .filter(|(_id, (_addr, report))| *report)
200            .filter_map(|(id, ((ip, port), report))| {
201                // only continue with addresses we can send on our local socket
202                let Some(ip) = map_to_local_socket_family(*ip, ipv6) else {
203                    trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
204                    return None;
205                };
206                Some((*id, ((ip, *port), report)))
207            })
208            .next()?;
209        *report_in_continuation = false;
210        Some((id, address))
211    }
212
213    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
214    /// advertised addresses.
215    pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
216        self.round_path_ids = path_ids;
217    }
218
219    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
220    /// advertised addresses.
221    pub(crate) fn add_round_path_id(&mut self, path_id: PathId) {
222        self.round_path_ids.push(path_id);
223    }
224
225    /// Adds an address to the remote set
226    ///
227    /// On success returns the address if it was new to the set. It will error when the set has no
228    /// capacity for the address.
229    pub(crate) fn add_remote_address(
230        &mut self,
231        add_addr: AddAddress,
232    ) -> Result<Option<SocketAddr>, Error> {
233        let AddAddress { seq_no, ip, port } = add_addr;
234        let address = (ip, port);
235        let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
236        match self.remote_addresses.entry(seq_no) {
237            Entry::Occupied(mut occupied_entry) => {
238                let is_update = occupied_entry.get().0 != address;
239                if is_update {
240                    occupied_entry.insert((address, false));
241                }
242                // The value might be different. This should not happen, but we assume that the new
243                // address is more recent than the previous, and thus worth updating
244                Ok(is_update.then_some(address.into()))
245            }
246            Entry::Vacant(vacant_entry) if allow_new => {
247                vacant_entry.insert((address, false));
248                Ok(Some(address.into()))
249            }
250            _ => Err(Error::TooManyAddresses),
251        }
252    }
253
254    /// Removes an address from the remote set
255    ///
256    /// Returns whether the address was present.
257    pub(crate) fn remove_remote_address(
258        &mut self,
259        remove_addr: RemoveAddress,
260    ) -> Option<SocketAddr> {
261        self.remote_addresses
262            .remove(&remove_addr.seq_no)
263            .map(|(address, _report_in_continuation)| address.into())
264    }
265
266    /// Checks that a received remote address is valid
267    ///
268    /// An address is valid as long as it does not change the value of a known address id.
269    pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
270        match self.remote_addresses.get(&add_addr.seq_no) {
271            None => true,
272            Some((existing, _)) => existing == &add_addr.ip_port(),
273        }
274    }
275
276    pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
277        self.remote_addresses
278            .values()
279            .map(|(address, _report_in_continuation)| (*address).into())
280            .collect()
281    }
282}
283
284#[derive(Debug)]
285pub(crate) struct ServerState {
286    /// Max number of remote addresses we allow.
287    ///
288    /// This is set by the local endpoint.
289    max_remote_addresses: usize,
290    /// Max number of local addresses allowed.
291    ///
292    /// This is set by the remote endpoint.
293    max_local_addresses: usize,
294    /// Candidate addresses the server reports as potentially reachable, to use for nat
295    /// traversal attempts.
296    local_addresses: FxHashMap<IpPort, VarInt>,
297    /// The next id to use for local addresses sent to the client.
298    next_local_addr_id: VarInt,
299    /// Current nat traversal round
300    ///
301    /// Servers keep track of the client's most recent round and cancel probing related to previous
302    /// rounds.
303    round: VarInt,
304    /// Addresses to which PATH_CHALLENGES need to be sent.
305    pending_probes: FxHashSet<IpPort>,
306}
307
308impl ServerState {
309    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
310        Self {
311            max_remote_addresses,
312            max_local_addresses,
313            local_addresses: Default::default(),
314            next_local_addr_id: Default::default(),
315            round: Default::default(),
316            pending_probes: Default::default(),
317        }
318    }
319
320    fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
321        let allow_new = self.local_addresses.len() < self.max_local_addresses;
322        match self.local_addresses.entry(address) {
323            Entry::Occupied(_) => Ok(None),
324            Entry::Vacant(vacant_entry) if allow_new => {
325                let id = self.next_local_addr_id;
326                self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
327                vacant_entry.insert(id);
328                Ok(Some(AddAddress::new(address, id)))
329            }
330            _ => Err(Error::TooManyAddresses),
331        }
332    }
333
334    fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
335        self.local_addresses.remove(address).map(RemoveAddress::new)
336    }
337
338    /// Handles a received [`ReachOut`].
339    ///
340    /// This might ignore the reach out frame if it belongs to an older round or if
341    /// the reach out can't be handled by an ipv4-only local socket.
342    pub(crate) fn handle_reach_out(
343        &mut self,
344        reach_out: ReachOut,
345        ipv6: bool,
346    ) -> Result<(), Error> {
347        let ReachOut { round, ip, port } = reach_out;
348
349        if round < self.round {
350            trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
351            return Ok(());
352        }
353        let Some(ip) = map_to_local_socket_family(ip, ipv6) else {
354            trace!("Ignoring IPv6 REACH_OUT frame due to not supporting IPv6 locally");
355            return Ok(());
356        };
357
358        if round > self.round {
359            self.round = round;
360            self.pending_probes.clear();
361        } else if self.pending_probes.len() >= self.max_remote_addresses
362            && !self.pending_probes.contains(&(ip, port))
363        {
364            return Err(Error::TooManyAddresses);
365        }
366        self.pending_probes.insert((ip, port));
367        Ok(())
368    }
369
370    pub(crate) fn next_probe(&mut self) -> Option<ServerProbing<'_>> {
371        self.pending_probes
372            .iter()
373            .next()
374            .copied()
375            .map(|remote| ServerProbing {
376                remote,
377                pending_probes: &mut self.pending_probes,
378            })
379    }
380}
381
382pub(crate) struct ServerProbing<'a> {
383    remote: IpPort,
384    pending_probes: &'a mut FxHashSet<IpPort>,
385}
386
387impl<'a> ServerProbing<'a> {
388    pub(crate) fn mark_as_sent(self) {
389        self.pending_probes.remove(&self.remote);
390    }
391
392    pub(crate) fn remote(&self) -> SocketAddr {
393        self.remote.into()
394    }
395}
396
397impl State {
398    pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
399        match side {
400            Side::Client => Self::ClientSide(ClientState::new(
401                max_remote_addresses.into(),
402                max_local_addresses.into(),
403            )),
404            Side::Server => Self::ServerSide(ServerState::new(
405                max_remote_addresses.into(),
406                max_local_addresses.into(),
407            )),
408        }
409    }
410
411    pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
412        match self {
413            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
414            Self::ClientSide(client_side) => Ok(client_side),
415            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
416        }
417    }
418
419    pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
420        match self {
421            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
422            Self::ClientSide(client_side) => Ok(client_side),
423            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
424        }
425    }
426
427    pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
428        match self {
429            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
430            Self::ClientSide(_) => Err(Error::WrongConnectionSide),
431            Self::ServerSide(server_side) => Ok(server_side),
432        }
433    }
434
435    /// Adds a local address to use for nat traversal.
436    ///
437    /// When this endpoint is the server within the connection, these addresses will be sent to the
438    /// client in add address frames. For clients, these addresses will be sent in reach out frames
439    /// when nat traversal attempts are initiated.
440    ///
441    /// If a frame should be sent, it is returned.
442    pub(crate) fn add_local_address(
443        &mut self,
444        address: SocketAddr,
445    ) -> Result<Option<AddAddress>, Error> {
446        let ip_port = IpPort::from((address.ip(), address.port()));
447        match self {
448            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
449            Self::ClientSide(client_state) => {
450                client_state.add_local_address(ip_port)?;
451                Ok(None)
452            }
453            Self::ServerSide(server_state) => server_state.add_local_address(ip_port),
454        }
455    }
456
457    /// Removes a local address from the advertised set for nat traversal.
458    ///
459    /// When this endpoint is the server, removed addresses must be reported with remove address
460    /// frames. Clients will simply stop reporting these addresses in reach out frames.
461    ///
462    /// If a frame should be sent, it is returned.
463    pub(crate) fn remove_local_address(
464        &mut self,
465        address: SocketAddr,
466    ) -> Result<Option<RemoveAddress>, Error> {
467        let address = IpPort::from((address.ip(), address.port()));
468        match self {
469            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
470            Self::ClientSide(client_state) => {
471                client_state.remove_local_address(&address);
472                Ok(None)
473            }
474            Self::ServerSide(server_state) => Ok(server_state.remove_local_address(&address)),
475        }
476    }
477
478    pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
479        match self {
480            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
481            Self::ClientSide(client_state) => Ok(client_state
482                .local_addresses
483                .iter()
484                .copied()
485                .map(Into::into)
486                .collect()),
487            Self::ServerSide(server_state) => Ok(server_state
488                .local_addresses
489                .keys()
490                .copied()
491                .map(Into::into)
492                .collect()),
493        }
494    }
495}
496
497/// Returns the given address as canonicalized IP address.
498///
499/// This checks that the address family is supported by our local socket.
500/// If it is supported, then the address is mapped to the respective IP address.
501/// If the given address is an IPv6 address, but our local socket doesn't support
502/// IPv6, then this returns `None`.
503pub(crate) fn map_to_local_socket_family(address: IpAddr, ipv6: bool) -> Option<IpAddr> {
504    let ip = match address {
505        IpAddr::V4(addr) if ipv6 => IpAddr::V6(addr.to_ipv6_mapped()),
506        IpAddr::V4(_) => address,
507        IpAddr::V6(_) if ipv6 => address,
508        IpAddr::V6(addr) => IpAddr::V4(addr.to_ipv4_mapped()?),
509    };
510    Some(ip)
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_basic_server_state() {
519        let mut state = ServerState::new(2, 2);
520
521        state
522            .handle_reach_out(
523                ReachOut {
524                    round: 1u32.into(),
525                    ip: std::net::Ipv4Addr::LOCALHOST.into(),
526                    port: 1,
527                },
528                true,
529            )
530            .unwrap();
531
532        state
533            .handle_reach_out(
534                ReachOut {
535                    round: 1u32.into(),
536                    ip: "1.1.1.1".parse().unwrap(), //std::net::Ipv4Addr::LOCALHOST.into(),
537                    port: 2,
538                },
539                true,
540            )
541            .unwrap();
542
543        dbg!(&state);
544        assert_eq!(state.pending_probes.len(), 2);
545
546        let probe = state.next_probe().unwrap();
547        probe.mark_as_sent();
548        let probe = state.next_probe().unwrap();
549        probe.mark_as_sent();
550
551        assert!(state.next_probe().is_none());
552        assert_eq!(state.pending_probes.len(), 0);
553    }
554
555    #[test]
556    fn test_map_to_local_socket() {
557        assert_eq!(
558            map_to_local_socket_family("1.1.1.1".parse().unwrap(), false),
559            Some("1.1.1.1".parse().unwrap())
560        );
561        assert_eq!(
562            map_to_local_socket_family("1.1.1.1".parse().unwrap(), true),
563            Some("::ffff:1.1.1.1".parse().unwrap())
564        );
565        assert_eq!(
566            map_to_local_socket_family("::1".parse().unwrap(), true),
567            Some("::1".parse().unwrap())
568        );
569        assert_eq!(
570            map_to_local_socket_family("::1".parse().unwrap(), false),
571            None
572        );
573        assert_eq!(
574            map_to_local_socket_family("::ffff:1.1.1.1".parse().unwrap(), false),
575            Some("1.1.1.1".parse().unwrap())
576        )
577    }
578}