1use std::{collections::HashSet, sync::Arc, time::Duration};
4
5use iroh_base::EndpointId;
6use n0_error::{e, stack_error};
7use n0_future::{SinkExt, StreamExt};
8use rand::Rng;
9use time::{Date, OffsetDateTime};
10use tokio::{
11 sync::mpsc::{self, error::TrySendError},
12 time::MissedTickBehavior,
13};
14use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
15use tracing::{Instrument, debug, trace, warn};
16
17use crate::{
18 PingTracker,
19 protos::{
20 relay::{ClientToRelayMsg, Datagrams, PING_INTERVAL, RelayToClientMsg},
21 streams::BytesStreamSink,
22 },
23 server::{
24 clients::Clients,
25 metrics::Metrics,
26 streams::{RecvError as RelayRecvError, RelayedStream, SendError as RelaySendError},
27 },
28};
29
30#[derive(Debug, Clone)]
32pub(super) struct Packet {
33 src: EndpointId,
35 data: Datagrams,
37}
38
39#[derive(Debug)]
43pub struct Config<S> {
44 pub endpoint_id: EndpointId,
46 pub stream: RelayedStream<S>,
48 pub write_timeout: Duration,
50 pub channel_capacity: usize,
52}
53
54#[derive(Debug)]
59pub struct Client {
60 endpoint_id: EndpointId,
62 connection_id: u64,
64 done: CancellationToken,
66 handle: AbortOnDropHandle<()>,
68 packet_queue: mpsc::Sender<Packet>,
70 message_queue: mpsc::Sender<RelayToClientMsg>,
72}
73
74impl Client {
75 pub(super) fn new<S>(
79 config: Config<S>,
80 connection_id: u64,
81 clients: &Clients,
82 metrics: Arc<Metrics>,
83 ) -> Client
84 where
85 S: BytesStreamSink + Send + 'static,
86 {
87 let Config {
88 endpoint_id,
89 stream,
90 write_timeout,
91 channel_capacity,
92 } = config;
93
94 let (packet_send_queue_s, packet_send_queue_r) = mpsc::channel(channel_capacity);
95 let (message_send_queue_s, message_send_queue_r) = mpsc::channel(channel_capacity);
96 let done = CancellationToken::new();
97
98 let actor = Actor {
99 stream,
100 timeout: write_timeout,
101 packet_send_queue: packet_send_queue_r,
102 message_send_queue: message_send_queue_r,
103 endpoint_id,
104 connection_id,
105 clients: clients.clone(),
106 client_counter: ClientCounter::default(),
107 ping_tracker: PingTracker::default(),
108 metrics,
109 };
110
111 let io_done = done.clone();
113 let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!(
114 "client-connection-actor",
115 remote_endpoint = %endpoint_id.fmt_short(),
116 connection_id = connection_id
117 )));
118
119 Client {
120 endpoint_id,
121 connection_id,
122 handle: AbortOnDropHandle::new(handle),
123 done,
124 packet_queue: packet_send_queue_s,
125 message_queue: message_send_queue_s,
126 }
127 }
128
129 pub(super) fn connection_id(&self) -> u64 {
130 self.connection_id
131 }
132
133 pub(super) async fn shutdown(self) {
137 self.start_shutdown();
138 if let Err(e) = self.handle.await {
139 warn!(
140 remote_endpoint = %self.endpoint_id.fmt_short(),
141 "error closing actor loop: {e:#?}",
142 );
143 };
144 }
145
146 pub(super) fn start_shutdown(&self) {
148 self.done.cancel();
149 }
150
151 pub(super) fn try_send_packet(
152 &self,
153 src: EndpointId,
154 data: Datagrams,
155 ) -> Result<(), TrySendError<Packet>> {
156 self.packet_queue.try_send(Packet { src, data })
157 }
158
159 pub(super) fn try_send_peer_gone(
160 &self,
161 key: EndpointId,
162 ) -> Result<(), TrySendError<RelayToClientMsg>> {
163 self.message_queue
164 .try_send(RelayToClientMsg::EndpointGone(key))
165 }
166
167 pub(super) fn try_send_health(
168 &self,
169 problem: String,
170 ) -> Result<(), TrySendError<RelayToClientMsg>> {
171 self.message_queue
172 .try_send(RelayToClientMsg::Health { problem })
173 }
174}
175
176#[stack_error(derive, add_meta, from_sources)]
178#[allow(missing_docs)]
179#[non_exhaustive]
180pub enum HandleFrameError {
181 #[error(transparent)]
182 ForwardPacket { source: ForwardPacketError },
183 #[error("Stream terminated")]
184 StreamTerminated {},
185 #[error(transparent)]
186 Recv { source: RelayRecvError },
187 #[error(transparent)]
188 Send { source: WriteFrameError },
189}
190
191#[stack_error(derive, add_meta, from_sources)]
193#[allow(missing_docs)]
194#[non_exhaustive]
195pub enum WriteFrameError {
196 #[error(transparent)]
197 Stream { source: RelaySendError },
198 #[error(transparent)]
199 Timeout {
200 #[error(std_err)]
201 source: tokio::time::error::Elapsed,
202 },
203}
204
205#[stack_error(derive, add_meta)]
207#[allow(missing_docs)]
208#[non_exhaustive]
209pub enum RunError {
210 #[error(transparent)]
211 ForwardPacket {
212 #[error(from)]
213 source: ForwardPacketError,
214 },
215 #[error("Flush")]
216 Flush {},
217 #[error(transparent)]
218 HandleFrame {
219 #[error(from)]
220 source: HandleFrameError,
221 },
222 #[error("Failed to send packet")]
223 PacketSend { source: WriteFrameError },
224 #[error("Handle was dropped")]
225 HandleDropped {},
226 #[error("Writing a frame failed")]
227 WriteFrame { source: WriteFrameError },
228 #[error("Tick flush")]
229 TickFlush {},
230}
231
232#[derive(Debug)]
250struct Actor<S> {
251 stream: RelayedStream<S>,
253 timeout: Duration,
255 packet_send_queue: mpsc::Receiver<Packet>,
257 message_send_queue: mpsc::Receiver<RelayToClientMsg>,
259 endpoint_id: EndpointId,
261 connection_id: u64,
263 clients: Clients,
265 client_counter: ClientCounter,
267 ping_tracker: PingTracker,
268 metrics: Arc<Metrics>,
269}
270
271impl<S> Actor<S>
272where
273 S: BytesStreamSink,
274{
275 async fn run(mut self, done: CancellationToken) {
276 self.metrics.accepts.inc();
280 if self.client_counter.update(self.endpoint_id) {
281 self.metrics.unique_client_keys.inc();
282 }
283 match self.run_inner(done).await {
284 Err(e) => {
285 warn!("actor errored {e:#}, exiting");
286 }
287 Ok(()) => {
288 debug!("actor finished, exiting");
289 }
290 }
291
292 self.clients
293 .unregister(self.connection_id, self.endpoint_id);
294 self.metrics.disconnects.inc();
295 }
296
297 async fn run_inner(&mut self, done: CancellationToken) -> Result<(), RunError> {
298 let next_interval = || {
300 let random_secs = rand::rng().random_range(1..=5);
301 Duration::from_secs(random_secs) + PING_INTERVAL
302 };
303
304 let mut ping_interval = tokio::time::interval(next_interval());
305 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
307 ping_interval.tick().await;
308
309 loop {
310 tokio::select! {
311 biased;
312
313 _ = done.cancelled() => {
314 trace!("actor loop cancelled, exiting");
315 self.stream.flush().await.map_err(|_| e!(RunError::Flush))?;
317 break;
318 }
319 maybe_frame = self.stream.next() => {
320 self
321 .handle_frame(maybe_frame)
322 .await?;
323 ping_interval.reset();
325 }
326 packet = self.packet_send_queue.recv() => {
328 let packet = packet.ok_or_else(|| e!(RunError::HandleDropped))?;
329 self.send_packet(packet)
330 .await
331 .map_err(|err| e!(RunError::PacketSend, err))?;
332 }
333 message = self.message_send_queue.recv() => {
335 let message = message .ok_or_else(|| e!(RunError::HandleDropped))?;
336 trace!("send {message:?}");
337 self.write_frame(message)
338 .await
339 .map_err(|err| e!(RunError::WriteFrame, err))?;
340 }
341 _ = self.ping_tracker.timeout() => {
342 trace!("pong timed out");
343 break;
344 }
345 _ = ping_interval.tick() => {
346 trace!("keep alive ping");
347 ping_interval.reset_after(next_interval());
349 let data = self.ping_tracker.new_ping();
350 self.write_frame(RelayToClientMsg::Ping(data))
351 .await
352 .map_err(|err| e!(RunError::WriteFrame, err))?;
353 }
354 }
355
356 self.stream
357 .flush()
358 .await
359 .map_err(|_| e!(RunError::TickFlush))?;
360 }
361 Ok(())
362 }
363
364 async fn write_frame(&mut self, frame: RelayToClientMsg) -> Result<(), WriteFrameError> {
368 tokio::time::timeout(self.timeout, self.stream.send(frame)).await??;
369 Ok(())
370 }
371
372 async fn send_raw(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
377 let remote_endpoint_id = packet.src;
378 let datagrams = packet.data;
379
380 if let Ok(len) = datagrams.contents.len().try_into() {
381 self.metrics.bytes_sent.inc_by(len);
382 }
383 self.write_frame(RelayToClientMsg::Datagrams {
384 remote_endpoint_id,
385 datagrams,
386 })
387 .await
388 }
389
390 async fn send_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
391 trace!("send packet");
392 match self.send_raw(packet).await {
393 Ok(()) => {
394 self.metrics.send_packets_sent.inc();
395 Ok(())
396 }
397 Err(err) => {
398 self.metrics.send_packets_dropped.inc();
399 Err(err)
400 }
401 }
402 }
403
404 async fn handle_frame(
406 &mut self,
407 maybe_frame: Option<Result<ClientToRelayMsg, RelayRecvError>>,
408 ) -> Result<(), HandleFrameError> {
409 trace!(?maybe_frame, "handle incoming frame");
410 let frame = match maybe_frame {
411 Some(frame) => frame?,
412 None => return Err(e!(HandleFrameError::StreamTerminated)),
413 };
414
415 match frame {
416 ClientToRelayMsg::Datagrams {
417 dst_endpoint_id: dst_key,
418 datagrams,
419 } => {
420 let packet_len = datagrams.contents.len();
421 if let Err(err @ ForwardPacketError { .. }) =
422 self.handle_frame_send_packet(dst_key, datagrams)
423 {
424 warn!("failed to handle send packet frame: {err:#}");
425 }
426 self.metrics.bytes_recv.inc_by(packet_len as u64);
427 }
428 ClientToRelayMsg::Ping(data) => {
429 self.metrics.got_ping.inc();
430 self.write_frame(RelayToClientMsg::Pong(data)).await?;
432 self.metrics.sent_pong.inc();
433 }
434 ClientToRelayMsg::Pong(data) => {
435 self.ping_tracker.pong_received(data);
436 }
437 }
438 Ok(())
439 }
440
441 fn handle_frame_send_packet(
442 &self,
443 dst: EndpointId,
444 data: Datagrams,
445 ) -> Result<(), ForwardPacketError> {
446 self.metrics.send_packets_recv.inc();
447 self.clients
448 .send_packet(dst, data, self.endpoint_id, &self.metrics)?;
449
450 Ok(())
451 }
452}
453
454#[derive(Debug)]
455pub(crate) enum SendError {
456 Full,
457 Closed,
458}
459
460#[stack_error(derive, add_meta)]
466#[error("failed to forward packet: {reason:?}")]
467pub struct ForwardPacketError {
468 reason: SendError,
469}
470
471#[derive(Debug)]
473struct ClientCounter {
474 clients: HashSet<EndpointId>,
475 last_clear_date: Date,
476}
477
478impl Default for ClientCounter {
479 fn default() -> Self {
480 Self {
481 clients: HashSet::new(),
482 last_clear_date: OffsetDateTime::now_utc().date(),
483 }
484 }
485}
486
487impl ClientCounter {
488 fn check_and_clear(&mut self) {
489 let today = OffsetDateTime::now_utc().date();
490 if today != self.last_clear_date {
491 self.clients.clear();
492 self.last_clear_date = today;
493 }
494 }
495
496 fn update(&mut self, client: EndpointId) -> bool {
498 self.check_and_clear();
499 self.clients.insert(client)
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use iroh_base::SecretKey;
506 use n0_error::{Result, StdResultExt, bail_any};
507 use n0_future::Stream;
508 use n0_tracing_test::traced_test;
509 use rand::SeedableRng;
510 use tracing::info;
511
512 use super::*;
513 use crate::{client::conn::Conn, protos::common::FrameType};
514
515 async fn recv_frame<
516 E: std::error::Error + Sync + Send + 'static,
517 S: Stream<Item = Result<RelayToClientMsg, E>> + Unpin,
518 >(
519 frame_type: FrameType,
520 mut stream: S,
521 ) -> Result<RelayToClientMsg> {
522 match stream.next().await {
523 Some(Ok(frame)) => {
524 if frame_type != frame.typ() {
525 bail_any!(
526 "Unexpected frame, got {:?}, but expected {:?}",
527 frame.typ(),
528 frame_type
529 );
530 }
531 Ok(frame)
532 }
533 Some(Err(err)) => Err(err).anyerr(),
534 None => bail_any!("Unexpected EOF, expected frame {frame_type:?}"),
535 }
536 }
537
538 #[tokio::test]
539 #[traced_test]
540 async fn test_client_actor_basic() -> Result {
541 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
542
543 let (send_queue_s, send_queue_r) = mpsc::channel(10);
544 let (message_s, message_r) = mpsc::channel(10);
545
546 let endpoint_id = SecretKey::generate(&mut rng).public();
547 let (io, io_rw) = tokio::io::duplex(1024);
548 let mut io_rw = Conn::test(io_rw);
549 let stream = RelayedStream::test(io);
550
551 let clients = Clients::default();
552 let metrics = Arc::new(Metrics::default());
553 let actor = Actor {
554 stream,
555 timeout: Duration::from_secs(1),
556 packet_send_queue: send_queue_r,
557 message_send_queue: message_r,
558 connection_id: 0,
559 endpoint_id,
560 clients: clients.clone(),
561 client_counter: ClientCounter::default(),
562 ping_tracker: PingTracker::default(),
563 metrics,
564 };
565
566 let done = CancellationToken::new();
567 let io_done = done.clone();
568 let handle = tokio::task::spawn(async move { actor.run(io_done).await });
569
570 println!("-- write");
572 let data = b"hello world!";
573
574 println!(" send packet");
576 let packet = Packet {
577 src: endpoint_id,
578 data: Datagrams::from(&data[..]),
579 };
580 send_queue_s
581 .send(packet.clone())
582 .await
583 .std_context("send")?;
584 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut io_rw)
585 .await
586 .anyerr()?;
587 assert_eq!(
588 frame,
589 RelayToClientMsg::Datagrams {
590 remote_endpoint_id: endpoint_id,
591 datagrams: data.to_vec().into()
592 }
593 );
594
595 println!("send peer gone");
597 message_s
598 .send(RelayToClientMsg::EndpointGone(endpoint_id))
599 .await
600 .std_context("send")?;
601 let frame = recv_frame(FrameType::EndpointGone, &mut io_rw)
602 .await
603 .anyerr()?;
604 assert_eq!(frame, RelayToClientMsg::EndpointGone(endpoint_id));
605
606 println!("--read");
608
609 let data = b"pingpong";
611 io_rw.send(ClientToRelayMsg::Ping(*data)).await?;
612
613 println!(" recv pong");
615 let frame = recv_frame(FrameType::Pong, &mut io_rw).await?;
616 assert_eq!(frame, RelayToClientMsg::Pong(*data));
617
618 let target = SecretKey::generate(&mut rng).public();
619
620 println!(" send packet");
622 let data = b"hello world!";
623 io_rw
624 .send(ClientToRelayMsg::Datagrams {
625 dst_endpoint_id: target,
626 datagrams: Datagrams::from(data),
627 })
628 .await
629 .std_context("send")?;
630
631 done.cancel();
632 handle.await.std_context("join")?;
633 Ok(())
634 }
635
636 #[tokio::test(start_paused = true)]
637 #[traced_test]
638 async fn test_rate_limit() -> Result {
639 const LIMIT: u32 = 50;
640 const MAX_FRAMES: u32 = 100;
641
642 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
643
644 let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _);
646 let mut frame_writer = Conn::test(io_write);
647 let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?;
649
650 let data = Datagrams::from(b"hello world!!!!!");
652 let target = SecretKey::generate(&mut rng).public();
653 let frame = ClientToRelayMsg::Datagrams {
654 dst_endpoint_id: target,
655 datagrams: data.clone(),
656 };
657 let frame_len = frame.to_bytes().len();
658 assert_eq!(frame_len, LIMIT as usize);
659
660 info!("-- send packet");
662 frame_writer.send(frame.clone()).await.std_context("send")?;
663 frame_writer.flush().await.std_context("flush")?;
664 let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
665 .await
666 .expect("timeout")
667 .expect("option")
668 .expect("ok");
669 assert_eq!(recv_frame, frame);
670
671 info!("-- send packet");
673 frame_writer.send(frame.clone()).await.std_context("send")?;
674 frame_writer.flush().await.std_context("flush")?;
675 let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await;
676 assert!(res.is_err(), "expecting a timeout");
677 info!("-- timeout happened");
678
679 info!("-- sleep");
681 tokio::time::sleep(Duration::from_secs(1)).await;
682
683 let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
685 .await
686 .expect("timeout")
687 .expect("option")
688 .expect("ok");
689 assert_eq!(recv_frame, frame);
690
691 Ok(())
692 }
693}