1use std::{
5 collections::HashSet,
6 sync::{
7 Arc,
8 atomic::{AtomicU64, Ordering},
9 },
10};
11
12use dashmap::DashMap;
13use iroh_base::EndpointId;
14use n0_future::IterExt;
15use tokio::sync::mpsc::error::TrySendError;
16use tracing::{debug, trace};
17
18use super::client::{Client, Config, ForwardPacketError};
19use crate::{
20 protos::{relay::Datagrams, streams::BytesStreamSink},
21 server::{client::SendError, metrics::Metrics},
22};
23
24#[derive(Debug, Default, Clone)]
26pub struct Clients(Arc<Inner>);
31
32#[derive(Debug, Default)]
33struct Inner {
34 clients: DashMap<EndpointId, ClientState>,
36 sent_to: DashMap<EndpointId, HashSet<EndpointId>>,
38 next_connection_id: AtomicU64,
40}
41
42#[derive(Debug)]
43struct ClientState {
44 active: Client,
45 inactive: Vec<Client>,
46}
47
48impl ClientState {
49 async fn shutdown_all(mut self) {
50 [self.active]
51 .into_iter()
52 .chain(self.inactive.drain(..))
53 .map(Client::shutdown)
54 .join_all()
55 .await;
56 }
57}
58
59impl Clients {
60 pub async fn shutdown(&self) {
66 let keys: Vec<_> = self.0.clients.iter().map(|x| *x.key()).collect();
67 trace!("shutting down {} clients", keys.len());
68 let clients = keys.into_iter().filter_map(|k| self.0.clients.remove(&k));
69 n0_future::join_all(clients.map(|(_, state)| state.shutdown_all())).await;
70 }
71
72 pub fn register<S>(&self, client_config: Config<S>, metrics: Arc<Metrics>)
74 where
75 S: BytesStreamSink + Send + 'static,
76 {
77 let endpoint_id = client_config.endpoint_id;
78 let connection_id = self.get_connection_id();
79 trace!(remote_endpoint = %endpoint_id.fmt_short(), "registering client");
80
81 let client = Client::new(client_config, connection_id, self, metrics);
82 match self.0.clients.entry(endpoint_id) {
83 dashmap::Entry::Occupied(mut entry) => {
84 let state = entry.get_mut();
85 let old_client = std::mem::replace(&mut state.active, client);
86 debug!(
87 remote_endpoint = %endpoint_id.fmt_short(),
88 "multiple connections found, deactivating old connection",
89 );
90 old_client
91 .try_send_health("Another endpoint connected with the same endpoint id. No more messages will be received".to_string())
92 .ok();
93 state.inactive.push(old_client);
94 }
95 dashmap::Entry::Vacant(entry) => {
96 entry.insert(ClientState {
97 active: client,
98 inactive: Vec::new(),
99 });
100 }
101 }
102 }
103
104 fn get_connection_id(&self) -> u64 {
105 self.0.next_connection_id.fetch_add(1, Ordering::Relaxed)
106 }
107
108 pub(super) fn unregister(&self, connection_id: u64, endpoint_id: EndpointId) {
114 trace!(
115 endpoint_id = %endpoint_id.fmt_short(),
116 connection_id, "unregistering client"
117 );
118
119 self.0.clients.remove_if_mut(&endpoint_id, |_id, state| {
120 if state.active.connection_id() == connection_id {
121 if let Some(last_inactive_client) = state.inactive.pop() {
123 state.active = last_inactive_client;
125 false
127 } else {
128 if let Some((_, sent_to)) = self.0.sent_to.remove(&endpoint_id) {
130 for key in sent_to {
131 match state.active.try_send_peer_gone(key) {
132 Ok(_) => {}
133 Err(TrySendError::Full(_)) => {
134 debug!(
135 dst = %key.fmt_short(),
136 "client too busy to receive packet, dropping packet"
137 );
138 }
139 Err(TrySendError::Closed(_)) => {
140 debug!(
141 dst = %key.fmt_short(),
142 "can no longer write to client, dropping packet"
143 );
144 }
145 }
146 }
147 }
148 true
150 }
151 } else {
152 state
154 .inactive
155 .retain(|client| client.connection_id() != connection_id);
156 false
158 }
159 });
160 }
161
162 pub(super) fn send_packet(
164 &self,
165 dst: EndpointId,
166 data: Datagrams,
167 src: EndpointId,
168 metrics: &Metrics,
169 ) -> Result<(), ForwardPacketError> {
170 let Some(client) = self.0.clients.get(&dst) else {
171 debug!(dst = %dst.fmt_short(), "no connected client, dropped packet");
172 metrics.send_packets_dropped.inc();
173 return Ok(());
174 };
175 match client.active.try_send_packet(src, data) {
176 Ok(_) => {
177 self.0.sent_to.entry(src).or_default().insert(dst);
179 Ok(())
180 }
181 Err(TrySendError::Full(_)) => {
182 debug!(
183 dst = %dst.fmt_short(),
184 "client too busy to receive packet, dropping packet"
185 );
186 Err(ForwardPacketError::new(SendError::Full))
187 }
188 Err(TrySendError::Closed(_)) => {
189 debug!(
190 dst = %dst.fmt_short(),
191 "can no longer write to client, dropping message and pruning connection"
192 );
193 client.active.start_shutdown();
194 Err(ForwardPacketError::new(SendError::Closed))
195 }
196 }
197 }
198
199 #[cfg(test)]
200 fn active_connection_id(&self, endpoint_id: EndpointId) -> Option<u64> {
201 self.0
202 .clients
203 .get(&endpoint_id)
204 .map(|s| s.active.connection_id())
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::time::Duration;
211
212 use iroh_base::SecretKey;
213 use n0_error::{Result, StdResultExt};
214 use n0_future::{Stream, StreamExt};
215 use n0_tracing_test::traced_test;
216 use rand::SeedableRng;
217
218 use super::*;
219 use crate::{
220 client::conn::Conn,
221 protos::{common::FrameType, relay::RelayToClientMsg, streams::WsBytesFramed},
222 server::streams::{MaybeTlsStream, RateLimited, ServerRelayedStream},
223 };
224
225 async fn recv_frame<
226 E: std::error::Error + Sync + Send + 'static,
227 S: Stream<Item = Result<RelayToClientMsg, E>> + Unpin,
228 >(
229 frame_type: FrameType,
230 mut stream: S,
231 ) -> Result<RelayToClientMsg> {
232 match stream.next().await {
233 Some(Ok(frame)) => {
234 if frame_type != frame.typ() {
235 n0_error::bail_any!(
236 "Unexpected frame, got {:?}, but expected {:?}",
237 frame.typ(),
238 frame_type
239 );
240 }
241 Ok(frame)
242 }
243 Some(Err(err)) => Err(err).anyerr(),
244 None => n0_error::bail_any!("Unexpected EOF, expected frame {frame_type:?}"),
245 }
246 }
247
248 fn test_client_builder(
249 key: EndpointId,
250 ) -> (Config<WsBytesFramed<RateLimited<MaybeTlsStream>>>, Conn) {
251 let (server, client) = tokio::io::duplex(1024);
252 (
253 Config {
254 endpoint_id: key,
255 stream: ServerRelayedStream::test(server),
256 write_timeout: Duration::from_secs(1),
257 channel_capacity: 10,
258 },
259 Conn::test(client),
260 )
261 }
262
263 #[tokio::test]
264 #[traced_test]
265 async fn test_clients() -> Result {
266 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
267 let a_key = SecretKey::generate(&mut rng).public();
268 let b_key = SecretKey::generate(&mut rng).public();
269
270 let (builder_a, mut a_rw) = test_client_builder(a_key);
271
272 let clients = Clients::default();
273 let metrics = Arc::new(Metrics::default());
274 clients.register(builder_a, metrics.clone());
275
276 let data = b"hello world!";
278 clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
279 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a_rw).await?;
280 assert_eq!(
281 frame,
282 RelayToClientMsg::Datagrams {
283 remote_endpoint_id: b_key,
284 datagrams: data.to_vec().into(),
285 }
286 );
287
288 {
289 let client = clients.0.clients.get(&a_key).unwrap();
290 client.active.start_shutdown();
292 }
293
294 let c = clients.clone();
296 tokio::time::timeout(Duration::from_secs(1), async move {
297 loop {
298 if !c.0.clients.contains_key(&a_key) {
299 break;
300 }
301 tokio::time::sleep(Duration::from_millis(100)).await;
302 }
303 })
304 .await
305 .std_context("timeout")?;
306 clients.shutdown().await;
307
308 Ok(())
309 }
310
311 #[tokio::test]
312 #[traced_test]
313 async fn test_clients_same_endpoint_id() -> Result {
314 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
315 let a_key = SecretKey::generate(&mut rng).public();
316 let b_key = SecretKey::generate(&mut rng).public();
317
318 let (a1_builder, mut a1_rw) = test_client_builder(a_key);
319
320 let clients = Clients::default();
321 let metrics = Arc::new(Metrics::default());
322
323 clients.register(a1_builder, metrics.clone());
325 let a1_conn_id = clients.active_connection_id(a_key).unwrap();
326
327 let data = b"hello world!";
329 clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
330 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a1_rw).await?;
331 assert_eq!(
332 frame,
333 RelayToClientMsg::Datagrams {
334 remote_endpoint_id: b_key,
335 datagrams: data.to_vec().into(),
336 }
337 );
338
339 let (a2_builder, mut a2_rw) = test_client_builder(a_key);
341 clients.register(a2_builder, metrics.clone());
342 let a2_conn_id = clients.active_connection_id(a_key).unwrap();
343 assert!(a2_conn_id != a1_conn_id);
344
345 let _frame = recv_frame(FrameType::Health, &mut a1_rw).await?;
347
348 clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
350 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a2_rw).await?;
351 assert_eq!(
352 frame,
353 RelayToClientMsg::Datagrams {
354 remote_endpoint_id: b_key,
355 datagrams: data.to_vec().into(),
356 }
357 );
358
359 clients
361 .0
362 .clients
363 .get(&a_key)
364 .unwrap()
365 .active
366 .start_shutdown();
367
368 tokio::time::timeout(Duration::from_secs(1), {
370 let clients = clients.clone();
371 async move {
372 while clients.active_connection_id(a_key) == Some(a2_conn_id) {
374 tokio::time::sleep(Duration::from_millis(100)).await;
375 }
376 }
377 })
378 .await
379 .std_context("timeout")?;
380
381 assert_eq!(clients.active_connection_id(a_key), Some(a1_conn_id));
383 clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
384 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a1_rw).await?;
385 assert_eq!(
386 frame,
387 RelayToClientMsg::Datagrams {
388 remote_endpoint_id: b_key,
389 datagrams: data.to_vec().into(),
390 }
391 );
392
393 clients
395 .0
396 .clients
397 .get(&a_key)
398 .unwrap()
399 .active
400 .start_shutdown();
401
402 tokio::time::timeout(Duration::from_secs(1), {
404 let clients = clients.clone();
405 async move {
406 while clients.0.clients.contains_key(&a_key) {
408 tokio::time::sleep(Duration::from_millis(100)).await;
409 }
410 }
411 })
412 .await
413 .std_context("timeout")?;
414
415 clients.shutdown().await;
416
417 Ok(())
418 }
419}