iroh_quinn_proto/connection/streams/
send.rs1use bytes::Bytes;
2use thiserror::Error;
3
4use crate::{
5 VarInt,
6 connection::{send_buffer::SendBuffer, streams::BytesOrSlice},
7 frame,
8};
9
10#[derive(Debug)]
11pub(super) struct Send {
12 pub(super) max_data: u64,
13 pub(super) state: SendState,
14 pub(super) pending: SendBuffer,
15 pub(super) priority: i32,
16 pub(super) fin_pending: bool,
18 pub(super) connection_blocked: bool,
20 pub(super) stop_reason: Option<VarInt>,
22}
23
24impl Send {
25 pub(super) fn new(max_data: VarInt) -> Box<Self> {
26 Box::new(Self {
27 max_data: max_data.into(),
28 state: SendState::Ready,
29 pending: SendBuffer::new(),
30 priority: 0,
31 fin_pending: false,
32 connection_blocked: false,
33 stop_reason: None,
34 })
35 }
36
37 pub(super) fn is_reset(&self) -> bool {
39 matches!(self.state, SendState::ResetSent)
40 }
41
42 pub(super) fn finish(&mut self) -> Result<(), FinishError> {
43 if let Some(error_code) = self.stop_reason {
44 Err(FinishError::Stopped(error_code))
45 } else if self.state == SendState::Ready {
46 self.state = SendState::DataSent {
47 finish_acked: false,
48 };
49 self.fin_pending = true;
50 Ok(())
51 } else {
52 Err(FinishError::ClosedStream)
53 }
54 }
55
56 pub(super) fn write<'a, S: BytesSource<'a>>(
57 &mut self,
58 source: &'a mut S,
59 limit: u64,
60 ) -> Result<Written, WriteError> {
61 if !self.is_writable() {
62 return Err(WriteError::ClosedStream);
63 }
64 if let Some(error_code) = self.stop_reason {
65 return Err(WriteError::Stopped(error_code));
66 }
67 let budget = self.max_data - self.pending.offset();
68 if budget == 0 {
69 return Err(WriteError::Blocked);
70 }
71 let mut limit = limit.min(budget) as usize;
72
73 let mut result = Written::default();
74 loop {
75 let (chunk, chunks_consumed) = source.pop_chunk(limit);
76 result.chunks += chunks_consumed;
77 result.bytes += chunk.len();
78
79 if chunk.is_empty() {
80 break;
81 }
82
83 limit -= chunk.len();
84 self.pending.write(chunk);
85 }
86
87 Ok(result)
88 }
89
90 pub(super) fn reset(&mut self) {
92 use SendState::*;
93 if let DataSent { .. } | Ready = self.state {
94 self.state = ResetSent;
95 }
96 }
97
98 pub(super) fn try_stop(&mut self, error_code: VarInt) -> bool {
103 if self.stop_reason.is_none() {
104 self.stop_reason = Some(error_code);
105 true
106 } else {
107 false
108 }
109 }
110
111 pub(super) fn ack(&mut self, frame: frame::StreamMeta) -> bool {
113 self.pending.ack(frame.offsets);
114 match self.state {
115 SendState::DataSent {
116 ref mut finish_acked,
117 } => {
118 *finish_acked |= frame.fin;
119 *finish_acked && self.pending.is_fully_acked()
120 }
121 _ => false,
122 }
123 }
124
125 pub(super) fn increase_max_data(&mut self, offset: u64) -> bool {
129 if offset <= self.max_data || self.state != SendState::Ready {
130 return false;
131 }
132 let was_blocked = self.pending.offset() == self.max_data;
133 self.max_data = offset;
134 was_blocked
135 }
136
137 pub(super) fn offset(&self) -> u64 {
138 self.pending.offset()
139 }
140
141 pub(super) fn is_pending(&self) -> bool {
142 self.pending.has_unsent_data() || self.fin_pending
143 }
144
145 pub(super) fn is_writable(&self) -> bool {
146 matches!(self.state, SendState::Ready)
147 }
148}
149
150pub(crate) struct BytesArray<'a> {
155 chunks: &'a mut [Bytes],
157 consumed: usize,
159}
160
161impl<'a> BytesArray<'a> {
162 pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self {
163 Self {
164 chunks,
165 consumed: 0,
166 }
167 }
168}
169
170impl<'a> BytesSource<'a> for BytesArray<'a> {
171 fn pop_chunk<'b>(&'b mut self, limit: usize) -> (impl BytesOrSlice<'b>, usize)
172 where
173 'a: 'b,
174 {
175 let mut chunks_consumed = 0;
178
179 while self.consumed < self.chunks.len() {
180 let chunk = &mut self.chunks[self.consumed];
181
182 if chunk.len() <= limit {
183 let chunk = std::mem::take(chunk);
184 self.consumed += 1;
185 chunks_consumed += 1;
186 if chunk.is_empty() {
187 continue;
188 }
189 return (chunk, chunks_consumed);
190 } else if limit > 0 {
191 let chunk = chunk.split_to(limit);
192 return (chunk, chunks_consumed);
193 } else {
194 break;
195 }
196 }
197
198 (Bytes::new(), chunks_consumed)
199 }
200}
201
202pub(crate) struct ByteSlice<'a> {
208 data: &'a [u8],
210}
211
212impl<'a> ByteSlice<'a> {
213 pub(crate) fn from_slice(data: &'a [u8]) -> Self {
214 Self { data }
215 }
216}
217
218impl<'a> BytesSource<'a> for ByteSlice<'a> {
219 fn pop_chunk<'b>(&'b mut self, limit: usize) -> (impl BytesOrSlice<'b>, usize)
220 where
221 'a: 'b,
222 {
223 let limit = limit.min(self.data.len());
224 if limit == 0 {
225 return (&[][..], 0);
226 }
227
228 let chunk = &self.data[..limit];
229 self.data = &self.data[chunk.len()..];
230
231 let chunks_consumed = usize::from(self.data.is_empty());
232 (chunk, chunks_consumed)
233 }
234}
235
236pub(super) trait BytesSource<'a> {
241 fn pop_chunk<'b>(&'b mut self, limit: usize) -> (impl BytesOrSlice<'b>, usize)
254 where
255 'a: 'b;
256}
257
258#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
260pub struct Written {
261 pub bytes: usize,
263 pub chunks: usize,
267}
268
269#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
271pub enum WriteError {
272 #[error("unable to accept further writes")]
279 Blocked,
280 #[error("stopped by peer: code {0}")]
287 Stopped(VarInt),
288 #[error("closed stream")]
290 ClosedStream,
291}
292
293#[derive(Debug, Copy, Clone, Eq, PartialEq)]
294pub(super) enum SendState {
295 Ready,
297 DataSent { finish_acked: bool },
299 ResetSent,
301}
302
303#[derive(Debug, Error, Clone, PartialEq, Eq)]
305pub enum FinishError {
306 #[error("stopped by peer: code {0}")]
313 Stopped(VarInt),
314 #[error("closed stream")]
316 ClosedStream,
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn bytes_array() {
325 let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
326 for limit in 0..full.len() {
327 let mut chunks = [
328 Bytes::from_static(b""),
329 Bytes::from_static(b"Hello "),
330 Bytes::from_static(b"Wo"),
331 Bytes::from_static(b""),
332 Bytes::from_static(b"r"),
333 Bytes::from_static(b"ld"),
334 Bytes::from_static(b""),
335 Bytes::from_static(b" 12345678"),
336 Bytes::from_static(b"9 ABCDE"),
337 Bytes::from_static(b"F"),
338 Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"),
339 ];
340 let num_chunks = chunks.len();
341 let last_chunk_len = chunks[chunks.len() - 1].len();
342
343 let mut array = BytesArray::from_chunks(&mut chunks);
344
345 let mut buf = Vec::new();
346 let mut chunks_popped = 0;
347 let mut chunks_consumed = 0;
348 let mut remaining = limit;
349 loop {
350 let (chunk, consumed) = array.pop_chunk(remaining);
351 chunks_consumed += consumed;
352
353 if !chunk.is_empty() {
354 buf.extend_from_slice(chunk.as_ref());
355 remaining -= chunk.len();
356 chunks_popped += 1;
357 } else {
358 break;
359 }
360 }
361
362 assert_eq!(&buf[..], &full[..limit]);
363
364 if limit == full.len() {
365 assert_eq!(chunks_consumed, num_chunks);
367 assert_eq!(chunks_consumed, chunks_popped + 3);
369 } else if limit > full.len() - last_chunk_len {
370 assert_eq!(chunks_consumed, num_chunks - 1);
372 assert_eq!(chunks_consumed, chunks_popped + 2);
373 }
374 }
375 }
376
377 #[test]
378 fn byte_slice() {
379 let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
380 for limit in 0..full.len() {
381 let mut array = ByteSlice::from_slice(&full[..]);
382
383 let mut buf = Vec::new();
384 let mut chunks_popped = 0;
385 let mut chunks_consumed = 0;
386 let mut remaining = limit;
387 loop {
388 let (chunk, consumed) = array.pop_chunk(remaining);
389 chunks_consumed += consumed;
390
391 if !chunk.is_empty() {
392 buf.extend_from_slice(chunk.as_ref());
393 remaining -= chunk.len();
394 chunks_popped += 1;
395 } else {
396 break;
397 }
398 }
399
400 assert_eq!(&buf[..], &full[..limit]);
401 if limit != 0 {
402 assert_eq!(chunks_popped, 1);
403 } else {
404 assert_eq!(chunks_popped, 0);
405 }
406
407 if limit == full.len() {
408 assert_eq!(chunks_consumed, 1);
409 } else {
410 assert_eq!(chunks_consumed, 0);
411 }
412 }
413 }
414}