noq_proto/
n0_nat_traversal.rs

1//! n0's (<https://n0.computer>) NAT Traversal protocol implementation.
2
3use std::{
4    collections::hash_map::Entry,
5    fmt::Display,
6    net::{IpAddr, SocketAddr},
7    time::Duration,
8};
9
10use rustc_hash::{FxHashMap, FxHashSet};
11use tracing::{debug, trace};
12
13use crate::{
14    FourTuple, Side, VarInt,
15    connection::spaces::PendingReachOutFrames,
16    frame::{AddAddress, ReachOut, RemoveAddress},
17};
18
19/// Maximum number of times we send a NAT probe to the same remote address in a round.
20///
21/// This is a trade-off between several factors:
22/// - Probe packets could be lost. This allows recovery.
23/// - We may need two probes to reach the NAT firewall to get through.
24/// - We may be sending probes to innocent bystanders on the internet.
25/// - A round never "finishes": probing of remotes only stops when:
26///   1. A new round is started.
27///   2. A probe was successful.
28///   3. This number of attempts is exhausted.
29///
30/// See [`State::retry_delay`] for the capped exponential backoff used. With this we send
31/// probes for up to 4s by default.
32pub(crate) const MAX_NAT_PROBE_ATTEMPTS: u8 = 9;
33
34/// An IP & port.
35///
36/// Invariant: This value should always be in the ip family that the local
37/// socket operates in.
38/// E.g. if the local socket is ipv4, then all `IpPort`s should only have
39/// IPv4 addresses, and if the socket supports ipv6, then all `IpPort`s
40/// should be IPv6 addresses or IPv6-mapped IPv4 addresses.
41///
42/// See also [`map_to_local_socket_family`], which powers this conversion.
43type IpPort = (IpAddr, u16);
44
45/// An IP & port in canonical form.
46///
47/// Avoids using ipv6-mapped ipv4 addresses.
48/// This is the primary type used to send ip addresses around remotely
49/// and the primary type used to canonicalize received addresses.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
51pub(crate) struct CanonicalIpPort {
52    canonical_ip: IpAddr,
53    port: u16,
54}
55
56impl CanonicalIpPort {
57    pub(crate) fn ip(&self) -> IpAddr {
58        self.canonical_ip
59    }
60
61    pub(crate) fn port(&self) -> u16 {
62        self.port
63    }
64
65    /// Converts this into a local-socket-family-mapped IP & port.
66    ///
67    /// Instead of using ipv4 and ipv6 addresses, this tries to match `ipv6`, which
68    /// should indicate whether the local socket supports ipv6 or not.
69    ///
70    /// If ipv6 is supported, all ipv4 addresses are mapped using ipv6-mapped ipv4
71    /// addresses.
72    /// If ipv6 is not supported, then this returns `None` for ipv6 addresses.
73    ///
74    /// See also [`map_to_local_socket_family`].
75    pub(crate) fn as_local_socket_family(&self, ipv6: bool) -> Option<IpPort> {
76        Some((
77            map_to_local_socket_family(self.canonical_ip, ipv6)?,
78            self.port,
79        ))
80    }
81
82    /// Returns this address as-is with the canonical IP used in a `SocketAddr`.
83    pub(crate) fn as_canonical_addr(&self) -> SocketAddr {
84        SocketAddr::new(self.canonical_ip, self.port)
85    }
86}
87
88impl Display for CanonicalIpPort {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        self.as_canonical_addr().fmt(f)
91    }
92}
93
94impl From<SocketAddr> for CanonicalIpPort {
95    fn from(addr: SocketAddr) -> Self {
96        Self {
97            canonical_ip: addr.ip().to_canonical(),
98            port: addr.port(),
99        }
100    }
101}
102
103impl From<IpPort> for CanonicalIpPort {
104    fn from((ip, port): IpPort) -> Self {
105        Self {
106            canonical_ip: ip.to_canonical(),
107            port,
108        }
109    }
110}
111
112/// Errors that the nat traversal state might encounter.
113#[derive(Debug, thiserror::Error)]
114pub enum Error {
115    /// An endpoint (local or remote) tried to add too many addresses to their advertised set
116    #[error("Tried to add too many addresses to their advertised set")]
117    TooManyAddresses,
118    /// The operation is not allowed for this endpoint's connection side
119    #[error("Not allowed for this endpoint's connection side")]
120    WrongConnectionSide,
121    /// The extension was not negotiated
122    #[error("n0's nat traversal was not negotiated")]
123    ExtensionNotNegotiated,
124    /// Not enough addresses to complete the operation
125    #[error("Not enough addresses")]
126    NotEnoughAddresses,
127    /// Nat traversal attempt failed due to a multipath error
128    #[error("Failed to establish paths {0}")]
129    Multipath(super::PathError),
130    /// Attempted to initiate NAT traversal on a closed, or closing connection.
131    #[error("The connection is already closed")]
132    Closed,
133}
134
135/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames.
136#[derive(Debug, Clone)]
137pub enum Event {
138    /// An ADD_ADDRESS frame was received.
139    AddressAdded(SocketAddr),
140    /// A REMOVE_ADDRESS frame was received.
141    AddressRemoved(SocketAddr),
142}
143
144/// State kept for n0's nat traversal
145#[derive(Debug, Default)]
146pub(crate) enum State {
147    #[default]
148    NotNegotiated,
149    ClientSide(ClientState),
150    ServerSide(ServerState),
151}
152
153impl State {
154    pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self {
155        match side {
156            Side::Client => Self::ClientSide(ClientState::new(
157                max_remote_addresses.into(),
158                max_local_addresses.into(),
159            )),
160            Side::Server => Self::ServerSide(ServerState::new(
161                max_remote_addresses.into(),
162                max_local_addresses.into(),
163            )),
164        }
165    }
166
167    pub(crate) fn is_negotiated(&self) -> bool {
168        match self {
169            Self::NotNegotiated => false,
170            Self::ClientSide(_) | Self::ServerSide(_) => true,
171        }
172    }
173
174    pub(crate) fn client_side(&self) -> Result<&ClientState, Error> {
175        match self {
176            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
177            Self::ClientSide(client_side) => Ok(client_side),
178            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
179        }
180    }
181
182    pub(crate) fn client_side_mut(&mut self) -> Result<&mut ClientState, Error> {
183        match self {
184            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
185            Self::ClientSide(client_side) => Ok(client_side),
186            Self::ServerSide(_) => Err(Error::WrongConnectionSide),
187        }
188    }
189
190    pub(crate) fn server_side_mut(&mut self) -> Result<&mut ServerState, Error> {
191        match self {
192            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
193            Self::ClientSide(_) => Err(Error::WrongConnectionSide),
194            Self::ServerSide(server_side) => Ok(server_side),
195        }
196    }
197
198    /// Adds a local address to use for nat traversal.
199    ///
200    /// When this endpoint is the server within the connection, these addresses will be sent to the
201    /// client in add address frames. For clients, these addresses will be sent in reach out frames
202    /// when nat traversal attempts are initiated.
203    ///
204    /// If a frame should be sent, it is returned.
205    pub(crate) fn add_local_address(
206        &mut self,
207        address: SocketAddr,
208    ) -> Result<Option<AddAddress>, Error> {
209        match self {
210            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
211            Self::ClientSide(client_state) => {
212                client_state.add_local_address(address)?;
213                Ok(None)
214            }
215            Self::ServerSide(server_state) => server_state.add_local_address(address),
216        }
217    }
218
219    /// Removes a local address from the advertised set for nat traversal.
220    ///
221    /// When this endpoint is the server, removed addresses must be reported with remove address
222    /// frames. Clients will simply stop reporting these addresses in reach out frames.
223    ///
224    /// If a frame should be sent, it is returned.
225    pub(crate) fn remove_local_address(
226        &mut self,
227        address: SocketAddr,
228    ) -> Result<Option<RemoveAddress>, Error> {
229        let address = IpPort::from((address.ip(), address.port()));
230        match self {
231            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
232            Self::ClientSide(client_state) => {
233                client_state.remove_local_address(&address);
234                Ok(None)
235            }
236            Self::ServerSide(server_state) => Ok(server_state.remove_local_address(&address)),
237        }
238    }
239
240    pub(crate) fn get_local_nat_traversal_addresses(&self) -> Result<Vec<SocketAddr>, Error> {
241        match self {
242            Self::NotNegotiated => Err(Error::ExtensionNotNegotiated),
243            Self::ClientSide(client_state) => Ok(client_state
244                .local_addresses
245                .iter()
246                .map(CanonicalIpPort::as_canonical_addr)
247                .collect()),
248            Self::ServerSide(server_state) => Ok(server_state
249                .local_addresses
250                .keys()
251                .map(CanonicalIpPort::as_canonical_addr)
252                .collect()),
253        }
254    }
255
256    /// Returns the next ready probe's address.
257    ///
258    /// If this is actually sent you must call [`Self::mark_probe_sent`].
259    pub(crate) fn next_probe_addr(&self) -> Option<IpPort> {
260        match self {
261            Self::NotNegotiated => None,
262            Self::ClientSide(state) => state.next_probe_addr(),
263            Self::ServerSide(state) => state.next_probe_addr(),
264        }
265    }
266
267    /// Marks a probe as sent to the address with the challenge.
268    pub(crate) fn mark_probe_sent(&mut self, remote: IpPort, challenge: u64) {
269        match self {
270            Self::NotNegotiated => (),
271            Self::ClientSide(state) => state.mark_probe_sent(remote, challenge),
272            Self::ServerSide(state) => state.mark_probe_sent(remote, challenge),
273        }
274    }
275
276    /// Re-queues probes that have not yet succeeded or reached 0 remaining retries.
277    ///
278    /// After calling [`Self::retry_delay`] must be checked.
279    pub(crate) fn queue_retries(&mut self, ipv6: bool) {
280        match self {
281            Self::NotNegotiated => (),
282            Self::ClientSide(state) => state.queue_retries(ipv6),
283            Self::ServerSide(state) => state.queue_retries(),
284        };
285    }
286
287    /// Marks a remote as successful if the response matches a sent probe.
288    ///
289    /// Returns true if it was a response to one of the NAT traversal probes and a path
290    /// needs to be opened. Note that the NAT probes are not padded to 1200 bytes so only
291    /// the address is validated, but not the entire path.
292    pub(crate) fn handle_path_response(&mut self, src: FourTuple, challenge: u64) -> bool {
293        match self {
294            Self::NotNegotiated => false,
295            Self::ClientSide(state) => state.handle_path_response(src, challenge),
296            Self::ServerSide(state) => state.handle_path_response(src, challenge),
297        }
298    }
299
300    /// Returns the delay to arm the `NatTraversalProbeRetry` timer.
301    ///
302    /// `initial_rtt` must be [`TransportConfig::initial_rtt`] so retries are scaled to this
303    /// value.
304    ///
305    /// [`TransportConfig::initial_rtt`]: crate::TransportConfig::initial_rtt
306    pub(crate) fn retry_delay(&self, initial_rtt: Duration) -> Option<Duration> {
307        match self {
308            Self::NotNegotiated => return None,
309            Self::ClientSide(state) => {
310                if !state
311                    .remote_addresses
312                    .values()
313                    .any(|(_, probes)| probes.remaining() > 0)
314                {
315                    return None;
316                }
317            }
318            Self::ServerSide(state) => {
319                if !state.remotes.values().any(|probes| probes.remaining() > 0) {
320                    return None;
321                }
322            }
323        }
324
325        let attempt = match self {
326            Self::NotNegotiated => return None,
327            Self::ClientSide(state) => state.attempt,
328            Self::ServerSide(state) => state.attempt,
329        };
330
331        // Retries follow at an exponential backoff, capped at max 2s interval. The base
332        // delay is initial_rtt/10, which for the default value means 33.3ms. Just under
333        // 10_000 km at the speed of light.
334        const MAX_BACKOFF_EXPONENT: u8 = 8;
335        const MAX_INTERVAL: Duration = Duration::from_secs(2);
336        let base = initial_rtt / 10;
337        let attempt = attempt.min(MAX_BACKOFF_EXPONENT) as u32;
338        let interval = match attempt {
339            0 => base * 2u32.pow(attempt),
340            _ => base * 2u32.pow(attempt) - base * 2u32.pow(attempt - 1),
341        };
342        Some(interval.min(MAX_INTERVAL))
343    }
344}
345
346#[derive(Debug)]
347pub(crate) struct ClientState {
348    /// Max number of remote addresses we allow
349    ///
350    /// This is set by the local endpoint.
351    max_remote_addresses: usize,
352    /// Max number of local addresses allowed
353    ///
354    /// This is set by the remote endpoint.
355    max_local_addresses: usize,
356    /// Candidate addresses the remote endpoint advertises.
357    ///
358    /// These are addresses on which the server is potentially reachable, to use for NAT
359    /// traversal attempts.
360    ///
361    /// They are indexed by their ADD_ADDRESS sequence id and stored in **canonical
362    /// form**. Not in the socket-native form as usual. This because we need to store them
363    /// so we have the correct sequence IDs.
364    remote_addresses: FxHashMap<VarInt, (CanonicalIpPort, ProbeState)>,
365    /// Candidate addresses for the local endpoint.
366    ///
367    /// These are addresses on which we are potentially reachable, to use for NAT traversal
368    /// attempts.
369    ///
370    /// They are stored in **canonical form**, not in socket-native form as usual. We may
371    /// nave a reflexive address that is IPv6 even if our local socket can only handle IPv4.
372    local_addresses: FxHashSet<CanonicalIpPort>,
373    /// Current nat traversal round.
374    round: VarInt,
375    /// The probing attempt in the round.
376    ///
377    /// Probes are sent to all remotes at the same time in a round, at intervals from
378    /// [`State::retry_delay`]. This is the number of times probes have been sent.
379    attempt: u8,
380    /// The data of PATH_CHALLENGE frames sent in probes.
381    ///
382    /// These are cleared when a new round starts, so any late-arriving PATH_RESPONSEs will
383    /// have no effect.
384    ///
385    /// They are stored in the usual socket-native form.
386    sent_challenges: FxHashMap<u64, IpPort>,
387    /// Queued probes to be sent in the next [`poll_transmit`] call.
388    ///
389    /// [`poll_transmit`]: crate::connection::Connection::poll_transmit
390    ///
391    /// They are stored in the usual socket-native form. Probes to address families not
392    /// addressable by the family are never inserted.
393    pending_probes: FxHashSet<IpPort>,
394    /// Network paths that were successfully probed but not yet opened.
395    ///
396    /// When we do not have enough CIDs or free path IDs we might not have been able to open
397    /// a new path. This allows us to try re-open the path when we get new CIDs or a new
398    /// MAX_PATH_ID.
399    // TODO(flub): perhaps there should be a time-limit on these?
400    paths_to_be_opened: Vec<FourTuple>,
401}
402
403impl ClientState {
404    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
405        Self {
406            max_remote_addresses,
407            max_local_addresses,
408            remote_addresses: Default::default(),
409            local_addresses: Default::default(),
410            round: Default::default(),
411            attempt: 0,
412            sent_challenges: Default::default(),
413            pending_probes: Default::default(),
414            paths_to_be_opened: Default::default(),
415        }
416    }
417
418    fn add_local_address(&mut self, address: SocketAddr) -> Result<(), Error> {
419        let address = CanonicalIpPort::from(address);
420        if self.local_addresses.len() < self.max_local_addresses {
421            self.local_addresses.insert(address);
422            Ok(())
423        } else if self.local_addresses.contains(&address) {
424            // at capacity, but the address is known, no issues here
425            Ok(())
426        } else {
427            // at capacity and the address is new
428            Err(Error::TooManyAddresses)
429        }
430    }
431
432    fn remove_local_address(&mut self, address: &IpPort) {
433        let address = CanonicalIpPort::from(*address);
434        self.local_addresses.remove(&address);
435    }
436
437    /// Initiates a new nat traversal round.
438    ///
439    /// A nat traversal round involves advertising the client's local addresses in
440    /// `REACH_OUT` frames, and initiating probing of the known remote addresses. When a new
441    /// round is initiated, the previous one is cancelled.
442    ///
443    /// `ipv6` indicates if the connection runs on a socket that supports IPv6. If so, then
444    /// all addresses returned [`PendingReachOutFrames`] will be IPv6 addresses (and
445    /// IPv4-mapped IPv6 addresses if necessary). Otherwise they're all IPv4 addresses.  See
446    /// also [`map_to_local_socket_family`].
447    ///
448    /// # Returns
449    ///
450    /// The REACH_OUT frames that need to be sent to the peer and the probed addresses. The
451    /// probed addresses are only informational, the pending probes are stored in
452    /// [`Self::pending_probes`].
453    ///
454    /// If the probed addresses are non-empty the `NatTraversalProbeRetry` timer needs to be
455    /// set.
456    pub(crate) fn initiate_nat_traversal_round(
457        &mut self,
458        ipv6: bool,
459    ) -> Result<(PendingReachOutFrames, Vec<SocketAddr>), Error> {
460        if self.local_addresses.is_empty() {
461            return Err(Error::NotEnoughAddresses);
462        }
463
464        self.round = self.round.saturating_add(1u8);
465        self.attempt = 0;
466        self.sent_challenges.clear();
467        self.pending_probes.clear();
468
469        // Enqueue the NAT probes to known remote addresses.
470        self.remote_addresses
471            .values_mut()
472            .for_each(|(ip_port, state)| {
473                if let Some(ip_port) = ip_port.as_local_socket_family(ipv6) {
474                    self.pending_probes.insert(ip_port);
475                    *state = ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS - 1);
476                } else {
477                    trace!(%ip_port, "not using IPv6 NAT candidate for IPv4 socket");
478                    *state = ProbeState::Active(0);
479                }
480            });
481        let probed_addrs: Vec<SocketAddr> = self
482            .pending_probes
483            .iter()
484            .copied()
485            .map(Into::into)
486            .collect();
487
488        // Build the REACH_OUT frames.
489        let reach_out_frames: PendingReachOutFrames = self
490            .local_addresses
491            .iter()
492            .map(|ip_port| ReachOut {
493                round: self.round,
494                ip: ip_port.ip(),
495                port: ip_port.port(),
496            })
497            .collect();
498
499        trace!(
500            round = %self.round,
501            reach_out = %reach_out_frames.len(),
502            to_probe = %self.pending_probes.len(),
503            "initiating NAT traversal round",
504        );
505        Ok((reach_out_frames, probed_addrs))
506    }
507
508    /// Re-queues probes that have not yet succeeded or reached 0 remaining retries.
509    ///
510    /// Returns whether any probes are now queued to send. In this case the
511    /// `NatTraversalProbeRetry` timer needs to be reset.
512    ///
513    /// `ipv6` as for [`Self::initiate_nat_traversal_round`].
514    pub(crate) fn queue_retries(&mut self, ipv6: bool) {
515        self.attempt += 1;
516        self.remote_addresses
517            .values_mut()
518            .for_each(|(ip_port, state)| match state {
519                ProbeState::Active(remaining) if *remaining > 0 => {
520                    *remaining -= 1;
521                    if let Some(ip_port) = ip_port.as_local_socket_family(ipv6) {
522                        self.pending_probes.insert(ip_port);
523                    } else {
524                        trace!(%ip_port, "skipping IPv6 NAT candidate for IPv4 socket");
525                        *remaining = 0;
526                    }
527                }
528                ProbeState::Active(_) | ProbeState::Succeeded => {}
529            });
530    }
531
532    /// Returns the next ready probe's address.
533    ///
534    /// If this is actually sent you must call [`Self::mark_probe_sent`].
535    fn next_probe_addr(&self) -> Option<IpPort> {
536        self.pending_probes.iter().next().copied()
537    }
538
539    /// Marks a probe as sent to the address with the challenge.
540    fn mark_probe_sent(&mut self, remote: IpPort, challenge: u64) {
541        self.pending_probes.remove(&remote);
542        self.sent_challenges.insert(challenge, remote);
543    }
544
545    /// Adds an address to the remote set.
546    ///
547    /// On success returns the address if it was new to the set. It will error when the set
548    /// has no capacity for the address.
549    ///
550    /// If this is called while a round is in progress this will effectively add the address
551    /// to the current round. There is no guarantee however that the current round is still
552    /// in progress however, if the last [`Self::queue_retries`] call returned `false` the
553    /// round has stopped.
554    // TODO(flub): probably should add an event to signal that the round is finished, so
555    //    that the application knows to start a new round.
556    pub(crate) fn add_remote_address(
557        &mut self,
558        add_addr: AddAddress,
559    ) -> Result<Option<SocketAddr>, Error> {
560        let AddAddress { seq_no, ip, port } = add_addr;
561        let address = CanonicalIpPort::from((ip, port));
562        let allow_new = self.remote_addresses.len() < self.max_remote_addresses;
563        match self.remote_addresses.entry(seq_no) {
564            Entry::Occupied(mut occupied_entry) => {
565                let is_update = occupied_entry.get().0 != address;
566                if is_update {
567                    occupied_entry.insert((address, ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS)));
568                }
569                // The value might be different. This should not happen, but we assume that the new
570                // address is more recent than the previous, and thus worth updating
571                Ok(is_update.then_some(address.as_canonical_addr()))
572            }
573            Entry::Vacant(vacant_entry) if allow_new => {
574                vacant_entry.insert((address, ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS)));
575                Ok(Some(address.as_canonical_addr()))
576            }
577            _ => Err(Error::TooManyAddresses),
578        }
579    }
580
581    /// Removes an address from the remote set.
582    ///
583    /// Returns whether the address was present.
584    pub(crate) fn remove_remote_address(
585        &mut self,
586        remove_addr: RemoveAddress,
587    ) -> Option<SocketAddr> {
588        self.remote_addresses
589            .remove(&remove_addr.seq_no)
590            .map(|(address, _)| address.as_canonical_addr())
591    }
592
593    /// Checks that a received remote address is valid.
594    ///
595    /// An address is valid as long as it does not change the value of a known address id.
596    pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool {
597        match self.remote_addresses.get(&add_addr.seq_no) {
598            None => true,
599            Some((existing, _)) => *existing == CanonicalIpPort::from(add_addr.ip_port()),
600        }
601    }
602
603    pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Vec<SocketAddr> {
604        self.remote_addresses
605            .values()
606            .map(|(address, _)| (*address).as_canonical_addr())
607            .collect()
608    }
609
610    /// Marks a remote as successful if the response matches a sent probe.
611    ///
612    /// Returns `true` if it was a response to one of the NAT traversal probes. In that case
613    /// [`Self::pop_pending_path_open`] should be called to open the next path.
614    fn handle_path_response(&mut self, network_path: FourTuple, challenge: u64) -> bool {
615        if let Entry::Occupied(entry) = self.sent_challenges.entry(challenge) {
616            let remote = (network_path.remote().ip(), network_path.remote().port());
617            if *entry.get() == remote {
618                entry.remove();
619
620                // self.remote_addresses is stored in canonical form.
621                let remote = CanonicalIpPort::from(remote);
622                // TODO: linear search is sad.
623                if let Some(seq) = self
624                    .remote_addresses
625                    .iter()
626                    .filter_map(
627                        |(seq, (addr, _))| {
628                            if *addr == remote { Some(*seq) } else { None }
629                        },
630                    )
631                    .next()
632                {
633                    trace!(
634                        ?network_path,
635                        challenge = %display(format_args!("0x{challenge:x}")),
636                        "Received valid NAT traversal probe response",
637                    );
638                    self.remote_addresses
639                        .insert(seq, (remote, ProbeState::Succeeded));
640                    self.paths_to_be_opened.push(network_path);
641                    return true;
642                } else {
643                    debug!("inconsistent remote addrs and seq");
644                }
645            } else {
646                debug!(
647                    ?network_path.remote,
648                    expected_remote = ?entry.get(),
649                    challenge = %display(format_args!("0x{challenge:x}")),
650                    "PATH_RESPONSE matched a NAT traversal probe but mismatching addr",
651                )
652            }
653        }
654        false
655    }
656
657    /// Returns a path that was NAT traversed and needs to be opened.
658    pub(crate) fn pop_pending_path_open(&mut self) -> Option<FourTuple> {
659        self.paths_to_be_opened.pop()
660    }
661
662    /// Put back a path that needs to be opened, e.g. for a temporary failure.
663    pub(crate) fn push_pending_path_open(&mut self, network_path: FourTuple) {
664        self.paths_to_be_opened.push(network_path)
665    }
666}
667
668/// State of an off-path NAT traversal probe to a remote address.
669#[derive(Debug)]
670enum ProbeState {
671    /// The remote still needs to be probed in this round.
672    ///
673    /// The remaining number of retries are stored in the `u8`.
674    Active(u8),
675    /// We received a probe response for this remote.
676    Succeeded,
677}
678
679impl ProbeState {
680    /// Returns the remaining number of probes to try for this remote.
681    fn remaining(&self) -> u8 {
682        match self {
683            Self::Active(remaining) => *remaining,
684            Self::Succeeded => 0,
685        }
686    }
687}
688
689#[derive(Debug)]
690pub(crate) struct ServerState {
691    /// Max number of remote addresses we allow.
692    ///
693    /// This is set by the local endpoint.
694    max_remote_addresses: usize,
695    /// Max number of local addresses allowed.
696    ///
697    /// This is set by the remote endpoint.
698    max_local_addresses: usize,
699    /// Candidate addresses the server reports as potentially reachable, to use for nat
700    /// traversal attempts.
701    ///
702    /// They are stored in **canonical form**, not in socket-native form as usual. We may
703    /// nave a reflexive address that is IPv6 even if our local socket can only handle IPv4.
704    local_addresses: FxHashMap<CanonicalIpPort, VarInt>,
705    /// The next id to use for local addresses sent to the client.
706    next_local_addr_id: VarInt,
707    /// Current nat traversal round
708    ///
709    /// Servers keep track of the client's most recent round and cancel probing related to previous
710    /// rounds.
711    round: VarInt,
712    /// The probing attempt in the round.
713    ///
714    /// Probes are sent to all remotes at the same time in a round, at intervals from
715    /// [`State::retry_delay`]. This is the number of times probes have been sent.
716    attempt: u8,
717    /// The remote addresses participating in this round.
718    ///
719    /// The set is cleared when a new round starts.
720    ///
721    /// These are stored in the usual local-socket native form.
722    remotes: FxHashMap<IpPort, ProbeState>,
723    /// The data of PATH_CHALLENGE frames sent in probes.
724    ///
725    /// These are cleared when a new round starts, so any late-arriving PATH_RESPONSEs will
726    /// have no effect.
727    sent_challenges: FxHashMap<u64, IpPort>,
728    /// Queued probes to be sent in the next [`poll_transmit`] call.
729    ///
730    /// At the beginning of a round this is populated from REACH_OUT frames and at every
731    /// retry this is populated from [`Self::remotes`].
732    ///
733    /// [`poll_transmit`]: crate::connection::Connection::poll_transmit
734    pending_probes: FxHashSet<IpPort>,
735}
736
737impl ServerState {
738    fn new(max_remote_addresses: usize, max_local_addresses: usize) -> Self {
739        Self {
740            max_remote_addresses,
741            max_local_addresses,
742            local_addresses: Default::default(),
743            next_local_addr_id: Default::default(),
744            round: Default::default(),
745            attempt: 0,
746            remotes: Default::default(),
747            sent_challenges: Default::default(),
748            pending_probes: Default::default(),
749        }
750    }
751
752    fn add_local_address(&mut self, address: SocketAddr) -> Result<Option<AddAddress>, Error> {
753        let address = CanonicalIpPort::from(address);
754        let allow_new = self.local_addresses.len() < self.max_local_addresses;
755        match self.local_addresses.entry(address) {
756            Entry::Occupied(_) => Ok(None),
757            Entry::Vacant(vacant_entry) if allow_new => {
758                let id = self.next_local_addr_id;
759                self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8);
760                vacant_entry.insert(id);
761                Ok(Some(AddAddress::new((address.ip(), address.port()), id)))
762            }
763            _ => Err(Error::TooManyAddresses),
764        }
765    }
766
767    fn remove_local_address(&mut self, address: &IpPort) -> Option<RemoveAddress> {
768        let address = CanonicalIpPort::from(*address);
769        self.local_addresses
770            .remove(&address)
771            .map(RemoveAddress::new)
772    }
773
774    /// Returns the current NAT traversal round number.
775    pub(crate) fn current_round(&self) -> VarInt {
776        self.round
777    }
778
779    /// Handles a received REACH_OUT frame.
780    ///
781    /// This might ignore the reach out frame if it belongs to an older round or if the
782    /// frame contains an IPv6 address while the local socket is IPv4-only.
783    ///
784    /// If a new round was started, the `NatTraversalProbeRetry` timer needs to be reset.
785    pub(crate) fn handle_reach_out(
786        &mut self,
787        reach_out: ReachOut,
788        ipv6: bool,
789    ) -> Result<(), Error> {
790        let ReachOut { round, ip, port } = reach_out;
791
792        if round < self.round {
793            trace!(current_round=%self.round, "ignoring REACH_OUT for previous round");
794            return Ok(());
795        }
796        let Some(ip) = map_to_local_socket_family(ip, ipv6) else {
797            trace!("Ignoring IPv6 REACH_OUT frame due to not supporting IPv6 locally");
798            return Ok(());
799        };
800
801        if round > self.round {
802            self.round = round;
803            self.attempt = 0;
804            self.remotes.clear();
805            self.sent_challenges.clear();
806            self.pending_probes.clear();
807        } else if self.remotes.contains_key(&(ip, port)) {
808            // Retransmitted frame.
809            return Ok(());
810        } else if self.remotes.len() >= self.max_remote_addresses {
811            return Err(Error::TooManyAddresses);
812        }
813        self.remotes
814            .entry((ip, port))
815            .or_insert(ProbeState::Active(MAX_NAT_PROBE_ATTEMPTS - 1));
816        self.pending_probes.insert((ip, port));
817        Ok(())
818    }
819
820    /// Re-queues probes that have not yet succeeded or reached [`MAX_NAT_PROBE_ATTEMPTS`].
821    ///
822    /// Returns whether any probes are now queued to send. In this case the
823    /// `NatTraversalProbeRetry` timer needs to be reset.
824    pub(crate) fn queue_retries(&mut self) {
825        self.attempt += 1;
826        self.remotes
827            .iter_mut()
828            .for_each(|(remote, state)| match state {
829                ProbeState::Active(remaining) if *remaining > 0 => {
830                    *remaining -= 1;
831                    self.pending_probes.insert(*remote);
832                }
833                ProbeState::Active(_) | ProbeState::Succeeded => (),
834            });
835    }
836
837    /// Returns the next ready probe's address.
838    ///
839    /// If this is actually sent you must call [`Self::mark_probe_sent`].
840    fn next_probe_addr(&self) -> Option<IpPort> {
841        self.pending_probes.iter().next().cloned()
842    }
843
844    /// Marks a probe as sent to the address with the challenge.
845    fn mark_probe_sent(&mut self, remote: IpPort, challenge: u64) {
846        self.pending_probes.remove(&remote);
847        self.sent_challenges.insert(challenge, remote);
848    }
849
850    /// Marks a remote as successful if the response matches a sent probe.
851    ///
852    /// Returns `true` if it was a response to one of the NAT traversal probes.
853    fn handle_path_response(&mut self, src: FourTuple, challenge: u64) -> bool {
854        if let Entry::Occupied(entry) = self.sent_challenges.entry(challenge) {
855            let remote = (src.remote().ip(), src.remote().port());
856            if *entry.get() == remote {
857                entry.remove();
858                self.remotes.insert(remote, ProbeState::Succeeded);
859                return true;
860            } else {
861                debug!(
862                    ?challenge,
863                    ?src.remote,
864                    "PATH_RESPONSE matched a NAT traversal probe but mismatching addr",
865                )
866            }
867        }
868        false
869    }
870}
871
872/// Returns the given address as canonicalized IP address.
873///
874/// This checks that the address family is supported by our local socket.
875/// If it is supported, then the address is mapped to the respective IP address.
876/// If the given address is an IPv6 address, but our local socket doesn't support
877/// IPv6, then this returns `None`.
878pub(crate) fn map_to_local_socket_family(address: IpAddr, ipv6: bool) -> Option<IpAddr> {
879    let ip = match address {
880        IpAddr::V4(addr) if ipv6 => IpAddr::V6(addr.to_ipv6_mapped()),
881        IpAddr::V4(_) => address,
882        IpAddr::V6(_) if ipv6 => address,
883        IpAddr::V6(addr) => IpAddr::V4(addr.to_ipv4_mapped()?),
884    };
885    Some(ip)
886}
887
888#[cfg(test)]
889mod tests {
890    use testresult::TestResult;
891
892    use super::*;
893
894    #[test]
895    fn test_basic_server_state() {
896        let mut state = ServerState::new(2, 2);
897
898        state
899            .handle_reach_out(
900                ReachOut {
901                    round: 1u32.into(),
902                    ip: std::net::Ipv4Addr::LOCALHOST.into(),
903                    port: 1,
904                },
905                true,
906            )
907            .unwrap();
908
909        state
910            .handle_reach_out(
911                ReachOut {
912                    round: 1u32.into(),
913                    ip: "1.1.1.1".parse().unwrap(), //std::net::Ipv4Addr::LOCALHOST.into(),
914                    port: 2,
915                },
916                true,
917            )
918            .unwrap();
919
920        dbg!(&state);
921        assert_eq!(state.pending_probes.len(), 2);
922
923        // Helper: send next ready probe
924        let mut challenge = 0;
925        let mut send_probe = |state: &mut ServerState| {
926            let remote = state.next_probe_addr().unwrap();
927            challenge += 1;
928            state.mark_probe_sent(remote, challenge);
929        };
930
931        send_probe(&mut state);
932        send_probe(&mut state);
933
934        // After sending both probes, no ready probes remain but they're still tracked.
935        assert!(state.next_probe_addr().is_none());
936
937        // After queuing retries, probes become available again
938        state.queue_retries();
939        send_probe(&mut state);
940        send_probe(&mut state);
941
942        // After 2 attempts each, retries still available (max is 10)
943        state.queue_retries();
944        send_probe(&mut state);
945        send_probe(&mut state);
946
947        // Exhaust remaining attempts
948        for _ in 3..MAX_NAT_PROBE_ATTEMPTS {
949            state.queue_retries();
950            send_probe(&mut state);
951            send_probe(&mut state);
952        }
953
954        // After max attempts, probes are removed
955        state.queue_retries();
956        assert!(state.next_probe_addr().is_none());
957    }
958
959    #[test]
960    fn test_map_to_local_socket() {
961        assert_eq!(
962            map_to_local_socket_family("1.1.1.1".parse().unwrap(), false),
963            Some("1.1.1.1".parse().unwrap())
964        );
965        assert_eq!(
966            map_to_local_socket_family("1.1.1.1".parse().unwrap(), true),
967            Some("::ffff:1.1.1.1".parse().unwrap())
968        );
969        assert_eq!(
970            map_to_local_socket_family("::1".parse().unwrap(), true),
971            Some("::1".parse().unwrap())
972        );
973        assert_eq!(
974            map_to_local_socket_family("::1".parse().unwrap(), false),
975            None
976        );
977        assert_eq!(
978            map_to_local_socket_family("::ffff:1.1.1.1".parse().unwrap(), false),
979            Some("1.1.1.1".parse().unwrap())
980        )
981    }
982
983    #[test]
984    fn test_retry_delay_server_ipv6() -> TestResult {
985        let initial_rtt = Duration::from_millis(333);
986        let ipv6 = true;
987        let remote = SocketAddr::from(("::2".parse::<IpAddr>()?, 2));
988        let remote_ipp = (remote.ip(), remote.port());
989
990        let mut nat = State::new(8, 8, Side::Server);
991
992        nat.server_side_mut()?.handle_reach_out(
993            ReachOut {
994                round: 1u8.into(),
995                ip: remote.ip(),
996                port: remote.port(),
997            },
998            ipv6,
999        )?;
1000
1001        let challenges = [1u64, 2, 3, 4, 5, 6, 7];
1002        let delays = [
1003            33_300u64, 66_600, 133_200, 266_400, 532_800, 1_065_600, 2_000_000,
1004        ];
1005        for (challenge, delay) in challenges.into_iter().zip(delays) {
1006            nat.queue_retries(ipv6);
1007            assert_eq!(nat.next_probe_addr(), Some(remote_ipp));
1008            nat.mark_probe_sent(remote_ipp, challenge);
1009            assert_eq!(
1010                nat.retry_delay(initial_rtt),
1011                Some(Duration::from_micros(delay)),
1012                "challenge: {challenge}"
1013            );
1014        }
1015
1016        assert!(nat.handle_path_response(
1017            FourTuple {
1018                remote,
1019                local_ip: Some("::3".parse::<IpAddr>()?),
1020            },
1021            challenges[6]
1022        ));
1023        assert_eq!(nat.retry_delay(initial_rtt), None);
1024
1025        Ok(())
1026    }
1027
1028    #[test]
1029    fn test_retry_delay_client_ipv6() -> TestResult {
1030        let initial_rtt = Duration::from_millis(333);
1031        let ipv6 = true;
1032        let remote = SocketAddr::from(("::2".parse::<IpAddr>()?, 2));
1033        let remote_ipp = (remote.ip(), remote.port());
1034        let local_addr = SocketAddr::from(("::3".parse::<IpAddr>()?, 3));
1035
1036        let mut nat = State::new(8, 8, Side::Client);
1037        nat.add_local_address(local_addr)?;
1038        nat.client_side_mut()?.add_remote_address(AddAddress {
1039            seq_no: 1u8.into(),
1040            ip: remote.ip(),
1041            port: remote.port(),
1042        })?;
1043        nat.client_side_mut()?.initiate_nat_traversal_round(ipv6)?;
1044
1045        let challenges = [1u64, 2, 3, 4, 5, 6, 7];
1046        let delays = [
1047            33_300u64, 66_600, 133_200, 266_400, 532_800, 1_065_600, 2_000_000,
1048        ];
1049        for (challenge, delay) in challenges.into_iter().zip(delays) {
1050            nat.queue_retries(ipv6);
1051            assert_eq!(nat.next_probe_addr(), Some(remote_ipp));
1052            nat.mark_probe_sent(remote_ipp, challenge);
1053            assert_eq!(
1054                nat.retry_delay(initial_rtt),
1055                Some(Duration::from_micros(delay)),
1056                "challenge: {challenge}"
1057            );
1058        }
1059
1060        assert!(nat.handle_path_response(
1061            FourTuple {
1062                remote,
1063                local_ip: Some(local_addr.ip()),
1064            },
1065            challenges[6]
1066        ));
1067        assert_eq!(nat.retry_delay(initial_rtt), None);
1068
1069        Ok(())
1070    }
1071}