noq_proto/
cid_generator.rs1use std::hash::Hasher;
2
3use rand::Rng;
4use rand::RngExt;
5
6use crate::Duration;
7use crate::MAX_CID_SIZE;
8use crate::shared::ConnectionId;
9
10pub trait ConnectionIdGenerator: Send + Sync {
12 fn generate_cid(&mut self) -> ConnectionId;
20
21 fn validate(&self, _cid: ConnectionId) -> Result<(), InvalidCid> {
25 Ok(())
26 }
27
28 fn cid_len(&self) -> usize;
30 fn cid_lifetime(&self) -> Option<Duration>;
34}
35
36#[derive(Debug, Copy, Clone)]
38pub struct InvalidCid;
39
40#[derive(Debug, Clone, Copy)]
45pub struct RandomConnectionIdGenerator {
46 cid_len: usize,
47 lifetime: Option<Duration>,
48}
49
50impl Default for RandomConnectionIdGenerator {
51 fn default() -> Self {
52 Self {
53 cid_len: 8,
54 lifetime: None,
55 }
56 }
57}
58
59impl RandomConnectionIdGenerator {
60 pub fn new(cid_len: usize) -> Self {
64 debug_assert!(cid_len <= MAX_CID_SIZE);
65 Self {
66 cid_len,
67 ..Self::default()
68 }
69 }
70
71 pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
73 self.lifetime = Some(d);
74 self
75 }
76}
77
78impl ConnectionIdGenerator for RandomConnectionIdGenerator {
79 fn generate_cid(&mut self) -> ConnectionId {
80 let mut bytes_arr = [0; MAX_CID_SIZE];
81 rand::rng().fill_bytes(&mut bytes_arr[..self.cid_len]);
82
83 ConnectionId::new(&bytes_arr[..self.cid_len])
84 }
85
86 fn cid_len(&self) -> usize {
88 self.cid_len
89 }
90
91 fn cid_lifetime(&self) -> Option<Duration> {
92 self.lifetime
93 }
94}
95
96pub struct HashedConnectionIdGenerator {
102 key: u64,
103 lifetime: Option<Duration>,
104}
105
106impl HashedConnectionIdGenerator {
107 pub fn new() -> Self {
109 Self::from_key(rand::rng().random())
110 }
111
112 pub fn from_key(key: u64) -> Self {
117 Self {
118 key,
119 lifetime: None,
120 }
121 }
122
123 pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
125 self.lifetime = Some(d);
126 self
127 }
128}
129
130impl Default for HashedConnectionIdGenerator {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl ConnectionIdGenerator for HashedConnectionIdGenerator {
137 fn generate_cid(&mut self) -> ConnectionId {
138 let mut bytes_arr = [0; NONCE_LEN + SIGNATURE_LEN];
139 rand::rng().fill_bytes(&mut bytes_arr[..NONCE_LEN]);
140 let mut hasher = rustc_hash::FxHasher::default();
141 hasher.write_u64(self.key);
142 hasher.write(&bytes_arr[..NONCE_LEN]);
143 bytes_arr[NONCE_LEN..].copy_from_slice(&hasher.finish().to_le_bytes()[..SIGNATURE_LEN]);
144 ConnectionId::new(&bytes_arr)
145 }
146
147 fn validate(&self, cid: ConnectionId) -> Result<(), InvalidCid> {
148 let (nonce, signature) = cid.split_at(NONCE_LEN);
149 let mut hasher = rustc_hash::FxHasher::default();
150 hasher.write_u64(self.key);
151 hasher.write(nonce);
152 let expected = hasher.finish().to_le_bytes();
153 match expected[..SIGNATURE_LEN] == signature[..] {
154 true => Ok(()),
155 false => Err(InvalidCid),
156 }
157 }
158
159 fn cid_len(&self) -> usize {
160 NONCE_LEN + SIGNATURE_LEN
161 }
162
163 fn cid_lifetime(&self) -> Option<Duration> {
164 self.lifetime
165 }
166}
167
168const NONCE_LEN: usize = 3; const SIGNATURE_LEN: usize = 8 - NONCE_LEN; #[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn validate_keyed_cid() {
177 let mut generator = HashedConnectionIdGenerator::new();
178 let cid = generator.generate_cid();
179 generator.validate(cid).unwrap();
180 }
181}