1use std::{
12 collections::{HashMap, VecDeque},
13 io,
14 ops::Deref,
15 sync::{
16 atomic::{AtomicUsize, Ordering},
17 Arc,
18 },
19 time::Duration,
20};
21
22use iroh::{
23 endpoint::{ConnectError, Connection},
24 Endpoint, EndpointId,
25};
26use n0_future::{
27 future::{self},
28 FuturesUnordered, MaybeFuture, Stream, StreamExt,
29};
30use snafu::Snafu;
31use tokio::sync::{
32 mpsc::{self, error::SendError as TokioSendError},
33 oneshot, Notify,
34};
35use tracing::{debug, error, info, trace};
36
37pub type OnConnected =
38 Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
39
40#[derive(derive_more::Debug, Clone)]
42pub struct Options {
43 pub idle_timeout: Duration,
45 pub connect_timeout: Duration,
47 pub max_connections: usize,
49 #[debug(skip)]
53 pub on_connected: Option<OnConnected>,
54}
55
56impl Default for Options {
57 fn default() -> Self {
58 Self {
59 idle_timeout: Duration::from_secs(5),
60 connect_timeout: Duration::from_secs(1),
61 max_connections: 1024,
62 on_connected: None,
63 }
64 }
65}
66
67impl Options {
68 pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
70 where
71 F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
72 Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
73 {
74 self.on_connected = Some(Arc::new(move |ep, conn| {
75 let ep = ep.clone();
76 let conn = conn.clone();
77 Box::pin(f(ep, conn))
78 }));
79 self
80 }
81}
82
83#[derive(Debug)]
85pub struct ConnectionRef {
86 connection: iroh::endpoint::Connection,
87 _permit: OneConnection,
88}
89
90impl Deref for ConnectionRef {
91 type Target = iroh::endpoint::Connection;
92
93 fn deref(&self) -> &Self::Target {
94 &self.connection
95 }
96}
97
98impl ConnectionRef {
99 fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
100 Self {
101 connection,
102 _permit: counter,
103 }
104 }
105}
106
107#[derive(Debug, Clone, Snafu)]
112#[snafu(module)]
113pub enum PoolConnectError {
114 Shutdown,
116 Timeout,
118 TooManyConnections,
120 ConnectError { source: Arc<ConnectError> },
122 OnConnectError { source: Arc<io::Error> },
124}
125
126impl From<ConnectError> for PoolConnectError {
127 fn from(e: ConnectError) -> Self {
128 PoolConnectError::ConnectError {
129 source: Arc::new(e),
130 }
131 }
132}
133
134impl From<io::Error> for PoolConnectError {
135 fn from(e: io::Error) -> Self {
136 PoolConnectError::OnConnectError {
137 source: Arc::new(e),
138 }
139 }
140}
141
142#[derive(Debug, Snafu)]
146#[snafu(module)]
147pub enum ConnectionPoolError {
148 Shutdown,
150}
151
152enum ActorMessage {
153 RequestRef(RequestRef),
154 ConnectionIdle { id: EndpointId },
155 ConnectionShutdown { id: EndpointId },
156}
157
158struct RequestRef {
159 id: EndpointId,
160 tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
161}
162
163struct Context {
164 options: Options,
165 endpoint: Endpoint,
166 owner: ConnectionPool,
167 alpn: Vec<u8>,
168}
169
170impl Context {
171 async fn run_connection_actor(
172 self: Arc<Self>,
173 node_id: EndpointId,
174 mut rx: mpsc::Receiver<RequestRef>,
175 ) {
176 let context = self;
177
178 let conn_fut = {
179 let context = context.clone();
180 async move {
181 let conn = context
182 .endpoint
183 .connect(node_id, &context.alpn)
184 .await
185 .map_err(PoolConnectError::from)?;
186 if let Some(on_connect) = &context.options.on_connected {
187 on_connect(&context.endpoint, &conn)
188 .await
189 .map_err(PoolConnectError::from)?;
190 }
191 Result::<Connection, PoolConnectError>::Ok(conn)
192 }
193 };
194
195 let state = n0_future::time::timeout(context.options.connect_timeout, conn_fut)
197 .await
198 .map_err(|_| PoolConnectError::Timeout)
199 .and_then(|r| r);
200 let conn_close = match &state {
201 Ok(conn) => {
202 let conn = conn.clone();
203 MaybeFuture::Some(async move { conn.closed().await })
204 }
205 Err(e) => {
206 debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
207 if context.owner.close(node_id).await.is_err() {
208 return;
209 }
210 MaybeFuture::None
211 }
212 };
213
214 let counter = ConnectionCounter::new();
215 let idle_timer = MaybeFuture::default();
216 let idle_stream = counter.clone().idle_stream();
217
218 tokio::pin!(idle_timer, idle_stream, conn_close);
219
220 loop {
221 tokio::select! {
222 biased;
223
224 handler = rx.recv() => {
226 match handler {
227 Some(RequestRef { id, tx }) => {
228 assert!(id == node_id, "Not for me!");
229 match &state {
230 Ok(state) => {
231 let res = ConnectionRef::new(state.clone(), counter.get_one());
232 info!(%node_id, "Handing out ConnectionRef {}", counter.current());
233
234 idle_timer.as_mut().set_none();
236 tx.send(Ok(res)).ok();
237 }
238 Err(cause) => {
239 tx.send(Err(cause.clone())).ok();
240 }
241 }
242 }
243 None => {
244 break;
246 }
247 }
248 }
249
250 _ = &mut conn_close => {
251 context.owner.close(node_id).await.ok();
253 }
254
255 _ = idle_stream.next() => {
256 if !counter.is_idle() {
257 continue;
258 };
259 trace!(%node_id, "Idle");
261 if context.owner.idle(node_id).await.is_err() {
262 break;
264 }
265 idle_timer.as_mut().set_future(n0_future::time::sleep(context.options.idle_timeout));
267 }
268
269 _ = &mut idle_timer => {
271 trace!(%node_id, "Idle timer expired, requesting shutdown");
272 context.owner.close(node_id).await.ok();
273 }
275 }
276 }
277
278 if let Ok(connection) = state {
279 let reason = if counter.is_idle() { b"idle" } else { b"drop" };
280 connection.close(0u32.into(), reason);
281 }
282
283 trace!(%node_id, "Connection actor shutting down");
284 }
285}
286
287struct Actor {
288 rx: mpsc::Receiver<ActorMessage>,
289 connections: HashMap<EndpointId, mpsc::Sender<RequestRef>>,
290 context: Arc<Context>,
291 idle: VecDeque<EndpointId>,
294 tasks: FuturesUnordered<future::Boxed<()>>,
296}
297
298impl Actor {
299 pub fn new(
300 endpoint: Endpoint,
301 alpn: &[u8],
302 options: Options,
303 ) -> (Self, mpsc::Sender<ActorMessage>) {
304 let (tx, rx) = mpsc::channel(100);
305 (
306 Self {
307 rx,
308 connections: HashMap::new(),
309 idle: VecDeque::new(),
310 context: Arc::new(Context {
311 options,
312 alpn: alpn.to_vec(),
313 endpoint,
314 owner: ConnectionPool { tx: tx.clone() },
315 }),
316 tasks: FuturesUnordered::new(),
317 },
318 tx,
319 )
320 }
321
322 fn add_idle(&mut self, id: EndpointId) {
323 self.remove_idle(id);
324 self.idle.push_back(id);
325 }
326
327 fn remove_idle(&mut self, id: EndpointId) {
328 self.idle.retain(|&x| x != id);
329 }
330
331 fn pop_oldest_idle(&mut self) -> Option<EndpointId> {
332 self.idle.pop_front()
333 }
334
335 fn remove_connection(&mut self, id: EndpointId) {
336 self.connections.remove(&id);
337 self.remove_idle(id);
338 }
339
340 async fn handle_msg(&mut self, msg: ActorMessage) {
341 match msg {
342 ActorMessage::RequestRef(mut msg) => {
343 let id = msg.id;
344 self.remove_idle(id);
345 if let Some(conn_tx) = self.connections.get(&id) {
347 if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
348 msg = e;
349 } else {
350 return;
351 }
352 self.remove_connection(id);
354 }
355
356 if self.connections.len() >= self.context.options.max_connections {
358 if let Some(idle) = self.pop_oldest_idle() {
359 trace!("removing oldest idle connection {}", idle);
361 self.connections.remove(&idle);
362 } else {
363 msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok();
364 return;
365 }
366 }
367 let (conn_tx, conn_rx) = mpsc::channel(100);
368 self.connections.insert(id, conn_tx.clone());
369
370 let context = self.context.clone();
371
372 self.tasks
373 .push(Box::pin(context.run_connection_actor(id, conn_rx)));
374
375 if conn_tx.send(msg).await.is_err() {
377 error!(%id, "Failed to send handler to new connection actor");
378 self.connections.remove(&id);
379 }
380 }
381 ActorMessage::ConnectionIdle { id } => {
382 self.add_idle(id);
383 trace!(%id, "connection idle");
384 }
385 ActorMessage::ConnectionShutdown { id } => {
386 self.remove_connection(id);
388 trace!(%id, "removed connection");
389 }
390 }
391 }
392
393 pub async fn run(mut self) {
394 loop {
395 tokio::select! {
396 biased;
397
398 msg = self.rx.recv() => {
399 if let Some(msg) = msg {
400 self.handle_msg(msg).await;
401 } else {
402 break;
403 }
404 }
405
406 _ = self.tasks.next(), if !self.tasks.is_empty() => {}
407 }
408 }
409 }
410}
411
412#[derive(Debug, Clone)]
414pub struct ConnectionPool {
415 tx: mpsc::Sender<ActorMessage>,
416}
417
418impl ConnectionPool {
419 pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
420 let (actor, tx) = Actor::new(endpoint, alpn, options);
421
422 n0_future::task::spawn(actor.run());
424
425 Self { tx }
426 }
427
428 pub async fn get_or_connect(
433 &self,
434 id: EndpointId,
435 ) -> std::result::Result<ConnectionRef, PoolConnectError> {
436 let (tx, rx) = oneshot::channel();
437 self.tx
438 .send(ActorMessage::RequestRef(RequestRef { id, tx }))
439 .await
440 .map_err(|_| PoolConnectError::Shutdown)?;
441 rx.await.map_err(|_| PoolConnectError::Shutdown)?
442 }
443
444 pub async fn close(&self, id: EndpointId) -> std::result::Result<(), ConnectionPoolError> {
449 self.tx
450 .send(ActorMessage::ConnectionShutdown { id })
451 .await
452 .map_err(|_| ConnectionPoolError::Shutdown)?;
453 Ok(())
454 }
455
456 pub(crate) async fn idle(
460 &self,
461 id: EndpointId,
462 ) -> std::result::Result<(), ConnectionPoolError> {
463 self.tx
464 .send(ActorMessage::ConnectionIdle { id })
465 .await
466 .map_err(|_| ConnectionPoolError::Shutdown)?;
467 Ok(())
468 }
469}
470
471#[derive(Debug)]
472struct ConnectionCounterInner {
473 count: AtomicUsize,
474 notify: Notify,
475}
476
477#[derive(Debug, Clone)]
478struct ConnectionCounter {
479 inner: Arc<ConnectionCounterInner>,
480}
481
482impl ConnectionCounter {
483 fn new() -> Self {
484 Self {
485 inner: Arc::new(ConnectionCounterInner {
486 count: Default::default(),
487 notify: Notify::new(),
488 }),
489 }
490 }
491
492 fn current(&self) -> usize {
493 self.inner.count.load(Ordering::SeqCst)
494 }
495
496 fn get_one(&self) -> OneConnection {
498 self.inner.count.fetch_add(1, Ordering::SeqCst);
499 OneConnection {
500 inner: self.inner.clone(),
501 }
502 }
503
504 fn is_idle(&self) -> bool {
505 self.inner.count.load(Ordering::SeqCst) == 0
506 }
507
508 fn idle_stream(self) -> impl Stream<Item = ()> {
517 n0_future::stream::unfold(self, |c| async move {
518 c.inner.notify.notified().await;
519 Some(((), c))
520 })
521 }
522}
523
524#[derive(Debug)]
526struct OneConnection {
527 inner: Arc<ConnectionCounterInner>,
528}
529
530impl Drop for OneConnection {
531 fn drop(&mut self) {
532 if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
533 self.inner.notify.notify_waiters();
534 }
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use std::{collections::BTreeMap, sync::Arc, time::Duration};
541
542 use iroh::{
543 discovery::static_provider::StaticProvider,
544 endpoint::{Connection, ConnectionType},
545 protocol::{AcceptError, ProtocolHandler, Router},
546 Endpoint, EndpointAddr, EndpointId, RelayMode, SecretKey, TransportAddr, Watcher,
547 };
548 use n0_future::{io, stream, BufferedStreamExt, StreamExt};
549 use n0_snafu::ResultExt;
550 use testresult::TestResult;
551 use tracing::trace;
552
553 use super::{ConnectionPool, Options, PoolConnectError};
554 use crate::util::connection_pool::OnConnected;
555
556 const ECHO_ALPN: &[u8] = b"echo";
557
558 #[derive(Debug, Clone)]
559 struct Echo;
560
561 impl ProtocolHandler for Echo {
562 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
563 let conn_id = connection.stable_id();
564 let id = connection.remote_id().map_err(AcceptError::from_err)?;
565 trace!(%id, %conn_id, "Accepting echo connection");
566 loop {
567 match connection.accept_bi().await {
568 Ok((mut send, mut recv)) => {
569 trace!(%id, %conn_id, "Accepted echo request");
570 tokio::io::copy(&mut recv, &mut send).await?;
571 send.finish().map_err(AcceptError::from_err)?;
572 }
573 Err(e) => {
574 trace!(%id, %conn_id, "Failed to accept echo request {e}");
575 break;
576 }
577 }
578 }
579 Ok(())
580 }
581 }
582
583 async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result<Vec<u8>> {
584 let conn_id = conn.stable_id();
585 let id = conn.remote_id().e()?;
586 trace!(%id, %conn_id, "Sending echo request");
587 let (mut send, mut recv) = conn.open_bi().await.e()?;
588 send.write_all(text).await.e()?;
589 send.finish().e()?;
590 let response = recv.read_to_end(1000).await.e()?;
591 trace!(%id, %conn_id, "Received echo response");
592 Ok(response)
593 }
594
595 async fn echo_server() -> TestResult<(EndpointAddr, Router)> {
596 let endpoint = iroh::Endpoint::builder()
597 .alpns(vec![ECHO_ALPN.to_vec()])
598 .bind()
599 .await?;
600 endpoint.online().await;
601 let addr = endpoint.addr();
602 let router = iroh::protocol::Router::builder(endpoint)
603 .accept(ECHO_ALPN, Echo)
604 .spawn();
605
606 Ok((addr, router))
607 }
608
609 async fn echo_servers(n: usize) -> TestResult<(Vec<EndpointId>, Vec<Router>, StaticProvider)> {
610 let res = stream::iter(0..n)
611 .map(|_| echo_server())
612 .buffered_unordered(16)
613 .collect::<Vec<_>>()
614 .await;
615 let res: Vec<(EndpointAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
616 let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
617 let ids = addrs.iter().map(|a| a.id).collect::<Vec<_>>();
618 let discovery = StaticProvider::from_endpoint_info(addrs);
619 Ok((ids, routers, discovery))
620 }
621
622 async fn shutdown_routers(routers: Vec<Router>) {
623 stream::iter(routers)
624 .for_each_concurrent(16, |router| async move {
625 let _ = router.shutdown().await;
626 })
627 .await;
628 }
629
630 fn test_options() -> Options {
631 Options {
632 idle_timeout: Duration::from_millis(100),
633 connect_timeout: Duration::from_secs(5),
634 max_connections: 32,
635 on_connected: None,
636 }
637 }
638
639 struct EchoClient {
640 pool: ConnectionPool,
641 }
642
643 impl EchoClient {
644 async fn echo(
645 &self,
646 id: EndpointId,
647 text: Vec<u8>,
648 ) -> Result<Result<(usize, Vec<u8>), n0_snafu::Error>, PoolConnectError> {
649 let conn = self.pool.get_or_connect(id).await?;
650 let id = conn.stable_id();
651 match echo_client(&conn, &text).await {
652 Ok(res) => Ok(Ok((id, res))),
653 Err(e) => Ok(Err(e)),
654 }
655 }
656 }
657
658 #[tokio::test]
659 async fn connection_pool_errors() -> TestResult<()> {
661 let discovery = StaticProvider::new();
663 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
664 .discovery(discovery.clone())
665 .bind()
666 .await?;
667 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
668 let client = EchoClient { pool };
669 {
670 let non_existing = SecretKey::from_bytes(&[0; 32]).public();
671 let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
672 assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
675 }
676 {
677 let non_listening = SecretKey::from_bytes(&[0; 32]).public();
678 discovery.add_endpoint_info(EndpointAddr {
680 id: non_listening,
681 addrs: vec![TransportAddr::Ip("127.0.0.1:12121".parse().unwrap())]
682 .into_iter()
683 .collect(),
684 });
685 let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
688 assert!(matches!(res, Err(PoolConnectError::Timeout)));
689 }
690 Ok(())
691 }
692
693 #[tokio::test]
694 async fn connection_pool_smoke() -> TestResult<()> {
696 let n = 32;
697 let (ids, routers, discovery) = echo_servers(n).await?;
698 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
700 .discovery(discovery.clone())
701 .bind()
702 .await?;
703 let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
704 let client = EchoClient { pool };
705 let mut connection_ids = BTreeMap::new();
706 let msg = b"Hello, pool!".to_vec();
707 for id in &ids {
708 let (cid1, res) = client.echo(*id, msg.clone()).await??;
709 assert_eq!(res, msg);
710 let (cid2, res) = client.echo(*id, msg.clone()).await??;
711 assert_eq!(res, msg);
712 assert_eq!(cid1, cid2);
713 connection_ids.insert(id, cid1);
714 }
715 n0_future::time::sleep(Duration::from_millis(1000)).await;
716 for id in &ids {
717 let cid1 = *connection_ids.get(id).expect("Connection ID not found");
718 let (cid2, res) = client.echo(*id, msg.clone()).await??;
719 assert_eq!(res, msg);
720 assert_ne!(cid1, cid2);
721 }
722 shutdown_routers(routers).await;
723 Ok(())
724 }
725
726 #[tokio::test]
729 async fn connection_pool_idle() -> TestResult<()> {
731 let n = 32;
732 let (ids, routers, discovery) = echo_servers(n).await?;
733 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
735 .discovery(discovery.clone())
736 .bind()
737 .await?;
738 let pool = ConnectionPool::new(
739 endpoint.clone(),
740 ECHO_ALPN,
741 Options {
742 idle_timeout: Duration::from_secs(100),
743 max_connections: 8,
744 ..test_options()
745 },
746 );
747 let client = EchoClient { pool };
748 let msg = b"Hello, pool!".to_vec();
749 for id in &ids {
750 let (_, res) = client.echo(*id, msg.clone()).await??;
751 assert_eq!(res, msg);
752 }
753 shutdown_routers(routers).await;
754 Ok(())
755 }
756
757 #[tokio::test]
761 async fn on_connected_error() -> TestResult<()> {
763 let n = 1;
764 let (ids, routers, discovery) = echo_servers(n).await?;
765 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
766 .discovery(discovery)
767 .bind()
768 .await?;
769 let on_connected: OnConnected =
770 Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
771 let pool = ConnectionPool::new(
772 endpoint,
773 ECHO_ALPN,
774 Options {
775 on_connected: Some(on_connected),
776 ..test_options()
777 },
778 );
779 let client = EchoClient { pool };
780 let msg = b"Hello, pool!".to_vec();
781 for id in &ids {
782 let res = client.echo(*id, msg.clone()).await;
783 assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
784 }
785 shutdown_routers(routers).await;
786 Ok(())
787 }
788
789 #[tokio::test]
791 async fn on_connected_direct() -> TestResult<()> {
793 let n = 1;
794 let (ids, routers, discovery) = echo_servers(n).await?;
795 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
796 .discovery(discovery)
797 .bind()
798 .await?;
799 let on_connected = |ep: Endpoint, conn: Connection| async move {
800 let Ok(id) = conn.remote_id() else {
801 return Err(io::Error::other("unable to get endpoint id"));
802 };
803 let Some(watcher) = ep.conn_type(id) else {
804 return Err(io::Error::other("unable to get conn_type watcher"));
805 };
806 let mut stream = watcher.stream();
807 while let Some(status) = stream.next().await {
808 if let ConnectionType::Direct { .. } = status {
809 return Ok(());
810 }
811 }
812 Err(io::Error::other("connection closed before becoming direct"))
813 };
814 let pool = ConnectionPool::new(
815 endpoint,
816 ECHO_ALPN,
817 test_options().with_on_connected(on_connected),
818 );
819 let client = EchoClient { pool };
820 let msg = b"Hello, pool!".to_vec();
821 for id in &ids {
822 let res = client.echo(*id, msg.clone()).await;
823 assert!(res.is_ok());
824 }
825 shutdown_routers(routers).await;
826 Ok(())
827 }
828
829 #[tokio::test]
834 async fn watch_close() -> TestResult<()> {
836 let n = 1;
837 let (ids, routers, discovery) = echo_servers(n).await?;
838 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
839 .discovery(discovery)
840 .bind()
841 .await?;
842
843 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
844 let conn = pool.get_or_connect(ids[0]).await?;
845 let cid1 = conn.stable_id();
846 conn.close(0u32.into(), b"test");
847 n0_future::time::sleep(Duration::from_millis(500)).await;
848 let conn = pool.get_or_connect(ids[0]).await?;
849 let cid2 = conn.stable_id();
850 assert_ne!(cid1, cid2);
851 shutdown_routers(routers).await;
852 Ok(())
853 }
854}