noq/send_stream.rs
1use std::{
2 future::{Future, poll_fn},
3 io,
4 pin::{Pin, pin},
5 task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use pin_project_lite::pin_project;
10use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
11use thiserror::Error;
12use tokio::sync::futures::OwnedNotified;
13
14use crate::{
15 VarInt,
16 connection::{ConnectionRef, State},
17};
18
19/// A stream that can only be used to send data
20///
21/// If dropped, streams that haven't been explicitly [`reset()`] will be implicitly [`finish()`]ed,
22/// continuing to (re)transmit previously written data until it has been fully acknowledged or the
23/// connection is closed.
24///
25/// # Cancellation
26///
27/// A `write` method is said to be *cancel-safe* when dropping its future before the future becomes
28/// ready will always result in no data being written to the stream. This is true of methods which
29/// succeed immediately when any progress is made, and is not true of methods which might need to
30/// perform multiple writes internally before succeeding. Each `write` method documents whether it is
31/// cancel-safe.
32///
33/// [`reset()`]: SendStream::reset
34/// [`finish()`]: SendStream::finish
35#[derive(Debug)]
36pub struct SendStream {
37 conn: ConnectionRef,
38 stream: StreamId,
39 is_0rtt: bool,
40}
41
42impl SendStream {
43 pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
44 Self {
45 conn,
46 stream,
47 is_0rtt,
48 }
49 }
50
51 /// Write a buffer into this stream, returning how many bytes were written
52 ///
53 /// Unless this method errors, it waits until some amount of `buf` can be written into this
54 /// stream, and then writes as much as it can without waiting again. Due to congestion and flow
55 /// control, this may be shorter than `buf.len()`. On success this yields the length of the
56 /// prefix that was written.
57 ///
58 /// # Cancel safety
59 ///
60 /// This method is cancellation safe. If this does not resolve, no bytes were written.
61 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
62 poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
63 }
64
65 /// Write a buffer into this stream in its entirety
66 ///
67 /// This method repeatedly calls [`write`](Self::write) until all bytes are written, or an
68 /// error occurs.
69 ///
70 /// # Cancel safety
71 ///
72 /// This method is *not* cancellation safe. Even if this does not resolve, some prefix of `buf`
73 /// may have been written when previously polled.
74 pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
75 while !buf.is_empty() {
76 let written = self.write(buf).await?;
77 buf = &buf[written..];
78 }
79 Ok(())
80 }
81
82 /// Write a slice of [`Bytes`] into this stream, returning how much was written
83 ///
84 /// Bytes to try to write are provided to this method as an array of cheaply cloneable chunks.
85 /// Unless this method errors, it waits until some amount of those bytes can be written into
86 /// this stream, and then writes as much as it can without waiting again. Due to congestion and
87 /// flow control, this may be less than the total number of bytes.
88 ///
89 /// On success, this method both mutates `bufs` and yields an informative [`Written`] struct
90 /// indicating how much was written:
91 ///
92 /// - [`Bytes`] chunks that were fully written are mutated to be [empty](Bytes::is_empty).
93 /// - If a [`Bytes`] chunk was partially written, it is [split to](Bytes::split_to) contain
94 /// only the suffix of bytes that were not written.
95 /// - The yielded [`Written`] struct indicates how many chunks were fully written as well as
96 /// how many bytes were written.
97 ///
98 /// # Cancel safety
99 ///
100 /// This method is cancellation safe. If this does not resolve, no bytes were written.
101 pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
102 poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
103 }
104
105 /// Write a single [`Bytes`] into this stream in its entirety
106 ///
107 /// Bytes to write are provided to this method as an single cheaply cloneable chunk. This
108 /// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
109 /// or an error occurs.
110 ///
111 /// # Cancel safety
112 ///
113 /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
114 /// been written when previously polled.
115 pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
116 self.write_all_chunks(&mut [buf]).await?;
117 Ok(())
118 }
119
120 /// Write a slice of [`Bytes`] into this stream in its entirety
121 ///
122 /// Bytes to write are provided to this method as an array of cheaply cloneable chunks. This
123 /// method repeatedly calls [`write_chunks`](Self::write_chunks) until all bytes are written,
124 /// or an error occurs. This method mutates `bufs` by mutating all chunks to be
125 /// [empty](Bytes::is_empty).
126 ///
127 /// # Cancel safety
128 ///
129 /// This method is *not* cancellation safe. Even if this does not resolve, some bytes may have
130 /// been written when previously polled.
131 pub async fn write_all_chunks(&mut self, mut bufs: &mut [Bytes]) -> Result<(), WriteError> {
132 while !bufs.is_empty() {
133 let written = self.write_chunks(bufs).await?;
134 bufs = &mut bufs[written.chunks..];
135 }
136 Ok(())
137 }
138
139 fn execute_poll<F, R>(
140 &mut self,
141 cx: &mut Context<'_>,
142 write_fn: F,
143 ) -> Poll<Result<R, WriteError>>
144 where
145 F: FnOnce(&mut proto::SendStream<'_>) -> Result<R, proto::WriteError>,
146 {
147 use proto::WriteError::*;
148 let mut conn = self.conn.state.lock("SendStream::poll_write");
149 if self.is_0rtt {
150 conn.check_0rtt()
151 .map_err(|()| WriteError::ZeroRttRejected)?;
152 }
153 if let Some(ref x) = conn.error {
154 return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
155 }
156
157 let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
158 Ok(result) => result,
159 Err(Blocked) => {
160 conn.blocked_writers.insert(self.stream, cx.waker().clone());
161 return Poll::Pending;
162 }
163 Err(Stopped(error_code)) => {
164 return Poll::Ready(Err(WriteError::Stopped(error_code)));
165 }
166 Err(ClosedStream) => {
167 return Poll::Ready(Err(WriteError::ClosedStream));
168 }
169 };
170
171 conn.wake();
172 Poll::Ready(Ok(result))
173 }
174
175 /// Notify the peer that no more data will ever be written to this stream
176 ///
177 /// It is an error to write to a [`SendStream`] after `finish()`ing it. [`reset()`](Self::reset)
178 /// may still be called after `finish` to abandon transmission of any stream data that might
179 /// still be buffered.
180 ///
181 /// To wait for the peer to receive all buffered stream data, see [`stopped()`](Self::stopped).
182 ///
183 /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
184 /// called. This error is harmless and serves only to indicate that the caller may have
185 /// incorrect assumptions about the stream's state.
186 pub fn finish(&mut self) -> Result<(), ClosedStream> {
187 let mut conn = self.conn.state.lock("finish");
188 match conn.inner.send_stream(self.stream).finish() {
189 Ok(()) => {
190 conn.wake();
191 Ok(())
192 }
193 Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
194 // Harmless. If the application needs to know about stopped streams at this point, it
195 // should call `stopped`.
196 Err(FinishError::Stopped(_)) => Ok(()),
197 }
198 }
199
200 /// Close the send stream immediately.
201 ///
202 /// No new data can be written after calling this method. Locally buffered data is dropped, and
203 /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
204 /// already been made to finish the stream, the peer may still receive all written data.
205 ///
206 /// May fail if [`finish()`](Self::finish) or [`reset()`](Self::reset) was previously
207 /// called. This error is harmless and serves only to indicate that the caller may have
208 /// incorrect assumptions about the stream's state.
209 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
210 let mut conn = self.conn.state.lock("SendStream::reset");
211 if self.is_0rtt && conn.check_0rtt().is_err() {
212 return Ok(());
213 }
214 conn.inner.send_stream(self.stream).reset(error_code)?;
215 conn.wake();
216 Ok(())
217 }
218
219 /// Set the priority of the send stream
220 ///
221 /// Every send stream has an initial priority of 0. Locally buffered data from streams with
222 /// higher priority will be transmitted before data from streams with lower priority. Changing
223 /// the priority of a stream with pending data may only take effect after that data has been
224 /// transmitted. Using many different priority levels per connection may have a negative
225 /// impact on performance.
226 pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
227 let mut conn = self.conn.state.lock("SendStream::set_priority");
228 conn.inner.send_stream(self.stream).set_priority(priority)?;
229 Ok(())
230 }
231
232 /// Get the priority of the send stream
233 pub fn priority(&self) -> Result<i32, ClosedStream> {
234 let mut conn = self.conn.state.lock("SendStream::priority");
235 conn.inner.send_stream(self.stream).priority()
236 }
237
238 /// Completes when the peer stops the stream or reads the stream to completion
239 ///
240 /// Yields `Some` with the stop error code if the peer stops the stream. Yields `None` if the
241 /// local side [`finish()`](Self::finish)es the stream and then the peer acknowledges receipt
242 /// of all stream data (although not necessarily the processing of it), after which the peer
243 /// closing the stream is no longer meaningful.
244 ///
245 /// For a variety of reasons, the peer may not send acknowledgements immediately upon receiving
246 /// data. As such, relying on `stopped` to know when the peer has read a stream to completion
247 /// may introduce more latency than using an application-level response of some sort.
248 pub fn stopped(&self) -> Stopped {
249 let notified = {
250 // Create an `OwnedNotified` to move into the future. By creating it before the first poll,
251 // we make sure that we don't miss any notifications.
252 let mut conn = self.conn.state.lock("SendStream::stopped");
253 conn.stopped
254 .entry(self.stream)
255 .or_default()
256 .clone()
257 .notified_owned()
258 };
259 Stopped {
260 conn: self.conn.clone(),
261 stream: self.stream,
262 is_0rtt: self.is_0rtt,
263 notified,
264 }
265 }
266
267 /// Get the identity of this stream
268 pub fn id(&self) -> StreamId {
269 self.stream
270 }
271
272 /// Attempt to write bytes from buf into the stream.
273 ///
274 /// On success, returns Poll::Ready(Ok(num_bytes_written)).
275 ///
276 /// If the stream is not ready for writing, the method returns Poll::Pending and arranges
277 /// for the current task (via cx.waker().wake_by_ref()) to receive a notification when the
278 /// stream becomes writable or is closed.
279 pub fn poll_write(
280 self: Pin<&mut Self>,
281 cx: &mut Context<'_>,
282 buf: &[u8],
283 ) -> Poll<Result<usize, WriteError>> {
284 pin!(self.get_mut().write(buf)).as_mut().poll(cx)
285 }
286}
287
288/// Check if a send stream is stopped.
289///
290/// Returns `Some` if the stream is stopped or the connection is closed.
291/// Returns `None` if the stream is not stopped.
292fn send_stream_stopped(
293 conn: &mut State,
294 stream: StreamId,
295 is_0rtt: bool,
296) -> Option<Result<Option<VarInt>, StoppedError>> {
297 if is_0rtt && conn.check_0rtt().is_err() {
298 return Some(Err(StoppedError::ZeroRttRejected));
299 }
300 match conn.inner.send_stream(stream).stopped() {
301 Err(ClosedStream { .. }) => Some(Ok(None)),
302 Ok(Some(error_code)) => Some(Ok(Some(error_code))),
303 Ok(None) => conn.error.clone().map(|error| Err(error.into())),
304 }
305}
306
307#[cfg(feature = "futures-io")]
308impl futures_io::AsyncWrite for SendStream {
309 fn poll_write(
310 self: Pin<&mut Self>,
311 cx: &mut Context<'_>,
312 buf: &[u8],
313 ) -> Poll<io::Result<usize>> {
314 self.poll_write(cx, buf).map_err(Into::into)
315 }
316
317 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
318 Poll::Ready(Ok(()))
319 }
320
321 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
322 Poll::Ready(self.get_mut().finish().map_err(Into::into))
323 }
324}
325
326impl tokio::io::AsyncWrite for SendStream {
327 fn poll_write(
328 self: Pin<&mut Self>,
329 cx: &mut Context<'_>,
330 buf: &[u8],
331 ) -> Poll<io::Result<usize>> {
332 self.poll_write(cx, buf).map_err(Into::into)
333 }
334
335 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
336 Poll::Ready(Ok(()))
337 }
338
339 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
340 Poll::Ready(self.get_mut().finish().map_err(Into::into))
341 }
342}
343
344impl Drop for SendStream {
345 fn drop(&mut self) {
346 let mut conn = self.conn.state.lock("SendStream::drop");
347
348 // clean up any previously registered wakers
349 conn.blocked_writers.remove(&self.stream);
350
351 if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
352 return;
353 }
354 match conn.inner.send_stream(self.stream).finish() {
355 Ok(()) => conn.wake(),
356 Err(FinishError::Stopped(reason)) => {
357 if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
358 conn.wake();
359 }
360 }
361 // Already finished or reset, which is fine.
362 Err(FinishError::ClosedStream) => {}
363 }
364 }
365}
366
367/// Errors that arise from writing to a stream
368#[derive(Debug, Error, Clone, PartialEq, Eq)]
369pub enum WriteError {
370 /// The peer is no longer accepting data on this stream
371 ///
372 /// Carries an application-defined error code.
373 #[error("sending stopped by peer: error {0}")]
374 Stopped(VarInt),
375 /// The connection was lost
376 #[error("connection lost")]
377 ConnectionLost(#[from] ConnectionError),
378 /// The stream has already been finished or reset
379 #[error("closed stream")]
380 ClosedStream,
381 /// This was a 0-RTT stream and the server rejected it
382 ///
383 /// Can only occur on clients for 0-RTT streams, which can be opened using
384 /// [`Connecting::into_0rtt()`].
385 ///
386 /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
387 #[error("0-RTT rejected")]
388 ZeroRttRejected,
389}
390
391impl From<ClosedStream> for WriteError {
392 #[inline]
393 fn from(_: ClosedStream) -> Self {
394 Self::ClosedStream
395 }
396}
397
398impl From<StoppedError> for WriteError {
399 fn from(x: StoppedError) -> Self {
400 match x {
401 StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
402 StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
403 }
404 }
405}
406
407impl From<WriteError> for io::Error {
408 fn from(x: WriteError) -> Self {
409 use WriteError::*;
410 let kind = match x {
411 Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
412 ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
413 };
414 Self::new(kind, x)
415 }
416}
417
418/// Errors that arise while monitoring for a send stream stop from the peer
419#[derive(Debug, Error, Clone, PartialEq, Eq)]
420pub enum StoppedError {
421 /// The connection was lost
422 #[error("connection lost")]
423 ConnectionLost(#[from] ConnectionError),
424 /// This was a 0-RTT stream and the server rejected it
425 ///
426 /// Can only occur on clients for 0-RTT streams, which can be opened using
427 /// [`Connecting::into_0rtt()`].
428 ///
429 /// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
430 #[error("0-RTT rejected")]
431 ZeroRttRejected,
432}
433
434impl From<StoppedError> for io::Error {
435 fn from(x: StoppedError) -> Self {
436 use StoppedError::*;
437 let kind = match x {
438 ZeroRttRejected => io::ErrorKind::ConnectionReset,
439 ConnectionLost(_) => io::ErrorKind::NotConnected,
440 };
441 Self::new(kind, x)
442 }
443}
444
445pin_project! {
446 /// Future returned from [`SendStream::stopped`].
447 #[derive(Debug)]
448 pub struct Stopped {
449 conn: ConnectionRef,
450 stream: StreamId,
451 is_0rtt: bool,
452 #[pin]
453 notified: OwnedNotified,
454 }
455}
456
457impl Future for Stopped {
458 type Output = Result<Option<VarInt>, StoppedError>;
459
460 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
461 let mut this = self.project();
462 loop {
463 let mut conn = this.conn.state.lock("SendStream::stopped");
464 // Check if the stream is stopped before polling the notify. This makes sure that
465 // no wakeups are missed.
466 if let Some(output) = send_stream_stopped(&mut conn, *this.stream, *this.is_0rtt) {
467 return Poll::Ready(output);
468 }
469 std::task::ready!(this.notified.as_mut().poll(cx));
470 }
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 fn check_is_send_sync<A: Send + Sync>() {}
477
478 #[allow(dead_code)]
479 fn test_bounds() {
480 check_is_send_sync::<super::Stopped>();
481 }
482}