iroh_relay/server/
client.rs

1//! The server-side representation of an ongoing client relaying connection.
2
3use std::{collections::HashSet, sync::Arc, time::Duration};
4
5use iroh_base::EndpointId;
6use n0_error::{e, stack_error};
7use n0_future::{SinkExt, StreamExt};
8use rand::Rng;
9use time::{Date, OffsetDateTime};
10use tokio::{
11    sync::mpsc::{self, error::TrySendError},
12    time::MissedTickBehavior,
13};
14use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
15use tracing::{Instrument, debug, trace, warn};
16
17use crate::{
18    PingTracker,
19    protos::{
20        relay::{ClientToRelayMsg, Datagrams, PING_INTERVAL, RelayToClientMsg},
21        streams::BytesStreamSink,
22    },
23    server::{
24        clients::Clients,
25        metrics::Metrics,
26        streams::{RecvError as RelayRecvError, RelayedStream, SendError as RelaySendError},
27    },
28};
29
30/// A request to write a dataframe to a Client
31#[derive(Debug, Clone)]
32pub(super) struct Packet {
33    /// The sender of the packet
34    src: EndpointId,
35    /// The data packet bytes.
36    data: Datagrams,
37}
38
39/// Configuration for a [`Client`].
40///
41/// Generic over the stream type to support different WebSocket implementations.
42#[derive(Debug)]
43pub struct Config<S> {
44    /// The endpoint ID of the client
45    pub endpoint_id: EndpointId,
46    /// The relayed stream connection
47    pub stream: RelayedStream<S>,
48    /// Write timeout for the client connection
49    pub write_timeout: Duration,
50    /// Channel capacity for internal message queues
51    pub channel_capacity: usize,
52}
53
54/// The [`Server`] side representation of a [`Client`]'s connection.
55///
56/// [`Server`]: crate::server::Server
57/// [`Client`]: crate::client::Client
58#[derive(Debug)]
59pub struct Client {
60    /// Identity of the connected peer.
61    endpoint_id: EndpointId,
62    /// Connection identifier.
63    connection_id: u64,
64    /// Used to close the connection loop.
65    done: CancellationToken,
66    /// Actor handle.
67    handle: AbortOnDropHandle<()>,
68    /// Channel to send packets intended for the client.
69    packet_queue: mpsc::Sender<Packet>,
70    /// Channel to send non-packet messages to the client.
71    message_queue: mpsc::Sender<RelayToClientMsg>,
72}
73
74impl Client {
75    /// Creates a client from a connection & starts a read and write loop to handle io to and from
76    /// the client
77    /// Call [`Client::shutdown`] to close the read and write loops before dropping the [`Client`]
78    pub(super) fn new<S>(
79        config: Config<S>,
80        connection_id: u64,
81        clients: &Clients,
82        metrics: Arc<Metrics>,
83    ) -> Client
84    where
85        S: BytesStreamSink + Send + 'static,
86    {
87        let Config {
88            endpoint_id,
89            stream,
90            write_timeout,
91            channel_capacity,
92        } = config;
93
94        let (packet_send_queue_s, packet_send_queue_r) = mpsc::channel(channel_capacity);
95        let (message_send_queue_s, message_send_queue_r) = mpsc::channel(channel_capacity);
96        let done = CancellationToken::new();
97
98        let actor = Actor {
99            stream,
100            timeout: write_timeout,
101            packet_send_queue: packet_send_queue_r,
102            message_send_queue: message_send_queue_r,
103            endpoint_id,
104            connection_id,
105            clients: clients.clone(),
106            client_counter: ClientCounter::default(),
107            ping_tracker: PingTracker::default(),
108            metrics,
109        };
110
111        // start io loop
112        let io_done = done.clone();
113        let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!(
114            "client-connection-actor",
115            remote_endpoint = %endpoint_id.fmt_short(),
116            connection_id = connection_id
117        )));
118
119        Client {
120            endpoint_id,
121            connection_id,
122            handle: AbortOnDropHandle::new(handle),
123            done,
124            packet_queue: packet_send_queue_s,
125            message_queue: message_send_queue_s,
126        }
127    }
128
129    pub(super) fn connection_id(&self) -> u64 {
130        self.connection_id
131    }
132
133    /// Shutdown the reader and writer loops and closes the connection.
134    ///
135    /// Any shutdown errors will be logged as warnings.
136    pub(super) async fn shutdown(self) {
137        self.start_shutdown();
138        if let Err(e) = self.handle.await {
139            warn!(
140                remote_endpoint = %self.endpoint_id.fmt_short(),
141                "error closing actor loop: {e:#?}",
142            );
143        };
144    }
145
146    /// Starts the process of shutdown.
147    pub(super) fn start_shutdown(&self) {
148        self.done.cancel();
149    }
150
151    pub(super) fn try_send_packet(
152        &self,
153        src: EndpointId,
154        data: Datagrams,
155    ) -> Result<(), TrySendError<Packet>> {
156        self.packet_queue.try_send(Packet { src, data })
157    }
158
159    pub(super) fn try_send_peer_gone(
160        &self,
161        key: EndpointId,
162    ) -> Result<(), TrySendError<RelayToClientMsg>> {
163        self.message_queue
164            .try_send(RelayToClientMsg::EndpointGone(key))
165    }
166
167    pub(super) fn try_send_health(
168        &self,
169        problem: String,
170    ) -> Result<(), TrySendError<RelayToClientMsg>> {
171        self.message_queue
172            .try_send(RelayToClientMsg::Health { problem })
173    }
174}
175
176/// Error when handling an incoming frame from a client.
177#[stack_error(derive, add_meta, from_sources)]
178#[allow(missing_docs)]
179#[non_exhaustive]
180pub enum HandleFrameError {
181    #[error(transparent)]
182    ForwardPacket { source: ForwardPacketError },
183    #[error("Stream terminated")]
184    StreamTerminated {},
185    #[error(transparent)]
186    Recv { source: RelayRecvError },
187    #[error(transparent)]
188    Send { source: WriteFrameError },
189}
190
191/// Error when writing a frame to a client.
192#[stack_error(derive, add_meta, from_sources)]
193#[allow(missing_docs)]
194#[non_exhaustive]
195pub enum WriteFrameError {
196    #[error(transparent)]
197    Stream { source: RelaySendError },
198    #[error(transparent)]
199    Timeout {
200        #[error(std_err)]
201        source: tokio::time::error::Elapsed,
202    },
203}
204
205/// Run error
206#[stack_error(derive, add_meta)]
207#[allow(missing_docs)]
208#[non_exhaustive]
209pub enum RunError {
210    #[error(transparent)]
211    ForwardPacket {
212        #[error(from)]
213        source: ForwardPacketError,
214    },
215    #[error("Flush")]
216    Flush {},
217    #[error(transparent)]
218    HandleFrame {
219        #[error(from)]
220        source: HandleFrameError,
221    },
222    #[error("Failed to send packet")]
223    PacketSend { source: WriteFrameError },
224    #[error("Handle was dropped")]
225    HandleDropped {},
226    #[error("Writing a frame failed")]
227    WriteFrame { source: WriteFrameError },
228    #[error("Tick flush")]
229    TickFlush {},
230}
231
232/// Manages all the reads and writes to this client. It periodically sends a `KEEP_ALIVE`
233/// message to the client to keep the connection alive.
234///
235/// Call `run` to manage the input and output to and from the connection and the server.
236/// Once it hits its first write error or error receiving off a channel,
237/// it errors on return.
238/// If writes do not complete in the given `timeout`, it will also error.
239///
240/// On the "write" side, the [`Actor`] can send the client:
241///  - a KEEP_ALIVE frame
242///  - a PEER_GONE frame to inform the client that a peer they have previously sent messages to
243///    is gone from the network
244///  - packets from other peers
245///
246/// On the "read" side, it can:
247///     - receive a ping and write a pong back
248///     to speak to the endpoint ID associated with that client.
249#[derive(Debug)]
250struct Actor<S> {
251    /// IO Stream to talk to the client
252    stream: RelayedStream<S>,
253    /// Maximum time we wait to complete a write to the client
254    timeout: Duration,
255    /// Receiver for packets to be sent to the client.
256    packet_send_queue: mpsc::Receiver<Packet>,
257    /// Receiver for non-packet messages to be sent to the client.
258    message_send_queue: mpsc::Receiver<RelayToClientMsg>,
259    /// [`EndpointId`] of this client
260    endpoint_id: EndpointId,
261    /// Connection identifier.
262    connection_id: u64,
263    /// Reference to the other connected clients.
264    clients: Clients,
265    /// Statistics about the connected clients
266    client_counter: ClientCounter,
267    ping_tracker: PingTracker,
268    metrics: Arc<Metrics>,
269}
270
271impl<S> Actor<S>
272where
273    S: BytesStreamSink,
274{
275    async fn run(mut self, done: CancellationToken) {
276        // Note the accept and disconnects metrics must be in a pair.  Technically the
277        // connection is accepted long before this in the HTTP server, but it is clearer to
278        // handle the metric here.
279        self.metrics.accepts.inc();
280        if self.client_counter.update(self.endpoint_id) {
281            self.metrics.unique_client_keys.inc();
282        }
283        match self.run_inner(done).await {
284            Err(e) => {
285                warn!("actor errored {e:#}, exiting");
286            }
287            Ok(()) => {
288                debug!("actor finished, exiting");
289            }
290        }
291
292        self.clients
293            .unregister(self.connection_id, self.endpoint_id);
294        self.metrics.disconnects.inc();
295    }
296
297    async fn run_inner(&mut self, done: CancellationToken) -> Result<(), RunError> {
298        // Add some jitter to ping pong interactions, to avoid all pings being sent at the same time
299        let next_interval = || {
300            let random_secs = rand::rng().random_range(1..=5);
301            Duration::from_secs(random_secs) + PING_INTERVAL
302        };
303
304        let mut ping_interval = tokio::time::interval(next_interval());
305        // ticks immediately
306        ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
307        ping_interval.tick().await;
308
309        loop {
310            tokio::select! {
311                biased;
312
313                _ = done.cancelled() => {
314                    trace!("actor loop cancelled, exiting");
315                    // final flush
316                    self.stream.flush().await.map_err(|_| e!(RunError::Flush))?;
317                    break;
318                }
319                maybe_frame = self.stream.next() => {
320                    self
321                        .handle_frame(maybe_frame)
322                        .await?;
323                    // reset the ping interval, we just received a message
324                    ping_interval.reset();
325                }
326                // Second priority, sending regular packets
327                packet = self.packet_send_queue.recv() => {
328                    let packet = packet.ok_or_else(|| e!(RunError::HandleDropped))?;
329                    self.send_packet(packet)
330                        .await
331                        .map_err(|err| e!(RunError::PacketSend, err))?;
332                }
333                // Last priority, sending other message
334                message = self.message_send_queue.recv() => {
335                    let message = message .ok_or_else(|| e!(RunError::HandleDropped))?;
336                    trace!("send {message:?}");
337                    self.write_frame(message)
338                        .await
339                        .map_err(|err| e!(RunError::WriteFrame, err))?;
340                }
341                _ = self.ping_tracker.timeout() => {
342                    trace!("pong timed out");
343                    break;
344                }
345                _ = ping_interval.tick() => {
346                    trace!("keep alive ping");
347                    // new interval
348                    ping_interval.reset_after(next_interval());
349                    let data = self.ping_tracker.new_ping();
350                    self.write_frame(RelayToClientMsg::Ping(data))
351                        .await
352                        .map_err(|err| e!(RunError::WriteFrame, err))?;
353                }
354            }
355
356            self.stream
357                .flush()
358                .await
359                .map_err(|_| e!(RunError::TickFlush))?;
360        }
361        Ok(())
362    }
363
364    /// Writes the given frame to the connection.
365    ///
366    /// Errors if the send does not happen within the `timeout` duration
367    async fn write_frame(&mut self, frame: RelayToClientMsg) -> Result<(), WriteFrameError> {
368        tokio::time::timeout(self.timeout, self.stream.send(frame)).await??;
369        Ok(())
370    }
371
372    /// Writes contents to the client in a `RECV_PACKET` frame.
373    ///
374    /// Errors if the send does not happen within the `timeout` duration
375    /// Does not flush.
376    async fn send_raw(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
377        let remote_endpoint_id = packet.src;
378        let datagrams = packet.data;
379
380        if let Ok(len) = datagrams.contents.len().try_into() {
381            self.metrics.bytes_sent.inc_by(len);
382        }
383        self.write_frame(RelayToClientMsg::Datagrams {
384            remote_endpoint_id,
385            datagrams,
386        })
387        .await
388    }
389
390    async fn send_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
391        trace!("send packet");
392        match self.send_raw(packet).await {
393            Ok(()) => {
394                self.metrics.send_packets_sent.inc();
395                Ok(())
396            }
397            Err(err) => {
398                self.metrics.send_packets_dropped.inc();
399                Err(err)
400            }
401        }
402    }
403
404    /// Handles frame read results.
405    async fn handle_frame(
406        &mut self,
407        maybe_frame: Option<Result<ClientToRelayMsg, RelayRecvError>>,
408    ) -> Result<(), HandleFrameError> {
409        trace!(?maybe_frame, "handle incoming frame");
410        let frame = match maybe_frame {
411            Some(frame) => frame?,
412            None => return Err(e!(HandleFrameError::StreamTerminated)),
413        };
414
415        match frame {
416            ClientToRelayMsg::Datagrams {
417                dst_endpoint_id: dst_key,
418                datagrams,
419            } => {
420                let packet_len = datagrams.contents.len();
421                if let Err(err @ ForwardPacketError { .. }) =
422                    self.handle_frame_send_packet(dst_key, datagrams)
423                {
424                    warn!("failed to handle send packet frame: {err:#}");
425                }
426                self.metrics.bytes_recv.inc_by(packet_len as u64);
427            }
428            ClientToRelayMsg::Ping(data) => {
429                self.metrics.got_ping.inc();
430                // TODO: add rate limiter
431                self.write_frame(RelayToClientMsg::Pong(data)).await?;
432                self.metrics.sent_pong.inc();
433            }
434            ClientToRelayMsg::Pong(data) => {
435                self.ping_tracker.pong_received(data);
436            }
437        }
438        Ok(())
439    }
440
441    fn handle_frame_send_packet(
442        &self,
443        dst: EndpointId,
444        data: Datagrams,
445    ) -> Result<(), ForwardPacketError> {
446        self.metrics.send_packets_recv.inc();
447        self.clients
448            .send_packet(dst, data, self.endpoint_id, &self.metrics)?;
449
450        Ok(())
451    }
452}
453
454#[derive(Debug)]
455pub(crate) enum SendError {
456    Full,
457    Closed,
458}
459
460/// Error returned when forwarding a packet to a client fails.
461///
462/// This error occurs when the relay server cannot deliver a packet to its intended
463/// recipient, typically due to the client's send queue being full or the client
464/// disconnecting.
465#[stack_error(derive, add_meta)]
466#[error("failed to forward packet: {reason:?}")]
467pub struct ForwardPacketError {
468    reason: SendError,
469}
470
471/// Tracks how many unique endpoints have been seen during the last day.
472#[derive(Debug)]
473struct ClientCounter {
474    clients: HashSet<EndpointId>,
475    last_clear_date: Date,
476}
477
478impl Default for ClientCounter {
479    fn default() -> Self {
480        Self {
481            clients: HashSet::new(),
482            last_clear_date: OffsetDateTime::now_utc().date(),
483        }
484    }
485}
486
487impl ClientCounter {
488    fn check_and_clear(&mut self) {
489        let today = OffsetDateTime::now_utc().date();
490        if today != self.last_clear_date {
491            self.clients.clear();
492            self.last_clear_date = today;
493        }
494    }
495
496    /// Marks this endpoint as seen, returns whether it is new today or not.
497    fn update(&mut self, client: EndpointId) -> bool {
498        self.check_and_clear();
499        self.clients.insert(client)
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use iroh_base::SecretKey;
506    use n0_error::{Result, StdResultExt, bail_any};
507    use n0_future::Stream;
508    use n0_tracing_test::traced_test;
509    use rand::SeedableRng;
510    use tracing::info;
511
512    use super::*;
513    use crate::{client::conn::Conn, protos::common::FrameType};
514
515    async fn recv_frame<
516        E: std::error::Error + Sync + Send + 'static,
517        S: Stream<Item = Result<RelayToClientMsg, E>> + Unpin,
518    >(
519        frame_type: FrameType,
520        mut stream: S,
521    ) -> Result<RelayToClientMsg> {
522        match stream.next().await {
523            Some(Ok(frame)) => {
524                if frame_type != frame.typ() {
525                    bail_any!(
526                        "Unexpected frame, got {:?}, but expected {:?}",
527                        frame.typ(),
528                        frame_type
529                    );
530                }
531                Ok(frame)
532            }
533            Some(Err(err)) => Err(err).anyerr(),
534            None => bail_any!("Unexpected EOF, expected frame {frame_type:?}"),
535        }
536    }
537
538    #[tokio::test]
539    #[traced_test]
540    async fn test_client_actor_basic() -> Result {
541        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
542
543        let (send_queue_s, send_queue_r) = mpsc::channel(10);
544        let (message_s, message_r) = mpsc::channel(10);
545
546        let endpoint_id = SecretKey::generate(&mut rng).public();
547        let (io, io_rw) = tokio::io::duplex(1024);
548        let mut io_rw = Conn::test(io_rw);
549        let stream = RelayedStream::test(io);
550
551        let clients = Clients::default();
552        let metrics = Arc::new(Metrics::default());
553        let actor = Actor {
554            stream,
555            timeout: Duration::from_secs(1),
556            packet_send_queue: send_queue_r,
557            message_send_queue: message_r,
558            connection_id: 0,
559            endpoint_id,
560            clients: clients.clone(),
561            client_counter: ClientCounter::default(),
562            ping_tracker: PingTracker::default(),
563            metrics,
564        };
565
566        let done = CancellationToken::new();
567        let io_done = done.clone();
568        let handle = tokio::task::spawn(async move { actor.run(io_done).await });
569
570        // Write tests
571        println!("-- write");
572        let data = b"hello world!";
573
574        // send packet
575        println!("  send packet");
576        let packet = Packet {
577            src: endpoint_id,
578            data: Datagrams::from(&data[..]),
579        };
580        send_queue_s
581            .send(packet.clone())
582            .await
583            .std_context("send")?;
584        let frame = recv_frame(FrameType::RelayToClientDatagram, &mut io_rw)
585            .await
586            .anyerr()?;
587        assert_eq!(
588            frame,
589            RelayToClientMsg::Datagrams {
590                remote_endpoint_id: endpoint_id,
591                datagrams: data.to_vec().into()
592            }
593        );
594
595        // send peer_gone
596        println!("send peer gone");
597        message_s
598            .send(RelayToClientMsg::EndpointGone(endpoint_id))
599            .await
600            .std_context("send")?;
601        let frame = recv_frame(FrameType::EndpointGone, &mut io_rw)
602            .await
603            .anyerr()?;
604        assert_eq!(frame, RelayToClientMsg::EndpointGone(endpoint_id));
605
606        // Read tests
607        println!("--read");
608
609        // send ping, expect pong
610        let data = b"pingpong";
611        io_rw.send(ClientToRelayMsg::Ping(*data)).await?;
612
613        // recv pong
614        println!(" recv pong");
615        let frame = recv_frame(FrameType::Pong, &mut io_rw).await?;
616        assert_eq!(frame, RelayToClientMsg::Pong(*data));
617
618        let target = SecretKey::generate(&mut rng).public();
619
620        // send packet
621        println!("  send packet");
622        let data = b"hello world!";
623        io_rw
624            .send(ClientToRelayMsg::Datagrams {
625                dst_endpoint_id: target,
626                datagrams: Datagrams::from(data),
627            })
628            .await
629            .std_context("send")?;
630
631        done.cancel();
632        handle.await.std_context("join")?;
633        Ok(())
634    }
635
636    #[tokio::test(start_paused = true)]
637    #[traced_test]
638    async fn test_rate_limit() -> Result {
639        const LIMIT: u32 = 50;
640        const MAX_FRAMES: u32 = 100;
641
642        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
643
644        // Build the rate limited stream.
645        let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _);
646        let mut frame_writer = Conn::test(io_write);
647        // Rate limiter allowing LIMIT bytes/s
648        let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?;
649
650        // Prepare a frame to send, assert its size.
651        let data = Datagrams::from(b"hello world!!!!!");
652        let target = SecretKey::generate(&mut rng).public();
653        let frame = ClientToRelayMsg::Datagrams {
654            dst_endpoint_id: target,
655            datagrams: data.clone(),
656        };
657        let frame_len = frame.to_bytes().len();
658        assert_eq!(frame_len, LIMIT as usize);
659
660        // Send a frame, it should arrive.
661        info!("-- send packet");
662        frame_writer.send(frame.clone()).await.std_context("send")?;
663        frame_writer.flush().await.std_context("flush")?;
664        let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
665            .await
666            .expect("timeout")
667            .expect("option")
668            .expect("ok");
669        assert_eq!(recv_frame, frame);
670
671        // Next frame does not arrive.
672        info!("-- send packet");
673        frame_writer.send(frame.clone()).await.std_context("send")?;
674        frame_writer.flush().await.std_context("flush")?;
675        let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await;
676        assert!(res.is_err(), "expecting a timeout");
677        info!("-- timeout happened");
678
679        // Wait long enough.
680        info!("-- sleep");
681        tokio::time::sleep(Duration::from_secs(1)).await;
682
683        // Frame arrives.
684        let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
685            .await
686            .expect("timeout")
687            .expect("option")
688            .expect("ok");
689        assert_eq!(recv_frame, frame);
690
691        Ok(())
692    }
693}