1use std::{
2 fmt,
3 mem::size_of,
4 net::{IpAddr, SocketAddr},
5};
6
7use bytes::{Buf, BufMut, Bytes};
8use rand::Rng;
9
10use crate::{
11 Duration, RESET_TOKEN_SIZE, ServerConfig, SystemTime, UNIX_EPOCH,
12 coding::{BufExt, BufMutExt},
13 crypto::{HandshakeTokenKey, HmacKey},
14 packet::InitialHeader,
15 shared::ConnectionId,
16};
17
18pub trait TokenLog: Send + Sync {
32 fn check_and_insert(
59 &self,
60 nonce: u128,
61 issued: SystemTime,
62 lifetime: Duration,
63 ) -> Result<(), TokenReuseError>;
64}
65
66pub struct TokenReuseError;
68
69pub struct NoneTokenLog;
71
72impl TokenLog for NoneTokenLog {
73 fn check_and_insert(&self, _: u128, _: SystemTime, _: Duration) -> Result<(), TokenReuseError> {
74 Err(TokenReuseError)
75 }
76}
77
78pub trait TokenStore: Send + Sync {
81 fn insert(&self, server_name: &str, token: Bytes);
85
86 fn take(&self, server_name: &str) -> Option<Bytes>;
93}
94
95pub struct NoneTokenStore;
97
98impl TokenStore for NoneTokenStore {
99 fn insert(&self, _: &str, _: Bytes) {}
100 fn take(&self, _: &str) -> Option<Bytes> {
101 None
102 }
103}
104
105#[derive(Debug)]
107pub(crate) struct IncomingToken {
108 pub(crate) retry_src_cid: Option<ConnectionId>,
109 pub(crate) orig_dst_cid: ConnectionId,
110 pub(crate) validated: bool,
111}
112
113impl IncomingToken {
114 pub(crate) fn from_header(
117 header: &InitialHeader,
118 server_config: &ServerConfig,
119 remote_address: SocketAddr,
120 ) -> Result<Self, InvalidRetryTokenError> {
121 let unvalidated = Self {
122 retry_src_cid: None,
123 orig_dst_cid: header.dst_cid,
124 validated: false,
125 };
126
127 if header.token.is_empty() {
129 return Ok(unvalidated);
130 }
131
132 let Some(retry) = Token::decode(&*server_config.token_key, &header.token) else {
142 return Ok(unvalidated);
143 };
144
145 match retry.payload {
147 TokenPayload::Retry {
148 address,
149 orig_dst_cid,
150 issued,
151 } => {
152 if address != remote_address {
153 return Err(InvalidRetryTokenError);
154 }
155 if issued + server_config.retry_token_lifetime < server_config.time_source.now() {
156 return Err(InvalidRetryTokenError);
157 }
158
159 Ok(Self {
160 retry_src_cid: Some(header.dst_cid),
161 orig_dst_cid,
162 validated: true,
163 })
164 }
165 TokenPayload::Validation { ip, issued } => {
166 if ip != remote_address.ip() {
167 return Ok(unvalidated);
168 }
169 if issued + server_config.validation_token.lifetime
170 < server_config.time_source.now()
171 {
172 return Ok(unvalidated);
173 }
174 if server_config
175 .validation_token
176 .log
177 .check_and_insert(retry.nonce, issued, server_config.validation_token.lifetime)
178 .is_err()
179 {
180 return Ok(unvalidated);
181 }
182
183 Ok(Self {
184 retry_src_cid: None,
185 orig_dst_cid: header.dst_cid,
186 validated: true,
187 })
188 }
189 }
190 }
191}
192
193pub(crate) struct InvalidRetryTokenError;
197
198pub(crate) struct Token {
200 pub(crate) payload: TokenPayload,
202 nonce: u128,
204}
205
206impl Token {
207 pub(crate) fn new(payload: TokenPayload, rng: &mut impl Rng) -> Self {
209 Self {
210 nonce: rng.random(),
211 payload,
212 }
213 }
214
215 pub(crate) fn encode(&self, key: &dyn HandshakeTokenKey) -> Vec<u8> {
217 let mut buf = Vec::new();
218
219 match self.payload {
221 TokenPayload::Retry {
222 address,
223 orig_dst_cid,
224 issued,
225 } => {
226 buf.put_u8(TokenType::Retry as u8);
227 encode_addr(&mut buf, address);
228 orig_dst_cid.encode_long(&mut buf);
229 encode_unix_secs(&mut buf, issued);
230 }
231 TokenPayload::Validation { ip, issued } => {
232 buf.put_u8(TokenType::Validation as u8);
233 encode_ip(&mut buf, ip);
234 encode_unix_secs(&mut buf, issued);
235 }
236 }
237
238 let aead_key = key.aead_from_hkdf(&self.nonce.to_le_bytes());
240 aead_key.seal(&mut buf, &[]).unwrap();
241 buf.extend(&self.nonce.to_le_bytes());
242
243 buf
244 }
245
246 fn decode(key: &dyn HandshakeTokenKey, raw_token_bytes: &[u8]) -> Option<Self> {
248 let nonce_slice_start = raw_token_bytes.len().checked_sub(size_of::<u128>())?;
251 let (sealed_token, nonce_bytes) = raw_token_bytes.split_at_checked(nonce_slice_start)?;
252
253 let nonce = u128::from_le_bytes(nonce_bytes.try_into().unwrap());
254
255 let aead_key = key.aead_from_hkdf(nonce_bytes);
256 let mut sealed_token = sealed_token.to_vec();
257 let data = aead_key.open(&mut sealed_token, &[]).ok()?;
258
259 let mut reader = &data[..];
261 let payload = match TokenType::from_byte((&mut reader).get::<u8>().ok()?)? {
262 TokenType::Retry => TokenPayload::Retry {
263 address: decode_addr(&mut reader)?,
264 orig_dst_cid: ConnectionId::decode_long(&mut reader)?,
265 issued: decode_unix_secs(&mut reader)?,
266 },
267 TokenType::Validation => TokenPayload::Validation {
268 ip: decode_ip(&mut reader)?,
269 issued: decode_unix_secs(&mut reader)?,
270 },
271 };
272
273 if !reader.is_empty() {
274 return None;
276 }
277
278 Some(Self { nonce, payload })
279 }
280}
281
282pub(crate) enum TokenPayload {
284 Retry {
286 address: SocketAddr,
288 orig_dst_cid: ConnectionId,
290 issued: SystemTime,
292 },
293 Validation {
295 ip: IpAddr,
297 issued: SystemTime,
299 },
300}
301
302#[derive(Copy, Clone)]
304#[repr(u8)]
305enum TokenType {
306 Retry = 0,
307 Validation = 1,
308}
309
310impl TokenType {
311 fn from_byte(n: u8) -> Option<Self> {
312 use TokenType::*;
313 [Retry, Validation].into_iter().find(|ty| *ty as u8 == n)
314 }
315}
316
317fn encode_addr(buf: &mut Vec<u8>, address: SocketAddr) {
318 encode_ip(buf, address.ip());
319 buf.put_u16(address.port());
320}
321
322fn decode_addr<B: Buf>(buf: &mut B) -> Option<SocketAddr> {
323 let ip = decode_ip(buf)?;
324 let port = buf.get().ok()?;
325 Some(SocketAddr::new(ip, port))
326}
327
328fn encode_ip(buf: &mut Vec<u8>, ip: IpAddr) {
329 match ip {
330 IpAddr::V4(x) => {
331 buf.put_u8(0);
332 buf.put_slice(&x.octets());
333 }
334 IpAddr::V6(x) => {
335 buf.put_u8(1);
336 buf.put_slice(&x.octets());
337 }
338 }
339}
340
341fn decode_ip<B: Buf>(buf: &mut B) -> Option<IpAddr> {
342 match buf.get::<u8>().ok()? {
343 0 => buf.get().ok().map(IpAddr::V4),
344 1 => buf.get().ok().map(IpAddr::V6),
345 _ => None,
346 }
347}
348
349fn encode_unix_secs(buf: &mut Vec<u8>, time: SystemTime) {
350 buf.write::<u64>(
351 time.duration_since(UNIX_EPOCH)
352 .unwrap_or_default()
353 .as_secs(),
354 );
355}
356
357fn decode_unix_secs<B: Buf>(buf: &mut B) -> Option<SystemTime> {
358 Some(UNIX_EPOCH + Duration::from_secs(buf.get::<u64>().ok()?))
359}
360
361#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Copy, Clone, Hash)]
366pub(crate) struct ResetToken([u8; RESET_TOKEN_SIZE]);
367
368impl ResetToken {
369 pub(crate) fn new(key: &dyn HmacKey, id: ConnectionId) -> Self {
370 let mut signature = vec![0; key.signature_len()];
371 key.sign(&id, &mut signature);
372 let mut result = [0; RESET_TOKEN_SIZE];
374 result.copy_from_slice(&signature[..RESET_TOKEN_SIZE]);
375 result.into()
376 }
377}
378
379impl PartialEq for ResetToken {
380 fn eq(&self, other: &Self) -> bool {
381 crate::constant_time::eq(&self.0, &other.0)
382 }
383}
384
385impl Eq for ResetToken {}
386
387impl From<[u8; RESET_TOKEN_SIZE]> for ResetToken {
388 fn from(x: [u8; RESET_TOKEN_SIZE]) -> Self {
389 Self(x)
390 }
391}
392
393impl std::ops::Deref for ResetToken {
394 type Target = [u8];
395 fn deref(&self) -> &[u8] {
396 &self.0
397 }
398}
399
400impl fmt::Display for ResetToken {
401 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402 for byte in self.iter() {
403 write!(f, "{byte:02x}")?;
404 }
405 Ok(())
406 }
407}
408
409#[cfg(all(test, any(feature = "aws-lc-rs", feature = "ring")))]
410mod test {
411 use super::*;
412 #[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
413 use aws_lc_rs::hkdf;
414 use rand::prelude::*;
415 #[cfg(feature = "ring")]
416 use ring::hkdf;
417
418 fn token_round_trip(payload: TokenPayload) -> TokenPayload {
419 let rng = &mut rand::rng();
420 let token = Token::new(payload, rng);
421 let mut master_key = [0; 64];
422 rng.fill_bytes(&mut master_key);
423 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
424 let encoded = token.encode(&prk);
425 let decoded = Token::decode(&prk, &encoded).expect("token didn't decrypt / decode");
426 assert_eq!(token.nonce, decoded.nonce);
427 decoded.payload
428 }
429
430 #[test]
431 fn retry_token_sanity() {
432 use crate::MAX_CID_SIZE;
433 use crate::cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator};
434 use crate::{Duration, UNIX_EPOCH};
435
436 use std::net::Ipv6Addr;
437
438 let address_1 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
439 let orig_dst_cid_1 = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
440 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Retry {
442 address: address_1,
443 orig_dst_cid: orig_dst_cid_1,
444 issued: issued_1,
445 };
446 let TokenPayload::Retry {
447 address: address_2,
448 orig_dst_cid: orig_dst_cid_2,
449 issued: issued_2,
450 } = token_round_trip(payload_1)
451 else {
452 panic!("token decoded as wrong variant");
453 };
454
455 assert_eq!(address_1, address_2);
456 assert_eq!(orig_dst_cid_1, orig_dst_cid_2);
457 assert_eq!(issued_1, issued_2);
458 }
459
460 #[test]
461 fn validation_token_sanity() {
462 use crate::{Duration, UNIX_EPOCH};
463
464 use std::net::Ipv6Addr;
465
466 let ip_1 = Ipv6Addr::LOCALHOST.into();
467 let issued_1 = UNIX_EPOCH + Duration::from_secs(42); let payload_1 = TokenPayload::Validation {
470 ip: ip_1,
471 issued: issued_1,
472 };
473 let TokenPayload::Validation {
474 ip: ip_2,
475 issued: issued_2,
476 } = token_round_trip(payload_1)
477 else {
478 panic!("token decoded as wrong variant");
479 };
480
481 assert_eq!(ip_1, ip_2);
482 assert_eq!(issued_1, issued_2);
483 }
484
485 #[test]
486 fn invalid_token_returns_err() {
487 use super::*;
488 use rand::RngCore;
489
490 let rng = &mut rand::rng();
491
492 let mut master_key = [0; 64];
493 rng.fill_bytes(&mut master_key);
494
495 let prk = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]).extract(&master_key);
496
497 let mut invalid_token = Vec::new();
498
499 let mut random_data = [0; 32];
500 rand::rng().fill_bytes(&mut random_data);
501 invalid_token.put_slice(&random_data);
502
503 assert!(Token::decode(&prk, &invalid_token).is_none());
505 }
506}