noq_proto/
n0_nat_traversal.rs

1//! n0's (<https://n0.computer>) NAT Traversal protocol implementation.
2
3use std::{
4    collections::hash_map::Entry,
5    fmt::Display,
6    net::{IpAddr, SocketAddr},
7    time::Duration,
8};
9
10use rustc_hash::{FxHashMap, FxHashSet};
11use tracing::{debug, trace};
12
13use crate::{
14    FourTuple, Side, VarInt,
15    connection::spaces::PendingReachOutFrames,
16    frame::{AddAddress, ReachOut, RemoveAddress},
17};
18
19/// Maximum number of times we send a NAT probe to the same remote address in a round.
20///
21/// This is a trade-off between several factors:
22/// - Probe packets could be lost. This allows recovery.
23/// - We may need two probes to reach the NAT firewall to get through.
24/// - We may be sending probes to innocent bystanders on the internet.
25/// - A round never "finishes": probing of remotes only stops when:
26///   1. A new round is started.
27///   2. A probe was successful.
28///   3. This number of attempts is exhausted.
29///
30/// See [`State::retry_delay`] for the capped exponential backoff used. With this we send
31/// probes for up to 4s by default.
32pub(crate) const MAX_NAT_PROBE_ATTEMPTS: u8 = 9;
33
34/// An IP & port.
35///
36/// Invariant: This value should always be in the ip family that the local
37/// socket operates in.
38/// E.g. if the local socket is ipv4, then all `IpPort`s should only have
39/// IPv4 addresses, and if the socket supports ipv6, then all `IpPort`s
40/// should be IPv6 addresses or IPv6-mapped IPv4 addresses.
41///
42/// See also [`map_to_local_socket_family`], which powers this conversion.
43type IpPort = (IpAddr, u16);
44
45/// An IP & port in canonical form.
46///
47/// Avoids using ipv6-mapped ipv4 addresses.
48/// This is the primary type used to send ip addresses around remotely
49/// and the primary type used to canonicalize received addresses.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
51pub(crate) struct CanonicalIpPort {
52    canonical_ip: IpAddr,
53    port: u16,
54}
55
56impl CanonicalIpPort {
57    pub(crate) fn ip(&self) -> IpAddr {
58        self.canonical_ip
59    }
60
61    pub(crate) fn port(&self) -> u16 {
62        self.port
63    }
64
65    /// Converts this into a local-socket-family-mapped IP & port.
66    ///
67    /// Instead of using ipv4 and ipv6 addresses, this tries to match `ipv6`, which
68    /// should indicate whether the local socket supports ipv6 or not.
69    ///
70    /// If ipv6 is supported, all ipv4 addresses are mapped using ipv6-mapped ipv4
71    /// addresses.
72    /// If ipv6 is not supported, then this returns `None` for ipv6 addresses.
73    ///
74    /// See also [`map_to_local_socket_family`].
75    pub(crate) fn as_local_socket_family(&self, ipv6: bool) -> Option<IpPort> {
76        Some((
77            map_to_local_socket_family(self.canonical_ip, ipv6)?,
78            self.port,
79        ))
80    }
81
82    /// Returns this address as-is with the canonical IP used in a `SocketAddr`.
83    pub(crate) fn as_canonical_addr(&self) -> SocketAddr {
84        SocketAddr::new(self.canonical_ip, self.port)
85    }
86}
87
88impl Display for CanonicalIpPort {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        self.as_canonical_addr().fmt(f)
91    }
92}
93
94impl From<SocketAddr> for CanonicalIpPort {
95    fn from(addr: SocketAddr) -> Self {
96        Self {
97            canonical_ip: addr.ip().to_canonical(),
98            port: addr.port(),
99        }
100    }
101}
102
103impl From<IpPort> for CanonicalIpPort {
104    fn from((ip, port): IpPort) -> Self {
105        Self {
106            canonical_ip: ip.to_canonical(),
107            port,
108        }
109    }
110}
111
112/// Errors that the nat traversal state might encounter.
113#[derive(Debug, thiserror::Error)]
114pub enum Error {
115    /// An endpoint (local or remote) tried to add too many addresses to their advertised set
116    #[error("Tried to add too many addresses to their advertised set")]
117    TooManyAddresses,
118    /// The operation is not allowed for this endpoint's connection side
119    #[error("Not allowed for this endpoint's connection side")]
120    WrongConnectionSide,
121    /// The extension was not negotiated
122    #[error("n0's nat traversal was not negotiated")]
123    ExtensionNotNegotiated,
124    /// Not enough addresses to complete the operation
125    #[error("Not enough addresses")]
126    NotEnoughAddresses,
127    /// Nat traversal attempt failed due to a multipath error
128    #[error("Failed to establish paths {0}")]
129    Multipath(super::PathError),
130    /// Attempted to initiate NAT traversal on a closed, or closing connection.
131    #[error("The connection is already closed")]
132    Closed,
133}
134
135/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames.
136#[derive(Debug, Clone)]
137pub enum Event {
138    /// An ADD_ADDRESS frame was received.
139    AddressAdded(SocketAddr),
140    /// A REMOVE_ADDRESS frame was received.
141    AddressRemoved(SocketAddr),
142}
143
144/// State kept for n0's nat traversal
145#[derive(Debug, Default)]
146pub(crate) enum State {
147    #[default]
148    NotNegotiated,
149    ClientSide(ClientState),
150    ServerSide(ServerState),
151}
152
153impl State {
154    pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
155        match side {
156            Side::Client => Self::ClientSide(ClientState::new(
157                max_remote_addresses.into(),
158                max_local_addresses.into(),
159            )),
160            Side::Server => Self::ServerSide(ServerState::new(
161                max_remote_addresses.into(),
162                max_local_addresses.into(),
163            )),
164        }
165    }
166
167    pub(crate) fn is_negotiated(&self) -> bool {
168        match self {
169            Self::NotNegotiated => false,
170            Self::ClientSide(_) | Self::ServerSide(_) => true,
171        }
172    }
173
174    pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
175        match self {
176            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
177            Self::ClientSide(client_side) => Ok(client_side),
178            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
179        }
180    }
181
182    pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
183        match self {
184            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
185            Self::ClientSide(client_side) => Ok(client_side),
186            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
187        }
188    }
189
190    pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
191        match self {
192            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
193            Self::ClientSide(_) => Err(Error::WrongConnectionSide),
194            Self::ServerSide(server_side) => Ok(server_side),
195        }
196    }
197
198    /// Adds a local address to use for nat traversal.
199    ///
200    /// When this endpoint is the server within the connection, these addresses will be sent to the
201    /// client in add address frames. For clients, these addresses will be sent in reach out frames
202    /// when nat traversal attempts are initiated.
203    ///
204    /// If a frame should be sent, it is returned.
205    pub(crate) fn add_local_address(
206        &mut self,
207        address: SocketAddr,
208    ) -> Result<Option<AddAddress>, Error> {
209        match self {
210            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
211            Self::ClientSide(client_state) => {
212                client_state.add_local_address(address)?;
213                Ok(None)
214            }
215            Self::ServerSide(server_state) => server_state.add_local_address(address),
216        }
217    }
218
219    /// Removes a local address from the advertised set for nat traversal.
220    ///
221    /// When this endpoint is the server, removed addresses must be reported with remove address
222    /// frames. Clients will simply stop reporting these addresses in reach out frames.
223    ///
224    /// If a frame should be sent, it is returned.
225    pub(crate) fn remove_local_address(
226        &mut self,
227        address: SocketAddr,
228    ) -> Result<Option<RemoveAddress>, Error> {
229        let address = IpPort::from((address.ip(), address.port()));
230        match self {
231            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
232            Self::ClientSide(client_state) => {
233                client_state.remove_local_address(&address);
234                Ok(None)
235            }
236            Self::ServerSide(server_state) => Ok(server_state.remove_local_address(&address)),
237        }
238    }
239
240    pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
241        match self {
242            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
243            Self::ClientSide(client_state) => Ok(client_state
244                .local_addresses
245                .iter()
246                .map(CanonicalIpPort::as_canonical_addr)
247                .collect()),
248            Self::ServerSide(server_state) => Ok(server_state
249                .local_addresses
250                .keys()
251                .map(CanonicalIpPort::as_canonical_addr)
252                .collect()),
253        }
254    }
255
256    /// Returns the next ready probe's address.
257    ///
258    /// If this is actually sent you must call [`Self::mark_probe_sent`].
259    pub(crate) fn next_probe_addr(&self) -> Option<IpPort> {
260        match self {
261            Self::NotNegotiated => None,
262            Self::ClientSide(state) => state.next_probe_addr(),
263            Self::ServerSide(state) => state.next_probe_addr(),
264        }
265    }
266
267    /// Marks a probe as sent to the address with the challenge.
268    pub(crate) fn mark_probe_sent(&mut self, remote: IpPort, challenge: u64) {
269        match self {
270            Self::NotNegotiated => (),
271            Self::ClientSide(state) => state.mark_probe_sent(remote, challenge),
272            Self::ServerSide(state) => state.mark_probe_sent(remote, challenge),
273        }
274    }
275
276    /// Re-queues probes that have not yet succeeded or reached 0 remaining retries.
277    ///
278    /// After calling [`Self::retry_delay`] must be checked.
279    pub(crate) fn queue_retries(&mut self, ipv6: bool) {
280        match self {
281            Self::NotNegotiated => (),
282            Self::ClientSide(state) => state.queue_retries(ipv6),
283            Self::ServerSide(state) => state.queue_retries(),
284        };
285    }
286
287    /// Marks a remote as successful if the response matches a sent probe.
288    ///
289    /// Returns true if it was a response to one of the NAT traversal probes and a path
290    /// needs to be opened. Note that the NAT probes are not padded to 1200 bytes so only
291    /// the address is validated, but not the entire path.
292    pub(crate) fn handle_path_response(&mut self, src: FourTuple, challenge: u64) -> bool {
293        match self {
294            Self::NotNegotiated => false,
295            Self::ClientSide(state) => state.handle_path_response(src, challenge),
296            Self::ServerSide(state) => state.handle_path_response(src, challenge),
297        }
298    }
299
300    /// Returns the delay to arm the `NatTraversalProbeRetry` timer.
301    ///
302    /// `initial_rtt` must be [`TransportConfig::initial_rtt`] so retries are scaled to this
303    /// value.
304    ///
305    /// [`TransportConfig::initial_rtt`]: crate::TransportConfig::initial_rtt
306    pub(crate) fn retry_delay(&self, initial_rtt: Duration) -> Option<Duration> {
307        match self {
308            Self::NotNegotiated => return None,
309            Self::ClientSide(state) => {
310                if !state
311                    .remote_addresses
312                    .values()
313                    .any(|(_, probes)| probes.remaining() > 0)
314                {
315                    return None;
316                }
317            }
318            Self::ServerSide(state) => {
319                if !state.remotes.values().any(|probes| probes.remaining() > 0) {
320                    return None;
321                }
322            }
323        }
324
325        let attempt = match self {
326            Self::NotNegotiated => return None,
327            Self::ClientSide(state) => state.attempt,
328            Self::ServerSide(state) => state.attempt,
329        };
330
331        // Retries follow at an exponential backoff, capped at max 2s interval. The base
332        // delay is initial_rtt/10, which for the default value means 33.3ms. Just under
333        // 10_000 km at the speed of light.
334        const MAX_BACKOFF_EXPONENT: u8 = 8;
335        const MAX_INTERVAL: Duration = Duration::from_secs(2);
336        let base = initial_rtt / 10;
337        let attempt = attempt.min(MAX_BACKOFF_EXPONENT) as u32;
338        let interval = match attempt {
339            0 => base * 2u32.pow(attempt),
340            _ => base * 2u32.pow(attempt) - base * 2u32.pow(attempt - 1),
341        };
342        Some(interval.min(MAX_INTERVAL))
343    }
344}
345
346#[derive(Debug)]
347pub(crate) struct ClientState {
348    /// Max number of remote addresses we allow
349    ///
350    /// This is set by the local endpoint.
351    max_remote_addresses: usize,
352    /// Max number of local addresses allowed
353    ///
354    /// This is set by the remote endpoint.
355    max_local_addresses: usize,
356    /// Candidate addresses the remote endpoint advertises.
357    ///
358    /// These are addresses on which the server is potentially reachable, to use for NAT
359    /// traversal attempts.
360    ///
361    /// They are indexed by their ADD_ADDRESS sequence id and stored in **canonical
362    /// form**. Not in the socket-native form as usual. This because we need to store them
363    /// so we have the correct sequence IDs.
364    remote_addresses: FxHashMap<VarInt, (CanonicalIpPort, ProbeState)>,
365    /// Candidate addresses for the local endpoint.
366    ///
367    /// These are addresses on which we are potentially reachable, to use for NAT traversal
368    /// attempts.
369    ///
370    /// They are stored in **canonical form**, not in socket-native form as usual. We may
371    /// nave a reflexive address that is IPv6 even if our local socket can only handle IPv4.
372    local_addresses: FxHashSet<CanonicalIpPort>,
373    /// Current nat traversal round.
374    round: VarInt,
375    /// The probing attempt in the round.
376    ///
377    /// Probes are sent to all remotes at the same time in a round, at intervals from
378    /// [`State::retry_delay`]. This is the number of times probes have been sent.
379    attempt: u8,
380    /// The data of PATH_CHALLENGE frames sent in probes.
381    ///
382    /// These are cleared when a new round starts, so any late-arriving PATH_RESPONSEs will
383    /// have no effect.
384    ///
385    /// They are stored in the usual socket-native form.
386    sent_challenges: FxHashMap<u64, IpPort>,
387    /// Queued probes to be sent in the next [`poll_transmit`] call.
388    ///
389    /// [`poll_transmit`]: crate::connection::Connection::poll_transmit
390    ///
391    /// They are stored in the usual socket-native form. Probes to address families not
392    /// addressable by the family are never inserted.
393    pending_probes: FxHashSet<IpPort>,
394    /// Network paths that were successfully probed but not yet opened.
395    ///
396    /// When we do not have enough CIDs or free path IDs we might not have been able to open
397    /// a new path. This allows us to try re-open the path when we get new CIDs or a new
398    /// MAX_PATH_ID.
399    // TODO(flub): perhaps there should be a time-limit on these?
400    paths_to_be_opened: Vec<FourTuple>,
401}
402
403impl ClientState {
404    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
405        Self {
406            max_remote_addresses,
407            max_local_addresses,
408            remote_addresses: Default::default(),
409            local_addresses: Default::default(),
410            round: Default::default(),
411            attempt: 0,
412            sent_challenges: Default::default(),
413            pending_probes: Default::default(),
414            paths_to_be_opened: Default::default(),
415        }
416    }
417
418    fn add_local_address(&mut self, address: SocketAddr) -> Result<(), Error> {
419        let address = CanonicalIpPort::from(address);
420        if self.local_addresses.len() < self.max_local_addresses {
421            self.local_addresses.insert(address);
422            Ok(())
423        } else if self.local_addresses.contains(&address) {
424            // at capacity, but the address is known, no issues here
425            Ok(())
426        } else {
427            // at capacity and the address is new
428            Err(Error::TooManyAddresses)
429        }
430    }
431
432    fn remove_local_address(&mut self, address: &IpPort) {
433        let address = CanonicalIpPort::from(*address);
434        self.local_addresses.remove(&address);
435    }
436
437    /// Initiates a new nat traversal round.
438    ///
439    /// A nat traversal round involves advertising the client's local addresses in
440    /// `REACH_OUT` frames, and initiating probing of the known remote addresses. When a new
441    /// round is initiated, the previous one is cancelled.
442    ///
443    /// `ipv6` indicates if the connection runs on a socket that supports IPv6. If so, then
444    /// all addresses returned [`PendingReachOutFrames`] will be IPv6 addresses (and
445    /// IPv4-mapped IPv6 addresses if necessary). Otherwise they're all IPv4 addresses.  See
446    /// also [`map_to_local_socket_family`].
447    ///
448    /// # Returns
449    ///
450    /// The REACH_OUT frames that need to be sent to the peer and the probed addresses. The
451    /// probed addresses are only informational, the pending probes are stored in
452    /// [`Self::pending_probes`].
453    ///
454    /// If the probed addresses are non-empty the `NatTraversalProbeRetry` timer needs to be
455    /// set.
456    pub(crate) fn initiate_nat_traversal_round(
457        &mut self,
458        ipv6: bool,
459    ) -> Result<(PendingReachOutFrames, Vec<SocketAddr>), Error> {
460        if self.local_addresses.is_empty() {
461            return Err(Error::NotEnoughAddresses);
462        }
463
464        self.round = self.round.saturating_add(1u8);
465        self.attempt = 0;
466        self.sent_challenges.clear();
467        self.pending_probes.clear();
468
469        // Enqueue the NAT probes to known remote addresses.
470        self.remote_addresses
471            .values_mut()
472            .for_each(|(ip_port, state)| {
473                if let Some(ip_port) = ip_port.as_local_socket_family(ipv6) {
474                    self.pending_probes.insert(ip_port);
475                    *state = ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS - 1);
476                } else {
477                    trace!(%ip_port, "not using IPv6 NAT candidate for IPv4 socket");
478                    *state = ProbeState::Active(0);
479                }
480            });
481        let probed_addrs: Vec<SocketAddr> = self
482            .pending_probes
483            .iter()
484            .copied()
485            .map(Into::into)
486            .collect();
487
488        // Build the REACH_OUT frames.
489        let reach_out_frames: PendingReachOutFrames = self
490            .local_addresses
491            .iter()
492            .map(|ip_port| ReachOut {
493                round: self.round,
494                ip: ip_port.ip(),
495                port: ip_port.port(),
496            })
497            .collect();
498
499        trace!(
500            round = %self.round,
501            reach_out = %reach_out_frames.len(),
502            to_probe = %self.pending_probes.len(),
503            "initiating NAT traversal round",
504        );
505        Ok((reach_out_frames, probed_addrs))
506    }
507
508    /// Re-queues probes that have not yet succeeded or reached 0 remaining retries.
509    ///
510    /// Returns whether any probes are now queued to send. In this case the
511    /// `NatTraversalProbeRetry` timer needs to be reset.
512    ///
513    /// `ipv6` as for [`Self::initiate_nat_traversal_round`].
514    pub(crate) fn queue_retries(&mut self, ipv6: bool) {
515        self.attempt += 1;
516        self.remote_addresses
517            .values_mut()
518            .for_each(|(ip_port, state)| match state {
519                ProbeState::Active(remaining) if *remaining > 0 => {
520                    *remaining -= 1;
521                    if let Some(ip_port) = ip_port.as_local_socket_family(ipv6) {
522                        self.pending_probes.insert(ip_port);
523                    } else {
524                        trace!(%ip_port, "skipping IPv6 NAT candidate for IPv4 socket");
525                        *remaining = 0;
526                    }
527                }
528                ProbeState::Active(_) | ProbeState::Succeeded => {}
529            });
530    }
531
532    /// Returns the next ready probe's address.
533    ///
534    /// If this is actually sent you must call [`Self::mark_probe_sent`].
535    fn next_probe_addr(&self) -> Option<IpPort> {
536        self.pending_probes.iter().next().copied()
537    }
538
539    /// Marks a probe as sent to the address with the challenge.
540    fn mark_probe_sent(&mut self, remote: IpPort, challenge: u64) {
541        self.pending_probes.remove(&remote);
542        self.sent_challenges.insert(challenge, remote);
543    }
544
545    /// Adds an address to the remote set.
546    ///
547    /// On success returns the address if it was new to the set. It will error when the set
548    /// has no capacity for the address.
549    ///
550    /// If this is called while a round is in progress this will effectively add the address
551    /// to the current round. There is no guarantee however that the current round is still
552    /// in progress however, if the last [`Self::queue_retries`] call returned `false` the
553    /// round has stopped.
554    // TODO(flub): probably should add an event to signal that the round is finished, so
555    //    that the application knows to start a new round.
556    pub(crate) fn add_remote_address(
557        &mut self,
558        add_addr: AddAddress,
559    ) -> Result<Option<SocketAddr>, Error> {
560        let AddAddress { seq_no, ip, port } = add_addr;
561        let address = CanonicalIpPort::from((ip, port));
562        let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
563        match self.remote_addresses.entry(seq_no) {
564            Entry::Occupied(mut occupied_entry) => {
565                let is_update = occupied_entry.get().0 != address;
566                if is_update {
567                    occupied_entry.insert((address, ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS)));
568                }
569                // The value might be different. This should not happen, but we assume that the new
570                // address is more recent than the previous, and thus worth updating
571                Ok(is_update.then_some(address.as_canonical_addr()))
572            }
573            Entry::Vacant(vacant_entry) if allow_new => {
574                vacant_entry.insert((address, ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS)));
575                Ok(Some(address.as_canonical_addr()))
576            }
577            _ => Err(Error::TooManyAddresses),
578        }
579    }
580
581    /// Removes an address from the remote set.
582    ///
583    /// Returns whether the address was present.
584    pub(crate) fn remove_remote_address(
585        &mut self,
586        remove_addr: RemoveAddress,
587    ) -> Option<SocketAddr> {
588        self.remote_addresses
589            .remove(&remove_addr.seq_no)
590            .map(|(address, _)| address.as_canonical_addr())
591    }
592
593    /// Checks that a received remote address is valid.
594    ///
595    /// An address is valid as long as it does not change the value of a known address id.
596    pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
597        match self.remote_addresses.get(&add_addr.seq_no) {
598            None => true,
599            Some((existing, _)) => *existing == CanonicalIpPort::from(add_addr.ip_port()),
600        }
601    }
602
603    pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
604        self.remote_addresses
605            .values()
606            .map(|(address, _)| (*address).as_canonical_addr())
607            .collect()
608    }
609
610    /// Marks a remote as successful if the response matches a sent probe.
611    ///
612    /// Returns `true` if it was a response to one of the NAT traversal probes. In that case
613    /// [`Self::pop_pending_path_open`] should be called to open the next path.
614    fn handle_path_response(&mut self, network_path: FourTuple, challenge: u64) -> bool {
615        if let Entry::Occupied(entry) = self.sent_challenges.entry(challenge) {
616            let remote = (network_path.remote().ip(), network_path.remote().port());
617            if *entry.get() == remote {
618                entry.remove();
619                trace!(
620                    ?network_path,
621                    challenge = %display(format_args!("0x{challenge:x}")),
622                    "Received valid NAT traversal probe response",
623                );
624                self.paths_to_be_opened.push(network_path);
625
626                // TODO: linear search is sad.
627                let remote = CanonicalIpPort::from(remote);
628                if let Some(seq) = self
629                    .remote_addresses
630                    .iter()
631                    .filter_map(
632                        |(seq, (addr, _))| {
633                            if *addr == remote { Some(*seq) } else { None }
634                        },
635                    )
636                    .next()
637                {
638                    // Stop probing this remote address.
639                    self.remote_addresses
640                        .insert(seq, (remote, ProbeState::Succeeded));
641                } else {
642                    // Nothing to stop probing, the remote was only challenged because a
643                    // PATH_RESPONSE was being sent to it. These are not retried locally
644                    // since the peer is responsible for retrying the challenges until it
645                    // receives a response, at which time the local challenge is delivered.
646                    trace!("probe opened un-advertised address, peer likely behind DEDN");
647                }
648                return true;
649            } else {
650                debug!(
651                    ?network_path.remote,
652                    expected_remote = ?entry.get(),
653                    challenge = %display(format_args!("0x{challenge:x}")),
654                    "PATH_RESPONSE matched a NAT traversal probe but mismatching addr",
655                )
656            }
657        }
658        false
659    }
660
661    /// Returns a path that was NAT traversed and needs to be opened.
662    pub(crate) fn pop_pending_path_open(&mut self) -> Option<FourTuple> {
663        self.paths_to_be_opened.pop()
664    }
665
666    /// Put back a path that needs to be opened, e.g. for a temporary failure.
667    pub(crate) fn push_pending_path_open(&mut self, network_path: FourTuple) {
668        self.paths_to_be_opened.push(network_path)
669    }
670}
671
672/// State of an off-path NAT traversal probe to a remote address.
673#[derive(Debug)]
674enum ProbeState {
675    /// The remote still needs to be probed in this round.
676    ///
677    /// The remaining number of retries are stored in the `u8`.
678    Active(u8),
679    /// We received a probe response for this remote.
680    Succeeded,
681}
682
683impl ProbeState {
684    /// Returns the remaining number of probes to try for this remote.
685    fn remaining(&self) -> u8 {
686        match self {
687            Self::Active(remaining) => *remaining,
688            Self::Succeeded => 0,
689        }
690    }
691}
692
693#[derive(Debug)]
694pub(crate) struct ServerState {
695    /// Max number of remote addresses we allow.
696    ///
697    /// This is set by the local endpoint.
698    max_remote_addresses: usize,
699    /// Max number of local addresses allowed.
700    ///
701    /// This is set by the remote endpoint.
702    max_local_addresses: usize,
703    /// Candidate addresses the server reports as potentially reachable, to use for nat
704    /// traversal attempts.
705    ///
706    /// They are stored in **canonical form**, not in socket-native form as usual. We may
707    /// nave a reflexive address that is IPv6 even if our local socket can only handle IPv4.
708    local_addresses: FxHashMap<CanonicalIpPort, VarInt>,
709    /// The next id to use for local addresses sent to the client.
710    next_local_addr_id: VarInt,
711    /// Current nat traversal round
712    ///
713    /// Servers keep track of the client's most recent round and cancel probing related to previous
714    /// rounds.
715    round: VarInt,
716    /// The probing attempt in the round.
717    ///
718    /// Probes are sent to all remotes at the same time in a round, at intervals from
719    /// [`State::retry_delay`]. This is the number of times probes have been sent.
720    attempt: u8,
721    /// The remote addresses participating in this round.
722    ///
723    /// The set is cleared when a new round starts.
724    ///
725    /// These are stored in the usual local-socket native form.
726    remotes: FxHashMap<IpPort, ProbeState>,
727    /// The data of PATH_CHALLENGE frames sent in probes.
728    ///
729    /// These are cleared when a new round starts, so any late-arriving PATH_RESPONSEs will
730    /// have no effect.
731    sent_challenges: FxHashMap<u64, IpPort>,
732    /// Queued probes to be sent in the next [`poll_transmit`] call.
733    ///
734    /// At the beginning of a round this is populated from REACH_OUT frames and at every
735    /// retry this is populated from [`Self::remotes`].
736    ///
737    /// [`poll_transmit`]: crate::connection::Connection::poll_transmit
738    pending_probes: FxHashSet<IpPort>,
739}
740
741impl ServerState {
742    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
743        Self {
744            max_remote_addresses,
745            max_local_addresses,
746            local_addresses: Default::default(),
747            next_local_addr_id: Default::default(),
748            round: Default::default(),
749            attempt: 0,
750            remotes: Default::default(),
751            sent_challenges: Default::default(),
752            pending_probes: Default::default(),
753        }
754    }
755
756    fn add_local_address(&mut self, address: SocketAddr) -> Result<Option<AddAddress>, Error> {
757        let address = CanonicalIpPort::from(address);
758        let allow_new = self.local_addresses.len() < self.max_local_addresses;
759        match self.local_addresses.entry(address) {
760            Entry::Occupied(_) => Ok(None),
761            Entry::Vacant(vacant_entry) if allow_new => {
762                let id = self.next_local_addr_id;
763                self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
764                vacant_entry.insert(id);
765                Ok(Some(AddAddress::new((address.ip(), address.port()), id)))
766            }
767            _ => Err(Error::TooManyAddresses),
768        }
769    }
770
771    fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
772        let address = CanonicalIpPort::from(*address);
773        self.local_addresses
774            .remove(&address)
775            .map(RemoveAddress::new)
776    }
777
778    /// Returns the current NAT traversal round number.
779    pub(crate) fn current_round(&self) -> VarInt {
780        self.round
781    }
782
783    /// Handles a received REACH_OUT frame.
784    ///
785    /// This might ignore the reach out frame if it belongs to an older round or if the
786    /// frame contains an IPv6 address while the local socket is IPv4-only.
787    ///
788    /// If a new round was started, the `NatTraversalProbeRetry` timer needs to be reset.
789    pub(crate) fn handle_reach_out(
790        &mut self,
791        reach_out: ReachOut,
792        ipv6: bool,
793    ) -> Result<(), Error> {
794        let ReachOut { round, ip, port } = reach_out;
795
796        if round < self.round {
797            trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
798            return Ok(());
799        }
800        let Some(ip) = map_to_local_socket_family(ip, ipv6) else {
801            trace!("Ignoring IPv6 REACH_OUT frame due to not supporting IPv6 locally");
802            return Ok(());
803        };
804
805        if round > self.round {
806            self.round = round;
807            self.attempt = 0;
808            self.remotes.clear();
809            self.sent_challenges.clear();
810            self.pending_probes.clear();
811        } else if self.remotes.contains_key(&(ip, port)) {
812            // Retransmitted frame.
813            return Ok(());
814        } else if self.remotes.len() >= self.max_remote_addresses {
815            return Err(Error::TooManyAddresses);
816        }
817        self.remotes
818            .entry((ip, port))
819            .or_insert(ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS - 1));
820        self.pending_probes.insert((ip, port));
821        Ok(())
822    }
823
824    /// Re-queues probes that have not yet succeeded or reached [`MAX_NAT_PROBE_ATTEMPTS`].
825    ///
826    /// Returns whether any probes are now queued to send. In this case the
827    /// `NatTraversalProbeRetry` timer needs to be reset.
828    pub(crate) fn queue_retries(&mut self) {
829        self.attempt += 1;
830        self.remotes
831            .iter_mut()
832            .for_each(|(remote, state)| match state {
833                ProbeState::Active(remaining) if *remaining > 0 => {
834                    *remaining -= 1;
835                    self.pending_probes.insert(*remote);
836                }
837                ProbeState::Active(_) | ProbeState::Succeeded => (),
838            });
839    }
840
841    /// Returns the next ready probe's address.
842    ///
843    /// If this is actually sent you must call [`Self::mark_probe_sent`].
844    fn next_probe_addr(&self) -> Option<IpPort> {
845        self.pending_probes.iter().next().cloned()
846    }
847
848    /// Marks a probe as sent to the address with the challenge.
849    fn mark_probe_sent(&mut self, remote: IpPort, challenge: u64) {
850        self.pending_probes.remove(&remote);
851        self.sent_challenges.insert(challenge, remote);
852    }
853
854    /// Marks a remote as successful if the response matches a sent probe.
855    ///
856    /// Returns `true` if it was a response to one of the NAT traversal probes.
857    fn handle_path_response(&mut self, src: FourTuple, challenge: u64) -> bool {
858        if let Entry::Occupied(entry) = self.sent_challenges.entry(challenge) {
859            let remote = (src.remote().ip(), src.remote().port());
860            if *entry.get() == remote {
861                entry.remove();
862                self.remotes.insert(remote, ProbeState::Succeeded);
863                return true;
864            } else {
865                debug!(
866                    ?challenge,
867                    ?src.remote,
868                    "PATH_RESPONSE matched a NAT traversal probe but mismatching addr",
869                )
870            }
871        }
872        false
873    }
874}
875
876/// Returns the given address as canonicalized IP address.
877///
878/// This checks that the address family is supported by our local socket.
879/// If it is supported, then the address is mapped to the respective IP address.
880/// If the given address is an IPv6 address, but our local socket doesn't support
881/// IPv6, then this returns `None`.
882pub(crate) fn map_to_local_socket_family(address: IpAddr, ipv6: bool) -> Option<IpAddr> {
883    let ip = match address {
884        IpAddr::V4(addr) if ipv6 => IpAddr::V6(addr.to_ipv6_mapped()),
885        IpAddr::V4(_) => address,
886        IpAddr::V6(_) if ipv6 => address,
887        IpAddr::V6(addr) => IpAddr::V4(addr.to_ipv4_mapped()?),
888    };
889    Some(ip)
890}
891
892#[cfg(test)]
893mod tests {
894    use testresult::TestResult;
895
896    use super::*;
897
898    #[test]
899    fn test_basic_server_state() {
900        let mut state = ServerState::new(2, 2);
901
902        state
903            .handle_reach_out(
904                ReachOut {
905                    round: 1u32.into(),
906                    ip: std::net::Ipv4Addr::LOCALHOST.into(),
907                    port: 1,
908                },
909                true,
910            )
911            .unwrap();
912
913        state
914            .handle_reach_out(
915                ReachOut {
916                    round: 1u32.into(),
917                    ip: "1.1.1.1".parse().unwrap(), //std::net::Ipv4Addr::LOCALHOST.into(),
918                    port: 2,
919                },
920                true,
921            )
922            .unwrap();
923
924        dbg!(&state);
925        assert_eq!(state.pending_probes.len(), 2);
926
927        // Helper: send next ready probe
928        let mut challenge = 0;
929        let mut send_probe = |state: &mut ServerState| {
930            let remote = state.next_probe_addr().unwrap();
931            challenge += 1;
932            state.mark_probe_sent(remote, challenge);
933        };
934
935        send_probe(&mut state);
936        send_probe(&mut state);
937
938        // After sending both probes, no ready probes remain but they're still tracked.
939        assert!(state.next_probe_addr().is_none());
940
941        // After queuing retries, probes become available again
942        state.queue_retries();
943        send_probe(&mut state);
944        send_probe(&mut state);
945
946        // After 2 attempts each, retries still available (max is 10)
947        state.queue_retries();
948        send_probe(&mut state);
949        send_probe(&mut state);
950
951        // Exhaust remaining attempts
952        for _ in 3..MAX_NAT_PROBE_ATTEMPTS {
953            state.queue_retries();
954            send_probe(&mut state);
955            send_probe(&mut state);
956        }
957
958        // After max attempts, probes are removed
959        state.queue_retries();
960        assert!(state.next_probe_addr().is_none());
961    }
962
963    #[test]
964    fn test_map_to_local_socket() {
965        assert_eq!(
966            map_to_local_socket_family("1.1.1.1".parse().unwrap(), false),
967            Some("1.1.1.1".parse().unwrap())
968        );
969        assert_eq!(
970            map_to_local_socket_family("1.1.1.1".parse().unwrap(), true),
971            Some("::ffff:1.1.1.1".parse().unwrap())
972        );
973        assert_eq!(
974            map_to_local_socket_family("::1".parse().unwrap(), true),
975            Some("::1".parse().unwrap())
976        );
977        assert_eq!(
978            map_to_local_socket_family("::1".parse().unwrap(), false),
979            None
980        );
981        assert_eq!(
982            map_to_local_socket_family("::ffff:1.1.1.1".parse().unwrap(), false),
983            Some("1.1.1.1".parse().unwrap())
984        )
985    }
986
987    #[test]
988    fn test_retry_delay_server_ipv6() -> TestResult {
989        let initial_rtt = Duration::from_millis(333);
990        let ipv6 = true;
991        let remote = SocketAddr::from(("::2".parse::<IpAddr>()?, 2));
992        let remote_ipp = (remote.ip(), remote.port());
993
994        let mut nat = State::new(8, 8, Side::Server);
995
996        nat.server_side_mut()?.handle_reach_out(
997            ReachOut {
998                round: 1u8.into(),
999                ip: remote.ip(),
1000                port: remote.port(),
1001            },
1002            ipv6,
1003        )?;
1004
1005        let challenges = [1u64, 2, 3, 4, 5, 6, 7];
1006        let delays = [
1007            33_300u64, 66_600, 133_200, 266_400, 532_800, 1_065_600, 2_000_000,
1008        ];
1009        for (challenge, delay) in challenges.into_iter().zip(delays) {
1010            nat.queue_retries(ipv6);
1011            assert_eq!(nat.next_probe_addr(), Some(remote_ipp));
1012            nat.mark_probe_sent(remote_ipp, challenge);
1013            assert_eq!(
1014                nat.retry_delay(initial_rtt),
1015                Some(Duration::from_micros(delay)),
1016                "challenge: {challenge}"
1017            );
1018        }
1019
1020        assert!(nat.handle_path_response(
1021            FourTuple {
1022                remote,
1023                local_ip: Some("::3".parse::<IpAddr>()?),
1024            },
1025            challenges[6]
1026        ));
1027        assert_eq!(nat.retry_delay(initial_rtt), None);
1028
1029        Ok(())
1030    }
1031
1032    #[test]
1033    fn test_retry_delay_client_ipv6() -> TestResult {
1034        let initial_rtt = Duration::from_millis(333);
1035        let ipv6 = true;
1036        let remote = SocketAddr::from(("::2".parse::<IpAddr>()?, 2));
1037        let remote_ipp = (remote.ip(), remote.port());
1038        let local_addr = SocketAddr::from(("::3".parse::<IpAddr>()?, 3));
1039
1040        let mut nat = State::new(8, 8, Side::Client);
1041        nat.add_local_address(local_addr)?;
1042        nat.client_side_mut()?.add_remote_address(AddAddress {
1043            seq_no: 1u8.into(),
1044            ip: remote.ip(),
1045            port: remote.port(),
1046        })?;
1047        nat.client_side_mut()?.initiate_nat_traversal_round(ipv6)?;
1048
1049        let challenges = [1u64, 2, 3, 4, 5, 6, 7];
1050        let delays = [
1051            33_300u64, 66_600, 133_200, 266_400, 532_800, 1_065_600, 2_000_000,
1052        ];
1053        for (challenge, delay) in challenges.into_iter().zip(delays) {
1054            nat.queue_retries(ipv6);
1055            assert_eq!(nat.next_probe_addr(), Some(remote_ipp));
1056            nat.mark_probe_sent(remote_ipp, challenge);
1057            assert_eq!(
1058                nat.retry_delay(initial_rtt),
1059                Some(Duration::from_micros(delay)),
1060                "challenge: {challenge}"
1061            );
1062        }
1063
1064        assert!(nat.handle_path_response(
1065            FourTuple {
1066                remote,
1067                local_ip: Some(local_addr.ip()),
1068            },
1069            challenges[6]
1070        ));
1071        assert_eq!(nat.retry_delay(initial_rtt), None);
1072
1073        Ok(())
1074    }
1075}