1use std::{
12 collections::{HashMap, VecDeque},
13 io,
14 ops::Deref,
15 sync::{
16 atomic::{AtomicUsize, Ordering},
17 Arc,
18 },
19 time::Duration,
20};
21
22use iroh::{
23 endpoint::{ConnectError, Connection},
24 Endpoint, EndpointId,
25};
26use n0_future::{
27 future::{self},
28 FuturesUnordered, MaybeFuture, Stream, StreamExt,
29};
30use snafu::Snafu;
31use tokio::sync::{
32 mpsc::{self, error::SendError as TokioSendError},
33 oneshot, Notify,
34};
35use tokio_util::time::FutureExt as TimeFutureExt;
36use tracing::{debug, error, info, trace};
37
38pub type OnConnected =
39 Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40
41#[derive(derive_more::Debug, Clone)]
43pub struct Options {
44 pub idle_timeout: Duration,
46 pub connect_timeout: Duration,
48 pub max_connections: usize,
50 #[debug(skip)]
54 pub on_connected: Option<OnConnected>,
55}
56
57impl Default for Options {
58 fn default() -> Self {
59 Self {
60 idle_timeout: Duration::from_secs(5),
61 connect_timeout: Duration::from_secs(1),
62 max_connections: 1024,
63 on_connected: None,
64 }
65 }
66}
67
68impl Options {
69 pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
71 where
72 F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
73 Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
74 {
75 self.on_connected = Some(Arc::new(move |ep, conn| {
76 let ep = ep.clone();
77 let conn = conn.clone();
78 Box::pin(f(ep, conn))
79 }));
80 self
81 }
82}
83
84#[derive(Debug)]
86pub struct ConnectionRef {
87 connection: iroh::endpoint::Connection,
88 _permit: OneConnection,
89}
90
91impl Deref for ConnectionRef {
92 type Target = iroh::endpoint::Connection;
93
94 fn deref(&self) -> &Self::Target {
95 &self.connection
96 }
97}
98
99impl ConnectionRef {
100 fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
101 Self {
102 connection,
103 _permit: counter,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Snafu)]
113#[snafu(module)]
114pub enum PoolConnectError {
115 Shutdown,
117 Timeout,
119 TooManyConnections,
121 ConnectError { source: Arc<ConnectError> },
123 OnConnectError { source: Arc<io::Error> },
125}
126
127impl From<ConnectError> for PoolConnectError {
128 fn from(e: ConnectError) -> Self {
129 PoolConnectError::ConnectError {
130 source: Arc::new(e),
131 }
132 }
133}
134
135impl From<io::Error> for PoolConnectError {
136 fn from(e: io::Error) -> Self {
137 PoolConnectError::OnConnectError {
138 source: Arc::new(e),
139 }
140 }
141}
142
143#[derive(Debug, Snafu)]
147#[snafu(module)]
148pub enum ConnectionPoolError {
149 Shutdown,
151}
152
153enum ActorMessage {
154 RequestRef(RequestRef),
155 ConnectionIdle { id: EndpointId },
156 ConnectionShutdown { id: EndpointId },
157}
158
159struct RequestRef {
160 id: EndpointId,
161 tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
162}
163
164struct Context {
165 options: Options,
166 endpoint: Endpoint,
167 owner: ConnectionPool,
168 alpn: Vec<u8>,
169}
170
171impl Context {
172 async fn run_connection_actor(
173 self: Arc<Self>,
174 node_id: EndpointId,
175 mut rx: mpsc::Receiver<RequestRef>,
176 ) {
177 let context = self;
178
179 let conn_fut = {
180 let context = context.clone();
181 async move {
182 let conn = context
183 .endpoint
184 .connect(node_id, &context.alpn)
185 .await
186 .map_err(PoolConnectError::from)?;
187 if let Some(on_connect) = &context.options.on_connected {
188 on_connect(&context.endpoint, &conn)
189 .await
190 .map_err(PoolConnectError::from)?;
191 }
192 Result::<Connection, PoolConnectError>::Ok(conn)
193 }
194 };
195
196 let state = conn_fut
198 .timeout(context.options.connect_timeout)
199 .await
200 .map_err(|_| PoolConnectError::Timeout)
201 .and_then(|r| r);
202 let conn_close = match &state {
203 Ok(conn) => {
204 let conn = conn.clone();
205 MaybeFuture::Some(async move { conn.closed().await })
206 }
207 Err(e) => {
208 debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
209 if context.owner.close(node_id).await.is_err() {
210 return;
211 }
212 MaybeFuture::None
213 }
214 };
215
216 let counter = ConnectionCounter::new();
217 let idle_timer = MaybeFuture::default();
218 let idle_stream = counter.clone().idle_stream();
219
220 tokio::pin!(idle_timer, idle_stream, conn_close);
221
222 loop {
223 tokio::select! {
224 biased;
225
226 handler = rx.recv() => {
228 match handler {
229 Some(RequestRef { id, tx }) => {
230 assert!(id == node_id, "Not for me!");
231 match &state {
232 Ok(state) => {
233 let res = ConnectionRef::new(state.clone(), counter.get_one());
234 info!(%node_id, "Handing out ConnectionRef {}", counter.current());
235
236 idle_timer.as_mut().set_none();
238 tx.send(Ok(res)).ok();
239 }
240 Err(cause) => {
241 tx.send(Err(cause.clone())).ok();
242 }
243 }
244 }
245 None => {
246 break;
248 }
249 }
250 }
251
252 _ = &mut conn_close => {
253 context.owner.close(node_id).await.ok();
255 }
256
257 _ = idle_stream.next() => {
258 if !counter.is_idle() {
259 continue;
260 };
261 trace!(%node_id, "Idle");
263 if context.owner.idle(node_id).await.is_err() {
264 break;
266 }
267 idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
269 }
270
271 _ = &mut idle_timer => {
273 trace!(%node_id, "Idle timer expired, requesting shutdown");
274 context.owner.close(node_id).await.ok();
275 }
277 }
278 }
279
280 if let Ok(connection) = state {
281 let reason = if counter.is_idle() { b"idle" } else { b"drop" };
282 connection.close(0u32.into(), reason);
283 }
284
285 trace!(%node_id, "Connection actor shutting down");
286 }
287}
288
289struct Actor {
290 rx: mpsc::Receiver<ActorMessage>,
291 connections: HashMap<EndpointId, mpsc::Sender<RequestRef>>,
292 context: Arc<Context>,
293 idle: VecDeque<EndpointId>,
296 tasks: FuturesUnordered<future::Boxed<()>>,
298}
299
300impl Actor {
301 pub fn new(
302 endpoint: Endpoint,
303 alpn: &[u8],
304 options: Options,
305 ) -> (Self, mpsc::Sender<ActorMessage>) {
306 let (tx, rx) = mpsc::channel(100);
307 (
308 Self {
309 rx,
310 connections: HashMap::new(),
311 idle: VecDeque::new(),
312 context: Arc::new(Context {
313 options,
314 alpn: alpn.to_vec(),
315 endpoint,
316 owner: ConnectionPool { tx: tx.clone() },
317 }),
318 tasks: FuturesUnordered::new(),
319 },
320 tx,
321 )
322 }
323
324 fn add_idle(&mut self, id: EndpointId) {
325 self.remove_idle(id);
326 self.idle.push_back(id);
327 }
328
329 fn remove_idle(&mut self, id: EndpointId) {
330 self.idle.retain(|&x| x != id);
331 }
332
333 fn pop_oldest_idle(&mut self) -> Option<EndpointId> {
334 self.idle.pop_front()
335 }
336
337 fn remove_connection(&mut self, id: EndpointId) {
338 self.connections.remove(&id);
339 self.remove_idle(id);
340 }
341
342 async fn handle_msg(&mut self, msg: ActorMessage) {
343 match msg {
344 ActorMessage::RequestRef(mut msg) => {
345 let id = msg.id;
346 self.remove_idle(id);
347 if let Some(conn_tx) = self.connections.get(&id) {
349 if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
350 msg = e;
351 } else {
352 return;
353 }
354 self.remove_connection(id);
356 }
357
358 if self.connections.len() >= self.context.options.max_connections {
360 if let Some(idle) = self.pop_oldest_idle() {
361 trace!("removing oldest idle connection {}", idle);
363 self.connections.remove(&idle);
364 } else {
365 msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok();
366 return;
367 }
368 }
369 let (conn_tx, conn_rx) = mpsc::channel(100);
370 self.connections.insert(id, conn_tx.clone());
371
372 let context = self.context.clone();
373
374 self.tasks
375 .push(Box::pin(context.run_connection_actor(id, conn_rx)));
376
377 if conn_tx.send(msg).await.is_err() {
379 error!(%id, "Failed to send handler to new connection actor");
380 self.connections.remove(&id);
381 }
382 }
383 ActorMessage::ConnectionIdle { id } => {
384 self.add_idle(id);
385 trace!(%id, "connection idle");
386 }
387 ActorMessage::ConnectionShutdown { id } => {
388 self.remove_connection(id);
390 trace!(%id, "removed connection");
391 }
392 }
393 }
394
395 pub async fn run(mut self) {
396 loop {
397 tokio::select! {
398 biased;
399
400 msg = self.rx.recv() => {
401 if let Some(msg) = msg {
402 self.handle_msg(msg).await;
403 } else {
404 break;
405 }
406 }
407
408 _ = self.tasks.next(), if !self.tasks.is_empty() => {}
409 }
410 }
411 }
412}
413
414#[derive(Debug, Clone)]
416pub struct ConnectionPool {
417 tx: mpsc::Sender<ActorMessage>,
418}
419
420impl ConnectionPool {
421 pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
422 let (actor, tx) = Actor::new(endpoint, alpn, options);
423
424 tokio::spawn(actor.run());
426
427 Self { tx }
428 }
429
430 pub async fn get_or_connect(
435 &self,
436 id: EndpointId,
437 ) -> std::result::Result<ConnectionRef, PoolConnectError> {
438 let (tx, rx) = oneshot::channel();
439 self.tx
440 .send(ActorMessage::RequestRef(RequestRef { id, tx }))
441 .await
442 .map_err(|_| PoolConnectError::Shutdown)?;
443 rx.await.map_err(|_| PoolConnectError::Shutdown)?
444 }
445
446 pub async fn close(&self, id: EndpointId) -> std::result::Result<(), ConnectionPoolError> {
451 self.tx
452 .send(ActorMessage::ConnectionShutdown { id })
453 .await
454 .map_err(|_| ConnectionPoolError::Shutdown)?;
455 Ok(())
456 }
457
458 pub(crate) async fn idle(
462 &self,
463 id: EndpointId,
464 ) -> std::result::Result<(), ConnectionPoolError> {
465 self.tx
466 .send(ActorMessage::ConnectionIdle { id })
467 .await
468 .map_err(|_| ConnectionPoolError::Shutdown)?;
469 Ok(())
470 }
471}
472
473#[derive(Debug)]
474struct ConnectionCounterInner {
475 count: AtomicUsize,
476 notify: Notify,
477}
478
479#[derive(Debug, Clone)]
480struct ConnectionCounter {
481 inner: Arc<ConnectionCounterInner>,
482}
483
484impl ConnectionCounter {
485 fn new() -> Self {
486 Self {
487 inner: Arc::new(ConnectionCounterInner {
488 count: Default::default(),
489 notify: Notify::new(),
490 }),
491 }
492 }
493
494 fn current(&self) -> usize {
495 self.inner.count.load(Ordering::SeqCst)
496 }
497
498 fn get_one(&self) -> OneConnection {
500 self.inner.count.fetch_add(1, Ordering::SeqCst);
501 OneConnection {
502 inner: self.inner.clone(),
503 }
504 }
505
506 fn is_idle(&self) -> bool {
507 self.inner.count.load(Ordering::SeqCst) == 0
508 }
509
510 fn idle_stream(self) -> impl Stream<Item = ()> {
519 n0_future::stream::unfold(self, |c| async move {
520 c.inner.notify.notified().await;
521 Some(((), c))
522 })
523 }
524}
525
526#[derive(Debug)]
528struct OneConnection {
529 inner: Arc<ConnectionCounterInner>,
530}
531
532impl Drop for OneConnection {
533 fn drop(&mut self) {
534 if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
535 self.inner.notify.notify_waiters();
536 }
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use std::{collections::BTreeMap, sync::Arc, time::Duration};
543
544 use iroh::{
545 discovery::static_provider::StaticProvider,
546 endpoint::{Connection, ConnectionType},
547 protocol::{AcceptError, ProtocolHandler, Router},
548 Endpoint, EndpointAddr, EndpointId, RelayMode, SecretKey, TransportAddr, Watcher,
549 };
550 use n0_future::{io, stream, BufferedStreamExt, StreamExt};
551 use n0_snafu::ResultExt;
552 use testresult::TestResult;
553 use tracing::trace;
554
555 use super::{ConnectionPool, Options, PoolConnectError};
556 use crate::util::connection_pool::OnConnected;
557
558 const ECHO_ALPN: &[u8] = b"echo";
559
560 #[derive(Debug, Clone)]
561 struct Echo;
562
563 impl ProtocolHandler for Echo {
564 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
565 let conn_id = connection.stable_id();
566 let id = connection.remote_id().map_err(AcceptError::from_err)?;
567 trace!(%id, %conn_id, "Accepting echo connection");
568 loop {
569 match connection.accept_bi().await {
570 Ok((mut send, mut recv)) => {
571 trace!(%id, %conn_id, "Accepted echo request");
572 tokio::io::copy(&mut recv, &mut send).await?;
573 send.finish().map_err(AcceptError::from_err)?;
574 }
575 Err(e) => {
576 trace!(%id, %conn_id, "Failed to accept echo request {e}");
577 break;
578 }
579 }
580 }
581 Ok(())
582 }
583 }
584
585 async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result<Vec<u8>> {
586 let conn_id = conn.stable_id();
587 let id = conn.remote_id().e()?;
588 trace!(%id, %conn_id, "Sending echo request");
589 let (mut send, mut recv) = conn.open_bi().await.e()?;
590 send.write_all(text).await.e()?;
591 send.finish().e()?;
592 let response = recv.read_to_end(1000).await.e()?;
593 trace!(%id, %conn_id, "Received echo response");
594 Ok(response)
595 }
596
597 async fn echo_server() -> TestResult<(EndpointAddr, Router)> {
598 let endpoint = iroh::Endpoint::builder()
599 .alpns(vec![ECHO_ALPN.to_vec()])
600 .bind()
601 .await?;
602 endpoint.online().await;
603 let addr = endpoint.addr();
604 let router = iroh::protocol::Router::builder(endpoint)
605 .accept(ECHO_ALPN, Echo)
606 .spawn();
607
608 Ok((addr, router))
609 }
610
611 async fn echo_servers(n: usize) -> TestResult<(Vec<EndpointId>, Vec<Router>, StaticProvider)> {
612 let res = stream::iter(0..n)
613 .map(|_| echo_server())
614 .buffered_unordered(16)
615 .collect::<Vec<_>>()
616 .await;
617 let res: Vec<(EndpointAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
618 let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
619 let ids = addrs.iter().map(|a| a.id).collect::<Vec<_>>();
620 let discovery = StaticProvider::from_endpoint_info(addrs);
621 Ok((ids, routers, discovery))
622 }
623
624 async fn shutdown_routers(routers: Vec<Router>) {
625 stream::iter(routers)
626 .for_each_concurrent(16, |router| async move {
627 let _ = router.shutdown().await;
628 })
629 .await;
630 }
631
632 fn test_options() -> Options {
633 Options {
634 idle_timeout: Duration::from_millis(100),
635 connect_timeout: Duration::from_secs(5),
636 max_connections: 32,
637 on_connected: None,
638 }
639 }
640
641 struct EchoClient {
642 pool: ConnectionPool,
643 }
644
645 impl EchoClient {
646 async fn echo(
647 &self,
648 id: EndpointId,
649 text: Vec<u8>,
650 ) -> Result<Result<(usize, Vec<u8>), n0_snafu::Error>, PoolConnectError> {
651 let conn = self.pool.get_or_connect(id).await?;
652 let id = conn.stable_id();
653 match echo_client(&conn, &text).await {
654 Ok(res) => Ok(Ok((id, res))),
655 Err(e) => Ok(Err(e)),
656 }
657 }
658 }
659
660 #[tokio::test]
661 async fn connection_pool_errors() -> TestResult<()> {
663 let discovery = StaticProvider::new();
665 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
666 .discovery(discovery.clone())
667 .bind()
668 .await?;
669 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
670 let client = EchoClient { pool };
671 {
672 let non_existing = SecretKey::from_bytes(&[0; 32]).public();
673 let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
674 assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
677 }
678 {
679 let non_listening = SecretKey::from_bytes(&[0; 32]).public();
680 discovery.add_endpoint_info(EndpointAddr {
682 id: non_listening,
683 addrs: vec![TransportAddr::Ip("127.0.0.1:12121".parse().unwrap())]
684 .into_iter()
685 .collect(),
686 });
687 let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
690 assert!(matches!(res, Err(PoolConnectError::Timeout)));
691 }
692 Ok(())
693 }
694
695 #[tokio::test]
696 async fn connection_pool_smoke() -> TestResult<()> {
698 let n = 32;
699 let (ids, routers, discovery) = echo_servers(n).await?;
700 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
702 .discovery(discovery.clone())
703 .bind()
704 .await?;
705 let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
706 let client = EchoClient { pool };
707 let mut connection_ids = BTreeMap::new();
708 let msg = b"Hello, pool!".to_vec();
709 for id in &ids {
710 let (cid1, res) = client.echo(*id, msg.clone()).await??;
711 assert_eq!(res, msg);
712 let (cid2, res) = client.echo(*id, msg.clone()).await??;
713 assert_eq!(res, msg);
714 assert_eq!(cid1, cid2);
715 connection_ids.insert(id, cid1);
716 }
717 tokio::time::sleep(Duration::from_millis(1000)).await;
718 for id in &ids {
719 let cid1 = *connection_ids.get(id).expect("Connection ID not found");
720 let (cid2, res) = client.echo(*id, msg.clone()).await??;
721 assert_eq!(res, msg);
722 assert_ne!(cid1, cid2);
723 }
724 shutdown_routers(routers).await;
725 Ok(())
726 }
727
728 #[tokio::test]
731 async fn connection_pool_idle() -> TestResult<()> {
733 let n = 32;
734 let (ids, routers, discovery) = echo_servers(n).await?;
735 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
737 .discovery(discovery.clone())
738 .bind()
739 .await?;
740 let pool = ConnectionPool::new(
741 endpoint.clone(),
742 ECHO_ALPN,
743 Options {
744 idle_timeout: Duration::from_secs(100),
745 max_connections: 8,
746 ..test_options()
747 },
748 );
749 let client = EchoClient { pool };
750 let msg = b"Hello, pool!".to_vec();
751 for id in &ids {
752 let (_, res) = client.echo(*id, msg.clone()).await??;
753 assert_eq!(res, msg);
754 }
755 shutdown_routers(routers).await;
756 Ok(())
757 }
758
759 #[tokio::test]
763 async fn on_connected_error() -> TestResult<()> {
765 let n = 1;
766 let (ids, routers, discovery) = echo_servers(n).await?;
767 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
768 .discovery(discovery)
769 .bind()
770 .await?;
771 let on_connected: OnConnected =
772 Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
773 let pool = ConnectionPool::new(
774 endpoint,
775 ECHO_ALPN,
776 Options {
777 on_connected: Some(on_connected),
778 ..test_options()
779 },
780 );
781 let client = EchoClient { pool };
782 let msg = b"Hello, pool!".to_vec();
783 for id in &ids {
784 let res = client.echo(*id, msg.clone()).await;
785 assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
786 }
787 shutdown_routers(routers).await;
788 Ok(())
789 }
790
791 #[tokio::test]
793 async fn on_connected_direct() -> TestResult<()> {
795 let n = 1;
796 let (ids, routers, discovery) = echo_servers(n).await?;
797 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
798 .discovery(discovery)
799 .bind()
800 .await?;
801 let on_connected = |ep: Endpoint, conn: Connection| async move {
802 let Ok(id) = conn.remote_id() else {
803 return Err(io::Error::other("unable to get endpoint id"));
804 };
805 let Some(watcher) = ep.conn_type(id) else {
806 return Err(io::Error::other("unable to get conn_type watcher"));
807 };
808 let mut stream = watcher.stream();
809 while let Some(status) = stream.next().await {
810 if let ConnectionType::Direct { .. } = status {
811 return Ok(());
812 }
813 }
814 Err(io::Error::other("connection closed before becoming direct"))
815 };
816 let pool = ConnectionPool::new(
817 endpoint,
818 ECHO_ALPN,
819 test_options().with_on_connected(on_connected),
820 );
821 let client = EchoClient { pool };
822 let msg = b"Hello, pool!".to_vec();
823 for id in &ids {
824 let res = client.echo(*id, msg.clone()).await;
825 assert!(res.is_ok());
826 }
827 shutdown_routers(routers).await;
828 Ok(())
829 }
830
831 #[tokio::test]
836 async fn watch_close() -> TestResult<()> {
838 let n = 1;
839 let (ids, routers, discovery) = echo_servers(n).await?;
840 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
841 .discovery(discovery)
842 .bind()
843 .await?;
844
845 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
846 let conn = pool.get_or_connect(ids[0]).await?;
847 let cid1 = conn.stable_id();
848 conn.close(0u32.into(), b"test");
849 tokio::time::sleep(Duration::from_millis(500)).await;
850 let conn = pool.get_or_connect(ids[0]).await?;
851 let cid2 = conn.stable_id();
852 assert_ne!(cid1, cid2);
853 shutdown_routers(routers).await;
854 Ok(())
855 }
856}