noq_proto/
n0_nat_traversal.rs

1//! n0's (<https://n0.computer>) NAT Traversal protocol implementation.
2
3use std::{
4    collections::hash_map::Entry,
5    net::{IpAddr, SocketAddr},
6};
7
8use rustc_hash::{FxHashMap, FxHashSet};
9use tracing::trace;
10
11use crate::{
12    PathId, Side, VarInt,
13    frame::{AddAddress, ReachOut, RemoveAddress},
14};
15
16type IpPort = (IpAddr, u16);
17
18/// Errors that the nat traversal state might encounter.
19#[derive(Debug, thiserror::Error)]
20pub enum Error {
21    /// An endpoint (local or remote) tried to add too many addresses to their advertised set
22    #[error("Tried to add too many addresses to their advertised set")]
23    TooManyAddresses,
24    /// The operation is not allowed for this endpoint's connection side
25    #[error("Not allowed for this endpoint's connection side")]
26    WrongConnectionSide,
27    /// The extension was not negotiated
28    #[error("n0's nat traversal was not negotiated")]
29    ExtensionNotNegotiated,
30    /// Not enough addresses to complete the operation
31    #[error("Not enough addresses")]
32    NotEnoughAddresses,
33    /// Nat traversal attempt failed due to a multipath error
34    #[error("Failed to establish paths {0}")]
35    Multipath(super::PathError),
36    /// Attempted to initiate NAT traversal on a closed, or closing connection.
37    #[error("The connection is already closed")]
38    Closed,
39}
40
41pub(crate) struct NatTraversalRound {
42    /// Sequence number to use for the new reach out frames.
43    pub(crate) new_round: VarInt,
44    /// Addresses to use to send reach out frames.
45    pub(crate) reach_out_at: FxHashSet<IpPort>,
46    /// Remotes to probe by attempting to open new paths.
47    ///
48    /// The addresses include their Id, so that it can be used to signal these should be returned
49    /// in a nat traversal continuation by calling [`ClientState::report_in_continuation`].
50    ///
51    /// These are filtered and mapped to the IP family the local socket supports.
52    pub(crate) addresses_to_probe: Vec<(VarInt, IpPort)>,
53    /// [`PathId`]s of the cancelled round.
54    pub(crate) prev_round_path_ids: Vec<PathId>,
55}
56
57/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames.
58#[derive(Debug, Clone)]
59pub enum Event {
60    /// An ADD_ADDRESS frame was received.
61    AddressAdded(SocketAddr),
62    /// A REMOVE_ADDRESS frame was received.
63    AddressRemoved(SocketAddr),
64}
65
66/// State kept for n0's nat traversal
67#[derive(Debug, Default)]
68pub(crate) enum State {
69    #[default]
70    NotNegotiated,
71    ClientSide(ClientState),
72    ServerSide(ServerState),
73}
74
75#[derive(Debug)]
76pub(crate) struct ClientState {
77    /// Max number of remote addresses we allow
78    ///
79    /// This is set by the local endpoint.
80    max_remote_addresses: usize,
81    /// Max number of local addresses allowed
82    ///
83    /// This is set by the remote endpoint.
84    max_local_addresses: usize,
85    /// Candidate addresses the remote server reports as potentially reachable, to use for nat
86    /// traversal attempts.
87    ///
88    /// These are indexed by their advertised Id. For each address, whether the address should be
89    /// reported in nat traversal continuations is kept.
90    remote_addresses: FxHashMap<VarInt, (IpPort, bool)>,
91    /// Candidate addresses the local client reports as potentially reachable, to use for nat
92    /// traversal attempts.
93    local_addresses: FxHashSet<IpPort>,
94    /// Current nat traversal round.
95    round: VarInt,
96    /// [`PathId`]s used to probe remotes assigned to this round.
97    round_path_ids: Vec<PathId>,
98}
99
100impl ClientState {
101    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
102        Self {
103            max_remote_addresses,
104            max_local_addresses,
105            remote_addresses: Default::default(),
106            local_addresses: Default::default(),
107            round: Default::default(),
108            round_path_ids: Default::default(),
109        }
110    }
111
112    fn add_local_address(&mut self, address: IpPort) -> Result<(), Error> {
113        if self.local_addresses.len() < self.max_local_addresses {
114            self.local_addresses.insert(address);
115            Ok(())
116        } else if self.local_addresses.contains(&address) {
117            // at capacity, but the address is known, no issues here
118            Ok(())
119        } else {
120            // at capacity and the address is new
121            Err(Error::TooManyAddresses)
122        }
123    }
124
125    fn remove_local_address(&mut self, address: &IpPort) {
126        self.local_addresses.remove(address);
127    }
128
129    /// Initiates a new nat traversal round.
130    ///
131    /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT`
132    /// frames, and initiating probing of the known remote addresses. When a new round is
133    /// initiated, the previous one is cancelled, and paths that have not been opened should be
134    /// closed.
135    ///
136    /// `ipv6` indicates if the connection runs on a socket that supports IPv6. If so, then all
137    /// addresses returned in [`NatTraversalRound`] will be IPv6 addresses (and IPv4-mapped IPv6
138    /// addresses if necessary). Otherwise they're all IPv4 addresses.
139    /// See also [`map_to_local_socket_family`].
140    pub(crate) fn initiate_nat_traversal_round(
141        &mut self,
142        ipv6: bool,
143    ) -> Result<NatTraversalRound, Error> {
144        if self.local_addresses.is_empty() {
145            return Err(Error::NotEnoughAddresses);
146        }
147
148        let prev_round_path_ids = std::mem::take(&mut self.round_path_ids);
149        self.round = self.round.saturating_add(1u8);
150        let mut addresses_to_probe = Vec::with_capacity(self.remote_addresses.len());
151        for (id, ((ip, port), report_in_continuation)) in self.remote_addresses.iter_mut() {
152            *report_in_continuation = false;
153
154            if let Some(ip) = map_to_local_socket_family(*ip, ipv6) {
155                addresses_to_probe.push((*id, (ip, *port)));
156            } else {
157                trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
158            }
159        }
160
161        Ok(NatTraversalRound {
162            new_round: self.round,
163            reach_out_at: self.local_addresses.iter().copied().collect(),
164            addresses_to_probe,
165            prev_round_path_ids,
166        })
167    }
168
169    /// Mark a remote address to be reported back in a nat traversal continuation if the error is
170    /// considered spurious from a nat traversal point of view.
171    ///
172    /// Ids not present are silently ignored.
173    pub(crate) fn report_in_continuation(&mut self, id: VarInt, e: crate::PathError) {
174        match e {
175            crate::PathError::MaxPathIdReached | crate::PathError::RemoteCidsExhausted => {
176                if let Some((_address, report_in_continuation)) = self.remote_addresses.get_mut(&id)
177                {
178                    *report_in_continuation = true;
179                }
180            }
181            _ => {}
182        }
183    }
184
185    /// Returns an address that needs to be probed, if any.
186    ///
187    /// The address will not be returned twice unless marked as such again with
188    /// [`Self::report_in_continuation`].
189    ///
190    /// `ipv6` indicates if the connection runs on a socket that supports IPv6. If so, then all
191    /// addresses returned in [`NatTraversalRound`] will be IPv6 addresses (and IPv4-mapped IPv6
192    /// addresses if necessary). Otherwise they're all IPv4 addresses.
193    /// See also [`map_to_local_socket_family`].
194    pub(crate) fn continue_nat_traversal_round(&mut self, ipv6: bool) -> Option<(VarInt, IpPort)> {
195        // this being random depends on iteration not returning always on the same order
196        let (id, (address, report_in_continuation)) = self
197            .remote_addresses
198            .iter_mut()
199            .filter(|(_id, (_addr, report))| *report)
200            .filter_map(|(id, ((ip, port), report))| {
201                // only continue with addresses we can send on our local socket
202                let Some(ip) = map_to_local_socket_family(*ip, ipv6) else {
203                    trace!(?ip, "not using IPv6 nat candidate for IPv4 socket");
204                    return None;
205                };
206                Some((*id, ((ip, *port), report)))
207            })
208            .next()?;
209        *report_in_continuation = false;
210        Some((id, address))
211    }
212
213    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
214    /// advertised addresses.
215    pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec<PathId>) {
216        self.round_path_ids = path_ids;
217    }
218
219    /// Add a [`PathId`] as part of the current attempts to create paths based on the server's
220    /// advertised addresses.
221    pub(crate) fn add_round_path_id(&mut self, path_id: PathId) {
222        self.round_path_ids.push(path_id);
223    }
224
225    /// Adds an address to the remote set
226    ///
227    /// On success returns the address if it was new to the set. It will error when the set has no
228    /// capacity for the address.
229    pub(crate) fn add_remote_address(
230        &mut self,
231        add_addr: AddAddress,
232    ) -> Result<Option<SocketAddr>, Error> {
233        let AddAddress { seq_no, ip, port } = add_addr;
234        let address = (ip, port);
235        let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
236        match self.remote_addresses.entry(seq_no) {
237            Entry::Occupied(mut occupied_entry) => {
238                let is_update = occupied_entry.get().0 != address;
239                if is_update {
240                    occupied_entry.insert((address, false));
241                }
242                // The value might be different. This should not happen, but we assume that the new
243                // address is more recent than the previous, and thus worth updating
244                Ok(is_update.then_some(address.into()))
245            }
246            Entry::Vacant(vacant_entry) if allow_new => {
247                vacant_entry.insert((address, false));
248                Ok(Some(address.into()))
249            }
250            _ => Err(Error::TooManyAddresses),
251        }
252    }
253
254    /// Removes an address from the remote set
255    ///
256    /// Returns whether the address was present.
257    pub(crate) fn remove_remote_address(
258        &mut self,
259        remove_addr: RemoveAddress,
260    ) -> Option<SocketAddr> {
261        self.remote_addresses
262            .remove(&remove_addr.seq_no)
263            .map(|(address, _report_in_continuation)| address.into())
264    }
265
266    /// Checks that a received remote address is valid
267    ///
268    /// An address is valid as long as it does not change the value of a known address id.
269    pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
270        match self.remote_addresses.get(&add_addr.seq_no) {
271            None => true,
272            Some((existing, _)) => existing == &add_addr.ip_port(),
273        }
274    }
275
276    pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
277        self.remote_addresses
278            .values()
279            .map(|(address, _report_in_continuation)| (*address).into())
280            .collect()
281    }
282}
283
284/// Maximum number of times an off-path probe is sent before giving up.
285pub(crate) const MAX_OFF_PATH_PROBE_ATTEMPTS: u8 = 10;
286
287/// State of an off-path probe to a client address.
288#[derive(Debug)]
289pub(crate) struct ProbeState {
290    /// Number of times this probe has been sent (0 = not yet sent).
291    pub(crate) attempts: u8,
292    /// Whether this probe is ready to be sent.
293    pub(crate) ready_to_send: bool,
294}
295
296#[derive(Debug)]
297pub(crate) struct ServerState {
298    /// Max number of remote addresses we allow.
299    ///
300    /// This is set by the local endpoint.
301    max_remote_addresses: usize,
302    /// Max number of local addresses allowed.
303    ///
304    /// This is set by the remote endpoint.
305    max_local_addresses: usize,
306    /// Candidate addresses the server reports as potentially reachable, to use for nat
307    /// traversal attempts.
308    local_addresses: FxHashMap<IpPort, VarInt>,
309    /// The next id to use for local addresses sent to the client.
310    next_local_addr_id: VarInt,
311    /// Current nat traversal round
312    ///
313    /// Servers keep track of the client's most recent round and cancel probing related to previous
314    /// rounds.
315    round: VarInt,
316    /// Addresses to which PATH_CHALLENGES need to be sent, with their probe state.
317    ///
318    /// Probes are retransmitted up to [`MAX_OFF_PATH_PROBE_ATTEMPTS`] times.
319    pending_probes: FxHashMap<IpPort, ProbeState>,
320}
321
322impl ServerState {
323    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
324        Self {
325            max_remote_addresses,
326            max_local_addresses,
327            local_addresses: Default::default(),
328            next_local_addr_id: Default::default(),
329            round: Default::default(),
330            pending_probes: Default::default(),
331        }
332    }
333
334    fn add_local_address(&mut self, address: IpPort) -> Result<Option<AddAddress>, Error> {
335        let allow_new = self.local_addresses.len() < self.max_local_addresses;
336        match self.local_addresses.entry(address) {
337            Entry::Occupied(_) => Ok(None),
338            Entry::Vacant(vacant_entry) if allow_new => {
339                let id = self.next_local_addr_id;
340                self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
341                vacant_entry.insert(id);
342                Ok(Some(AddAddress::new(address, id)))
343            }
344            _ => Err(Error::TooManyAddresses),
345        }
346    }
347
348    fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
349        self.local_addresses.remove(address).map(RemoveAddress::new)
350    }
351
352    /// Returns the current NAT traversal round number.
353    pub(crate) fn current_round(&self) -> VarInt {
354        self.round
355    }
356
357    /// Handles a received [`ReachOut`].
358    ///
359    /// This might ignore the reach out frame if it belongs to an older round or if
360    /// the reach out can't be handled by an ipv4-only local socket.
361    pub(crate) fn handle_reach_out(
362        &mut self,
363        reach_out: ReachOut,
364        ipv6: bool,
365    ) -> Result<(), Error> {
366        let ReachOut { round, ip, port } = reach_out;
367
368        if round < self.round {
369            trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
370            return Ok(());
371        }
372        let Some(ip) = map_to_local_socket_family(ip, ipv6) else {
373            trace!("Ignoring IPv6 REACH_OUT frame due to not supporting IPv6 locally");
374            return Ok(());
375        };
376
377        if round > self.round {
378            self.round = round;
379            self.pending_probes.clear();
380        } else if self.pending_probes.len() >= self.max_remote_addresses
381            && !self.pending_probes.contains_key(&(ip, port))
382        {
383            return Err(Error::TooManyAddresses);
384        }
385        self.pending_probes.entry((ip, port)).or_insert(ProbeState {
386            attempts: 0,
387            ready_to_send: true,
388        });
389        Ok(())
390    }
391
392    /// Re-queue all sent probes that haven't exceeded [`MAX_OFF_PATH_PROBE_ATTEMPTS`]
393    /// for retransmission. Called when the off-path probe retry timer fires.
394    ///
395    /// Returns whether any probes were re-queued.
396    pub(crate) fn queue_retries(&mut self) -> bool {
397        let mut any_requeued = false;
398        self.pending_probes.retain(|_, state| {
399            if state.attempts > 0 && state.attempts < MAX_OFF_PATH_PROBE_ATTEMPTS {
400                state.ready_to_send = true;
401                any_requeued = true;
402                true
403            } else {
404                state.attempts < MAX_OFF_PATH_PROBE_ATTEMPTS
405            }
406        });
407        any_requeued
408    }
409
410    /// Returns whether there are any probes that have been sent but are waiting
411    /// for retry (i.e., sent at least once but under the max attempt limit).
412    pub(crate) fn has_pending_retries(&self) -> bool {
413        self.pending_probes
414            .values()
415            .any(|state| state.attempts > 0 && state.attempts < MAX_OFF_PATH_PROBE_ATTEMPTS)
416    }
417
418    /// Returns the next ready probe's address.
419    pub(crate) fn next_probe_addr(&self) -> Option<SocketAddr> {
420        self.pending_probes
421            .iter()
422            .find(|(_, state)| state.ready_to_send)
423            .map(|(addr, _)| (*addr).into())
424    }
425
426    /// Mark a probe as sent by address.
427    pub(crate) fn mark_probe_sent(&mut self, remote: IpPort) {
428        if let Some(state) = self.pending_probes.get_mut(&remote) {
429            state.attempts += 1;
430            state.ready_to_send = false;
431        }
432    }
433}
434
435impl State {
436    pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
437        match side {
438            Side::Client => Self::ClientSide(ClientState::new(
439                max_remote_addresses.into(),
440                max_local_addresses.into(),
441            )),
442            Side::Server => Self::ServerSide(ServerState::new(
443                max_remote_addresses.into(),
444                max_local_addresses.into(),
445            )),
446        }
447    }
448
449    pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
450        match self {
451            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
452            Self::ClientSide(client_side) => Ok(client_side),
453            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
454        }
455    }
456
457    pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
458        match self {
459            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
460            Self::ClientSide(client_side) => Ok(client_side),
461            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
462        }
463    }
464
465    pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
466        match self {
467            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
468            Self::ClientSide(_) => Err(Error::WrongConnectionSide),
469            Self::ServerSide(server_side) => Ok(server_side),
470        }
471    }
472
473    /// Adds a local address to use for nat traversal.
474    ///
475    /// When this endpoint is the server within the connection, these addresses will be sent to the
476    /// client in add address frames. For clients, these addresses will be sent in reach out frames
477    /// when nat traversal attempts are initiated.
478    ///
479    /// If a frame should be sent, it is returned.
480    pub(crate) fn add_local_address(
481        &mut self,
482        address: SocketAddr,
483    ) -> Result<Option<AddAddress>, Error> {
484        let ip_port = IpPort::from((address.ip(), address.port()));
485        match self {
486            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
487            Self::ClientSide(client_state) => {
488                client_state.add_local_address(ip_port)?;
489                Ok(None)
490            }
491            Self::ServerSide(server_state) => server_state.add_local_address(ip_port),
492        }
493    }
494
495    /// Removes a local address from the advertised set for nat traversal.
496    ///
497    /// When this endpoint is the server, removed addresses must be reported with remove address
498    /// frames. Clients will simply stop reporting these addresses in reach out frames.
499    ///
500    /// If a frame should be sent, it is returned.
501    pub(crate) fn remove_local_address(
502        &mut self,
503        address: SocketAddr,
504    ) -> Result<Option<RemoveAddress>, Error> {
505        let address = IpPort::from((address.ip(), address.port()));
506        match self {
507            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
508            Self::ClientSide(client_state) => {
509                client_state.remove_local_address(&address);
510                Ok(None)
511            }
512            Self::ServerSide(server_state) => Ok(server_state.remove_local_address(&address)),
513        }
514    }
515
516    pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
517        match self {
518            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
519            Self::ClientSide(client_state) => Ok(client_state
520                .local_addresses
521                .iter()
522                .copied()
523                .map(Into::into)
524                .collect()),
525            Self::ServerSide(server_state) => Ok(server_state
526                .local_addresses
527                .keys()
528                .copied()
529                .map(Into::into)
530                .collect()),
531        }
532    }
533}
534
535/// Returns the given address as canonicalized IP address.
536///
537/// This checks that the address family is supported by our local socket.
538/// If it is supported, then the address is mapped to the respective IP address.
539/// If the given address is an IPv6 address, but our local socket doesn't support
540/// IPv6, then this returns `None`.
541pub(crate) fn map_to_local_socket_family(address: IpAddr, ipv6: bool) -> Option<IpAddr> {
542    let ip = match address {
543        IpAddr::V4(addr) if ipv6 => IpAddr::V6(addr.to_ipv6_mapped()),
544        IpAddr::V4(_) => address,
545        IpAddr::V6(_) if ipv6 => address,
546        IpAddr::V6(addr) => IpAddr::V4(addr.to_ipv4_mapped()?),
547    };
548    Some(ip)
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    #[test]
556    fn test_basic_server_state() {
557        let mut state = ServerState::new(2, 2);
558
559        state
560            .handle_reach_out(
561                ReachOut {
562                    round: 1u32.into(),
563                    ip: std::net::Ipv4Addr::LOCALHOST.into(),
564                    port: 1,
565                },
566                true,
567            )
568            .unwrap();
569
570        state
571            .handle_reach_out(
572                ReachOut {
573                    round: 1u32.into(),
574                    ip: "1.1.1.1".parse().unwrap(), //std::net::Ipv4Addr::LOCALHOST.into(),
575                    port: 2,
576                },
577                true,
578            )
579            .unwrap();
580
581        dbg!(&state);
582        assert_eq!(state.pending_probes.len(), 2);
583
584        // Helper: send next ready probe
585        let send_probe = |state: &mut ServerState| {
586            let remote = state.next_probe_addr().unwrap();
587            state.mark_probe_sent((remote.ip(), remote.port()));
588        };
589
590        send_probe(&mut state);
591        send_probe(&mut state);
592
593        // After sending both probes, no ready probes remain but they're still tracked.
594        assert!(state.next_probe_addr().is_none());
595        assert_eq!(state.pending_probes.len(), 2);
596        assert!(state.has_pending_retries());
597
598        // After queuing retries, probes become available again
599        assert!(state.queue_retries());
600        send_probe(&mut state);
601        send_probe(&mut state);
602
603        // After 2 attempts each, retries still available (max is 10)
604        assert!(state.queue_retries());
605        send_probe(&mut state);
606        send_probe(&mut state);
607
608        // Exhaust remaining attempts
609        for _ in 3..MAX_OFF_PATH_PROBE_ATTEMPTS {
610            assert!(state.queue_retries());
611            send_probe(&mut state);
612            send_probe(&mut state);
613        }
614
615        // After max attempts, probes are removed
616        assert!(!state.queue_retries());
617        assert!(state.next_probe_addr().is_none());
618        assert_eq!(state.pending_probes.len(), 0);
619    }
620
621    #[test]
622    fn test_map_to_local_socket() {
623        assert_eq!(
624            map_to_local_socket_family("1.1.1.1".parse().unwrap(), false),
625            Some("1.1.1.1".parse().unwrap())
626        );
627        assert_eq!(
628            map_to_local_socket_family("1.1.1.1".parse().unwrap(), true),
629            Some("::ffff:1.1.1.1".parse().unwrap())
630        );
631        assert_eq!(
632            map_to_local_socket_family("::1".parse().unwrap(), true),
633            Some("::1".parse().unwrap())
634        );
635        assert_eq!(
636            map_to_local_socket_family("::1".parse().unwrap(), false),
637            None
638        );
639        assert_eq!(
640            map_to_local_socket_family("::ffff:1.1.1.1".parse().unwrap(), false),
641            Some("1.1.1.1".parse().unwrap())
642        )
643    }
644}