noq/
send_stream.rs

1use std::{
2    future::{Future, poll_fn},
3    io,
4    pin::{Pin, pin},
5    task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use pin_project_lite::pin_project;
10use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
11use thiserror::Error;
12use tokio::sync::futures::OwnedNotified;
13
14use crate::{
15    VarInt,
16    connection::{ConnectionRef, State},
17};
18
19/// A stream that can only be used to send data
20///
21/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
22/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
23/// connection is closed.
24///
25/// # Cancellation
26///
27/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
28/// ready will always result in no data being written to the stream. This is true of methods which
29/// succeed immediately when any progress is made, and is not true of methods which might need to
30/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
31/// cancel-safe.
32///
33/// [`reset()`]: SendStream::reset
34/// [`finish()`]: SendStream::finish
35#[derive(Debug)]
36pub struct SendStream {
37    conn: ConnectionRef,
38    stream: StreamId,
39    is_0rtt: bool,
40}
41
42impl SendStream {
43    pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
44        Self {
45            conn,
46            stream,
47            is_0rtt,
48        }
49    }
50
51    /// Write a buffer into this stream, returning how many bytes were written
52    ///
53    /// Unless this method errors, it waits until some amount of `buf` can be written into this
54    /// stream, and then writes as much as it can without waiting again. Due to congestion and flow
55    /// control, this may be shorter than `buf.len()`. On success this yields the length of the
56    /// prefix that was written.
57    ///
58    /// # Cancel safety
59    ///
60    /// This method is cancellation safe. If this does not resolve, no bytes were written.
61    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
62        poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
63    }
64
65    /// Write a buffer into this stream in its entirety
66    ///
67    /// This method repeatedly calls [`write`](Self::write) until all bytes are written, or an
68    /// error occurs.
69    ///
70    /// # Cancel safety
71    ///
72    /// This method is *not* cancellation safe. Even if this does not resolve, some prefix of `buf`
73    /// may have been written when previously polled.
74    pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
75        while !buf.is_empty() {
76            let written = self.write(buf).await?;
77            buf = &buf[written..];
78        }
79        Ok(())
80    }
81
82    /// Write a slice of [`Bytes`] into this stream, returning how much was written
83    ///
84    /// Bytes to try to write are provided to this method as an array of cheaply cloneable chunks.
85    /// Unless this method errors, it waits until some amount of those bytes can be written into
86    /// this stream, and then writes as much as it can without waiting again. Due to congestion and
87    /// flow control, this may be less than the total number of bytes.
88    ///
89    /// On success, this method both mutates `bufs` and yields an informative [`Written`] struct
90    /// indicating how much was written:
91    ///
92    /// - [`Bytes`] chunks that were fully written are mutated to be [empty](Bytes::is_empty).
93    /// - If a [`Bytes`] chunk was partially written, it is [split to](Bytes::split_to) contain
94    ///   only the suffix of bytes that were not written.
95    /// - The yielded [`Written`] struct indicates how many chunks were fully written as well as
96    ///   how many bytes were written.
97    ///
98    /// # Cancel safety
99    ///
100    /// This method is cancellation safe. If this does not resolve, no bytes were written.
101    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
102        poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
103    }
104
105    /// Write a single [`Bytes`] into this stream in its entirety
106    ///
107    /// Bytes to write are provided to this method as an single cheaply cloneable chunk. This
108    /// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
109    /// or an error occurs.
110    ///
111    /// # Cancel safety
112    ///
113    /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
114    /// been written when previously polled.
115    pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
116        self.write_all_chunks(&mut [buf]).await?;
117        Ok(())
118    }
119
120    /// Write a slice of [`Bytes`] into this stream in its entirety
121    ///
122    /// Bytes to write are provided to this method as an array of cheaply cloneable chunks. This
123    /// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
124    /// or an error occurs. This method mutates `bufs` by mutating all chunks to be
125    /// [empty](Bytes::is_empty).
126    ///
127    /// # Cancel safety
128    ///
129    /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
130    /// been written when previously polled.
131    pub async fn write_all_chunks(&mut self, mut bufs: &mut [Bytes]) -> Result<(), WriteError> {
132        while !bufs.is_empty() {
133            let written = self.write_chunks(bufs).await?;
134            bufs = &mut bufs[written.chunks..];
135        }
136        Ok(())
137    }
138
139    fn execute_poll<F, R>(
140        &mut self,
141        cx: &mut Context<'_>,
142        write_fn: F,
143    ) -> Poll<Result<R, WriteError>>
144    where
145        F: FnOnce(&mut proto::SendStream<'_>) -> Result<R, proto::WriteError>,
146    {
147        use proto::WriteError::*;
148        let mut conn = self.conn.state.lock("SendStream::poll_write");
149        if self.is_0rtt {
150            conn.check_0rtt()
151                .map_err(|()| WriteError::ZeroRttRejected)?;
152        }
153        if let Some(ref x) = conn.error {
154            return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
155        }
156
157        let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
158            Ok(result) => result,
159            Err(Blocked) => {
160                conn.blocked_writers.insert(self.stream, cx.waker().clone());
161                return Poll::Pending;
162            }
163            Err(Stopped(error_code)) => {
164                return Poll::Ready(Err(WriteError::Stopped(error_code)));
165            }
166            Err(ClosedStream) => {
167                return Poll::Ready(Err(WriteError::ClosedStream));
168            }
169        };
170
171        conn.wake();
172        Poll::Ready(Ok(result))
173    }
174
175    /// Notify the peer that no more data will ever be written to this stream
176    ///
177    /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
178    /// may still be called after `finish` to abandon transmission of any stream data that might
179    /// still be buffered.
180    ///
181    /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
182    ///
183    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
184    /// called. This error is harmless and serves only to indicate that the caller may have
185    /// incorrect assumptions about the stream's state.
186    pub fn finish(&mut self) -> Result<(), ClosedStream> {
187        let mut conn = self.conn.state.lock("finish");
188        match conn.inner.send_stream(self.stream).finish() {
189            Ok(()) => {
190                conn.wake();
191                Ok(())
192            }
193            Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
194            // Harmless. If the application needs to know about stopped streams at this point, it
195            // should call `stopped`.
196            Err(FinishError::Stopped(_)) => Ok(()),
197        }
198    }
199
200    /// Close the send stream immediately.
201    ///
202    /// No new data can be written after calling this method. Locally buffered data is dropped, and
203    /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
204    /// already been made to finish the stream, the peer may still receive all written data.
205    ///
206    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
207    /// called. This error is harmless and serves only to indicate that the caller may have
208    /// incorrect assumptions about the stream's state.
209    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
210        let mut conn = self.conn.state.lock("SendStream::reset");
211        if self.is_0rtt && conn.check_0rtt().is_err() {
212            return Ok(());
213        }
214        conn.inner.send_stream(self.stream).reset(error_code)?;
215        conn.wake();
216        Ok(())
217    }
218
219    /// Set the priority of the send stream
220    ///
221    /// Every send stream has an initial priority of 0. Locally buffered data from streams with
222    /// higher priority will be transmitted before data from streams with lower priority. Changing
223    /// the priority of a stream with pending data may only take effect after that data has been
224    /// transmitted. Using many different priority levels per connection may have a negative
225    /// impact on performance.
226    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
227        let mut conn = self.conn.state.lock("SendStream::set_priority");
228        conn.inner.send_stream(self.stream).set_priority(priority)?;
229        Ok(())
230    }
231
232    /// Get the priority of the send stream
233    pub fn priority(&self) -> Result<i32, ClosedStream> {
234        let mut conn = self.conn.state.lock("SendStream::priority");
235        conn.inner.send_stream(self.stream).priority()
236    }
237
238    /// Completes when the peer stops the stream or reads the stream to completion
239    ///
240    /// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
241    /// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
242    /// of all stream data (although not necessarily the processing of it), after which the peer
243    /// closing the stream is no longer meaningful.
244    ///
245    /// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
246    /// data. As such, relying on `stopped` to know when the peer has read a stream to completion
247    /// may introduce more latency than using an application-level response of some sort.
248    pub fn stopped(&self) -> Stopped {
249        let notified = {
250            // Create an `OwnedNotified` to move into the future. By creating it before the first poll,
251            // we make sure that we don't miss any notifications.
252            let mut conn = self.conn.state.lock("SendStream::stopped");
253            conn.stopped
254                .entry(self.stream)
255                .or_default()
256                .clone()
257                .notified_owned()
258        };
259        Stopped {
260            conn: self.conn.clone(),
261            stream: self.stream,
262            is_0rtt: self.is_0rtt,
263            notified,
264        }
265    }
266
267    /// Get the identity of this stream
268    pub fn id(&self) -> StreamId {
269        self.stream
270    }
271
272    /// Attempt to write bytes from buf into the stream.
273    ///
274    /// On success, returns Poll::Ready(Ok(num_bytes_written)).
275    ///
276    /// If the stream is not ready for writing, the method returns Poll::Pending and arranges
277    /// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
278    /// stream becomes writable or is closed.
279    pub fn poll_write(
280        self: Pin<&mut Self>,
281        cx: &mut Context<'_>,
282        buf: &[u8],
283    ) -> Poll<Result<usize, WriteError>> {
284        pin!(self.get_mut().write(buf)).as_mut().poll(cx)
285    }
286}
287
288/// Check if a send stream is stopped.
289///
290/// Returns `Some` if the stream is stopped or the connection is closed.
291/// Returns `None` if the stream is not stopped.
292fn send_stream_stopped(
293    conn: &mut State,
294    stream: StreamId,
295    is_0rtt: bool,
296) -> Option<Result<Option<VarInt>, StoppedError>> {
297    if is_0rtt && conn.check_0rtt().is_err() {
298        return Some(Err(StoppedError::ZeroRttRejected));
299    }
300    match conn.inner.send_stream(stream).stopped() {
301        Err(ClosedStream { .. }) => Some(Ok(None)),
302        Ok(Some(error_code)) => Some(Ok(Some(error_code))),
303        Ok(None) => conn.error.clone().map(|error| Err(error.into())),
304    }
305}
306
307#[cfg(feature = "futures-io")]
308impl futures_io::AsyncWrite for SendStream {
309    fn poll_write(
310        self: Pin<&mut Self>,
311        cx: &mut Context<'_>,
312        buf: &[u8],
313    ) -> Poll<io::Result<usize>> {
314        self.poll_write(cx, buf).map_err(Into::into)
315    }
316
317    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
318        Poll::Ready(Ok(()))
319    }
320
321    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
322        Poll::Ready(self.get_mut().finish().map_err(Into::into))
323    }
324}
325
326impl tokio::io::AsyncWrite for SendStream {
327    fn poll_write(
328        self: Pin<&mut Self>,
329        cx: &mut Context<'_>,
330        buf: &[u8],
331    ) -> Poll<io::Result<usize>> {
332        self.poll_write(cx, buf).map_err(Into::into)
333    }
334
335    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
336        Poll::Ready(Ok(()))
337    }
338
339    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
340        Poll::Ready(self.get_mut().finish().map_err(Into::into))
341    }
342}
343
344impl Drop for SendStream {
345    fn drop(&mut self) {
346        let mut conn = self.conn.state.lock("SendStream::drop");
347
348        // clean up any previously registered wakers
349        conn.blocked_writers.remove(&self.stream);
350
351        if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
352            return;
353        }
354        match conn.inner.send_stream(self.stream).finish() {
355            Ok(()) => conn.wake(),
356            Err(FinishError::Stopped(reason)) => {
357                if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
358                    conn.wake();
359                }
360            }
361            // Already finished or reset, which is fine.
362            Err(FinishError::ClosedStream) => {}
363        }
364    }
365}
366
367/// Errors that arise from writing to a stream
368#[derive(Debug, Error, Clone, PartialEq, Eq)]
369pub enum WriteError {
370    /// The peer is no longer accepting data on this stream
371    ///
372    /// Carries an application-defined error code.
373    #[error("sending stopped by peer: error {0}")]
374    Stopped(VarInt),
375    /// The connection was lost
376    #[error("connection lost")]
377    ConnectionLost(#[from] ConnectionError),
378    /// The stream has already been finished or reset
379    #[error("closed stream")]
380    ClosedStream,
381    /// This was a 0-RTT stream and the server rejected it
382    ///
383    /// Can only occur on clients for 0-RTT streams, which can be opened using
384    /// [`Connecting::into_0rtt()`].
385    ///
386    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
387    #[error("0-RTT rejected")]
388    ZeroRttRejected,
389}
390
391impl From<ClosedStream> for WriteError {
392    #[inline]
393    fn from(_: ClosedStream) -> Self {
394        Self::ClosedStream
395    }
396}
397
398impl From<StoppedError> for WriteError {
399    fn from(x: StoppedError) -> Self {
400        match x {
401            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
402            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
403        }
404    }
405}
406
407impl From<WriteError> for io::Error {
408    fn from(x: WriteError) -> Self {
409        use WriteError::*;
410        let kind = match x {
411            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
412            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
413        };
414        Self::new(kind, x)
415    }
416}
417
418/// Errors that arise while monitoring for a send stream stop from the peer
419#[derive(Debug, Error, Clone, PartialEq, Eq)]
420pub enum StoppedError {
421    /// The connection was lost
422    #[error("connection lost")]
423    ConnectionLost(#[from] ConnectionError),
424    /// This was a 0-RTT stream and the server rejected it
425    ///
426    /// Can only occur on clients for 0-RTT streams, which can be opened using
427    /// [`Connecting::into_0rtt()`].
428    ///
429    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
430    #[error("0-RTT rejected")]
431    ZeroRttRejected,
432}
433
434impl From<StoppedError> for io::Error {
435    fn from(x: StoppedError) -> Self {
436        use StoppedError::*;
437        let kind = match x {
438            ZeroRttRejected => io::ErrorKind::ConnectionReset,
439            ConnectionLost(_) => io::ErrorKind::NotConnected,
440        };
441        Self::new(kind, x)
442    }
443}
444
445pin_project! {
446    /// Future returned from [`SendStream::stopped`].
447    #[derive(Debug)]
448    pub struct Stopped {
449        conn: ConnectionRef,
450        stream: StreamId,
451        is_0rtt: bool,
452        #[pin]
453        notified: OwnedNotified,
454    }
455}
456
457impl Future for Stopped {
458    type Output = Result<Option<VarInt>, StoppedError>;
459
460    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
461        let mut this = self.project();
462        loop {
463            let mut conn = this.conn.state.lock("SendStream::stopped");
464            // Check if the stream is stopped before polling the notify. This makes sure that
465            // no wakeups are missed.
466            if let Some(output) = send_stream_stopped(&mut conn, *this.stream, *this.is_0rtt) {
467                return Poll::Ready(output);
468            }
469            std::task::ready!(this.notified.as_mut().poll(cx));
470        }
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    fn check_is_send_sync<A: Send + Sync>() {}
477
478    #[allow(dead_code)]
479    fn test_bounds() {
480        check_is_send_sync::<super::Stopped>();
481    }
482}