iroh_relay/protos/
streams.rs

1//! Implements logic for abstracting over a websocket stream that allows sending only [`Bytes`]-based
2//! messages.
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use bytes::Bytes;
9use n0_future::{Sink, Stream, ready};
10#[cfg(not(wasm_browser))]
11use tokio::io::{AsyncRead, AsyncWrite};
12use tracing::warn;
13
14use crate::ExportKeyingMaterial;
15
16#[cfg(not(wasm_browser))]
17#[derive(derive_more::Debug)]
18pub(crate) struct WsBytesFramed<T> {
19    #[debug("WebSocketStream<T>")]
20    pub(crate) io: tokio_websockets::WebSocketStream<T>,
21}
22
23#[cfg(wasm_browser)]
24#[derive(derive_more::Debug)]
25pub(crate) struct WsBytesFramed {
26    #[debug("WebSocketStream")]
27    pub(crate) io: ws_stream_wasm::WsStream,
28}
29
30/// Error type for WebSocket stream operations.
31///
32/// This type alias represents errors that can occur during WebSocket communication.
33/// The underlying error type depends on the platform:
34/// - On non-browser platforms: `tokio_websockets::Error`
35/// - On browser WASM: `ws_stream_wasm::WsErr`
36#[cfg(not(wasm_browser))]
37pub type StreamError = tokio_websockets::Error;
38
39/// Error type for WebSocket stream operations.
40///
41/// This type alias represents errors that can occur during WebSocket communication.
42/// The underlying error type depends on the platform:
43/// - On non-browser platforms: `tokio_websockets::Error`
44/// - On browser WASM: `ws_stream_wasm::WsErr`
45#[cfg(wasm_browser)]
46pub type StreamError = ws_stream_wasm::WsErr;
47
48/// Shorthand for a type that implements both a websocket-based stream & sink for [`Bytes`].
49pub trait BytesStreamSink:
50    Stream<Item = Result<Bytes, StreamError>> + Sink<Bytes, Error = StreamError> + Unpin
51{
52}
53
54impl<T> BytesStreamSink for T where
55    T: Stream<Item = Result<Bytes, StreamError>> + Sink<Bytes, Error = StreamError> + Unpin
56{
57}
58
59#[cfg(not(wasm_browser))]
60impl<IO: ExportKeyingMaterial + AsyncRead + AsyncWrite + Unpin> ExportKeyingMaterial
61    for WsBytesFramed<IO>
62{
63    fn export_keying_material<T: AsMut<[u8]>>(
64        &self,
65        output: T,
66        label: &[u8],
67        context: Option<&[u8]>,
68    ) -> Option<T> {
69        self.io
70            .get_ref()
71            .export_keying_material(output, label, context)
72    }
73}
74
75#[cfg(wasm_browser)]
76impl ExportKeyingMaterial for WsBytesFramed {
77    fn export_keying_material<T: AsMut<[u8]>>(
78        &self,
79        _output: T,
80        _label: &[u8],
81        _context: Option<&[u8]>,
82    ) -> Option<T> {
83        None
84    }
85}
86
87#[cfg(not(wasm_browser))]
88impl<T: AsyncRead + AsyncWrite + Unpin> Stream for WsBytesFramed<T> {
89    type Item = Result<Bytes, StreamError>;
90
91    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92        loop {
93            match ready!(Pin::new(&mut self.io).poll_next(cx)) {
94                None => return Poll::Ready(None),
95                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
96                Some(Ok(msg)) => {
97                    if msg.is_close() {
98                        // Indicate the stream is done when we receive a close message.
99                        // Note: We don't have to poll the stream to completion for it to close gracefully.
100                        return Poll::Ready(None);
101                    }
102                    if msg.is_ping() || msg.is_pong() {
103                        continue; // Responding appropriately to these is done inside of tokio_websockets/browser impls
104                    }
105                    if !msg.is_binary() {
106                        warn!(?msg, "Got websocket message of unsupported type, skipping.");
107                        continue;
108                    }
109                    return Poll::Ready(Some(Ok(msg.into_payload().into())));
110                }
111            }
112        }
113    }
114}
115
116#[cfg(wasm_browser)]
117impl Stream for WsBytesFramed {
118    type Item = Result<Bytes, StreamError>;
119
120    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
121        loop {
122            match ready!(Pin::new(&mut self.io).poll_next(cx)) {
123                None => return Poll::Ready(None),
124                Some(ws_stream_wasm::WsMessage::Binary(msg)) => {
125                    return Poll::Ready(Some(Ok(msg.into())));
126                }
127                Some(msg) => {
128                    warn!(?msg, "Got websocket message of unsupported type, skipping.");
129                    continue;
130                }
131            }
132        }
133    }
134}
135
136#[cfg(not(wasm_browser))]
137impl<T: AsyncRead + AsyncWrite + Unpin> Sink<Bytes> for WsBytesFramed<T> {
138    type Error = StreamError;
139
140    fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> {
141        let msg = tokio_websockets::Message::binary(tokio_websockets::Payload::from(bytes));
142        Pin::new(&mut self.io).start_send(msg)
143    }
144
145    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146        Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into)
147    }
148
149    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150        Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into)
151    }
152
153    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154        Pin::new(&mut self.io).poll_close(cx).map_err(Into::into)
155    }
156}
157
158#[cfg(wasm_browser)]
159impl Sink<Bytes> for WsBytesFramed {
160    type Error = StreamError;
161
162    fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> {
163        let msg = ws_stream_wasm::WsMessage::Binary(Vec::from(bytes));
164        Pin::new(&mut self.io).start_send(msg).map_err(Into::into)
165    }
166
167    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into)
169    }
170
171    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172        Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into)
173    }
174
175    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
176        Pin::new(&mut self.io).poll_close(cx).map_err(Into::into)
177    }
178}