1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7 ConnectionId, PathId,
8 coding::{self, BufExt, BufMutExt},
9 connection::{EncryptionLevel, SpaceKind},
10 crypto,
11};
12
13#[cfg_attr(test, derive(Clone))]
26#[derive(Debug)]
27pub struct PartialDecode {
28 plain_header: ProtectedHeader,
29 buf: io::Cursor<BytesMut>,
30}
31
32#[allow(clippy::len_without_is_empty)]
33impl PartialDecode {
34 pub fn new(
36 bytes: BytesMut,
37 cid_parser: &(impl ConnectionIdParser + ?Sized),
38 supported_versions: &[u32],
39 grease_quic_bit: bool,
40 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
41 let mut buf = io::Cursor::new(bytes);
42 let plain_header =
43 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
44 let dgram_len = buf.get_ref().len();
45 let packet_len = plain_header
46 .payload_len()
47 .map(|len| (buf.position() + len) as usize)
48 .unwrap_or(dgram_len);
49 match dgram_len.cmp(&packet_len) {
50 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
51 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
52 "packet too short to contain payload length",
53 )),
54 Ordering::Greater => {
55 let rest = Some(buf.get_mut().split_off(packet_len));
56 Ok((Self { plain_header, buf }, rest))
57 }
58 }
59 }
60
61 pub(crate) fn data(&self) -> &[u8] {
63 self.buf.get_ref()
64 }
65
66 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
67 self.plain_header.as_initial()
68 }
69
70 pub(crate) fn has_long_header(&self) -> bool {
71 !matches!(self.plain_header, ProtectedHeader::Short { .. })
72 }
73
74 pub(crate) fn is_initial(&self) -> bool {
75 self.encryption_level() == Some(EncryptionLevel::Initial)
76 }
77
78 pub(crate) fn encryption_level(&self) -> Option<EncryptionLevel> {
79 use ProtectedHeader::*;
80 match self.plain_header {
81 Initial { .. } => Some(EncryptionLevel::Initial),
82 Long { ty, .. } => Some(match ty {
83 LongType::Handshake => EncryptionLevel::Handshake,
84 LongType::ZeroRtt => EncryptionLevel::ZeroRtt,
85 }),
86 Short { .. } => Some(EncryptionLevel::OneRtt),
87 _ => None,
88 }
89 }
90
91 pub(crate) fn is_0rtt(&self) -> bool {
92 match self.plain_header {
93 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
94 _ => false,
95 }
96 }
97
98 pub fn dst_cid(&self) -> ConnectionId {
100 self.plain_header.dst_cid()
101 }
102
103 #[allow(unreachable_pub)] pub fn len(&self) -> usize {
106 self.buf.get_ref().len()
107 }
108
109 pub(crate) fn finish(
110 self,
111 header_crypto: Option<&dyn crypto::HeaderKey>,
112 ) -> Result<Packet, PacketDecodeError> {
113 use ProtectedHeader::*;
114 let Self {
115 plain_header,
116 mut buf,
117 } = self;
118
119 if let Initial(ProtectedInitialHeader {
120 dst_cid,
121 src_cid,
122 token_pos,
123 version,
124 ..
125 }) = plain_header
126 {
127 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
128 let header_len = buf.position() as usize;
129 let mut bytes = buf.into_inner();
130
131 let header_data = bytes.split_to(header_len).freeze();
132 let token = header_data.slice(token_pos.start..token_pos.end);
133 return Ok(Packet {
134 header: Header::Initial(InitialHeader {
135 dst_cid,
136 src_cid,
137 token,
138 number,
139 version,
140 }),
141 header_data,
142 payload: bytes,
143 });
144 }
145
146 let header = match plain_header {
147 Long {
148 ty,
149 dst_cid,
150 src_cid,
151 version,
152 ..
153 } => Header::Long {
154 ty,
155 dst_cid,
156 src_cid,
157 number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
158 version,
159 },
160 Retry {
161 dst_cid,
162 src_cid,
163 version,
164 } => Header::Retry {
165 dst_cid,
166 src_cid,
167 version,
168 },
169 Short { spin, dst_cid, .. } => {
170 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
171 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
172 Header::Short {
173 spin,
174 key_phase,
175 dst_cid,
176 number,
177 }
178 }
179 VersionNegotiate {
180 random,
181 dst_cid,
182 src_cid,
183 } => Header::VersionNegotiate {
184 random,
185 dst_cid,
186 src_cid,
187 },
188 Initial { .. } => unreachable!(),
189 };
190
191 let header_len = buf.position() as usize;
192 let mut bytes = buf.into_inner();
193 Ok(Packet {
194 header,
195 header_data: bytes.split_to(header_len).freeze(),
196 payload: bytes,
197 })
198 }
199
200 fn decrypt_header(
201 buf: &mut io::Cursor<BytesMut>,
202 header_crypto: &dyn crypto::HeaderKey,
203 ) -> Result<PacketNumber, PacketDecodeError> {
204 let packet_length = buf.get_ref().len();
205 let pn_offset = buf.position() as usize;
206 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
207 return Err(PacketDecodeError::InvalidHeader(
208 "packet too short to extract header protection sample",
209 ));
210 }
211
212 header_crypto.decrypt(pn_offset, buf.get_mut());
213
214 let len = PacketNumber::decode_len(buf.get_ref()[0]);
215 PacketNumber::decode(len, buf)
216 }
217}
218
219pub(crate) trait BufLen {
226 fn len(&self) -> usize;
228}
229
230impl BufLen for Vec<u8> {
231 fn len(&self) -> usize {
232 self.len()
233 }
234}
235
236pub(crate) struct Packet {
240 pub(crate) header: Header,
242 pub(crate) header_data: Bytes,
244 pub(crate) payload: BytesMut,
246}
247
248impl Packet {
249 pub(crate) fn reserved_bits_valid(&self) -> bool {
250 let mask = match self.header {
251 Header::Short { .. } => SHORT_RESERVED_BITS,
252 _ => LONG_RESERVED_BITS,
253 };
254 self.header_data[0] & mask == 0
255 }
256}
257
258pub(crate) struct InitialPacket {
259 pub(crate) header: InitialHeader,
260 pub(crate) header_data: Bytes,
261 pub(crate) payload: BytesMut,
262}
263
264impl From<InitialPacket> for Packet {
265 fn from(x: InitialPacket) -> Self {
266 Self {
267 header: Header::Initial(x.header),
268 header_data: x.header_data,
269 payload: x.payload,
270 }
271 }
272}
273
274#[cfg_attr(test, derive(Clone))]
275#[derive(Debug)]
276pub(crate) enum Header {
277 Initial(InitialHeader),
278 Long {
279 ty: LongType,
280 dst_cid: ConnectionId,
281 src_cid: ConnectionId,
282 number: PacketNumber,
283 version: u32,
284 },
285 Retry {
286 dst_cid: ConnectionId,
287 src_cid: ConnectionId,
288 version: u32,
289 },
290 Short {
291 spin: bool,
292 key_phase: bool,
293 dst_cid: ConnectionId,
294 number: PacketNumber,
295 },
296 VersionNegotiate {
297 random: u8,
298 src_cid: ConnectionId,
299 dst_cid: ConnectionId,
300 },
301}
302
303impl Header {
304 pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
305 use Header::*;
306 let start = w.len();
307 match *self {
308 Initial(InitialHeader {
309 ref dst_cid,
310 ref src_cid,
311 ref token,
312 number,
313 version,
314 }) => {
315 w.write(u8::from(LongHeaderType::Initial) | number.tag());
316 w.write(version);
317 dst_cid.encode_long(w);
318 src_cid.encode_long(w);
319 w.write_var(token.len() as u64);
320 w.put_slice(token);
321 w.write::<u16>(0); number.encode(w);
323 PartialEncode {
324 start,
325 header_len: w.len() - start,
326 pn: Some((number.len(), true)),
327 }
328 }
329 Long {
330 ty,
331 ref dst_cid,
332 ref src_cid,
333 number,
334 version,
335 } => {
336 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
337 w.write(version);
338 dst_cid.encode_long(w);
339 src_cid.encode_long(w);
340 w.write::<u16>(0); number.encode(w);
342 PartialEncode {
343 start,
344 header_len: w.len() - start,
345 pn: Some((number.len(), true)),
346 }
347 }
348 Retry {
349 ref dst_cid,
350 ref src_cid,
351 version,
352 } => {
353 w.write(u8::from(LongHeaderType::Retry));
354 w.write(version);
355 dst_cid.encode_long(w);
356 src_cid.encode_long(w);
357 PartialEncode {
358 start,
359 header_len: w.len() - start,
360 pn: None,
361 }
362 }
363 Short {
364 spin,
365 key_phase,
366 ref dst_cid,
367 number,
368 } => {
369 w.write(
370 FIXED_BIT
371 | if key_phase { KEY_PHASE_BIT } else { 0 }
372 | if spin { SPIN_BIT } else { 0 }
373 | number.tag(),
374 );
375 w.put_slice(dst_cid);
376 number.encode(w);
377 PartialEncode {
378 start,
379 header_len: w.len() - start,
380 pn: Some((number.len(), false)),
381 }
382 }
383 VersionNegotiate {
384 ref random,
385 ref dst_cid,
386 ref src_cid,
387 } => {
388 w.write(0x80u8 | random);
389 w.write::<u32>(0);
390 dst_cid.encode_long(w);
391 src_cid.encode_long(w);
392 PartialEncode {
393 start,
394 header_len: w.len() - start,
395 pn: None,
396 }
397 }
398 }
399 }
400
401 pub(crate) fn is_protected(&self) -> bool {
403 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
404 }
405
406 pub(crate) fn number(&self) -> Option<PacketNumber> {
407 use Header::*;
408 Some(match *self {
409 Initial(InitialHeader { number, .. }) => number,
410 Long { number, .. } => number,
411 Short { number, .. } => number,
412 _ => {
413 return None;
414 }
415 })
416 }
417
418 pub(crate) fn space(&self) -> SpaceKind {
419 use Header::*;
420 match *self {
421 Short { .. } => SpaceKind::Data,
422 Long {
423 ty: LongType::ZeroRtt,
424 ..
425 } => SpaceKind::Data,
426 Long {
427 ty: LongType::Handshake,
428 ..
429 } => SpaceKind::Handshake,
430 _ => SpaceKind::Initial,
431 }
432 }
433
434 pub(crate) fn key_phase(&self) -> bool {
435 match *self {
436 Self::Short { key_phase, .. } => key_phase,
437 _ => false,
438 }
439 }
440
441 pub(crate) fn is_short(&self) -> bool {
442 matches!(*self, Self::Short { .. })
443 }
444
445 pub(crate) fn is_1rtt(&self) -> bool {
446 self.is_short()
447 }
448
449 pub(crate) fn is_0rtt(&self) -> bool {
450 matches!(
451 *self,
452 Self::Long {
453 ty: LongType::ZeroRtt,
454 ..
455 }
456 )
457 }
458
459 pub(crate) fn dst_cid(&self) -> ConnectionId {
460 use Header::*;
461 match *self {
462 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
463 Long { dst_cid, .. } => dst_cid,
464 Retry { dst_cid, .. } => dst_cid,
465 Short { dst_cid, .. } => dst_cid,
466 VersionNegotiate { dst_cid, .. } => dst_cid,
467 }
468 }
469
470 pub(crate) fn can_coalesce(&self) -> bool {
474 use Header::*;
475 match *self {
476 Initial(_) => true,
477 Long { .. } => true,
478 Retry { .. } => false,
479 Short { .. } => false,
480 VersionNegotiate { .. } => false,
481 }
482 }
483
484 pub(crate) fn has_frames(&self) -> bool {
486 use Header::*;
487 match *self {
488 Initial(_) => true,
489 Long { .. } => true,
490 Retry { .. } => false,
491 Short { .. } => true,
492 VersionNegotiate { .. } => false,
493 }
494 }
495
496 #[cfg(feature = "qlog")]
497 pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
498 match self {
499 Self::Initial(initial_header) => Some(initial_header.src_cid),
500 Self::Long { src_cid, .. } => Some(*src_cid),
501 Self::Retry { src_cid, .. } => Some(*src_cid),
502 Self::Short { .. } => None,
503 Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
504 }
505 }
506}
507
508pub(crate) struct PartialEncode {
509 pub(crate) start: usize,
510 pub(crate) header_len: usize,
511 pn: Option<(usize, bool)>,
513}
514
515impl PartialEncode {
516 pub(crate) fn finish(
517 self,
518 buf: &mut [u8],
519 header_crypto: &dyn crypto::HeaderKey,
520 crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
521 ) {
522 let Self { header_len, pn, .. } = self;
523 let Some((pn_len, write_len)) = pn else {
524 return;
525 };
526
527 let pn_pos = header_len - pn_len;
528 if write_len {
529 let len = buf.len() - header_len + pn_len;
530 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
532 slice.put_u16(len as u16 | (0b01 << 14));
533 }
534
535 if let Some((packet_number, path_id, crypto)) = crypto {
536 crypto.encrypt(path_id, packet_number, buf, header_len);
537 }
538
539 debug_assert!(
540 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
541 "packet must be padded to at least {} bytes for header protection sampling",
542 pn_pos + 4 + header_crypto.sample_size()
543 );
544 header_crypto.encrypt(pn_pos, buf);
545 }
546
547 #[cfg(test)]
551 pub(crate) fn no_header() -> Self {
552 Self {
553 start: 0,
554 header_len: 0,
555 pn: None,
556 }
557 }
558}
559
560#[derive(Clone, Debug)]
562pub enum ProtectedHeader {
563 Initial(ProtectedInitialHeader),
565 Long {
567 ty: LongType,
569 dst_cid: ConnectionId,
571 src_cid: ConnectionId,
573 len: u64,
575 version: u32,
577 },
578 Retry {
580 dst_cid: ConnectionId,
582 src_cid: ConnectionId,
584 version: u32,
586 },
587 Short {
589 spin: bool,
591 dst_cid: ConnectionId,
593 },
594 VersionNegotiate {
596 random: u8,
598 dst_cid: ConnectionId,
600 src_cid: ConnectionId,
602 },
603}
604
605impl ProtectedHeader {
606 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
607 match self {
608 Self::Initial(x) => Some(x),
609 _ => None,
610 }
611 }
612
613 pub fn dst_cid(&self) -> ConnectionId {
615 use ProtectedHeader::*;
616 match self {
617 Initial(header) => header.dst_cid,
618 &Long { dst_cid, .. } => dst_cid,
619 &Retry { dst_cid, .. } => dst_cid,
620 &Short { dst_cid, .. } => dst_cid,
621 &VersionNegotiate { dst_cid, .. } => dst_cid,
622 }
623 }
624
625 fn payload_len(&self) -> Option<u64> {
626 use ProtectedHeader::*;
627 match self {
628 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
629 _ => None,
630 }
631 }
632
633 pub fn decode(
635 buf: &mut io::Cursor<BytesMut>,
636 cid_parser: &(impl ConnectionIdParser + ?Sized),
637 supported_versions: &[u32],
638 grease_quic_bit: bool,
639 ) -> Result<Self, PacketDecodeError> {
640 let first = buf.get::<u8>()?;
641 if !grease_quic_bit && first & FIXED_BIT == 0 {
642 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
643 }
644 if first & LONG_HEADER_FORM == 0 {
645 let spin = first & SPIN_BIT != 0;
646
647 Ok(Self::Short {
648 spin,
649 dst_cid: cid_parser.parse(buf)?,
650 })
651 } else {
652 let version = buf.get::<u32>()?;
653
654 let dst_cid = ConnectionId::decode_long(buf)
655 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
656 let src_cid = ConnectionId::decode_long(buf)
657 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
658
659 if version == 0 {
661 let random = first & !LONG_HEADER_FORM;
662 return Ok(Self::VersionNegotiate {
663 random,
664 dst_cid,
665 src_cid,
666 });
667 }
668
669 if !supported_versions.contains(&version) {
670 return Err(PacketDecodeError::UnsupportedVersion {
671 src_cid,
672 dst_cid,
673 version,
674 });
675 }
676
677 match LongHeaderType::from_byte(first)? {
678 LongHeaderType::Initial => {
679 let token_len = buf.get_var()? as usize;
680 let token_start = buf.position() as usize;
681 if token_len > buf.remaining() {
682 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
683 }
684 buf.advance(token_len);
685
686 let len = buf.get_var()?;
687 Ok(Self::Initial(ProtectedInitialHeader {
688 dst_cid,
689 src_cid,
690 token_pos: token_start..token_start + token_len,
691 len,
692 version,
693 }))
694 }
695 LongHeaderType::Retry => Ok(Self::Retry {
696 dst_cid,
697 src_cid,
698 version,
699 }),
700 LongHeaderType::Standard(ty) => Ok(Self::Long {
701 ty,
702 dst_cid,
703 src_cid,
704 len: buf.get_var()?,
705 version,
706 }),
707 }
708 }
709 }
710}
711
712#[derive(Clone, Debug)]
714pub struct ProtectedInitialHeader {
715 pub dst_cid: ConnectionId,
717 pub src_cid: ConnectionId,
719 pub token_pos: Range<usize>,
721 pub len: u64,
723 pub version: u32,
725}
726
727#[derive(Clone, Debug)]
728pub(crate) struct InitialHeader {
729 pub(crate) dst_cid: ConnectionId,
730 pub(crate) src_cid: ConnectionId,
731 pub(crate) token: Bytes,
732 pub(crate) number: PacketNumber,
733 pub(crate) version: u32,
734}
735
736#[derive(Debug, Copy, Clone, Eq, PartialEq)]
738pub(crate) enum PacketNumber {
739 U8(u8),
740 U16(u16),
741 U24(u32),
742 U32(u32),
743}
744
745impl PacketNumber {
746 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
747 let range = (n - largest_acked) * 2;
748 if range < 1 << 8 {
749 Self::U8(n as u8)
750 } else if range < 1 << 16 {
751 Self::U16(n as u16)
752 } else if range < 1 << 24 {
753 Self::U24(n as u32)
754 } else if range < 1 << 32 {
755 Self::U32(n as u32)
756 } else {
757 panic!("packet number too large to encode")
758 }
759 }
760
761 pub(crate) fn len(self) -> usize {
762 use PacketNumber::*;
763 match self {
764 U8(_) => 1,
765 U16(_) => 2,
766 U24(_) => 3,
767 U32(_) => 4,
768 }
769 }
770
771 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
772 use PacketNumber::*;
773 match self {
774 U8(x) => w.write(x),
775 U16(x) => w.write(x),
776 U24(x) => w.put_uint(u64::from(x), 3),
777 U32(x) => w.write(x),
778 }
779 }
780
781 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
782 use PacketNumber::*;
783 let pn = match len {
784 1 => U8(r.get()?),
785 2 => U16(r.get()?),
786 3 => U24(r.get_uint(3) as u32),
787 4 => U32(r.get()?),
788 _ => unreachable!(),
789 };
790 Ok(pn)
791 }
792
793 pub(crate) fn decode_len(tag: u8) -> usize {
794 1 + (tag & 0x03) as usize
795 }
796
797 fn tag(self) -> u8 {
798 use PacketNumber::*;
799 match self {
800 U8(_) => 0b00,
801 U16(_) => 0b01,
802 U24(_) => 0b10,
803 U32(_) => 0b11,
804 }
805 }
806
807 pub(crate) fn expand(self, expected: u64) -> u64 {
808 use PacketNumber::*;
810 let truncated = match self {
811 U8(x) => u64::from(x),
812 U16(x) => u64::from(x),
813 U24(x) => u64::from(x),
814 U32(x) => u64::from(x),
815 };
816 let nbits = self.len() * 8;
817 let win = 1 << nbits;
818 let hwin = win / 2;
819 let mask = win - 1;
820 let candidate = (expected & !mask) | truncated;
829 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
830 candidate + win
831 } else if candidate > expected + hwin && candidate > win {
832 candidate - win
833 } else {
834 candidate
835 }
836 }
837}
838
839pub struct FixedLengthConnectionIdParser {
841 expected_len: usize,
842}
843
844impl FixedLengthConnectionIdParser {
845 pub fn new(expected_len: usize) -> Self {
847 Self { expected_len }
848 }
849}
850
851impl ConnectionIdParser for FixedLengthConnectionIdParser {
852 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
853 (buffer.remaining() >= self.expected_len)
854 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
855 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
856 }
857}
858
859pub trait ConnectionIdParser {
861 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
863}
864
865#[derive(Clone, Copy, Debug, Eq, PartialEq)]
867pub(crate) enum LongHeaderType {
868 Initial,
869 Retry,
870 Standard(LongType),
871}
872
873impl LongHeaderType {
874 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
875 use {LongHeaderType::*, LongType::*};
876 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
877 Ok(match (b & 0x30) >> 4 {
878 0x0 => Initial,
879 0x1 => Standard(ZeroRtt),
880 0x2 => Standard(Handshake),
881 0x3 => Retry,
882 _ => unreachable!(),
883 })
884 }
885}
886
887impl From<LongHeaderType> for u8 {
888 fn from(ty: LongHeaderType) -> Self {
889 use {LongHeaderType::*, LongType::*};
890 match ty {
891 Initial => LONG_HEADER_FORM | FIXED_BIT,
892 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
893 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
894 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
895 }
896 }
897}
898
899#[derive(Clone, Copy, Debug, Eq, PartialEq)]
901pub enum LongType {
902 Handshake,
904 ZeroRtt,
906}
907
908#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
910pub enum PacketDecodeError {
911 #[error("unsupported version {version:x}")]
913 UnsupportedVersion {
914 src_cid: ConnectionId,
916 dst_cid: ConnectionId,
918 version: u32,
920 },
921 #[error("invalid header: {0}")]
923 InvalidHeader(&'static str),
924}
925
926impl From<coding::UnexpectedEnd> for PacketDecodeError {
927 fn from(_: coding::UnexpectedEnd) -> Self {
928 Self::InvalidHeader("unexpected end of packet")
929 }
930}
931
932pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
933pub(crate) const FIXED_BIT: u8 = 0x40;
934pub(crate) const SPIN_BIT: u8 = 0x20;
935const SHORT_RESERVED_BITS: u8 = 0x18;
936const LONG_RESERVED_BITS: u8 = 0x0c;
937const KEY_PHASE_BIT: u8 = 0x04;
938
939#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
941pub(crate) enum SpaceId {
942 Initial = 0,
944 Handshake = 1,
945 Data = 2,
947}
948
949impl SpaceId {
950 pub(crate) fn iter() -> impl Iterator<Item = Self> {
951 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
952 }
953
954 pub(crate) fn next(&self) -> Option<Self> {
958 match self {
959 Self::Initial => Some(Self::Handshake),
960 Self::Handshake => Some(Self::Data),
961 Self::Data => None,
962 }
963 }
964
965 pub(crate) fn encryption_level(self) -> EncryptionLevel {
967 match self {
968 Self::Initial => EncryptionLevel::Initial,
969 Self::Handshake => EncryptionLevel::Handshake,
970 Self::Data => EncryptionLevel::OneRtt,
971 }
972 }
973
974 pub(crate) fn kind(self) -> SpaceKind {
976 match self {
977 Self::Initial => SpaceKind::Initial,
978 Self::Handshake => SpaceKind::Handshake,
979 Self::Data => SpaceKind::Data,
980 }
981 }
982}
983
984#[cfg(test)]
985mod tests {
986 use super::*;
987 use hex_literal::hex;
988 use std::io;
989
990 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
991 let mut buf = Vec::new();
992 typed.encode(&mut buf);
993 assert_eq!(&buf[..], encoded);
994 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
995 assert_eq!(typed, decoded);
996 }
997
998 #[test]
999 fn roundtrip_packet_numbers() {
1000 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
1001 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
1002 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
1003 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
1004 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
1005 }
1006
1007 #[test]
1008 fn pn_encode() {
1009 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
1010 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
1011 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
1012 }
1013
1014 #[test]
1015 fn pn_expand_roundtrip() {
1016 for expected in 0..1024 {
1017 for actual in expected..1024 {
1018 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
1019 }
1020 }
1021 }
1022
1023 #[cfg(all(feature = "rustls", any(feature = "aws-lc-rs", feature = "ring")))]
1024 #[test]
1025 fn header_encoding() {
1026 use crate::Side;
1027 use crate::crypto::rustls::{configured_provider, initial_keys, initial_suite_from_provider};
1028 use rustls::quic::Version;
1029
1030 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
1031 let provider = configured_provider();
1032
1033 let suite = initial_suite_from_provider(&provider).unwrap();
1034 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
1035 let mut buf = Vec::new();
1036 let header = Header::Initial(InitialHeader {
1037 number: PacketNumber::U8(0),
1038 src_cid: ConnectionId::new(&[]),
1039 dst_cid: dcid,
1040 token: Bytes::new(),
1041 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1042 });
1043 let encode = header.encode(&mut buf);
1044 let header_len = buf.len();
1045 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1046 encode.finish(
1047 &mut buf,
1048 &*client.header.local,
1049 Some((0, PathId::ZERO, &*client.packet.local)),
1050 );
1051
1052 for byte in &buf {
1053 print!("{byte:02x}");
1054 }
1055 println!();
1056 assert_eq!(
1057 buf[..],
1058 hex!(
1059 "c8000000010806b858ec6f80452b00004021be
1060 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1061 )[..]
1062 );
1063
1064 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1065 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1066 let decode = PartialDecode::new(
1067 buf.as_slice().into(),
1068 &FixedLengthConnectionIdParser::new(0),
1069 &supported_versions,
1070 false,
1071 )
1072 .unwrap()
1073 .0;
1074 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1075 assert_eq!(
1076 packet.header_data[..],
1077 hex!("c0000000010806b858ec6f80452b0000402100")[..]
1078 );
1079 server
1080 .packet
1081 .remote
1082 .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1083 .unwrap();
1084 assert_eq!(packet.payload[..], [0; 16]);
1085 match packet.header {
1086 Header::Initial(InitialHeader {
1087 number: PacketNumber::U8(0),
1088 ..
1089 }) => {}
1090 _ => {
1091 panic!("unexpected header {:?}", packet.header);
1092 }
1093 }
1094 }
1095}