noq/send_stream.rs
1use std::{
2 future::{Future, poll_fn},
3 io,
4 pin::{Pin, pin},
5 task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use pin_project_lite::pin_project;
10use proto::{ClosedStream, ConnectionError, FinishError, StreamId};
11use thiserror::Error;
12use tokio::sync::futures::OwnedNotified;
13
14use crate::{
15 VarInt,
16 connection::{ConnectionRef, State},
17};
18
19/// A stream that can only be used to send data
20///
21/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
22/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
23/// connection is closed.
24///
25/// # Cancellation
26///
27/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
28/// ready will always result in no data being written to the stream. This is true of methods which
29/// succeed immediately when any progress is made, and is not true of methods which might need to
30/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
31/// cancel-safe.
32///
33/// [`reset()`]: SendStream::reset
34/// [`finish()`]: SendStream::finish
35#[derive(Debug)]
36pub struct SendStream {
37 conn: ConnectionRef,
38 stream: StreamId,
39 is_0rtt: bool,
40}
41
42impl SendStream {
43 pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
44 Self {
45 conn,
46 stream,
47 is_0rtt,
48 }
49 }
50
51 /// Write a buffer into this stream, returning how many bytes were written
52 ///
53 /// Unless this method errors, it waits until some amount of `buf` can be written into this
54 /// stream, and then writes as much as it can without waiting again. Due to congestion and flow
55 /// control, this may be shorter than `buf.len()`. On success this yields the length of the
56 /// prefix that was written.
57 ///
58 /// # Cancel safety
59 ///
60 /// This method is cancellation safe. If this does not resolve, no bytes were written.
61 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
62 poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
63 }
64
65 /// Write a buffer into this stream in its entirety
66 ///
67 /// This method repeatedly calls [`write`](Self::write) until all bytes are written, or an
68 /// error occurs.
69 ///
70 /// # Cancel safety
71 ///
72 /// This method is *not* cancellation safe. Even if this does not resolve, some prefix of `buf`
73 /// may have been written when previously polled.
74 pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
75 while !buf.is_empty() {
76 let written = self.write(buf).await?;
77 buf = &buf[written..];
78 }
79 Ok(())
80 }
81
82 /// Writes [`Bytes`] from a slice of buffers into this stream, returning how many bytes were.
83 /// written
84 ///
85 /// Bytes to try to write are provided to this method as an array of cheaply cloneable chunks.
86 /// Unless this method errors, it waits until some amount of those bytes can be written into
87 /// this stream, and then writes as much as it can without waiting again. Due to congestion and
88 /// flow control, this may be less than the total number of bytes.
89 ///
90 /// On success, this method both mutates `bufs` and returns the number of bytes written:
91 ///
92 /// - `bufs` is advanced past chunks that were fully written.
93 /// - If a [`Bytes`] chunk was partially written, the chunk at the new front of `bufs` is
94 /// [split to](Bytes::split_to) contain only the suffix of bytes that were not written.
95 ///
96 /// # Cancel safety
97 ///
98 /// This method is cancellation safe. If this does not resolve, no bytes were written.
99 pub async fn write_many_chunks(
100 &mut self,
101 bufs: &mut &mut [Bytes],
102 ) -> Result<usize, WriteError> {
103 poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
104 }
105
106 /// Writes a single [`Bytes`] into this stream in its entirety.
107 ///
108 /// Bytes to write are provided to this method as a single cheaply cloneable chunk. This
109 /// method repeatedly calls [`write_many_chunks`](Self::write_many_chunks) until all bytes
110 /// are written, or an error occurs.
111 ///
112 /// # Cancel safety
113 ///
114 /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
115 /// been written when previously polled.
116 pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
117 self.write_all_chunks(&mut [buf]).await
118 }
119
120 /// Writes a slice of [`Bytes`] into this stream in its entirety.
121 ///
122 /// Bytes to write are provided to this method as an array of cheaply cloneable chunks. This
123 /// method repeatedly calls [`write_many_chunks`](Self::write_many_chunks) until all bytes are
124 /// written, or an error occurs.
125 ///
126 /// # Cancel safety
127 ///
128 /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
129 /// been written when previously polled.
130 pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
131 let mut bufs = &mut bufs[..];
132 while !bufs.is_empty() {
133 self.write_many_chunks(&mut bufs).await?;
134 }
135 Ok(())
136 }
137
138 fn execute_poll<F, R>(
139 &mut self,
140 cx: &mut Context<'_>,
141 write_fn: F,
142 ) -> Poll<Result<R, WriteError>>
143 where
144 F: FnOnce(&mut proto::SendStream<'_>) -> Result<R, proto::WriteError>,
145 {
146 use proto::WriteError::*;
147 let mut conn = self.conn.lock_and_wake("SendStream::poll_write");
148 if self.is_0rtt && conn.check_0rtt().is_err() {
149 conn.skip_waking();
150 return Poll::Ready(Err(WriteError::ZeroRttRejected));
151 }
152 if let Some(conn_err) = conn.error.clone() {
153 conn.skip_waking();
154 return Poll::Ready(Err(WriteError::ConnectionLost(conn_err)));
155 }
156
157 let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
158 Ok(result) => result,
159 Err(Blocked) => {
160 conn.blocked_writers.insert(self.stream, cx.waker().clone());
161 conn.skip_waking();
162 return Poll::Pending;
163 }
164 Err(Stopped(error_code)) => {
165 conn.skip_waking();
166 return Poll::Ready(Err(WriteError::Stopped(error_code)));
167 }
168 Err(ClosedStream) => {
169 conn.skip_waking();
170 return Poll::Ready(Err(WriteError::ClosedStream));
171 }
172 };
173
174 Poll::Ready(Ok(result))
175 }
176
177 /// Notify the peer that no more data will ever be written to this stream
178 ///
179 /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
180 /// may still be called after `finish` to abandon transmission of any stream data that might
181 /// still be buffered.
182 ///
183 /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
184 ///
185 /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
186 /// called. This error is harmless and serves only to indicate that the caller may have
187 /// incorrect assumptions about the stream's state.
188 pub fn finish(&mut self) -> Result<(), ClosedStream> {
189 let mut conn = self.conn.lock_and_wake("finish");
190 if let Err(e) = conn.inner.send_stream(self.stream).finish() {
191 conn.skip_waking();
192 match e {
193 FinishError::ClosedStream => Err(ClosedStream::default()),
194 // Harmless. If the application needs to know about stopped streams at this point, it
195 // should call `stopped`.
196 FinishError::Stopped(_) => Ok(()),
197 }
198 } else {
199 Ok(())
200 }
201 }
202
203 /// Close the send stream immediately.
204 ///
205 /// No new data can be written after calling this method. Locally buffered data is dropped, and
206 /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
207 /// already been made to finish the stream, the peer may still receive all written data.
208 ///
209 /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
210 /// called. This error is harmless and serves only to indicate that the caller may have
211 /// incorrect assumptions about the stream's state.
212 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
213 let mut conn = self.conn.lock_and_wake("SendStream::reset");
214 if self.is_0rtt && conn.check_0rtt().is_err() {
215 conn.skip_waking();
216 return Ok(());
217 }
218 conn.inner.send_stream(self.stream).reset(error_code)?;
219 Ok(())
220 }
221
222 /// Set the priority of the send stream
223 ///
224 /// Every send stream has an initial priority of 0. Locally buffered data from streams with
225 /// higher priority will be transmitted before data from streams with lower priority. Changing
226 /// the priority of a stream with pending data may only take effect after that data has been
227 /// transmitted. Using many different priority levels per connection may have a negative
228 /// impact on performance.
229 pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
230 let mut conn = self.conn.lock_without_waking("SendStream::set_priority");
231 conn.inner.send_stream(self.stream).set_priority(priority)?;
232 Ok(())
233 }
234
235 /// Get the priority of the send stream
236 pub fn priority(&self) -> Result<i32, ClosedStream> {
237 let mut conn = self.conn.lock_without_waking("SendStream::priority");
238 conn.inner.send_stream(self.stream).priority()
239 }
240
241 /// Completes when the peer stops the stream or reads the stream to completion
242 ///
243 /// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
244 /// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
245 /// of all stream data (although not necessarily the processing of it), after which the peer
246 /// closing the stream is no longer meaningful.
247 ///
248 /// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
249 /// data. As such, relying on `stopped` to know when the peer has read a stream to completion
250 /// may introduce more latency than using an application-level response of some sort.
251 pub fn stopped(&self) -> Stopped {
252 let notified = {
253 // Create an `OwnedNotified` to move into the future. By creating it before the first poll,
254 // we make sure that we don't miss any notifications.
255 let mut conn = self.conn.lock_without_waking("SendStream::stopped");
256 conn.stopped
257 .entry(self.stream)
258 .or_default()
259 .clone()
260 .notified_owned()
261 };
262 Stopped {
263 conn: self.conn.clone(),
264 stream: self.stream,
265 is_0rtt: self.is_0rtt,
266 notified,
267 }
268 }
269
270 /// Get the identity of this stream
271 pub fn id(&self) -> StreamId {
272 self.stream
273 }
274
275 /// Attempt to write bytes from buf into the stream.
276 ///
277 /// On success, returns Poll::Ready(Ok(num_bytes_written)).
278 ///
279 /// If the stream is not ready for writing, the method returns Poll::Pending and arranges
280 /// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
281 /// stream becomes writable or is closed.
282 pub fn poll_write(
283 self: Pin<&mut Self>,
284 cx: &mut Context<'_>,
285 buf: &[u8],
286 ) -> Poll<Result<usize, WriteError>> {
287 pin!(self.get_mut().write(buf)).as_mut().poll(cx)
288 }
289}
290
291/// Check if a send stream is stopped.
292///
293/// Returns `Some` if the stream is stopped or the connection is closed.
294/// Returns `None` if the stream is not stopped.
295fn send_stream_stopped(
296 conn: &mut State,
297 stream: StreamId,
298 is_0rtt: bool,
299) -> Option<Result<Option<VarInt>, StoppedError>> {
300 if is_0rtt && conn.check_0rtt().is_err() {
301 return Some(Err(StoppedError::ZeroRttRejected));
302 }
303 match conn.inner.send_stream(stream).stopped() {
304 Err(ClosedStream { .. }) => Some(Ok(None)),
305 Ok(Some(error_code)) => Some(Ok(Some(error_code))),
306 Ok(None) => conn.error.clone().map(|error| Err(error.into())),
307 }
308}
309
310#[cfg(feature = "futures-io")]
311impl futures_io::AsyncWrite for SendStream {
312 fn poll_write(
313 self: Pin<&mut Self>,
314 cx: &mut Context<'_>,
315 buf: &[u8],
316 ) -> Poll<io::Result<usize>> {
317 self.poll_write(cx, buf).map_err(Into::into)
318 }
319
320 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
321 Poll::Ready(Ok(()))
322 }
323
324 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
325 Poll::Ready(self.get_mut().finish().map_err(Into::into))
326 }
327}
328
329impl tokio::io::AsyncWrite for SendStream {
330 fn poll_write(
331 self: Pin<&mut Self>,
332 cx: &mut Context<'_>,
333 buf: &[u8],
334 ) -> Poll<io::Result<usize>> {
335 self.poll_write(cx, buf).map_err(Into::into)
336 }
337
338 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
339 Poll::Ready(Ok(()))
340 }
341
342 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
343 Poll::Ready(self.get_mut().finish().map_err(Into::into))
344 }
345}
346
347impl Drop for SendStream {
348 fn drop(&mut self) {
349 let mut conn = self.conn.lock_and_wake("SendStream::drop");
350
351 // clean up any previously registered wakers
352 conn.blocked_writers.remove(&self.stream);
353
354 if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
355 conn.skip_waking();
356 return;
357 }
358 match conn.inner.send_stream(self.stream).finish() {
359 Ok(()) => {}
360 Err(FinishError::Stopped(reason)) => {
361 if conn.inner.send_stream(self.stream).reset(reason).is_err() {
362 conn.skip_waking()
363 }
364 }
365 // Already finished or reset, which is fine.
366 Err(FinishError::ClosedStream) => {
367 conn.skip_waking();
368 }
369 }
370 }
371}
372
373/// Errors that arise from writing to a stream
374#[derive(Debug, Error, Clone, PartialEq, Eq)]
375pub enum WriteError {
376 /// The peer is no longer accepting data on this stream
377 ///
378 /// Carries an application-defined error code.
379 #[error("sending stopped by peer: error {0}")]
380 Stopped(VarInt),
381 /// The connection was lost
382 #[error("connection lost")]
383 ConnectionLost(#[from] ConnectionError),
384 /// The stream has already been finished or reset
385 #[error("closed stream")]
386 ClosedStream,
387 /// This was a 0-RTT stream and the server rejected it
388 ///
389 /// Can only occur on clients for 0-RTT streams, which can be opened using
390 /// [`Connecting::into_0rtt()`].
391 ///
392 /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
393 #[error("0-RTT rejected")]
394 ZeroRttRejected,
395}
396
397impl From<ClosedStream> for WriteError {
398 #[inline]
399 fn from(_: ClosedStream) -> Self {
400 Self::ClosedStream
401 }
402}
403
404impl From<StoppedError> for WriteError {
405 fn from(x: StoppedError) -> Self {
406 match x {
407 StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
408 StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
409 }
410 }
411}
412
413impl From<WriteError> for io::Error {
414 fn from(x: WriteError) -> Self {
415 use WriteError::*;
416 let kind = match x {
417 Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
418 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
419 };
420 Self::new(kind, x)
421 }
422}
423
424/// Errors that arise while monitoring for a send stream stop from the peer
425#[derive(Debug, Error, Clone, PartialEq, Eq)]
426pub enum StoppedError {
427 /// The connection was lost
428 #[error("connection lost")]
429 ConnectionLost(#[from] ConnectionError),
430 /// This was a 0-RTT stream and the server rejected it
431 ///
432 /// Can only occur on clients for 0-RTT streams, which can be opened using
433 /// [`Connecting::into_0rtt()`].
434 ///
435 /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
436 #[error("0-RTT rejected")]
437 ZeroRttRejected,
438}
439
440impl From<StoppedError> for io::Error {
441 fn from(x: StoppedError) -> Self {
442 use StoppedError::*;
443 let kind = match x {
444 ZeroRttRejected => io::ErrorKind::ConnectionReset,
445 ConnectionLost(_) => io::ErrorKind::NotConnected,
446 };
447 Self::new(kind, x)
448 }
449}
450
451pin_project! {
452 /// Future returned from [`SendStream::stopped`].
453 #[derive(Debug)]
454 pub struct Stopped {
455 conn: ConnectionRef,
456 stream: StreamId,
457 is_0rtt: bool,
458 #[pin]
459 notified: OwnedNotified,
460 }
461}
462
463impl Future for Stopped {
464 type Output = Result<Option<VarInt>, StoppedError>;
465
466 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
467 let mut this = self.project();
468 loop {
469 let mut conn = this.conn.lock_without_waking("SendStream::stopped");
470 // Check if the stream is stopped before polling the notify. This makes sure that
471 // no wakeups are missed.
472 if let Some(output) = send_stream_stopped(&mut conn, *this.stream, *this.is_0rtt) {
473 return Poll::Ready(output);
474 }
475 std::task::ready!(this.notified.as_mut().poll(cx));
476 }
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 fn check_is_send_sync<A: Send + Sync>() {}
483
484 #[allow(dead_code)]
485 fn test_bounds() {
486 check_is_send_sync::<super::Stopped>();
487 }
488}