iroh_relay/server/
streams.rs

1//! Streams used in the server-side implementation of iroh relays.
2
3use std::{
4    pin::Pin,
5    sync::{Arc, atomic::AtomicBool},
6    task::{Context, Poll},
7};
8
9use n0_error::{ensure, stack_error};
10use n0_future::{FutureExt, Sink, Stream, ready, time};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tracing::{instrument, warn};
13
14use super::{ClientRateLimit, Metrics};
15use crate::{
16    ExportKeyingMaterial, KeyCache, MAX_PACKET_SIZE,
17    protos::{
18        relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg},
19        streams::{StreamError, WsBytesFramed},
20    },
21};
22
23/// The relay's connection to a client.
24///
25/// This implements
26/// - a [`Stream`] of [`ClientToRelayMsg`]s that are received from the client,
27/// - a [`Sink`] of [`RelayToClientMsg`]s that can be sent to the client.
28///
29/// Generic over the inner stream type to support different WebSocket implementations.
30#[derive(Debug)]
31pub struct RelayedStream<S> {
32    pub(crate) inner: S,
33    pub(crate) key_cache: KeyCache,
34}
35
36impl<S> RelayedStream<S> {
37    /// Creates a new RelayedStream from an inner stream and key cache.
38    ///
39    /// This is the primary constructor for external integrations using custom
40    /// WebSocket implementations.
41    pub fn new(inner: S, key_cache: KeyCache) -> Self {
42        Self { inner, key_cache }
43    }
44}
45
46/// Type alias for the standard server-side relay stream
47#[allow(dead_code)]
48pub(crate) type ServerRelayedStream = RelayedStream<WsBytesFramed<RateLimited<MaybeTlsStream>>>;
49
50#[cfg(test)]
51impl ServerRelayedStream {
52    pub(crate) fn test(stream: tokio::io::DuplexStream) -> Self {
53        let stream = MaybeTlsStream::Test(stream);
54        let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default()));
55        Self {
56            inner: WsBytesFramed {
57                io: tokio_websockets::ServerBuilder::new()
58                    .limits(Self::limits())
59                    .serve(stream),
60            },
61            key_cache: KeyCache::test(),
62        }
63    }
64
65    pub(crate) fn test_limited(
66        stream: tokio::io::DuplexStream,
67        max_burst_bytes: u32,
68        bytes_per_second: u32,
69    ) -> Result<Self, InvalidBucketConfig> {
70        let stream = MaybeTlsStream::Test(stream);
71        let stream = RateLimited::new(
72            stream,
73            max_burst_bytes,
74            bytes_per_second,
75            Arc::new(Metrics::default()),
76        )?;
77        Ok(Self {
78            inner: WsBytesFramed {
79                io: tokio_websockets::ServerBuilder::new()
80                    .limits(Self::limits())
81                    .serve(stream),
82            },
83            key_cache: KeyCache::test(),
84        })
85    }
86
87    fn limits() -> tokio_websockets::Limits {
88        tokio_websockets::Limits::default()
89            .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE))
90    }
91}
92
93/// Relay send errors
94#[stack_error(derive, add_meta)]
95#[non_exhaustive]
96pub enum SendError {
97    /// Error from the underlying WebSocket stream
98    #[error(transparent)]
99    StreamError {
100        #[error(from, std_err)]
101        /// The underlying stream error
102        source: StreamError,
103    },
104    /// Packet size exceeds the maximum allowed size
105    #[error("Packet exceeds max packet size")]
106    ExceedsMaxPacketSize {
107        /// The size of the packet that was too large
108        size: usize,
109    },
110    /// Attempted to send an empty packet
111    #[error("Attempted to send empty packet")]
112    EmptyPacket {},
113}
114
115impl<S> Sink<RelayToClientMsg> for RelayedStream<S>
116where
117    S: Sink<bytes::Bytes, Error = StreamError> + Unpin,
118{
119    type Error = SendError;
120
121    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122        Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
123    }
124
125    fn start_send(mut self: Pin<&mut Self>, item: RelayToClientMsg) -> Result<(), Self::Error> {
126        let size = item.encoded_len();
127        ensure!(
128            size <= MAX_PACKET_SIZE,
129            SendError::ExceedsMaxPacketSize { size }
130        );
131        if let RelayToClientMsg::Datagrams { datagrams, .. } = &item {
132            ensure!(!datagrams.contents.is_empty(), SendError::EmptyPacket);
133        }
134
135        Pin::new(&mut self.inner)
136            .start_send(item.to_bytes().freeze())
137            .map_err(Into::into)
138    }
139
140    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141        Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
142    }
143
144    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145        Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
146    }
147}
148
149/// Relay receive errors
150#[stack_error(derive, add_meta, from_sources)]
151#[non_exhaustive]
152pub enum RecvError {
153    /// Error decoding the relay protocol message
154    #[error(transparent)]
155    Proto {
156        /// The protocol decoding error
157        source: ProtoError,
158    },
159    /// Error from the underlying WebSocket stream
160    #[error(transparent)]
161    StreamError {
162        #[error(std_err)]
163        /// The underlying stream error
164        source: StreamError,
165    },
166}
167
168impl<S> Stream for RelayedStream<S>
169where
170    S: Stream<Item = Result<bytes::Bytes, StreamError>> + Unpin,
171{
172    type Item = Result<ClientToRelayMsg, RecvError>;
173
174    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175        Poll::Ready(match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
176            Some(Ok(msg)) => {
177                Some(ClientToRelayMsg::from_bytes(msg, &self.key_cache).map_err(Into::into))
178            }
179            Some(Err(e)) => Some(Err(e.into())),
180            None => None,
181        })
182    }
183}
184
185/// The main underlying IO stream type used for the relay server.
186///
187/// Allows choosing whether or not the underlying [`tokio::net::TcpStream`] is served over Tls
188#[derive(Debug)]
189#[allow(clippy::large_enum_variant)]
190pub enum MaybeTlsStream {
191    /// A plain non-Tls [`tokio::net::TcpStream`]
192    Plain(tokio::net::TcpStream),
193    /// A Tls wrapped [`tokio::net::TcpStream`]
194    Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
195    /// An in-memory bidirectional pipe.
196    #[cfg(test)]
197    Test(tokio::io::DuplexStream),
198}
199
200impl MaybeTlsStream {
201    /// Tries to disable the nagle algorithm on the TCP stream.
202    ///
203    /// This sets the NO_DELAY option on the TCP stream, which turns off the
204    /// nagle algorithm for coalecing writes together.
205    ///
206    /// If this fails, this will print a warning the first time it fails.
207    pub fn disable_nagle(&self) {
208        let stream = match self {
209            #[cfg(test)]
210            Self::Test(_) => return,
211            Self::Plain(stream) => stream,
212            Self::Tls(tls_stream) => tls_stream.get_ref().0,
213        };
214
215        if stream.set_nodelay(true).is_err() {
216            use std::sync::atomic::Ordering::Relaxed;
217
218            static FAILED_NO_DELAY: AtomicBool = AtomicBool::new(false);
219            if !FAILED_NO_DELAY.swap(true, Relaxed) {
220                warn!(
221                    "Failed to set TCP socket to NO_DELAY (turning off Nagle failed). This will impair relay performance."
222                );
223            }
224        }
225    }
226}
227
228impl ExportKeyingMaterial for MaybeTlsStream {
229    fn export_keying_material<T: AsMut<[u8]>>(
230        &self,
231        output: T,
232        label: &[u8],
233        context: Option<&[u8]>,
234    ) -> Option<T> {
235        let Self::Tls(tls) = self else {
236            return None;
237        };
238
239        tls.get_ref()
240            .1
241            .export_keying_material(output, label, context)
242            .ok()
243    }
244}
245
246impl AsyncRead for MaybeTlsStream {
247    fn poll_read(
248        mut self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250        buf: &mut tokio::io::ReadBuf<'_>,
251    ) -> Poll<std::io::Result<()>> {
252        match &mut *self {
253            MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
254            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
255            #[cfg(test)]
256            MaybeTlsStream::Test(s) => Pin::new(s).poll_read(cx, buf),
257        }
258    }
259}
260
261impl AsyncWrite for MaybeTlsStream {
262    fn poll_flush(
263        mut self: Pin<&mut Self>,
264        cx: &mut Context<'_>,
265    ) -> Poll<std::result::Result<(), std::io::Error>> {
266        match &mut *self {
267            MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
268            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
269            #[cfg(test)]
270            MaybeTlsStream::Test(s) => Pin::new(s).poll_flush(cx),
271        }
272    }
273
274    fn poll_shutdown(
275        mut self: Pin<&mut Self>,
276        cx: &mut Context<'_>,
277    ) -> Poll<std::result::Result<(), std::io::Error>> {
278        match &mut *self {
279            MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
280            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
281            #[cfg(test)]
282            MaybeTlsStream::Test(s) => Pin::new(s).poll_shutdown(cx),
283        }
284    }
285
286    fn poll_write(
287        mut self: Pin<&mut Self>,
288        cx: &mut Context<'_>,
289        buf: &[u8],
290    ) -> Poll<std::result::Result<usize, std::io::Error>> {
291        match &mut *self {
292            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
293            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
294            #[cfg(test)]
295            MaybeTlsStream::Test(s) => Pin::new(s).poll_write(cx, buf),
296        }
297    }
298
299    fn poll_write_vectored(
300        mut self: Pin<&mut Self>,
301        cx: &mut Context<'_>,
302        bufs: &[std::io::IoSlice<'_>],
303    ) -> Poll<std::result::Result<usize, std::io::Error>> {
304        match &mut *self {
305            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
306            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
307            #[cfg(test)]
308            MaybeTlsStream::Test(s) => Pin::new(s).poll_write_vectored(cx, bufs),
309        }
310    }
311
312    fn is_write_vectored(&self) -> bool {
313        match self {
314            MaybeTlsStream::Plain(s) => s.is_write_vectored(),
315            MaybeTlsStream::Tls(s) => s.is_write_vectored(),
316            #[cfg(test)]
317            MaybeTlsStream::Test(s) => s.is_write_vectored(),
318        }
319    }
320}
321
322/// Rate limiter for reading from a [`RelayedStream`].
323///
324/// The writes to the sink are not rate limited.
325///
326/// This potentially buffers one frame if the rate limiter does not allows this frame.
327/// While the frame is buffered the undernlying stream is no longer polled.
328#[derive(Debug)]
329pub(crate) struct RateLimited<S> {
330    inner: S,
331    bucket: Option<Bucket>,
332    bucket_refilled: Option<Pin<Box<time::Sleep>>>,
333    /// Keeps track if this stream was ever rate-limited.
334    limited_once: bool,
335    metrics: Arc<Metrics>,
336}
337
338#[derive(Debug)]
339struct Bucket {
340    // The current bucket fill
341    fill: i64,
342    // The maximum bucket fill
343    max: i64,
344    // The bucket's last fill time
345    last_fill: time::Instant,
346    // Interval length of one refill
347    refill_period: time::Duration,
348    // How much we re-fill per refill period
349    refill: i64,
350}
351
352#[allow(missing_docs)]
353#[stack_error(derive, add_meta)]
354pub struct InvalidBucketConfig {
355    max: i64,
356    bytes_per_second: i64,
357    refill_period: time::Duration,
358}
359
360impl Bucket {
361    fn new(
362        max: i64,
363        bytes_per_second: i64,
364        refill_period: time::Duration,
365    ) -> Result<Self, InvalidBucketConfig> {
366        // milliseconds is the tokio timer resolution
367        let refill = bytes_per_second.saturating_mul(refill_period.as_millis() as i64) / 1000;
368        ensure!(
369            max > 0 && bytes_per_second > 0 && refill_period.as_millis() as u32 > 0 && refill > 0,
370            InvalidBucketConfig {
371                max,
372                bytes_per_second,
373                refill_period
374            }
375        );
376        Ok(Self {
377            fill: max,
378            max,
379            last_fill: time::Instant::now(),
380            refill_period,
381            refill,
382        })
383    }
384
385    fn update_state(&mut self) {
386        let now = time::Instant::now();
387        // div safety: self.refill_period.as_millis() is checked to be non-null in constructor
388        let refill_periods = now.saturating_duration_since(self.last_fill).as_millis() as u32
389            / self.refill_period.as_millis() as u32;
390        if refill_periods == 0 {
391            // Nothing to do - we won't refill yet
392            return;
393        }
394
395        self.fill = self
396            .fill
397            .saturating_add(refill_periods as i64 * self.refill);
398        self.fill = std::cmp::min(self.fill, self.max);
399        self.last_fill += self.refill_period * refill_periods;
400    }
401
402    fn consume(&mut self, bytes: usize) -> Result<(), time::Instant> {
403        let bytes = i64::try_from(bytes).unwrap_or(i64::MAX);
404        self.update_state();
405
406        self.fill = self.fill.saturating_sub(bytes);
407
408        if self.fill > 0 {
409            return Ok(());
410        }
411
412        let missing = self.fill.saturating_neg();
413
414        let periods_needed = (missing / self.refill) + 1;
415        let periods_needed = u32::try_from(periods_needed).unwrap_or(u32::MAX);
416
417        Err(self.last_fill + periods_needed * self.refill_period)
418    }
419}
420
421impl<S> RateLimited<S> {
422    pub(crate) fn from_cfg(
423        cfg: Option<ClientRateLimit>,
424        io: S,
425        metrics: Arc<Metrics>,
426    ) -> Result<Self, InvalidBucketConfig> {
427        match cfg {
428            Some(cfg) => {
429                let bytes_per_second = cfg.bytes_per_second.into();
430                let max_burst_bytes = cfg.max_burst_bytes.map_or(bytes_per_second / 10, u32::from);
431                Self::new(io, max_burst_bytes, bytes_per_second, metrics)
432            }
433            None => Ok(Self::unlimited(io, metrics)),
434        }
435    }
436
437    pub(crate) fn new(
438        inner: S,
439        max_burst_bytes: u32,
440        bytes_per_second: u32,
441        metrics: Arc<Metrics>,
442    ) -> Result<Self, InvalidBucketConfig> {
443        Ok(Self {
444            inner,
445            bucket: Some(Bucket::new(
446                max_burst_bytes as i64,
447                bytes_per_second as i64,
448                time::Duration::from_millis(100),
449            )?),
450            bucket_refilled: None,
451            limited_once: false,
452            metrics,
453        })
454    }
455
456    pub(crate) fn unlimited(inner: S, metrics: Arc<Metrics>) -> Self {
457        Self {
458            inner,
459            bucket: None,
460            bucket_refilled: None,
461            limited_once: false,
462            metrics,
463        }
464    }
465
466    /// Records metrics about being rate-limited.
467    fn record_rate_limited(&mut self, bytes: usize) {
468        // TODO: add a label for the frame type.
469        self.metrics.bytes_rx_ratelimited_total.inc_by(bytes as u64);
470        if !self.limited_once {
471            self.metrics.conns_rx_ratelimited_total.inc();
472            self.limited_once = true;
473        }
474    }
475}
476
477impl<S: ExportKeyingMaterial> ExportKeyingMaterial for RateLimited<S> {
478    fn export_keying_material<T: AsMut<[u8]>>(
479        &self,
480        output: T,
481        label: &[u8],
482        context: Option<&[u8]>,
483    ) -> Option<T> {
484        self.inner.export_keying_material(output, label, context)
485    }
486}
487
488impl<S: AsyncRead + Unpin> AsyncRead for RateLimited<S> {
489    #[instrument(name = "rate_limited_poll_read", skip_all)]
490    fn poll_read(
491        mut self: Pin<&mut Self>,
492        cx: &mut std::task::Context<'_>,
493        buf: &mut tokio::io::ReadBuf<'_>,
494    ) -> Poll<std::io::Result<()>> {
495        let this = &mut *self;
496        let Some(bucket) = &mut this.bucket else {
497            // If there is no rate-limiter, then directly poll the inner.
498            return Pin::new(&mut this.inner).poll_read(cx, buf);
499        };
500
501        // If we're currently limited, wait until we've got some bucket space again
502        if let Some(bucket_refilled) = &mut this.bucket_refilled {
503            ready!(bucket_refilled.poll(cx));
504            this.bucket_refilled = None;
505        }
506
507        // We're not currently limited, let's read
508
509        // Poll inner for a new item.
510        let bytes_before = buf.remaining();
511        ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
512        let bytes_read = bytes_before - buf.remaining();
513
514        // Record how much we've read, rate limit accordingly, if need be.
515        if let Err(refill_time) = bucket.consume(bytes_read) {
516            this.record_rate_limited(bytes_read);
517            this.bucket_refilled = Some(Box::pin(time::sleep_until(refill_time)));
518        }
519
520        Poll::Ready(Ok(()))
521    }
522}
523
524impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimited<S> {
525    fn poll_write(
526        mut self: Pin<&mut Self>,
527        cx: &mut std::task::Context<'_>,
528        buf: &[u8],
529    ) -> Poll<Result<usize, std::io::Error>> {
530        Pin::new(&mut self.inner).poll_write(cx, buf)
531    }
532
533    fn poll_flush(
534        mut self: Pin<&mut Self>,
535        cx: &mut std::task::Context<'_>,
536    ) -> Poll<Result<(), std::io::Error>> {
537        Pin::new(&mut self.inner).poll_flush(cx)
538    }
539
540    fn poll_shutdown(
541        mut self: Pin<&mut Self>,
542        cx: &mut std::task::Context<'_>,
543    ) -> Poll<Result<(), std::io::Error>> {
544        Pin::new(&mut self.inner).poll_shutdown(cx)
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    use std::sync::Arc;
551
552    use n0_error::{Result, StdResultExt};
553    use n0_future::time;
554    use n0_tracing_test::traced_test;
555    use tokio::io::{AsyncReadExt, AsyncWriteExt};
556
557    use super::Bucket;
558    use crate::server::{Metrics, streams::RateLimited};
559
560    #[tokio::test(start_paused = true)]
561    #[traced_test]
562    async fn test_ratelimiter() -> Result {
563        let (read, mut write) = tokio::io::duplex(4096);
564
565        let send_total = 10 * 1024 * 1024; // 10MiB
566        let send_data = vec![42u8; send_total];
567
568        let bytes_per_second = 12_345;
569
570        let mut rate_limited = RateLimited::new(
571            read,
572            bytes_per_second / 10,
573            bytes_per_second,
574            Arc::new(Metrics::default()),
575        )?;
576
577        let before = time::Instant::now();
578        n0_future::future::try_zip(
579            async {
580                let mut remaining = send_total;
581                let mut buf = [0u8; 4096];
582                while remaining > 0 {
583                    remaining -= rate_limited.read(&mut buf).await?;
584                }
585                Ok(())
586            },
587            async {
588                write.write_all(&send_data).await?;
589                write.flush().await
590            },
591        )
592        .await
593        .anyerr()?;
594
595        let duration = time::Instant::now().duration_since(before);
596        assert_ne!(duration.as_millis(), 0);
597
598        let actual_bytes_per_second = send_total as f64 / duration.as_secs_f64();
599        println!("{actual_bytes_per_second}");
600        assert_eq!(actual_bytes_per_second.round() as u32, bytes_per_second);
601
602        Ok(())
603    }
604
605    #[tokio::test(start_paused = true)]
606    async fn test_bucket_high_refill() -> Result {
607        let bytes_per_second = i64::MAX;
608        let mut bucket = Bucket::new(i64::MAX, bytes_per_second, time::Duration::from_millis(100))?;
609        for _ in 0..100 {
610            time::sleep(time::Duration::from_millis(100)).await;
611            assert!(bucket.consume(1_000_000).is_ok());
612        }
613
614        Ok(())
615    }
616
617    #[tokio::test(start_paused = true)]
618    async fn smoke_test_bucket_high_consume() -> Result {
619        let bytes_per_second = 123_456;
620        let mut bucket = Bucket::new(
621            bytes_per_second / 10,
622            bytes_per_second,
623            time::Duration::from_millis(100),
624        )?;
625        for _ in 0..100 {
626            let Err(until) = bucket.consume(usize::MAX) else {
627                panic!("i64::MAX shouldn't be within limits");
628            };
629            time::sleep_until(until).await;
630        }
631
632        Ok(())
633    }
634}