1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7 ConnectionId, PathId,
8 coding::{self, BufExt, BufMutExt},
9 crypto,
10};
11
12#[cfg_attr(test, derive(Clone))]
25#[derive(Debug)]
26pub struct PartialDecode {
27 plain_header: ProtectedHeader,
28 buf: io::Cursor<BytesMut>,
29}
30
31#[allow(clippy::len_without_is_empty)]
32impl PartialDecode {
33 pub fn new(
35 bytes: BytesMut,
36 cid_parser: &(impl ConnectionIdParser + ?Sized),
37 supported_versions: &[u32],
38 grease_quic_bit: bool,
39 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
40 let mut buf = io::Cursor::new(bytes);
41 let plain_header =
42 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
43 let dgram_len = buf.get_ref().len();
44 let packet_len = plain_header
45 .payload_len()
46 .map(|len| (buf.position() + len) as usize)
47 .unwrap_or(dgram_len);
48 match dgram_len.cmp(&packet_len) {
49 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
50 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
51 "packet too short to contain payload length",
52 )),
53 Ordering::Greater => {
54 let rest = Some(buf.get_mut().split_off(packet_len));
55 Ok((Self { plain_header, buf }, rest))
56 }
57 }
58 }
59
60 pub(crate) fn data(&self) -> &[u8] {
62 self.buf.get_ref()
63 }
64
65 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
66 self.plain_header.as_initial()
67 }
68
69 pub(crate) fn has_long_header(&self) -> bool {
70 !matches!(self.plain_header, ProtectedHeader::Short { .. })
71 }
72
73 pub(crate) fn is_initial(&self) -> bool {
74 self.space() == Some(SpaceId::Initial)
75 }
76
77 pub(crate) fn space(&self) -> Option<SpaceId> {
78 use ProtectedHeader::*;
79 match self.plain_header {
80 Initial { .. } => Some(SpaceId::Initial),
81 Long {
82 ty: LongType::Handshake,
83 ..
84 } => Some(SpaceId::Handshake),
85 Long {
86 ty: LongType::ZeroRtt,
87 ..
88 } => Some(SpaceId::Data),
89 Short { .. } => Some(SpaceId::Data),
90 _ => None,
91 }
92 }
93
94 pub(crate) fn is_0rtt(&self) -> bool {
95 match self.plain_header {
96 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
97 _ => false,
98 }
99 }
100
101 pub fn dst_cid(&self) -> ConnectionId {
103 self.plain_header.dst_cid()
104 }
105
106 #[allow(unreachable_pub)] pub fn len(&self) -> usize {
109 self.buf.get_ref().len()
110 }
111
112 pub(crate) fn finish(
113 self,
114 header_crypto: Option<&dyn crypto::HeaderKey>,
115 ) -> Result<Packet, PacketDecodeError> {
116 use ProtectedHeader::*;
117 let Self {
118 plain_header,
119 mut buf,
120 } = self;
121
122 if let Initial(ProtectedInitialHeader {
123 dst_cid,
124 src_cid,
125 token_pos,
126 version,
127 ..
128 }) = plain_header
129 {
130 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
131 let header_len = buf.position() as usize;
132 let mut bytes = buf.into_inner();
133
134 let header_data = bytes.split_to(header_len).freeze();
135 let token = header_data.slice(token_pos.start..token_pos.end);
136 return Ok(Packet {
137 header: Header::Initial(InitialHeader {
138 dst_cid,
139 src_cid,
140 token,
141 number,
142 version,
143 }),
144 header_data,
145 payload: bytes,
146 });
147 }
148
149 let header = match plain_header {
150 Long {
151 ty,
152 dst_cid,
153 src_cid,
154 version,
155 ..
156 } => Header::Long {
157 ty,
158 dst_cid,
159 src_cid,
160 number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
161 version,
162 },
163 Retry {
164 dst_cid,
165 src_cid,
166 version,
167 } => Header::Retry {
168 dst_cid,
169 src_cid,
170 version,
171 },
172 Short { spin, dst_cid, .. } => {
173 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
174 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
175 Header::Short {
176 spin,
177 key_phase,
178 dst_cid,
179 number,
180 }
181 }
182 VersionNegotiate {
183 random,
184 dst_cid,
185 src_cid,
186 } => Header::VersionNegotiate {
187 random,
188 dst_cid,
189 src_cid,
190 },
191 Initial { .. } => unreachable!(),
192 };
193
194 let header_len = buf.position() as usize;
195 let mut bytes = buf.into_inner();
196 Ok(Packet {
197 header,
198 header_data: bytes.split_to(header_len).freeze(),
199 payload: bytes,
200 })
201 }
202
203 fn decrypt_header(
204 buf: &mut io::Cursor<BytesMut>,
205 header_crypto: &dyn crypto::HeaderKey,
206 ) -> Result<PacketNumber, PacketDecodeError> {
207 let packet_length = buf.get_ref().len();
208 let pn_offset = buf.position() as usize;
209 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
210 return Err(PacketDecodeError::InvalidHeader(
211 "packet too short to extract header protection sample",
212 ));
213 }
214
215 header_crypto.decrypt(pn_offset, buf.get_mut());
216
217 let len = PacketNumber::decode_len(buf.get_ref()[0]);
218 PacketNumber::decode(len, buf)
219 }
220}
221
222pub(crate) trait BufLen {
229 fn len(&self) -> usize;
231}
232
233impl BufLen for Vec<u8> {
234 fn len(&self) -> usize {
235 self.len()
236 }
237}
238
239pub(crate) struct Packet {
240 pub(crate) header: Header,
241 pub(crate) header_data: Bytes,
242 pub(crate) payload: BytesMut,
243}
244
245impl Packet {
246 pub(crate) fn reserved_bits_valid(&self) -> bool {
247 let mask = match self.header {
248 Header::Short { .. } => SHORT_RESERVED_BITS,
249 _ => LONG_RESERVED_BITS,
250 };
251 self.header_data[0] & mask == 0
252 }
253}
254
255pub(crate) struct InitialPacket {
256 pub(crate) header: InitialHeader,
257 pub(crate) header_data: Bytes,
258 pub(crate) payload: BytesMut,
259}
260
261impl From<InitialPacket> for Packet {
262 fn from(x: InitialPacket) -> Self {
263 Self {
264 header: Header::Initial(x.header),
265 header_data: x.header_data,
266 payload: x.payload,
267 }
268 }
269}
270
271#[cfg_attr(test, derive(Clone))]
272#[derive(Debug)]
273pub(crate) enum Header {
274 Initial(InitialHeader),
275 Long {
276 ty: LongType,
277 dst_cid: ConnectionId,
278 src_cid: ConnectionId,
279 number: PacketNumber,
280 version: u32,
281 },
282 Retry {
283 dst_cid: ConnectionId,
284 src_cid: ConnectionId,
285 version: u32,
286 },
287 Short {
288 spin: bool,
289 key_phase: bool,
290 dst_cid: ConnectionId,
291 number: PacketNumber,
292 },
293 VersionNegotiate {
294 random: u8,
295 src_cid: ConnectionId,
296 dst_cid: ConnectionId,
297 },
298}
299
300impl Header {
301 pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
302 use Header::*;
303 let start = w.len();
304 match *self {
305 Initial(InitialHeader {
306 ref dst_cid,
307 ref src_cid,
308 ref token,
309 number,
310 version,
311 }) => {
312 w.write(u8::from(LongHeaderType::Initial) | number.tag());
313 w.write(version);
314 dst_cid.encode_long(w);
315 src_cid.encode_long(w);
316 w.write_var(token.len() as u64);
317 w.put_slice(token);
318 w.write::<u16>(0); number.encode(w);
320 PartialEncode {
321 start,
322 header_len: w.len() - start,
323 pn: Some((number.len(), true)),
324 }
325 }
326 Long {
327 ty,
328 ref dst_cid,
329 ref src_cid,
330 number,
331 version,
332 } => {
333 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
334 w.write(version);
335 dst_cid.encode_long(w);
336 src_cid.encode_long(w);
337 w.write::<u16>(0); number.encode(w);
339 PartialEncode {
340 start,
341 header_len: w.len() - start,
342 pn: Some((number.len(), true)),
343 }
344 }
345 Retry {
346 ref dst_cid,
347 ref src_cid,
348 version,
349 } => {
350 w.write(u8::from(LongHeaderType::Retry));
351 w.write(version);
352 dst_cid.encode_long(w);
353 src_cid.encode_long(w);
354 PartialEncode {
355 start,
356 header_len: w.len() - start,
357 pn: None,
358 }
359 }
360 Short {
361 spin,
362 key_phase,
363 ref dst_cid,
364 number,
365 } => {
366 w.write(
367 FIXED_BIT
368 | if key_phase { KEY_PHASE_BIT } else { 0 }
369 | if spin { SPIN_BIT } else { 0 }
370 | number.tag(),
371 );
372 w.put_slice(dst_cid);
373 number.encode(w);
374 PartialEncode {
375 start,
376 header_len: w.len() - start,
377 pn: Some((number.len(), false)),
378 }
379 }
380 VersionNegotiate {
381 ref random,
382 ref dst_cid,
383 ref src_cid,
384 } => {
385 w.write(0x80u8 | random);
386 w.write::<u32>(0);
387 dst_cid.encode_long(w);
388 src_cid.encode_long(w);
389 PartialEncode {
390 start,
391 header_len: w.len() - start,
392 pn: None,
393 }
394 }
395 }
396 }
397
398 pub(crate) fn is_protected(&self) -> bool {
400 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
401 }
402
403 pub(crate) fn number(&self) -> Option<PacketNumber> {
404 use Header::*;
405 Some(match *self {
406 Initial(InitialHeader { number, .. }) => number,
407 Long { number, .. } => number,
408 Short { number, .. } => number,
409 _ => {
410 return None;
411 }
412 })
413 }
414
415 pub(crate) fn space(&self) -> SpaceId {
416 use Header::*;
417 match *self {
418 Short { .. } => SpaceId::Data,
419 Long {
420 ty: LongType::ZeroRtt,
421 ..
422 } => SpaceId::Data,
423 Long {
424 ty: LongType::Handshake,
425 ..
426 } => SpaceId::Handshake,
427 _ => SpaceId::Initial,
428 }
429 }
430
431 pub(crate) fn key_phase(&self) -> bool {
432 match *self {
433 Self::Short { key_phase, .. } => key_phase,
434 _ => false,
435 }
436 }
437
438 pub(crate) fn is_short(&self) -> bool {
439 matches!(*self, Self::Short { .. })
440 }
441
442 pub(crate) fn is_1rtt(&self) -> bool {
443 self.is_short()
444 }
445
446 pub(crate) fn is_0rtt(&self) -> bool {
447 matches!(
448 *self,
449 Self::Long {
450 ty: LongType::ZeroRtt,
451 ..
452 }
453 )
454 }
455
456 pub(crate) fn dst_cid(&self) -> ConnectionId {
457 use Header::*;
458 match *self {
459 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
460 Long { dst_cid, .. } => dst_cid,
461 Retry { dst_cid, .. } => dst_cid,
462 Short { dst_cid, .. } => dst_cid,
463 VersionNegotiate { dst_cid, .. } => dst_cid,
464 }
465 }
466
467 pub(crate) fn has_frames(&self) -> bool {
469 use Header::*;
470 match *self {
471 Initial(_) => true,
472 Long { .. } => true,
473 Retry { .. } => false,
474 Short { .. } => true,
475 VersionNegotiate { .. } => false,
476 }
477 }
478
479 #[cfg(feature = "qlog")]
480 pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
481 match self {
482 Self::Initial(initial_header) => Some(initial_header.src_cid),
483 Self::Long { src_cid, .. } => Some(*src_cid),
484 Self::Retry { src_cid, .. } => Some(*src_cid),
485 Self::Short { .. } => None,
486 Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
487 }
488 }
489}
490
491pub(crate) struct PartialEncode {
492 pub(crate) start: usize,
493 pub(crate) header_len: usize,
494 pn: Option<(usize, bool)>,
496}
497
498impl PartialEncode {
499 pub(crate) fn finish(
500 self,
501 buf: &mut [u8],
502 header_crypto: &dyn crypto::HeaderKey,
503 crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
504 ) {
505 let Self { header_len, pn, .. } = self;
506 let (pn_len, write_len) = match pn {
507 Some((pn_len, write_len)) => (pn_len, write_len),
508 None => return,
509 };
510
511 let pn_pos = header_len - pn_len;
512 if write_len {
513 let len = buf.len() - header_len + pn_len;
514 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
516 slice.put_u16(len as u16 | (0b01 << 14));
517 }
518
519 if let Some((packet_number, path_id, crypto)) = crypto {
520 crypto.encrypt(path_id, packet_number, buf, header_len);
521 }
522
523 debug_assert!(
524 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
525 "packet must be padded to at least {} bytes for header protection sampling",
526 pn_pos + 4 + header_crypto.sample_size()
527 );
528 header_crypto.encrypt(pn_pos, buf);
529 }
530
531 #[cfg(test)]
532 pub(crate) fn dummy() -> PartialEncode {
533 PartialEncode {
534 start: 0,
535 header_len: 2,
536 pn: Some((1, true)),
537 }
538 }
539}
540
541#[derive(Clone, Debug)]
543pub enum ProtectedHeader {
544 Initial(ProtectedInitialHeader),
546 Long {
548 ty: LongType,
550 dst_cid: ConnectionId,
552 src_cid: ConnectionId,
554 len: u64,
556 version: u32,
558 },
559 Retry {
561 dst_cid: ConnectionId,
563 src_cid: ConnectionId,
565 version: u32,
567 },
568 Short {
570 spin: bool,
572 dst_cid: ConnectionId,
574 },
575 VersionNegotiate {
577 random: u8,
579 dst_cid: ConnectionId,
581 src_cid: ConnectionId,
583 },
584}
585
586impl ProtectedHeader {
587 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
588 match self {
589 Self::Initial(x) => Some(x),
590 _ => None,
591 }
592 }
593
594 pub fn dst_cid(&self) -> ConnectionId {
596 use ProtectedHeader::*;
597 match self {
598 Initial(header) => header.dst_cid,
599 &Long { dst_cid, .. } => dst_cid,
600 &Retry { dst_cid, .. } => dst_cid,
601 &Short { dst_cid, .. } => dst_cid,
602 &VersionNegotiate { dst_cid, .. } => dst_cid,
603 }
604 }
605
606 fn payload_len(&self) -> Option<u64> {
607 use ProtectedHeader::*;
608 match self {
609 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
610 _ => None,
611 }
612 }
613
614 pub fn decode(
616 buf: &mut io::Cursor<BytesMut>,
617 cid_parser: &(impl ConnectionIdParser + ?Sized),
618 supported_versions: &[u32],
619 grease_quic_bit: bool,
620 ) -> Result<Self, PacketDecodeError> {
621 let first = buf.get::<u8>()?;
622 if !grease_quic_bit && first & FIXED_BIT == 0 {
623 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
624 }
625 if first & LONG_HEADER_FORM == 0 {
626 let spin = first & SPIN_BIT != 0;
627
628 Ok(Self::Short {
629 spin,
630 dst_cid: cid_parser.parse(buf)?,
631 })
632 } else {
633 let version = buf.get::<u32>()?;
634
635 let dst_cid = ConnectionId::decode_long(buf)
636 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
637 let src_cid = ConnectionId::decode_long(buf)
638 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
639
640 if version == 0 {
642 let random = first & !LONG_HEADER_FORM;
643 return Ok(Self::VersionNegotiate {
644 random,
645 dst_cid,
646 src_cid,
647 });
648 }
649
650 if !supported_versions.contains(&version) {
651 return Err(PacketDecodeError::UnsupportedVersion {
652 src_cid,
653 dst_cid,
654 version,
655 });
656 }
657
658 match LongHeaderType::from_byte(first)? {
659 LongHeaderType::Initial => {
660 let token_len = buf.get_var()? as usize;
661 let token_start = buf.position() as usize;
662 if token_len > buf.remaining() {
663 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
664 }
665 buf.advance(token_len);
666
667 let len = buf.get_var()?;
668 Ok(Self::Initial(ProtectedInitialHeader {
669 dst_cid,
670 src_cid,
671 token_pos: token_start..token_start + token_len,
672 len,
673 version,
674 }))
675 }
676 LongHeaderType::Retry => Ok(Self::Retry {
677 dst_cid,
678 src_cid,
679 version,
680 }),
681 LongHeaderType::Standard(ty) => Ok(Self::Long {
682 ty,
683 dst_cid,
684 src_cid,
685 len: buf.get_var()?,
686 version,
687 }),
688 }
689 }
690 }
691}
692
693#[derive(Clone, Debug)]
695pub struct ProtectedInitialHeader {
696 pub dst_cid: ConnectionId,
698 pub src_cid: ConnectionId,
700 pub token_pos: Range<usize>,
702 pub len: u64,
704 pub version: u32,
706}
707
708#[derive(Clone, Debug)]
709pub(crate) struct InitialHeader {
710 pub(crate) dst_cid: ConnectionId,
711 pub(crate) src_cid: ConnectionId,
712 pub(crate) token: Bytes,
713 pub(crate) number: PacketNumber,
714 pub(crate) version: u32,
715}
716
717#[derive(Debug, Copy, Clone, Eq, PartialEq)]
719pub(crate) enum PacketNumber {
720 U8(u8),
721 U16(u16),
722 U24(u32),
723 U32(u32),
724}
725
726impl PacketNumber {
727 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
728 let range = (n - largest_acked) * 2;
729 if range < 1 << 8 {
730 Self::U8(n as u8)
731 } else if range < 1 << 16 {
732 Self::U16(n as u16)
733 } else if range < 1 << 24 {
734 Self::U24(n as u32)
735 } else if range < 1 << 32 {
736 Self::U32(n as u32)
737 } else {
738 panic!("packet number too large to encode")
739 }
740 }
741
742 pub(crate) fn len(self) -> usize {
743 use PacketNumber::*;
744 match self {
745 U8(_) => 1,
746 U16(_) => 2,
747 U24(_) => 3,
748 U32(_) => 4,
749 }
750 }
751
752 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
753 use PacketNumber::*;
754 match self {
755 U8(x) => w.write(x),
756 U16(x) => w.write(x),
757 U24(x) => w.put_uint(u64::from(x), 3),
758 U32(x) => w.write(x),
759 }
760 }
761
762 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
763 use PacketNumber::*;
764 let pn = match len {
765 1 => U8(r.get()?),
766 2 => U16(r.get()?),
767 3 => U24(r.get_uint(3) as u32),
768 4 => U32(r.get()?),
769 _ => unreachable!(),
770 };
771 Ok(pn)
772 }
773
774 pub(crate) fn decode_len(tag: u8) -> usize {
775 1 + (tag & 0x03) as usize
776 }
777
778 fn tag(self) -> u8 {
779 use PacketNumber::*;
780 match self {
781 U8(_) => 0b00,
782 U16(_) => 0b01,
783 U24(_) => 0b10,
784 U32(_) => 0b11,
785 }
786 }
787
788 pub(crate) fn expand(self, expected: u64) -> u64 {
789 use PacketNumber::*;
791 let truncated = match self {
792 U8(x) => u64::from(x),
793 U16(x) => u64::from(x),
794 U24(x) => u64::from(x),
795 U32(x) => u64::from(x),
796 };
797 let nbits = self.len() * 8;
798 let win = 1 << nbits;
799 let hwin = win / 2;
800 let mask = win - 1;
801 let candidate = (expected & !mask) | truncated;
810 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
811 candidate + win
812 } else if candidate > expected + hwin && candidate > win {
813 candidate - win
814 } else {
815 candidate
816 }
817 }
818}
819
820pub struct FixedLengthConnectionIdParser {
822 expected_len: usize,
823}
824
825impl FixedLengthConnectionIdParser {
826 pub fn new(expected_len: usize) -> Self {
828 Self { expected_len }
829 }
830}
831
832impl ConnectionIdParser for FixedLengthConnectionIdParser {
833 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
834 (buffer.remaining() >= self.expected_len)
835 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
836 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
837 }
838}
839
840pub trait ConnectionIdParser {
842 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
844}
845
846#[derive(Clone, Copy, Debug, Eq, PartialEq)]
848pub(crate) enum LongHeaderType {
849 Initial,
850 Retry,
851 Standard(LongType),
852}
853
854impl LongHeaderType {
855 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
856 use {LongHeaderType::*, LongType::*};
857 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
858 Ok(match (b & 0x30) >> 4 {
859 0x0 => Initial,
860 0x1 => Standard(ZeroRtt),
861 0x2 => Standard(Handshake),
862 0x3 => Retry,
863 _ => unreachable!(),
864 })
865 }
866}
867
868impl From<LongHeaderType> for u8 {
869 fn from(ty: LongHeaderType) -> Self {
870 use {LongHeaderType::*, LongType::*};
871 match ty {
872 Initial => LONG_HEADER_FORM | FIXED_BIT,
873 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
874 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
875 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
876 }
877 }
878}
879
880#[derive(Clone, Copy, Debug, Eq, PartialEq)]
882pub enum LongType {
883 Handshake,
885 ZeroRtt,
887}
888
889#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
891pub enum PacketDecodeError {
892 #[error("unsupported version {version:x}")]
894 UnsupportedVersion {
895 src_cid: ConnectionId,
897 dst_cid: ConnectionId,
899 version: u32,
901 },
902 #[error("invalid header: {0}")]
904 InvalidHeader(&'static str),
905}
906
907impl From<coding::UnexpectedEnd> for PacketDecodeError {
908 fn from(_: coding::UnexpectedEnd) -> Self {
909 Self::InvalidHeader("unexpected end of packet")
910 }
911}
912
913pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
914pub(crate) const FIXED_BIT: u8 = 0x40;
915pub(crate) const SPIN_BIT: u8 = 0x20;
916const SHORT_RESERVED_BITS: u8 = 0x18;
917const LONG_RESERVED_BITS: u8 = 0x0c;
918const KEY_PHASE_BIT: u8 = 0x04;
919
920#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
922pub enum SpaceId {
923 Initial = 0,
925 Handshake = 1,
926 Data = 2,
928}
929
930impl SpaceId {
931 pub fn iter() -> impl Iterator<Item = Self> {
932 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
933 }
934
935 pub fn next(&self) -> Self {
939 match self {
940 Self::Initial => Self::Handshake,
941 Self::Handshake => Self::Data,
942 Self::Data => Self::Data,
943 }
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950 use hex_literal::hex;
951 use std::io;
952
953 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
954 let mut buf = Vec::new();
955 typed.encode(&mut buf);
956 assert_eq!(&buf[..], encoded);
957 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
958 assert_eq!(typed, decoded);
959 }
960
961 #[test]
962 fn roundtrip_packet_numbers() {
963 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
964 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
965 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
966 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
967 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
968 }
969
970 #[test]
971 fn pn_encode() {
972 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
973 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
974 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
975 }
976
977 #[test]
978 fn pn_expand_roundtrip() {
979 for expected in 0..1024 {
980 for actual in expected..1024 {
981 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
982 }
983 }
984 }
985
986 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
987 #[test]
988 fn header_encoding() {
989 use crate::Side;
990 use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
991 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
992 use rustls::crypto::aws_lc_rs::default_provider;
993 #[cfg(feature = "rustls-ring")]
994 use rustls::crypto::ring::default_provider;
995 use rustls::quic::Version;
996
997 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
998 let provider = default_provider();
999
1000 let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
1001 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
1002 let mut buf = Vec::new();
1003 let header = Header::Initial(InitialHeader {
1004 number: PacketNumber::U8(0),
1005 src_cid: ConnectionId::new(&[]),
1006 dst_cid: dcid,
1007 token: Bytes::new(),
1008 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1009 });
1010 let encode = header.encode(&mut buf);
1011 let header_len = buf.len();
1012 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1013 encode.finish(
1014 &mut buf,
1015 &*client.header.local,
1016 Some((0, PathId::ZERO, &*client.packet.local)),
1017 );
1018
1019 for byte in &buf {
1020 print!("{byte:02x}");
1021 }
1022 println!();
1023 assert_eq!(
1024 buf[..],
1025 hex!(
1026 "c8000000010806b858ec6f80452b00004021be
1027 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1028 )[..]
1029 );
1030
1031 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1032 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1033 let decode = PartialDecode::new(
1034 buf.as_slice().into(),
1035 &FixedLengthConnectionIdParser::new(0),
1036 &supported_versions,
1037 false,
1038 )
1039 .unwrap()
1040 .0;
1041 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1042 assert_eq!(
1043 packet.header_data[..],
1044 hex!("c0000000010806b858ec6f80452b0000402100")[..]
1045 );
1046 server
1047 .packet
1048 .remote
1049 .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1050 .unwrap();
1051 assert_eq!(packet.payload[..], [0; 16]);
1052 match packet.header {
1053 Header::Initial(InitialHeader {
1054 number: PacketNumber::U8(0),
1055 ..
1056 }) => {}
1057 _ => {
1058 panic!("unexpected header {:?}", packet.header);
1059 }
1060 }
1061 }
1062}