1use std::{any::Any, io, str, sync::Arc};
2
3use aes_gcm::{KeyInit, aead::AeadMutInPlace};
4use bytes::BytesMut;
5pub use rustls::Error;
6use rustls::{
7 self, CipherSuite,
8 pki_types::{CertificateDer, ServerName},
9 quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
10};
11#[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
12use rustls::{client::danger::ServerCertVerifier, pki_types::PrivateKeyDer};
13
14use crate::{
15 ConnectError, ConnectionId, PathId, Side, TransportError, TransportErrorCode,
16 crypto::{
17 self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
18 },
19 transport_parameters::TransportParameters,
20};
21
22impl From<Side> for rustls::Side {
23 fn from(s: Side) -> Self {
24 match s {
25 Side::Client => Self::Client,
26 Side::Server => Self::Server,
27 }
28 }
29}
30
31pub struct TlsSession {
33 version: Version,
34 got_handshake_data: bool,
35 next_secrets: Option<Secrets>,
36 inner: Connection,
37 suite: Suite,
38}
39
40impl TlsSession {
41 fn side(&self) -> Side {
42 match self.inner {
43 Connection::Client(_) => Side::Client,
44 Connection::Server(_) => Side::Server,
45 }
46 }
47}
48
49impl crypto::Session for TlsSession {
50 fn initial_keys(&self, dst_cid: ConnectionId, side: Side) -> Keys {
51 initial_keys(self.version, dst_cid, side, &self.suite)
52 }
53
54 fn handshake_data(&self) -> Option<Box<dyn Any>> {
55 if !self.got_handshake_data {
56 return None;
57 }
58 Some(Box::new(HandshakeData {
59 protocol: self.inner.alpn_protocol().map(|x| x.into()),
60 server_name: match self.inner {
61 Connection::Client(_) => None,
62 Connection::Server(ref session) => session.server_name().map(|x| x.into()),
63 },
64 #[cfg(feature = "__rustls-post-quantum-test")]
65 negotiated_key_exchange_group: self
66 .inner
67 .negotiated_key_exchange_group()
68 .expect("key exchange group is negotiated")
69 .name(),
70 }))
71 }
72
73 fn peer_identity(&self) -> Option<Box<dyn Any>> {
75 self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
76 Box::new(
77 v.iter()
78 .map(|v| v.clone().into_owned())
79 .collect::<Vec<CertificateDer<'static>>>(),
80 )
81 })
82 }
83
84 fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
85 let keys = self.inner.zero_rtt_keys()?;
86 Some((Box::new(keys.header), Box::new(keys.packet)))
87 }
88
89 fn early_data_accepted(&self) -> Option<bool> {
90 match self.inner {
91 Connection::Client(ref session) => Some(session.is_early_data_accepted()),
92 _ => None,
93 }
94 }
95
96 fn is_handshaking(&self) -> bool {
97 self.inner.is_handshaking()
98 }
99
100 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
101 self.inner.read_hs(buf).map_err(|e| {
102 if let Some(alert) = self.inner.alert() {
103 TransportError {
104 code: TransportErrorCode::crypto(alert.into()),
105 frame: crate::frame::MaybeFrame::None,
106 reason: e.to_string(),
107 crypto: Some(Arc::new(e)),
108 }
109 } else {
110 TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
111 }
112 })?;
113 if !self.got_handshake_data {
114 let have_server_name = match self.inner {
118 Connection::Client(_) => false,
119 Connection::Server(ref session) => session.server_name().is_some(),
120 };
121 if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
122 self.got_handshake_data = true;
123 return Ok(true);
124 }
125 }
126 Ok(false)
127 }
128
129 fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
130 match self.inner.quic_transport_parameters() {
131 None => Ok(None),
132 Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
133 Ok(params) => Ok(Some(params)),
134 Err(e) => Err(e.into()),
135 },
136 }
137 }
138
139 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
140 let keys = match self.inner.write_hs(buf)? {
141 KeyChange::Handshake { keys } => keys,
142 KeyChange::OneRtt { keys, next } => {
143 self.next_secrets = Some(next);
144 keys
145 }
146 };
147
148 Some(Keys {
149 header: KeyPair {
150 local: Box::new(keys.local.header),
151 remote: Box::new(keys.remote.header),
152 },
153 packet: KeyPair {
154 local: Box::new(keys.local.packet),
155 remote: Box::new(keys.remote.packet),
156 },
157 })
158 }
159
160 fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
161 let secrets = self.next_secrets.as_mut()?;
162 let keys = secrets.next_packet_keys();
163 Some(KeyPair {
164 local: Box::new(keys.local),
165 remote: Box::new(keys.remote),
166 })
167 }
168
169 fn is_valid_retry(&self, orig_dst_cid: ConnectionId, header: &[u8], payload: &[u8]) -> bool {
170 if payload.len() < 16 {
171 return false;
172 }
173
174 let mut pseudo_packet =
175 Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
176 pseudo_packet.push(orig_dst_cid.len() as u8);
177 pseudo_packet.extend_from_slice(&orig_dst_cid);
178 pseudo_packet.extend_from_slice(header);
179 pseudo_packet.extend_from_slice(payload);
180
181 let (nonce, key) = match self.version {
182 Version::V1 => (&RETRY_INTEGRITY_NONCE_V1, &RETRY_INTEGRITY_KEY_V1),
183 Version::V1Draft => (&RETRY_INTEGRITY_NONCE_DRAFT, &RETRY_INTEGRITY_KEY_DRAFT),
184 _ => unreachable!(),
185 };
186
187 let Some((aad, tag)) = pseudo_packet.split_last_chunk::<16>() else {
188 return false; };
190
191 let key = aes_gcm::Key::<aes_gcm::Aes128Gcm>::from_slice(key);
193 let nonce = aes_gcm::Nonce::from_slice(nonce);
194 let tag = aes_gcm::Tag::from_slice(tag);
195 aes_gcm::Aes128Gcm::new(key)
196 .decrypt_in_place_detached(nonce, aad, &mut [], tag)
197 .is_ok()
198 }
199
200 fn export_keying_material(
201 &self,
202 output: &mut [u8],
203 label: &[u8],
204 context: &[u8],
205 ) -> Result<(), ExportKeyingMaterialError> {
206 self.inner
207 .export_keying_material(output, label, Some(context))
208 .map_err(|_| ExportKeyingMaterialError)?;
209 Ok(())
210 }
211}
212
213const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
214 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
215];
216const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
217 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
218];
219
220const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
221 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
222];
223const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
224 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
225];
226
227impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
228 fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
229 let (header, sample) = packet.split_at_mut(pn_offset + 4);
230 let (first, rest) = header.split_at_mut(1);
231 let pn_end = Ord::min(pn_offset + 3, rest.len());
232 self.decrypt_in_place(
233 &sample[..self.sample_size()],
234 &mut first[0],
235 &mut rest[pn_offset - 1..pn_end],
236 )
237 .unwrap();
238 }
239
240 fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
241 let (header, sample) = packet.split_at_mut(pn_offset + 4);
242 let (first, rest) = header.split_at_mut(1);
243 let pn_end = Ord::min(pn_offset + 3, rest.len());
244 self.encrypt_in_place(
245 &sample[..self.sample_size()],
246 &mut first[0],
247 &mut rest[pn_offset - 1..pn_end],
248 )
249 .unwrap();
250 }
251
252 fn sample_size(&self) -> usize {
253 self.sample_len()
254 }
255}
256
257pub struct HandshakeData {
259 pub protocol: Option<Vec<u8>>,
263 pub server_name: Option<String>,
267 #[cfg(feature = "__rustls-post-quantum-test")]
269 pub negotiated_key_exchange_group: rustls::NamedGroup,
270}
271
272pub struct QuicClientConfig {
291 pub(crate) inner: Arc<rustls::ClientConfig>,
292 initial: Suite,
293}
294
295impl QuicClientConfig {
296 #[cfg(all(
297 feature = "platform-verifier",
298 any(feature = "aws-lc-rs", feature = "ring")
299 ))]
300 pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
301 use rustls_platform_verifier::BuilderVerifierExt;
302
303 let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
305 .with_protocol_versions(&[&rustls::version::TLS13])
306 .unwrap() .with_platform_verifier()?
308 .with_no_client_auth();
309
310 inner.enable_early_data = true;
311 Ok(Self {
312 initial: initial_suite_from_provider(inner.crypto_provider())
314 .expect("no initial cipher suite found"),
315 inner: Arc::new(inner),
316 })
317 }
318
319 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
324 pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
325 let inner = Self::inner(verifier);
326 Self {
327 initial: initial_suite_from_provider(inner.crypto_provider())
329 .expect("no initial cipher suite found"),
330 inner: Arc::new(inner),
331 }
332 }
333
334 pub fn with_initial(
338 inner: Arc<rustls::ClientConfig>,
339 initial: Suite,
340 ) -> Result<Self, NoInitialCipherSuite> {
341 match initial.suite.common.suite {
342 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
343 _ => Err(NoInitialCipherSuite { specific: true }),
344 }
345 }
346
347 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
348 pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
349 let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
351 .with_protocol_versions(&[&rustls::version::TLS13])
352 .unwrap() .dangerous()
354 .with_custom_certificate_verifier(verifier)
355 .with_no_client_auth();
356
357 config.enable_early_data = true;
358 config
359 }
360}
361
362impl crypto::ClientConfig for QuicClientConfig {
363 fn start_session(
364 self: Arc<Self>,
365 version: u32,
366 server_name: &str,
367 params: &TransportParameters,
368 ) -> Result<Box<dyn crypto::Session>, ConnectError> {
369 let version = interpret_version(version)?;
370 Ok(Box::new(TlsSession {
371 version,
372 got_handshake_data: false,
373 next_secrets: None,
374 inner: rustls::quic::Connection::Client(
375 rustls::quic::ClientConnection::new(
376 self.inner.clone(),
377 version,
378 ServerName::try_from(server_name)
379 .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
380 .to_owned(),
381 to_vec(params),
382 )
383 .unwrap(),
384 ),
385 suite: self.initial,
386 }))
387 }
388}
389
390impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
391 type Error = NoInitialCipherSuite;
392
393 fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
394 Arc::new(inner).try_into()
395 }
396}
397
398impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
399 type Error = NoInitialCipherSuite;
400
401 fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
402 Ok(Self {
403 initial: initial_suite_from_provider(inner.crypto_provider())
404 .ok_or(NoInitialCipherSuite { specific: false })?,
405 inner,
406 })
407 }
408}
409
410#[derive(Clone, Debug)]
418pub struct NoInitialCipherSuite {
419 specific: bool,
421}
422
423impl std::fmt::Display for NoInitialCipherSuite {
424 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425 f.write_str(match self.specific {
426 true => "invalid cipher suite specified",
427 false => "no initial cipher suite found",
428 })
429 }
430}
431
432impl std::error::Error for NoInitialCipherSuite {}
433
434pub struct QuicServerConfig {
447 inner: Arc<rustls::ServerConfig>,
448 initial: Suite,
449}
450
451impl QuicServerConfig {
452 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
453 pub(crate) fn new(
454 cert_chain: Vec<CertificateDer<'static>>,
455 key: PrivateKeyDer<'static>,
456 ) -> Result<Self, rustls::Error> {
457 let inner = Self::inner(cert_chain, key)?;
458 Ok(Self {
459 initial: initial_suite_from_provider(inner.crypto_provider())
461 .expect("no initial cipher suite found"),
462 inner: Arc::new(inner),
463 })
464 }
465
466 pub fn with_initial(
470 inner: Arc<rustls::ServerConfig>,
471 initial: Suite,
472 ) -> Result<Self, NoInitialCipherSuite> {
473 match initial.suite.common.suite {
474 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
475 _ => Err(NoInitialCipherSuite { specific: true }),
476 }
477 }
478
479 #[cfg(any(feature = "aws-lc-rs", feature = "ring"))]
485 pub(crate) fn inner(
486 cert_chain: Vec<CertificateDer<'static>>,
487 key: PrivateKeyDer<'static>,
488 ) -> Result<rustls::ServerConfig, rustls::Error> {
489 let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
490 .with_protocol_versions(&[&rustls::version::TLS13])
491 .unwrap() .with_no_client_auth()
493 .with_single_cert(cert_chain, key)?;
494
495 inner.max_early_data_size = u32::MAX;
496 Ok(inner)
497 }
498}
499
500impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
501 type Error = NoInitialCipherSuite;
502
503 fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
504 Arc::new(inner).try_into()
505 }
506}
507
508impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
509 type Error = NoInitialCipherSuite;
510
511 fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
512 Ok(Self {
513 initial: initial_suite_from_provider(inner.crypto_provider())
514 .ok_or(NoInitialCipherSuite { specific: false })?,
515 inner,
516 })
517 }
518}
519
520impl crypto::ServerConfig for QuicServerConfig {
521 fn start_session(
522 self: Arc<Self>,
523 version: u32,
524 params: &TransportParameters,
525 ) -> Box<dyn crypto::Session> {
526 let version = interpret_version(version).unwrap();
528 Box::new(TlsSession {
529 version,
530 got_handshake_data: false,
531 next_secrets: None,
532 inner: rustls::quic::Connection::Server(
533 rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
534 .unwrap(),
535 ),
536 suite: self.initial,
537 })
538 }
539
540 fn initial_keys(
541 &self,
542 version: u32,
543 dst_cid: ConnectionId,
544 ) -> Result<Keys, UnsupportedVersion> {
545 let version = interpret_version(version)?;
546 Ok(initial_keys(version, dst_cid, Side::Server, &self.initial))
547 }
548
549 fn retry_tag(&self, version: u32, orig_dst_cid: ConnectionId, packet: &[u8]) -> [u8; 16] {
550 let version = interpret_version(version).unwrap();
552 let (nonce, key) = match version {
553 Version::V1 => (&RETRY_INTEGRITY_NONCE_V1, &RETRY_INTEGRITY_KEY_V1),
554 Version::V1Draft => (&RETRY_INTEGRITY_NONCE_DRAFT, &RETRY_INTEGRITY_KEY_DRAFT),
555 _ => unreachable!(),
556 };
557
558 let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
559 pseudo_packet.push(orig_dst_cid.len() as u8);
560 pseudo_packet.extend_from_slice(&orig_dst_cid);
561 pseudo_packet.extend_from_slice(packet);
562
563 let nonce = aes_gcm::Nonce::from_slice(nonce);
564 let key = aes_gcm::Key::<aes_gcm::Aes128Gcm>::from_slice(key);
565 let tag = aes_gcm::Aes128Gcm::new(key)
566 .encrypt_in_place_detached(nonce, &pseudo_packet, &mut [])
567 .unwrap();
568 tag.into()
569 }
570}
571
572pub(crate) fn initial_suite_from_provider(
573 provider: &Arc<rustls::crypto::CryptoProvider>,
574) -> Option<Suite> {
575 provider
576 .cipher_suites
577 .iter()
578 .find_map(|cs| match (cs.suite(), cs.tls13()) {
579 (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
580 Some(suite.quic_suite())
581 }
582 _ => None,
583 })
584 .flatten()
585}
586
587#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
588pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
589 Arc::new(rustls::crypto::aws_lc_rs::default_provider())
590}
591
592#[cfg(feature = "ring")]
593pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
594 Arc::new(rustls::crypto::ring::default_provider())
595}
596
597fn to_vec(params: &TransportParameters) -> Vec<u8> {
598 let mut bytes = Vec::new();
599 params.write(&mut bytes);
600 bytes
601}
602
603pub(crate) fn initial_keys(
604 version: Version,
605 dst_cid: ConnectionId,
606 side: Side,
607 suite: &Suite,
608) -> Keys {
609 let keys = suite.keys(&dst_cid, side.into(), version);
610 Keys {
611 header: KeyPair {
612 local: Box::new(keys.local.header),
613 remote: Box::new(keys.remote.header),
614 },
615 packet: KeyPair {
616 local: Box::new(keys.local.packet),
617 remote: Box::new(keys.remote.packet),
618 },
619 }
620}
621
622impl crypto::PacketKey for Box<dyn PacketKey> {
623 fn encrypt(&self, path_id: PathId, packet: u64, buf: &mut [u8], header_len: usize) {
624 let (header, payload_tag) = buf.split_at_mut(header_len);
625 let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
626 let tag = self
627 .encrypt_in_place_for_path(path_id.as_u32(), packet, &*header, payload)
628 .unwrap();
629 tag_storage.copy_from_slice(tag.as_ref());
630 }
631
632 fn decrypt(
633 &self,
634 path_id: PathId,
635 packet: u64,
636 header: &[u8],
637 payload: &mut BytesMut,
638 ) -> Result<(), CryptoError> {
639 let plain = self
640 .decrypt_in_place_for_path(path_id.as_u32(), packet, header, payload.as_mut())
641 .map_err(|_| CryptoError)?;
642 let plain_len = plain.len();
643 payload.truncate(plain_len);
644 Ok(())
645 }
646
647 fn tag_len(&self) -> usize {
648 (**self).tag_len()
649 }
650
651 fn confidentiality_limit(&self) -> u64 {
652 (**self).confidentiality_limit()
653 }
654
655 fn integrity_limit(&self) -> u64 {
656 (**self).integrity_limit()
657 }
658}
659
660fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
661 match version {
662 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
663 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
664 _ => Err(UnsupportedVersion),
665 }
666}