1use std::{
4 pin::Pin,
5 sync::{Arc, atomic::AtomicBool},
6 task::{Context, Poll},
7};
8
9use n0_error::{ensure, stack_error};
10use n0_future::{FutureExt, Sink, Stream, ready, time};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tracing::{instrument, warn};
13
14use super::{ClientRateLimit, Metrics};
15use crate::{
16 ExportKeyingMaterial, KeyCache, MAX_PACKET_SIZE,
17 protos::{
18 relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg},
19 streams::{StreamError, WsBytesFramed},
20 },
21};
22
23#[derive(Debug)]
31pub struct RelayedStream<S> {
32 pub(crate) inner: S,
33 pub(crate) key_cache: KeyCache,
34}
35
36impl<S> RelayedStream<S> {
37 pub fn new(inner: S, key_cache: KeyCache) -> Self {
42 Self { inner, key_cache }
43 }
44}
45
46#[allow(dead_code)]
48pub(crate) type ServerRelayedStream = RelayedStream<WsBytesFramed<RateLimited<MaybeTlsStream>>>;
49
50#[cfg(test)]
51impl ServerRelayedStream {
52 pub(crate) fn test(stream: tokio::io::DuplexStream) -> Self {
53 let stream = MaybeTlsStream::Test(stream);
54 let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default()));
55 Self {
56 inner: WsBytesFramed {
57 io: tokio_websockets::ServerBuilder::new()
58 .limits(Self::limits())
59 .serve(stream),
60 },
61 key_cache: KeyCache::test(),
62 }
63 }
64
65 pub(crate) fn test_limited(
66 stream: tokio::io::DuplexStream,
67 max_burst_bytes: u32,
68 bytes_per_second: u32,
69 ) -> Result<Self, InvalidBucketConfig> {
70 let stream = MaybeTlsStream::Test(stream);
71 let stream = RateLimited::new(
72 stream,
73 max_burst_bytes,
74 bytes_per_second,
75 Arc::new(Metrics::default()),
76 )?;
77 Ok(Self {
78 inner: WsBytesFramed {
79 io: tokio_websockets::ServerBuilder::new()
80 .limits(Self::limits())
81 .serve(stream),
82 },
83 key_cache: KeyCache::test(),
84 })
85 }
86
87 fn limits() -> tokio_websockets::Limits {
88 tokio_websockets::Limits::default()
89 .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE))
90 }
91}
92
93#[stack_error(derive, add_meta)]
95#[non_exhaustive]
96pub enum SendError {
97 #[error(transparent)]
99 StreamError {
100 #[error(from, std_err)]
101 source: StreamError,
103 },
104 #[error("Packet exceeds max packet size")]
106 ExceedsMaxPacketSize {
107 size: usize,
109 },
110 #[error("Attempted to send empty packet")]
112 EmptyPacket {},
113}
114
115impl<S> Sink<RelayToClientMsg> for RelayedStream<S>
116where
117 S: Sink<bytes::Bytes, Error = StreamError> + Unpin,
118{
119 type Error = SendError;
120
121 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122 Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
123 }
124
125 fn start_send(mut self: Pin<&mut Self>, item: RelayToClientMsg) -> Result<(), Self::Error> {
126 let size = item.encoded_len();
127 ensure!(
128 size <= MAX_PACKET_SIZE,
129 SendError::ExceedsMaxPacketSize { size }
130 );
131 if let RelayToClientMsg::Datagrams { datagrams, .. } = &item {
132 ensure!(!datagrams.contents.is_empty(), SendError::EmptyPacket);
133 }
134
135 Pin::new(&mut self.inner)
136 .start_send(item.to_bytes().freeze())
137 .map_err(Into::into)
138 }
139
140 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141 Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
142 }
143
144 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145 Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
146 }
147}
148
149#[stack_error(derive, add_meta, from_sources)]
151#[non_exhaustive]
152pub enum RecvError {
153 #[error(transparent)]
155 Proto {
156 source: ProtoError,
158 },
159 #[error(transparent)]
161 StreamError {
162 #[error(std_err)]
163 source: StreamError,
165 },
166}
167
168impl<S> Stream for RelayedStream<S>
169where
170 S: Stream<Item = Result<bytes::Bytes, StreamError>> + Unpin,
171{
172 type Item = Result<ClientToRelayMsg, RecvError>;
173
174 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175 Poll::Ready(match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
176 Some(Ok(msg)) => {
177 Some(ClientToRelayMsg::from_bytes(msg, &self.key_cache).map_err(Into::into))
178 }
179 Some(Err(e)) => Some(Err(e.into())),
180 None => None,
181 })
182 }
183}
184
185#[derive(Debug)]
189#[allow(clippy::large_enum_variant)]
190pub enum MaybeTlsStream {
191 Plain(tokio::net::TcpStream),
193 Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
195 #[cfg(test)]
197 Test(tokio::io::DuplexStream),
198}
199
200impl MaybeTlsStream {
201 pub fn disable_nagle(&self) {
208 let stream = match self {
209 #[cfg(test)]
210 Self::Test(_) => return,
211 Self::Plain(stream) => stream,
212 Self::Tls(tls_stream) => tls_stream.get_ref().0,
213 };
214
215 if stream.set_nodelay(true).is_err() {
216 use std::sync::atomic::Ordering::Relaxed;
217
218 static FAILED_NO_DELAY: AtomicBool = AtomicBool::new(false);
219 if !FAILED_NO_DELAY.swap(true, Relaxed) {
220 warn!(
221 "Failed to set TCP socket to NO_DELAY (turning off Nagle failed). This will impair relay performance."
222 );
223 }
224 }
225 }
226}
227
228impl ExportKeyingMaterial for MaybeTlsStream {
229 fn export_keying_material<T: AsMut<[u8]>>(
230 &self,
231 output: T,
232 label: &[u8],
233 context: Option<&[u8]>,
234 ) -> Option<T> {
235 let Self::Tls(tls) = self else {
236 return None;
237 };
238
239 tls.get_ref()
240 .1
241 .export_keying_material(output, label, context)
242 .ok()
243 }
244}
245
246impl AsyncRead for MaybeTlsStream {
247 fn poll_read(
248 mut self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 buf: &mut tokio::io::ReadBuf<'_>,
251 ) -> Poll<std::io::Result<()>> {
252 match &mut *self {
253 MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
254 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
255 #[cfg(test)]
256 MaybeTlsStream::Test(s) => Pin::new(s).poll_read(cx, buf),
257 }
258 }
259}
260
261impl AsyncWrite for MaybeTlsStream {
262 fn poll_flush(
263 mut self: Pin<&mut Self>,
264 cx: &mut Context<'_>,
265 ) -> Poll<std::result::Result<(), std::io::Error>> {
266 match &mut *self {
267 MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
268 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
269 #[cfg(test)]
270 MaybeTlsStream::Test(s) => Pin::new(s).poll_flush(cx),
271 }
272 }
273
274 fn poll_shutdown(
275 mut self: Pin<&mut Self>,
276 cx: &mut Context<'_>,
277 ) -> Poll<std::result::Result<(), std::io::Error>> {
278 match &mut *self {
279 MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
280 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
281 #[cfg(test)]
282 MaybeTlsStream::Test(s) => Pin::new(s).poll_shutdown(cx),
283 }
284 }
285
286 fn poll_write(
287 mut self: Pin<&mut Self>,
288 cx: &mut Context<'_>,
289 buf: &[u8],
290 ) -> Poll<std::result::Result<usize, std::io::Error>> {
291 match &mut *self {
292 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
293 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
294 #[cfg(test)]
295 MaybeTlsStream::Test(s) => Pin::new(s).poll_write(cx, buf),
296 }
297 }
298
299 fn poll_write_vectored(
300 mut self: Pin<&mut Self>,
301 cx: &mut Context<'_>,
302 bufs: &[std::io::IoSlice<'_>],
303 ) -> Poll<std::result::Result<usize, std::io::Error>> {
304 match &mut *self {
305 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
306 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
307 #[cfg(test)]
308 MaybeTlsStream::Test(s) => Pin::new(s).poll_write_vectored(cx, bufs),
309 }
310 }
311
312 fn is_write_vectored(&self) -> bool {
313 match self {
314 MaybeTlsStream::Plain(s) => s.is_write_vectored(),
315 MaybeTlsStream::Tls(s) => s.is_write_vectored(),
316 #[cfg(test)]
317 MaybeTlsStream::Test(s) => s.is_write_vectored(),
318 }
319 }
320}
321
322#[derive(Debug)]
329pub(crate) struct RateLimited<S> {
330 inner: S,
331 bucket: Option<Bucket>,
332 bucket_refilled: Option<Pin<Box<time::Sleep>>>,
333 limited_once: bool,
335 metrics: Arc<Metrics>,
336}
337
338#[derive(Debug)]
339struct Bucket {
340 fill: i64,
342 max: i64,
344 last_fill: time::Instant,
346 refill_period: time::Duration,
348 refill: i64,
350}
351
352#[allow(missing_docs)]
353#[stack_error(derive, add_meta)]
354pub struct InvalidBucketConfig {
355 max: i64,
356 bytes_per_second: i64,
357 refill_period: time::Duration,
358}
359
360impl Bucket {
361 fn new(
362 max: i64,
363 bytes_per_second: i64,
364 refill_period: time::Duration,
365 ) -> Result<Self, InvalidBucketConfig> {
366 let refill = bytes_per_second.saturating_mul(refill_period.as_millis() as i64) / 1000;
368 ensure!(
369 max > 0 && bytes_per_second > 0 && refill_period.as_millis() as u32 > 0 && refill > 0,
370 InvalidBucketConfig {
371 max,
372 bytes_per_second,
373 refill_period
374 }
375 );
376 Ok(Self {
377 fill: max,
378 max,
379 last_fill: time::Instant::now(),
380 refill_period,
381 refill,
382 })
383 }
384
385 fn update_state(&mut self) {
386 let now = time::Instant::now();
387 let refill_periods = now.saturating_duration_since(self.last_fill).as_millis() as u32
389 / self.refill_period.as_millis() as u32;
390 if refill_periods == 0 {
391 return;
393 }
394
395 self.fill = self
396 .fill
397 .saturating_add(refill_periods as i64 * self.refill);
398 self.fill = std::cmp::min(self.fill, self.max);
399 self.last_fill += self.refill_period * refill_periods;
400 }
401
402 fn consume(&mut self, bytes: usize) -> Result<(), time::Instant> {
403 let bytes = i64::try_from(bytes).unwrap_or(i64::MAX);
404 self.update_state();
405
406 self.fill = self.fill.saturating_sub(bytes);
407
408 if self.fill > 0 {
409 return Ok(());
410 }
411
412 let missing = self.fill.saturating_neg();
413
414 let periods_needed = (missing / self.refill) + 1;
415 let periods_needed = u32::try_from(periods_needed).unwrap_or(u32::MAX);
416
417 Err(self.last_fill + periods_needed * self.refill_period)
418 }
419}
420
421impl<S> RateLimited<S> {
422 pub(crate) fn from_cfg(
423 cfg: Option<ClientRateLimit>,
424 io: S,
425 metrics: Arc<Metrics>,
426 ) -> Result<Self, InvalidBucketConfig> {
427 match cfg {
428 Some(cfg) => {
429 let bytes_per_second = cfg.bytes_per_second.into();
430 let max_burst_bytes = cfg.max_burst_bytes.map_or(bytes_per_second / 10, u32::from);
431 Self::new(io, max_burst_bytes, bytes_per_second, metrics)
432 }
433 None => Ok(Self::unlimited(io, metrics)),
434 }
435 }
436
437 pub(crate) fn new(
438 inner: S,
439 max_burst_bytes: u32,
440 bytes_per_second: u32,
441 metrics: Arc<Metrics>,
442 ) -> Result<Self, InvalidBucketConfig> {
443 Ok(Self {
444 inner,
445 bucket: Some(Bucket::new(
446 max_burst_bytes as i64,
447 bytes_per_second as i64,
448 time::Duration::from_millis(100),
449 )?),
450 bucket_refilled: None,
451 limited_once: false,
452 metrics,
453 })
454 }
455
456 pub(crate) fn unlimited(inner: S, metrics: Arc<Metrics>) -> Self {
457 Self {
458 inner,
459 bucket: None,
460 bucket_refilled: None,
461 limited_once: false,
462 metrics,
463 }
464 }
465
466 fn record_rate_limited(&mut self, bytes: usize) {
468 self.metrics.bytes_rx_ratelimited_total.inc_by(bytes as u64);
470 if !self.limited_once {
471 self.metrics.conns_rx_ratelimited_total.inc();
472 self.limited_once = true;
473 }
474 }
475}
476
477impl<S: ExportKeyingMaterial> ExportKeyingMaterial for RateLimited<S> {
478 fn export_keying_material<T: AsMut<[u8]>>(
479 &self,
480 output: T,
481 label: &[u8],
482 context: Option<&[u8]>,
483 ) -> Option<T> {
484 self.inner.export_keying_material(output, label, context)
485 }
486}
487
488impl<S: AsyncRead + Unpin> AsyncRead for RateLimited<S> {
489 #[instrument(name = "rate_limited_poll_read", skip_all)]
490 fn poll_read(
491 mut self: Pin<&mut Self>,
492 cx: &mut std::task::Context<'_>,
493 buf: &mut tokio::io::ReadBuf<'_>,
494 ) -> Poll<std::io::Result<()>> {
495 let this = &mut *self;
496 let Some(bucket) = &mut this.bucket else {
497 return Pin::new(&mut this.inner).poll_read(cx, buf);
499 };
500
501 if let Some(bucket_refilled) = &mut this.bucket_refilled {
503 ready!(bucket_refilled.poll(cx));
504 this.bucket_refilled = None;
505 }
506
507 let bytes_before = buf.remaining();
511 ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
512 let bytes_read = bytes_before - buf.remaining();
513
514 if let Err(refill_time) = bucket.consume(bytes_read) {
516 this.record_rate_limited(bytes_read);
517 this.bucket_refilled = Some(Box::pin(time::sleep_until(refill_time)));
518 }
519
520 Poll::Ready(Ok(()))
521 }
522}
523
524impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimited<S> {
525 fn poll_write(
526 mut self: Pin<&mut Self>,
527 cx: &mut std::task::Context<'_>,
528 buf: &[u8],
529 ) -> Poll<Result<usize, std::io::Error>> {
530 Pin::new(&mut self.inner).poll_write(cx, buf)
531 }
532
533 fn poll_flush(
534 mut self: Pin<&mut Self>,
535 cx: &mut std::task::Context<'_>,
536 ) -> Poll<Result<(), std::io::Error>> {
537 Pin::new(&mut self.inner).poll_flush(cx)
538 }
539
540 fn poll_shutdown(
541 mut self: Pin<&mut Self>,
542 cx: &mut std::task::Context<'_>,
543 ) -> Poll<Result<(), std::io::Error>> {
544 Pin::new(&mut self.inner).poll_shutdown(cx)
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use std::sync::Arc;
551
552 use n0_error::{Result, StdResultExt};
553 use n0_future::time;
554 use n0_tracing_test::traced_test;
555 use tokio::io::{AsyncReadExt, AsyncWriteExt};
556
557 use super::Bucket;
558 use crate::server::{Metrics, streams::RateLimited};
559
560 #[tokio::test(start_paused = true)]
561 #[traced_test]
562 async fn test_ratelimiter() -> Result {
563 let (read, mut write) = tokio::io::duplex(4096);
564
565 let send_total = 10 * 1024 * 1024; let send_data = vec![42u8; send_total];
567
568 let bytes_per_second = 12_345;
569
570 let mut rate_limited = RateLimited::new(
571 read,
572 bytes_per_second / 10,
573 bytes_per_second,
574 Arc::new(Metrics::default()),
575 )?;
576
577 let before = time::Instant::now();
578 n0_future::future::try_zip(
579 async {
580 let mut remaining = send_total;
581 let mut buf = [0u8; 4096];
582 while remaining > 0 {
583 remaining -= rate_limited.read(&mut buf).await?;
584 }
585 Ok(())
586 },
587 async {
588 write.write_all(&send_data).await?;
589 write.flush().await
590 },
591 )
592 .await
593 .anyerr()?;
594
595 let duration = time::Instant::now().duration_since(before);
596 assert_ne!(duration.as_millis(), 0);
597
598 let actual_bytes_per_second = send_total as f64 / duration.as_secs_f64();
599 println!("{actual_bytes_per_second}");
600 assert_eq!(actual_bytes_per_second.round() as u32, bytes_per_second);
601
602 Ok(())
603 }
604
605 #[tokio::test(start_paused = true)]
606 async fn test_bucket_high_refill() -> Result {
607 let bytes_per_second = i64::MAX;
608 let mut bucket = Bucket::new(i64::MAX, bytes_per_second, time::Duration::from_millis(100))?;
609 for _ in 0..100 {
610 time::sleep(time::Duration::from_millis(100)).await;
611 assert!(bucket.consume(1_000_000).is_ok());
612 }
613
614 Ok(())
615 }
616
617 #[tokio::test(start_paused = true)]
618 async fn smoke_test_bucket_high_consume() -> Result {
619 let bytes_per_second = 123_456;
620 let mut bucket = Bucket::new(
621 bytes_per_second / 10,
622 bytes_per_second,
623 time::Duration::from_millis(100),
624 )?;
625 for _ in 0..100 {
626 let Err(until) = bucket.consume(usize::MAX) else {
627 panic!("i64::MAX shouldn't be within limits");
628 };
629 time::sleep_until(until).await;
630 }
631
632 Ok(())
633 }
634}