noq_proto/
packet.rs

1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7    ConnectionId, PathId,
8    coding::{self, BufExt, BufMutExt},
9    connection::{EncryptionLevel, SpaceKind},
10    crypto,
11};
12
13/// Decodes a QUIC packet's invariant header
14///
15/// Due to packet number encryption, it is impossible to fully decode a header
16/// (which includes a variable-length packet number) without crypto context.
17/// The crypto context (represented by the `Crypto` type in noq) is usually
18/// part of the `Connection`, or can be derived from the destination CID for
19/// Initial packets.
20///
21/// To cope with this, we decode the invariant header (which should be stable
22/// across QUIC versions), which gives us the destination CID and allows us
23/// to inspect the version and packet type (which depends on the version).
24/// This information allows us to fully decode and decrypt the packet.
25#[cfg_attr(test, derive(Clone))]
26#[derive(Debug)]
27pub struct PartialDecode {
28    plain_header: ProtectedHeader,
29    buf: io::Cursor<BytesMut>,
30}
31
32#[allow(clippy::len_without_is_empty)]
33impl PartialDecode {
34    /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet
35    pub fn new(
36        bytes: BytesMut,
37        cid_parser: &(impl ConnectionIdParser + ?Sized),
38        supported_versions: &[u32],
39        grease_quic_bit: bool,
40    ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
41        let mut buf = io::Cursor::new(bytes);
42        let plain_header =
43            ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
44        let dgram_len = buf.get_ref().len();
45        let packet_len = plain_header
46            .payload_len()
47            .map(|len| (buf.position() + len) as usize)
48            .unwrap_or(dgram_len);
49        match dgram_len.cmp(&packet_len) {
50            Ordering::Equal => Ok((Self { plain_header, buf }, None)),
51            Ordering::Less => Err(PacketDecodeError::InvalidHeader(
52                "packet too short to contain payload length",
53            )),
54            Ordering::Greater => {
55                let rest = Some(buf.get_mut().split_off(packet_len));
56                Ok((Self { plain_header, buf }, rest))
57            }
58        }
59    }
60
61    /// The underlying partially-decoded packet data
62    pub(crate) fn data(&self) -> &[u8] {
63        self.buf.get_ref()
64    }
65
66    pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
67        self.plain_header.as_initial()
68    }
69
70    pub(crate) fn has_long_header(&self) -> bool {
71        !matches!(self.plain_header, ProtectedHeader::Short { .. })
72    }
73
74    pub(crate) fn is_initial(&self) -> bool {
75        self.encryption_level() == Some(EncryptionLevel::Initial)
76    }
77
78    pub(crate) fn encryption_level(&self) -> Option<EncryptionLevel> {
79        use ProtectedHeader::*;
80        match self.plain_header {
81            Initial { .. } => Some(EncryptionLevel::Initial),
82            Long { ty, .. } => Some(match ty {
83                LongType::Handshake => EncryptionLevel::Handshake,
84                LongType::ZeroRtt => EncryptionLevel::ZeroRtt,
85            }),
86            Short { .. } => Some(EncryptionLevel::OneRtt),
87            _ => None,
88        }
89    }
90
91    pub(crate) fn is_0rtt(&self) -> bool {
92        match self.plain_header {
93            ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
94            _ => false,
95        }
96    }
97
98    /// The destination connection ID of the packet
99    pub fn dst_cid(&self) -> ConnectionId {
100        self.plain_header.dst_cid()
101    }
102
103    /// Length of QUIC packet being decoded
104    #[allow(unreachable_pub)] // fuzzing only
105    pub fn len(&self) -> usize {
106        self.buf.get_ref().len()
107    }
108
109    pub(crate) fn finish(
110        self,
111        header_crypto: Option<&dyn crypto::HeaderKey>,
112    ) -> Result<Packet, PacketDecodeError> {
113        use ProtectedHeader::*;
114        let Self {
115            plain_header,
116            mut buf,
117        } = self;
118
119        if let Initial(ProtectedInitialHeader {
120            dst_cid,
121            src_cid,
122            token_pos,
123            version,
124            ..
125        }) = plain_header
126        {
127            let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
128            let header_len = buf.position() as usize;
129            let mut bytes = buf.into_inner();
130
131            let header_data = bytes.split_to(header_len).freeze();
132            let token = header_data.slice(token_pos.start..token_pos.end);
133            return Ok(Packet {
134                header: Header::Initial(InitialHeader {
135                    dst_cid,
136                    src_cid,
137                    token,
138                    number,
139                    version,
140                }),
141                header_data,
142                payload: bytes,
143            });
144        }
145
146        let header = match plain_header {
147            Long {
148                ty,
149                dst_cid,
150                src_cid,
151                version,
152                ..
153            } => Header::Long {
154                ty,
155                dst_cid,
156                src_cid,
157                number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
158                version,
159            },
160            Retry {
161                dst_cid,
162                src_cid,
163                version,
164            } => Header::Retry {
165                dst_cid,
166                src_cid,
167                version,
168            },
169            Short { spin, dst_cid, .. } => {
170                let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
171                let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
172                Header::Short {
173                    spin,
174                    key_phase,
175                    dst_cid,
176                    number,
177                }
178            }
179            VersionNegotiate {
180                random,
181                dst_cid,
182                src_cid,
183            } => Header::VersionNegotiate {
184                random,
185                dst_cid,
186                src_cid,
187            },
188            Initial { .. } => unreachable!(),
189        };
190
191        let header_len = buf.position() as usize;
192        let mut bytes = buf.into_inner();
193        Ok(Packet {
194            header,
195            header_data: bytes.split_to(header_len).freeze(),
196            payload: bytes,
197        })
198    }
199
200    fn decrypt_header(
201        buf: &mut io::Cursor<BytesMut>,
202        header_crypto: &dyn crypto::HeaderKey,
203    ) -> Result<PacketNumber, PacketDecodeError> {
204        let packet_length = buf.get_ref().len();
205        let pn_offset = buf.position() as usize;
206        if packet_length < pn_offset + 4 + header_crypto.sample_size() {
207            return Err(PacketDecodeError::InvalidHeader(
208                "packet too short to extract header protection sample",
209            ));
210        }
211
212        header_crypto.decrypt(pn_offset, buf.get_mut());
213
214        let len = PacketNumber::decode_len(buf.get_ref()[0]);
215        PacketNumber::decode(len, buf)
216    }
217}
218
219/// A buffer that can tell how much has been written to it already
220///
221/// This is commonly used for when a buffer is passed and the user may not write past a
222/// given size. It allows the user of such a buffer to know the current cursor position in
223/// the buffer. The maximum write size is usually passed in the same unit as
224/// [`BufLen::len`]: bytes since the buffer start.
225pub(crate) trait BufLen {
226    /// Returns the number of bytes written into the buffer so far
227    fn len(&self) -> usize;
228}
229
230impl BufLen for Vec<u8> {
231    fn len(&self) -> usize {
232        self.len()
233    }
234}
235
236/// A received packet with header protection removed.
237// TODO(flub): Would be grand to make this typesafe by adding a generic that indicates
238//    whether the payload is encrypted or decrypted.
239pub(crate) struct Packet {
240    /// The decoded header.
241    pub(crate) header: Header,
242    /// The encoded header data, with header protection removed (i.e. decrypted).
243    pub(crate) header_data: Bytes,
244    /// Packet payload, still encrypted when created, later decrypted in-place.
245    pub(crate) payload: BytesMut,
246}
247
248impl Packet {
249    pub(crate) fn reserved_bits_valid(&self) -> bool {
250        let mask = match self.header {
251            Header::Short { .. } => SHORT_RESERVED_BITS,
252            _ => LONG_RESERVED_BITS,
253        };
254        self.header_data[0] & mask == 0
255    }
256}
257
258pub(crate) struct InitialPacket {
259    pub(crate) header: InitialHeader,
260    pub(crate) header_data: Bytes,
261    pub(crate) payload: BytesMut,
262}
263
264impl From<InitialPacket> for Packet {
265    fn from(x: InitialPacket) -> Self {
266        Self {
267            header: Header::Initial(x.header),
268            header_data: x.header_data,
269            payload: x.payload,
270        }
271    }
272}
273
274#[cfg_attr(test, derive(Clone))]
275#[derive(Debug)]
276pub(crate) enum Header {
277    Initial(InitialHeader),
278    Long {
279        ty: LongType,
280        dst_cid: ConnectionId,
281        src_cid: ConnectionId,
282        number: PacketNumber,
283        version: u32,
284    },
285    Retry {
286        dst_cid: ConnectionId,
287        src_cid: ConnectionId,
288        version: u32,
289    },
290    Short {
291        spin: bool,
292        key_phase: bool,
293        dst_cid: ConnectionId,
294        number: PacketNumber,
295    },
296    VersionNegotiate {
297        random: u8,
298        src_cid: ConnectionId,
299        dst_cid: ConnectionId,
300    },
301}
302
303impl Header {
304    pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
305        use Header::*;
306        let start = w.len();
307        match *self {
308            Initial(InitialHeader {
309                ref dst_cid,
310                ref src_cid,
311                ref token,
312                number,
313                version,
314            }) => {
315                w.write(u8::from(LongHeaderType::Initial) | number.tag());
316                w.write(version);
317                dst_cid.encode_long(w);
318                src_cid.encode_long(w);
319                w.write_var(token.len() as u64);
320                w.put_slice(token);
321                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
322                number.encode(w);
323                PartialEncode {
324                    start,
325                    header_len: w.len() - start,
326                    pn: Some((number.len(), true)),
327                }
328            }
329            Long {
330                ty,
331                ref dst_cid,
332                ref src_cid,
333                number,
334                version,
335            } => {
336                w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
337                w.write(version);
338                dst_cid.encode_long(w);
339                src_cid.encode_long(w);
340                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
341                number.encode(w);
342                PartialEncode {
343                    start,
344                    header_len: w.len() - start,
345                    pn: Some((number.len(), true)),
346                }
347            }
348            Retry {
349                ref dst_cid,
350                ref src_cid,
351                version,
352            } => {
353                w.write(u8::from(LongHeaderType::Retry));
354                w.write(version);
355                dst_cid.encode_long(w);
356                src_cid.encode_long(w);
357                PartialEncode {
358                    start,
359                    header_len: w.len() - start,
360                    pn: None,
361                }
362            }
363            Short {
364                spin,
365                key_phase,
366                ref dst_cid,
367                number,
368            } => {
369                w.write(
370                    FIXED_BIT
371                        | if key_phase { KEY_PHASE_BIT } else { 0 }
372                        | if spin { SPIN_BIT } else { 0 }
373                        | number.tag(),
374                );
375                w.put_slice(dst_cid);
376                number.encode(w);
377                PartialEncode {
378                    start,
379                    header_len: w.len() - start,
380                    pn: Some((number.len(), false)),
381                }
382            }
383            VersionNegotiate {
384                ref random,
385                ref dst_cid,
386                ref src_cid,
387            } => {
388                w.write(0x80u8 | random);
389                w.write::<u32>(0);
390                dst_cid.encode_long(w);
391                src_cid.encode_long(w);
392                PartialEncode {
393                    start,
394                    header_len: w.len() - start,
395                    pn: None,
396                }
397            }
398        }
399    }
400
401    /// Whether the packet is encrypted on the wire
402    pub(crate) fn is_protected(&self) -> bool {
403        !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
404    }
405
406    pub(crate) fn number(&self) -> Option<PacketNumber> {
407        use Header::*;
408        Some(match *self {
409            Initial(InitialHeader { number, .. }) => number,
410            Long { number, .. } => number,
411            Short { number, .. } => number,
412            _ => {
413                return None;
414            }
415        })
416    }
417
418    pub(crate) fn space(&self) -> SpaceKind {
419        use Header::*;
420        match *self {
421            Short { .. } => SpaceKind::Data,
422            Long {
423                ty: LongType::ZeroRtt,
424                ..
425            } => SpaceKind::Data,
426            Long {
427                ty: LongType::Handshake,
428                ..
429            } => SpaceKind::Handshake,
430            _ => SpaceKind::Initial,
431        }
432    }
433
434    pub(crate) fn key_phase(&self) -> bool {
435        match *self {
436            Self::Short { key_phase, .. } => key_phase,
437            _ => false,
438        }
439    }
440
441    pub(crate) fn is_short(&self) -> bool {
442        matches!(*self, Self::Short { .. })
443    }
444
445    pub(crate) fn is_1rtt(&self) -> bool {
446        self.is_short()
447    }
448
449    pub(crate) fn is_0rtt(&self) -> bool {
450        matches!(
451            *self,
452            Self::Long {
453                ty: LongType::ZeroRtt,
454                ..
455            }
456        )
457    }
458
459    pub(crate) fn dst_cid(&self) -> ConnectionId {
460        use Header::*;
461        match *self {
462            Initial(InitialHeader { dst_cid, .. }) => dst_cid,
463            Long { dst_cid, .. } => dst_cid,
464            Retry { dst_cid, .. } => dst_cid,
465            Short { dst_cid, .. } => dst_cid,
466            VersionNegotiate { dst_cid, .. } => dst_cid,
467        }
468    }
469
470    /// Is this packet allowed to be coalesced with others?
471    ///
472    /// Ref <https://www.rfc-editor.org/rfc/rfc9000.html#name-coalescing-packets>
473    pub(crate) fn can_coalesce(&self) -> bool {
474        use Header::*;
475        match *self {
476            Initial(_) => true,
477            Long { .. } => true,
478            Retry { .. } => false,
479            Short { .. } => false,
480            VersionNegotiate { .. } => false,
481        }
482    }
483
484    /// Whether the payload of this packet contains QUIC frames
485    pub(crate) fn has_frames(&self) -> bool {
486        use Header::*;
487        match *self {
488            Initial(_) => true,
489            Long { .. } => true,
490            Retry { .. } => false,
491            Short { .. } => true,
492            VersionNegotiate { .. } => false,
493        }
494    }
495
496    #[cfg(feature = "qlog")]
497    pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
498        match self {
499            Self::Initial(initial_header) => Some(initial_header.src_cid),
500            Self::Long { src_cid, .. } => Some(*src_cid),
501            Self::Retry { src_cid, .. } => Some(*src_cid),
502            Self::Short { .. } => None,
503            Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
504        }
505    }
506}
507
508pub(crate) struct PartialEncode {
509    pub(crate) start: usize,
510    pub(crate) header_len: usize,
511    // Packet number length, payload length needed
512    pn: Option<(usize, bool)>,
513}
514
515impl PartialEncode {
516    pub(crate) fn finish(
517        self,
518        buf: &mut [u8],
519        header_crypto: &dyn crypto::HeaderKey,
520        crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
521    ) {
522        let Self { header_len, pn, .. } = self;
523        let Some((pn_len, write_len)) = pn else {
524            return;
525        };
526
527        let pn_pos = header_len - pn_len;
528        if write_len {
529            let len = buf.len() - header_len + pn_len;
530            assert!(len < 2usize.pow(14)); // Fits in reserved space
531            let mut slice = &mut buf[pn_pos - 2..pn_pos];
532            slice.put_u16(len as u16 | (0b01 << 14));
533        }
534
535        if let Some((packet_number, path_id, crypto)) = crypto {
536            crypto.encrypt(path_id, packet_number, buf, header_len);
537        }
538
539        debug_assert!(
540            pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
541            "packet must be padded to at least {} bytes for header protection sampling",
542            pn_pos + 4 + header_crypto.sample_size()
543        );
544        header_crypto.encrypt(pn_pos, buf);
545    }
546
547    /// Creates a [`PartialEncode`] that has not encoded a header into the buffer.
548    ///
549    /// This is used exclusively for testing as such a type is otherwise invalid.
550    #[cfg(test)]
551    pub(crate) fn no_header() -> Self {
552        Self {
553            start: 0,
554            header_len: 0,
555            pn: None,
556        }
557    }
558}
559
560/// Plain packet header
561#[derive(Clone, Debug)]
562pub enum ProtectedHeader {
563    /// An Initial packet header
564    Initial(ProtectedInitialHeader),
565    /// A Long packet header, as used during the handshake
566    Long {
567        /// Type of the Long header packet
568        ty: LongType,
569        /// Destination Connection ID
570        dst_cid: ConnectionId,
571        /// Source Connection ID
572        src_cid: ConnectionId,
573        /// Length of the packet payload
574        len: u64,
575        /// QUIC version
576        version: u32,
577    },
578    /// A Retry packet header
579    Retry {
580        /// Destination Connection ID
581        dst_cid: ConnectionId,
582        /// Source Connection ID
583        src_cid: ConnectionId,
584        /// QUIC version
585        version: u32,
586    },
587    /// A short packet header, as used during the data phase
588    Short {
589        /// Spin bit
590        spin: bool,
591        /// Destination Connection ID
592        dst_cid: ConnectionId,
593    },
594    /// A Version Negotiation packet header
595    VersionNegotiate {
596        /// Random value
597        random: u8,
598        /// Destination Connection ID
599        dst_cid: ConnectionId,
600        /// Source Connection ID
601        src_cid: ConnectionId,
602    },
603}
604
605impl ProtectedHeader {
606    fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
607        match self {
608            Self::Initial(x) => Some(x),
609            _ => None,
610        }
611    }
612
613    /// The destination Connection ID of the packet
614    pub fn dst_cid(&self) -> ConnectionId {
615        use ProtectedHeader::*;
616        match self {
617            Initial(header) => header.dst_cid,
618            &Long { dst_cid, .. } => dst_cid,
619            &Retry { dst_cid, .. } => dst_cid,
620            &Short { dst_cid, .. } => dst_cid,
621            &VersionNegotiate { dst_cid, .. } => dst_cid,
622        }
623    }
624
625    fn payload_len(&self) -> Option<u64> {
626        use ProtectedHeader::*;
627        match self {
628            Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
629            _ => None,
630        }
631    }
632
633    /// Decode a plain header from given buffer, with given [`ConnectionIdParser`].
634    pub fn decode(
635        buf: &mut io::Cursor<BytesMut>,
636        cid_parser: &(impl ConnectionIdParser + ?Sized),
637        supported_versions: &[u32],
638        grease_quic_bit: bool,
639    ) -> Result<Self, PacketDecodeError> {
640        let first = buf.get::<u8>()?;
641        if !grease_quic_bit && first & FIXED_BIT == 0 {
642            return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
643        }
644        if first & LONG_HEADER_FORM == 0 {
645            let spin = first & SPIN_BIT != 0;
646
647            Ok(Self::Short {
648                spin,
649                dst_cid: cid_parser.parse(buf)?,
650            })
651        } else {
652            let version = buf.get::<u32>()?;
653
654            let dst_cid = ConnectionId::decode_long(buf)
655                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
656            let src_cid = ConnectionId::decode_long(buf)
657                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
658
659            // TODO: Support long CIDs for compatibility with future QUIC versions
660            if version == 0 {
661                let random = first & !LONG_HEADER_FORM;
662                return Ok(Self::VersionNegotiate {
663                    random,
664                    dst_cid,
665                    src_cid,
666                });
667            }
668
669            if !supported_versions.contains(&version) {
670                return Err(PacketDecodeError::UnsupportedVersion {
671                    src_cid,
672                    dst_cid,
673                    version,
674                });
675            }
676
677            match LongHeaderType::from_byte(first)? {
678                LongHeaderType::Initial => {
679                    let token_len = buf.get_var()? as usize;
680                    let token_start = buf.position() as usize;
681                    if token_len > buf.remaining() {
682                        return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
683                    }
684                    buf.advance(token_len);
685
686                    let len = buf.get_var()?;
687                    Ok(Self::Initial(ProtectedInitialHeader {
688                        dst_cid,
689                        src_cid,
690                        token_pos: token_start..token_start + token_len,
691                        len,
692                        version,
693                    }))
694                }
695                LongHeaderType::Retry => Ok(Self::Retry {
696                    dst_cid,
697                    src_cid,
698                    version,
699                }),
700                LongHeaderType::Standard(ty) => Ok(Self::Long {
701                    ty,
702                    dst_cid,
703                    src_cid,
704                    len: buf.get_var()?,
705                    version,
706                }),
707            }
708        }
709    }
710}
711
712/// Header of an Initial packet, before decryption
713#[derive(Clone, Debug)]
714pub struct ProtectedInitialHeader {
715    /// Destination Connection ID
716    pub dst_cid: ConnectionId,
717    /// Source Connection ID
718    pub src_cid: ConnectionId,
719    /// The position of a token in the packet buffer
720    pub token_pos: Range<usize>,
721    /// Length of the packet payload
722    pub len: u64,
723    /// QUIC version
724    pub version: u32,
725}
726
727#[derive(Clone, Debug)]
728pub(crate) struct InitialHeader {
729    pub(crate) dst_cid: ConnectionId,
730    pub(crate) src_cid: ConnectionId,
731    pub(crate) token: Bytes,
732    pub(crate) number: PacketNumber,
733    pub(crate) version: u32,
734}
735
736// An encoded packet number
737#[derive(Debug, Copy, Clone, Eq, PartialEq)]
738pub(crate) enum PacketNumber {
739    U8(u8),
740    U16(u16),
741    U24(u32),
742    U32(u32),
743}
744
745impl PacketNumber {
746    pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
747        let range = (n - largest_acked) * 2;
748        if range < 1 << 8 {
749            Self::U8(n as u8)
750        } else if range < 1 << 16 {
751            Self::U16(n as u16)
752        } else if range < 1 << 24 {
753            Self::U24(n as u32)
754        } else if range < 1 << 32 {
755            Self::U32(n as u32)
756        } else {
757            panic!("packet number too large to encode")
758        }
759    }
760
761    pub(crate) fn len(self) -> usize {
762        use PacketNumber::*;
763        match self {
764            U8(_) => 1,
765            U16(_) => 2,
766            U24(_) => 3,
767            U32(_) => 4,
768        }
769    }
770
771    pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
772        use PacketNumber::*;
773        match self {
774            U8(x) => w.write(x),
775            U16(x) => w.write(x),
776            U24(x) => w.put_uint(u64::from(x), 3),
777            U32(x) => w.write(x),
778        }
779    }
780
781    pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
782        use PacketNumber::*;
783        let pn = match len {
784            1 => U8(r.get()?),
785            2 => U16(r.get()?),
786            3 => U24(r.get_uint(3) as u32),
787            4 => U32(r.get()?),
788            _ => unreachable!(),
789        };
790        Ok(pn)
791    }
792
793    pub(crate) fn decode_len(tag: u8) -> usize {
794        1 + (tag & 0x03) as usize
795    }
796
797    fn tag(self) -> u8 {
798        use PacketNumber::*;
799        match self {
800            U8(_) => 0b00,
801            U16(_) => 0b01,
802            U24(_) => 0b10,
803            U32(_) => 0b11,
804        }
805    }
806
807    pub(crate) fn expand(self, expected: u64) -> u64 {
808        // From Appendix A
809        use PacketNumber::*;
810        let truncated = match self {
811            U8(x) => u64::from(x),
812            U16(x) => u64::from(x),
813            U24(x) => u64::from(x),
814            U32(x) => u64::from(x),
815        };
816        let nbits = self.len() * 8;
817        let win = 1 << nbits;
818        let hwin = win / 2;
819        let mask = win - 1;
820        // The incoming packet number should be greater than expected - hwin and less than or equal
821        // to expected + hwin
822        //
823        // This means we can't just strip the trailing bits from expected and add the truncated
824        // because that might yield a value outside the window.
825        //
826        // The following code calculates a candidate value and makes sure it's within the packet
827        // number window.
828        let candidate = (expected & !mask) | truncated;
829        if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
830            candidate + win
831        } else if candidate > expected + hwin && candidate > win {
832            candidate - win
833        } else {
834            candidate
835        }
836    }
837}
838
839/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length
840pub struct FixedLengthConnectionIdParser {
841    expected_len: usize,
842}
843
844impl FixedLengthConnectionIdParser {
845    /// Create a new instance of `FixedLengthConnectionIdParser`
846    pub fn new(expected_len: usize) -> Self {
847        Self { expected_len }
848    }
849}
850
851impl ConnectionIdParser for FixedLengthConnectionIdParser {
852    fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
853        (buffer.remaining() >= self.expected_len)
854            .then(|| ConnectionId::from_buf(buffer, self.expected_len))
855            .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
856    }
857}
858
859/// Parse connection id in short header packet
860pub trait ConnectionIdParser {
861    /// Parse a connection id from given buffer
862    fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
863}
864
865/// Long packet type including non-uniform cases
866#[derive(Clone, Copy, Debug, Eq, PartialEq)]
867pub(crate) enum LongHeaderType {
868    Initial,
869    Retry,
870    Standard(LongType),
871}
872
873impl LongHeaderType {
874    fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
875        use {LongHeaderType::*, LongType::*};
876        debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
877        Ok(match (b & 0x30) >> 4 {
878            0x0 => Initial,
879            0x1 => Standard(ZeroRtt),
880            0x2 => Standard(Handshake),
881            0x3 => Retry,
882            _ => unreachable!(),
883        })
884    }
885}
886
887impl From<LongHeaderType> for u8 {
888    fn from(ty: LongHeaderType) -> Self {
889        use {LongHeaderType::*, LongType::*};
890        match ty {
891            Initial => LONG_HEADER_FORM | FIXED_BIT,
892            Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
893            Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
894            Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
895        }
896    }
897}
898
899/// Long packet types with uniform header structure
900#[derive(Clone, Copy, Debug, Eq, PartialEq)]
901pub enum LongType {
902    /// Handshake packet
903    Handshake,
904    /// 0-RTT packet
905    ZeroRtt,
906}
907
908/// Packet decode error
909#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
910pub enum PacketDecodeError {
911    /// Packet uses a QUIC version that is not supported
912    #[error("unsupported version {version:x}")]
913    UnsupportedVersion {
914        /// Source Connection ID
915        src_cid: ConnectionId,
916        /// Destination Connection ID
917        dst_cid: ConnectionId,
918        /// The version that was unsupported
919        version: u32,
920    },
921    /// The packet header is invalid
922    #[error("invalid header: {0}")]
923    InvalidHeader(&'static str),
924}
925
926impl From<coding::UnexpectedEnd> for PacketDecodeError {
927    fn from(_: coding::UnexpectedEnd) -> Self {
928        Self::InvalidHeader("unexpected end of packet")
929    }
930}
931
932pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
933pub(crate) const FIXED_BIT: u8 = 0x40;
934pub(crate) const SPIN_BIT: u8 = 0x20;
935const SHORT_RESERVED_BITS: u8 = 0x18;
936const LONG_RESERVED_BITS: u8 = 0x0c;
937const KEY_PHASE_BIT: u8 = 0x04;
938
939/// Packet number space identifiers
940#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
941pub(crate) enum SpaceId {
942    /// Unprotected packets, used to bootstrap the handshake
943    Initial = 0,
944    Handshake = 1,
945    /// Application data space, used for 0-RTT and post-handshake/1-RTT packets
946    Data = 2,
947}
948
949impl SpaceId {
950    pub(crate) fn iter() -> impl Iterator<Item = Self> {
951        [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
952    }
953
954    /// Returns the next higher packet space.
955    ///
956    /// Returns `None` if at  [`SpaceId::Data`].
957    pub(crate) fn next(&self) -> Option<Self> {
958        match self {
959            Self::Initial => Some(Self::Handshake),
960            Self::Handshake => Some(Self::Data),
961            Self::Data => None,
962        }
963    }
964
965    /// Returns the encryption level for this packet space.
966    pub(crate) fn encryption_level(self) -> EncryptionLevel {
967        match self {
968            Self::Initial => EncryptionLevel::Initial,
969            Self::Handshake => EncryptionLevel::Handshake,
970            Self::Data => EncryptionLevel::OneRtt,
971        }
972    }
973
974    /// Returns the [`SpaceKind`] for this packet space.
975    pub(crate) fn kind(self) -> SpaceKind {
976        match self {
977            Self::Initial => SpaceKind::Initial,
978            Self::Handshake => SpaceKind::Handshake,
979            Self::Data => SpaceKind::Data,
980        }
981    }
982}
983
984#[cfg(test)]
985mod tests {
986    use super::*;
987    use hex_literal::hex;
988    use std::io;
989
990    fn check_pn(typed: PacketNumber, encoded: &[u8]) {
991        let mut buf = Vec::new();
992        typed.encode(&mut buf);
993        assert_eq!(&buf[..], encoded);
994        let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
995        assert_eq!(typed, decoded);
996    }
997
998    #[test]
999    fn roundtrip_packet_numbers() {
1000        check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
1001        check_pn(PacketNumber::U16(0x80), &hex!("0080"));
1002        check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
1003        check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
1004        check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
1005    }
1006
1007    #[test]
1008    fn pn_encode() {
1009        check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
1010        check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
1011        check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
1012    }
1013
1014    #[test]
1015    fn pn_expand_roundtrip() {
1016        for expected in 0..1024 {
1017            for actual in expected..1024 {
1018                assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
1019            }
1020        }
1021    }
1022
1023    #[cfg(all(feature = "rustls", any(feature = "aws-lc-rs", feature = "ring")))]
1024    #[test]
1025    fn header_encoding() {
1026        use crate::Side;
1027        use crate::crypto::rustls::{configured_provider, initial_keys, initial_suite_from_provider};
1028        use rustls::quic::Version;
1029
1030        let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
1031        let provider = configured_provider();
1032
1033        let suite = initial_suite_from_provider(&provider).unwrap();
1034        let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
1035        let mut buf = Vec::new();
1036        let header = Header::Initial(InitialHeader {
1037            number: PacketNumber::U8(0),
1038            src_cid: ConnectionId::new(&[]),
1039            dst_cid: dcid,
1040            token: Bytes::new(),
1041            version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1042        });
1043        let encode = header.encode(&mut buf);
1044        let header_len = buf.len();
1045        buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1046        encode.finish(
1047            &mut buf,
1048            &*client.header.local,
1049            Some((0, PathId::ZERO, &*client.packet.local)),
1050        );
1051
1052        for byte in &buf {
1053            print!("{byte:02x}");
1054        }
1055        println!();
1056        assert_eq!(
1057            buf[..],
1058            hex!(
1059                "c8000000010806b858ec6f80452b00004021be
1060                 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1061            )[..]
1062        );
1063
1064        let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1065        let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1066        let decode = PartialDecode::new(
1067            buf.as_slice().into(),
1068            &FixedLengthConnectionIdParser::new(0),
1069            &supported_versions,
1070            false,
1071        )
1072        .unwrap()
1073        .0;
1074        let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1075        assert_eq!(
1076            packet.header_data[..],
1077            hex!("c0000000010806b858ec6f80452b0000402100")[..]
1078        );
1079        server
1080            .packet
1081            .remote
1082            .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1083            .unwrap();
1084        assert_eq!(packet.payload[..], [0; 16]);
1085        match packet.header {
1086            Header::Initial(InitialHeader {
1087                number: PacketNumber::U8(0),
1088                ..
1089            }) => {}
1090            _ => {
1091                panic!("unexpected header {:?}", packet.header);
1092            }
1093        }
1094    }
1095}