1use std::{any::Any, io, str, sync::Arc};
2
3#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
4use aws_lc_rs::aead;
5use bytes::BytesMut;
6#[cfg(feature = "ring")]
7use ring::aead;
8pub use rustls::Error;
9#[cfg(feature = "__rustls-post-quantum-test")]
10use rustls::NamedGroup;
11use rustls::{
12 self, CipherSuite,
13 client::danger::ServerCertVerifier,
14 pki_types::{CertificateDer, PrivateKeyDer, ServerName},
15 quic::{Connection, HeaderProtectionKey, KeyChange, PacketKey, Secrets, Suite, Version},
16};
17#[cfg(feature = "platform-verifier")]
18use rustls_platform_verifier::BuilderVerifierExt;
19
20use crate::{
21 ConnectError, ConnectionId, PathId, Side, TransportError, TransportErrorCode,
22 crypto::{
23 self, CryptoError, ExportKeyingMaterialError, HeaderKey, KeyPair, Keys, UnsupportedVersion,
24 },
25 transport_parameters::TransportParameters,
26};
27
28impl From<Side> for rustls::Side {
29 fn from(s: Side) -> Self {
30 match s {
31 Side::Client => Self::Client,
32 Side::Server => Self::Server,
33 }
34 }
35}
36
37pub struct TlsSession {
39 version: Version,
40 got_handshake_data: bool,
41 next_secrets: Option<Secrets>,
42 inner: Connection,
43 suite: Suite,
44}
45
46impl TlsSession {
47 fn side(&self) -> Side {
48 match self.inner {
49 Connection::Client(_) => Side::Client,
50 Connection::Server(_) => Side::Server,
51 }
52 }
53}
54
55impl crypto::Session for TlsSession {
56 fn initial_keys(&self, dst_cid: ConnectionId, side: Side) -> Keys {
57 initial_keys(self.version, dst_cid, side, &self.suite)
58 }
59
60 fn handshake_data(&self) -> Option<Box<dyn Any>> {
61 if !self.got_handshake_data {
62 return None;
63 }
64 Some(Box::new(HandshakeData {
65 protocol: self.inner.alpn_protocol().map(|x| x.into()),
66 server_name: match self.inner {
67 Connection::Client(_) => None,
68 Connection::Server(ref session) => session.server_name().map(|x| x.into()),
69 },
70 #[cfg(feature = "__rustls-post-quantum-test")]
71 negotiated_key_exchange_group: self
72 .inner
73 .negotiated_key_exchange_group()
74 .expect("key exchange group is negotiated")
75 .name(),
76 }))
77 }
78
79 fn peer_identity(&self) -> Option<Box<dyn Any>> {
81 self.inner.peer_certificates().map(|v| -> Box<dyn Any> {
82 Box::new(
83 v.iter()
84 .map(|v| v.clone().into_owned())
85 .collect::<Vec<CertificateDer<'static>>>(),
86 )
87 })
88 }
89
90 fn early_crypto(&self) -> Option<(Box<dyn HeaderKey>, Box<dyn crypto::PacketKey>)> {
91 let keys = self.inner.zero_rtt_keys()?;
92 Some((Box::new(keys.header), Box::new(keys.packet)))
93 }
94
95 fn early_data_accepted(&self) -> Option<bool> {
96 match self.inner {
97 Connection::Client(ref session) => Some(session.is_early_data_accepted()),
98 _ => None,
99 }
100 }
101
102 fn is_handshaking(&self) -> bool {
103 self.inner.is_handshaking()
104 }
105
106 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, TransportError> {
107 self.inner.read_hs(buf).map_err(|e| {
108 if let Some(alert) = self.inner.alert() {
109 TransportError {
110 code: TransportErrorCode::crypto(alert.into()),
111 frame: None,
112 reason: e.to_string(),
113 crypto: Some(Arc::new(e)),
114 }
115 } else {
116 TransportError::PROTOCOL_VIOLATION(format!("TLS error: {e}"))
117 }
118 })?;
119 if !self.got_handshake_data {
120 let have_server_name = match self.inner {
124 Connection::Client(_) => false,
125 Connection::Server(ref session) => session.server_name().is_some(),
126 };
127 if self.inner.alpn_protocol().is_some() || have_server_name || !self.is_handshaking() {
128 self.got_handshake_data = true;
129 return Ok(true);
130 }
131 }
132 Ok(false)
133 }
134
135 fn transport_parameters(&self) -> Result<Option<TransportParameters>, TransportError> {
136 match self.inner.quic_transport_parameters() {
137 None => Ok(None),
138 Some(buf) => match TransportParameters::read(self.side(), &mut io::Cursor::new(buf)) {
139 Ok(params) => Ok(Some(params)),
140 Err(e) => Err(e.into()),
141 },
142 }
143 }
144
145 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<Keys> {
146 let keys = match self.inner.write_hs(buf)? {
147 KeyChange::Handshake { keys } => keys,
148 KeyChange::OneRtt { keys, next } => {
149 self.next_secrets = Some(next);
150 keys
151 }
152 };
153
154 Some(Keys {
155 header: KeyPair {
156 local: Box::new(keys.local.header),
157 remote: Box::new(keys.remote.header),
158 },
159 packet: KeyPair {
160 local: Box::new(keys.local.packet),
161 remote: Box::new(keys.remote.packet),
162 },
163 })
164 }
165
166 fn next_1rtt_keys(&mut self) -> Option<KeyPair<Box<dyn crypto::PacketKey>>> {
167 let secrets = self.next_secrets.as_mut()?;
168 let keys = secrets.next_packet_keys();
169 Some(KeyPair {
170 local: Box::new(keys.local),
171 remote: Box::new(keys.remote),
172 })
173 }
174
175 fn is_valid_retry(&self, orig_dst_cid: ConnectionId, header: &[u8], payload: &[u8]) -> bool {
176 let tag_start = match payload.len().checked_sub(16) {
177 Some(x) => x,
178 None => return false,
179 };
180
181 let mut pseudo_packet =
182 Vec::with_capacity(header.len() + payload.len() + orig_dst_cid.len() + 1);
183 pseudo_packet.push(orig_dst_cid.len() as u8);
184 pseudo_packet.extend_from_slice(&orig_dst_cid);
185 pseudo_packet.extend_from_slice(header);
186 let tag_start = tag_start + pseudo_packet.len();
187 pseudo_packet.extend_from_slice(payload);
188
189 let (nonce, key) = match self.version {
190 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
191 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
192 _ => unreachable!(),
193 };
194
195 let nonce = aead::Nonce::assume_unique_for_key(nonce);
196 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
197
198 let (aad, tag) = pseudo_packet.split_at_mut(tag_start);
199 key.open_in_place(nonce, aead::Aad::from(aad), tag).is_ok()
200 }
201
202 fn export_keying_material(
203 &self,
204 output: &mut [u8],
205 label: &[u8],
206 context: &[u8],
207 ) -> Result<(), ExportKeyingMaterialError> {
208 self.inner
209 .export_keying_material(output, label, Some(context))
210 .map_err(|_| ExportKeyingMaterialError)?;
211 Ok(())
212 }
213}
214
215const RETRY_INTEGRITY_KEY_DRAFT: [u8; 16] = [
216 0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1,
217];
218const RETRY_INTEGRITY_NONCE_DRAFT: [u8; 12] = [
219 0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c,
220];
221
222const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
223 0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
224];
225const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
226 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
227];
228
229impl crypto::HeaderKey for Box<dyn HeaderProtectionKey> {
230 fn decrypt(&self, pn_offset: usize, packet: &mut [u8]) {
231 let (header, sample) = packet.split_at_mut(pn_offset + 4);
232 let (first, rest) = header.split_at_mut(1);
233 let pn_end = Ord::min(pn_offset + 3, rest.len());
234 self.decrypt_in_place(
235 &sample[..self.sample_size()],
236 &mut first[0],
237 &mut rest[pn_offset - 1..pn_end],
238 )
239 .unwrap();
240 }
241
242 fn encrypt(&self, pn_offset: usize, packet: &mut [u8]) {
243 let (header, sample) = packet.split_at_mut(pn_offset + 4);
244 let (first, rest) = header.split_at_mut(1);
245 let pn_end = Ord::min(pn_offset + 3, rest.len());
246 self.encrypt_in_place(
247 &sample[..self.sample_size()],
248 &mut first[0],
249 &mut rest[pn_offset - 1..pn_end],
250 )
251 .unwrap();
252 }
253
254 fn sample_size(&self) -> usize {
255 self.sample_len()
256 }
257}
258
259pub struct HandshakeData {
261 pub protocol: Option<Vec<u8>>,
265 pub server_name: Option<String>,
269 #[cfg(feature = "__rustls-post-quantum-test")]
271 pub negotiated_key_exchange_group: NamedGroup,
272}
273
274pub struct QuicClientConfig {
293 pub(crate) inner: Arc<rustls::ClientConfig>,
294 initial: Suite,
295}
296
297impl QuicClientConfig {
298 #[cfg(feature = "platform-verifier")]
299 pub(crate) fn with_platform_verifier() -> Result<Self, Error> {
300 let mut inner = rustls::ClientConfig::builder_with_provider(configured_provider())
302 .with_protocol_versions(&[&rustls::version::TLS13])
303 .unwrap() .with_platform_verifier()?
305 .with_no_client_auth();
306
307 inner.enable_early_data = true;
308 Ok(Self {
309 initial: initial_suite_from_provider(inner.crypto_provider())
311 .expect("no initial cipher suite found"),
312 inner: Arc::new(inner),
313 })
314 }
315
316 pub(crate) fn new(verifier: Arc<dyn ServerCertVerifier>) -> Self {
321 let inner = Self::inner(verifier);
322 Self {
323 initial: initial_suite_from_provider(inner.crypto_provider())
325 .expect("no initial cipher suite found"),
326 inner: Arc::new(inner),
327 }
328 }
329
330 pub fn with_initial(
334 inner: Arc<rustls::ClientConfig>,
335 initial: Suite,
336 ) -> Result<Self, NoInitialCipherSuite> {
337 match initial.suite.common.suite {
338 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
339 _ => Err(NoInitialCipherSuite { specific: true }),
340 }
341 }
342
343 pub(crate) fn inner(verifier: Arc<dyn ServerCertVerifier>) -> rustls::ClientConfig {
344 let mut config = rustls::ClientConfig::builder_with_provider(configured_provider())
346 .with_protocol_versions(&[&rustls::version::TLS13])
347 .unwrap() .dangerous()
349 .with_custom_certificate_verifier(verifier)
350 .with_no_client_auth();
351
352 config.enable_early_data = true;
353 config
354 }
355}
356
357impl crypto::ClientConfig for QuicClientConfig {
358 fn start_session(
359 self: Arc<Self>,
360 version: u32,
361 server_name: &str,
362 params: &TransportParameters,
363 ) -> Result<Box<dyn crypto::Session>, ConnectError> {
364 let version = interpret_version(version)?;
365 Ok(Box::new(TlsSession {
366 version,
367 got_handshake_data: false,
368 next_secrets: None,
369 inner: rustls::quic::Connection::Client(
370 rustls::quic::ClientConnection::new(
371 self.inner.clone(),
372 version,
373 ServerName::try_from(server_name)
374 .map_err(|_| ConnectError::InvalidServerName(server_name.into()))?
375 .to_owned(),
376 to_vec(params),
377 )
378 .unwrap(),
379 ),
380 suite: self.initial,
381 }))
382 }
383}
384
385impl TryFrom<rustls::ClientConfig> for QuicClientConfig {
386 type Error = NoInitialCipherSuite;
387
388 fn try_from(inner: rustls::ClientConfig) -> Result<Self, Self::Error> {
389 Arc::new(inner).try_into()
390 }
391}
392
393impl TryFrom<Arc<rustls::ClientConfig>> for QuicClientConfig {
394 type Error = NoInitialCipherSuite;
395
396 fn try_from(inner: Arc<rustls::ClientConfig>) -> Result<Self, Self::Error> {
397 Ok(Self {
398 initial: initial_suite_from_provider(inner.crypto_provider())
399 .ok_or(NoInitialCipherSuite { specific: false })?,
400 inner,
401 })
402 }
403}
404
405#[derive(Clone, Debug)]
413pub struct NoInitialCipherSuite {
414 specific: bool,
416}
417
418impl std::fmt::Display for NoInitialCipherSuite {
419 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
420 f.write_str(match self.specific {
421 true => "invalid cipher suite specified",
422 false => "no initial cipher suite found",
423 })
424 }
425}
426
427impl std::error::Error for NoInitialCipherSuite {}
428
429pub struct QuicServerConfig {
442 inner: Arc<rustls::ServerConfig>,
443 initial: Suite,
444}
445
446impl QuicServerConfig {
447 pub(crate) fn new(
448 cert_chain: Vec<CertificateDer<'static>>,
449 key: PrivateKeyDer<'static>,
450 ) -> Result<Self, rustls::Error> {
451 let inner = Self::inner(cert_chain, key)?;
452 Ok(Self {
453 initial: initial_suite_from_provider(inner.crypto_provider())
455 .expect("no initial cipher suite found"),
456 inner: Arc::new(inner),
457 })
458 }
459
460 pub fn with_initial(
464 inner: Arc<rustls::ServerConfig>,
465 initial: Suite,
466 ) -> Result<Self, NoInitialCipherSuite> {
467 match initial.suite.common.suite {
468 CipherSuite::TLS13_AES_128_GCM_SHA256 => Ok(Self { inner, initial }),
469 _ => Err(NoInitialCipherSuite { specific: true }),
470 }
471 }
472
473 pub(crate) fn inner(
479 cert_chain: Vec<CertificateDer<'static>>,
480 key: PrivateKeyDer<'static>,
481 ) -> Result<rustls::ServerConfig, rustls::Error> {
482 let mut inner = rustls::ServerConfig::builder_with_provider(configured_provider())
483 .with_protocol_versions(&[&rustls::version::TLS13])
484 .unwrap() .with_no_client_auth()
486 .with_single_cert(cert_chain, key)?;
487
488 inner.max_early_data_size = u32::MAX;
489 Ok(inner)
490 }
491}
492
493impl TryFrom<rustls::ServerConfig> for QuicServerConfig {
494 type Error = NoInitialCipherSuite;
495
496 fn try_from(inner: rustls::ServerConfig) -> Result<Self, Self::Error> {
497 Arc::new(inner).try_into()
498 }
499}
500
501impl TryFrom<Arc<rustls::ServerConfig>> for QuicServerConfig {
502 type Error = NoInitialCipherSuite;
503
504 fn try_from(inner: Arc<rustls::ServerConfig>) -> Result<Self, Self::Error> {
505 Ok(Self {
506 initial: initial_suite_from_provider(inner.crypto_provider())
507 .ok_or(NoInitialCipherSuite { specific: false })?,
508 inner,
509 })
510 }
511}
512
513impl crypto::ServerConfig for QuicServerConfig {
514 fn start_session(
515 self: Arc<Self>,
516 version: u32,
517 params: &TransportParameters,
518 ) -> Box<dyn crypto::Session> {
519 let version = interpret_version(version).unwrap();
521 Box::new(TlsSession {
522 version,
523 got_handshake_data: false,
524 next_secrets: None,
525 inner: rustls::quic::Connection::Server(
526 rustls::quic::ServerConnection::new(self.inner.clone(), version, to_vec(params))
527 .unwrap(),
528 ),
529 suite: self.initial,
530 })
531 }
532
533 fn initial_keys(
534 &self,
535 version: u32,
536 dst_cid: ConnectionId,
537 ) -> Result<Keys, UnsupportedVersion> {
538 let version = interpret_version(version)?;
539 Ok(initial_keys(version, dst_cid, Side::Server, &self.initial))
540 }
541
542 fn retry_tag(&self, version: u32, orig_dst_cid: ConnectionId, packet: &[u8]) -> [u8; 16] {
543 let version = interpret_version(version).unwrap();
545 let (nonce, key) = match version {
546 Version::V1 => (RETRY_INTEGRITY_NONCE_V1, RETRY_INTEGRITY_KEY_V1),
547 Version::V1Draft => (RETRY_INTEGRITY_NONCE_DRAFT, RETRY_INTEGRITY_KEY_DRAFT),
548 _ => unreachable!(),
549 };
550
551 let mut pseudo_packet = Vec::with_capacity(packet.len() + orig_dst_cid.len() + 1);
552 pseudo_packet.push(orig_dst_cid.len() as u8);
553 pseudo_packet.extend_from_slice(&orig_dst_cid);
554 pseudo_packet.extend_from_slice(packet);
555
556 let nonce = aead::Nonce::assume_unique_for_key(nonce);
557 let key = aead::LessSafeKey::new(aead::UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
558
559 let tag = key
560 .seal_in_place_separate_tag(nonce, aead::Aad::from(pseudo_packet), &mut [])
561 .unwrap();
562 let mut result = [0; 16];
563 result.copy_from_slice(tag.as_ref());
564 result
565 }
566}
567
568pub(crate) fn initial_suite_from_provider(
569 provider: &Arc<rustls::crypto::CryptoProvider>,
570) -> Option<Suite> {
571 provider
572 .cipher_suites
573 .iter()
574 .find_map(|cs| match (cs.suite(), cs.tls13()) {
575 (rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, Some(suite)) => {
576 Some(suite.quic_suite())
577 }
578 _ => None,
579 })
580 .flatten()
581}
582
583pub(crate) fn configured_provider() -> Arc<rustls::crypto::CryptoProvider> {
584 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
585 let provider = rustls::crypto::aws_lc_rs::default_provider();
586 #[cfg(feature = "rustls-ring")]
587 let provider = rustls::crypto::ring::default_provider();
588 Arc::new(provider)
589}
590
591fn to_vec(params: &TransportParameters) -> Vec<u8> {
592 let mut bytes = Vec::new();
593 params.write(&mut bytes);
594 bytes
595}
596
597pub(crate) fn initial_keys(
598 version: Version,
599 dst_cid: ConnectionId,
600 side: Side,
601 suite: &Suite,
602) -> Keys {
603 let keys = suite.keys(&dst_cid, side.into(), version);
604 Keys {
605 header: KeyPair {
606 local: Box::new(keys.local.header),
607 remote: Box::new(keys.remote.header),
608 },
609 packet: KeyPair {
610 local: Box::new(keys.local.packet),
611 remote: Box::new(keys.remote.packet),
612 },
613 }
614}
615
616impl crypto::PacketKey for Box<dyn PacketKey> {
617 fn encrypt(&self, path_id: PathId, packet: u64, buf: &mut [u8], header_len: usize) {
618 let (header, payload_tag) = buf.split_at_mut(header_len);
619 let (payload, tag_storage) = payload_tag.split_at_mut(payload_tag.len() - self.tag_len());
620 let tag = self
621 .encrypt_in_place_for_path(path_id.as_u32(), packet, &*header, payload)
622 .unwrap();
623 tag_storage.copy_from_slice(tag.as_ref());
624 }
625
626 fn decrypt(
627 &self,
628 path_id: PathId,
629 packet: u64,
630 header: &[u8],
631 payload: &mut BytesMut,
632 ) -> Result<(), CryptoError> {
633 let plain = self
634 .decrypt_in_place_for_path(path_id.as_u32(), packet, header, payload.as_mut())
635 .map_err(|_| CryptoError)?;
636 let plain_len = plain.len();
637 payload.truncate(plain_len);
638 Ok(())
639 }
640
641 fn tag_len(&self) -> usize {
642 (**self).tag_len()
643 }
644
645 fn confidentiality_limit(&self) -> u64 {
646 (**self).confidentiality_limit()
647 }
648
649 fn integrity_limit(&self) -> u64 {
650 (**self).integrity_limit()
651 }
652}
653
654fn interpret_version(version: u32) -> Result<Version, UnsupportedVersion> {
655 match version {
656 0xff00_001d..=0xff00_0020 => Ok(Version::V1Draft),
657 0x0000_0001 | 0xff00_0021..=0xff00_0022 => Ok(Version::V1),
658 _ => Err(UnsupportedVersion),
659 }
660}