iroh_quinn_proto/
packet.rs

1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7    ConnectionId, PathId,
8    coding::{self, BufExt, BufMutExt},
9    crypto,
10};
11
12/// Decodes a QUIC packet's invariant header
13///
14/// Due to packet number encryption, it is impossible to fully decode a header
15/// (which includes a variable-length packet number) without crypto context.
16/// The crypto context (represented by the `Crypto` type in Quinn) is usually
17/// part of the `Connection`, or can be derived from the destination CID for
18/// Initial packets.
19///
20/// To cope with this, we decode the invariant header (which should be stable
21/// across QUIC versions), which gives us the destination CID and allows us
22/// to inspect the version and packet type (which depends on the version).
23/// This information allows us to fully decode and decrypt the packet.
24#[cfg_attr(test, derive(Clone))]
25#[derive(Debug)]
26pub struct PartialDecode {
27    plain_header: ProtectedHeader,
28    buf: io::Cursor<BytesMut>,
29}
30
31#[allow(clippy::len_without_is_empty)]
32impl PartialDecode {
33    /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet
34    pub fn new(
35        bytes: BytesMut,
36        cid_parser: &(impl ConnectionIdParser + ?Sized),
37        supported_versions: &[u32],
38        grease_quic_bit: bool,
39    ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
40        let mut buf = io::Cursor::new(bytes);
41        let plain_header =
42            ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
43        let dgram_len = buf.get_ref().len();
44        let packet_len = plain_header
45            .payload_len()
46            .map(|len| (buf.position() + len) as usize)
47            .unwrap_or(dgram_len);
48        match dgram_len.cmp(&packet_len) {
49            Ordering::Equal => Ok((Self { plain_header, buf }, None)),
50            Ordering::Less => Err(PacketDecodeError::InvalidHeader(
51                "packet too short to contain payload length",
52            )),
53            Ordering::Greater => {
54                let rest = Some(buf.get_mut().split_off(packet_len));
55                Ok((Self { plain_header, buf }, rest))
56            }
57        }
58    }
59
60    /// The underlying partially-decoded packet data
61    pub(crate) fn data(&self) -> &[u8] {
62        self.buf.get_ref()
63    }
64
65    pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
66        self.plain_header.as_initial()
67    }
68
69    pub(crate) fn has_long_header(&self) -> bool {
70        !matches!(self.plain_header, ProtectedHeader::Short { .. })
71    }
72
73    pub(crate) fn is_initial(&self) -> bool {
74        self.space() == Some(SpaceId::Initial)
75    }
76
77    pub(crate) fn space(&self) -> Option<SpaceId> {
78        use ProtectedHeader::*;
79        match self.plain_header {
80            Initial { .. } => Some(SpaceId::Initial),
81            Long {
82                ty: LongType::Handshake,
83                ..
84            } => Some(SpaceId::Handshake),
85            Long {
86                ty: LongType::ZeroRtt,
87                ..
88            } => Some(SpaceId::Data),
89            Short { .. } => Some(SpaceId::Data),
90            _ => None,
91        }
92    }
93
94    pub(crate) fn is_0rtt(&self) -> bool {
95        match self.plain_header {
96            ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
97            _ => false,
98        }
99    }
100
101    /// The destination connection ID of the packet
102    pub fn dst_cid(&self) -> ConnectionId {
103        self.plain_header.dst_cid()
104    }
105
106    /// Length of QUIC packet being decoded
107    #[allow(unreachable_pub)] // fuzzing only
108    pub fn len(&self) -> usize {
109        self.buf.get_ref().len()
110    }
111
112    pub(crate) fn finish(
113        self,
114        header_crypto: Option<&dyn crypto::HeaderKey>,
115    ) -> Result<Packet, PacketDecodeError> {
116        use ProtectedHeader::*;
117        let Self {
118            plain_header,
119            mut buf,
120        } = self;
121
122        if let Initial(ProtectedInitialHeader {
123            dst_cid,
124            src_cid,
125            token_pos,
126            version,
127            ..
128        }) = plain_header
129        {
130            let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
131            let header_len = buf.position() as usize;
132            let mut bytes = buf.into_inner();
133
134            let header_data = bytes.split_to(header_len).freeze();
135            let token = header_data.slice(token_pos.start..token_pos.end);
136            return Ok(Packet {
137                header: Header::Initial(InitialHeader {
138                    dst_cid,
139                    src_cid,
140                    token,
141                    number,
142                    version,
143                }),
144                header_data,
145                payload: bytes,
146            });
147        }
148
149        let header = match plain_header {
150            Long {
151                ty,
152                dst_cid,
153                src_cid,
154                version,
155                ..
156            } => Header::Long {
157                ty,
158                dst_cid,
159                src_cid,
160                number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
161                version,
162            },
163            Retry {
164                dst_cid,
165                src_cid,
166                version,
167            } => Header::Retry {
168                dst_cid,
169                src_cid,
170                version,
171            },
172            Short { spin, dst_cid, .. } => {
173                let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
174                let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
175                Header::Short {
176                    spin,
177                    key_phase,
178                    dst_cid,
179                    number,
180                }
181            }
182            VersionNegotiate {
183                random,
184                dst_cid,
185                src_cid,
186            } => Header::VersionNegotiate {
187                random,
188                dst_cid,
189                src_cid,
190            },
191            Initial { .. } => unreachable!(),
192        };
193
194        let header_len = buf.position() as usize;
195        let mut bytes = buf.into_inner();
196        Ok(Packet {
197            header,
198            header_data: bytes.split_to(header_len).freeze(),
199            payload: bytes,
200        })
201    }
202
203    fn decrypt_header(
204        buf: &mut io::Cursor<BytesMut>,
205        header_crypto: &dyn crypto::HeaderKey,
206    ) -> Result<PacketNumber, PacketDecodeError> {
207        let packet_length = buf.get_ref().len();
208        let pn_offset = buf.position() as usize;
209        if packet_length < pn_offset + 4 + header_crypto.sample_size() {
210            return Err(PacketDecodeError::InvalidHeader(
211                "packet too short to extract header protection sample",
212            ));
213        }
214
215        header_crypto.decrypt(pn_offset, buf.get_mut());
216
217        let len = PacketNumber::decode_len(buf.get_ref()[0]);
218        PacketNumber::decode(len, buf)
219    }
220}
221
222/// A buffer that can tell how much has been written to it already
223///
224/// This is commonly used for when a buffer is passed and the user may not write past a
225/// given size. It allows the user of such a buffer to know the current cursor position in
226/// the buffer. The maximum write size is usually passed in the same unit as
227/// [`BufLen::len`]: bytes since the buffer start.
228pub(crate) trait BufLen {
229    /// Returns the number of bytes written into the buffer so far
230    fn len(&self) -> usize;
231}
232
233impl BufLen for Vec<u8> {
234    fn len(&self) -> usize {
235        self.len()
236    }
237}
238
239pub(crate) struct Packet {
240    pub(crate) header: Header,
241    pub(crate) header_data: Bytes,
242    pub(crate) payload: BytesMut,
243}
244
245impl Packet {
246    pub(crate) fn reserved_bits_valid(&self) -> bool {
247        let mask = match self.header {
248            Header::Short { .. } => SHORT_RESERVED_BITS,
249            _ => LONG_RESERVED_BITS,
250        };
251        self.header_data[0] & mask == 0
252    }
253}
254
255pub(crate) struct InitialPacket {
256    pub(crate) header: InitialHeader,
257    pub(crate) header_data: Bytes,
258    pub(crate) payload: BytesMut,
259}
260
261impl From<InitialPacket> for Packet {
262    fn from(x: InitialPacket) -> Self {
263        Self {
264            header: Header::Initial(x.header),
265            header_data: x.header_data,
266            payload: x.payload,
267        }
268    }
269}
270
271#[cfg_attr(test, derive(Clone))]
272#[derive(Debug)]
273pub(crate) enum Header {
274    Initial(InitialHeader),
275    Long {
276        ty: LongType,
277        dst_cid: ConnectionId,
278        src_cid: ConnectionId,
279        number: PacketNumber,
280        version: u32,
281    },
282    Retry {
283        dst_cid: ConnectionId,
284        src_cid: ConnectionId,
285        version: u32,
286    },
287    Short {
288        spin: bool,
289        key_phase: bool,
290        dst_cid: ConnectionId,
291        number: PacketNumber,
292    },
293    VersionNegotiate {
294        random: u8,
295        src_cid: ConnectionId,
296        dst_cid: ConnectionId,
297    },
298}
299
300impl Header {
301    pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
302        use Header::*;
303        let start = w.len();
304        match *self {
305            Initial(InitialHeader {
306                ref dst_cid,
307                ref src_cid,
308                ref token,
309                number,
310                version,
311            }) => {
312                w.write(u8::from(LongHeaderType::Initial) | number.tag());
313                w.write(version);
314                dst_cid.encode_long(w);
315                src_cid.encode_long(w);
316                w.write_var(token.len() as u64);
317                w.put_slice(token);
318                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
319                number.encode(w);
320                PartialEncode {
321                    start,
322                    header_len: w.len() - start,
323                    pn: Some((number.len(), true)),
324                }
325            }
326            Long {
327                ty,
328                ref dst_cid,
329                ref src_cid,
330                number,
331                version,
332            } => {
333                w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
334                w.write(version);
335                dst_cid.encode_long(w);
336                src_cid.encode_long(w);
337                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
338                number.encode(w);
339                PartialEncode {
340                    start,
341                    header_len: w.len() - start,
342                    pn: Some((number.len(), true)),
343                }
344            }
345            Retry {
346                ref dst_cid,
347                ref src_cid,
348                version,
349            } => {
350                w.write(u8::from(LongHeaderType::Retry));
351                w.write(version);
352                dst_cid.encode_long(w);
353                src_cid.encode_long(w);
354                PartialEncode {
355                    start,
356                    header_len: w.len() - start,
357                    pn: None,
358                }
359            }
360            Short {
361                spin,
362                key_phase,
363                ref dst_cid,
364                number,
365            } => {
366                w.write(
367                    FIXED_BIT
368                        | if key_phase { KEY_PHASE_BIT } else { 0 }
369                        | if spin { SPIN_BIT } else { 0 }
370                        | number.tag(),
371                );
372                w.put_slice(dst_cid);
373                number.encode(w);
374                PartialEncode {
375                    start,
376                    header_len: w.len() - start,
377                    pn: Some((number.len(), false)),
378                }
379            }
380            VersionNegotiate {
381                ref random,
382                ref dst_cid,
383                ref src_cid,
384            } => {
385                w.write(0x80u8 | random);
386                w.write::<u32>(0);
387                dst_cid.encode_long(w);
388                src_cid.encode_long(w);
389                PartialEncode {
390                    start,
391                    header_len: w.len() - start,
392                    pn: None,
393                }
394            }
395        }
396    }
397
398    /// Whether the packet is encrypted on the wire
399    pub(crate) fn is_protected(&self) -> bool {
400        !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
401    }
402
403    pub(crate) fn number(&self) -> Option<PacketNumber> {
404        use Header::*;
405        Some(match *self {
406            Initial(InitialHeader { number, .. }) => number,
407            Long { number, .. } => number,
408            Short { number, .. } => number,
409            _ => {
410                return None;
411            }
412        })
413    }
414
415    pub(crate) fn space(&self) -> SpaceId {
416        use Header::*;
417        match *self {
418            Short { .. } => SpaceId::Data,
419            Long {
420                ty: LongType::ZeroRtt,
421                ..
422            } => SpaceId::Data,
423            Long {
424                ty: LongType::Handshake,
425                ..
426            } => SpaceId::Handshake,
427            _ => SpaceId::Initial,
428        }
429    }
430
431    pub(crate) fn key_phase(&self) -> bool {
432        match *self {
433            Self::Short { key_phase, .. } => key_phase,
434            _ => false,
435        }
436    }
437
438    pub(crate) fn is_short(&self) -> bool {
439        matches!(*self, Self::Short { .. })
440    }
441
442    pub(crate) fn is_1rtt(&self) -> bool {
443        self.is_short()
444    }
445
446    pub(crate) fn is_0rtt(&self) -> bool {
447        matches!(
448            *self,
449            Self::Long {
450                ty: LongType::ZeroRtt,
451                ..
452            }
453        )
454    }
455
456    pub(crate) fn dst_cid(&self) -> ConnectionId {
457        use Header::*;
458        match *self {
459            Initial(InitialHeader { dst_cid, .. }) => dst_cid,
460            Long { dst_cid, .. } => dst_cid,
461            Retry { dst_cid, .. } => dst_cid,
462            Short { dst_cid, .. } => dst_cid,
463            VersionNegotiate { dst_cid, .. } => dst_cid,
464        }
465    }
466
467    /// Whether the payload of this packet contains QUIC frames
468    pub(crate) fn has_frames(&self) -> bool {
469        use Header::*;
470        match *self {
471            Initial(_) => true,
472            Long { .. } => true,
473            Retry { .. } => false,
474            Short { .. } => true,
475            VersionNegotiate { .. } => false,
476        }
477    }
478
479    #[cfg(feature = "qlog")]
480    pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
481        match self {
482            Self::Initial(initial_header) => Some(initial_header.src_cid),
483            Self::Long { src_cid, .. } => Some(*src_cid),
484            Self::Retry { src_cid, .. } => Some(*src_cid),
485            Self::Short { .. } => None,
486            Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
487        }
488    }
489}
490
491pub(crate) struct PartialEncode {
492    pub(crate) start: usize,
493    pub(crate) header_len: usize,
494    // Packet number length, payload length needed
495    pn: Option<(usize, bool)>,
496}
497
498impl PartialEncode {
499    pub(crate) fn finish(
500        self,
501        buf: &mut [u8],
502        header_crypto: &dyn crypto::HeaderKey,
503        crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
504    ) {
505        let Self { header_len, pn, .. } = self;
506        let (pn_len, write_len) = match pn {
507            Some((pn_len, write_len)) => (pn_len, write_len),
508            None => return,
509        };
510
511        let pn_pos = header_len - pn_len;
512        if write_len {
513            let len = buf.len() - header_len + pn_len;
514            assert!(len < 2usize.pow(14)); // Fits in reserved space
515            let mut slice = &mut buf[pn_pos - 2..pn_pos];
516            slice.put_u16(len as u16 | (0b01 << 14));
517        }
518
519        if let Some((packet_number, path_id, crypto)) = crypto {
520            crypto.encrypt(path_id, packet_number, buf, header_len);
521        }
522
523        debug_assert!(
524            pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
525            "packet must be padded to at least {} bytes for header protection sampling",
526            pn_pos + 4 + header_crypto.sample_size()
527        );
528        header_crypto.encrypt(pn_pos, buf);
529    }
530}
531
532/// Plain packet header
533#[derive(Clone, Debug)]
534pub enum ProtectedHeader {
535    /// An Initial packet header
536    Initial(ProtectedInitialHeader),
537    /// A Long packet header, as used during the handshake
538    Long {
539        /// Type of the Long header packet
540        ty: LongType,
541        /// Destination Connection ID
542        dst_cid: ConnectionId,
543        /// Source Connection ID
544        src_cid: ConnectionId,
545        /// Length of the packet payload
546        len: u64,
547        /// QUIC version
548        version: u32,
549    },
550    /// A Retry packet header
551    Retry {
552        /// Destination Connection ID
553        dst_cid: ConnectionId,
554        /// Source Connection ID
555        src_cid: ConnectionId,
556        /// QUIC version
557        version: u32,
558    },
559    /// A short packet header, as used during the data phase
560    Short {
561        /// Spin bit
562        spin: bool,
563        /// Destination Connection ID
564        dst_cid: ConnectionId,
565    },
566    /// A Version Negotiation packet header
567    VersionNegotiate {
568        /// Random value
569        random: u8,
570        /// Destination Connection ID
571        dst_cid: ConnectionId,
572        /// Source Connection ID
573        src_cid: ConnectionId,
574    },
575}
576
577impl ProtectedHeader {
578    fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
579        match self {
580            Self::Initial(x) => Some(x),
581            _ => None,
582        }
583    }
584
585    /// The destination Connection ID of the packet
586    pub fn dst_cid(&self) -> ConnectionId {
587        use ProtectedHeader::*;
588        match self {
589            Initial(header) => header.dst_cid,
590            &Long { dst_cid, .. } => dst_cid,
591            &Retry { dst_cid, .. } => dst_cid,
592            &Short { dst_cid, .. } => dst_cid,
593            &VersionNegotiate { dst_cid, .. } => dst_cid,
594        }
595    }
596
597    fn payload_len(&self) -> Option<u64> {
598        use ProtectedHeader::*;
599        match self {
600            Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
601            _ => None,
602        }
603    }
604
605    /// Decode a plain header from given buffer, with given [`ConnectionIdParser`].
606    pub fn decode(
607        buf: &mut io::Cursor<BytesMut>,
608        cid_parser: &(impl ConnectionIdParser + ?Sized),
609        supported_versions: &[u32],
610        grease_quic_bit: bool,
611    ) -> Result<Self, PacketDecodeError> {
612        let first = buf.get::<u8>()?;
613        if !grease_quic_bit && first & FIXED_BIT == 0 {
614            return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
615        }
616        if first & LONG_HEADER_FORM == 0 {
617            let spin = first & SPIN_BIT != 0;
618
619            Ok(Self::Short {
620                spin,
621                dst_cid: cid_parser.parse(buf)?,
622            })
623        } else {
624            let version = buf.get::<u32>()?;
625
626            let dst_cid = ConnectionId::decode_long(buf)
627                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
628            let src_cid = ConnectionId::decode_long(buf)
629                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
630
631            // TODO: Support long CIDs for compatibility with future QUIC versions
632            if version == 0 {
633                let random = first & !LONG_HEADER_FORM;
634                return Ok(Self::VersionNegotiate {
635                    random,
636                    dst_cid,
637                    src_cid,
638                });
639            }
640
641            if !supported_versions.contains(&version) {
642                return Err(PacketDecodeError::UnsupportedVersion {
643                    src_cid,
644                    dst_cid,
645                    version,
646                });
647            }
648
649            match LongHeaderType::from_byte(first)? {
650                LongHeaderType::Initial => {
651                    let token_len = buf.get_var()? as usize;
652                    let token_start = buf.position() as usize;
653                    if token_len > buf.remaining() {
654                        return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
655                    }
656                    buf.advance(token_len);
657
658                    let len = buf.get_var()?;
659                    Ok(Self::Initial(ProtectedInitialHeader {
660                        dst_cid,
661                        src_cid,
662                        token_pos: token_start..token_start + token_len,
663                        len,
664                        version,
665                    }))
666                }
667                LongHeaderType::Retry => Ok(Self::Retry {
668                    dst_cid,
669                    src_cid,
670                    version,
671                }),
672                LongHeaderType::Standard(ty) => Ok(Self::Long {
673                    ty,
674                    dst_cid,
675                    src_cid,
676                    len: buf.get_var()?,
677                    version,
678                }),
679            }
680        }
681    }
682}
683
684/// Header of an Initial packet, before decryption
685#[derive(Clone, Debug)]
686pub struct ProtectedInitialHeader {
687    /// Destination Connection ID
688    pub dst_cid: ConnectionId,
689    /// Source Connection ID
690    pub src_cid: ConnectionId,
691    /// The position of a token in the packet buffer
692    pub token_pos: Range<usize>,
693    /// Length of the packet payload
694    pub len: u64,
695    /// QUIC version
696    pub version: u32,
697}
698
699#[derive(Clone, Debug)]
700pub(crate) struct InitialHeader {
701    pub(crate) dst_cid: ConnectionId,
702    pub(crate) src_cid: ConnectionId,
703    pub(crate) token: Bytes,
704    pub(crate) number: PacketNumber,
705    pub(crate) version: u32,
706}
707
708// An encoded packet number
709#[derive(Debug, Copy, Clone, Eq, PartialEq)]
710pub(crate) enum PacketNumber {
711    U8(u8),
712    U16(u16),
713    U24(u32),
714    U32(u32),
715}
716
717impl PacketNumber {
718    pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
719        let range = (n - largest_acked) * 2;
720        if range < 1 << 8 {
721            Self::U8(n as u8)
722        } else if range < 1 << 16 {
723            Self::U16(n as u16)
724        } else if range < 1 << 24 {
725            Self::U24(n as u32)
726        } else if range < 1 << 32 {
727            Self::U32(n as u32)
728        } else {
729            panic!("packet number too large to encode")
730        }
731    }
732
733    pub(crate) fn len(self) -> usize {
734        use PacketNumber::*;
735        match self {
736            U8(_) => 1,
737            U16(_) => 2,
738            U24(_) => 3,
739            U32(_) => 4,
740        }
741    }
742
743    pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
744        use PacketNumber::*;
745        match self {
746            U8(x) => w.write(x),
747            U16(x) => w.write(x),
748            U24(x) => w.put_uint(u64::from(x), 3),
749            U32(x) => w.write(x),
750        }
751    }
752
753    pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
754        use PacketNumber::*;
755        let pn = match len {
756            1 => U8(r.get()?),
757            2 => U16(r.get()?),
758            3 => U24(r.get_uint(3) as u32),
759            4 => U32(r.get()?),
760            _ => unreachable!(),
761        };
762        Ok(pn)
763    }
764
765    pub(crate) fn decode_len(tag: u8) -> usize {
766        1 + (tag & 0x03) as usize
767    }
768
769    fn tag(self) -> u8 {
770        use PacketNumber::*;
771        match self {
772            U8(_) => 0b00,
773            U16(_) => 0b01,
774            U24(_) => 0b10,
775            U32(_) => 0b11,
776        }
777    }
778
779    pub(crate) fn expand(self, expected: u64) -> u64 {
780        // From Appendix A
781        use PacketNumber::*;
782        let truncated = match self {
783            U8(x) => u64::from(x),
784            U16(x) => u64::from(x),
785            U24(x) => u64::from(x),
786            U32(x) => u64::from(x),
787        };
788        let nbits = self.len() * 8;
789        let win = 1 << nbits;
790        let hwin = win / 2;
791        let mask = win - 1;
792        // The incoming packet number should be greater than expected - hwin and less than or equal
793        // to expected + hwin
794        //
795        // This means we can't just strip the trailing bits from expected and add the truncated
796        // because that might yield a value outside the window.
797        //
798        // The following code calculates a candidate value and makes sure it's within the packet
799        // number window.
800        let candidate = (expected & !mask) | truncated;
801        if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
802            candidate + win
803        } else if candidate > expected + hwin && candidate > win {
804            candidate - win
805        } else {
806            candidate
807        }
808    }
809}
810
811/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length
812pub struct FixedLengthConnectionIdParser {
813    expected_len: usize,
814}
815
816impl FixedLengthConnectionIdParser {
817    /// Create a new instance of `FixedLengthConnectionIdParser`
818    pub fn new(expected_len: usize) -> Self {
819        Self { expected_len }
820    }
821}
822
823impl ConnectionIdParser for FixedLengthConnectionIdParser {
824    fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
825        (buffer.remaining() >= self.expected_len)
826            .then(|| ConnectionId::from_buf(buffer, self.expected_len))
827            .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
828    }
829}
830
831/// Parse connection id in short header packet
832pub trait ConnectionIdParser {
833    /// Parse a connection id from given buffer
834    fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
835}
836
837/// Long packet type including non-uniform cases
838#[derive(Clone, Copy, Debug, Eq, PartialEq)]
839pub(crate) enum LongHeaderType {
840    Initial,
841    Retry,
842    Standard(LongType),
843}
844
845impl LongHeaderType {
846    fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
847        use {LongHeaderType::*, LongType::*};
848        debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
849        Ok(match (b & 0x30) >> 4 {
850            0x0 => Initial,
851            0x1 => Standard(ZeroRtt),
852            0x2 => Standard(Handshake),
853            0x3 => Retry,
854            _ => unreachable!(),
855        })
856    }
857}
858
859impl From<LongHeaderType> for u8 {
860    fn from(ty: LongHeaderType) -> Self {
861        use {LongHeaderType::*, LongType::*};
862        match ty {
863            Initial => LONG_HEADER_FORM | FIXED_BIT,
864            Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
865            Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
866            Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
867        }
868    }
869}
870
871/// Long packet types with uniform header structure
872#[derive(Clone, Copy, Debug, Eq, PartialEq)]
873pub enum LongType {
874    /// Handshake packet
875    Handshake,
876    /// 0-RTT packet
877    ZeroRtt,
878}
879
880/// Packet decode error
881#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
882pub enum PacketDecodeError {
883    /// Packet uses a QUIC version that is not supported
884    #[error("unsupported version {version:x}")]
885    UnsupportedVersion {
886        /// Source Connection ID
887        src_cid: ConnectionId,
888        /// Destination Connection ID
889        dst_cid: ConnectionId,
890        /// The version that was unsupported
891        version: u32,
892    },
893    /// The packet header is invalid
894    #[error("invalid header: {0}")]
895    InvalidHeader(&'static str),
896}
897
898impl From<coding::UnexpectedEnd> for PacketDecodeError {
899    fn from(_: coding::UnexpectedEnd) -> Self {
900        Self::InvalidHeader("unexpected end of packet")
901    }
902}
903
904pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
905pub(crate) const FIXED_BIT: u8 = 0x40;
906pub(crate) const SPIN_BIT: u8 = 0x20;
907const SHORT_RESERVED_BITS: u8 = 0x18;
908const LONG_RESERVED_BITS: u8 = 0x0c;
909const KEY_PHASE_BIT: u8 = 0x04;
910
911/// Packet number space identifiers
912#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
913pub enum SpaceId {
914    /// Unprotected packets, used to bootstrap the handshake
915    Initial = 0,
916    Handshake = 1,
917    /// Application data space, used for 0-RTT and post-handshake/1-RTT packets
918    Data = 2,
919}
920
921impl SpaceId {
922    pub fn iter() -> impl Iterator<Item = Self> {
923        [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
924    }
925
926    /// Returns the next higher packet space.
927    ///
928    /// Keeps returning [`SpaceId::Data`] as the highest space.
929    pub fn next(&self) -> Self {
930        match self {
931            Self::Initial => Self::Handshake,
932            Self::Handshake => Self::Data,
933            Self::Data => Self::Data,
934        }
935    }
936}
937
938#[cfg(test)]
939mod tests {
940    use super::*;
941    use hex_literal::hex;
942    use std::io;
943
944    fn check_pn(typed: PacketNumber, encoded: &[u8]) {
945        let mut buf = Vec::new();
946        typed.encode(&mut buf);
947        assert_eq!(&buf[..], encoded);
948        let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
949        assert_eq!(typed, decoded);
950    }
951
952    #[test]
953    fn roundtrip_packet_numbers() {
954        check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
955        check_pn(PacketNumber::U16(0x80), &hex!("0080"));
956        check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
957        check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
958        check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
959    }
960
961    #[test]
962    fn pn_encode() {
963        check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
964        check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
965        check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
966    }
967
968    #[test]
969    fn pn_expand_roundtrip() {
970        for expected in 0..1024 {
971            for actual in expected..1024 {
972                assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
973            }
974        }
975    }
976
977    #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
978    #[test]
979    fn header_encoding() {
980        use crate::Side;
981        use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
982        #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
983        use rustls::crypto::aws_lc_rs::default_provider;
984        #[cfg(feature = "rustls-ring")]
985        use rustls::crypto::ring::default_provider;
986        use rustls::quic::Version;
987
988        let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
989        let provider = default_provider();
990
991        let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
992        let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
993        let mut buf = Vec::new();
994        let header = Header::Initial(InitialHeader {
995            number: PacketNumber::U8(0),
996            src_cid: ConnectionId::new(&[]),
997            dst_cid: dcid,
998            token: Bytes::new(),
999            version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1000        });
1001        let encode = header.encode(&mut buf);
1002        let header_len = buf.len();
1003        buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1004        encode.finish(
1005            &mut buf,
1006            &*client.header.local,
1007            Some((0, PathId::ZERO, &*client.packet.local)),
1008        );
1009
1010        for byte in &buf {
1011            print!("{byte:02x}");
1012        }
1013        println!();
1014        assert_eq!(
1015            buf[..],
1016            hex!(
1017                "c8000000010806b858ec6f80452b00004021be
1018                 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1019            )[..]
1020        );
1021
1022        let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1023        let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1024        let decode = PartialDecode::new(
1025            buf.as_slice().into(),
1026            &FixedLengthConnectionIdParser::new(0),
1027            &supported_versions,
1028            false,
1029        )
1030        .unwrap()
1031        .0;
1032        let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1033        assert_eq!(
1034            packet.header_data[..],
1035            hex!("c0000000010806b858ec6f80452b0000402100")[..]
1036        );
1037        server
1038            .packet
1039            .remote
1040            .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1041            .unwrap();
1042        assert_eq!(packet.payload[..], [0; 16]);
1043        match packet.header {
1044            Header::Initial(InitialHeader {
1045                number: PacketNumber::U8(0),
1046                ..
1047            }) => {}
1048            _ => {
1049                panic!("unexpected header {:?}", packet.header);
1050            }
1051        }
1052    }
1053}