1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7 ConnectionId, PathId,
8 coding::{self, BufExt, BufMutExt},
9 crypto,
10};
11
12#[cfg_attr(test, derive(Clone))]
25#[derive(Debug)]
26pub struct PartialDecode {
27 plain_header: ProtectedHeader,
28 buf: io::Cursor<BytesMut>,
29}
30
31#[allow(clippy::len_without_is_empty)]
32impl PartialDecode {
33 pub fn new(
35 bytes: BytesMut,
36 cid_parser: &(impl ConnectionIdParser + ?Sized),
37 supported_versions: &[u32],
38 grease_quic_bit: bool,
39 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
40 let mut buf = io::Cursor::new(bytes);
41 let plain_header =
42 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
43 let dgram_len = buf.get_ref().len();
44 let packet_len = plain_header
45 .payload_len()
46 .map(|len| (buf.position() + len) as usize)
47 .unwrap_or(dgram_len);
48 match dgram_len.cmp(&packet_len) {
49 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
50 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
51 "packet too short to contain payload length",
52 )),
53 Ordering::Greater => {
54 let rest = Some(buf.get_mut().split_off(packet_len));
55 Ok((Self { plain_header, buf }, rest))
56 }
57 }
58 }
59
60 pub(crate) fn data(&self) -> &[u8] {
62 self.buf.get_ref()
63 }
64
65 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
66 self.plain_header.as_initial()
67 }
68
69 pub(crate) fn has_long_header(&self) -> bool {
70 !matches!(self.plain_header, ProtectedHeader::Short { .. })
71 }
72
73 pub(crate) fn is_initial(&self) -> bool {
74 self.space() == Some(SpaceId::Initial)
75 }
76
77 pub(crate) fn space(&self) -> Option<SpaceId> {
78 use ProtectedHeader::*;
79 match self.plain_header {
80 Initial { .. } => Some(SpaceId::Initial),
81 Long {
82 ty: LongType::Handshake,
83 ..
84 } => Some(SpaceId::Handshake),
85 Long {
86 ty: LongType::ZeroRtt,
87 ..
88 } => Some(SpaceId::Data),
89 Short { .. } => Some(SpaceId::Data),
90 _ => None,
91 }
92 }
93
94 pub(crate) fn is_0rtt(&self) -> bool {
95 match self.plain_header {
96 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
97 _ => false,
98 }
99 }
100
101 pub fn dst_cid(&self) -> ConnectionId {
103 self.plain_header.dst_cid()
104 }
105
106 #[allow(unreachable_pub)] pub fn len(&self) -> usize {
109 self.buf.get_ref().len()
110 }
111
112 pub(crate) fn finish(
113 self,
114 header_crypto: Option<&dyn crypto::HeaderKey>,
115 ) -> Result<Packet, PacketDecodeError> {
116 use ProtectedHeader::*;
117 let Self {
118 plain_header,
119 mut buf,
120 } = self;
121
122 if let Initial(ProtectedInitialHeader {
123 dst_cid,
124 src_cid,
125 token_pos,
126 version,
127 ..
128 }) = plain_header
129 {
130 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
131 let header_len = buf.position() as usize;
132 let mut bytes = buf.into_inner();
133
134 let header_data = bytes.split_to(header_len).freeze();
135 let token = header_data.slice(token_pos.start..token_pos.end);
136 return Ok(Packet {
137 header: Header::Initial(InitialHeader {
138 dst_cid,
139 src_cid,
140 token,
141 number,
142 version,
143 }),
144 header_data,
145 payload: bytes,
146 });
147 }
148
149 let header = match plain_header {
150 Long {
151 ty,
152 dst_cid,
153 src_cid,
154 version,
155 ..
156 } => Header::Long {
157 ty,
158 dst_cid,
159 src_cid,
160 number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
161 version,
162 },
163 Retry {
164 dst_cid,
165 src_cid,
166 version,
167 } => Header::Retry {
168 dst_cid,
169 src_cid,
170 version,
171 },
172 Short { spin, dst_cid, .. } => {
173 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
174 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
175 Header::Short {
176 spin,
177 key_phase,
178 dst_cid,
179 number,
180 }
181 }
182 VersionNegotiate {
183 random,
184 dst_cid,
185 src_cid,
186 } => Header::VersionNegotiate {
187 random,
188 dst_cid,
189 src_cid,
190 },
191 Initial { .. } => unreachable!(),
192 };
193
194 let header_len = buf.position() as usize;
195 let mut bytes = buf.into_inner();
196 Ok(Packet {
197 header,
198 header_data: bytes.split_to(header_len).freeze(),
199 payload: bytes,
200 })
201 }
202
203 fn decrypt_header(
204 buf: &mut io::Cursor<BytesMut>,
205 header_crypto: &dyn crypto::HeaderKey,
206 ) -> Result<PacketNumber, PacketDecodeError> {
207 let packet_length = buf.get_ref().len();
208 let pn_offset = buf.position() as usize;
209 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
210 return Err(PacketDecodeError::InvalidHeader(
211 "packet too short to extract header protection sample",
212 ));
213 }
214
215 header_crypto.decrypt(pn_offset, buf.get_mut());
216
217 let len = PacketNumber::decode_len(buf.get_ref()[0]);
218 PacketNumber::decode(len, buf)
219 }
220}
221
222pub(crate) trait BufLen {
229 fn len(&self) -> usize;
231}
232
233impl BufLen for Vec<u8> {
234 fn len(&self) -> usize {
235 self.len()
236 }
237}
238
239pub(crate) struct Packet {
240 pub(crate) header: Header,
241 pub(crate) header_data: Bytes,
242 pub(crate) payload: BytesMut,
243}
244
245impl Packet {
246 pub(crate) fn reserved_bits_valid(&self) -> bool {
247 let mask = match self.header {
248 Header::Short { .. } => SHORT_RESERVED_BITS,
249 _ => LONG_RESERVED_BITS,
250 };
251 self.header_data[0] & mask == 0
252 }
253}
254
255pub(crate) struct InitialPacket {
256 pub(crate) header: InitialHeader,
257 pub(crate) header_data: Bytes,
258 pub(crate) payload: BytesMut,
259}
260
261impl From<InitialPacket> for Packet {
262 fn from(x: InitialPacket) -> Self {
263 Self {
264 header: Header::Initial(x.header),
265 header_data: x.header_data,
266 payload: x.payload,
267 }
268 }
269}
270
271#[cfg_attr(test, derive(Clone))]
272#[derive(Debug)]
273pub(crate) enum Header {
274 Initial(InitialHeader),
275 Long {
276 ty: LongType,
277 dst_cid: ConnectionId,
278 src_cid: ConnectionId,
279 number: PacketNumber,
280 version: u32,
281 },
282 Retry {
283 dst_cid: ConnectionId,
284 src_cid: ConnectionId,
285 version: u32,
286 },
287 Short {
288 spin: bool,
289 key_phase: bool,
290 dst_cid: ConnectionId,
291 number: PacketNumber,
292 },
293 VersionNegotiate {
294 random: u8,
295 src_cid: ConnectionId,
296 dst_cid: ConnectionId,
297 },
298}
299
300impl Header {
301 pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
302 use Header::*;
303 let start = w.len();
304 match *self {
305 Initial(InitialHeader {
306 ref dst_cid,
307 ref src_cid,
308 ref token,
309 number,
310 version,
311 }) => {
312 w.write(u8::from(LongHeaderType::Initial) | number.tag());
313 w.write(version);
314 dst_cid.encode_long(w);
315 src_cid.encode_long(w);
316 w.write_var(token.len() as u64);
317 w.put_slice(token);
318 w.write::<u16>(0); number.encode(w);
320 PartialEncode {
321 start,
322 header_len: w.len() - start,
323 pn: Some((number.len(), true)),
324 }
325 }
326 Long {
327 ty,
328 ref dst_cid,
329 ref src_cid,
330 number,
331 version,
332 } => {
333 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
334 w.write(version);
335 dst_cid.encode_long(w);
336 src_cid.encode_long(w);
337 w.write::<u16>(0); number.encode(w);
339 PartialEncode {
340 start,
341 header_len: w.len() - start,
342 pn: Some((number.len(), true)),
343 }
344 }
345 Retry {
346 ref dst_cid,
347 ref src_cid,
348 version,
349 } => {
350 w.write(u8::from(LongHeaderType::Retry));
351 w.write(version);
352 dst_cid.encode_long(w);
353 src_cid.encode_long(w);
354 PartialEncode {
355 start,
356 header_len: w.len() - start,
357 pn: None,
358 }
359 }
360 Short {
361 spin,
362 key_phase,
363 ref dst_cid,
364 number,
365 } => {
366 w.write(
367 FIXED_BIT
368 | if key_phase { KEY_PHASE_BIT } else { 0 }
369 | if spin { SPIN_BIT } else { 0 }
370 | number.tag(),
371 );
372 w.put_slice(dst_cid);
373 number.encode(w);
374 PartialEncode {
375 start,
376 header_len: w.len() - start,
377 pn: Some((number.len(), false)),
378 }
379 }
380 VersionNegotiate {
381 ref random,
382 ref dst_cid,
383 ref src_cid,
384 } => {
385 w.write(0x80u8 | random);
386 w.write::<u32>(0);
387 dst_cid.encode_long(w);
388 src_cid.encode_long(w);
389 PartialEncode {
390 start,
391 header_len: w.len() - start,
392 pn: None,
393 }
394 }
395 }
396 }
397
398 pub(crate) fn is_protected(&self) -> bool {
400 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
401 }
402
403 pub(crate) fn number(&self) -> Option<PacketNumber> {
404 use Header::*;
405 Some(match *self {
406 Initial(InitialHeader { number, .. }) => number,
407 Long { number, .. } => number,
408 Short { number, .. } => number,
409 _ => {
410 return None;
411 }
412 })
413 }
414
415 pub(crate) fn space(&self) -> SpaceId {
416 use Header::*;
417 match *self {
418 Short { .. } => SpaceId::Data,
419 Long {
420 ty: LongType::ZeroRtt,
421 ..
422 } => SpaceId::Data,
423 Long {
424 ty: LongType::Handshake,
425 ..
426 } => SpaceId::Handshake,
427 _ => SpaceId::Initial,
428 }
429 }
430
431 pub(crate) fn key_phase(&self) -> bool {
432 match *self {
433 Self::Short { key_phase, .. } => key_phase,
434 _ => false,
435 }
436 }
437
438 pub(crate) fn is_short(&self) -> bool {
439 matches!(*self, Self::Short { .. })
440 }
441
442 pub(crate) fn is_1rtt(&self) -> bool {
443 self.is_short()
444 }
445
446 pub(crate) fn is_0rtt(&self) -> bool {
447 matches!(
448 *self,
449 Self::Long {
450 ty: LongType::ZeroRtt,
451 ..
452 }
453 )
454 }
455
456 pub(crate) fn dst_cid(&self) -> ConnectionId {
457 use Header::*;
458 match *self {
459 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
460 Long { dst_cid, .. } => dst_cid,
461 Retry { dst_cid, .. } => dst_cid,
462 Short { dst_cid, .. } => dst_cid,
463 VersionNegotiate { dst_cid, .. } => dst_cid,
464 }
465 }
466
467 pub(crate) fn has_frames(&self) -> bool {
469 use Header::*;
470 match *self {
471 Initial(_) => true,
472 Long { .. } => true,
473 Retry { .. } => false,
474 Short { .. } => true,
475 VersionNegotiate { .. } => false,
476 }
477 }
478
479 #[cfg(feature = "qlog")]
480 pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
481 match self {
482 Self::Initial(initial_header) => Some(initial_header.src_cid),
483 Self::Long { src_cid, .. } => Some(*src_cid),
484 Self::Retry { src_cid, .. } => Some(*src_cid),
485 Self::Short { .. } => None,
486 Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
487 }
488 }
489}
490
491pub(crate) struct PartialEncode {
492 pub(crate) start: usize,
493 pub(crate) header_len: usize,
494 pn: Option<(usize, bool)>,
496}
497
498impl PartialEncode {
499 pub(crate) fn finish(
500 self,
501 buf: &mut [u8],
502 header_crypto: &dyn crypto::HeaderKey,
503 crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
504 ) {
505 let Self { header_len, pn, .. } = self;
506 let (pn_len, write_len) = match pn {
507 Some((pn_len, write_len)) => (pn_len, write_len),
508 None => return,
509 };
510
511 let pn_pos = header_len - pn_len;
512 if write_len {
513 let len = buf.len() - header_len + pn_len;
514 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
516 slice.put_u16(len as u16 | (0b01 << 14));
517 }
518
519 if let Some((packet_number, path_id, crypto)) = crypto {
520 crypto.encrypt(path_id, packet_number, buf, header_len);
521 }
522
523 debug_assert!(
524 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
525 "packet must be padded to at least {} bytes for header protection sampling",
526 pn_pos + 4 + header_crypto.sample_size()
527 );
528 header_crypto.encrypt(pn_pos, buf);
529 }
530}
531
532#[derive(Clone, Debug)]
534pub enum ProtectedHeader {
535 Initial(ProtectedInitialHeader),
537 Long {
539 ty: LongType,
541 dst_cid: ConnectionId,
543 src_cid: ConnectionId,
545 len: u64,
547 version: u32,
549 },
550 Retry {
552 dst_cid: ConnectionId,
554 src_cid: ConnectionId,
556 version: u32,
558 },
559 Short {
561 spin: bool,
563 dst_cid: ConnectionId,
565 },
566 VersionNegotiate {
568 random: u8,
570 dst_cid: ConnectionId,
572 src_cid: ConnectionId,
574 },
575}
576
577impl ProtectedHeader {
578 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
579 match self {
580 Self::Initial(x) => Some(x),
581 _ => None,
582 }
583 }
584
585 pub fn dst_cid(&self) -> ConnectionId {
587 use ProtectedHeader::*;
588 match self {
589 Initial(header) => header.dst_cid,
590 &Long { dst_cid, .. } => dst_cid,
591 &Retry { dst_cid, .. } => dst_cid,
592 &Short { dst_cid, .. } => dst_cid,
593 &VersionNegotiate { dst_cid, .. } => dst_cid,
594 }
595 }
596
597 fn payload_len(&self) -> Option<u64> {
598 use ProtectedHeader::*;
599 match self {
600 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
601 _ => None,
602 }
603 }
604
605 pub fn decode(
607 buf: &mut io::Cursor<BytesMut>,
608 cid_parser: &(impl ConnectionIdParser + ?Sized),
609 supported_versions: &[u32],
610 grease_quic_bit: bool,
611 ) -> Result<Self, PacketDecodeError> {
612 let first = buf.get::<u8>()?;
613 if !grease_quic_bit && first & FIXED_BIT == 0 {
614 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
615 }
616 if first & LONG_HEADER_FORM == 0 {
617 let spin = first & SPIN_BIT != 0;
618
619 Ok(Self::Short {
620 spin,
621 dst_cid: cid_parser.parse(buf)?,
622 })
623 } else {
624 let version = buf.get::<u32>()?;
625
626 let dst_cid = ConnectionId::decode_long(buf)
627 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
628 let src_cid = ConnectionId::decode_long(buf)
629 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
630
631 if version == 0 {
633 let random = first & !LONG_HEADER_FORM;
634 return Ok(Self::VersionNegotiate {
635 random,
636 dst_cid,
637 src_cid,
638 });
639 }
640
641 if !supported_versions.contains(&version) {
642 return Err(PacketDecodeError::UnsupportedVersion {
643 src_cid,
644 dst_cid,
645 version,
646 });
647 }
648
649 match LongHeaderType::from_byte(first)? {
650 LongHeaderType::Initial => {
651 let token_len = buf.get_var()? as usize;
652 let token_start = buf.position() as usize;
653 if token_len > buf.remaining() {
654 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
655 }
656 buf.advance(token_len);
657
658 let len = buf.get_var()?;
659 Ok(Self::Initial(ProtectedInitialHeader {
660 dst_cid,
661 src_cid,
662 token_pos: token_start..token_start + token_len,
663 len,
664 version,
665 }))
666 }
667 LongHeaderType::Retry => Ok(Self::Retry {
668 dst_cid,
669 src_cid,
670 version,
671 }),
672 LongHeaderType::Standard(ty) => Ok(Self::Long {
673 ty,
674 dst_cid,
675 src_cid,
676 len: buf.get_var()?,
677 version,
678 }),
679 }
680 }
681 }
682}
683
684#[derive(Clone, Debug)]
686pub struct ProtectedInitialHeader {
687 pub dst_cid: ConnectionId,
689 pub src_cid: ConnectionId,
691 pub token_pos: Range<usize>,
693 pub len: u64,
695 pub version: u32,
697}
698
699#[derive(Clone, Debug)]
700pub(crate) struct InitialHeader {
701 pub(crate) dst_cid: ConnectionId,
702 pub(crate) src_cid: ConnectionId,
703 pub(crate) token: Bytes,
704 pub(crate) number: PacketNumber,
705 pub(crate) version: u32,
706}
707
708#[derive(Debug, Copy, Clone, Eq, PartialEq)]
710pub(crate) enum PacketNumber {
711 U8(u8),
712 U16(u16),
713 U24(u32),
714 U32(u32),
715}
716
717impl PacketNumber {
718 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
719 let range = (n - largest_acked) * 2;
720 if range < 1 << 8 {
721 Self::U8(n as u8)
722 } else if range < 1 << 16 {
723 Self::U16(n as u16)
724 } else if range < 1 << 24 {
725 Self::U24(n as u32)
726 } else if range < 1 << 32 {
727 Self::U32(n as u32)
728 } else {
729 panic!("packet number too large to encode")
730 }
731 }
732
733 pub(crate) fn len(self) -> usize {
734 use PacketNumber::*;
735 match self {
736 U8(_) => 1,
737 U16(_) => 2,
738 U24(_) => 3,
739 U32(_) => 4,
740 }
741 }
742
743 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
744 use PacketNumber::*;
745 match self {
746 U8(x) => w.write(x),
747 U16(x) => w.write(x),
748 U24(x) => w.put_uint(u64::from(x), 3),
749 U32(x) => w.write(x),
750 }
751 }
752
753 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
754 use PacketNumber::*;
755 let pn = match len {
756 1 => U8(r.get()?),
757 2 => U16(r.get()?),
758 3 => U24(r.get_uint(3) as u32),
759 4 => U32(r.get()?),
760 _ => unreachable!(),
761 };
762 Ok(pn)
763 }
764
765 pub(crate) fn decode_len(tag: u8) -> usize {
766 1 + (tag & 0x03) as usize
767 }
768
769 fn tag(self) -> u8 {
770 use PacketNumber::*;
771 match self {
772 U8(_) => 0b00,
773 U16(_) => 0b01,
774 U24(_) => 0b10,
775 U32(_) => 0b11,
776 }
777 }
778
779 pub(crate) fn expand(self, expected: u64) -> u64 {
780 use PacketNumber::*;
782 let truncated = match self {
783 U8(x) => u64::from(x),
784 U16(x) => u64::from(x),
785 U24(x) => u64::from(x),
786 U32(x) => u64::from(x),
787 };
788 let nbits = self.len() * 8;
789 let win = 1 << nbits;
790 let hwin = win / 2;
791 let mask = win - 1;
792 let candidate = (expected & !mask) | truncated;
801 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
802 candidate + win
803 } else if candidate > expected + hwin && candidate > win {
804 candidate - win
805 } else {
806 candidate
807 }
808 }
809}
810
811pub struct FixedLengthConnectionIdParser {
813 expected_len: usize,
814}
815
816impl FixedLengthConnectionIdParser {
817 pub fn new(expected_len: usize) -> Self {
819 Self { expected_len }
820 }
821}
822
823impl ConnectionIdParser for FixedLengthConnectionIdParser {
824 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
825 (buffer.remaining() >= self.expected_len)
826 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
827 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
828 }
829}
830
831pub trait ConnectionIdParser {
833 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
835}
836
837#[derive(Clone, Copy, Debug, Eq, PartialEq)]
839pub(crate) enum LongHeaderType {
840 Initial,
841 Retry,
842 Standard(LongType),
843}
844
845impl LongHeaderType {
846 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
847 use {LongHeaderType::*, LongType::*};
848 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
849 Ok(match (b & 0x30) >> 4 {
850 0x0 => Initial,
851 0x1 => Standard(ZeroRtt),
852 0x2 => Standard(Handshake),
853 0x3 => Retry,
854 _ => unreachable!(),
855 })
856 }
857}
858
859impl From<LongHeaderType> for u8 {
860 fn from(ty: LongHeaderType) -> Self {
861 use {LongHeaderType::*, LongType::*};
862 match ty {
863 Initial => LONG_HEADER_FORM | FIXED_BIT,
864 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
865 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
866 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
867 }
868 }
869}
870
871#[derive(Clone, Copy, Debug, Eq, PartialEq)]
873pub enum LongType {
874 Handshake,
876 ZeroRtt,
878}
879
880#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
882pub enum PacketDecodeError {
883 #[error("unsupported version {version:x}")]
885 UnsupportedVersion {
886 src_cid: ConnectionId,
888 dst_cid: ConnectionId,
890 version: u32,
892 },
893 #[error("invalid header: {0}")]
895 InvalidHeader(&'static str),
896}
897
898impl From<coding::UnexpectedEnd> for PacketDecodeError {
899 fn from(_: coding::UnexpectedEnd) -> Self {
900 Self::InvalidHeader("unexpected end of packet")
901 }
902}
903
904pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
905pub(crate) const FIXED_BIT: u8 = 0x40;
906pub(crate) const SPIN_BIT: u8 = 0x20;
907const SHORT_RESERVED_BITS: u8 = 0x18;
908const LONG_RESERVED_BITS: u8 = 0x0c;
909const KEY_PHASE_BIT: u8 = 0x04;
910
911#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
913pub enum SpaceId {
914 Initial = 0,
916 Handshake = 1,
917 Data = 2,
919}
920
921impl SpaceId {
922 pub fn iter() -> impl Iterator<Item = Self> {
923 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
924 }
925
926 pub fn next(&self) -> Self {
930 match self {
931 Self::Initial => Self::Handshake,
932 Self::Handshake => Self::Data,
933 Self::Data => Self::Data,
934 }
935 }
936}
937
938#[cfg(test)]
939mod tests {
940 use super::*;
941 use hex_literal::hex;
942 use std::io;
943
944 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
945 let mut buf = Vec::new();
946 typed.encode(&mut buf);
947 assert_eq!(&buf[..], encoded);
948 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
949 assert_eq!(typed, decoded);
950 }
951
952 #[test]
953 fn roundtrip_packet_numbers() {
954 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
955 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
956 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
957 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
958 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
959 }
960
961 #[test]
962 fn pn_encode() {
963 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
964 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
965 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
966 }
967
968 #[test]
969 fn pn_expand_roundtrip() {
970 for expected in 0..1024 {
971 for actual in expected..1024 {
972 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
973 }
974 }
975 }
976
977 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
978 #[test]
979 fn header_encoding() {
980 use crate::Side;
981 use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
982 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
983 use rustls::crypto::aws_lc_rs::default_provider;
984 #[cfg(feature = "rustls-ring")]
985 use rustls::crypto::ring::default_provider;
986 use rustls::quic::Version;
987
988 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
989 let provider = default_provider();
990
991 let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
992 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
993 let mut buf = Vec::new();
994 let header = Header::Initial(InitialHeader {
995 number: PacketNumber::U8(0),
996 src_cid: ConnectionId::new(&[]),
997 dst_cid: dcid,
998 token: Bytes::new(),
999 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1000 });
1001 let encode = header.encode(&mut buf);
1002 let header_len = buf.len();
1003 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1004 encode.finish(
1005 &mut buf,
1006 &*client.header.local,
1007 Some((0, PathId::ZERO, &*client.packet.local)),
1008 );
1009
1010 for byte in &buf {
1011 print!("{byte:02x}");
1012 }
1013 println!();
1014 assert_eq!(
1015 buf[..],
1016 hex!(
1017 "c8000000010806b858ec6f80452b00004021be
1018 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1019 )[..]
1020 );
1021
1022 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1023 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1024 let decode = PartialDecode::new(
1025 buf.as_slice().into(),
1026 &FixedLengthConnectionIdParser::new(0),
1027 &supported_versions,
1028 false,
1029 )
1030 .unwrap()
1031 .0;
1032 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1033 assert_eq!(
1034 packet.header_data[..],
1035 hex!("c0000000010806b858ec6f80452b0000402100")[..]
1036 );
1037 server
1038 .packet
1039 .remote
1040 .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1041 .unwrap();
1042 assert_eq!(packet.payload[..], [0; 16]);
1043 match packet.header {
1044 Header::Initial(InitialHeader {
1045 number: PacketNumber::U8(0),
1046 ..
1047 }) => {}
1048 _ => {
1049 panic!("unexpected header {:?}", packet.header);
1050 }
1051 }
1052 }
1053}