1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7 ConnectionId, PathId,
8 coding::{self, BufExt, BufMutExt},
9 connection::{EncryptionLevel, SpaceKind},
10 crypto,
11};
12
13#[cfg_attr(test, derive(Clone))]
26#[derive(Debug)]
27pub struct PartialDecode {
28 plain_header: ProtectedHeader,
29 buf: io::Cursor<BytesMut>,
30}
31
32#[allow(clippy::len_without_is_empty)]
33impl PartialDecode {
34 pub fn new(
36 bytes: BytesMut,
37 cid_parser: &(impl ConnectionIdParser + ?Sized),
38 supported_versions: &[u32],
39 grease_quic_bit: bool,
40 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
41 let mut buf = io::Cursor::new(bytes);
42 let plain_header =
43 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
44 let dgram_len = buf.get_ref().len();
45 let packet_len = plain_header
46 .payload_len()
47 .map(|len| (buf.position() + len) as usize)
48 .unwrap_or(dgram_len);
49 match dgram_len.cmp(&packet_len) {
50 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
51 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
52 "packet too short to contain payload length",
53 )),
54 Ordering::Greater => {
55 let rest = Some(buf.get_mut().split_off(packet_len));
56 Ok((Self { plain_header, buf }, rest))
57 }
58 }
59 }
60
61 pub(crate) fn data(&self) -> &[u8] {
63 self.buf.get_ref()
64 }
65
66 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
67 self.plain_header.as_initial()
68 }
69
70 pub(crate) fn has_long_header(&self) -> bool {
71 !matches!(self.plain_header, ProtectedHeader::Short { .. })
72 }
73
74 pub(crate) fn is_initial(&self) -> bool {
75 self.encryption_level() == Some(EncryptionLevel::Initial)
76 }
77
78 pub(crate) fn encryption_level(&self) -> Option<EncryptionLevel> {
79 use ProtectedHeader::*;
80 match self.plain_header {
81 Initial { .. } => Some(EncryptionLevel::Initial),
82 Long { ty, .. } => Some(match ty {
83 LongType::Handshake => EncryptionLevel::Handshake,
84 LongType::ZeroRtt => EncryptionLevel::ZeroRtt,
85 }),
86 Short { .. } => Some(EncryptionLevel::OneRtt),
87 _ => None,
88 }
89 }
90
91 pub(crate) fn is_0rtt(&self) -> bool {
92 match self.plain_header {
93 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
94 _ => false,
95 }
96 }
97
98 pub fn dst_cid(&self) -> ConnectionId {
100 self.plain_header.dst_cid()
101 }
102
103 #[allow(unreachable_pub)] pub fn len(&self) -> usize {
106 self.buf.get_ref().len()
107 }
108
109 pub(crate) fn finish(
110 self,
111 header_crypto: Option<&dyn crypto::HeaderKey>,
112 ) -> Result<Packet, PacketDecodeError> {
113 use ProtectedHeader::*;
114 let Self {
115 plain_header,
116 mut buf,
117 } = self;
118
119 if let Initial(ProtectedInitialHeader {
120 dst_cid,
121 src_cid,
122 token_pos,
123 version,
124 ..
125 }) = plain_header
126 {
127 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
128 let header_len = buf.position() as usize;
129 let mut bytes = buf.into_inner();
130
131 let header_data = bytes.split_to(header_len).freeze();
132 let token = header_data.slice(token_pos.start..token_pos.end);
133 return Ok(Packet {
134 header: Header::Initial(InitialHeader {
135 dst_cid,
136 src_cid,
137 token,
138 number,
139 version,
140 }),
141 header_data,
142 payload: bytes,
143 });
144 }
145
146 let header = match plain_header {
147 Long {
148 ty,
149 dst_cid,
150 src_cid,
151 version,
152 ..
153 } => Header::Long {
154 ty,
155 dst_cid,
156 src_cid,
157 number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
158 version,
159 },
160 Retry {
161 dst_cid,
162 src_cid,
163 version,
164 } => Header::Retry {
165 dst_cid,
166 src_cid,
167 version,
168 },
169 Short { spin, dst_cid, .. } => {
170 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
171 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
172 Header::Short {
173 spin,
174 key_phase,
175 dst_cid,
176 number,
177 }
178 }
179 VersionNegotiate {
180 random,
181 dst_cid,
182 src_cid,
183 } => Header::VersionNegotiate {
184 random,
185 dst_cid,
186 src_cid,
187 },
188 Initial { .. } => unreachable!(),
189 };
190
191 let header_len = buf.position() as usize;
192 let mut bytes = buf.into_inner();
193 Ok(Packet {
194 header,
195 header_data: bytes.split_to(header_len).freeze(),
196 payload: bytes,
197 })
198 }
199
200 fn decrypt_header(
201 buf: &mut io::Cursor<BytesMut>,
202 header_crypto: &dyn crypto::HeaderKey,
203 ) -> Result<PacketNumber, PacketDecodeError> {
204 let packet_length = buf.get_ref().len();
205 let pn_offset = buf.position() as usize;
206 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
207 return Err(PacketDecodeError::InvalidHeader(
208 "packet too short to extract header protection sample",
209 ));
210 }
211
212 header_crypto.decrypt(pn_offset, buf.get_mut());
213
214 let len = PacketNumber::decode_len(buf.get_ref()[0]);
215 PacketNumber::decode(len, buf)
216 }
217}
218
219pub(crate) trait BufLen {
226 fn len(&self) -> usize;
228}
229
230impl BufLen for Vec<u8> {
231 fn len(&self) -> usize {
232 self.len()
233 }
234}
235
236pub(crate) struct Packet {
237 pub(crate) header: Header,
238 pub(crate) header_data: Bytes,
239 pub(crate) payload: BytesMut,
240}
241
242impl Packet {
243 pub(crate) fn reserved_bits_valid(&self) -> bool {
244 let mask = match self.header {
245 Header::Short { .. } => SHORT_RESERVED_BITS,
246 _ => LONG_RESERVED_BITS,
247 };
248 self.header_data[0] & mask == 0
249 }
250}
251
252pub(crate) struct InitialPacket {
253 pub(crate) header: InitialHeader,
254 pub(crate) header_data: Bytes,
255 pub(crate) payload: BytesMut,
256}
257
258impl From<InitialPacket> for Packet {
259 fn from(x: InitialPacket) -> Self {
260 Self {
261 header: Header::Initial(x.header),
262 header_data: x.header_data,
263 payload: x.payload,
264 }
265 }
266}
267
268#[cfg_attr(test, derive(Clone))]
269#[derive(Debug)]
270pub(crate) enum Header {
271 Initial(InitialHeader),
272 Long {
273 ty: LongType,
274 dst_cid: ConnectionId,
275 src_cid: ConnectionId,
276 number: PacketNumber,
277 version: u32,
278 },
279 Retry {
280 dst_cid: ConnectionId,
281 src_cid: ConnectionId,
282 version: u32,
283 },
284 Short {
285 spin: bool,
286 key_phase: bool,
287 dst_cid: ConnectionId,
288 number: PacketNumber,
289 },
290 VersionNegotiate {
291 random: u8,
292 src_cid: ConnectionId,
293 dst_cid: ConnectionId,
294 },
295}
296
297impl Header {
298 pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
299 use Header::*;
300 let start = w.len();
301 match *self {
302 Initial(InitialHeader {
303 ref dst_cid,
304 ref src_cid,
305 ref token,
306 number,
307 version,
308 }) => {
309 w.write(u8::from(LongHeaderType::Initial) | number.tag());
310 w.write(version);
311 dst_cid.encode_long(w);
312 src_cid.encode_long(w);
313 w.write_var(token.len() as u64);
314 w.put_slice(token);
315 w.write::<u16>(0); number.encode(w);
317 PartialEncode {
318 start,
319 header_len: w.len() - start,
320 pn: Some((number.len(), true)),
321 }
322 }
323 Long {
324 ty,
325 ref dst_cid,
326 ref src_cid,
327 number,
328 version,
329 } => {
330 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
331 w.write(version);
332 dst_cid.encode_long(w);
333 src_cid.encode_long(w);
334 w.write::<u16>(0); number.encode(w);
336 PartialEncode {
337 start,
338 header_len: w.len() - start,
339 pn: Some((number.len(), true)),
340 }
341 }
342 Retry {
343 ref dst_cid,
344 ref src_cid,
345 version,
346 } => {
347 w.write(u8::from(LongHeaderType::Retry));
348 w.write(version);
349 dst_cid.encode_long(w);
350 src_cid.encode_long(w);
351 PartialEncode {
352 start,
353 header_len: w.len() - start,
354 pn: None,
355 }
356 }
357 Short {
358 spin,
359 key_phase,
360 ref dst_cid,
361 number,
362 } => {
363 w.write(
364 FIXED_BIT
365 | if key_phase { KEY_PHASE_BIT } else { 0 }
366 | if spin { SPIN_BIT } else { 0 }
367 | number.tag(),
368 );
369 w.put_slice(dst_cid);
370 number.encode(w);
371 PartialEncode {
372 start,
373 header_len: w.len() - start,
374 pn: Some((number.len(), false)),
375 }
376 }
377 VersionNegotiate {
378 ref random,
379 ref dst_cid,
380 ref src_cid,
381 } => {
382 w.write(0x80u8 | random);
383 w.write::<u32>(0);
384 dst_cid.encode_long(w);
385 src_cid.encode_long(w);
386 PartialEncode {
387 start,
388 header_len: w.len() - start,
389 pn: None,
390 }
391 }
392 }
393 }
394
395 pub(crate) fn is_protected(&self) -> bool {
397 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
398 }
399
400 pub(crate) fn number(&self) -> Option<PacketNumber> {
401 use Header::*;
402 Some(match *self {
403 Initial(InitialHeader { number, .. }) => number,
404 Long { number, .. } => number,
405 Short { number, .. } => number,
406 _ => {
407 return None;
408 }
409 })
410 }
411
412 pub(crate) fn space(&self) -> SpaceKind {
413 use Header::*;
414 match *self {
415 Short { .. } => SpaceKind::Data,
416 Long {
417 ty: LongType::ZeroRtt,
418 ..
419 } => SpaceKind::Data,
420 Long {
421 ty: LongType::Handshake,
422 ..
423 } => SpaceKind::Handshake,
424 _ => SpaceKind::Initial,
425 }
426 }
427
428 pub(crate) fn key_phase(&self) -> bool {
429 match *self {
430 Self::Short { key_phase, .. } => key_phase,
431 _ => false,
432 }
433 }
434
435 pub(crate) fn is_short(&self) -> bool {
436 matches!(*self, Self::Short { .. })
437 }
438
439 pub(crate) fn is_1rtt(&self) -> bool {
440 self.is_short()
441 }
442
443 pub(crate) fn is_0rtt(&self) -> bool {
444 matches!(
445 *self,
446 Self::Long {
447 ty: LongType::ZeroRtt,
448 ..
449 }
450 )
451 }
452
453 pub(crate) fn dst_cid(&self) -> ConnectionId {
454 use Header::*;
455 match *self {
456 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
457 Long { dst_cid, .. } => dst_cid,
458 Retry { dst_cid, .. } => dst_cid,
459 Short { dst_cid, .. } => dst_cid,
460 VersionNegotiate { dst_cid, .. } => dst_cid,
461 }
462 }
463
464 pub(crate) fn can_coalesce(&self) -> bool {
468 use Header::*;
469 match *self {
470 Initial(_) => true,
471 Long { .. } => true,
472 Retry { .. } => false,
473 Short { .. } => false,
474 VersionNegotiate { .. } => false,
475 }
476 }
477
478 pub(crate) fn has_frames(&self) -> bool {
480 use Header::*;
481 match *self {
482 Initial(_) => true,
483 Long { .. } => true,
484 Retry { .. } => false,
485 Short { .. } => true,
486 VersionNegotiate { .. } => false,
487 }
488 }
489
490 #[cfg(feature = "qlog")]
491 pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
492 match self {
493 Self::Initial(initial_header) => Some(initial_header.src_cid),
494 Self::Long { src_cid, .. } => Some(*src_cid),
495 Self::Retry { src_cid, .. } => Some(*src_cid),
496 Self::Short { .. } => None,
497 Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
498 }
499 }
500}
501
502pub(crate) struct PartialEncode {
503 pub(crate) start: usize,
504 pub(crate) header_len: usize,
505 pn: Option<(usize, bool)>,
507}
508
509impl PartialEncode {
510 pub(crate) fn finish(
511 self,
512 buf: &mut [u8],
513 header_crypto: &dyn crypto::HeaderKey,
514 crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
515 ) {
516 let Self { header_len, pn, .. } = self;
517 let Some((pn_len, write_len)) = pn else {
518 return;
519 };
520
521 let pn_pos = header_len - pn_len;
522 if write_len {
523 let len = buf.len() - header_len + pn_len;
524 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
526 slice.put_u16(len as u16 | (0b01 << 14));
527 }
528
529 if let Some((packet_number, path_id, crypto)) = crypto {
530 crypto.encrypt(path_id, packet_number, buf, header_len);
531 }
532
533 debug_assert!(
534 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
535 "packet must be padded to at least {} bytes for header protection sampling",
536 pn_pos + 4 + header_crypto.sample_size()
537 );
538 header_crypto.encrypt(pn_pos, buf);
539 }
540
541 #[cfg(test)]
545 pub(crate) fn no_header() -> Self {
546 Self {
547 start: 0,
548 header_len: 0,
549 pn: None,
550 }
551 }
552}
553
554#[derive(Clone, Debug)]
556pub enum ProtectedHeader {
557 Initial(ProtectedInitialHeader),
559 Long {
561 ty: LongType,
563 dst_cid: ConnectionId,
565 src_cid: ConnectionId,
567 len: u64,
569 version: u32,
571 },
572 Retry {
574 dst_cid: ConnectionId,
576 src_cid: ConnectionId,
578 version: u32,
580 },
581 Short {
583 spin: bool,
585 dst_cid: ConnectionId,
587 },
588 VersionNegotiate {
590 random: u8,
592 dst_cid: ConnectionId,
594 src_cid: ConnectionId,
596 },
597}
598
599impl ProtectedHeader {
600 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
601 match self {
602 Self::Initial(x) => Some(x),
603 _ => None,
604 }
605 }
606
607 pub fn dst_cid(&self) -> ConnectionId {
609 use ProtectedHeader::*;
610 match self {
611 Initial(header) => header.dst_cid,
612 &Long { dst_cid, .. } => dst_cid,
613 &Retry { dst_cid, .. } => dst_cid,
614 &Short { dst_cid, .. } => dst_cid,
615 &VersionNegotiate { dst_cid, .. } => dst_cid,
616 }
617 }
618
619 fn payload_len(&self) -> Option<u64> {
620 use ProtectedHeader::*;
621 match self {
622 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
623 _ => None,
624 }
625 }
626
627 pub fn decode(
629 buf: &mut io::Cursor<BytesMut>,
630 cid_parser: &(impl ConnectionIdParser + ?Sized),
631 supported_versions: &[u32],
632 grease_quic_bit: bool,
633 ) -> Result<Self, PacketDecodeError> {
634 let first = buf.get::<u8>()?;
635 if !grease_quic_bit && first & FIXED_BIT == 0 {
636 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
637 }
638 if first & LONG_HEADER_FORM == 0 {
639 let spin = first & SPIN_BIT != 0;
640
641 Ok(Self::Short {
642 spin,
643 dst_cid: cid_parser.parse(buf)?,
644 })
645 } else {
646 let version = buf.get::<u32>()?;
647
648 let dst_cid = ConnectionId::decode_long(buf)
649 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
650 let src_cid = ConnectionId::decode_long(buf)
651 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
652
653 if version == 0 {
655 let random = first & !LONG_HEADER_FORM;
656 return Ok(Self::VersionNegotiate {
657 random,
658 dst_cid,
659 src_cid,
660 });
661 }
662
663 if !supported_versions.contains(&version) {
664 return Err(PacketDecodeError::UnsupportedVersion {
665 src_cid,
666 dst_cid,
667 version,
668 });
669 }
670
671 match LongHeaderType::from_byte(first)? {
672 LongHeaderType::Initial => {
673 let token_len = buf.get_var()? as usize;
674 let token_start = buf.position() as usize;
675 if token_len > buf.remaining() {
676 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
677 }
678 buf.advance(token_len);
679
680 let len = buf.get_var()?;
681 Ok(Self::Initial(ProtectedInitialHeader {
682 dst_cid,
683 src_cid,
684 token_pos: token_start..token_start + token_len,
685 len,
686 version,
687 }))
688 }
689 LongHeaderType::Retry => Ok(Self::Retry {
690 dst_cid,
691 src_cid,
692 version,
693 }),
694 LongHeaderType::Standard(ty) => Ok(Self::Long {
695 ty,
696 dst_cid,
697 src_cid,
698 len: buf.get_var()?,
699 version,
700 }),
701 }
702 }
703 }
704}
705
706#[derive(Clone, Debug)]
708pub struct ProtectedInitialHeader {
709 pub dst_cid: ConnectionId,
711 pub src_cid: ConnectionId,
713 pub token_pos: Range<usize>,
715 pub len: u64,
717 pub version: u32,
719}
720
721#[derive(Clone, Debug)]
722pub(crate) struct InitialHeader {
723 pub(crate) dst_cid: ConnectionId,
724 pub(crate) src_cid: ConnectionId,
725 pub(crate) token: Bytes,
726 pub(crate) number: PacketNumber,
727 pub(crate) version: u32,
728}
729
730#[derive(Debug, Copy, Clone, Eq, PartialEq)]
732pub(crate) enum PacketNumber {
733 U8(u8),
734 U16(u16),
735 U24(u32),
736 U32(u32),
737}
738
739impl PacketNumber {
740 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
741 let range = (n - largest_acked) * 2;
742 if range < 1 << 8 {
743 Self::U8(n as u8)
744 } else if range < 1 << 16 {
745 Self::U16(n as u16)
746 } else if range < 1 << 24 {
747 Self::U24(n as u32)
748 } else if range < 1 << 32 {
749 Self::U32(n as u32)
750 } else {
751 panic!("packet number too large to encode")
752 }
753 }
754
755 pub(crate) fn len(self) -> usize {
756 use PacketNumber::*;
757 match self {
758 U8(_) => 1,
759 U16(_) => 2,
760 U24(_) => 3,
761 U32(_) => 4,
762 }
763 }
764
765 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
766 use PacketNumber::*;
767 match self {
768 U8(x) => w.write(x),
769 U16(x) => w.write(x),
770 U24(x) => w.put_uint(u64::from(x), 3),
771 U32(x) => w.write(x),
772 }
773 }
774
775 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
776 use PacketNumber::*;
777 let pn = match len {
778 1 => U8(r.get()?),
779 2 => U16(r.get()?),
780 3 => U24(r.get_uint(3) as u32),
781 4 => U32(r.get()?),
782 _ => unreachable!(),
783 };
784 Ok(pn)
785 }
786
787 pub(crate) fn decode_len(tag: u8) -> usize {
788 1 + (tag & 0x03) as usize
789 }
790
791 fn tag(self) -> u8 {
792 use PacketNumber::*;
793 match self {
794 U8(_) => 0b00,
795 U16(_) => 0b01,
796 U24(_) => 0b10,
797 U32(_) => 0b11,
798 }
799 }
800
801 pub(crate) fn expand(self, expected: u64) -> u64 {
802 use PacketNumber::*;
804 let truncated = match self {
805 U8(x) => u64::from(x),
806 U16(x) => u64::from(x),
807 U24(x) => u64::from(x),
808 U32(x) => u64::from(x),
809 };
810 let nbits = self.len() * 8;
811 let win = 1 << nbits;
812 let hwin = win / 2;
813 let mask = win - 1;
814 let candidate = (expected & !mask) | truncated;
823 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
824 candidate + win
825 } else if candidate > expected + hwin && candidate > win {
826 candidate - win
827 } else {
828 candidate
829 }
830 }
831}
832
833pub struct FixedLengthConnectionIdParser {
835 expected_len: usize,
836}
837
838impl FixedLengthConnectionIdParser {
839 pub fn new(expected_len: usize) -> Self {
841 Self { expected_len }
842 }
843}
844
845impl ConnectionIdParser for FixedLengthConnectionIdParser {
846 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
847 (buffer.remaining() >= self.expected_len)
848 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
849 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
850 }
851}
852
853pub trait ConnectionIdParser {
855 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
857}
858
859#[derive(Clone, Copy, Debug, Eq, PartialEq)]
861pub(crate) enum LongHeaderType {
862 Initial,
863 Retry,
864 Standard(LongType),
865}
866
867impl LongHeaderType {
868 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
869 use {LongHeaderType::*, LongType::*};
870 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
871 Ok(match (b & 0x30) >> 4 {
872 0x0 => Initial,
873 0x1 => Standard(ZeroRtt),
874 0x2 => Standard(Handshake),
875 0x3 => Retry,
876 _ => unreachable!(),
877 })
878 }
879}
880
881impl From<LongHeaderType> for u8 {
882 fn from(ty: LongHeaderType) -> Self {
883 use {LongHeaderType::*, LongType::*};
884 match ty {
885 Initial => LONG_HEADER_FORM | FIXED_BIT,
886 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
887 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
888 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
889 }
890 }
891}
892
893#[derive(Clone, Copy, Debug, Eq, PartialEq)]
895pub enum LongType {
896 Handshake,
898 ZeroRtt,
900}
901
902#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
904pub enum PacketDecodeError {
905 #[error("unsupported version {version:x}")]
907 UnsupportedVersion {
908 src_cid: ConnectionId,
910 dst_cid: ConnectionId,
912 version: u32,
914 },
915 #[error("invalid header: {0}")]
917 InvalidHeader(&'static str),
918}
919
920impl From<coding::UnexpectedEnd> for PacketDecodeError {
921 fn from(_: coding::UnexpectedEnd) -> Self {
922 Self::InvalidHeader("unexpected end of packet")
923 }
924}
925
926pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
927pub(crate) const FIXED_BIT: u8 = 0x40;
928pub(crate) const SPIN_BIT: u8 = 0x20;
929const SHORT_RESERVED_BITS: u8 = 0x18;
930const LONG_RESERVED_BITS: u8 = 0x0c;
931const KEY_PHASE_BIT: u8 = 0x04;
932
933#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
935pub(crate) enum SpaceId {
936 Initial = 0,
938 Handshake = 1,
939 Data = 2,
941}
942
943impl SpaceId {
944 pub(crate) fn iter() -> impl Iterator<Item = Self> {
945 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
946 }
947
948 pub(crate) fn next(&self) -> Option<Self> {
952 match self {
953 Self::Initial => Some(Self::Handshake),
954 Self::Handshake => Some(Self::Data),
955 Self::Data => None,
956 }
957 }
958
959 pub(crate) fn encryption_level(self) -> EncryptionLevel {
961 match self {
962 Self::Initial => EncryptionLevel::Initial,
963 Self::Handshake => EncryptionLevel::Handshake,
964 Self::Data => EncryptionLevel::OneRtt,
965 }
966 }
967
968 pub(crate) fn kind(self) -> SpaceKind {
970 match self {
971 Self::Initial => SpaceKind::Initial,
972 Self::Handshake => SpaceKind::Handshake,
973 Self::Data => SpaceKind::Data,
974 }
975 }
976}
977
978#[cfg(test)]
979mod tests {
980 use super::*;
981 use hex_literal::hex;
982 use std::io;
983
984 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
985 let mut buf = Vec::new();
986 typed.encode(&mut buf);
987 assert_eq!(&buf[..], encoded);
988 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
989 assert_eq!(typed, decoded);
990 }
991
992 #[test]
993 fn roundtrip_packet_numbers() {
994 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
995 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
996 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
997 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
998 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
999 }
1000
1001 #[test]
1002 fn pn_encode() {
1003 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
1004 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
1005 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
1006 }
1007
1008 #[test]
1009 fn pn_expand_roundtrip() {
1010 for expected in 0..1024 {
1011 for actual in expected..1024 {
1012 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
1013 }
1014 }
1015 }
1016
1017 #[cfg(all(feature = "rustls", any(feature = "aws-lc-rs", feature = "ring")))]
1018 #[test]
1019 fn header_encoding() {
1020 use crate::Side;
1021 use crate::crypto::rustls::{configured_provider, initial_keys, initial_suite_from_provider};
1022 use rustls::quic::Version;
1023
1024 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
1025 let provider = configured_provider();
1026
1027 let suite = initial_suite_from_provider(&provider).unwrap();
1028 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
1029 let mut buf = Vec::new();
1030 let header = Header::Initial(InitialHeader {
1031 number: PacketNumber::U8(0),
1032 src_cid: ConnectionId::new(&[]),
1033 dst_cid: dcid,
1034 token: Bytes::new(),
1035 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1036 });
1037 let encode = header.encode(&mut buf);
1038 let header_len = buf.len();
1039 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1040 encode.finish(
1041 &mut buf,
1042 &*client.header.local,
1043 Some((0, PathId::ZERO, &*client.packet.local)),
1044 );
1045
1046 for byte in &buf {
1047 print!("{byte:02x}");
1048 }
1049 println!();
1050 assert_eq!(
1051 buf[..],
1052 hex!(
1053 "c8000000010806b858ec6f80452b00004021be
1054 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1055 )[..]
1056 );
1057
1058 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1059 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1060 let decode = PartialDecode::new(
1061 buf.as_slice().into(),
1062 &FixedLengthConnectionIdParser::new(0),
1063 &supported_versions,
1064 false,
1065 )
1066 .unwrap()
1067 .0;
1068 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1069 assert_eq!(
1070 packet.header_data[..],
1071 hex!("c0000000010806b858ec6f80452b0000402100")[..]
1072 );
1073 server
1074 .packet
1075 .remote
1076 .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1077 .unwrap();
1078 assert_eq!(packet.payload[..], [0; 16]);
1079 match packet.header {
1080 Header::Initial(InitialHeader {
1081 number: PacketNumber::U8(0),
1082 ..
1083 }) => {}
1084 _ => {
1085 panic!("unexpected header {:?}", packet.header);
1086 }
1087 }
1088 }
1089}