noq/
send_stream.rs

1use std::{
2    future::{Future, poll_fn},
3    io,
4    pin::{Pin, pin},
5    task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use pin_project_lite::pin_project;
10use proto::{ClosedStream, ConnectionError, FinishError, StreamId};
11use thiserror::Error;
12use tokio::sync::futures::OwnedNotified;
13
14use crate::{
15    VarInt,
16    connection::{ConnectionRef, State},
17};
18
19/// A stream that can only be used to send data
20///
21/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
22/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
23/// connection is closed.
24///
25/// # Cancellation
26///
27/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
28/// ready will always result in no data being written to the stream. This is true of methods which
29/// succeed immediately when any progress is made, and is not true of methods which might need to
30/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
31/// cancel-safe.
32///
33/// [`reset()`]: SendStream::reset
34/// [`finish()`]: SendStream::finish
35#[derive(Debug)]
36pub struct SendStream {
37    conn: ConnectionRef,
38    stream: StreamId,
39    is_0rtt: bool,
40}
41
42impl SendStream {
43    pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
44        Self {
45            conn,
46            stream,
47            is_0rtt,
48        }
49    }
50
51    /// Write a buffer into this stream, returning how many bytes were written
52    ///
53    /// Unless this method errors, it waits until some amount of `buf` can be written into this
54    /// stream, and then writes as much as it can without waiting again. Due to congestion and flow
55    /// control, this may be shorter than `buf.len()`. On success this yields the length of the
56    /// prefix that was written.
57    ///
58    /// # Cancel safety
59    ///
60    /// This method is cancellation safe. If this does not resolve, no bytes were written.
61    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
62        poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
63    }
64
65    /// Write a buffer into this stream in its entirety
66    ///
67    /// This method repeatedly calls [`write`](Self::write) until all bytes are written, or an
68    /// error occurs.
69    ///
70    /// # Cancel safety
71    ///
72    /// This method is *not* cancellation safe. Even if this does not resolve, some prefix of `buf`
73    /// may have been written when previously polled.
74    pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
75        while !buf.is_empty() {
76            let written = self.write(buf).await?;
77            buf = &buf[written..];
78        }
79        Ok(())
80    }
81
82    /// Writes [`Bytes`] from a slice of buffers into this stream, returning how many bytes were.
83    /// written
84    ///
85    /// Bytes to try to write are provided to this method as an array of cheaply cloneable chunks.
86    /// Unless this method errors, it waits until some amount of those bytes can be written into
87    /// this stream, and then writes as much as it can without waiting again. Due to congestion and
88    /// flow control, this may be less than the total number of bytes.
89    ///
90    /// On success, this method both mutates `bufs` and returns the number of bytes written:
91    ///
92    /// - `bufs` is advanced past chunks that were fully written.
93    /// - If a [`Bytes`] chunk was partially written, the chunk at the new front of `bufs` is
94    ///   [split to](Bytes::split_to) contain only the suffix of bytes that were not written.
95    ///
96    /// # Cancel safety
97    ///
98    /// This method is cancellation safe. If this does not resolve, no bytes were written.
99    pub async fn write_many_chunks(
100        &mut self,
101        bufs: &mut &mut [Bytes],
102    ) -> Result<usize, WriteError> {
103        poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
104    }
105
106    /// Writes a single [`Bytes`] into this stream in its entirety.
107    ///
108    /// Bytes to write are provided to this method as a single cheaply cloneable chunk. This
109    /// method repeatedly calls [`write_many_chunks`](Self::write_many_chunks) until all bytes
110    /// are written, or an error occurs.
111    ///
112    /// # Cancel safety
113    ///
114    /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
115    /// been written when previously polled.
116    pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
117        self.write_all_chunks(&mut [buf]).await
118    }
119
120    /// Writes a slice of [`Bytes`] into this stream in its entirety.
121    ///
122    /// Bytes to write are provided to this method as an array of cheaply cloneable chunks. This
123    /// method repeatedly calls [`write_many_chunks`](Self::write_many_chunks) until all bytes are
124    /// written, or an error occurs.
125    ///
126    /// # Cancel safety
127    ///
128    /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
129    /// been written when previously polled.
130    pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
131        let mut bufs = &mut bufs[..];
132        while !bufs.is_empty() {
133            self.write_many_chunks(&mut bufs).await?;
134        }
135        Ok(())
136    }
137
138    fn execute_poll<F, R>(
139        &mut self,
140        cx: &mut Context<'_>,
141        write_fn: F,
142    ) -> Poll<Result<R, WriteError>>
143    where
144        F: FnOnce(&mut proto::SendStream<'_>) -> Result<R, proto::WriteError>,
145    {
146        use proto::WriteError::*;
147        let mut conn = self.conn.lock_and_wake("SendStream::poll_write");
148        if self.is_0rtt && conn.check_0rtt().is_err() {
149            conn.skip_waking();
150            return Poll::Ready(Err(WriteError::ZeroRttRejected));
151        }
152        if let Some(conn_err) = conn.error.clone() {
153            conn.skip_waking();
154            return Poll::Ready(Err(WriteError::ConnectionLost(conn_err)));
155        }
156
157        let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
158            Ok(result) => result,
159            Err(Blocked) => {
160                conn.blocked_writers.insert(self.stream, cx.waker().clone());
161                conn.skip_waking();
162                return Poll::Pending;
163            }
164            Err(Stopped(error_code)) => {
165                conn.skip_waking();
166                return Poll::Ready(Err(WriteError::Stopped(error_code)));
167            }
168            Err(ClosedStream) => {
169                conn.skip_waking();
170                return Poll::Ready(Err(WriteError::ClosedStream));
171            }
172        };
173
174        Poll::Ready(Ok(result))
175    }
176
177    /// Notify the peer that no more data will ever be written to this stream
178    ///
179    /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
180    /// may still be called after `finish` to abandon transmission of any stream data that might
181    /// still be buffered.
182    ///
183    /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
184    ///
185    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
186    /// called. This error is harmless and serves only to indicate that the caller may have
187    /// incorrect assumptions about the stream's state.
188    pub fn finish(&mut self) -> Result<(), ClosedStream> {
189        let mut conn = self.conn.lock_and_wake("finish");
190        if let Err(e) = conn.inner.send_stream(self.stream).finish() {
191            conn.skip_waking();
192            match e {
193                FinishError::ClosedStream => Err(ClosedStream::default()),
194                // Harmless. If the application needs to know about stopped streams at this point, it
195                // should call `stopped`.
196                FinishError::Stopped(_) => Ok(()),
197            }
198        } else {
199            Ok(())
200        }
201    }
202
203    /// Close the send stream immediately.
204    ///
205    /// No new data can be written after calling this method. Locally buffered data is dropped, and
206    /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
207    /// already been made to finish the stream, the peer may still receive all written data.
208    ///
209    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
210    /// called. This error is harmless and serves only to indicate that the caller may have
211    /// incorrect assumptions about the stream's state.
212    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
213        let mut conn = self.conn.lock_and_wake("SendStream::reset");
214        if self.is_0rtt && conn.check_0rtt().is_err() {
215            conn.skip_waking();
216            return Ok(());
217        }
218        conn.inner.send_stream(self.stream).reset(error_code)?;
219        Ok(())
220    }
221
222    /// Set the priority of the send stream
223    ///
224    /// Every send stream has an initial priority of 0. Locally buffered data from streams with
225    /// higher priority will be transmitted before data from streams with lower priority. Changing
226    /// the priority of a stream with pending data may only take effect after that data has been
227    /// transmitted. Using many different priority levels per connection may have a negative
228    /// impact on performance.
229    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
230        let mut conn = self.conn.lock_without_waking("SendStream::set_priority");
231        conn.inner.send_stream(self.stream).set_priority(priority)?;
232        Ok(())
233    }
234
235    /// Get the priority of the send stream
236    pub fn priority(&self) -> Result<i32, ClosedStream> {
237        let mut conn = self.conn.lock_without_waking("SendStream::priority");
238        conn.inner.send_stream(self.stream).priority()
239    }
240
241    /// Completes when the peer stops the stream or reads the stream to completion
242    ///
243    /// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
244    /// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
245    /// of all stream data (although not necessarily the processing of it), after which the peer
246    /// closing the stream is no longer meaningful.
247    ///
248    /// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
249    /// data. As such, relying on `stopped` to know when the peer has read a stream to completion
250    /// may introduce more latency than using an application-level response of some sort.
251    pub fn stopped(&self) -> Stopped {
252        let notified = {
253            // Create an `OwnedNotified` to move into the future. By creating it before the first poll,
254            // we make sure that we don't miss any notifications.
255            let mut conn = self.conn.lock_without_waking("SendStream::stopped");
256            conn.stopped
257                .entry(self.stream)
258                .or_default()
259                .clone()
260                .notified_owned()
261        };
262        Stopped {
263            conn: self.conn.clone(),
264            stream: self.stream,
265            is_0rtt: self.is_0rtt,
266            notified,
267        }
268    }
269
270    /// Get the identity of this stream
271    pub fn id(&self) -> StreamId {
272        self.stream
273    }
274
275    /// Attempt to write bytes from buf into the stream.
276    ///
277    /// On success, returns Poll::Ready(Ok(num_bytes_written)).
278    ///
279    /// If the stream is not ready for writing, the method returns Poll::Pending and arranges
280    /// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
281    /// stream becomes writable or is closed.
282    pub fn poll_write(
283        self: Pin<&mut Self>,
284        cx: &mut Context<'_>,
285        buf: &[u8],
286    ) -> Poll<Result<usize, WriteError>> {
287        pin!(self.get_mut().write(buf)).as_mut().poll(cx)
288    }
289}
290
291/// Check if a send stream is stopped.
292///
293/// Returns `Some` if the stream is stopped or the connection is closed.
294/// Returns `None` if the stream is not stopped.
295fn send_stream_stopped(
296    conn: &mut State,
297    stream: StreamId,
298    is_0rtt: bool,
299) -> Option<Result<Option<VarInt>, StoppedError>> {
300    if is_0rtt && conn.check_0rtt().is_err() {
301        return Some(Err(StoppedError::ZeroRttRejected));
302    }
303    match conn.inner.send_stream(stream).stopped() {
304        Err(ClosedStream { .. }) => Some(Ok(None)),
305        Ok(Some(error_code)) => Some(Ok(Some(error_code))),
306        Ok(None) => conn.error.clone().map(|error| Err(error.into())),
307    }
308}
309
310#[cfg(feature = "futures-io")]
311impl futures_io::AsyncWrite for SendStream {
312    fn poll_write(
313        self: Pin<&mut Self>,
314        cx: &mut Context<'_>,
315        buf: &[u8],
316    ) -> Poll<io::Result<usize>> {
317        self.poll_write(cx, buf).map_err(Into::into)
318    }
319
320    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
321        Poll::Ready(Ok(()))
322    }
323
324    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
325        Poll::Ready(self.get_mut().finish().map_err(Into::into))
326    }
327}
328
329impl tokio::io::AsyncWrite for SendStream {
330    fn poll_write(
331        self: Pin<&mut Self>,
332        cx: &mut Context<'_>,
333        buf: &[u8],
334    ) -> Poll<io::Result<usize>> {
335        self.poll_write(cx, buf).map_err(Into::into)
336    }
337
338    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
339        Poll::Ready(Ok(()))
340    }
341
342    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
343        Poll::Ready(self.get_mut().finish().map_err(Into::into))
344    }
345}
346
347impl Drop for SendStream {
348    fn drop(&mut self) {
349        let mut conn = self.conn.lock_and_wake("SendStream::drop");
350
351        // clean up any previously registered wakers
352        conn.blocked_writers.remove(&self.stream);
353
354        if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
355            conn.skip_waking();
356            return;
357        }
358        match conn.inner.send_stream(self.stream).finish() {
359            Ok(()) => {}
360            Err(FinishError::Stopped(reason)) => {
361                if conn.inner.send_stream(self.stream).reset(reason).is_err() {
362                    conn.skip_waking()
363                }
364            }
365            // Already finished or reset, which is fine.
366            Err(FinishError::ClosedStream) => {
367                conn.skip_waking();
368            }
369        }
370    }
371}
372
373/// Errors that arise from writing to a stream
374#[derive(Debug, Error, Clone, PartialEq, Eq)]
375pub enum WriteError {
376    /// The peer is no longer accepting data on this stream
377    ///
378    /// Carries an application-defined error code.
379    #[error("sending stopped by peer: error {0}")]
380    Stopped(VarInt),
381    /// The connection was lost
382    #[error("connection lost")]
383    ConnectionLost(#[from] ConnectionError),
384    /// The stream has already been finished or reset
385    #[error("closed stream")]
386    ClosedStream,
387    /// This was a 0-RTT stream and the server rejected it
388    ///
389    /// Can only occur on clients for 0-RTT streams, which can be opened using
390    /// [`Connecting::into_0rtt()`].
391    ///
392    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
393    #[error("0-RTT rejected")]
394    ZeroRttRejected,
395}
396
397impl From<ClosedStream> for WriteError {
398    #[inline]
399    fn from(_: ClosedStream) -> Self {
400        Self::ClosedStream
401    }
402}
403
404impl From<StoppedError> for WriteError {
405    fn from(x: StoppedError) -> Self {
406        match x {
407            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
408            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
409        }
410    }
411}
412
413impl From<WriteError> for io::Error {
414    fn from(x: WriteError) -> Self {
415        use WriteError::*;
416        let kind = match x {
417            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
418            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
419        };
420        Self::new(kind, x)
421    }
422}
423
424/// Errors that arise while monitoring for a send stream stop from the peer
425#[derive(Debug, Error, Clone, PartialEq, Eq)]
426pub enum StoppedError {
427    /// The connection was lost
428    #[error("connection lost")]
429    ConnectionLost(#[from] ConnectionError),
430    /// This was a 0-RTT stream and the server rejected it
431    ///
432    /// Can only occur on clients for 0-RTT streams, which can be opened using
433    /// [`Connecting::into_0rtt()`].
434    ///
435    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
436    #[error("0-RTT rejected")]
437    ZeroRttRejected,
438}
439
440impl From<StoppedError> for io::Error {
441    fn from(x: StoppedError) -> Self {
442        use StoppedError::*;
443        let kind = match x {
444            ZeroRttRejected => io::ErrorKind::ConnectionReset,
445            ConnectionLost(_) => io::ErrorKind::NotConnected,
446        };
447        Self::new(kind, x)
448    }
449}
450
451pin_project! {
452    /// Future returned from [`SendStream::stopped`].
453    #[derive(Debug)]
454    pub struct Stopped {
455        conn: ConnectionRef,
456        stream: StreamId,
457        is_0rtt: bool,
458        #[pin]
459        notified: OwnedNotified,
460    }
461}
462
463impl Future for Stopped {
464    type Output = Result<Option<VarInt>, StoppedError>;
465
466    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
467        let mut this = self.project();
468        loop {
469            let mut conn = this.conn.lock_without_waking("SendStream::stopped");
470            // Check if the stream is stopped before polling the notify. This makes sure that
471            // no wakeups are missed.
472            if let Some(output) = send_stream_stopped(&mut conn, *this.stream, *this.is_0rtt) {
473                return Poll::Ready(output);
474            }
475            std::task::ready!(this.notified.as_mut().poll(cx));
476        }
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    fn check_is_send_sync<A: Send + Sync>() {}
483
484    #[allow(dead_code)]
485    fn test_bounds() {
486        check_is_send_sync::<super::Stopped>();
487    }
488}