iroh_quinn_proto/connection/
send_buffer.rs1use std::{collections::VecDeque, ops::Range};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4
5use crate::{VarInt, connection::streams::BytesOrSlice, range_set::ArrayRangeSet};
6
7#[derive(Default, Debug)]
9pub(super) struct SendBuffer {
10 data: SendBufferData,
18 unsent: u64,
22 acks: ArrayRangeSet,
29 retransmits: ArrayRangeSet,
36}
37
38const MAX_COMBINE: usize = 1452;
42
43#[derive(Default, Debug)]
46struct SendBufferData {
47 offset: u64,
49 len: usize,
51 segments: VecDeque<Bytes>,
53 last_segment: BytesMut,
55}
56
57impl SendBufferData {
58 fn len(&self) -> usize {
60 self.len
61 }
62
63 #[inline(always)]
65 fn range(&self) -> Range<u64> {
66 self.offset..self.offset + self.len as u64
67 }
68
69 fn append<'a>(&'a mut self, data: impl BytesOrSlice<'a>) {
71 self.len += data.len();
72 if data.len() > MAX_COMBINE {
73 if !self.last_segment.is_empty() {
75 self.segments.push_back(self.last_segment.split().freeze());
76 }
77 self.segments.push_back(data.into_bytes());
78 } else {
79 let rest = if self.last_segment.len() + data.len() > MAX_COMBINE
81 && !self.last_segment.is_empty()
82 {
83 let capacity = MAX_COMBINE.saturating_sub(self.last_segment.len());
85 let (curr, rest) = data.as_ref().split_at(capacity);
86 self.last_segment.put_slice(curr);
87 self.segments.push_back(self.last_segment.split().freeze());
88 rest
89 } else {
90 data.as_ref()
91 };
92 self.last_segment.extend_from_slice(rest);
94 }
95 }
96
97 fn pop_front(&mut self, n: usize) {
101 let mut n = n.min(self.len);
102 self.len -= n;
103 self.offset += n as u64;
104 while n > 0 {
105 let Some(front) = self.segments.front_mut() else {
107 break;
108 };
109 if front.len() <= n {
110 n -= front.len();
112 self.segments.pop_front();
113 } else {
114 front.advance(n);
116 n = 0;
117 }
118 }
119 self.last_segment.advance(n);
121 if self.segments.len() * 4 < self.segments.capacity() {
123 self.segments.shrink_to_fit();
124 }
125 }
126
127 fn segments_iter(&self) -> impl Iterator<Item = &[u8]> {
131 self.segments
132 .iter()
133 .map(|x| x.as_ref())
134 .chain(std::iter::once(self.last_segment.as_ref()))
135 }
136
137 #[cfg(any(test, feature = "bench"))]
141 fn get(&self, offsets: Range<u64>) -> &[u8] {
142 assert!(
143 offsets.start >= self.range().start && offsets.end <= self.range().end,
144 "Requested range is outside of buffered data"
145 );
146 let offsets = Range {
148 start: (offsets.start - self.offset) as usize,
149 end: (offsets.end - self.offset) as usize,
150 };
151 let mut segment_offset = 0;
152 for segment in self.segments_iter() {
153 if offsets.start >= segment_offset && offsets.start < segment_offset + segment.len() {
154 let start = offsets.start - segment_offset;
155 let end = offsets.end - segment_offset;
156
157 return &segment[start..end.min(segment.len())];
158 }
159 segment_offset += segment.len();
160 }
161
162 unreachable!("impossible if segments and range are consistent");
163 }
164
165 fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
166 assert!(
167 offsets.start >= self.range().start && offsets.end <= self.range().end,
168 "Requested range is outside of buffered data"
169 );
170 let offsets = Range {
172 start: (offsets.start - self.offset) as usize,
173 end: (offsets.end - self.offset) as usize,
174 };
175 let mut segment_offset = 0;
176 for segment in self.segments_iter() {
177 let start = segment_offset.max(offsets.start);
179 let end = (segment_offset + segment.len()).min(offsets.end);
180 if start < end {
181 buf.put_slice(&segment[start - segment_offset..end - segment_offset]);
183 }
184 segment_offset += segment.len();
185 if segment_offset >= offsets.end {
186 break;
188 }
189 }
190 }
191
192 #[cfg(test)]
193 fn to_vec(&self) -> Vec<u8> {
194 let mut result = Vec::with_capacity(self.len);
195 for segment in self.segments_iter() {
196 result.extend_from_slice(segment);
197 }
198 result
199 }
200}
201
202impl SendBuffer {
203 pub(super) fn new() -> Self {
205 Self::default()
206 }
207
208 pub(super) fn write<'a>(&'a mut self, data: impl BytesOrSlice<'a>) {
210 self.data.append(data);
211 }
212
213 pub(super) fn ack(&mut self, mut range: Range<u64>) {
215 let base_offset = self.fully_acked_offset();
217 range.start = base_offset.max(range.start);
218 range.end = base_offset.max(range.end);
219
220 self.acks.insert(range);
221
222 while self.acks.min() == Some(self.fully_acked_offset()) {
223 let prefix = self.acks.pop_min().unwrap();
224 let to_advance = (prefix.end - prefix.start) as usize;
225 self.data.pop_front(to_advance);
226 }
227
228 self.retransmits.remove(0..self.fully_acked_offset());
233 }
234
235 pub(super) fn poll_transmit(&mut self, mut max_len: usize) -> (Range<u64>, bool) {
249 debug_assert!(max_len >= 8 + 8);
250 let mut encode_length = false;
251
252 if let Some(range) = self.retransmits.pop_min() {
253 if range.start != 0 {
258 max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) });
259 }
260 if range.end - range.start < max_len as u64 {
261 encode_length = true;
262 max_len -= 8;
263 }
264
265 let end = range.end.min((max_len as u64).saturating_add(range.start));
266 if end != range.end {
267 self.retransmits.insert(end..range.end);
268 }
269 return (range.start..end, encode_length);
270 }
271
272 if self.unsent != 0 {
277 max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) });
278 }
279 if self.offset() - self.unsent < max_len as u64 {
280 encode_length = true;
281 max_len -= 8;
282 }
283
284 let end = self
285 .offset()
286 .min((max_len as u64).saturating_add(self.unsent));
287 let result = self.unsent..end;
288 self.unsent = end;
289 (result, encode_length)
290 }
291
292 #[cfg(any(test, feature = "bench"))]
299 pub(super) fn get(&self, offsets: Range<u64>) -> &[u8] {
300 self.data.get(offsets)
301 }
302
303 pub(super) fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
304 self.data.get_into(offsets, buf)
305 }
306
307 pub(super) fn retransmit(&mut self, mut range: Range<u64>) {
309 debug_assert!(range.end <= self.unsent, "unsent data can't be lost");
310 range.start = range.start.max(self.fully_acked_offset());
317 self.retransmits.insert(range);
318 }
319
320 pub(super) fn retransmit_all_for_0rtt(&mut self) {
321 debug_assert_eq!(self.fully_acked_offset(), 0);
323 self.unsent = 0;
324 }
325
326 fn fully_acked_offset(&self) -> u64 {
328 self.data.range().start
329 }
330
331 pub(super) fn offset(&self) -> u64 {
334 self.data.range().end
335 }
336
337 pub(super) fn is_fully_acked(&self) -> bool {
339 self.data.len() == 0
340 }
341
342 pub(super) fn has_unsent_data(&self) -> bool {
346 self.unsent != self.offset() || !self.retransmits.is_empty()
347 }
348
349 pub(super) fn unacked(&self) -> u64 {
351 self.data.len() as u64 - self.acks.iter().map(|x| x.end - x.start).sum::<u64>()
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn fragment_with_length() {
361 let mut buf = SendBuffer::new();
362 const MSG: &[u8] = b"Hello, world!";
363 buf.write(MSG);
364 assert_eq!(buf.poll_transmit(19), (0..11, true));
367 assert_eq!(
368 buf.poll_transmit(MSG.len() + 16 - 11),
369 (11..MSG.len() as u64, true)
370 );
371 assert_eq!(
372 buf.poll_transmit(58),
373 (MSG.len() as u64..MSG.len() as u64, true)
374 );
375 }
376
377 #[test]
378 fn fragment_without_length() {
379 let mut buf = SendBuffer::new();
380 const MSG: &[u8] = b"Hello, world with some extra data!";
381 buf.write(MSG);
382 assert_eq!(buf.poll_transmit(19), (0..19, false));
384 assert_eq!(
385 buf.poll_transmit(MSG.len() - 19 + 1),
386 (19..MSG.len() as u64, false)
387 );
388 assert_eq!(
389 buf.poll_transmit(58),
390 (MSG.len() as u64..MSG.len() as u64, true)
391 );
392 }
393
394 #[test]
395 fn reserves_encoded_offset() {
396 let mut buf = SendBuffer::new();
397
398 let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]);
400 for _ in 0..1025 {
401 buf.write(chunk.clone());
402 }
403
404 const SIZE1: u64 = 64;
405 const SIZE2: u64 = 16 * 1024;
406 const SIZE3: u64 = 1024 * 1024 * 1024;
407
408 assert_eq!(buf.poll_transmit(16), (0..16, false));
410 buf.retransmit(0..16);
411 assert_eq!(buf.poll_transmit(16), (0..16, false));
412 let mut transmitted = 16u64;
413
414 assert_eq!(
416 buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
417 (transmitted..SIZE1, false)
418 );
419 buf.retransmit(transmitted..SIZE1);
420 assert_eq!(
421 buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
422 (transmitted..SIZE1, false)
423 );
424 transmitted = SIZE1;
425
426 assert_eq!(
428 buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
429 (transmitted..SIZE2, false)
430 );
431 buf.retransmit(transmitted..SIZE2);
432 assert_eq!(
433 buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
434 (transmitted..SIZE2, false)
435 );
436 transmitted = SIZE2;
437
438 assert_eq!(
440 buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
441 (transmitted..SIZE3, false)
442 );
443 buf.retransmit(transmitted..SIZE3);
444 assert_eq!(
445 buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
446 (transmitted..SIZE3, false)
447 );
448 transmitted = SIZE3;
449
450 assert_eq!(
452 buf.poll_transmit(chunk.len() + 8),
453 (transmitted..transmitted + chunk.len() as u64, false)
454 );
455 buf.retransmit(transmitted..transmitted + chunk.len() as u64);
456 assert_eq!(
457 buf.poll_transmit(chunk.len() + 8),
458 (transmitted..transmitted + chunk.len() as u64, false)
459 );
460 }
461
462 #[test]
464 fn multiple_large_segments() {
465 const N: usize = 2000;
467 const K: u64 = N as u64;
468 fn dup(data: &[u8]) -> Bytes {
469 let mut buf = BytesMut::with_capacity(data.len() * N);
470 for c in data {
471 for _ in 0..N {
472 buf.put_u8(*c);
473 }
474 }
475 buf.freeze()
476 }
477
478 fn same(a: &[u8], b: &[u8]) -> bool {
479 std::ptr::eq(a.as_ptr(), b.as_ptr())
481 }
482
483 let mut buf = SendBuffer::new();
484 let msg: Bytes = dup(b"Hello, world!");
485 let msg_len: u64 = msg.len() as u64;
486
487 let seg1: Bytes = dup(b"He");
488 buf.write(seg1.clone());
489 let seg2: Bytes = dup(b"llo,");
490 buf.write(seg2.clone());
491 let seg3: Bytes = dup(b" w");
492 buf.write(seg3.clone());
493 let seg4: Bytes = dup(b"o");
494 buf.write(seg4.clone());
495 let seg5: Bytes = dup(b"rld!");
496 buf.write(seg5.clone());
497 assert_eq!(aggregate_unacked(&buf), msg);
498 assert!(same(buf.get(0..5 * K), &seg1));
500 assert!(same(buf.get(2 * K..8 * K), &seg2));
501 assert!(same(buf.get(6 * K..8 * K), &seg3));
502 assert!(same(buf.get(8 * 2000..msg_len), &seg4));
503 assert!(same(buf.get(9 * 2000..msg_len), &seg5));
504 buf.ack(0..K);
506 assert_eq!(aggregate_unacked(&buf), &msg[N..]);
507 buf.ack(0..3 * K);
508 assert_eq!(aggregate_unacked(&buf), &msg[3 * N..]);
509 buf.ack(3 * K..5 * K);
510 assert_eq!(aggregate_unacked(&buf), &msg[5 * N..]);
511 buf.ack(7 * K..9 * K);
513 assert_eq!(aggregate_unacked(&buf), &msg[5 * N..]);
514 buf.ack(4 * K..7 * K);
516 assert_eq!(aggregate_unacked(&buf), &msg[9 * N..]);
517 buf.ack(0..msg_len);
519 assert_eq!(aggregate_unacked(&buf), &[] as &[u8]);
520 }
521
522 #[test]
523 fn retransmit() {
524 let mut buf = SendBuffer::new();
525 const MSG: &[u8] = b"Hello, world with extra data!";
526 buf.write(MSG);
527 assert_eq!(buf.poll_transmit(16), (0..16, false));
529 assert_eq!(buf.poll_transmit(16), (16..23, true));
530 buf.retransmit(0..16);
532 assert_eq!(buf.poll_transmit(16), (0..16, false));
534 assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true));
535 buf.retransmit(16..23);
537 assert_eq!(buf.poll_transmit(16), (16..23, true));
538 }
539
540 #[test]
541 fn ack() {
542 let mut buf = SendBuffer::new();
543 const MSG: &[u8] = b"Hello, world!";
544 buf.write(MSG);
545 assert_eq!(buf.poll_transmit(16), (0..8, true));
546 buf.ack(0..8);
547 assert_eq!(aggregate_unacked(&buf), &MSG[8..]);
548 }
549
550 #[test]
551 fn reordered_ack() {
552 let mut buf = SendBuffer::new();
553 const MSG: &[u8] = b"Hello, world with extra data!";
554 buf.write(MSG);
555 assert_eq!(buf.poll_transmit(16), (0..16, false));
556 assert_eq!(buf.poll_transmit(16), (16..23, true));
557 buf.ack(16..23);
558 assert_eq!(aggregate_unacked(&buf), MSG);
559 buf.ack(0..16);
560 assert_eq!(aggregate_unacked(&buf), &MSG[23..]);
561 assert!(buf.acks.is_empty());
562 }
563
564 fn aggregate_unacked(buf: &SendBuffer) -> Vec<u8> {
565 buf.data.to_vec()
566 }
567
568 #[test]
569 #[should_panic(expected = "Requested range is outside of buffered data")]
570 fn send_buffer_get_out_of_range() {
571 let data = SendBufferData::default();
572 data.get(0..1);
573 }
574
575 #[test]
576 #[should_panic(expected = "Requested range is outside of buffered data")]
577 fn send_buffer_get_into_out_of_range() {
578 let data = SendBufferData::default();
579 let mut buf = Vec::new();
580 data.get_into(0..1, &mut buf);
581 }
582}
583
584#[cfg(all(test, not(target_family = "wasm")))]
585mod proptests {
586 use super::*;
587
588 use proptest::prelude::*;
589 use test_strategy::{Arbitrary, proptest};
590 use crate::tests::subscribe;
591 use tracing::trace;
592
593 #[derive(Debug, Clone, Arbitrary)]
594 enum Op {
595 Write(#[strategy(proptest::collection::vec(any::<u8>(), 0..1024))] Vec<u8>),
597 Ack(Range<u64>),
599 Retransmit(Range<u64>),
601 PollTransmit(#[strategy(16usize..1024)] usize),
603 }
604
605 fn map_range(input: Range<u64>, target: Range<u64>) -> Range<u64> {
612 if target.is_empty() {
613 return target;
614 }
615 let size = target.end - target.start;
616 let a = target.start + (input.start % size);
617 let b = target.start + (input.end % size);
618 a.min(b)..a.max(b)
619 }
620
621 #[proptest]
622 fn send_buffer_matches_reference(
623 #[strategy(proptest::collection::vec(any::<Op>(), 1..100))] ops: Vec<Op>,
624 ) {
625 let _guard = subscribe();
626 let mut sb = SendBuffer::new();
627 let mut buf = Vec::new();
629 let mut max_send_offset = 0u64;
631 let mut max_full_send_offset = 0u64;
633 trace!("");
634 for op in ops {
635 match op {
636 Op::Write(data) => {
637 trace!("Op::Write({})", data.len());
638 buf.extend_from_slice(&data);
639 sb.write(Bytes::from(data));
640 }
641 Op::Ack(range) => {
642 let range = map_range(range, 0..max_send_offset);
644 if range.contains(&max_full_send_offset) {
646 max_full_send_offset = range.end;
647 }
648 trace!("Op::Ack({:?})", range);
649 sb.ack(range);
650 }
651 Op::Retransmit(range) => {
652 let range = map_range(range, 0..max_send_offset);
654 trace!("Op::Retransmit({:?})", range);
655 sb.retransmit(range);
656 }
657 Op::PollTransmit(max_len) => {
658 trace!("Op::PollTransmit({})", max_len);
659 let (range, _partial) = sb.poll_transmit(max_len);
660 max_send_offset = max_send_offset.max(range.end);
661 assert!(
662 range.start >= max_full_send_offset,
663 "poll_transmit returned already fully acked data: range={:?}, max_full_send_offset={}",
664 range,
665 max_full_send_offset
666 );
667
668 let mut t1 = Vec::new();
669 sb.get_into(range.clone(), &mut t1);
670
671 let mut t2 = Vec::new();
672 t2.extend_from_slice(&buf[range.start as usize..range.end as usize]);
673
674 assert_eq!(t1, t2, "Data mismatch for range {:?}", range);
675 }
676 }
677 }
678 trace!("Op::Retransmit({:?})", 0..max_send_offset);
680 sb.retransmit(0..max_send_offset);
681 loop {
682 trace!("Op::PollTransmit({})", 1024);
683 let (range, _partial) = sb.poll_transmit(1024);
684 if range.is_empty() {
685 break;
686 }
687 trace!("Op::Ack({:?})", range);
688 sb.ack(range);
689 }
690 assert!(
691 sb.is_fully_acked(),
692 "SendBuffer not fully acked at end of ops"
693 );
694 }
695}
696
697#[cfg(feature = "bench")]
698pub mod send_buffer_benches {
699 use bytes::Bytes;
704 use criterion::Criterion;
705
706 use super::SendBuffer;
707
708 pub fn get_into_many_segments(criterion: &mut Criterion) {
710 let mut group = criterion.benchmark_group("get_into_many_segments");
711 let mut buf = SendBuffer::new();
712
713 const SEGMENTS: u64 = 10000;
714 const SEGMENT_SIZE: u64 = 10;
715 const PACKET_SIZE: u64 = 1200;
716 const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
717
718 for i in 0..SEGMENTS {
720 buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
721 }
722
723 let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
724 group.bench_function("get_into", |b| {
725 b.iter(|| {
726 tgt.clear();
728 buf.get_into(BYTES - PACKET_SIZE..BYTES, std::hint::black_box(&mut tgt));
729 });
730 });
731 }
732
733 pub fn get_loop_many_segments(criterion: &mut Criterion) {
735 let mut group = criterion.benchmark_group("get_loop_many_segments");
736 let mut buf = SendBuffer::new();
737
738 const SEGMENTS: u64 = 10000;
739 const SEGMENT_SIZE: u64 = 10;
740 const PACKET_SIZE: u64 = 1200;
741 const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
742
743 for i in 0..SEGMENTS {
745 buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
746 }
747
748 let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
749 group.bench_function("get_loop", |b| {
750 b.iter(|| {
751 tgt.clear();
753 let mut range = BYTES - PACKET_SIZE..BYTES;
754 while range.start < range.end {
755 let slice = std::hint::black_box(buf.get(range.clone()));
756 range.start += slice.len() as u64;
757 tgt.extend_from_slice(slice);
758 }
759 });
760 });
761 }
762}