noq_proto/connection/streams/
mod.rs1use std::{
2 collections::{BinaryHeap, hash_map},
3 io,
4};
5
6use bytes::Bytes;
7use thiserror::Error;
8use tracing::trace;
9
10use super::spaces::Retransmits;
11use crate::{
12 Dir, StreamId, VarInt,
13 connection::streams::state::{get_or_insert_recv, get_or_insert_send},
14 frame,
15};
16
17mod recv;
18use recv::Recv;
19pub use recv::{Chunks, ReadError, ReadableError};
20
21mod send;
22pub(crate) use send::{ByteSlice, BytesArray};
23use send::{BytesSource, Send, SendState};
24pub use send::{FinishError, WriteError};
25pub(crate) use send::Written;
26
27mod state;
28#[allow(unreachable_pub)] pub use state::StreamsState;
30
31pub struct Streams<'a> {
33 pub(super) state: &'a mut StreamsState,
34 pub(super) conn_state: &'a super::State,
35}
36
37#[allow(clippy::needless_lifetimes)] impl<'a> Streams<'a> {
39 #[cfg(fuzzing)]
40 pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self {
41 Self { state, conn_state }
42 }
43
44 pub fn open(&mut self, dir: Dir) -> Option<StreamId> {
48 if self.conn_state.is_closed() {
49 return None;
50 }
51
52 if self.state.next[dir as usize] >= self.state.max[dir as usize] {
53 self.state.streams_blocked[dir as usize] = true;
54 return None;
55 }
56
57 self.state.next[dir as usize] += 1;
58 let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1);
59 self.state.insert(false, id);
60 self.state.send_streams += 1;
61 Some(id)
62 }
63
64 pub fn accept(&mut self, dir: Dir) -> Option<StreamId> {
69 if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] {
70 return None;
71 }
72
73 let x = self.state.next_reported_remote[dir as usize];
74 self.state.next_reported_remote[dir as usize] = x + 1;
75 if dir == Dir::Bi {
76 self.state.send_streams += 1;
77 }
78
79 Some(StreamId::new(!self.state.side, dir, x))
80 }
81
82 #[cfg(fuzzing)]
83 pub fn state(&mut self) -> &mut StreamsState {
84 self.state
85 }
86
87 pub fn send_streams(&self) -> usize {
89 self.state.send_streams
90 }
91
92 pub fn remote_open_streams(&self, dir: Dir) -> u64 {
98 self.state.next_remote[dir as usize]
100 - (self.state.max_remote[dir as usize]
101 - self.state.allocated_remote_count[dir as usize])
102 }
103}
104
105pub struct RecvStream<'a> {
107 pub(super) id: StreamId,
108 pub(super) state: &'a mut StreamsState,
109 pub(super) pending: &'a mut Retransmits,
110}
111
112impl RecvStream<'_> {
113 pub fn read(&mut self, ordered: bool) -> Result<Chunks<'_>, ReadableError> {
130 Chunks::new(self.id, ordered, self.state, self.pending)
131 }
132
133 pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
138 let mut entry = match self.state.recv.entry(self.id) {
139 hash_map::Entry::Occupied(s) => s,
140 hash_map::Entry::Vacant(_) => return Err(ClosedStream { _private: () }),
141 };
142 let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut());
143
144 let (read_credits, stop_sending) = stream.stop()?;
145 if stop_sending.should_transmit() {
146 self.pending.stop_sending.push(frame::StopSending {
147 id: self.id,
148 error_code,
149 });
150 }
151
152 if !stream.final_offset_unknown() {
156 let recv = entry.remove().expect("must have recv when stopping");
157 self.state.stream_recv_freed(self.id, recv);
158 }
159
160 if self.state.add_read_credits(read_credits).should_transmit() {
161 self.pending.max_data = true;
162 }
163
164 Ok(())
165 }
166
167 pub fn bytes_read(&self) -> Result<u64, ClosedStream> {
172 let recv = self
173 .state
174 .recv
175 .get(&self.id)
176 .and_then(|s| s.as_ref())
177 .and_then(|s| s.as_open_recv())
178 .ok_or(ClosedStream { _private: () })?;
179 Ok(recv.assembler.bytes_read())
180 }
181
182 pub fn received_reset(&mut self) -> Result<Option<VarInt>, ClosedStream> {
187 let hash_map::Entry::Occupied(entry) = self.state.recv.entry(self.id) else {
188 return Err(ClosedStream { _private: () });
189 };
190 let Some(s) = entry.get().as_ref().and_then(|s| s.as_open_recv()) else {
191 return Ok(None);
192 };
193 if s.stopped {
194 return Err(ClosedStream { _private: () });
195 }
196 let Some(code) = s.reset_code() else {
197 return Ok(None);
198 };
199
200 let (_, recv) = entry.remove_entry();
203 self.state
204 .stream_recv_freed(self.id, recv.expect("must have recv on reset"));
205 self.state.queue_max_stream_id(self.pending);
206
207 Ok(Some(code))
208 }
209}
210
211pub struct SendStream<'a> {
213 pub(super) id: StreamId,
214 pub(super) state: &'a mut StreamsState,
215 pub(super) pending: &'a mut Retransmits,
216 pub(super) conn_state: &'a super::State,
217}
218
219#[allow(clippy::needless_lifetimes)] impl<'a> SendStream<'a> {
221 #[cfg(fuzzing)]
222 pub fn new(
223 id: StreamId,
224 state: &'a mut StreamsState,
225 pending: &'a mut Retransmits,
226 conn_state: &'a super::State,
227 ) -> Self {
228 Self {
229 id,
230 state,
231 pending,
232 conn_state,
233 }
234 }
235
236 pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
240 Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
241 }
242
243 pub fn write_chunks(&mut self, data: &mut &mut [Bytes]) -> Result<usize, WriteError> {
251 let written = self.write_source(&mut BytesArray::from_chunks(data))?;
252 *data = &mut std::mem::take(data)[written.chunks..];
253 Ok(written.bytes)
254 }
255
256 fn write_source<'b, B: BytesSource<'b>>(
257 &mut self,
258 source: &'b mut B,
259 ) -> Result<Written, WriteError> {
260 if self.conn_state.is_closed() {
261 trace!(%self.id, "write blocked; connection draining");
262 return Err(WriteError::Blocked);
263 }
264
265 let limit = self.state.write_limit();
266
267 let max_send_data = self.state.max_send_data(self.id);
268
269 let stream = self
270 .state
271 .send
272 .get_mut(&self.id)
273 .map(get_or_insert_send(max_send_data))
274 .ok_or(WriteError::ClosedStream)?;
275
276 if limit == 0 {
277 trace!(
278 stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent,
279 "write blocked by connection-level flow control or send window"
280 );
281 if !stream.connection_blocked {
282 stream.connection_blocked = true;
283 self.state.connection_blocked.push(self.id);
284 }
285 return Err(WriteError::Blocked);
286 }
287
288 let was_pending = stream.is_pending();
289 let written = stream.write(source, limit)?;
290 self.state.data_sent += written.bytes as u64;
291 self.state.unacked_data += written.bytes as u64;
292 trace!(stream = %self.id, "wrote {} bytes", written.bytes);
293 if !was_pending {
294 self.state.pending.push_pending(self.id, stream.priority);
295 }
296 Ok(written)
297 }
298
299 pub fn stopped(&self) -> Result<Option<VarInt>, ClosedStream> {
301 match self.state.send.get(&self.id).as_ref() {
302 Some(Some(s)) => Ok(s.stop_reason),
303 Some(None) => Ok(None),
304 None => Err(ClosedStream { _private: () }),
305 }
306 }
307
308 pub fn finish(&mut self) -> Result<(), FinishError> {
314 let max_send_data = self.state.max_send_data(self.id);
315 let stream = self
316 .state
317 .send
318 .get_mut(&self.id)
319 .map(get_or_insert_send(max_send_data))
320 .ok_or(FinishError::ClosedStream)?;
321
322 let was_pending = stream.is_pending();
323 stream.finish()?;
324 if !was_pending {
325 self.state.pending.push_pending(self.id, stream.priority);
326 }
327
328 Ok(())
329 }
330
331 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
336 let max_send_data = self.state.max_send_data(self.id);
337 let stream = self
338 .state
339 .send
340 .get_mut(&self.id)
341 .map(get_or_insert_send(max_send_data))
342 .ok_or(ClosedStream { _private: () })?;
343
344 if matches!(stream.state, SendState::ResetSent) {
345 return Err(ClosedStream { _private: () });
347 }
348
349 self.state.unacked_data -= stream.pending.unacked();
353 stream.reset();
354 self.pending.reset_stream.push((self.id, error_code));
355
356 Ok(())
358 }
359
360 pub fn set_priority(&mut self, priority: i32) -> Result<(), ClosedStream> {
365 let max_send_data = self.state.max_send_data(self.id);
366 let stream = self
367 .state
368 .send
369 .get_mut(&self.id)
370 .map(get_or_insert_send(max_send_data))
371 .ok_or(ClosedStream { _private: () })?;
372
373 stream.priority = priority;
374 Ok(())
375 }
376
377 pub fn priority(&self) -> Result<i32, ClosedStream> {
382 let stream = self
383 .state
384 .send
385 .get(&self.id)
386 .ok_or(ClosedStream { _private: () })?;
387
388 Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default())
389 }
390}
391
392struct PendingStreamsQueue {
394 streams: BinaryHeap<PendingStream>,
395 next: Option<PendingStream>,
398 recency: u64,
401}
402
403impl PendingStreamsQueue {
404 fn new() -> Self {
405 Self {
406 streams: BinaryHeap::new(),
407 next: None,
408 recency: u64::MAX,
409 }
410 }
411
412 fn reinsert_pending(&mut self, id: StreamId, priority: i32) {
414 assert!(self.next.is_none());
415
416 self.next = Some(PendingStream {
417 priority,
418 recency: self.recency, id,
420 });
421 }
422
423 fn push_pending(&mut self, id: StreamId, priority: i32) {
425 self.recency -= 1;
434 self.streams.push(PendingStream {
435 priority,
436 recency: self.recency,
437 id,
438 });
439 }
440
441 fn pop(&mut self) -> Option<PendingStream> {
442 self.next.take().or_else(|| self.streams.pop())
443 }
444
445 fn clear(&mut self) {
446 self.next = None;
447 self.streams.clear();
448 }
449
450 fn iter(&self) -> impl Iterator<Item = &PendingStream> {
451 self.next.iter().chain(self.streams.iter())
452 }
453
454 #[cfg(test)]
455 fn len(&self) -> usize {
456 self.streams.len() + self.next.is_some() as usize
457 }
458}
459
460#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
462struct PendingStream {
463 priority: i32,
467 recency: u64,
473 id: StreamId,
477}
478
479#[derive(Debug, PartialEq, Eq)]
481pub enum StreamEvent {
482 Opened {
484 dir: Dir,
486 },
487 Readable {
489 id: StreamId,
491 },
492 Writable {
496 id: StreamId,
498 },
499 Finished {
501 id: StreamId,
503 },
504 Stopped {
506 id: StreamId,
508 error_code: VarInt,
510 },
511 Available {
513 dir: Dir,
515 },
516}
517
518#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
523#[must_use = "A frame might need to be enqueued"]
524pub struct ShouldTransmit(bool);
525
526impl ShouldTransmit {
527 pub fn should_transmit(self) -> bool {
529 self.0
530 }
531}
532
533#[derive(Debug, Default, Error, Clone, PartialEq, Eq)]
535#[error("closed stream")]
536pub struct ClosedStream {
537 _private: (),
538}
539
540impl From<ClosedStream> for io::Error {
541 fn from(x: ClosedStream) -> Self {
542 Self::new(io::ErrorKind::NotConnected, x)
543 }
544}
545
546#[derive(Debug, Copy, Clone, Eq, PartialEq)]
547enum StreamHalf {
548 Send,
549 Recv,
550}
551
552pub(super) trait BytesOrSlice<'a>: AsRef<[u8]> + 'a {
554 fn len(&self) -> usize {
555 self.as_ref().len()
556 }
557 fn is_empty(&self) -> bool {
558 self.as_ref().is_empty()
559 }
560 fn into_bytes(self) -> Bytes;
561}
562
563impl BytesOrSlice<'_> for Bytes {
564 fn into_bytes(self) -> Bytes {
565 self
566 }
567}
568
569impl<'a> BytesOrSlice<'a> for &'a [u8] {
570 fn into_bytes(self) -> Bytes {
571 Bytes::copy_from_slice(self)
572 }
573}