noq_proto/
varint.rs

1use std::{convert::TryInto, fmt};
2
3use bytes::{Buf, BufMut};
4use thiserror::Error;
5
6use crate::coding::{self, Decodable, Encodable, UnexpectedEnd};
7
8/// An integer less than 2^62
9///
10/// Values of this type are suitable for encoding as QUIC variable-length integer.
11// It would be neat if we could express to Rust that the top two bits are available for use as enum
12// discriminants
13#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
14pub struct VarInt(pub(crate) u64);
15
16impl VarInt {
17    /// The largest representable value
18    pub const MAX: Self = Self((1 << 62) - 1);
19    /// The largest encoded value length
20    pub const MAX_SIZE: usize = 8;
21
22    /// Construct a `VarInt` infallibly
23    pub const fn from_u32(x: u32) -> Self {
24        Self(x as u64)
25    }
26
27    /// Succeeds iff `x` < 2^62
28    pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
29        if x < 2u64.pow(62) {
30            Ok(Self(x))
31        } else {
32            Err(VarIntBoundsExceeded)
33        }
34    }
35
36    /// Create a VarInt without ensuring it's in range
37    ///
38    /// # Safety
39    ///
40    /// `x` must be less than 2^62.
41    pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
42        Self(x)
43    }
44
45    /// Extract the integer value
46    pub const fn into_inner(self) -> u64 {
47        self.0
48    }
49
50    /// Saturating integer addition. Computes self + rhs, saturating at the numeric bounds instead
51    /// of overflowing.
52    pub fn saturating_add(self, rhs: impl Into<Self>) -> Self {
53        let rhs = rhs.into();
54        let inner = self.0.saturating_add(rhs.0).min(Self::MAX.0);
55        Self(inner)
56    }
57
58    /// Compute the number of bytes needed to encode this value
59    pub(crate) const fn size(self) -> usize {
60        let x = self.0;
61        if x < 2u64.pow(6) {
62            1
63        } else if x < 2u64.pow(14) {
64            2
65        } else if x < 2u64.pow(30) {
66            4
67        } else if x < 2u64.pow(62) {
68            8
69        } else {
70            panic!("malformed VarInt");
71        }
72    }
73}
74
75impl From<VarInt> for u64 {
76    fn from(x: VarInt) -> Self {
77        x.0
78    }
79}
80
81impl From<u8> for VarInt {
82    fn from(x: u8) -> Self {
83        Self(x.into())
84    }
85}
86
87impl From<u16> for VarInt {
88    fn from(x: u16) -> Self {
89        Self(x.into())
90    }
91}
92
93impl From<u32> for VarInt {
94    fn from(x: u32) -> Self {
95        Self(x.into())
96    }
97}
98
99impl std::convert::TryFrom<u64> for VarInt {
100    type Error = VarIntBoundsExceeded;
101    /// Succeeds iff `x` < 2^62
102    fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
103        Self::from_u64(x)
104    }
105}
106
107impl std::convert::TryFrom<u128> for VarInt {
108    type Error = VarIntBoundsExceeded;
109    /// Succeeds iff `x` < 2^62
110    fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
111        Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
112    }
113}
114
115impl std::convert::TryFrom<usize> for VarInt {
116    type Error = VarIntBoundsExceeded;
117    /// Succeeds iff `x` < 2^62
118    fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
119        Self::try_from(x as u64)
120    }
121}
122
123impl fmt::Debug for VarInt {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        self.0.fmt(f)
126    }
127}
128
129impl fmt::Display for VarInt {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        self.0.fmt(f)
132    }
133}
134
135#[cfg(feature = "arbitrary")]
136impl<'arbitrary> arbitrary::Arbitrary<'arbitrary> for VarInt {
137    fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result<Self> {
138        Ok(Self(u.int_in_range(0..=Self::MAX.0)?))
139    }
140}
141
142#[cfg(test)]
143impl proptest::arbitrary::Arbitrary for VarInt {
144    type Parameters = ();
145    type Strategy = proptest::strategy::BoxedStrategy<Self>;
146
147    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
148        use proptest::strategy::Strategy;
149        (0..=Self::MAX.0).prop_map(Self).boxed()
150    }
151}
152
153/// Strategy for generating a u64 in the valid VarInt range (0..2^62)
154#[cfg(test)]
155pub(crate) fn varint_u64() -> impl proptest::strategy::Strategy<Value = u64> {
156    0..=VarInt::MAX.0
157}
158
159/// Error returned when constructing a `VarInt` from a value >= 2^62
160#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
161#[error("value too large for varint encoding")]
162pub struct VarIntBoundsExceeded;
163
164impl Decodable for VarInt {
165    fn decode<B: Buf>(r: &mut B) -> coding::Result<Self> {
166        if !r.has_remaining() {
167            return Err(UnexpectedEnd);
168        }
169        let mut buf = [0; 8];
170        buf[0] = r.get_u8();
171        let tag = buf[0] >> 6;
172        buf[0] &= 0b0011_1111;
173        let x = match tag {
174            0b00 => u64::from(buf[0]),
175            0b01 => {
176                if r.remaining() < 1 {
177                    return Err(UnexpectedEnd);
178                }
179                r.copy_to_slice(&mut buf[1..2]);
180                u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
181            }
182            0b10 => {
183                if r.remaining() < 3 {
184                    return Err(UnexpectedEnd);
185                }
186                r.copy_to_slice(&mut buf[1..4]);
187                u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
188            }
189            0b11 => {
190                if r.remaining() < 7 {
191                    return Err(UnexpectedEnd);
192                }
193                r.copy_to_slice(&mut buf[1..8]);
194                u64::from_be_bytes(buf)
195            }
196            _ => unreachable!(),
197        };
198        Ok(Self(x))
199    }
200}
201
202impl Encodable for VarInt {
203    fn encode<B: BufMut>(&self, w: &mut B) {
204        let x = self.0;
205        if x < 2u64.pow(6) {
206            w.put_u8(x as u8);
207        } else if x < 2u64.pow(14) {
208            w.put_u16((0b01 << 14) | x as u16);
209        } else if x < 2u64.pow(30) {
210            w.put_u32((0b10 << 30) | x as u32);
211        } else if x < 2u64.pow(62) {
212            w.put_u64((0b11 << 62) | x);
213        } else {
214            unreachable!("malformed VarInt")
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_saturating_add() {
225        // add within range behaves normally
226        let large: VarInt = u32::MAX.into();
227        let next = u64::from(u32::MAX) + 1;
228        assert_eq!(large.saturating_add(1u8), VarInt::from_u64(next).unwrap());
229
230        // outside range saturates
231        assert_eq!(VarInt::MAX.saturating_add(1u8), VarInt::MAX)
232    }
233}