iroh_relay/server/
clients.rs

1//! The "Server" side of the client. Uses the `ClientConnManager`.
2// Based on tailscale/derp/derp_server.go
3
4use std::{
5    collections::HashSet,
6    sync::{
7        Arc,
8        atomic::{AtomicU64, Ordering},
9    },
10};
11
12use dashmap::DashMap;
13use iroh_base::EndpointId;
14use n0_future::IterExt;
15use tokio::sync::mpsc::error::TrySendError;
16use tracing::{debug, trace};
17
18use super::client::{Client, Config, ForwardPacketError};
19use crate::{
20    protos::{relay::Datagrams, streams::BytesStreamSink},
21    server::{client::SendError, metrics::Metrics},
22};
23
24/// Manages the connections to all currently connected clients.
25#[derive(Debug, Default, Clone)]
26/// Registry of connected relay clients.
27///
28/// This type manages the collection of active client connections and
29/// handles routing messages between them.
30pub struct Clients(Arc<Inner>);
31
32#[derive(Debug, Default)]
33struct Inner {
34    /// The list of all currently connected clients.
35    clients: DashMap<EndpointId, ClientState>,
36    /// Map of which client has sent where
37    sent_to: DashMap<EndpointId, HashSet<EndpointId>>,
38    /// Connection ID Counter
39    next_connection_id: AtomicU64,
40}
41
42#[derive(Debug)]
43struct ClientState {
44    active: Client,
45    inactive: Vec<Client>,
46}
47
48impl ClientState {
49    async fn shutdown_all(mut self) {
50        [self.active]
51            .into_iter()
52            .chain(self.inactive.drain(..))
53            .map(Client::shutdown)
54            .join_all()
55            .await;
56    }
57}
58
59impl Clients {
60    /// Shuts down all connected clients.
61    ///
62    /// This method gracefully disconnects all active client connections managed by
63    /// this registry. It will wait for all clients to complete their shutdown before
64    /// returning.
65    pub async fn shutdown(&self) {
66        let keys: Vec<_> = self.0.clients.iter().map(|x| *x.key()).collect();
67        trace!("shutting down {} clients", keys.len());
68        let clients = keys.into_iter().filter_map(|k| self.0.clients.remove(&k));
69        n0_future::join_all(clients.map(|(_, state)| state.shutdown_all())).await;
70    }
71
72    /// Builds the client handler and starts the read & write loops for the connection.
73    pub fn register<S>(&self, client_config: Config<S>, metrics: Arc<Metrics>)
74    where
75        S: BytesStreamSink + Send + 'static,
76    {
77        let endpoint_id = client_config.endpoint_id;
78        let connection_id = self.get_connection_id();
79        trace!(remote_endpoint = %endpoint_id.fmt_short(), "registering client");
80
81        let client = Client::new(client_config, connection_id, self, metrics);
82        match self.0.clients.entry(endpoint_id) {
83            dashmap::Entry::Occupied(mut entry) => {
84                let state = entry.get_mut();
85                let old_client = std::mem::replace(&mut state.active, client);
86                debug!(
87                    remote_endpoint = %endpoint_id.fmt_short(),
88                    "multiple connections found, deactivating old connection",
89                );
90                old_client
91                    .try_send_health("Another endpoint connected with the same endpoint id. No more messages will be received".to_string())
92                    .ok();
93                state.inactive.push(old_client);
94            }
95            dashmap::Entry::Vacant(entry) => {
96                entry.insert(ClientState {
97                    active: client,
98                    inactive: Vec::new(),
99                });
100            }
101        }
102    }
103
104    fn get_connection_id(&self) -> u64 {
105        self.0.next_connection_id.fetch_add(1, Ordering::Relaxed)
106    }
107
108    /// Removes the client from the map of clients, & sends a notification
109    /// to each client that peers has sent data to, to let them know that
110    /// peer is gone from the network.
111    ///
112    /// Must be passed a matching connection_id.
113    pub(super) fn unregister(&self, connection_id: u64, endpoint_id: EndpointId) {
114        trace!(
115            endpoint_id = %endpoint_id.fmt_short(),
116            connection_id, "unregistering client"
117        );
118
119        self.0.clients.remove_if_mut(&endpoint_id, |_id, state| {
120            if state.active.connection_id() == connection_id {
121                // The unregistering client is the currently active client
122                if let Some(last_inactive_client) = state.inactive.pop() {
123                    // There is an inactive client, promote to active again.
124                    state.active = last_inactive_client;
125                    // Don't remove the entry from client map.
126                    false
127                } else {
128                    // No inactive clients: Inform other peers that this peer is now gone.
129                    if let Some((_, sent_to)) = self.0.sent_to.remove(&endpoint_id) {
130                        for key in sent_to {
131                            match state.active.try_send_peer_gone(key) {
132                                Ok(_) => {}
133                                Err(TrySendError::Full(_)) => {
134                                    debug!(
135                                        dst = %key.fmt_short(),
136                                        "client too busy to receive packet, dropping packet"
137                                    );
138                                }
139                                Err(TrySendError::Closed(_)) => {
140                                    debug!(
141                                        dst = %key.fmt_short(),
142                                        "can no longer write to client, dropping packet"
143                                    );
144                                }
145                            }
146                        }
147                    }
148                    // Remove entry from the client map.
149                    true
150                }
151            } else {
152                // The unregistering client is already inactive. Remove from the list of inactive clients.
153                state
154                    .inactive
155                    .retain(|client| client.connection_id() != connection_id);
156                // Active client is unmodified: keep entry in map.
157                false
158            }
159        });
160    }
161
162    /// Attempt to send a packet to client with [`EndpointId`] `dst`.
163    pub(super) fn send_packet(
164        &self,
165        dst: EndpointId,
166        data: Datagrams,
167        src: EndpointId,
168        metrics: &Metrics,
169    ) -> Result<(), ForwardPacketError> {
170        let Some(client) = self.0.clients.get(&dst) else {
171            debug!(dst = %dst.fmt_short(), "no connected client, dropped packet");
172            metrics.send_packets_dropped.inc();
173            return Ok(());
174        };
175        match client.active.try_send_packet(src, data) {
176            Ok(_) => {
177                // Record sent_to relationship
178                self.0.sent_to.entry(src).or_default().insert(dst);
179                Ok(())
180            }
181            Err(TrySendError::Full(_)) => {
182                debug!(
183                    dst = %dst.fmt_short(),
184                    "client too busy to receive packet, dropping packet"
185                );
186                Err(ForwardPacketError::new(SendError::Full))
187            }
188            Err(TrySendError::Closed(_)) => {
189                debug!(
190                    dst = %dst.fmt_short(),
191                    "can no longer write to client, dropping message and pruning connection"
192                );
193                client.active.start_shutdown();
194                Err(ForwardPacketError::new(SendError::Closed))
195            }
196        }
197    }
198
199    #[cfg(test)]
200    fn active_connection_id(&self, endpoint_id: EndpointId) -> Option<u64> {
201        self.0
202            .clients
203            .get(&endpoint_id)
204            .map(|s| s.active.connection_id())
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::time::Duration;
211
212    use iroh_base::SecretKey;
213    use n0_error::{Result, StdResultExt};
214    use n0_future::{Stream, StreamExt};
215    use n0_tracing_test::traced_test;
216    use rand::SeedableRng;
217
218    use super::*;
219    use crate::{
220        client::conn::Conn,
221        protos::{common::FrameType, relay::RelayToClientMsg, streams::WsBytesFramed},
222        server::streams::{MaybeTlsStream, RateLimited, ServerRelayedStream},
223    };
224
225    async fn recv_frame<
226        E: std::error::Error + Sync + Send + 'static,
227        S: Stream<Item = Result<RelayToClientMsg, E>> + Unpin,
228    >(
229        frame_type: FrameType,
230        mut stream: S,
231    ) -> Result<RelayToClientMsg> {
232        match stream.next().await {
233            Some(Ok(frame)) => {
234                if frame_type != frame.typ() {
235                    n0_error::bail_any!(
236                        "Unexpected frame, got {:?}, but expected {:?}",
237                        frame.typ(),
238                        frame_type
239                    );
240                }
241                Ok(frame)
242            }
243            Some(Err(err)) => Err(err).anyerr(),
244            None => n0_error::bail_any!("Unexpected EOF, expected frame {frame_type:?}"),
245        }
246    }
247
248    fn test_client_builder(
249        key: EndpointId,
250    ) -> (Config<WsBytesFramed<RateLimited<MaybeTlsStream>>>, Conn) {
251        let (server, client) = tokio::io::duplex(1024);
252        (
253            Config {
254                endpoint_id: key,
255                stream: ServerRelayedStream::test(server),
256                write_timeout: Duration::from_secs(1),
257                channel_capacity: 10,
258            },
259            Conn::test(client),
260        )
261    }
262
263    #[tokio::test]
264    #[traced_test]
265    async fn test_clients() -> Result {
266        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
267        let a_key = SecretKey::generate(&mut rng).public();
268        let b_key = SecretKey::generate(&mut rng).public();
269
270        let (builder_a, mut a_rw) = test_client_builder(a_key);
271
272        let clients = Clients::default();
273        let metrics = Arc::new(Metrics::default());
274        clients.register(builder_a, metrics.clone());
275
276        // send packet
277        let data = b"hello world!";
278        clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
279        let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a_rw).await?;
280        assert_eq!(
281            frame,
282            RelayToClientMsg::Datagrams {
283                remote_endpoint_id: b_key,
284                datagrams: data.to_vec().into(),
285            }
286        );
287
288        {
289            let client = clients.0.clients.get(&a_key).unwrap();
290            // shutdown client a, this should trigger the removal from the clients list
291            client.active.start_shutdown();
292        }
293
294        // need to wait a moment for the removal to be processed
295        let c = clients.clone();
296        tokio::time::timeout(Duration::from_secs(1), async move {
297            loop {
298                if !c.0.clients.contains_key(&a_key) {
299                    break;
300                }
301                tokio::time::sleep(Duration::from_millis(100)).await;
302            }
303        })
304        .await
305        .std_context("timeout")?;
306        clients.shutdown().await;
307
308        Ok(())
309    }
310
311    #[tokio::test]
312    #[traced_test]
313    async fn test_clients_same_endpoint_id() -> Result {
314        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
315        let a_key = SecretKey::generate(&mut rng).public();
316        let b_key = SecretKey::generate(&mut rng).public();
317
318        let (a1_builder, mut a1_rw) = test_client_builder(a_key);
319
320        let clients = Clients::default();
321        let metrics = Arc::new(Metrics::default());
322
323        // register client a
324        clients.register(a1_builder, metrics.clone());
325        let a1_conn_id = clients.active_connection_id(a_key).unwrap();
326
327        // send packet and verify it is send to a1
328        let data = b"hello world!";
329        clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
330        let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a1_rw).await?;
331        assert_eq!(
332            frame,
333            RelayToClientMsg::Datagrams {
334                remote_endpoint_id: b_key,
335                datagrams: data.to_vec().into(),
336            }
337        );
338
339        // register new client with same endpoint id
340        let (a2_builder, mut a2_rw) = test_client_builder(a_key);
341        clients.register(a2_builder, metrics.clone());
342        let a2_conn_id = clients.active_connection_id(a_key).unwrap();
343        assert!(a2_conn_id != a1_conn_id);
344
345        // a1 is marked inactive and should receive a health frame
346        let _frame = recv_frame(FrameType::Health, &mut a1_rw).await?;
347
348        // send packet and verify it is send to a2
349        clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
350        let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a2_rw).await?;
351        assert_eq!(
352            frame,
353            RelayToClientMsg::Datagrams {
354                remote_endpoint_id: b_key,
355                datagrams: data.to_vec().into(),
356            }
357        );
358
359        // disconnect a2
360        clients
361            .0
362            .clients
363            .get(&a_key)
364            .unwrap()
365            .active
366            .start_shutdown();
367
368        // need to wait a moment for the removal to be processed
369        tokio::time::timeout(Duration::from_secs(1), {
370            let clients = clients.clone();
371            async move {
372                // wait until the active connection is no longer a2 (which we unregistered)
373                while clients.active_connection_id(a_key) == Some(a2_conn_id) {
374                    tokio::time::sleep(Duration::from_millis(100)).await;
375                }
376            }
377        })
378        .await
379        .std_context("timeout")?;
380
381        // a1 should be marked active again now, and receive sent messages
382        assert_eq!(clients.active_connection_id(a_key), Some(a1_conn_id));
383        clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
384        let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a1_rw).await?;
385        assert_eq!(
386            frame,
387            RelayToClientMsg::Datagrams {
388                remote_endpoint_id: b_key,
389                datagrams: data.to_vec().into(),
390            }
391        );
392
393        // after shutting down the now-active client, there should no longer be an entry for that endpoint id
394        clients
395            .0
396            .clients
397            .get(&a_key)
398            .unwrap()
399            .active
400            .start_shutdown();
401
402        // need to wait a moment for the removal to be processed
403        tokio::time::timeout(Duration::from_secs(1), {
404            let clients = clients.clone();
405            async move {
406                // wait until the active connection is no longer a2 (which we unregistered)
407                while clients.0.clients.contains_key(&a_key) {
408                    tokio::time::sleep(Duration::from_millis(100)).await;
409                }
410            }
411        })
412        .await
413        .std_context("timeout")?;
414
415        clients.shutdown().await;
416
417        Ok(())
418    }
419}