noq_proto/
cid_generator.rs

1use std::hash::Hasher;
2
3use rand::Rng;
4use rand::RngExt;
5
6use crate::Duration;
7use crate::MAX_CID_SIZE;
8use crate::shared::ConnectionId;
9
10/// Generates connection IDs for incoming connections
11pub trait ConnectionIdGenerator: Send + Sync {
12    /// Generates a new CID
13    ///
14    /// Connection IDs MUST NOT contain any information that can be used by
15    /// an external observer (that is, one that does not cooperate with the
16    /// issuer) to correlate them with other connection IDs for the same
17    /// connection. They MUST have high entropy, e.g. due to encrypted data
18    /// or cryptographic-grade random data.
19    fn generate_cid(&mut self) -> ConnectionId;
20
21    /// Quickly determine whether `cid` could have been generated by this generator
22    ///
23    /// False positives are permitted, but increase the cost of handling invalid packets.
24    fn validate(&self, _cid: ConnectionId) -> Result<(), InvalidCid> {
25        Ok(())
26    }
27
28    /// Returns the length of a CID for connections created by this generator
29    fn cid_len(&self) -> usize;
30    /// Returns the lifetime of generated Connection IDs
31    ///
32    /// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
33    fn cid_lifetime(&self) -> Option<Duration>;
34}
35
36/// The connection ID was not recognized by the [`ConnectionIdGenerator`]
37#[derive(Debug, Copy, Clone)]
38pub struct InvalidCid;
39
40/// Generates purely random connection IDs of a specified length
41///
42/// Random CIDs can be smaller than those produced by [`HashedConnectionIdGenerator`], but cannot be
43/// usefully [`validate`](ConnectionIdGenerator::validate)d.
44#[derive(Debug, Clone, Copy)]
45pub struct RandomConnectionIdGenerator {
46    cid_len: usize,
47    lifetime: Option<Duration>,
48}
49
50impl Default for RandomConnectionIdGenerator {
51    fn default() -> Self {
52        Self {
53            cid_len: 8,
54            lifetime: None,
55        }
56    }
57}
58
59impl RandomConnectionIdGenerator {
60    /// Initialize Random CID generator with a fixed CID length
61    ///
62    /// The given length must be less than or equal to MAX_CID_SIZE.
63    pub fn new(cid_len: usize) -> Self {
64        debug_assert!(cid_len <= MAX_CID_SIZE);
65        Self {
66            cid_len,
67            ..Self::default()
68        }
69    }
70
71    /// Set the lifetime of CIDs created by this generator
72    pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
73        self.lifetime = Some(d);
74        self
75    }
76}
77
78impl ConnectionIdGenerator for RandomConnectionIdGenerator {
79    fn generate_cid(&mut self) -> ConnectionId {
80        let mut bytes_arr = [0; MAX_CID_SIZE];
81        rand::rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
82
83        ConnectionId::new(&bytes_arr[..self.cid_len])
84    }
85
86    /// Provide the length of dst_cid in short header packet
87    fn cid_len(&self) -> usize {
88        self.cid_len
89    }
90
91    fn cid_lifetime(&self) -> Option<Duration> {
92        self.lifetime
93    }
94}
95
96/// Generates 8-byte connection IDs that can be efficiently
97/// [`validate`](ConnectionIdGenerator::validate)d
98///
99/// This generator uses a non-cryptographic hash and can therefore still be spoofed, but nonetheless
100/// helps prevents noq from responding to non-QUIC packets at very low cost.
101pub struct HashedConnectionIdGenerator {
102    key: u64,
103    lifetime: Option<Duration>,
104}
105
106impl HashedConnectionIdGenerator {
107    /// Create a generator with a random key
108    pub fn new() -> Self {
109        Self::from_key(rand::rng().random())
110    }
111
112    /// Create a generator with a specific key
113    ///
114    /// Allows [`validate`](ConnectionIdGenerator::validate) to recognize a consistent set of
115    /// connection IDs across restarts
116    pub fn from_key(key: u64) -> Self {
117        Self {
118            key,
119            lifetime: None,
120        }
121    }
122
123    /// Set the lifetime of CIDs created by this generator
124    pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
125        self.lifetime = Some(d);
126        self
127    }
128}
129
130impl Default for HashedConnectionIdGenerator {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl ConnectionIdGenerator for HashedConnectionIdGenerator {
137    fn generate_cid(&mut self) -> ConnectionId {
138        let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
139        rand::rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
140        let mut hasher = rustc_hash::FxHasher::default();
141        hasher.write_u64(self.key);
142        hasher.write(&bytes_arr[..NONCE_LEN]);
143        bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
144        ConnectionId::new(&bytes_arr)
145    }
146
147    fn validate(&self, cid: ConnectionId) -> Result<(), InvalidCid> {
148        let (nonce, signature) = cid.split_at(NONCE_LEN);
149        let mut hasher = rustc_hash::FxHasher::default();
150        hasher.write_u64(self.key);
151        hasher.write(nonce);
152        let expected = hasher.finish().to_le_bytes();
153        match expected[..SIGNATURE_LEN] == signature[..] {
154            true => Ok(()),
155            false => Err(InvalidCid),
156        }
157    }
158
159    fn cid_len(&self) -> usize {
160        NONCE_LEN + SIGNATURE_LEN
161    }
162
163    fn cid_lifetime(&self) -> Option<Duration> {
164        self.lifetime
165    }
166}
167
168const NONCE_LEN: usize = 3; // Good for more than 16 million connections
169const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn validate_keyed_cid() {
177        let mut generator = HashedConnectionIdGenerator::new();
178        let cid = generator.generate_cid();
179        generator.validate(cid).unwrap();
180    }
181}