noq_proto/
packet.rs

1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7    ConnectionId, PathId,
8    coding::{self, BufExt, BufMutExt},
9    connection::{EncryptionLevel, SpaceKind},
10    crypto,
11};
12
13/// Decodes a QUIC packet's invariant header
14///
15/// Due to packet number encryption, it is impossible to fully decode a header
16/// (which includes a variable-length packet number) without crypto context.
17/// The crypto context (represented by the `Crypto` type in noq) is usually
18/// part of the `Connection`, or can be derived from the destination CID for
19/// Initial packets.
20///
21/// To cope with this, we decode the invariant header (which should be stable
22/// across QUIC versions), which gives us the destination CID and allows us
23/// to inspect the version and packet type (which depends on the version).
24/// This information allows us to fully decode and decrypt the packet.
25#[cfg_attr(test, derive(Clone))]
26#[derive(Debug)]
27pub struct PartialDecode {
28    plain_header: ProtectedHeader,
29    buf: io::Cursor<BytesMut>,
30}
31
32#[allow(clippy::len_without_is_empty)]
33impl PartialDecode {
34    /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet
35    pub fn new(
36        bytes: BytesMut,
37        cid_parser: &(impl ConnectionIdParser + ?Sized),
38        supported_versions: &[u32],
39        grease_quic_bit: bool,
40    ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
41        let mut buf = io::Cursor::new(bytes);
42        let plain_header =
43            ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
44        let dgram_len = buf.get_ref().len();
45        let packet_len = plain_header
46            .payload_len()
47            .map(|len| (buf.position() + len) as usize)
48            .unwrap_or(dgram_len);
49        match dgram_len.cmp(&packet_len) {
50            Ordering::Equal => Ok((Self { plain_header, buf }, None)),
51            Ordering::Less => Err(PacketDecodeError::InvalidHeader(
52                "packet too short to contain payload length",
53            )),
54            Ordering::Greater => {
55                let rest = Some(buf.get_mut().split_off(packet_len));
56                Ok((Self { plain_header, buf }, rest))
57            }
58        }
59    }
60
61    /// The underlying partially-decoded packet data
62    pub(crate) fn data(&self) -> &[u8] {
63        self.buf.get_ref()
64    }
65
66    pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
67        self.plain_header.as_initial()
68    }
69
70    pub(crate) fn has_long_header(&self) -> bool {
71        !matches!(self.plain_header, ProtectedHeader::Short { .. })
72    }
73
74    pub(crate) fn is_initial(&self) -> bool {
75        self.encryption_level() == Some(EncryptionLevel::Initial)
76    }
77
78    pub(crate) fn encryption_level(&self) -> Option<EncryptionLevel> {
79        use ProtectedHeader::*;
80        match self.plain_header {
81            Initial { .. } => Some(EncryptionLevel::Initial),
82            Long { ty, .. } => Some(match ty {
83                LongType::Handshake => EncryptionLevel::Handshake,
84                LongType::ZeroRtt => EncryptionLevel::ZeroRtt,
85            }),
86            Short { .. } => Some(EncryptionLevel::OneRtt),
87            _ => None,
88        }
89    }
90
91    pub(crate) fn is_0rtt(&self) -> bool {
92        match self.plain_header {
93            ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
94            _ => false,
95        }
96    }
97
98    /// The destination connection ID of the packet
99    pub fn dst_cid(&self) -> ConnectionId {
100        self.plain_header.dst_cid()
101    }
102
103    /// Length of QUIC packet being decoded
104    #[allow(unreachable_pub)] // fuzzing only
105    pub fn len(&self) -> usize {
106        self.buf.get_ref().len()
107    }
108
109    pub(crate) fn finish(
110        self,
111        header_crypto: Option<&dyn crypto::HeaderKey>,
112    ) -> Result<Packet, PacketDecodeError> {
113        use ProtectedHeader::*;
114        let Self {
115            plain_header,
116            mut buf,
117        } = self;
118
119        if let Initial(ProtectedInitialHeader {
120            dst_cid,
121            src_cid,
122            token_pos,
123            version,
124            ..
125        }) = plain_header
126        {
127            let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
128            let header_len = buf.position() as usize;
129            let mut bytes = buf.into_inner();
130
131            let header_data = bytes.split_to(header_len).freeze();
132            let token = header_data.slice(token_pos.start..token_pos.end);
133            return Ok(Packet {
134                header: Header::Initial(InitialHeader {
135                    dst_cid,
136                    src_cid,
137                    token,
138                    number,
139                    version,
140                }),
141                header_data,
142                payload: bytes,
143            });
144        }
145
146        let header = match plain_header {
147            Long {
148                ty,
149                dst_cid,
150                src_cid,
151                version,
152                ..
153            } => Header::Long {
154                ty,
155                dst_cid,
156                src_cid,
157                number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
158                version,
159            },
160            Retry {
161                dst_cid,
162                src_cid,
163                version,
164            } => Header::Retry {
165                dst_cid,
166                src_cid,
167                version,
168            },
169            Short { spin, dst_cid, .. } => {
170                let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
171                let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
172                Header::Short {
173                    spin,
174                    key_phase,
175                    dst_cid,
176                    number,
177                }
178            }
179            VersionNegotiate {
180                random,
181                dst_cid,
182                src_cid,
183            } => Header::VersionNegotiate {
184                random,
185                dst_cid,
186                src_cid,
187            },
188            Initial { .. } => unreachable!(),
189        };
190
191        let header_len = buf.position() as usize;
192        let mut bytes = buf.into_inner();
193        Ok(Packet {
194            header,
195            header_data: bytes.split_to(header_len).freeze(),
196            payload: bytes,
197        })
198    }
199
200    fn decrypt_header(
201        buf: &mut io::Cursor<BytesMut>,
202        header_crypto: &dyn crypto::HeaderKey,
203    ) -> Result<PacketNumber, PacketDecodeError> {
204        let packet_length = buf.get_ref().len();
205        let pn_offset = buf.position() as usize;
206        if packet_length < pn_offset + 4 + header_crypto.sample_size() {
207            return Err(PacketDecodeError::InvalidHeader(
208                "packet too short to extract header protection sample",
209            ));
210        }
211
212        header_crypto.decrypt(pn_offset, buf.get_mut());
213
214        let len = PacketNumber::decode_len(buf.get_ref()[0]);
215        PacketNumber::decode(len, buf)
216    }
217}
218
219/// A buffer that can tell how much has been written to it already
220///
221/// This is commonly used for when a buffer is passed and the user may not write past a
222/// given size. It allows the user of such a buffer to know the current cursor position in
223/// the buffer. The maximum write size is usually passed in the same unit as
224/// [`BufLen::len`]: bytes since the buffer start.
225pub(crate) trait BufLen {
226    /// Returns the number of bytes written into the buffer so far
227    fn len(&self) -> usize;
228}
229
230impl BufLen for Vec<u8> {
231    fn len(&self) -> usize {
232        self.len()
233    }
234}
235
236pub(crate) struct Packet {
237    pub(crate) header: Header,
238    pub(crate) header_data: Bytes,
239    pub(crate) payload: BytesMut,
240}
241
242impl Packet {
243    pub(crate) fn reserved_bits_valid(&self) -> bool {
244        let mask = match self.header {
245            Header::Short { .. } => SHORT_RESERVED_BITS,
246            _ => LONG_RESERVED_BITS,
247        };
248        self.header_data[0] & mask == 0
249    }
250}
251
252pub(crate) struct InitialPacket {
253    pub(crate) header: InitialHeader,
254    pub(crate) header_data: Bytes,
255    pub(crate) payload: BytesMut,
256}
257
258impl From<InitialPacket> for Packet {
259    fn from(x: InitialPacket) -> Self {
260        Self {
261            header: Header::Initial(x.header),
262            header_data: x.header_data,
263            payload: x.payload,
264        }
265    }
266}
267
268#[cfg_attr(test, derive(Clone))]
269#[derive(Debug)]
270pub(crate) enum Header {
271    Initial(InitialHeader),
272    Long {
273        ty: LongType,
274        dst_cid: ConnectionId,
275        src_cid: ConnectionId,
276        number: PacketNumber,
277        version: u32,
278    },
279    Retry {
280        dst_cid: ConnectionId,
281        src_cid: ConnectionId,
282        version: u32,
283    },
284    Short {
285        spin: bool,
286        key_phase: bool,
287        dst_cid: ConnectionId,
288        number: PacketNumber,
289    },
290    VersionNegotiate {
291        random: u8,
292        src_cid: ConnectionId,
293        dst_cid: ConnectionId,
294    },
295}
296
297impl Header {
298    pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
299        use Header::*;
300        let start = w.len();
301        match *self {
302            Initial(InitialHeader {
303                ref dst_cid,
304                ref src_cid,
305                ref token,
306                number,
307                version,
308            }) => {
309                w.write(u8::from(LongHeaderType::Initial) | number.tag());
310                w.write(version);
311                dst_cid.encode_long(w);
312                src_cid.encode_long(w);
313                w.write_var(token.len() as u64);
314                w.put_slice(token);
315                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
316                number.encode(w);
317                PartialEncode {
318                    start,
319                    header_len: w.len() - start,
320                    pn: Some((number.len(), true)),
321                }
322            }
323            Long {
324                ty,
325                ref dst_cid,
326                ref src_cid,
327                number,
328                version,
329            } => {
330                w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
331                w.write(version);
332                dst_cid.encode_long(w);
333                src_cid.encode_long(w);
334                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
335                number.encode(w);
336                PartialEncode {
337                    start,
338                    header_len: w.len() - start,
339                    pn: Some((number.len(), true)),
340                }
341            }
342            Retry {
343                ref dst_cid,
344                ref src_cid,
345                version,
346            } => {
347                w.write(u8::from(LongHeaderType::Retry));
348                w.write(version);
349                dst_cid.encode_long(w);
350                src_cid.encode_long(w);
351                PartialEncode {
352                    start,
353                    header_len: w.len() - start,
354                    pn: None,
355                }
356            }
357            Short {
358                spin,
359                key_phase,
360                ref dst_cid,
361                number,
362            } => {
363                w.write(
364                    FIXED_BIT
365                        | if key_phase { KEY_PHASE_BIT } else { 0 }
366                        | if spin { SPIN_BIT } else { 0 }
367                        | number.tag(),
368                );
369                w.put_slice(dst_cid);
370                number.encode(w);
371                PartialEncode {
372                    start,
373                    header_len: w.len() - start,
374                    pn: Some((number.len(), false)),
375                }
376            }
377            VersionNegotiate {
378                ref random,
379                ref dst_cid,
380                ref src_cid,
381            } => {
382                w.write(0x80u8 | random);
383                w.write::<u32>(0);
384                dst_cid.encode_long(w);
385                src_cid.encode_long(w);
386                PartialEncode {
387                    start,
388                    header_len: w.len() - start,
389                    pn: None,
390                }
391            }
392        }
393    }
394
395    /// Whether the packet is encrypted on the wire
396    pub(crate) fn is_protected(&self) -> bool {
397        !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
398    }
399
400    pub(crate) fn number(&self) -> Option<PacketNumber> {
401        use Header::*;
402        Some(match *self {
403            Initial(InitialHeader { number, .. }) => number,
404            Long { number, .. } => number,
405            Short { number, .. } => number,
406            _ => {
407                return None;
408            }
409        })
410    }
411
412    pub(crate) fn space(&self) -> SpaceKind {
413        use Header::*;
414        match *self {
415            Short { .. } => SpaceKind::Data,
416            Long {
417                ty: LongType::ZeroRtt,
418                ..
419            } => SpaceKind::Data,
420            Long {
421                ty: LongType::Handshake,
422                ..
423            } => SpaceKind::Handshake,
424            _ => SpaceKind::Initial,
425        }
426    }
427
428    pub(crate) fn key_phase(&self) -> bool {
429        match *self {
430            Self::Short { key_phase, .. } => key_phase,
431            _ => false,
432        }
433    }
434
435    pub(crate) fn is_short(&self) -> bool {
436        matches!(*self, Self::Short { .. })
437    }
438
439    pub(crate) fn is_1rtt(&self) -> bool {
440        self.is_short()
441    }
442
443    pub(crate) fn is_0rtt(&self) -> bool {
444        matches!(
445            *self,
446            Self::Long {
447                ty: LongType::ZeroRtt,
448                ..
449            }
450        )
451    }
452
453    pub(crate) fn dst_cid(&self) -> ConnectionId {
454        use Header::*;
455        match *self {
456            Initial(InitialHeader { dst_cid, .. }) => dst_cid,
457            Long { dst_cid, .. } => dst_cid,
458            Retry { dst_cid, .. } => dst_cid,
459            Short { dst_cid, .. } => dst_cid,
460            VersionNegotiate { dst_cid, .. } => dst_cid,
461        }
462    }
463
464    /// Is this packet allowed to be coalesced with others?
465    ///
466    /// Ref <https://www.rfc-editor.org/rfc/rfc9000.html#name-coalescing-packets>
467    pub(crate) fn can_coalesce(&self) -> bool {
468        use Header::*;
469        match *self {
470            Initial(_) => true,
471            Long { .. } => true,
472            Retry { .. } => false,
473            Short { .. } => false,
474            VersionNegotiate { .. } => false,
475        }
476    }
477
478    /// Whether the payload of this packet contains QUIC frames
479    pub(crate) fn has_frames(&self) -> bool {
480        use Header::*;
481        match *self {
482            Initial(_) => true,
483            Long { .. } => true,
484            Retry { .. } => false,
485            Short { .. } => true,
486            VersionNegotiate { .. } => false,
487        }
488    }
489
490    #[cfg(feature = "qlog")]
491    pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
492        match self {
493            Self::Initial(initial_header) => Some(initial_header.src_cid),
494            Self::Long { src_cid, .. } => Some(*src_cid),
495            Self::Retry { src_cid, .. } => Some(*src_cid),
496            Self::Short { .. } => None,
497            Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
498        }
499    }
500}
501
502pub(crate) struct PartialEncode {
503    pub(crate) start: usize,
504    pub(crate) header_len: usize,
505    // Packet number length, payload length needed
506    pn: Option<(usize, bool)>,
507}
508
509impl PartialEncode {
510    pub(crate) fn finish(
511        self,
512        buf: &mut [u8],
513        header_crypto: &dyn crypto::HeaderKey,
514        crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
515    ) {
516        let Self { header_len, pn, .. } = self;
517        let (pn_len, write_len) = match pn {
518            Some((pn_len, write_len)) => (pn_len, write_len),
519            None => return,
520        };
521
522        let pn_pos = header_len - pn_len;
523        if write_len {
524            let len = buf.len() - header_len + pn_len;
525            assert!(len < 2usize.pow(14)); // Fits in reserved space
526            let mut slice = &mut buf[pn_pos - 2..pn_pos];
527            slice.put_u16(len as u16 | (0b01 << 14));
528        }
529
530        if let Some((packet_number, path_id, crypto)) = crypto {
531            crypto.encrypt(path_id, packet_number, buf, header_len);
532        }
533
534        debug_assert!(
535            pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
536            "packet must be padded to at least {} bytes for header protection sampling",
537            pn_pos + 4 + header_crypto.sample_size()
538        );
539        header_crypto.encrypt(pn_pos, buf);
540    }
541
542    /// Creates a [`PartialEncode`] that has not encoded a header into the buffer.
543    ///
544    /// This is used exclusively for testing as such a type is otherwise invalid.
545    #[cfg(test)]
546    pub(crate) fn no_header() -> Self {
547        Self {
548            start: 0,
549            header_len: 0,
550            pn: None,
551        }
552    }
553}
554
555/// Plain packet header
556#[derive(Clone, Debug)]
557pub enum ProtectedHeader {
558    /// An Initial packet header
559    Initial(ProtectedInitialHeader),
560    /// A Long packet header, as used during the handshake
561    Long {
562        /// Type of the Long header packet
563        ty: LongType,
564        /// Destination Connection ID
565        dst_cid: ConnectionId,
566        /// Source Connection ID
567        src_cid: ConnectionId,
568        /// Length of the packet payload
569        len: u64,
570        /// QUIC version
571        version: u32,
572    },
573    /// A Retry packet header
574    Retry {
575        /// Destination Connection ID
576        dst_cid: ConnectionId,
577        /// Source Connection ID
578        src_cid: ConnectionId,
579        /// QUIC version
580        version: u32,
581    },
582    /// A short packet header, as used during the data phase
583    Short {
584        /// Spin bit
585        spin: bool,
586        /// Destination Connection ID
587        dst_cid: ConnectionId,
588    },
589    /// A Version Negotiation packet header
590    VersionNegotiate {
591        /// Random value
592        random: u8,
593        /// Destination Connection ID
594        dst_cid: ConnectionId,
595        /// Source Connection ID
596        src_cid: ConnectionId,
597    },
598}
599
600impl ProtectedHeader {
601    fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
602        match self {
603            Self::Initial(x) => Some(x),
604            _ => None,
605        }
606    }
607
608    /// The destination Connection ID of the packet
609    pub fn dst_cid(&self) -> ConnectionId {
610        use ProtectedHeader::*;
611        match self {
612            Initial(header) => header.dst_cid,
613            &Long { dst_cid, .. } => dst_cid,
614            &Retry { dst_cid, .. } => dst_cid,
615            &Short { dst_cid, .. } => dst_cid,
616            &VersionNegotiate { dst_cid, .. } => dst_cid,
617        }
618    }
619
620    fn payload_len(&self) -> Option<u64> {
621        use ProtectedHeader::*;
622        match self {
623            Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
624            _ => None,
625        }
626    }
627
628    /// Decode a plain header from given buffer, with given [`ConnectionIdParser`].
629    pub fn decode(
630        buf: &mut io::Cursor<BytesMut>,
631        cid_parser: &(impl ConnectionIdParser + ?Sized),
632        supported_versions: &[u32],
633        grease_quic_bit: bool,
634    ) -> Result<Self, PacketDecodeError> {
635        let first = buf.get::<u8>()?;
636        if !grease_quic_bit && first & FIXED_BIT == 0 {
637            return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
638        }
639        if first & LONG_HEADER_FORM == 0 {
640            let spin = first & SPIN_BIT != 0;
641
642            Ok(Self::Short {
643                spin,
644                dst_cid: cid_parser.parse(buf)?,
645            })
646        } else {
647            let version = buf.get::<u32>()?;
648
649            let dst_cid = ConnectionId::decode_long(buf)
650                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
651            let src_cid = ConnectionId::decode_long(buf)
652                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
653
654            // TODO: Support long CIDs for compatibility with future QUIC versions
655            if version == 0 {
656                let random = first & !LONG_HEADER_FORM;
657                return Ok(Self::VersionNegotiate {
658                    random,
659                    dst_cid,
660                    src_cid,
661                });
662            }
663
664            if !supported_versions.contains(&version) {
665                return Err(PacketDecodeError::UnsupportedVersion {
666                    src_cid,
667                    dst_cid,
668                    version,
669                });
670            }
671
672            match LongHeaderType::from_byte(first)? {
673                LongHeaderType::Initial => {
674                    let token_len = buf.get_var()? as usize;
675                    let token_start = buf.position() as usize;
676                    if token_len > buf.remaining() {
677                        return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
678                    }
679                    buf.advance(token_len);
680
681                    let len = buf.get_var()?;
682                    Ok(Self::Initial(ProtectedInitialHeader {
683                        dst_cid,
684                        src_cid,
685                        token_pos: token_start..token_start + token_len,
686                        len,
687                        version,
688                    }))
689                }
690                LongHeaderType::Retry => Ok(Self::Retry {
691                    dst_cid,
692                    src_cid,
693                    version,
694                }),
695                LongHeaderType::Standard(ty) => Ok(Self::Long {
696                    ty,
697                    dst_cid,
698                    src_cid,
699                    len: buf.get_var()?,
700                    version,
701                }),
702            }
703        }
704    }
705}
706
707/// Header of an Initial packet, before decryption
708#[derive(Clone, Debug)]
709pub struct ProtectedInitialHeader {
710    /// Destination Connection ID
711    pub dst_cid: ConnectionId,
712    /// Source Connection ID
713    pub src_cid: ConnectionId,
714    /// The position of a token in the packet buffer
715    pub token_pos: Range<usize>,
716    /// Length of the packet payload
717    pub len: u64,
718    /// QUIC version
719    pub version: u32,
720}
721
722#[derive(Clone, Debug)]
723pub(crate) struct InitialHeader {
724    pub(crate) dst_cid: ConnectionId,
725    pub(crate) src_cid: ConnectionId,
726    pub(crate) token: Bytes,
727    pub(crate) number: PacketNumber,
728    pub(crate) version: u32,
729}
730
731// An encoded packet number
732#[derive(Debug, Copy, Clone, Eq, PartialEq)]
733pub(crate) enum PacketNumber {
734    U8(u8),
735    U16(u16),
736    U24(u32),
737    U32(u32),
738}
739
740impl PacketNumber {
741    pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
742        let range = (n - largest_acked) * 2;
743        if range < 1 << 8 {
744            Self::U8(n as u8)
745        } else if range < 1 << 16 {
746            Self::U16(n as u16)
747        } else if range < 1 << 24 {
748            Self::U24(n as u32)
749        } else if range < 1 << 32 {
750            Self::U32(n as u32)
751        } else {
752            panic!("packet number too large to encode")
753        }
754    }
755
756    pub(crate) fn len(self) -> usize {
757        use PacketNumber::*;
758        match self {
759            U8(_) => 1,
760            U16(_) => 2,
761            U24(_) => 3,
762            U32(_) => 4,
763        }
764    }
765
766    pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
767        use PacketNumber::*;
768        match self {
769            U8(x) => w.write(x),
770            U16(x) => w.write(x),
771            U24(x) => w.put_uint(u64::from(x), 3),
772            U32(x) => w.write(x),
773        }
774    }
775
776    pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
777        use PacketNumber::*;
778        let pn = match len {
779            1 => U8(r.get()?),
780            2 => U16(r.get()?),
781            3 => U24(r.get_uint(3) as u32),
782            4 => U32(r.get()?),
783            _ => unreachable!(),
784        };
785        Ok(pn)
786    }
787
788    pub(crate) fn decode_len(tag: u8) -> usize {
789        1 + (tag & 0x03) as usize
790    }
791
792    fn tag(self) -> u8 {
793        use PacketNumber::*;
794        match self {
795            U8(_) => 0b00,
796            U16(_) => 0b01,
797            U24(_) => 0b10,
798            U32(_) => 0b11,
799        }
800    }
801
802    pub(crate) fn expand(self, expected: u64) -> u64 {
803        // From Appendix A
804        use PacketNumber::*;
805        let truncated = match self {
806            U8(x) => u64::from(x),
807            U16(x) => u64::from(x),
808            U24(x) => u64::from(x),
809            U32(x) => u64::from(x),
810        };
811        let nbits = self.len() * 8;
812        let win = 1 << nbits;
813        let hwin = win / 2;
814        let mask = win - 1;
815        // The incoming packet number should be greater than expected - hwin and less than or equal
816        // to expected + hwin
817        //
818        // This means we can't just strip the trailing bits from expected and add the truncated
819        // because that might yield a value outside the window.
820        //
821        // The following code calculates a candidate value and makes sure it's within the packet
822        // number window.
823        let candidate = (expected & !mask) | truncated;
824        if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
825            candidate + win
826        } else if candidate > expected + hwin && candidate > win {
827            candidate - win
828        } else {
829            candidate
830        }
831    }
832}
833
834/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length
835pub struct FixedLengthConnectionIdParser {
836    expected_len: usize,
837}
838
839impl FixedLengthConnectionIdParser {
840    /// Create a new instance of `FixedLengthConnectionIdParser`
841    pub fn new(expected_len: usize) -> Self {
842        Self { expected_len }
843    }
844}
845
846impl ConnectionIdParser for FixedLengthConnectionIdParser {
847    fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
848        (buffer.remaining() >= self.expected_len)
849            .then(|| ConnectionId::from_buf(buffer, self.expected_len))
850            .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
851    }
852}
853
854/// Parse connection id in short header packet
855pub trait ConnectionIdParser {
856    /// Parse a connection id from given buffer
857    fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
858}
859
860/// Long packet type including non-uniform cases
861#[derive(Clone, Copy, Debug, Eq, PartialEq)]
862pub(crate) enum LongHeaderType {
863    Initial,
864    Retry,
865    Standard(LongType),
866}
867
868impl LongHeaderType {
869    fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
870        use {LongHeaderType::*, LongType::*};
871        debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
872        Ok(match (b & 0x30) >> 4 {
873            0x0 => Initial,
874            0x1 => Standard(ZeroRtt),
875            0x2 => Standard(Handshake),
876            0x3 => Retry,
877            _ => unreachable!(),
878        })
879    }
880}
881
882impl From<LongHeaderType> for u8 {
883    fn from(ty: LongHeaderType) -> Self {
884        use {LongHeaderType::*, LongType::*};
885        match ty {
886            Initial => LONG_HEADER_FORM | FIXED_BIT,
887            Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
888            Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
889            Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
890        }
891    }
892}
893
894/// Long packet types with uniform header structure
895#[derive(Clone, Copy, Debug, Eq, PartialEq)]
896pub enum LongType {
897    /// Handshake packet
898    Handshake,
899    /// 0-RTT packet
900    ZeroRtt,
901}
902
903/// Packet decode error
904#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
905pub enum PacketDecodeError {
906    /// Packet uses a QUIC version that is not supported
907    #[error("unsupported version {version:x}")]
908    UnsupportedVersion {
909        /// Source Connection ID
910        src_cid: ConnectionId,
911        /// Destination Connection ID
912        dst_cid: ConnectionId,
913        /// The version that was unsupported
914        version: u32,
915    },
916    /// The packet header is invalid
917    #[error("invalid header: {0}")]
918    InvalidHeader(&'static str),
919}
920
921impl From<coding::UnexpectedEnd> for PacketDecodeError {
922    fn from(_: coding::UnexpectedEnd) -> Self {
923        Self::InvalidHeader("unexpected end of packet")
924    }
925}
926
927pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
928pub(crate) const FIXED_BIT: u8 = 0x40;
929pub(crate) const SPIN_BIT: u8 = 0x20;
930const SHORT_RESERVED_BITS: u8 = 0x18;
931const LONG_RESERVED_BITS: u8 = 0x0c;
932const KEY_PHASE_BIT: u8 = 0x04;
933
934/// Packet number space identifiers
935#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
936pub(crate) enum SpaceId {
937    /// Unprotected packets, used to bootstrap the handshake
938    Initial = 0,
939    Handshake = 1,
940    /// Application data space, used for 0-RTT and post-handshake/1-RTT packets
941    Data = 2,
942}
943
944impl SpaceId {
945    pub(crate) fn iter() -> impl Iterator<Item = Self> {
946        [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
947    }
948
949    /// Returns the next higher packet space.
950    ///
951    /// Returns `None` if at  [`SpaceId::Data`].
952    pub(crate) fn next(&self) -> Option<Self> {
953        match self {
954            Self::Initial => Some(Self::Handshake),
955            Self::Handshake => Some(Self::Data),
956            Self::Data => None,
957        }
958    }
959
960    /// Returns the encryption level for this packet space.
961    pub(crate) fn encryption_level(self) -> EncryptionLevel {
962        match self {
963            Self::Initial => EncryptionLevel::Initial,
964            Self::Handshake => EncryptionLevel::Handshake,
965            Self::Data => EncryptionLevel::OneRtt,
966        }
967    }
968
969    /// Returns the [`SpaceKind`] for this packet space.
970    pub(crate) fn kind(self) -> SpaceKind {
971        match self {
972            Self::Initial => SpaceKind::Initial,
973            Self::Handshake => SpaceKind::Handshake,
974            Self::Data => SpaceKind::Data,
975        }
976    }
977}
978
979#[cfg(test)]
980mod tests {
981    use super::*;
982    use hex_literal::hex;
983    use std::io;
984
985    fn check_pn(typed: PacketNumber, encoded: &[u8]) {
986        let mut buf = Vec::new();
987        typed.encode(&mut buf);
988        assert_eq!(&buf[..], encoded);
989        let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
990        assert_eq!(typed, decoded);
991    }
992
993    #[test]
994    fn roundtrip_packet_numbers() {
995        check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
996        check_pn(PacketNumber::U16(0x80), &hex!("0080"));
997        check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
998        check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
999        check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
1000    }
1001
1002    #[test]
1003    fn pn_encode() {
1004        check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
1005        check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
1006        check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
1007    }
1008
1009    #[test]
1010    fn pn_expand_roundtrip() {
1011        for expected in 0..1024 {
1012            for actual in expected..1024 {
1013                assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
1014            }
1015        }
1016    }
1017
1018    #[cfg(all(feature = "rustls", any(feature = "aws-lc-rs", feature = "ring")))]
1019    #[test]
1020    fn header_encoding() {
1021        use crate::Side;
1022        use crate::crypto::rustls::{configured_provider, initial_keys, initial_suite_from_provider};
1023        use rustls::quic::Version;
1024
1025        let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
1026        let provider = configured_provider();
1027
1028        let suite = initial_suite_from_provider(&provider).unwrap();
1029        let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
1030        let mut buf = Vec::new();
1031        let header = Header::Initial(InitialHeader {
1032            number: PacketNumber::U8(0),
1033            src_cid: ConnectionId::new(&[]),
1034            dst_cid: dcid,
1035            token: Bytes::new(),
1036            version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1037        });
1038        let encode = header.encode(&mut buf);
1039        let header_len = buf.len();
1040        buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1041        encode.finish(
1042            &mut buf,
1043            &*client.header.local,
1044            Some((0, PathId::ZERO, &*client.packet.local)),
1045        );
1046
1047        for byte in &buf {
1048            print!("{byte:02x}");
1049        }
1050        println!();
1051        assert_eq!(
1052            buf[..],
1053            hex!(
1054                "c8000000010806b858ec6f80452b00004021be
1055                 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1056            )[..]
1057        );
1058
1059        let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1060        let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1061        let decode = PartialDecode::new(
1062            buf.as_slice().into(),
1063            &FixedLengthConnectionIdParser::new(0),
1064            &supported_versions,
1065            false,
1066        )
1067        .unwrap()
1068        .0;
1069        let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1070        assert_eq!(
1071            packet.header_data[..],
1072            hex!("c0000000010806b858ec6f80452b0000402100")[..]
1073        );
1074        server
1075            .packet
1076            .remote
1077            .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1078            .unwrap();
1079        assert_eq!(packet.payload[..], [0; 16]);
1080        match packet.header {
1081            Header::Initial(InitialHeader {
1082                number: PacketNumber::U8(0),
1083                ..
1084            }) => {}
1085            _ => {
1086                panic!("unexpected header {:?}", packet.header);
1087            }
1088        }
1089    }
1090}