iroh_quinn_proto/connection/
send_buffer.rs1use std::{collections::VecDeque, ops::Range};
2
3use bytes::{Buf, BufMut, Bytes};
4
5use crate::{VarInt, range_set::ArrayRangeSet};
6
7#[derive(Default, Debug)]
9pub(super) struct SendBuffer {
10 data: SendBufferData,
18 unsent: u64,
22 acks: ArrayRangeSet,
29 retransmits: ArrayRangeSet,
36}
37
38#[derive(Default, Debug)]
41struct SendBufferData {
42 offset: u64,
44 segments: VecDeque<Bytes>,
46 len: usize,
48}
49
50impl SendBufferData {
51 fn len(&self) -> usize {
53 self.len
54 }
55
56 #[inline(always)]
58 fn range(&self) -> Range<u64> {
59 self.offset..self.offset + self.len as u64
60 }
61
62 fn append(&mut self, data: Bytes) {
64 self.len += data.len();
65 self.segments.push_back(data);
66 }
67
68 fn pop_front(&mut self, n: usize) {
72 let mut n = n.min(self.len);
73 self.len -= n;
74 self.offset += n as u64;
75 while n > 0 {
76 let front = self.segments.front_mut().expect("Expected buffered data");
77
78 if front.len() <= n {
79 n -= front.len();
81 self.segments.pop_front();
82 } else {
83 front.advance(n);
85 n = 0;
86 }
87 }
88 if self.segments.len() * 4 < self.segments.capacity() {
89 self.segments.shrink_to_fit();
90 }
91 }
92
93 #[cfg(any(test, feature = "bench"))]
97 fn get(&self, offsets: Range<u64>) -> &[u8] {
98 assert!(
99 offsets.start >= self.range().start && offsets.end <= self.range().end,
100 "Requested range is outside of buffered data"
101 );
102 let offsets = Range {
104 start: (offsets.start - self.offset) as usize,
105 end: (offsets.end - self.offset) as usize,
106 };
107 let mut segment_offset = 0;
108 for segment in self.segments.iter() {
109 if offsets.start >= segment_offset && offsets.start < segment_offset + segment.len() {
110 let start = offsets.start - segment_offset;
111 let end = offsets.end - segment_offset;
112
113 return &segment[start..end.min(segment.len())];
114 }
115 segment_offset += segment.len();
116 }
117
118 unreachable!("impossible if segments and range are consistent");
119 }
120
121 fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
122 assert!(
123 offsets.start >= self.range().start && offsets.end <= self.range().end,
124 "Requested range is outside of buffered data"
125 );
126 let offsets = Range {
128 start: (offsets.start - self.offset) as usize,
129 end: (offsets.end - self.offset) as usize,
130 };
131 let mut segment_offset = 0;
132 for segment in self.segments.iter() {
133 let start = segment_offset.max(offsets.start);
135 let end = (segment_offset + segment.len()).min(offsets.end);
136 if start < end {
137 buf.put_slice(&segment[start - segment_offset..end - segment_offset]);
139 }
140 segment_offset += segment.len();
141 if segment_offset >= offsets.end {
142 break;
144 }
145 }
146 }
147
148 #[cfg(test)]
149 fn to_vec(&self) -> Vec<u8> {
150 let mut result = Vec::with_capacity(self.len);
151 for segment in self.segments.iter() {
152 result.extend_from_slice(&segment[..]);
153 }
154 result
155 }
156}
157
158impl SendBuffer {
159 pub(super) fn new() -> Self {
161 Self::default()
162 }
163
164 pub(super) fn write(&mut self, data: Bytes) {
166 self.data.append(data);
167 }
168
169 pub(super) fn ack(&mut self, mut range: Range<u64>) {
171 let base_offset = self.data.range().start;
173 range.start = base_offset.max(range.start);
174 range.end = base_offset.max(range.end);
175
176 self.acks.insert(range);
177
178 while self.acks.min() == Some(self.data.range().start) {
179 let prefix = self.acks.pop_min().unwrap();
180 let to_advance = (prefix.end - prefix.start) as usize;
181 self.data.pop_front(to_advance);
182 }
183 }
184
185 pub(super) fn poll_transmit(&mut self, mut max_len: usize) -> (Range<u64>, bool) {
199 debug_assert!(max_len >= 8 + 8);
200 let mut encode_length = false;
201
202 if let Some(range) = self.retransmits.pop_min() {
203 if range.start != 0 {
208 max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) });
209 }
210 if range.end - range.start < max_len as u64 {
211 encode_length = true;
212 max_len -= 8;
213 }
214
215 let end = range.end.min((max_len as u64).saturating_add(range.start));
216 if end != range.end {
217 self.retransmits.insert(end..range.end);
218 }
219 return (range.start..end, encode_length);
220 }
221
222 if self.unsent != 0 {
227 max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) });
228 }
229 if self.offset() - self.unsent < max_len as u64 {
230 encode_length = true;
231 max_len -= 8;
232 }
233
234 let end = self
235 .offset()
236 .min((max_len as u64).saturating_add(self.unsent));
237 let result = self.unsent..end;
238 self.unsent = end;
239 (result, encode_length)
240 }
241
242 #[cfg(any(test, feature = "bench"))]
249 pub(super) fn get(&self, offsets: Range<u64>) -> &[u8] {
250 self.data.get(offsets)
251 }
252
253 pub(super) fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
254 self.data.get_into(offsets, buf)
255 }
256
257 pub(super) fn retransmit(&mut self, range: Range<u64>) {
259 debug_assert!(range.end <= self.unsent, "unsent data can't be lost");
260 self.retransmits.insert(range);
261 }
262
263 pub(super) fn retransmit_all_for_0rtt(&mut self) {
264 debug_assert_eq!(self.data.range().start, 0);
266 self.unsent = 0;
267 }
268
269 pub(super) fn offset(&self) -> u64 {
272 self.data.range().end
273 }
274
275 pub(super) fn is_fully_acked(&self) -> bool {
277 self.data.len() == 0
278 }
279
280 pub(super) fn has_unsent_data(&self) -> bool {
284 self.unsent != self.offset() || !self.retransmits.is_empty()
285 }
286
287 pub(super) fn unacked(&self) -> u64 {
289 self.data.len() as u64 - self.acks.iter().map(|x| x.end - x.start).sum::<u64>()
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use test_strategy::{Arbitrary, proptest};
296 use proptest::prelude::*;
297 use rand::{SeedableRng, rngs::StdRng};
298
299 use super::*;
300
301 #[test]
302 fn fragment_with_length() {
303 let mut buf = SendBuffer::new();
304 const MSG: &[u8] = b"Hello, world!";
305 buf.write(MSG.into());
306 assert_eq!(buf.poll_transmit(19), (0..11, true));
309 assert_eq!(
310 buf.poll_transmit(MSG.len() + 16 - 11),
311 (11..MSG.len() as u64, true)
312 );
313 assert_eq!(
314 buf.poll_transmit(58),
315 (MSG.len() as u64..MSG.len() as u64, true)
316 );
317 }
318
319 #[test]
320 fn fragment_without_length() {
321 let mut buf = SendBuffer::new();
322 const MSG: &[u8] = b"Hello, world with some extra data!";
323 buf.write(MSG.into());
324 assert_eq!(buf.poll_transmit(19), (0..19, false));
326 assert_eq!(
327 buf.poll_transmit(MSG.len() - 19 + 1),
328 (19..MSG.len() as u64, false)
329 );
330 assert_eq!(
331 buf.poll_transmit(58),
332 (MSG.len() as u64..MSG.len() as u64, true)
333 );
334 }
335
336 #[test]
337 fn reserves_encoded_offset() {
338 let mut buf = SendBuffer::new();
339
340 let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]);
342 for _ in 0..1025 {
343 buf.write(chunk.clone());
344 }
345
346 const SIZE1: u64 = 64;
347 const SIZE2: u64 = 16 * 1024;
348 const SIZE3: u64 = 1024 * 1024 * 1024;
349
350 assert_eq!(buf.poll_transmit(16), (0..16, false));
352 buf.retransmit(0..16);
353 assert_eq!(buf.poll_transmit(16), (0..16, false));
354 let mut transmitted = 16u64;
355
356 assert_eq!(
358 buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
359 (transmitted..SIZE1, false)
360 );
361 buf.retransmit(transmitted..SIZE1);
362 assert_eq!(
363 buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
364 (transmitted..SIZE1, false)
365 );
366 transmitted = SIZE1;
367
368 assert_eq!(
370 buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
371 (transmitted..SIZE2, false)
372 );
373 buf.retransmit(transmitted..SIZE2);
374 assert_eq!(
375 buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
376 (transmitted..SIZE2, false)
377 );
378 transmitted = SIZE2;
379
380 assert_eq!(
382 buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
383 (transmitted..SIZE3, false)
384 );
385 buf.retransmit(transmitted..SIZE3);
386 assert_eq!(
387 buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
388 (transmitted..SIZE3, false)
389 );
390 transmitted = SIZE3;
391
392 assert_eq!(
394 buf.poll_transmit(chunk.len() + 8),
395 (transmitted..transmitted + chunk.len() as u64, false)
396 );
397 buf.retransmit(transmitted..transmitted + chunk.len() as u64);
398 assert_eq!(
399 buf.poll_transmit(chunk.len() + 8),
400 (transmitted..transmitted + chunk.len() as u64, false)
401 );
402 }
403
404 #[test]
405 fn multiple_segments() {
406 let mut buf = SendBuffer::new();
407 const MSG: &[u8] = b"Hello, world!";
408 const MSG_LEN: u64 = MSG.len() as u64;
409
410 const SEG1: &[u8] = b"He";
411 buf.write(SEG1.into());
412 const SEG2: &[u8] = b"llo,";
413 buf.write(SEG2.into());
414 const SEG3: &[u8] = b" w";
415 buf.write(SEG3.into());
416 const SEG4: &[u8] = b"o";
417 buf.write(SEG4.into());
418 const SEG5: &[u8] = b"rld!";
419 buf.write(SEG5.into());
420
421 assert_eq!(aggregate_unacked(&buf), MSG);
422
423 assert_eq!(buf.poll_transmit(16), (0..8, true));
424 assert_eq!(buf.get(0..5), SEG1);
425 assert_eq!(buf.get(2..8), SEG2);
426 assert_eq!(buf.get(6..8), SEG3);
427
428 assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true));
429 assert_eq!(buf.get(8..MSG_LEN), SEG4);
430 assert_eq!(buf.get(9..MSG_LEN), SEG5);
431
432 assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true));
433
434 buf.ack(0..1);
436 assert_eq!(aggregate_unacked(&buf), &MSG[1..]);
437 buf.ack(0..3);
438 assert_eq!(aggregate_unacked(&buf), &MSG[3..]);
439 buf.ack(3..5);
440 assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
441 buf.ack(7..9);
442 assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
443 buf.ack(4..7);
444 assert_eq!(aggregate_unacked(&buf), &MSG[9..]);
445 buf.ack(0..MSG_LEN);
446 assert_eq!(aggregate_unacked(&buf), &[] as &[u8]);
447 }
448
449 #[test]
450 fn retransmit() {
451 let mut buf = SendBuffer::new();
452 const MSG: &[u8] = b"Hello, world with extra data!";
453 buf.write(MSG.into());
454 assert_eq!(buf.poll_transmit(16), (0..16, false));
456 assert_eq!(buf.poll_transmit(16), (16..23, true));
457 buf.retransmit(0..16);
459 assert_eq!(buf.poll_transmit(16), (0..16, false));
461 assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true));
462 buf.retransmit(16..23);
464 assert_eq!(buf.poll_transmit(16), (16..23, true));
465 }
466
467 #[test]
468 fn ack() {
469 let mut buf = SendBuffer::new();
470 const MSG: &[u8] = b"Hello, world!";
471 buf.write(MSG.into());
472 assert_eq!(buf.poll_transmit(16), (0..8, true));
473 buf.ack(0..8);
474 assert_eq!(aggregate_unacked(&buf), &MSG[8..]);
475 }
476
477 #[test]
478 fn reordered_ack() {
479 let mut buf = SendBuffer::new();
480 const MSG: &[u8] = b"Hello, world with extra data!";
481 buf.write(MSG.into());
482 assert_eq!(buf.poll_transmit(16), (0..16, false));
483 assert_eq!(buf.poll_transmit(16), (16..23, true));
484 buf.ack(16..23);
485 assert_eq!(aggregate_unacked(&buf), MSG);
486 buf.ack(0..16);
487 assert_eq!(aggregate_unacked(&buf), &MSG[23..]);
488 assert!(buf.acks.is_empty());
489 }
490
491 fn aggregate_unacked(buf: &SendBuffer) -> Vec<u8> {
492 buf.data.to_vec()
493 }
494
495 #[test]
496 #[should_panic(expected = "Requested range is outside of buffered data")]
497 fn send_buffer_get_out_of_range() {
498 let data = SendBufferData::default();
499 data.get(0..1);
500 }
501
502 #[test]
503 #[should_panic(expected = "Requested range is outside of buffered data")]
504 fn send_buffer_get_into_out_of_range() {
505 let data = SendBufferData::default();
506 let mut buf = Vec::new();
507 data.get_into(0..1, &mut buf);
508 }
509
510 #[derive(Debug, Clone, Arbitrary)]
511 enum Op {
512 Write(#[strategy(proptest::collection::vec(any::<u8>(), 0..1024))] Vec<u8>),
514 Ack(u64),
516 Retransmit(u64),
518 PollTransmit(#[strategy(16usize..1024)] usize),
520 }
521
522 fn random_range(seed: u64, range: Range<u64>) -> Range<u64> {
523 if range.is_empty() {
524 return range;
525 }
526 let mut rng = StdRng::seed_from_u64(seed);
527 let a = rng.random_range(range.clone());
528 let b = rng.random_range(range);
529 if a < b { a..b } else { b..a }
530 }
531
532 #[proptest]
533 fn send_buffer_matches_reference(
534 #[strategy(proptest::collection::vec(any::<Op>(), 1..100))] ops: Vec<Op>,
535 ) {
536 let mut sb = SendBuffer::new();
537 let mut total_bytes = 0u64;
539 let mut max_send_offset = 0u64;
541 let mut max_full_send_offset = 0u64;
543 println!("");
544 for op in ops {
545 match op {
546 Op::Write(data) => {
547 total_bytes += data.len() as u64;
548 println!("Writing {} bytes", data.len());
549 sb.write(Bytes::from(data));
550 }
551 Op::Ack(seed) => {
552 let range = random_range(seed, 0..max_send_offset + 1);
554 if range.is_empty() {
555 continue;
556 }
557 if range.contains(&max_full_send_offset) {
559 max_full_send_offset = range.end;
560 }
561 println!("Acking range: {:?}", range);
562 sb.ack(range);
563 }
564 Op::Retransmit(seed) => {
565 let range = random_range(seed, max_full_send_offset..max_send_offset + 1);
566 if range.is_empty() {
567 continue;
568 }
569 println!("Retransmitting range: {:?}", range);
570 sb.retransmit(range);
571 }
572 Op::PollTransmit(max_len) => {
573 let (range, _partial) = sb.poll_transmit(max_len);
574 max_send_offset = max_send_offset.max(range.end);
575 let mut buf = Vec::new();
576 println!("Getting data for range: {:?}", range);
577 sb.get_into(range, &mut buf);
578 }
579 }
580
581 }
582 }
583}
584
585#[cfg(feature = "bench")]
586pub mod send_buffer_benches {
587 use criterion::Criterion;
592 use bytes::Bytes;
593 use super::SendBuffer;
594
595 pub fn get_into_many_segments(criterion: &mut Criterion) {
597 let mut group = criterion.benchmark_group("get_into_many_segments");
598 let mut buf = SendBuffer::new();
599
600 const SEGMENTS: u64 = 10000;
601 const SEGMENT_SIZE: u64 = 10;
602 const PACKET_SIZE: u64 = 1200;
603 const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
604
605 for i in 0..SEGMENTS {
607 buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
608 }
609
610 let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
611 group.bench_function("get_into", |b| {
612 b.iter(|| {
613 tgt.clear();
615 buf.get_into(BYTES - PACKET_SIZE..BYTES, std::hint::black_box(&mut tgt));
616 });
617 });
618 }
619
620 pub fn get_loop_many_segments(criterion: &mut Criterion) {
622 let mut group = criterion.benchmark_group("get_loop_many_segments");
623 let mut buf = SendBuffer::new();
624
625 const SEGMENTS: u64 = 10000;
626 const SEGMENT_SIZE: u64 = 10;
627 const PACKET_SIZE: u64 = 1200;
628 const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
629
630 for i in 0..SEGMENTS {
632 buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
633 }
634
635 let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
636 group.bench_function("get_loop", |b| {
637 b.iter(|| {
638 tgt.clear();
640 let mut range = BYTES - PACKET_SIZE..BYTES;
641 while range.start < range.end {
642 let slice = std::hint::black_box(buf.get(range.clone()));
643 range.start += slice.len() as u64;
644 tgt.extend_from_slice(slice);
645 }
646 });
647 });
648 }
649}