iroh_quinn/
send_stream.rs

1use std::{
2    future::{Future, poll_fn},
3    io,
4    pin::{Pin, pin},
5    task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
10use thiserror::Error;
11
12use crate::{
13    VarInt,
14    connection::{ConnectionRef, State},
15};
16
17/// A stream that can only be used to send data
18///
19/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
20/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
21/// connection is closed.
22///
23/// # Cancellation
24///
25/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
26/// ready will always result in no data being written to the stream. This is true of methods which
27/// succeed immediately when any progress is made, and is not true of methods which might need to
28/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
29/// cancel-safe.
30///
31/// [`reset()`]: SendStream::reset
32/// [`finish()`]: SendStream::finish
33#[derive(Debug)]
34pub struct SendStream {
35    conn: ConnectionRef,
36    stream: StreamId,
37    is_0rtt: bool,
38}
39
40impl SendStream {
41    pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
42        Self {
43            conn,
44            stream,
45            is_0rtt,
46        }
47    }
48
49    /// Write a buffer into this stream, returning how many bytes were written
50    ///
51    /// Unless this method errors, it waits until some amount of `buf` can be written into this
52    /// stream, and then writes as much as it can without waiting again. Due to congestion and flow
53    /// control, this may be shorter than `buf.len()`. On success this yields the length of the
54    /// prefix that was written.
55    ///
56    /// # Cancel safety
57    ///
58    /// This method is cancellation safe. If this does not resolve, no bytes were written.
59    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
60        poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
61    }
62
63    /// Write a buffer into this stream in its entirety
64    ///
65    /// This method repeatedly calls [`write`](Self::write) until all bytes are written, or an
66    /// error occurs.
67    ///
68    /// # Cancel safety
69    ///
70    /// This method is *not* cancellation safe. Even if this does not resolve, some prefix of `buf`
71    /// may have been written when previously polled.
72    pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
73        while !buf.is_empty() {
74            let written = self.write(buf).await?;
75            buf = &buf[written..];
76        }
77        Ok(())
78    }
79
80    /// Write a slice of [`Bytes`] into this stream, returning how much was written
81    ///
82    /// Bytes to try to write are provided to this method as an array of cheaply cloneable chunks.
83    /// Unless this method errors, it waits until some amount of those bytes can be written into
84    /// this stream, and then writes as much as it can without waiting again. Due to congestion and
85    /// flow control, this may be less than the total number of bytes.
86    ///
87    /// On success, this method both mutates `bufs` and yields an informative [`Written`] struct
88    /// indicating how much was written:
89    ///
90    /// - [`Bytes`] chunks that were fully written are mutated to be [empty](Bytes::is_empty).
91    /// - If a [`Bytes`] chunk was partially written, it is [split to](Bytes::split_to) contain
92    ///   only the suffix of bytes that were not written.
93    /// - The yielded [`Written`] struct indicates how many chunks were fully written as well as
94    ///   how many bytes were written.
95    ///
96    /// # Cancel safety
97    ///
98    /// This method is cancellation safe. If this does not resolve, no bytes were written.
99    pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
100        poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
101    }
102
103    /// Write a single [`Bytes`] into this stream in its entirety
104    ///
105    /// Bytes to write are provided to this method as an single cheaply cloneable chunk. This
106    /// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
107    /// or an error occurs.
108    ///
109    /// # Cancel safety
110    ///
111    /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
112    /// been written when previously polled.
113    pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
114        self.write_all_chunks(&mut [buf]).await?;
115        Ok(())
116    }
117
118    /// Write a slice of [`Bytes`] into this stream in its entirety
119    ///
120    /// Bytes to write are provided to this method as an array of cheaply cloneable chunks. This
121    /// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
122    /// or an error occurs. This method mutates `bufs` by mutating all chunks to be
123    /// [empty](Bytes::is_empty).
124    ///
125    /// # Cancel safety
126    ///
127    /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
128    /// been written when previously polled.
129    pub async fn write_all_chunks(&mut self, mut bufs: &mut [Bytes]) -> Result<(), WriteError> {
130        while !bufs.is_empty() {
131            let written = self.write_chunks(bufs).await?;
132            bufs = &mut bufs[written.chunks..];
133        }
134        Ok(())
135    }
136
137    fn execute_poll<F, R>(
138        &mut self,
139        cx: &mut Context<'_>,
140        write_fn: F,
141    ) -> Poll<Result<R, WriteError>>
142    where
143        F: FnOnce(&mut proto::SendStream<'_>) -> Result<R, proto::WriteError>,
144    {
145        use proto::WriteError::*;
146        let mut conn = self.conn.state.lock("SendStream::poll_write");
147        if self.is_0rtt {
148            conn.check_0rtt()
149                .map_err(|()| WriteError::ZeroRttRejected)?;
150        }
151        if let Some(ref x) = conn.error {
152            return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
153        }
154
155        let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
156            Ok(result) => result,
157            Err(Blocked) => {
158                conn.blocked_writers.insert(self.stream, cx.waker().clone());
159                return Poll::Pending;
160            }
161            Err(Stopped(error_code)) => {
162                return Poll::Ready(Err(WriteError::Stopped(error_code)));
163            }
164            Err(ClosedStream) => {
165                return Poll::Ready(Err(WriteError::ClosedStream));
166            }
167        };
168
169        conn.wake();
170        Poll::Ready(Ok(result))
171    }
172
173    /// Notify the peer that no more data will ever be written to this stream
174    ///
175    /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
176    /// may still be called after `finish` to abandon transmission of any stream data that might
177    /// still be buffered.
178    ///
179    /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
180    ///
181    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
182    /// called. This error is harmless and serves only to indicate that the caller may have
183    /// incorrect assumptions about the stream's state.
184    pub fn finish(&mut self) -> Result<(), ClosedStream> {
185        let mut conn = self.conn.state.lock("finish");
186        match conn.inner.send_stream(self.stream).finish() {
187            Ok(()) => {
188                conn.wake();
189                Ok(())
190            }
191            Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
192            // Harmless. If the application needs to know about stopped streams at this point, it
193            // should call `stopped`.
194            Err(FinishError::Stopped(_)) => Ok(()),
195        }
196    }
197
198    /// Close the send stream immediately.
199    ///
200    /// No new data can be written after calling this method. Locally buffered data is dropped, and
201    /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
202    /// already been made to finish the stream, the peer may still receive all written data.
203    ///
204    /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
205    /// called. This error is harmless and serves only to indicate that the caller may have
206    /// incorrect assumptions about the stream's state.
207    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
208        let mut conn = self.conn.state.lock("SendStream::reset");
209        if self.is_0rtt && conn.check_0rtt().is_err() {
210            return Ok(());
211        }
212        conn.inner.send_stream(self.stream).reset(error_code)?;
213        conn.wake();
214        Ok(())
215    }
216
217    /// Set the priority of the send stream
218    ///
219    /// Every send stream has an initial priority of 0. Locally buffered data from streams with
220    /// higher priority will be transmitted before data from streams with lower priority. Changing
221    /// the priority of a stream with pending data may only take effect after that data has been
222    /// transmitted. Using many different priority levels per connection may have a negative
223    /// impact on performance.
224    pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
225        let mut conn = self.conn.state.lock("SendStream::set_priority");
226        conn.inner.send_stream(self.stream).set_priority(priority)?;
227        Ok(())
228    }
229
230    /// Get the priority of the send stream
231    pub fn priority(&self) -> Result<i32, ClosedStream> {
232        let mut conn = self.conn.state.lock("SendStream::priority");
233        conn.inner.send_stream(self.stream).priority()
234    }
235
236    /// Completes when the peer stops the stream or reads the stream to completion
237    ///
238    /// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
239    /// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
240    /// of all stream data (although not necessarily the processing of it), after which the peer
241    /// closing the stream is no longer meaningful.
242    ///
243    /// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
244    /// data. As such, relying on `stopped` to know when the peer has read a stream to completion
245    /// may introduce more latency than using an application-level response of some sort.
246    pub fn stopped(
247        &self,
248    ) -> impl Future<Output = Result<Option<VarInt>, StoppedError>> + Send + Sync + 'static {
249        let conn = self.conn.clone();
250        let stream = self.stream;
251        let is_0rtt = self.is_0rtt;
252        async move {
253            loop {
254                // The `Notify::notified` future needs to be created while the lock is being held,
255                // otherwise a wakeup could be missed if triggered in between releasing the lock
256                // and creating the future.
257                // The lock may only be held in a block without `await`s, otherwise the future
258                // becomes `!Send`. `Notify::notified` is lifetime-bound to `Notify`, therefore
259                // we need to declare `notify` outside of the block, and initialize it inside.
260                let notify;
261                {
262                    let mut conn = conn.state.lock("SendStream::stopped");
263                    if let Some(output) = send_stream_stopped(&mut conn, stream, is_0rtt) {
264                        return output;
265                    }
266
267                    notify = conn.stopped.entry(stream).or_default().clone();
268                    notify.notified()
269                }
270                .await
271            }
272        }
273    }
274
275    /// Get the identity of this stream
276    pub fn id(&self) -> StreamId {
277        self.stream
278    }
279
280    /// Attempt to write bytes from buf into the stream.
281    ///
282    /// On success, returns Poll::Ready(Ok(num_bytes_written)).
283    ///
284    /// If the stream is not ready for writing, the method returns Poll::Pending and arranges
285    /// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
286    /// stream becomes writable or is closed.
287    pub fn poll_write(
288        self: Pin<&mut Self>,
289        cx: &mut Context<'_>,
290        buf: &[u8],
291    ) -> Poll<Result<usize, WriteError>> {
292        pin!(self.get_mut().write(buf)).as_mut().poll(cx)
293    }
294}
295
296/// Check if a send stream is stopped.
297///
298/// Returns `Some` if the stream is stopped or the connection is closed.
299/// Returns `None` if the stream is not stopped.
300fn send_stream_stopped(
301    conn: &mut State,
302    stream: StreamId,
303    is_0rtt: bool,
304) -> Option<Result<Option<VarInt>, StoppedError>> {
305    if is_0rtt && conn.check_0rtt().is_err() {
306        return Some(Err(StoppedError::ZeroRttRejected));
307    }
308    match conn.inner.send_stream(stream).stopped() {
309        Err(ClosedStream { .. }) => Some(Ok(None)),
310        Ok(Some(error_code)) => Some(Ok(Some(error_code))),
311        Ok(None) => conn.error.clone().map(|error| Err(error.into())),
312    }
313}
314
315#[cfg(feature = "futures-io")]
316impl futures_io::AsyncWrite for SendStream {
317    fn poll_write(
318        self: Pin<&mut Self>,
319        cx: &mut Context<'_>,
320        buf: &[u8],
321    ) -> Poll<io::Result<usize>> {
322        self.poll_write(cx, buf).map_err(Into::into)
323    }
324
325    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
326        Poll::Ready(Ok(()))
327    }
328
329    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
330        Poll::Ready(self.get_mut().finish().map_err(Into::into))
331    }
332}
333
334impl tokio::io::AsyncWrite for SendStream {
335    fn poll_write(
336        self: Pin<&mut Self>,
337        cx: &mut Context<'_>,
338        buf: &[u8],
339    ) -> Poll<io::Result<usize>> {
340        self.poll_write(cx, buf).map_err(Into::into)
341    }
342
343    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
344        Poll::Ready(Ok(()))
345    }
346
347    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
348        Poll::Ready(self.get_mut().finish().map_err(Into::into))
349    }
350}
351
352impl Drop for SendStream {
353    fn drop(&mut self) {
354        let mut conn = self.conn.state.lock("SendStream::drop");
355
356        // clean up any previously registered wakers
357        conn.blocked_writers.remove(&self.stream);
358
359        if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
360            return;
361        }
362        match conn.inner.send_stream(self.stream).finish() {
363            Ok(()) => conn.wake(),
364            Err(FinishError::Stopped(reason)) => {
365                if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
366                    conn.wake();
367                }
368            }
369            // Already finished or reset, which is fine.
370            Err(FinishError::ClosedStream) => {}
371        }
372    }
373}
374
375/// Errors that arise from writing to a stream
376#[derive(Debug, Error, Clone, PartialEq, Eq)]
377pub enum WriteError {
378    /// The peer is no longer accepting data on this stream
379    ///
380    /// Carries an application-defined error code.
381    #[error("sending stopped by peer: error {0}")]
382    Stopped(VarInt),
383    /// The connection was lost
384    #[error("connection lost")]
385    ConnectionLost(#[from] ConnectionError),
386    /// The stream has already been finished or reset
387    #[error("closed stream")]
388    ClosedStream,
389    /// This was a 0-RTT stream and the server rejected it
390    ///
391    /// Can only occur on clients for 0-RTT streams, which can be opened using
392    /// [`Connecting::into_0rtt()`].
393    ///
394    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
395    #[error("0-RTT rejected")]
396    ZeroRttRejected,
397}
398
399impl From<ClosedStream> for WriteError {
400    #[inline]
401    fn from(_: ClosedStream) -> Self {
402        Self::ClosedStream
403    }
404}
405
406impl From<StoppedError> for WriteError {
407    fn from(x: StoppedError) -> Self {
408        match x {
409            StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
410            StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
411        }
412    }
413}
414
415impl From<WriteError> for io::Error {
416    fn from(x: WriteError) -> Self {
417        use WriteError::*;
418        let kind = match x {
419            Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
420            ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
421        };
422        Self::new(kind, x)
423    }
424}
425
426/// Errors that arise while monitoring for a send stream stop from the peer
427#[derive(Debug, Error, Clone, PartialEq, Eq)]
428pub enum StoppedError {
429    /// The connection was lost
430    #[error("connection lost")]
431    ConnectionLost(#[from] ConnectionError),
432    /// This was a 0-RTT stream and the server rejected it
433    ///
434    /// Can only occur on clients for 0-RTT streams, which can be opened using
435    /// [`Connecting::into_0rtt()`].
436    ///
437    /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
438    #[error("0-RTT rejected")]
439    ZeroRttRejected,
440}
441
442impl From<StoppedError> for io::Error {
443    fn from(x: StoppedError) -> Self {
444        use StoppedError::*;
445        let kind = match x {
446            ZeroRttRejected => io::ErrorKind::ConnectionReset,
447            ConnectionLost(_) => io::ErrorKind::NotConnected,
448        };
449        Self::new(kind, x)
450    }
451}