iroh_quinn_proto/connection/
send_buffer.rs

1use std::{collections::VecDeque, ops::Range};
2
3use bytes::{Buf, BufMut, Bytes};
4
5use crate::{VarInt, range_set::ArrayRangeSet};
6
7/// Buffer of outgoing retransmittable stream data
8#[derive(Default, Debug)]
9pub(super) struct SendBuffer {
10    /// Data queued by the application that has to be retained for resends.
11    ///
12    /// Only data up to the highest contiguous acknowledged offset can be discarded.
13    /// We could discard acknowledged in this buffer, but it would require a more
14    /// complex data structure. Instead, we track acknowledged ranges in `acks`.
15    ///
16    /// Data keeps track of the base offset of the buffered data.
17    data: SendBufferData,
18    /// The first offset that hasn't been sent even once
19    ///
20    /// Always lies in `data.range()`
21    unsent: u64,
22    /// Acknowledged ranges which couldn't be discarded yet as they don't include the earliest
23    /// offset in `unacked`
24    ///
25    /// All ranges must be within `data.range().start..(data.range().end - unsent)`, since data
26    /// that has never been sent can't be acknowledged.
27    // TODO: Recover storage from these by compacting (#700)
28    acks: ArrayRangeSet,
29    /// Previously transmitted ranges deemed lost and marked for retransmission
30    ///
31    /// All ranges must be within `data.range().start..(data.range().end - unsent)`, since data
32    /// that has never been sent can't be retransmitted.
33    ///
34    /// This should usually not overlap with `acks`, but this is not strictly enforced.
35    retransmits: ArrayRangeSet,
36}
37
38/// This is where the data of the send buffer lives. It supports appending at the end,
39/// removing from the front, and retrieving data by range.
40#[derive(Default, Debug)]
41struct SendBufferData {
42    /// Start offset of the buffered data
43    offset: u64,
44    /// Buffered data segments
45    segments: VecDeque<Bytes>,
46    /// Total size of `buffered_segments`
47    len: usize,
48}
49
50impl SendBufferData {
51    /// Total size of buffered data
52    fn len(&self) -> usize {
53        self.len
54    }
55
56    /// Range of buffered data
57    #[inline(always)]
58    fn range(&self) -> Range<u64> {
59        self.offset..self.offset + self.len as u64
60    }
61
62    /// Append data to the end of the buffer
63    fn append(&mut self, data: Bytes) {
64        self.len += data.len();
65        self.segments.push_back(data);
66    }
67
68    /// Discard data from the front of the buffer
69    ///
70    /// Calling this with n > len() is allowed and will simply clear the buffer.
71    fn pop_front(&mut self, n: usize) {
72        let mut n = n.min(self.len);
73        self.len -= n;
74        self.offset += n as u64;
75        while n > 0 {
76            let front = self.segments.front_mut().expect("Expected buffered data");
77
78            if front.len() <= n {
79                // Remove the whole front segment
80                n -= front.len();
81                self.segments.pop_front();
82            } else {
83                // Advance within the front segment
84                front.advance(n);
85                n = 0;
86            }
87        }
88        if self.segments.len() * 4 < self.segments.capacity() {
89            self.segments.shrink_to_fit();
90        }
91    }
92
93    /// Returns data which is associated with a range
94    ///
95    /// Requesting a range outside of the buffered data will panic.
96    #[cfg(any(test, feature = "bench"))]
97    fn get(&self, offsets: Range<u64>) -> &[u8] {
98        assert!(
99            offsets.start >= self.range().start && offsets.end <= self.range().end,
100            "Requested range is outside of buffered data"
101        );
102        // translate to segment-relative offsets and usize
103        let offsets = Range {
104            start: (offsets.start - self.offset) as usize,
105            end: (offsets.end - self.offset) as usize,
106        };
107        let mut segment_offset = 0;
108        for segment in self.segments.iter() {
109            if offsets.start >= segment_offset && offsets.start < segment_offset + segment.len() {
110                let start = offsets.start - segment_offset;
111                let end = offsets.end - segment_offset;
112
113                return &segment[start..end.min(segment.len())];
114            }
115            segment_offset += segment.len();
116        }
117
118        unreachable!("impossible if segments and range are consistent");
119    }
120
121    fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
122        assert!(
123            offsets.start >= self.range().start && offsets.end <= self.range().end,
124            "Requested range is outside of buffered data"
125        );
126        // translate to segment-relative offsets and usize
127        let offsets = Range {
128            start: (offsets.start - self.offset) as usize,
129            end: (offsets.end - self.offset) as usize,
130        };
131        let mut segment_offset = 0;
132        for segment in self.segments.iter() {
133            // intersect segment range with requested range
134            let start = segment_offset.max(offsets.start);
135            let end = (segment_offset + segment.len()).min(offsets.end);
136            if start < end {
137                // slice range intersects with requested range
138                buf.put_slice(&segment[start - segment_offset..end - segment_offset]);
139            }
140            segment_offset += segment.len();
141            if segment_offset >= offsets.end {
142                // we are beyond the requested range
143                break;
144            }
145        }
146    }
147
148    #[cfg(test)]
149    fn to_vec(&self) -> Vec<u8> {
150        let mut result = Vec::with_capacity(self.len);
151        for segment in self.segments.iter() {
152            result.extend_from_slice(&segment[..]);
153        }
154        result
155    }
156}
157
158impl SendBuffer {
159    /// Construct an empty buffer at the initial offset
160    pub(super) fn new() -> Self {
161        Self::default()
162    }
163
164    /// Append application data to the end of the stream
165    pub(super) fn write(&mut self, data: Bytes) {
166        self.data.append(data);
167    }
168
169    /// Discard a range of acknowledged stream data
170    pub(super) fn ack(&mut self, mut range: Range<u64>) {
171        // Clamp the range to data which is still tracked
172        let base_offset = self.data.range().start;
173        range.start = base_offset.max(range.start);
174        range.end = base_offset.max(range.end);
175
176        self.acks.insert(range);
177
178        while self.acks.min() == Some(self.data.range().start) {
179            let prefix = self.acks.pop_min().unwrap();
180            let to_advance = (prefix.end - prefix.start) as usize;
181            self.data.pop_front(to_advance);
182        }
183    }
184
185    /// Compute the next range to transmit on this stream and update state to account for that
186    /// transmission.
187    ///
188    /// `max_len` here includes the space which is available to transmit the
189    /// offset and length of the data to send. The caller has to guarantee that
190    /// there is at least enough space available to write maximum-sized metadata
191    /// (8 byte offset + 8 byte length).
192    ///
193    /// The method returns a tuple:
194    /// - The first return value indicates the range of data to send
195    /// - The second return value indicates whether the length needs to be encoded
196    ///   in the STREAM frames metadata (`true`), or whether it can be omitted
197    ///   since the selected range will fill the whole packet.
198    pub(super) fn poll_transmit(&mut self, mut max_len: usize) -> (Range<u64>, bool) {
199        debug_assert!(max_len >= 8 + 8);
200        let mut encode_length = false;
201
202        if let Some(range) = self.retransmits.pop_min() {
203            // Retransmit sent data
204
205            // When the offset is known, we know how many bytes are required to encode it.
206            // Offset 0 requires no space
207            if range.start != 0 {
208                max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) });
209            }
210            if range.end - range.start < max_len as u64 {
211                encode_length = true;
212                max_len -= 8;
213            }
214
215            let end = range.end.min((max_len as u64).saturating_add(range.start));
216            if end != range.end {
217                self.retransmits.insert(end..range.end);
218            }
219            return (range.start..end, encode_length);
220        }
221
222        // Transmit new data
223
224        // When the offset is known, we know how many bytes are required to encode it.
225        // Offset 0 requires no space
226        if self.unsent != 0 {
227            max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) });
228        }
229        if self.offset() - self.unsent < max_len as u64 {
230            encode_length = true;
231            max_len -= 8;
232        }
233
234        let end = self
235            .offset()
236            .min((max_len as u64).saturating_add(self.unsent));
237        let result = self.unsent..end;
238        self.unsent = end;
239        (result, encode_length)
240    }
241
242    /// Returns data which is associated with a range
243    ///
244    /// This function can return a subset of the range, if the data is stored
245    /// in noncontiguous fashion in the send buffer. In this case callers
246    /// should call the function again with an incremented start offset to
247    /// retrieve more data.
248    #[cfg(any(test, feature = "bench"))]
249    pub(super) fn get(&self, offsets: Range<u64>) -> &[u8] {
250        self.data.get(offsets)
251    }
252
253    pub(super) fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
254        self.data.get_into(offsets, buf)
255    }
256
257    /// Queue a range of sent but unacknowledged data to be retransmitted
258    pub(super) fn retransmit(&mut self, range: Range<u64>) {
259        debug_assert!(range.end <= self.unsent, "unsent data can't be lost");
260        self.retransmits.insert(range);
261    }
262
263    pub(super) fn retransmit_all_for_0rtt(&mut self) {
264        // check that we still got all data - we didn't get any acks.
265        debug_assert_eq!(self.data.range().start, 0);
266        self.unsent = 0;
267    }
268
269    /// First stream offset unwritten by the application, i.e. the offset that the next write will
270    /// begin at
271    pub(super) fn offset(&self) -> u64 {
272        self.data.range().end
273    }
274
275    /// Whether all sent data has been acknowledged
276    pub(super) fn is_fully_acked(&self) -> bool {
277        self.data.len() == 0
278    }
279
280    /// Whether there's data to send
281    ///
282    /// There may be sent unacknowledged data even when this is false.
283    pub(super) fn has_unsent_data(&self) -> bool {
284        self.unsent != self.offset() || !self.retransmits.is_empty()
285    }
286
287    /// Compute the amount of data that hasn't been acknowledged
288    pub(super) fn unacked(&self) -> u64 {
289        self.data.len() as u64 - self.acks.iter().map(|x| x.end - x.start).sum::<u64>()
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use test_strategy::{Arbitrary, proptest};
296    use proptest::prelude::*;
297    use rand::{SeedableRng, rngs::StdRng};
298
299    use super::*;
300
301    #[test]
302    fn fragment_with_length() {
303        let mut buf = SendBuffer::new();
304        const MSG: &[u8] = b"Hello, world!";
305        buf.write(MSG.into());
306        // 0 byte offset => 19 bytes left => 13 byte data isn't enough
307        // with 8 bytes reserved for length 11 payload bytes will fit
308        assert_eq!(buf.poll_transmit(19), (0..11, true));
309        assert_eq!(
310            buf.poll_transmit(MSG.len() + 16 - 11),
311            (11..MSG.len() as u64, true)
312        );
313        assert_eq!(
314            buf.poll_transmit(58),
315            (MSG.len() as u64..MSG.len() as u64, true)
316        );
317    }
318
319    #[test]
320    fn fragment_without_length() {
321        let mut buf = SendBuffer::new();
322        const MSG: &[u8] = b"Hello, world with some extra data!";
323        buf.write(MSG.into());
324        // 0 byte offset => 19 bytes left => can be filled by 34 bytes payload
325        assert_eq!(buf.poll_transmit(19), (0..19, false));
326        assert_eq!(
327            buf.poll_transmit(MSG.len() - 19 + 1),
328            (19..MSG.len() as u64, false)
329        );
330        assert_eq!(
331            buf.poll_transmit(58),
332            (MSG.len() as u64..MSG.len() as u64, true)
333        );
334    }
335
336    #[test]
337    fn reserves_encoded_offset() {
338        let mut buf = SendBuffer::new();
339
340        // Pretend we have more than 1 GB of data in the buffer
341        let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]);
342        for _ in 0..1025 {
343            buf.write(chunk.clone());
344        }
345
346        const SIZE1: u64 = 64;
347        const SIZE2: u64 = 16 * 1024;
348        const SIZE3: u64 = 1024 * 1024 * 1024;
349
350        // Offset 0 requires no space
351        assert_eq!(buf.poll_transmit(16), (0..16, false));
352        buf.retransmit(0..16);
353        assert_eq!(buf.poll_transmit(16), (0..16, false));
354        let mut transmitted = 16u64;
355
356        // Offset 16 requires 1 byte
357        assert_eq!(
358            buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
359            (transmitted..SIZE1, false)
360        );
361        buf.retransmit(transmitted..SIZE1);
362        assert_eq!(
363            buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
364            (transmitted..SIZE1, false)
365        );
366        transmitted = SIZE1;
367
368        // Offset 64 requires 2 bytes
369        assert_eq!(
370            buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
371            (transmitted..SIZE2, false)
372        );
373        buf.retransmit(transmitted..SIZE2);
374        assert_eq!(
375            buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
376            (transmitted..SIZE2, false)
377        );
378        transmitted = SIZE2;
379
380        // Offset 16384 requires requires 4 bytes
381        assert_eq!(
382            buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
383            (transmitted..SIZE3, false)
384        );
385        buf.retransmit(transmitted..SIZE3);
386        assert_eq!(
387            buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
388            (transmitted..SIZE3, false)
389        );
390        transmitted = SIZE3;
391
392        // Offset 1GB requires 8 bytes
393        assert_eq!(
394            buf.poll_transmit(chunk.len() + 8),
395            (transmitted..transmitted + chunk.len() as u64, false)
396        );
397        buf.retransmit(transmitted..transmitted + chunk.len() as u64);
398        assert_eq!(
399            buf.poll_transmit(chunk.len() + 8),
400            (transmitted..transmitted + chunk.len() as u64, false)
401        );
402    }
403
404    #[test]
405    fn multiple_segments() {
406        let mut buf = SendBuffer::new();
407        const MSG: &[u8] = b"Hello, world!";
408        const MSG_LEN: u64 = MSG.len() as u64;
409
410        const SEG1: &[u8] = b"He";
411        buf.write(SEG1.into());
412        const SEG2: &[u8] = b"llo,";
413        buf.write(SEG2.into());
414        const SEG3: &[u8] = b" w";
415        buf.write(SEG3.into());
416        const SEG4: &[u8] = b"o";
417        buf.write(SEG4.into());
418        const SEG5: &[u8] = b"rld!";
419        buf.write(SEG5.into());
420
421        assert_eq!(aggregate_unacked(&buf), MSG);
422
423        assert_eq!(buf.poll_transmit(16), (0..8, true));
424        assert_eq!(buf.get(0..5), SEG1);
425        assert_eq!(buf.get(2..8), SEG2);
426        assert_eq!(buf.get(6..8), SEG3);
427
428        assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true));
429        assert_eq!(buf.get(8..MSG_LEN), SEG4);
430        assert_eq!(buf.get(9..MSG_LEN), SEG5);
431
432        assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true));
433
434        // Now drain the segments
435        buf.ack(0..1);
436        assert_eq!(aggregate_unacked(&buf), &MSG[1..]);
437        buf.ack(0..3);
438        assert_eq!(aggregate_unacked(&buf), &MSG[3..]);
439        buf.ack(3..5);
440        assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
441        buf.ack(7..9);
442        assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
443        buf.ack(4..7);
444        assert_eq!(aggregate_unacked(&buf), &MSG[9..]);
445        buf.ack(0..MSG_LEN);
446        assert_eq!(aggregate_unacked(&buf), &[] as &[u8]);
447    }
448
449    #[test]
450    fn retransmit() {
451        let mut buf = SendBuffer::new();
452        const MSG: &[u8] = b"Hello, world with extra data!";
453        buf.write(MSG.into());
454        // Transmit two frames
455        assert_eq!(buf.poll_transmit(16), (0..16, false));
456        assert_eq!(buf.poll_transmit(16), (16..23, true));
457        // Lose the first, but not the second
458        buf.retransmit(0..16);
459        // Ensure we only retransmit the lost frame, then continue sending fresh data
460        assert_eq!(buf.poll_transmit(16), (0..16, false));
461        assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true));
462        // Lose the second frame
463        buf.retransmit(16..23);
464        assert_eq!(buf.poll_transmit(16), (16..23, true));
465    }
466
467    #[test]
468    fn ack() {
469        let mut buf = SendBuffer::new();
470        const MSG: &[u8] = b"Hello, world!";
471        buf.write(MSG.into());
472        assert_eq!(buf.poll_transmit(16), (0..8, true));
473        buf.ack(0..8);
474        assert_eq!(aggregate_unacked(&buf), &MSG[8..]);
475    }
476
477    #[test]
478    fn reordered_ack() {
479        let mut buf = SendBuffer::new();
480        const MSG: &[u8] = b"Hello, world with extra data!";
481        buf.write(MSG.into());
482        assert_eq!(buf.poll_transmit(16), (0..16, false));
483        assert_eq!(buf.poll_transmit(16), (16..23, true));
484        buf.ack(16..23);
485        assert_eq!(aggregate_unacked(&buf), MSG);
486        buf.ack(0..16);
487        assert_eq!(aggregate_unacked(&buf), &MSG[23..]);
488        assert!(buf.acks.is_empty());
489    }
490
491    fn aggregate_unacked(buf: &SendBuffer) -> Vec<u8> {
492        buf.data.to_vec()
493    }
494
495    #[test]
496    #[should_panic(expected = "Requested range is outside of buffered data")]
497    fn send_buffer_get_out_of_range() {
498        let data = SendBufferData::default();
499        data.get(0..1);
500    }
501
502    #[test]
503    #[should_panic(expected = "Requested range is outside of buffered data")]
504    fn send_buffer_get_into_out_of_range() {
505        let data = SendBufferData::default();
506        let mut buf = Vec::new();
507        data.get_into(0..1, &mut buf);
508    }
509
510    #[derive(Debug, Clone, Arbitrary)]
511    enum Op {
512        // write the given bytes
513        Write(#[strategy(proptest::collection::vec(any::<u8>(), 0..1024))] Vec<u8>),
514        // ack a random range, the value is the rng seed
515        Ack(u64),
516        // retransmit a random range, the value is the rng seed
517        Retransmit(u64),
518        // poll_transmit with the given max len
519        PollTransmit(#[strategy(16usize..1024)] usize),
520    }
521
522    fn random_range(seed: u64, range: Range<u64>) -> Range<u64> {
523        if range.is_empty() {
524            return range;
525        }
526        let mut rng = StdRng::seed_from_u64(seed);
527        let a = rng.random_range(range.clone());
528        let b = rng.random_range(range);
529        if a < b { a..b } else { b..a }
530    }
531
532    #[proptest]
533    fn send_buffer_matches_reference(
534        #[strategy(proptest::collection::vec(any::<Op>(), 1..100))] ops: Vec<Op>,
535    ) {
536        let mut sb = SendBuffer::new();
537        // total bytes written so far
538        let mut total_bytes = 0u64;
539        // max offset that has been returned by poll_transmit
540        let mut max_send_offset = 0u64;
541        // max offset up to which data has been fully acked
542        let mut max_full_send_offset = 0u64;
543        println!("");
544        for op in ops {
545            match op {
546                Op::Write(data) => {
547                    total_bytes += data.len() as u64;
548                    println!("Writing {} bytes", data.len());
549                    sb.write(Bytes::from(data));
550                }
551                Op::Ack(seed) => {
552                    // only generate acks for data that has been sent
553                    let range = random_range(seed, 0..max_send_offset + 1);
554                    if range.is_empty() {
555                        continue;
556                    }
557                    // update fully acked range
558                    if range.contains(&max_full_send_offset) {
559                        max_full_send_offset = range.end;
560                    }
561                    println!("Acking range: {:?}", range);
562                    sb.ack(range);
563                }
564                Op::Retransmit(seed) => {
565                    let range = random_range(seed, max_full_send_offset..max_send_offset + 1);
566                    if range.is_empty() {
567                        continue;
568                    }
569                    println!("Retransmitting range: {:?}", range);
570                    sb.retransmit(range);
571                }
572                Op::PollTransmit(max_len) => {
573                    let (range, _partial) = sb.poll_transmit(max_len);
574                    max_send_offset = max_send_offset.max(range.end);
575                    let mut buf = Vec::new();
576                    println!("Getting data for range: {:?}", range);
577                    sb.get_into(range, &mut buf);
578                }
579            }
580
581        }
582    }
583}
584
585#[cfg(feature = "bench")]
586pub mod send_buffer_benches {
587    //! Bench fns for SendBuffer
588    //!
589    //! These are defined here and re-exported via `bench_exports` in lib.rs,
590    //! so we can access the private `SendBuffer` struct.
591    use criterion::Criterion;
592    use bytes::Bytes;
593    use super::SendBuffer;
594
595    /// Pathological case: many segments, get from end
596    pub fn get_into_many_segments(criterion: &mut Criterion) {
597        let mut group = criterion.benchmark_group("get_into_many_segments");
598        let mut buf = SendBuffer::new();
599
600        const SEGMENTS: u64 = 10000;
601        const SEGMENT_SIZE: u64 = 10;
602        const PACKET_SIZE: u64 = 1200;
603        const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
604
605        // 10000 segments of 10 bytes each = 100KB total (same data size)
606        for i in 0..SEGMENTS {
607            buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
608        }
609
610        let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
611        group.bench_function("get_into", |b| {
612            b.iter(|| {
613                // Get from end (very slow - scans through all 1000 segments)
614                tgt.clear();
615                buf.get_into(BYTES - PACKET_SIZE..BYTES, std::hint::black_box(&mut tgt));
616            });
617        });
618    }
619
620    /// Get segments in the old way, using a loop of get calls
621    pub fn get_loop_many_segments(criterion: &mut Criterion) {
622        let mut group = criterion.benchmark_group("get_loop_many_segments");
623        let mut buf = SendBuffer::new();
624
625        const SEGMENTS: u64 = 10000;
626        const SEGMENT_SIZE: u64 = 10;
627        const PACKET_SIZE: u64 = 1200;
628        const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
629
630        // 10000 segments of 10 bytes each = 100KB total (same data size)
631        for i in 0..SEGMENTS {
632            buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
633        }
634
635        let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
636        group.bench_function("get_loop", |b| {
637            b.iter(|| {
638                // Get from end (very slow - scans through all 1000 segments)
639                tgt.clear();
640                let mut range = BYTES - PACKET_SIZE..BYTES;
641                while range.start < range.end {
642                    let slice = std::hint::black_box(buf.get(range.clone()));
643                    range.start += slice.len() as u64;
644                    tgt.extend_from_slice(slice);
645                }
646            });
647        });
648    }
649}