1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7 ConnectionId, PathId,
8 coding::{self, BufExt, BufMutExt},
9 connection::{EncryptionLevel, SpaceKind},
10 crypto,
11};
12
13#[cfg_attr(test, derive(Clone))]
26#[derive(Debug)]
27pub struct PartialDecode {
28 plain_header: ProtectedHeader,
29 buf: io::Cursor<BytesMut>,
30}
31
32#[allow(clippy::len_without_is_empty)]
33impl PartialDecode {
34 pub fn new(
36 bytes: BytesMut,
37 cid_parser: &(impl ConnectionIdParser + ?Sized),
38 supported_versions: &[u32],
39 grease_quic_bit: bool,
40 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
41 let mut buf = io::Cursor::new(bytes);
42 let plain_header =
43 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
44 let dgram_len = buf.get_ref().len();
45 let packet_len = plain_header
46 .payload_len()
47 .map(|len| (buf.position() + len) as usize)
48 .unwrap_or(dgram_len);
49 match dgram_len.cmp(&packet_len) {
50 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
51 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
52 "packet too short to contain payload length",
53 )),
54 Ordering::Greater => {
55 let rest = Some(buf.get_mut().split_off(packet_len));
56 Ok((Self { plain_header, buf }, rest))
57 }
58 }
59 }
60
61 pub(crate) fn data(&self) -> &[u8] {
63 self.buf.get_ref()
64 }
65
66 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
67 self.plain_header.as_initial()
68 }
69
70 pub(crate) fn has_long_header(&self) -> bool {
71 !matches!(self.plain_header, ProtectedHeader::Short { .. })
72 }
73
74 pub(crate) fn is_initial(&self) -> bool {
75 self.encryption_level() == Some(EncryptionLevel::Initial)
76 }
77
78 pub(crate) fn encryption_level(&self) -> Option<EncryptionLevel> {
79 use ProtectedHeader::*;
80 match self.plain_header {
81 Initial { .. } => Some(EncryptionLevel::Initial),
82 Long { ty, .. } => Some(match ty {
83 LongType::Handshake => EncryptionLevel::Handshake,
84 LongType::ZeroRtt => EncryptionLevel::ZeroRtt,
85 }),
86 Short { .. } => Some(EncryptionLevel::OneRtt),
87 _ => None,
88 }
89 }
90
91 pub(crate) fn is_0rtt(&self) -> bool {
92 match self.plain_header {
93 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
94 _ => false,
95 }
96 }
97
98 pub fn dst_cid(&self) -> ConnectionId {
100 self.plain_header.dst_cid()
101 }
102
103 #[allow(unreachable_pub)] pub fn len(&self) -> usize {
106 self.buf.get_ref().len()
107 }
108
109 pub(crate) fn finish(
110 self,
111 header_crypto: Option<&dyn crypto::HeaderKey>,
112 ) -> Result<Packet, PacketDecodeError> {
113 use ProtectedHeader::*;
114 let Self {
115 plain_header,
116 mut buf,
117 } = self;
118
119 if let Initial(ProtectedInitialHeader {
120 dst_cid,
121 src_cid,
122 token_pos,
123 version,
124 ..
125 }) = plain_header
126 {
127 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
128 let header_len = buf.position() as usize;
129 let mut bytes = buf.into_inner();
130
131 let header_data = bytes.split_to(header_len).freeze();
132 let token = header_data.slice(token_pos.start..token_pos.end);
133 return Ok(Packet {
134 header: Header::Initial(InitialHeader {
135 dst_cid,
136 src_cid,
137 token,
138 number,
139 version,
140 }),
141 header_data,
142 payload: bytes,
143 });
144 }
145
146 let header = match plain_header {
147 Long {
148 ty,
149 dst_cid,
150 src_cid,
151 version,
152 ..
153 } => Header::Long {
154 ty,
155 dst_cid,
156 src_cid,
157 number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
158 version,
159 },
160 Retry {
161 dst_cid,
162 src_cid,
163 version,
164 } => Header::Retry {
165 dst_cid,
166 src_cid,
167 version,
168 },
169 Short { spin, dst_cid, .. } => {
170 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
171 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
172 Header::Short {
173 spin,
174 key_phase,
175 dst_cid,
176 number,
177 }
178 }
179 VersionNegotiate {
180 random,
181 dst_cid,
182 src_cid,
183 } => Header::VersionNegotiate {
184 random,
185 dst_cid,
186 src_cid,
187 },
188 Initial { .. } => unreachable!(),
189 };
190
191 let header_len = buf.position() as usize;
192 let mut bytes = buf.into_inner();
193 Ok(Packet {
194 header,
195 header_data: bytes.split_to(header_len).freeze(),
196 payload: bytes,
197 })
198 }
199
200 fn decrypt_header(
201 buf: &mut io::Cursor<BytesMut>,
202 header_crypto: &dyn crypto::HeaderKey,
203 ) -> Result<PacketNumber, PacketDecodeError> {
204 let packet_length = buf.get_ref().len();
205 let pn_offset = buf.position() as usize;
206 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
207 return Err(PacketDecodeError::InvalidHeader(
208 "packet too short to extract header protection sample",
209 ));
210 }
211
212 header_crypto.decrypt(pn_offset, buf.get_mut());
213
214 let len = PacketNumber::decode_len(buf.get_ref()[0]);
215 PacketNumber::decode(len, buf)
216 }
217}
218
219pub(crate) trait BufLen {
226 fn len(&self) -> usize;
228}
229
230impl BufLen for Vec<u8> {
231 fn len(&self) -> usize {
232 self.len()
233 }
234}
235
236pub(crate) struct Packet {
237 pub(crate) header: Header,
238 pub(crate) header_data: Bytes,
239 pub(crate) payload: BytesMut,
240}
241
242impl Packet {
243 pub(crate) fn reserved_bits_valid(&self) -> bool {
244 let mask = match self.header {
245 Header::Short { .. } => SHORT_RESERVED_BITS,
246 _ => LONG_RESERVED_BITS,
247 };
248 self.header_data[0] & mask == 0
249 }
250}
251
252pub(crate) struct InitialPacket {
253 pub(crate) header: InitialHeader,
254 pub(crate) header_data: Bytes,
255 pub(crate) payload: BytesMut,
256}
257
258impl From<InitialPacket> for Packet {
259 fn from(x: InitialPacket) -> Self {
260 Self {
261 header: Header::Initial(x.header),
262 header_data: x.header_data,
263 payload: x.payload,
264 }
265 }
266}
267
268#[cfg_attr(test, derive(Clone))]
269#[derive(Debug)]
270pub(crate) enum Header {
271 Initial(InitialHeader),
272 Long {
273 ty: LongType,
274 dst_cid: ConnectionId,
275 src_cid: ConnectionId,
276 number: PacketNumber,
277 version: u32,
278 },
279 Retry {
280 dst_cid: ConnectionId,
281 src_cid: ConnectionId,
282 version: u32,
283 },
284 Short {
285 spin: bool,
286 key_phase: bool,
287 dst_cid: ConnectionId,
288 number: PacketNumber,
289 },
290 VersionNegotiate {
291 random: u8,
292 src_cid: ConnectionId,
293 dst_cid: ConnectionId,
294 },
295}
296
297impl Header {
298 pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
299 use Header::*;
300 let start = w.len();
301 match *self {
302 Initial(InitialHeader {
303 ref dst_cid,
304 ref src_cid,
305 ref token,
306 number,
307 version,
308 }) => {
309 w.write(u8::from(LongHeaderType::Initial) | number.tag());
310 w.write(version);
311 dst_cid.encode_long(w);
312 src_cid.encode_long(w);
313 w.write_var(token.len() as u64);
314 w.put_slice(token);
315 w.write::<u16>(0); number.encode(w);
317 PartialEncode {
318 start,
319 header_len: w.len() - start,
320 pn: Some((number.len(), true)),
321 }
322 }
323 Long {
324 ty,
325 ref dst_cid,
326 ref src_cid,
327 number,
328 version,
329 } => {
330 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
331 w.write(version);
332 dst_cid.encode_long(w);
333 src_cid.encode_long(w);
334 w.write::<u16>(0); number.encode(w);
336 PartialEncode {
337 start,
338 header_len: w.len() - start,
339 pn: Some((number.len(), true)),
340 }
341 }
342 Retry {
343 ref dst_cid,
344 ref src_cid,
345 version,
346 } => {
347 w.write(u8::from(LongHeaderType::Retry));
348 w.write(version);
349 dst_cid.encode_long(w);
350 src_cid.encode_long(w);
351 PartialEncode {
352 start,
353 header_len: w.len() - start,
354 pn: None,
355 }
356 }
357 Short {
358 spin,
359 key_phase,
360 ref dst_cid,
361 number,
362 } => {
363 w.write(
364 FIXED_BIT
365 | if key_phase { KEY_PHASE_BIT } else { 0 }
366 | if spin { SPIN_BIT } else { 0 }
367 | number.tag(),
368 );
369 w.put_slice(dst_cid);
370 number.encode(w);
371 PartialEncode {
372 start,
373 header_len: w.len() - start,
374 pn: Some((number.len(), false)),
375 }
376 }
377 VersionNegotiate {
378 ref random,
379 ref dst_cid,
380 ref src_cid,
381 } => {
382 w.write(0x80u8 | random);
383 w.write::<u32>(0);
384 dst_cid.encode_long(w);
385 src_cid.encode_long(w);
386 PartialEncode {
387 start,
388 header_len: w.len() - start,
389 pn: None,
390 }
391 }
392 }
393 }
394
395 pub(crate) fn is_protected(&self) -> bool {
397 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
398 }
399
400 pub(crate) fn number(&self) -> Option<PacketNumber> {
401 use Header::*;
402 Some(match *self {
403 Initial(InitialHeader { number, .. }) => number,
404 Long { number, .. } => number,
405 Short { number, .. } => number,
406 _ => {
407 return None;
408 }
409 })
410 }
411
412 pub(crate) fn space(&self) -> SpaceKind {
413 use Header::*;
414 match *self {
415 Short { .. } => SpaceKind::Data,
416 Long {
417 ty: LongType::ZeroRtt,
418 ..
419 } => SpaceKind::Data,
420 Long {
421 ty: LongType::Handshake,
422 ..
423 } => SpaceKind::Handshake,
424 _ => SpaceKind::Initial,
425 }
426 }
427
428 pub(crate) fn key_phase(&self) -> bool {
429 match *self {
430 Self::Short { key_phase, .. } => key_phase,
431 _ => false,
432 }
433 }
434
435 pub(crate) fn is_short(&self) -> bool {
436 matches!(*self, Self::Short { .. })
437 }
438
439 pub(crate) fn is_1rtt(&self) -> bool {
440 self.is_short()
441 }
442
443 pub(crate) fn is_0rtt(&self) -> bool {
444 matches!(
445 *self,
446 Self::Long {
447 ty: LongType::ZeroRtt,
448 ..
449 }
450 )
451 }
452
453 pub(crate) fn dst_cid(&self) -> ConnectionId {
454 use Header::*;
455 match *self {
456 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
457 Long { dst_cid, .. } => dst_cid,
458 Retry { dst_cid, .. } => dst_cid,
459 Short { dst_cid, .. } => dst_cid,
460 VersionNegotiate { dst_cid, .. } => dst_cid,
461 }
462 }
463
464 pub(crate) fn can_coalesce(&self) -> bool {
468 use Header::*;
469 match *self {
470 Initial(_) => true,
471 Long { .. } => true,
472 Retry { .. } => false,
473 Short { .. } => false,
474 VersionNegotiate { .. } => false,
475 }
476 }
477
478 pub(crate) fn has_frames(&self) -> bool {
480 use Header::*;
481 match *self {
482 Initial(_) => true,
483 Long { .. } => true,
484 Retry { .. } => false,
485 Short { .. } => true,
486 VersionNegotiate { .. } => false,
487 }
488 }
489
490 #[cfg(feature = "qlog")]
491 pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
492 match self {
493 Self::Initial(initial_header) => Some(initial_header.src_cid),
494 Self::Long { src_cid, .. } => Some(*src_cid),
495 Self::Retry { src_cid, .. } => Some(*src_cid),
496 Self::Short { .. } => None,
497 Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
498 }
499 }
500}
501
502pub(crate) struct PartialEncode {
503 pub(crate) start: usize,
504 pub(crate) header_len: usize,
505 pn: Option<(usize, bool)>,
507}
508
509impl PartialEncode {
510 pub(crate) fn finish(
511 self,
512 buf: &mut [u8],
513 header_crypto: &dyn crypto::HeaderKey,
514 crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
515 ) {
516 let Self { header_len, pn, .. } = self;
517 let (pn_len, write_len) = match pn {
518 Some((pn_len, write_len)) => (pn_len, write_len),
519 None => return,
520 };
521
522 let pn_pos = header_len - pn_len;
523 if write_len {
524 let len = buf.len() - header_len + pn_len;
525 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
527 slice.put_u16(len as u16 | (0b01 << 14));
528 }
529
530 if let Some((packet_number, path_id, crypto)) = crypto {
531 crypto.encrypt(path_id, packet_number, buf, header_len);
532 }
533
534 debug_assert!(
535 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
536 "packet must be padded to at least {} bytes for header protection sampling",
537 pn_pos + 4 + header_crypto.sample_size()
538 );
539 header_crypto.encrypt(pn_pos, buf);
540 }
541
542 #[cfg(test)]
546 pub(crate) fn no_header() -> Self {
547 Self {
548 start: 0,
549 header_len: 0,
550 pn: None,
551 }
552 }
553}
554
555#[derive(Clone, Debug)]
557pub enum ProtectedHeader {
558 Initial(ProtectedInitialHeader),
560 Long {
562 ty: LongType,
564 dst_cid: ConnectionId,
566 src_cid: ConnectionId,
568 len: u64,
570 version: u32,
572 },
573 Retry {
575 dst_cid: ConnectionId,
577 src_cid: ConnectionId,
579 version: u32,
581 },
582 Short {
584 spin: bool,
586 dst_cid: ConnectionId,
588 },
589 VersionNegotiate {
591 random: u8,
593 dst_cid: ConnectionId,
595 src_cid: ConnectionId,
597 },
598}
599
600impl ProtectedHeader {
601 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
602 match self {
603 Self::Initial(x) => Some(x),
604 _ => None,
605 }
606 }
607
608 pub fn dst_cid(&self) -> ConnectionId {
610 use ProtectedHeader::*;
611 match self {
612 Initial(header) => header.dst_cid,
613 &Long { dst_cid, .. } => dst_cid,
614 &Retry { dst_cid, .. } => dst_cid,
615 &Short { dst_cid, .. } => dst_cid,
616 &VersionNegotiate { dst_cid, .. } => dst_cid,
617 }
618 }
619
620 fn payload_len(&self) -> Option<u64> {
621 use ProtectedHeader::*;
622 match self {
623 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
624 _ => None,
625 }
626 }
627
628 pub fn decode(
630 buf: &mut io::Cursor<BytesMut>,
631 cid_parser: &(impl ConnectionIdParser + ?Sized),
632 supported_versions: &[u32],
633 grease_quic_bit: bool,
634 ) -> Result<Self, PacketDecodeError> {
635 let first = buf.get::<u8>()?;
636 if !grease_quic_bit && first & FIXED_BIT == 0 {
637 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
638 }
639 if first & LONG_HEADER_FORM == 0 {
640 let spin = first & SPIN_BIT != 0;
641
642 Ok(Self::Short {
643 spin,
644 dst_cid: cid_parser.parse(buf)?,
645 })
646 } else {
647 let version = buf.get::<u32>()?;
648
649 let dst_cid = ConnectionId::decode_long(buf)
650 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
651 let src_cid = ConnectionId::decode_long(buf)
652 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
653
654 if version == 0 {
656 let random = first & !LONG_HEADER_FORM;
657 return Ok(Self::VersionNegotiate {
658 random,
659 dst_cid,
660 src_cid,
661 });
662 }
663
664 if !supported_versions.contains(&version) {
665 return Err(PacketDecodeError::UnsupportedVersion {
666 src_cid,
667 dst_cid,
668 version,
669 });
670 }
671
672 match LongHeaderType::from_byte(first)? {
673 LongHeaderType::Initial => {
674 let token_len = buf.get_var()? as usize;
675 let token_start = buf.position() as usize;
676 if token_len > buf.remaining() {
677 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
678 }
679 buf.advance(token_len);
680
681 let len = buf.get_var()?;
682 Ok(Self::Initial(ProtectedInitialHeader {
683 dst_cid,
684 src_cid,
685 token_pos: token_start..token_start + token_len,
686 len,
687 version,
688 }))
689 }
690 LongHeaderType::Retry => Ok(Self::Retry {
691 dst_cid,
692 src_cid,
693 version,
694 }),
695 LongHeaderType::Standard(ty) => Ok(Self::Long {
696 ty,
697 dst_cid,
698 src_cid,
699 len: buf.get_var()?,
700 version,
701 }),
702 }
703 }
704 }
705}
706
707#[derive(Clone, Debug)]
709pub struct ProtectedInitialHeader {
710 pub dst_cid: ConnectionId,
712 pub src_cid: ConnectionId,
714 pub token_pos: Range<usize>,
716 pub len: u64,
718 pub version: u32,
720}
721
722#[derive(Clone, Debug)]
723pub(crate) struct InitialHeader {
724 pub(crate) dst_cid: ConnectionId,
725 pub(crate) src_cid: ConnectionId,
726 pub(crate) token: Bytes,
727 pub(crate) number: PacketNumber,
728 pub(crate) version: u32,
729}
730
731#[derive(Debug, Copy, Clone, Eq, PartialEq)]
733pub(crate) enum PacketNumber {
734 U8(u8),
735 U16(u16),
736 U24(u32),
737 U32(u32),
738}
739
740impl PacketNumber {
741 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
742 let range = (n - largest_acked) * 2;
743 if range < 1 << 8 {
744 Self::U8(n as u8)
745 } else if range < 1 << 16 {
746 Self::U16(n as u16)
747 } else if range < 1 << 24 {
748 Self::U24(n as u32)
749 } else if range < 1 << 32 {
750 Self::U32(n as u32)
751 } else {
752 panic!("packet number too large to encode")
753 }
754 }
755
756 pub(crate) fn len(self) -> usize {
757 use PacketNumber::*;
758 match self {
759 U8(_) => 1,
760 U16(_) => 2,
761 U24(_) => 3,
762 U32(_) => 4,
763 }
764 }
765
766 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
767 use PacketNumber::*;
768 match self {
769 U8(x) => w.write(x),
770 U16(x) => w.write(x),
771 U24(x) => w.put_uint(u64::from(x), 3),
772 U32(x) => w.write(x),
773 }
774 }
775
776 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
777 use PacketNumber::*;
778 let pn = match len {
779 1 => U8(r.get()?),
780 2 => U16(r.get()?),
781 3 => U24(r.get_uint(3) as u32),
782 4 => U32(r.get()?),
783 _ => unreachable!(),
784 };
785 Ok(pn)
786 }
787
788 pub(crate) fn decode_len(tag: u8) -> usize {
789 1 + (tag & 0x03) as usize
790 }
791
792 fn tag(self) -> u8 {
793 use PacketNumber::*;
794 match self {
795 U8(_) => 0b00,
796 U16(_) => 0b01,
797 U24(_) => 0b10,
798 U32(_) => 0b11,
799 }
800 }
801
802 pub(crate) fn expand(self, expected: u64) -> u64 {
803 use PacketNumber::*;
805 let truncated = match self {
806 U8(x) => u64::from(x),
807 U16(x) => u64::from(x),
808 U24(x) => u64::from(x),
809 U32(x) => u64::from(x),
810 };
811 let nbits = self.len() * 8;
812 let win = 1 << nbits;
813 let hwin = win / 2;
814 let mask = win - 1;
815 let candidate = (expected & !mask) | truncated;
824 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
825 candidate + win
826 } else if candidate > expected + hwin && candidate > win {
827 candidate - win
828 } else {
829 candidate
830 }
831 }
832}
833
834pub struct FixedLengthConnectionIdParser {
836 expected_len: usize,
837}
838
839impl FixedLengthConnectionIdParser {
840 pub fn new(expected_len: usize) -> Self {
842 Self { expected_len }
843 }
844}
845
846impl ConnectionIdParser for FixedLengthConnectionIdParser {
847 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
848 (buffer.remaining() >= self.expected_len)
849 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
850 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
851 }
852}
853
854pub trait ConnectionIdParser {
856 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
858}
859
860#[derive(Clone, Copy, Debug, Eq, PartialEq)]
862pub(crate) enum LongHeaderType {
863 Initial,
864 Retry,
865 Standard(LongType),
866}
867
868impl LongHeaderType {
869 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
870 use {LongHeaderType::*, LongType::*};
871 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
872 Ok(match (b & 0x30) >> 4 {
873 0x0 => Initial,
874 0x1 => Standard(ZeroRtt),
875 0x2 => Standard(Handshake),
876 0x3 => Retry,
877 _ => unreachable!(),
878 })
879 }
880}
881
882impl From<LongHeaderType> for u8 {
883 fn from(ty: LongHeaderType) -> Self {
884 use {LongHeaderType::*, LongType::*};
885 match ty {
886 Initial => LONG_HEADER_FORM | FIXED_BIT,
887 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
888 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
889 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
890 }
891 }
892}
893
894#[derive(Clone, Copy, Debug, Eq, PartialEq)]
896pub enum LongType {
897 Handshake,
899 ZeroRtt,
901}
902
903#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
905pub enum PacketDecodeError {
906 #[error("unsupported version {version:x}")]
908 UnsupportedVersion {
909 src_cid: ConnectionId,
911 dst_cid: ConnectionId,
913 version: u32,
915 },
916 #[error("invalid header: {0}")]
918 InvalidHeader(&'static str),
919}
920
921impl From<coding::UnexpectedEnd> for PacketDecodeError {
922 fn from(_: coding::UnexpectedEnd) -> Self {
923 Self::InvalidHeader("unexpected end of packet")
924 }
925}
926
927pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
928pub(crate) const FIXED_BIT: u8 = 0x40;
929pub(crate) const SPIN_BIT: u8 = 0x20;
930const SHORT_RESERVED_BITS: u8 = 0x18;
931const LONG_RESERVED_BITS: u8 = 0x0c;
932const KEY_PHASE_BIT: u8 = 0x04;
933
934#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
936pub(crate) enum SpaceId {
937 Initial = 0,
939 Handshake = 1,
940 Data = 2,
942}
943
944impl SpaceId {
945 pub(crate) fn iter() -> impl Iterator<Item = Self> {
946 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
947 }
948
949 pub(crate) fn next(&self) -> Option<Self> {
953 match self {
954 Self::Initial => Some(Self::Handshake),
955 Self::Handshake => Some(Self::Data),
956 Self::Data => None,
957 }
958 }
959
960 pub(crate) fn encryption_level(self) -> EncryptionLevel {
962 match self {
963 Self::Initial => EncryptionLevel::Initial,
964 Self::Handshake => EncryptionLevel::Handshake,
965 Self::Data => EncryptionLevel::OneRtt,
966 }
967 }
968
969 pub(crate) fn kind(self) -> SpaceKind {
971 match self {
972 Self::Initial => SpaceKind::Initial,
973 Self::Handshake => SpaceKind::Handshake,
974 Self::Data => SpaceKind::Data,
975 }
976 }
977}
978
979#[cfg(test)]
980mod tests {
981 use super::*;
982 use hex_literal::hex;
983 use std::io;
984
985 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
986 let mut buf = Vec::new();
987 typed.encode(&mut buf);
988 assert_eq!(&buf[..], encoded);
989 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
990 assert_eq!(typed, decoded);
991 }
992
993 #[test]
994 fn roundtrip_packet_numbers() {
995 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
996 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
997 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
998 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
999 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
1000 }
1001
1002 #[test]
1003 fn pn_encode() {
1004 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
1005 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
1006 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
1007 }
1008
1009 #[test]
1010 fn pn_expand_roundtrip() {
1011 for expected in 0..1024 {
1012 for actual in expected..1024 {
1013 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
1014 }
1015 }
1016 }
1017
1018 #[cfg(all(feature = "rustls", any(feature = "aws-lc-rs", feature = "ring")))]
1019 #[test]
1020 fn header_encoding() {
1021 use crate::Side;
1022 use crate::crypto::rustls::{configured_provider, initial_keys, initial_suite_from_provider};
1023 use rustls::quic::Version;
1024
1025 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
1026 let provider = configured_provider();
1027
1028 let suite = initial_suite_from_provider(&provider).unwrap();
1029 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
1030 let mut buf = Vec::new();
1031 let header = Header::Initial(InitialHeader {
1032 number: PacketNumber::U8(0),
1033 src_cid: ConnectionId::new(&[]),
1034 dst_cid: dcid,
1035 token: Bytes::new(),
1036 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1037 });
1038 let encode = header.encode(&mut buf);
1039 let header_len = buf.len();
1040 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1041 encode.finish(
1042 &mut buf,
1043 &*client.header.local,
1044 Some((0, PathId::ZERO, &*client.packet.local)),
1045 );
1046
1047 for byte in &buf {
1048 print!("{byte:02x}");
1049 }
1050 println!();
1051 assert_eq!(
1052 buf[..],
1053 hex!(
1054 "c8000000010806b858ec6f80452b00004021be
1055 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1056 )[..]
1057 );
1058
1059 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1060 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1061 let decode = PartialDecode::new(
1062 buf.as_slice().into(),
1063 &FixedLengthConnectionIdParser::new(0),
1064 &supported_versions,
1065 false,
1066 )
1067 .unwrap()
1068 .0;
1069 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1070 assert_eq!(
1071 packet.header_data[..],
1072 hex!("c0000000010806b858ec6f80452b0000402100")[..]
1073 );
1074 server
1075 .packet
1076 .remote
1077 .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1078 .unwrap();
1079 assert_eq!(packet.payload[..], [0; 16]);
1080 match packet.header {
1081 Header::Initial(InitialHeader {
1082 number: PacketNumber::U8(0),
1083 ..
1084 }) => {}
1085 _ => {
1086 panic!("unexpected header {:?}", packet.header);
1087 }
1088 }
1089 }
1090}