iroh_relay/protos/
streams.rs

1//! Implements logic for abstracting over a websocket stream that allows sending only [`Bytes`]-based
2//! messages.
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use n0_future::{Sink, Stream, ready};
10#[cfg(not(wasm_browser))]
11use tokio::io::{AsyncRead, AsyncWrite};
12use tracing::warn;
13
14use crate::ExportKeyingMaterial;
15
16#[cfg(not(wasm_browser))]
17#[derive(derive_more::Debug)]
18pub(crate) struct WsBytesFramed<T> {
19    #[debug("WebSocketStream<T>")]
20    pub(crate) io: tokio_websockets::WebSocketStream<T>,
21}
22
23#[cfg(wasm_browser)]
24#[derive(derive_more::Debug)]
25pub(crate) struct WsBytesFramed {
26    #[debug("WebSocketStream")]
27    pub(crate) io: ws_stream_wasm::WsStream,
28}
29
30#[cfg(not(wasm_browser))]
31pub(crate) type StreamError = tokio_websockets::Error;
32
33#[cfg(wasm_browser)]
34pub(crate) type StreamError = ws_stream_wasm::WsErr;
35
36/// Shorthand for a type that implements both a websocket-based stream & sink for [`Bytes`].
37pub(crate) trait BytesStreamSink:
38    Stream<Item = Result<Bytes, StreamError>> + Sink<Bytes, Error = StreamError> + Unpin
39{
40}
41
42impl<T> BytesStreamSink for T where
43    T: Stream<Item = Result<Bytes, StreamError>> + Sink<Bytes, Error = StreamError> + Unpin
44{
45}
46
47#[cfg(not(wasm_browser))]
48impl<IO: ExportKeyingMaterial + AsyncRead + AsyncWrite + Unpin> ExportKeyingMaterial
49    for WsBytesFramed<IO>
50{
51    fn export_keying_material<T: AsMut<[u8]>>(
52        &self,
53        output: T,
54        label: &[u8],
55        context: Option<&[u8]>,
56    ) -> Option<T> {
57        self.io
58            .get_ref()
59            .export_keying_material(output, label, context)
60    }
61}
62
63#[cfg(wasm_browser)]
64impl ExportKeyingMaterial for WsBytesFramed {
65    fn export_keying_material<T: AsMut<[u8]>>(
66        &self,
67        _output: T,
68        _label: &[u8],
69        _context: Option<&[u8]>,
70    ) -> Option<T> {
71        None
72    }
73}
74
75#[cfg(not(wasm_browser))]
76impl<T: AsyncRead + AsyncWrite + Unpin> Stream for WsBytesFramed<T> {
77    type Item = Result<Bytes, StreamError>;
78
79    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80        loop {
81            match ready!(Pin::new(&mut self.io).poll_next(cx)) {
82                None => return Poll::Ready(None),
83                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
84                Some(Ok(msg)) => {
85                    if msg.is_close() {
86                        // Indicate the stream is done when we receive a close message.
87                        // Note: We don't have to poll the stream to completion for it to close gracefully.
88                        return Poll::Ready(None);
89                    }
90                    if msg.is_ping() || msg.is_pong() {
91                        continue; // Responding appropriately to these is done inside of tokio_websockets/browser impls
92                    }
93                    if !msg.is_binary() {
94                        warn!(?msg, "Got websocket message of unsupported type, skipping.");
95                        continue;
96                    }
97                    return Poll::Ready(Some(Ok(msg.into_payload().into())));
98                }
99            }
100        }
101    }
102}
103
104#[cfg(wasm_browser)]
105impl Stream for WsBytesFramed {
106    type Item = Result<Bytes, StreamError>;
107
108    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
109        loop {
110            match ready!(Pin::new(&mut self.io).poll_next(cx)) {
111                None => return Poll::Ready(None),
112                Some(ws_stream_wasm::WsMessage::Binary(msg)) => {
113                    return Poll::Ready(Some(Ok(msg.into())));
114                }
115                Some(msg) => {
116                    warn!(?msg, "Got websocket message of unsupported type, skipping.");
117                    continue;
118                }
119            }
120        }
121    }
122}
123
124#[cfg(not(wasm_browser))]
125impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Bytes> for WsBytesFramed<T> {
126    type Error = StreamError;
127
128    fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> {
129        let msg = tokio_websockets::Message::binary(tokio_websockets::Payload::from(bytes));
130        Pin::new(&mut self.io).start_send(msg)
131    }
132
133    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134        Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into)
135    }
136
137    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138        Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into)
139    }
140
141    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142        Pin::new(&mut self.io).poll_close(cx).map_err(Into::into)
143    }
144}
145
146#[cfg(wasm_browser)]
147impl Sink<Bytes> for WsBytesFramed {
148    type Error = StreamError;
149
150    fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> {
151        let msg = ws_stream_wasm::WsMessage::Binary(Vec::from(bytes));
152        Pin::new(&mut self.io).start_send(msg).map_err(Into::into)
153    }
154
155    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into)
157    }
158
159    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160        Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into)
161    }
162
163    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
164        Pin::new(&mut self.io).poll_close(cx).map_err(Into::into)
165    }
166}