iroh_quinn_proto/
packet.rs

1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7    ConnectionId, PathId,
8    coding::{self, BufExt, BufMutExt},
9    crypto,
10};
11
12/// Decodes a QUIC packet's invariant header
13///
14/// Due to packet number encryption, it is impossible to fully decode a header
15/// (which includes a variable-length packet number) without crypto context.
16/// The crypto context (represented by the `Crypto` type in Quinn) is usually
17/// part of the `Connection`, or can be derived from the destination CID for
18/// Initial packets.
19///
20/// To cope with this, we decode the invariant header (which should be stable
21/// across QUIC versions), which gives us the destination CID and allows us
22/// to inspect the version and packet type (which depends on the version).
23/// This information allows us to fully decode and decrypt the packet.
24#[cfg_attr(test, derive(Clone))]
25#[derive(Debug)]
26pub struct PartialDecode {
27    plain_header: ProtectedHeader,
28    buf: io::Cursor<BytesMut>,
29}
30
31#[allow(clippy::len_without_is_empty)]
32impl PartialDecode {
33    /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet
34    pub fn new(
35        bytes: BytesMut,
36        cid_parser: &(impl ConnectionIdParser + ?Sized),
37        supported_versions: &[u32],
38        grease_quic_bit: bool,
39    ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
40        let mut buf = io::Cursor::new(bytes);
41        let plain_header =
42            ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
43        let dgram_len = buf.get_ref().len();
44        let packet_len = plain_header
45            .payload_len()
46            .map(|len| (buf.position() + len) as usize)
47            .unwrap_or(dgram_len);
48        match dgram_len.cmp(&packet_len) {
49            Ordering::Equal => Ok((Self { plain_header, buf }, None)),
50            Ordering::Less => Err(PacketDecodeError::InvalidHeader(
51                "packet too short to contain payload length",
52            )),
53            Ordering::Greater => {
54                let rest = Some(buf.get_mut().split_off(packet_len));
55                Ok((Self { plain_header, buf }, rest))
56            }
57        }
58    }
59
60    /// The underlying partially-decoded packet data
61    pub(crate) fn data(&self) -> &[u8] {
62        self.buf.get_ref()
63    }
64
65    pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
66        self.plain_header.as_initial()
67    }
68
69    pub(crate) fn has_long_header(&self) -> bool {
70        !matches!(self.plain_header, ProtectedHeader::Short { .. })
71    }
72
73    pub(crate) fn is_initial(&self) -> bool {
74        self.space() == Some(SpaceId::Initial)
75    }
76
77    pub(crate) fn space(&self) -> Option<SpaceId> {
78        use ProtectedHeader::*;
79        match self.plain_header {
80            Initial { .. } => Some(SpaceId::Initial),
81            Long {
82                ty: LongType::Handshake,
83                ..
84            } => Some(SpaceId::Handshake),
85            Long {
86                ty: LongType::ZeroRtt,
87                ..
88            } => Some(SpaceId::Data),
89            Short { .. } => Some(SpaceId::Data),
90            _ => None,
91        }
92    }
93
94    pub(crate) fn is_0rtt(&self) -> bool {
95        match self.plain_header {
96            ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
97            _ => false,
98        }
99    }
100
101    /// The destination connection ID of the packet
102    pub fn dst_cid(&self) -> ConnectionId {
103        self.plain_header.dst_cid()
104    }
105
106    /// Length of QUIC packet being decoded
107    #[allow(unreachable_pub)] // fuzzing only
108    pub fn len(&self) -> usize {
109        self.buf.get_ref().len()
110    }
111
112    pub(crate) fn finish(
113        self,
114        header_crypto: Option<&dyn crypto::HeaderKey>,
115    ) -> Result<Packet, PacketDecodeError> {
116        use ProtectedHeader::*;
117        let Self {
118            plain_header,
119            mut buf,
120        } = self;
121
122        if let Initial(ProtectedInitialHeader {
123            dst_cid,
124            src_cid,
125            token_pos,
126            version,
127            ..
128        }) = plain_header
129        {
130            let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
131            let header_len = buf.position() as usize;
132            let mut bytes = buf.into_inner();
133
134            let header_data = bytes.split_to(header_len).freeze();
135            let token = header_data.slice(token_pos.start..token_pos.end);
136            return Ok(Packet {
137                header: Header::Initial(InitialHeader {
138                    dst_cid,
139                    src_cid,
140                    token,
141                    number,
142                    version,
143                }),
144                header_data,
145                payload: bytes,
146            });
147        }
148
149        let header = match plain_header {
150            Long {
151                ty,
152                dst_cid,
153                src_cid,
154                version,
155                ..
156            } => Header::Long {
157                ty,
158                dst_cid,
159                src_cid,
160                number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
161                version,
162            },
163            Retry {
164                dst_cid,
165                src_cid,
166                version,
167            } => Header::Retry {
168                dst_cid,
169                src_cid,
170                version,
171            },
172            Short { spin, dst_cid, .. } => {
173                let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
174                let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
175                Header::Short {
176                    spin,
177                    key_phase,
178                    dst_cid,
179                    number,
180                }
181            }
182            VersionNegotiate {
183                random,
184                dst_cid,
185                src_cid,
186            } => Header::VersionNegotiate {
187                random,
188                dst_cid,
189                src_cid,
190            },
191            Initial { .. } => unreachable!(),
192        };
193
194        let header_len = buf.position() as usize;
195        let mut bytes = buf.into_inner();
196        Ok(Packet {
197            header,
198            header_data: bytes.split_to(header_len).freeze(),
199            payload: bytes,
200        })
201    }
202
203    fn decrypt_header(
204        buf: &mut io::Cursor<BytesMut>,
205        header_crypto: &dyn crypto::HeaderKey,
206    ) -> Result<PacketNumber, PacketDecodeError> {
207        let packet_length = buf.get_ref().len();
208        let pn_offset = buf.position() as usize;
209        if packet_length < pn_offset + 4 + header_crypto.sample_size() {
210            return Err(PacketDecodeError::InvalidHeader(
211                "packet too short to extract header protection sample",
212            ));
213        }
214
215        header_crypto.decrypt(pn_offset, buf.get_mut());
216
217        let len = PacketNumber::decode_len(buf.get_ref()[0]);
218        PacketNumber::decode(len, buf)
219    }
220}
221
222/// A buffer that can tell how much has been written to it already
223///
224/// This is commonly used for when a buffer is passed and the user may not write past a
225/// given size. It allows the user of such a buffer to know the current cursor position in
226/// the buffer. The maximum write size is usually passed in the same unit as
227/// [`BufLen::len`]: bytes since the buffer start.
228pub(crate) trait BufLen {
229    /// Returns the number of bytes written into the buffer so far
230    fn len(&self) -> usize;
231}
232
233impl BufLen for Vec<u8> {
234    fn len(&self) -> usize {
235        self.len()
236    }
237}
238
239pub(crate) struct Packet {
240    pub(crate) header: Header,
241    pub(crate) header_data: Bytes,
242    pub(crate) payload: BytesMut,
243}
244
245impl Packet {
246    pub(crate) fn reserved_bits_valid(&self) -> bool {
247        let mask = match self.header {
248            Header::Short { .. } => SHORT_RESERVED_BITS,
249            _ => LONG_RESERVED_BITS,
250        };
251        self.header_data[0] & mask == 0
252    }
253}
254
255pub(crate) struct InitialPacket {
256    pub(crate) header: InitialHeader,
257    pub(crate) header_data: Bytes,
258    pub(crate) payload: BytesMut,
259}
260
261impl From<InitialPacket> for Packet {
262    fn from(x: InitialPacket) -> Self {
263        Self {
264            header: Header::Initial(x.header),
265            header_data: x.header_data,
266            payload: x.payload,
267        }
268    }
269}
270
271#[cfg_attr(test, derive(Clone))]
272#[derive(Debug)]
273pub(crate) enum Header {
274    Initial(InitialHeader),
275    Long {
276        ty: LongType,
277        dst_cid: ConnectionId,
278        src_cid: ConnectionId,
279        number: PacketNumber,
280        version: u32,
281    },
282    Retry {
283        dst_cid: ConnectionId,
284        src_cid: ConnectionId,
285        version: u32,
286    },
287    Short {
288        spin: bool,
289        key_phase: bool,
290        dst_cid: ConnectionId,
291        number: PacketNumber,
292    },
293    VersionNegotiate {
294        random: u8,
295        src_cid: ConnectionId,
296        dst_cid: ConnectionId,
297    },
298}
299
300impl Header {
301    pub(crate) fn encode(&self, w: &mut (impl BufMut + BufLen)) -> PartialEncode {
302        use Header::*;
303        let start = w.len();
304        match *self {
305            Initial(InitialHeader {
306                ref dst_cid,
307                ref src_cid,
308                ref token,
309                number,
310                version,
311            }) => {
312                w.write(u8::from(LongHeaderType::Initial) | number.tag());
313                w.write(version);
314                dst_cid.encode_long(w);
315                src_cid.encode_long(w);
316                w.write_var(token.len() as u64);
317                w.put_slice(token);
318                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
319                number.encode(w);
320                PartialEncode {
321                    start,
322                    header_len: w.len() - start,
323                    pn: Some((number.len(), true)),
324                }
325            }
326            Long {
327                ty,
328                ref dst_cid,
329                ref src_cid,
330                number,
331                version,
332            } => {
333                w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
334                w.write(version);
335                dst_cid.encode_long(w);
336                src_cid.encode_long(w);
337                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
338                number.encode(w);
339                PartialEncode {
340                    start,
341                    header_len: w.len() - start,
342                    pn: Some((number.len(), true)),
343                }
344            }
345            Retry {
346                ref dst_cid,
347                ref src_cid,
348                version,
349            } => {
350                w.write(u8::from(LongHeaderType::Retry));
351                w.write(version);
352                dst_cid.encode_long(w);
353                src_cid.encode_long(w);
354                PartialEncode {
355                    start,
356                    header_len: w.len() - start,
357                    pn: None,
358                }
359            }
360            Short {
361                spin,
362                key_phase,
363                ref dst_cid,
364                number,
365            } => {
366                w.write(
367                    FIXED_BIT
368                        | if key_phase { KEY_PHASE_BIT } else { 0 }
369                        | if spin { SPIN_BIT } else { 0 }
370                        | number.tag(),
371                );
372                w.put_slice(dst_cid);
373                number.encode(w);
374                PartialEncode {
375                    start,
376                    header_len: w.len() - start,
377                    pn: Some((number.len(), false)),
378                }
379            }
380            VersionNegotiate {
381                ref random,
382                ref dst_cid,
383                ref src_cid,
384            } => {
385                w.write(0x80u8 | random);
386                w.write::<u32>(0);
387                dst_cid.encode_long(w);
388                src_cid.encode_long(w);
389                PartialEncode {
390                    start,
391                    header_len: w.len() - start,
392                    pn: None,
393                }
394            }
395        }
396    }
397
398    /// Whether the packet is encrypted on the wire
399    pub(crate) fn is_protected(&self) -> bool {
400        !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
401    }
402
403    pub(crate) fn number(&self) -> Option<PacketNumber> {
404        use Header::*;
405        Some(match *self {
406            Initial(InitialHeader { number, .. }) => number,
407            Long { number, .. } => number,
408            Short { number, .. } => number,
409            _ => {
410                return None;
411            }
412        })
413    }
414
415    pub(crate) fn space(&self) -> SpaceId {
416        use Header::*;
417        match *self {
418            Short { .. } => SpaceId::Data,
419            Long {
420                ty: LongType::ZeroRtt,
421                ..
422            } => SpaceId::Data,
423            Long {
424                ty: LongType::Handshake,
425                ..
426            } => SpaceId::Handshake,
427            _ => SpaceId::Initial,
428        }
429    }
430
431    pub(crate) fn key_phase(&self) -> bool {
432        match *self {
433            Self::Short { key_phase, .. } => key_phase,
434            _ => false,
435        }
436    }
437
438    pub(crate) fn is_short(&self) -> bool {
439        matches!(*self, Self::Short { .. })
440    }
441
442    pub(crate) fn is_1rtt(&self) -> bool {
443        self.is_short()
444    }
445
446    pub(crate) fn is_0rtt(&self) -> bool {
447        matches!(
448            *self,
449            Self::Long {
450                ty: LongType::ZeroRtt,
451                ..
452            }
453        )
454    }
455
456    pub(crate) fn dst_cid(&self) -> ConnectionId {
457        use Header::*;
458        match *self {
459            Initial(InitialHeader { dst_cid, .. }) => dst_cid,
460            Long { dst_cid, .. } => dst_cid,
461            Retry { dst_cid, .. } => dst_cid,
462            Short { dst_cid, .. } => dst_cid,
463            VersionNegotiate { dst_cid, .. } => dst_cid,
464        }
465    }
466
467    /// Whether the payload of this packet contains QUIC frames
468    pub(crate) fn has_frames(&self) -> bool {
469        use Header::*;
470        match *self {
471            Initial(_) => true,
472            Long { .. } => true,
473            Retry { .. } => false,
474            Short { .. } => true,
475            VersionNegotiate { .. } => false,
476        }
477    }
478
479    #[cfg(feature = "qlog")]
480    pub(crate) fn src_cid(&self) -> Option<ConnectionId> {
481        match self {
482            Self::Initial(initial_header) => Some(initial_header.src_cid),
483            Self::Long { src_cid, .. } => Some(*src_cid),
484            Self::Retry { src_cid, .. } => Some(*src_cid),
485            Self::Short { .. } => None,
486            Self::VersionNegotiate { src_cid, .. } => Some(*src_cid),
487        }
488    }
489}
490
491pub(crate) struct PartialEncode {
492    pub(crate) start: usize,
493    pub(crate) header_len: usize,
494    // Packet number length, payload length needed
495    pn: Option<(usize, bool)>,
496}
497
498impl PartialEncode {
499    pub(crate) fn finish(
500        self,
501        buf: &mut [u8],
502        header_crypto: &dyn crypto::HeaderKey,
503        crypto: Option<(u64, PathId, &dyn crypto::PacketKey)>,
504    ) {
505        let Self { header_len, pn, .. } = self;
506        let (pn_len, write_len) = match pn {
507            Some((pn_len, write_len)) => (pn_len, write_len),
508            None => return,
509        };
510
511        let pn_pos = header_len - pn_len;
512        if write_len {
513            let len = buf.len() - header_len + pn_len;
514            assert!(len < 2usize.pow(14)); // Fits in reserved space
515            let mut slice = &mut buf[pn_pos - 2..pn_pos];
516            slice.put_u16(len as u16 | (0b01 << 14));
517        }
518
519        if let Some((packet_number, path_id, crypto)) = crypto {
520            crypto.encrypt(path_id, packet_number, buf, header_len);
521        }
522
523        debug_assert!(
524            pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
525            "packet must be padded to at least {} bytes for header protection sampling",
526            pn_pos + 4 + header_crypto.sample_size()
527        );
528        header_crypto.encrypt(pn_pos, buf);
529    }
530
531    #[cfg(test)]
532    pub(crate) fn dummy() -> PartialEncode {
533        PartialEncode {
534            start: 0,
535            header_len: 2,
536            pn: Some((1, true)),
537        }
538    }
539}
540
541/// Plain packet header
542#[derive(Clone, Debug)]
543pub enum ProtectedHeader {
544    /// An Initial packet header
545    Initial(ProtectedInitialHeader),
546    /// A Long packet header, as used during the handshake
547    Long {
548        /// Type of the Long header packet
549        ty: LongType,
550        /// Destination Connection ID
551        dst_cid: ConnectionId,
552        /// Source Connection ID
553        src_cid: ConnectionId,
554        /// Length of the packet payload
555        len: u64,
556        /// QUIC version
557        version: u32,
558    },
559    /// A Retry packet header
560    Retry {
561        /// Destination Connection ID
562        dst_cid: ConnectionId,
563        /// Source Connection ID
564        src_cid: ConnectionId,
565        /// QUIC version
566        version: u32,
567    },
568    /// A short packet header, as used during the data phase
569    Short {
570        /// Spin bit
571        spin: bool,
572        /// Destination Connection ID
573        dst_cid: ConnectionId,
574    },
575    /// A Version Negotiation packet header
576    VersionNegotiate {
577        /// Random value
578        random: u8,
579        /// Destination Connection ID
580        dst_cid: ConnectionId,
581        /// Source Connection ID
582        src_cid: ConnectionId,
583    },
584}
585
586impl ProtectedHeader {
587    fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
588        match self {
589            Self::Initial(x) => Some(x),
590            _ => None,
591        }
592    }
593
594    /// The destination Connection ID of the packet
595    pub fn dst_cid(&self) -> ConnectionId {
596        use ProtectedHeader::*;
597        match self {
598            Initial(header) => header.dst_cid,
599            &Long { dst_cid, .. } => dst_cid,
600            &Retry { dst_cid, .. } => dst_cid,
601            &Short { dst_cid, .. } => dst_cid,
602            &VersionNegotiate { dst_cid, .. } => dst_cid,
603        }
604    }
605
606    fn payload_len(&self) -> Option<u64> {
607        use ProtectedHeader::*;
608        match self {
609            Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
610            _ => None,
611        }
612    }
613
614    /// Decode a plain header from given buffer, with given [`ConnectionIdParser`].
615    pub fn decode(
616        buf: &mut io::Cursor<BytesMut>,
617        cid_parser: &(impl ConnectionIdParser + ?Sized),
618        supported_versions: &[u32],
619        grease_quic_bit: bool,
620    ) -> Result<Self, PacketDecodeError> {
621        let first = buf.get::<u8>()?;
622        if !grease_quic_bit && first & FIXED_BIT == 0 {
623            return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
624        }
625        if first & LONG_HEADER_FORM == 0 {
626            let spin = first & SPIN_BIT != 0;
627
628            Ok(Self::Short {
629                spin,
630                dst_cid: cid_parser.parse(buf)?,
631            })
632        } else {
633            let version = buf.get::<u32>()?;
634
635            let dst_cid = ConnectionId::decode_long(buf)
636                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
637            let src_cid = ConnectionId::decode_long(buf)
638                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
639
640            // TODO: Support long CIDs for compatibility with future QUIC versions
641            if version == 0 {
642                let random = first & !LONG_HEADER_FORM;
643                return Ok(Self::VersionNegotiate {
644                    random,
645                    dst_cid,
646                    src_cid,
647                });
648            }
649
650            if !supported_versions.contains(&version) {
651                return Err(PacketDecodeError::UnsupportedVersion {
652                    src_cid,
653                    dst_cid,
654                    version,
655                });
656            }
657
658            match LongHeaderType::from_byte(first)? {
659                LongHeaderType::Initial => {
660                    let token_len = buf.get_var()? as usize;
661                    let token_start = buf.position() as usize;
662                    if token_len > buf.remaining() {
663                        return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
664                    }
665                    buf.advance(token_len);
666
667                    let len = buf.get_var()?;
668                    Ok(Self::Initial(ProtectedInitialHeader {
669                        dst_cid,
670                        src_cid,
671                        token_pos: token_start..token_start + token_len,
672                        len,
673                        version,
674                    }))
675                }
676                LongHeaderType::Retry => Ok(Self::Retry {
677                    dst_cid,
678                    src_cid,
679                    version,
680                }),
681                LongHeaderType::Standard(ty) => Ok(Self::Long {
682                    ty,
683                    dst_cid,
684                    src_cid,
685                    len: buf.get_var()?,
686                    version,
687                }),
688            }
689        }
690    }
691}
692
693/// Header of an Initial packet, before decryption
694#[derive(Clone, Debug)]
695pub struct ProtectedInitialHeader {
696    /// Destination Connection ID
697    pub dst_cid: ConnectionId,
698    /// Source Connection ID
699    pub src_cid: ConnectionId,
700    /// The position of a token in the packet buffer
701    pub token_pos: Range<usize>,
702    /// Length of the packet payload
703    pub len: u64,
704    /// QUIC version
705    pub version: u32,
706}
707
708#[derive(Clone, Debug)]
709pub(crate) struct InitialHeader {
710    pub(crate) dst_cid: ConnectionId,
711    pub(crate) src_cid: ConnectionId,
712    pub(crate) token: Bytes,
713    pub(crate) number: PacketNumber,
714    pub(crate) version: u32,
715}
716
717// An encoded packet number
718#[derive(Debug, Copy, Clone, Eq, PartialEq)]
719pub(crate) enum PacketNumber {
720    U8(u8),
721    U16(u16),
722    U24(u32),
723    U32(u32),
724}
725
726impl PacketNumber {
727    pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
728        let range = (n - largest_acked) * 2;
729        if range < 1 << 8 {
730            Self::U8(n as u8)
731        } else if range < 1 << 16 {
732            Self::U16(n as u16)
733        } else if range < 1 << 24 {
734            Self::U24(n as u32)
735        } else if range < 1 << 32 {
736            Self::U32(n as u32)
737        } else {
738            panic!("packet number too large to encode")
739        }
740    }
741
742    pub(crate) fn len(self) -> usize {
743        use PacketNumber::*;
744        match self {
745            U8(_) => 1,
746            U16(_) => 2,
747            U24(_) => 3,
748            U32(_) => 4,
749        }
750    }
751
752    pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
753        use PacketNumber::*;
754        match self {
755            U8(x) => w.write(x),
756            U16(x) => w.write(x),
757            U24(x) => w.put_uint(u64::from(x), 3),
758            U32(x) => w.write(x),
759        }
760    }
761
762    pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
763        use PacketNumber::*;
764        let pn = match len {
765            1 => U8(r.get()?),
766            2 => U16(r.get()?),
767            3 => U24(r.get_uint(3) as u32),
768            4 => U32(r.get()?),
769            _ => unreachable!(),
770        };
771        Ok(pn)
772    }
773
774    pub(crate) fn decode_len(tag: u8) -> usize {
775        1 + (tag & 0x03) as usize
776    }
777
778    fn tag(self) -> u8 {
779        use PacketNumber::*;
780        match self {
781            U8(_) => 0b00,
782            U16(_) => 0b01,
783            U24(_) => 0b10,
784            U32(_) => 0b11,
785        }
786    }
787
788    pub(crate) fn expand(self, expected: u64) -> u64 {
789        // From Appendix A
790        use PacketNumber::*;
791        let truncated = match self {
792            U8(x) => u64::from(x),
793            U16(x) => u64::from(x),
794            U24(x) => u64::from(x),
795            U32(x) => u64::from(x),
796        };
797        let nbits = self.len() * 8;
798        let win = 1 << nbits;
799        let hwin = win / 2;
800        let mask = win - 1;
801        // The incoming packet number should be greater than expected - hwin and less than or equal
802        // to expected + hwin
803        //
804        // This means we can't just strip the trailing bits from expected and add the truncated
805        // because that might yield a value outside the window.
806        //
807        // The following code calculates a candidate value and makes sure it's within the packet
808        // number window.
809        let candidate = (expected & !mask) | truncated;
810        if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
811            candidate + win
812        } else if candidate > expected + hwin && candidate > win {
813            candidate - win
814        } else {
815            candidate
816        }
817    }
818}
819
820/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length
821pub struct FixedLengthConnectionIdParser {
822    expected_len: usize,
823}
824
825impl FixedLengthConnectionIdParser {
826    /// Create a new instance of `FixedLengthConnectionIdParser`
827    pub fn new(expected_len: usize) -> Self {
828        Self { expected_len }
829    }
830}
831
832impl ConnectionIdParser for FixedLengthConnectionIdParser {
833    fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
834        (buffer.remaining() >= self.expected_len)
835            .then(|| ConnectionId::from_buf(buffer, self.expected_len))
836            .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
837    }
838}
839
840/// Parse connection id in short header packet
841pub trait ConnectionIdParser {
842    /// Parse a connection id from given buffer
843    fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
844}
845
846/// Long packet type including non-uniform cases
847#[derive(Clone, Copy, Debug, Eq, PartialEq)]
848pub(crate) enum LongHeaderType {
849    Initial,
850    Retry,
851    Standard(LongType),
852}
853
854impl LongHeaderType {
855    fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
856        use {LongHeaderType::*, LongType::*};
857        debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
858        Ok(match (b & 0x30) >> 4 {
859            0x0 => Initial,
860            0x1 => Standard(ZeroRtt),
861            0x2 => Standard(Handshake),
862            0x3 => Retry,
863            _ => unreachable!(),
864        })
865    }
866}
867
868impl From<LongHeaderType> for u8 {
869    fn from(ty: LongHeaderType) -> Self {
870        use {LongHeaderType::*, LongType::*};
871        match ty {
872            Initial => LONG_HEADER_FORM | FIXED_BIT,
873            Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
874            Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
875            Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
876        }
877    }
878}
879
880/// Long packet types with uniform header structure
881#[derive(Clone, Copy, Debug, Eq, PartialEq)]
882pub enum LongType {
883    /// Handshake packet
884    Handshake,
885    /// 0-RTT packet
886    ZeroRtt,
887}
888
889/// Packet decode error
890#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
891pub enum PacketDecodeError {
892    /// Packet uses a QUIC version that is not supported
893    #[error("unsupported version {version:x}")]
894    UnsupportedVersion {
895        /// Source Connection ID
896        src_cid: ConnectionId,
897        /// Destination Connection ID
898        dst_cid: ConnectionId,
899        /// The version that was unsupported
900        version: u32,
901    },
902    /// The packet header is invalid
903    #[error("invalid header: {0}")]
904    InvalidHeader(&'static str),
905}
906
907impl From<coding::UnexpectedEnd> for PacketDecodeError {
908    fn from(_: coding::UnexpectedEnd) -> Self {
909        Self::InvalidHeader("unexpected end of packet")
910    }
911}
912
913pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
914pub(crate) const FIXED_BIT: u8 = 0x40;
915pub(crate) const SPIN_BIT: u8 = 0x20;
916const SHORT_RESERVED_BITS: u8 = 0x18;
917const LONG_RESERVED_BITS: u8 = 0x0c;
918const KEY_PHASE_BIT: u8 = 0x04;
919
920/// Packet number space identifiers
921#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
922pub enum SpaceId {
923    /// Unprotected packets, used to bootstrap the handshake
924    Initial = 0,
925    Handshake = 1,
926    /// Application data space, used for 0-RTT and post-handshake/1-RTT packets
927    Data = 2,
928}
929
930impl SpaceId {
931    pub fn iter() -> impl Iterator<Item = Self> {
932        [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
933    }
934
935    /// Returns the next higher packet space.
936    ///
937    /// Keeps returning [`SpaceId::Data`] as the highest space.
938    pub fn next(&self) -> Self {
939        match self {
940            Self::Initial => Self::Handshake,
941            Self::Handshake => Self::Data,
942            Self::Data => Self::Data,
943        }
944    }
945}
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950    use hex_literal::hex;
951    use std::io;
952
953    fn check_pn(typed: PacketNumber, encoded: &[u8]) {
954        let mut buf = Vec::new();
955        typed.encode(&mut buf);
956        assert_eq!(&buf[..], encoded);
957        let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
958        assert_eq!(typed, decoded);
959    }
960
961    #[test]
962    fn roundtrip_packet_numbers() {
963        check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
964        check_pn(PacketNumber::U16(0x80), &hex!("0080"));
965        check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
966        check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
967        check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
968    }
969
970    #[test]
971    fn pn_encode() {
972        check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
973        check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
974        check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
975    }
976
977    #[test]
978    fn pn_expand_roundtrip() {
979        for expected in 0..1024 {
980            for actual in expected..1024 {
981                assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
982            }
983        }
984    }
985
986    #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
987    #[test]
988    fn header_encoding() {
989        use crate::Side;
990        use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
991        #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
992        use rustls::crypto::aws_lc_rs::default_provider;
993        #[cfg(feature = "rustls-ring")]
994        use rustls::crypto::ring::default_provider;
995        use rustls::quic::Version;
996
997        let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
998        let provider = default_provider();
999
1000        let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
1001        let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
1002        let mut buf = Vec::new();
1003        let header = Header::Initial(InitialHeader {
1004            number: PacketNumber::U8(0),
1005            src_cid: ConnectionId::new(&[]),
1006            dst_cid: dcid,
1007            token: Bytes::new(),
1008            version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
1009        });
1010        let encode = header.encode(&mut buf);
1011        let header_len = buf.len();
1012        buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
1013        encode.finish(
1014            &mut buf,
1015            &*client.header.local,
1016            Some((0, PathId::ZERO, &*client.packet.local)),
1017        );
1018
1019        for byte in &buf {
1020            print!("{byte:02x}");
1021        }
1022        println!();
1023        assert_eq!(
1024            buf[..],
1025            hex!(
1026                "c8000000010806b858ec6f80452b00004021be
1027                 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1028            )[..]
1029        );
1030
1031        let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1032        let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1033        let decode = PartialDecode::new(
1034            buf.as_slice().into(),
1035            &FixedLengthConnectionIdParser::new(0),
1036            &supported_versions,
1037            false,
1038        )
1039        .unwrap()
1040        .0;
1041        let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1042        assert_eq!(
1043            packet.header_data[..],
1044            hex!("c0000000010806b858ec6f80452b0000402100")[..]
1045        );
1046        server
1047            .packet
1048            .remote
1049            .decrypt(PathId::ZERO, 0, &packet.header_data, &mut packet.payload)
1050            .unwrap();
1051        assert_eq!(packet.payload[..], [0; 16]);
1052        match packet.header {
1053            Header::Initial(InitialHeader {
1054                number: PacketNumber::U8(0),
1055                ..
1056            }) => {}
1057            _ => {
1058                panic!("unexpected header {:?}", packet.header);
1059            }
1060        }
1061    }
1062}