iroh_quinn_proto/connection/
send_buffer.rs1use std::{collections::VecDeque, ops::Range};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4
5use crate::{VarInt, range_set::ArrayRangeSet};
6
7#[derive(Default, Debug)]
9pub(super) struct SendBuffer {
10 data: SendBufferData,
18 unsent: u64,
22 acks: ArrayRangeSet,
29 retransmits: ArrayRangeSet,
36}
37
38const MAX_COMBINE: usize = 1024;
42
43#[derive(Default, Debug)]
46struct SendBufferData {
47 offset: u64,
49 len: usize,
51 segments: VecDeque<Bytes>,
53 last_segment: BytesMut,
55}
56
57impl SendBufferData {
58 fn len(&self) -> usize {
60 self.len
61 }
62
63 #[inline(always)]
65 fn range(&self) -> Range<u64> {
66 self.offset..self.offset + self.len as u64
67 }
68
69 fn append(&mut self, data: Bytes) {
71 self.len += data.len();
72 if data.len() > MAX_COMBINE {
73 if !self.last_segment.is_empty() {
75 self.segments.push_back(self.last_segment.split().freeze());
76 }
77 self.segments.push_back(data);
78 } else {
79 if self.last_segment.len() + data.len() > MAX_COMBINE && !self.last_segment.is_empty() {
81 self.segments.push_back(self.last_segment.split().freeze());
82 }
83 self.last_segment.extend_from_slice(&data);
84 }
85 }
86
87 fn pop_front(&mut self, n: usize) {
91 let mut n = n.min(self.len);
92 self.len -= n;
93 self.offset += n as u64;
94 while n > 0 {
95 let Some(front) = self.segments.front_mut() else {
97 break;
98 };
99 if front.len() <= n {
100 n -= front.len();
102 self.segments.pop_front();
103 } else {
104 front.advance(n);
106 n = 0;
107 }
108 }
109 self.last_segment.advance(n);
111 if self.segments.len() * 4 < self.segments.capacity() {
113 self.segments.shrink_to_fit();
114 }
115 }
116
117 fn segments_iter(&self) -> impl Iterator<Item = &[u8]> {
121 self.segments
122 .iter()
123 .map(|x| x.as_ref())
124 .chain(std::iter::once(self.last_segment.as_ref()))
125 }
126
127 #[cfg(any(test, feature = "bench"))]
131 fn get(&self, offsets: Range<u64>) -> &[u8] {
132 assert!(
133 offsets.start >= self.range().start && offsets.end <= self.range().end,
134 "Requested range is outside of buffered data"
135 );
136 let offsets = Range {
138 start: (offsets.start - self.offset) as usize,
139 end: (offsets.end - self.offset) as usize,
140 };
141 let mut segment_offset = 0;
142 for segment in self.segments_iter() {
143 if offsets.start >= segment_offset && offsets.start < segment_offset + segment.len() {
144 let start = offsets.start - segment_offset;
145 let end = offsets.end - segment_offset;
146
147 return &segment[start..end.min(segment.len())];
148 }
149 segment_offset += segment.len();
150 }
151
152 unreachable!("impossible if segments and range are consistent");
153 }
154
155 fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
156 assert!(
157 offsets.start >= self.range().start && offsets.end <= self.range().end,
158 "Requested range is outside of buffered data"
159 );
160 let offsets = Range {
162 start: (offsets.start - self.offset) as usize,
163 end: (offsets.end - self.offset) as usize,
164 };
165 let mut segment_offset = 0;
166 for segment in self.segments_iter() {
167 let start = segment_offset.max(offsets.start);
169 let end = (segment_offset + segment.len()).min(offsets.end);
170 if start < end {
171 buf.put_slice(&segment[start - segment_offset..end - segment_offset]);
173 }
174 segment_offset += segment.len();
175 if segment_offset >= offsets.end {
176 break;
178 }
179 }
180 }
181
182 #[cfg(test)]
183 fn to_vec(&self) -> Vec<u8> {
184 let mut result = Vec::with_capacity(self.len);
185 for segment in self.segments_iter() {
186 result.extend_from_slice(segment);
187 }
188 result
189 }
190}
191
192impl SendBuffer {
193 pub(super) fn new() -> Self {
195 Self::default()
196 }
197
198 pub(super) fn write(&mut self, data: Bytes) {
200 self.data.append(data);
201 }
202
203 pub(super) fn ack(&mut self, mut range: Range<u64>) {
205 let base_offset = self.data.range().start;
207 range.start = base_offset.max(range.start);
208 range.end = base_offset.max(range.end);
209
210 self.acks.insert(range);
211
212 while self.acks.min() == Some(self.data.range().start) {
213 let prefix = self.acks.pop_min().unwrap();
214 let to_advance = (prefix.end - prefix.start) as usize;
215 self.data.pop_front(to_advance);
216 }
217 }
218
219 pub(super) fn poll_transmit(&mut self, mut max_len: usize) -> (Range<u64>, bool) {
233 debug_assert!(max_len >= 8 + 8);
234 let mut encode_length = false;
235
236 if let Some(range) = self.retransmits.pop_min() {
237 if range.start != 0 {
242 max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(range.start) });
243 }
244 if range.end - range.start < max_len as u64 {
245 encode_length = true;
246 max_len -= 8;
247 }
248
249 let end = range.end.min((max_len as u64).saturating_add(range.start));
250 if end != range.end {
251 self.retransmits.insert(end..range.end);
252 }
253 return (range.start..end, encode_length);
254 }
255
256 if self.unsent != 0 {
261 max_len -= VarInt::size(unsafe { VarInt::from_u64_unchecked(self.unsent) });
262 }
263 if self.offset() - self.unsent < max_len as u64 {
264 encode_length = true;
265 max_len -= 8;
266 }
267
268 let end = self
269 .offset()
270 .min((max_len as u64).saturating_add(self.unsent));
271 let result = self.unsent..end;
272 self.unsent = end;
273 (result, encode_length)
274 }
275
276 #[cfg(any(test, feature = "bench"))]
283 pub(super) fn get(&self, offsets: Range<u64>) -> &[u8] {
284 self.data.get(offsets)
285 }
286
287 pub(super) fn get_into(&self, offsets: Range<u64>, buf: &mut impl BufMut) {
288 self.data.get_into(offsets, buf)
289 }
290
291 pub(super) fn retransmit(&mut self, range: Range<u64>) {
293 debug_assert!(range.end <= self.unsent, "unsent data can't be lost");
294 self.retransmits.insert(range);
295 }
296
297 pub(super) fn retransmit_all_for_0rtt(&mut self) {
298 debug_assert_eq!(self.offset(), self.data.len() as u64);
299 self.unsent = 0;
300 }
301
302 pub(super) fn offset(&self) -> u64 {
305 self.data.range().end
306 }
307
308 pub(super) fn is_fully_acked(&self) -> bool {
310 self.data.len() == 0
311 }
312
313 pub(super) fn has_unsent_data(&self) -> bool {
317 self.unsent != self.offset() || !self.retransmits.is_empty()
318 }
319
320 pub(super) fn unacked(&self) -> u64 {
322 self.data.len() as u64 - self.acks.iter().map(|x| x.end - x.start).sum::<u64>()
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn fragment_with_length() {
332 let mut buf = SendBuffer::new();
333 const MSG: &[u8] = b"Hello, world!";
334 buf.write(MSG.into());
335 assert_eq!(buf.poll_transmit(19), (0..11, true));
338 assert_eq!(
339 buf.poll_transmit(MSG.len() + 16 - 11),
340 (11..MSG.len() as u64, true)
341 );
342 assert_eq!(
343 buf.poll_transmit(58),
344 (MSG.len() as u64..MSG.len() as u64, true)
345 );
346 }
347
348 #[test]
349 fn fragment_without_length() {
350 let mut buf = SendBuffer::new();
351 const MSG: &[u8] = b"Hello, world with some extra data!";
352 buf.write(MSG.into());
353 assert_eq!(buf.poll_transmit(19), (0..19, false));
355 assert_eq!(
356 buf.poll_transmit(MSG.len() - 19 + 1),
357 (19..MSG.len() as u64, false)
358 );
359 assert_eq!(
360 buf.poll_transmit(58),
361 (MSG.len() as u64..MSG.len() as u64, true)
362 );
363 }
364
365 #[test]
366 fn reserves_encoded_offset() {
367 let mut buf = SendBuffer::new();
368
369 let chunk: Bytes = Bytes::from_static(&[0; 1024 * 1024]);
371 for _ in 0..1025 {
372 buf.write(chunk.clone());
373 }
374
375 const SIZE1: u64 = 64;
376 const SIZE2: u64 = 16 * 1024;
377 const SIZE3: u64 = 1024 * 1024 * 1024;
378
379 assert_eq!(buf.poll_transmit(16), (0..16, false));
381 buf.retransmit(0..16);
382 assert_eq!(buf.poll_transmit(16), (0..16, false));
383 let mut transmitted = 16u64;
384
385 assert_eq!(
387 buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
388 (transmitted..SIZE1, false)
389 );
390 buf.retransmit(transmitted..SIZE1);
391 assert_eq!(
392 buf.poll_transmit((SIZE1 - transmitted + 1) as usize),
393 (transmitted..SIZE1, false)
394 );
395 transmitted = SIZE1;
396
397 assert_eq!(
399 buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
400 (transmitted..SIZE2, false)
401 );
402 buf.retransmit(transmitted..SIZE2);
403 assert_eq!(
404 buf.poll_transmit((SIZE2 - transmitted + 2) as usize),
405 (transmitted..SIZE2, false)
406 );
407 transmitted = SIZE2;
408
409 assert_eq!(
411 buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
412 (transmitted..SIZE3, false)
413 );
414 buf.retransmit(transmitted..SIZE3);
415 assert_eq!(
416 buf.poll_transmit((SIZE3 - transmitted + 4) as usize),
417 (transmitted..SIZE3, false)
418 );
419 transmitted = SIZE3;
420
421 assert_eq!(
423 buf.poll_transmit(chunk.len() + 8),
424 (transmitted..transmitted + chunk.len() as u64, false)
425 );
426 buf.retransmit(transmitted..transmitted + chunk.len() as u64);
427 assert_eq!(
428 buf.poll_transmit(chunk.len() + 8),
429 (transmitted..transmitted + chunk.len() as u64, false)
430 );
431 }
432
433 #[test]
434 #[ignore]
435 fn multiple_segments() {
436 let mut buf = SendBuffer::new();
437 const MSG: &[u8] = b"Hello, world!";
438 const MSG_LEN: u64 = MSG.len() as u64;
439
440 const SEG1: &[u8] = b"He";
441 buf.write(SEG1.into());
442 const SEG2: &[u8] = b"llo,";
443 buf.write(SEG2.into());
444 const SEG3: &[u8] = b" w";
445 buf.write(SEG3.into());
446 const SEG4: &[u8] = b"o";
447 buf.write(SEG4.into());
448 const SEG5: &[u8] = b"rld!";
449 buf.write(SEG5.into());
450
451 assert_eq!(aggregate_unacked(&buf), MSG);
452
453 assert_eq!(buf.poll_transmit(16), (0..8, true));
454 assert_eq!(buf.get(0..5), SEG1);
455 assert_eq!(buf.get(2..8), SEG2);
456 assert_eq!(buf.get(6..8), SEG3);
457
458 assert_eq!(buf.poll_transmit(16), (8..MSG_LEN, true));
459 assert_eq!(buf.get(8..MSG_LEN), SEG4);
460 assert_eq!(buf.get(9..MSG_LEN), SEG5);
461
462 assert_eq!(buf.poll_transmit(42), (MSG_LEN..MSG_LEN, true));
463
464 buf.ack(0..1);
466 assert_eq!(aggregate_unacked(&buf), &MSG[1..]);
467 buf.ack(0..3);
468 assert_eq!(aggregate_unacked(&buf), &MSG[3..]);
469 buf.ack(3..5);
470 assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
471 buf.ack(7..9);
472 assert_eq!(aggregate_unacked(&buf), &MSG[5..]);
473 buf.ack(4..7);
474 assert_eq!(aggregate_unacked(&buf), &MSG[9..]);
475 buf.ack(0..MSG_LEN);
476 assert_eq!(aggregate_unacked(&buf), &[] as &[u8]);
477 }
478
479 #[test]
480 fn retransmit() {
481 let mut buf = SendBuffer::new();
482 const MSG: &[u8] = b"Hello, world with extra data!";
483 buf.write(MSG.into());
484 assert_eq!(buf.poll_transmit(16), (0..16, false));
486 assert_eq!(buf.poll_transmit(16), (16..23, true));
487 buf.retransmit(0..16);
489 assert_eq!(buf.poll_transmit(16), (0..16, false));
491 assert_eq!(buf.poll_transmit(16), (23..MSG.len() as u64, true));
492 buf.retransmit(16..23);
494 assert_eq!(buf.poll_transmit(16), (16..23, true));
495 }
496
497 #[test]
498 fn ack() {
499 let mut buf = SendBuffer::new();
500 const MSG: &[u8] = b"Hello, world!";
501 buf.write(MSG.into());
502 assert_eq!(buf.poll_transmit(16), (0..8, true));
503 buf.ack(0..8);
504 assert_eq!(aggregate_unacked(&buf), &MSG[8..]);
505 }
506
507 #[test]
508 fn reordered_ack() {
509 let mut buf = SendBuffer::new();
510 const MSG: &[u8] = b"Hello, world with extra data!";
511 buf.write(MSG.into());
512 assert_eq!(buf.poll_transmit(16), (0..16, false));
513 assert_eq!(buf.poll_transmit(16), (16..23, true));
514 buf.ack(16..23);
515 assert_eq!(aggregate_unacked(&buf), MSG);
516 buf.ack(0..16);
517 assert_eq!(aggregate_unacked(&buf), &MSG[23..]);
518 assert!(buf.acks.is_empty());
519 }
520
521 fn aggregate_unacked(buf: &SendBuffer) -> Vec<u8> {
522 buf.data.to_vec()
523 }
524
525 #[test]
526 #[should_panic(expected = "Requested range is outside of buffered data")]
527 fn send_buffer_get_out_of_range() {
528 let data = SendBufferData::default();
529 data.get(0..1);
530 }
531
532 #[test]
533 #[should_panic(expected = "Requested range is outside of buffered data")]
534 fn send_buffer_get_into_out_of_range() {
535 let data = SendBufferData::default();
536 let mut buf = Vec::new();
537 data.get_into(0..1, &mut buf);
538 }
539}
540
541#[cfg(feature = "bench")]
542pub mod send_buffer_benches {
543 use criterion::Criterion;
548 use bytes::Bytes;
549 use super::SendBuffer;
550
551 pub fn get_into_many_segments(criterion: &mut Criterion) {
553 let mut group = criterion.benchmark_group("get_into_many_segments");
554 let mut buf = SendBuffer::new();
555
556 const SEGMENTS: u64 = 10000;
557 const SEGMENT_SIZE: u64 = 10;
558 const PACKET_SIZE: u64 = 1200;
559 const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
560
561 for i in 0..SEGMENTS {
563 buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
564 }
565
566 let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
567 group.bench_function("get_into", |b| {
568 b.iter(|| {
569 tgt.clear();
571 buf.get_into(BYTES - PACKET_SIZE..BYTES, std::hint::black_box(&mut tgt));
572 });
573 });
574 }
575
576 pub fn get_loop_many_segments(criterion: &mut Criterion) {
578 let mut group = criterion.benchmark_group("get_loop_many_segments");
579 let mut buf = SendBuffer::new();
580
581 const SEGMENTS: u64 = 10000;
582 const SEGMENT_SIZE: u64 = 10;
583 const PACKET_SIZE: u64 = 1200;
584 const BYTES: u64 = SEGMENTS * SEGMENT_SIZE;
585
586 for i in 0..SEGMENTS {
588 buf.write(Bytes::from(vec![i as u8; SEGMENT_SIZE as usize]));
589 }
590
591 let mut tgt = Vec::with_capacity(PACKET_SIZE as usize);
592 group.bench_function("get_loop", |b| {
593 b.iter(|| {
594 tgt.clear();
596 let mut range = BYTES - PACKET_SIZE..BYTES;
597 while range.start < range.end {
598 let slice = std::hint::black_box(buf.get(range.clone()));
599 range.start += slice.len() as u64;
600 tgt.extend_from_slice(slice);
601 }
602 });
603 });
604 }
605}