1use std::{
12 collections::{HashMap, VecDeque},
13 io,
14 ops::Deref,
15 sync::{
16 atomic::{AtomicUsize, Ordering},
17 Arc,
18 },
19 time::Duration,
20};
21
22use iroh::{
23 endpoint::{ConnectError, Connection},
24 Endpoint, EndpointId,
25};
26use n0_error::{e, stack_error};
27use n0_future::{
28 future::{self},
29 FuturesUnordered, MaybeFuture, Stream, StreamExt,
30};
31use tokio::sync::{
32 mpsc::{self, error::SendError as TokioSendError},
33 oneshot, Notify,
34};
35use tracing::{debug, error, info, trace};
36
37pub type OnConnected =
38 Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
39
40#[derive(derive_more::Debug, Clone)]
42pub struct Options {
43 pub idle_timeout: Duration,
45 pub connect_timeout: Duration,
47 pub max_connections: usize,
49 #[debug(skip)]
53 pub on_connected: Option<OnConnected>,
54}
55
56impl Default for Options {
57 fn default() -> Self {
58 Self {
59 idle_timeout: Duration::from_secs(5),
60 connect_timeout: Duration::from_secs(1),
61 max_connections: 1024,
62 on_connected: None,
63 }
64 }
65}
66
67impl Options {
68 pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
70 where
71 F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
72 Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
73 {
74 self.on_connected = Some(Arc::new(move |ep, conn| {
75 let ep = ep.clone();
76 let conn = conn.clone();
77 Box::pin(f(ep, conn))
78 }));
79 self
80 }
81}
82
83#[derive(Debug)]
85pub struct ConnectionRef {
86 connection: iroh::endpoint::Connection,
87 _permit: OneConnection,
88}
89
90impl Deref for ConnectionRef {
91 type Target = iroh::endpoint::Connection;
92
93 fn deref(&self) -> &Self::Target {
94 &self.connection
95 }
96}
97
98impl ConnectionRef {
99 fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
100 Self {
101 connection,
102 _permit: counter,
103 }
104 }
105}
106
107#[stack_error(derive, add_meta)]
112#[derive(Clone)]
113pub enum PoolConnectError {
114 #[error("Connection pool is shut down")]
116 Shutdown {},
117 #[error("Timeout during connect")]
119 Timeout {},
120 #[error("Too many connections")]
122 TooManyConnections {},
123 #[error(transparent)]
125 ConnectError { source: Arc<ConnectError> },
126 #[error(transparent)]
128 OnConnectError {
129 #[error(std_err)]
130 source: Arc<io::Error>,
131 },
132}
133
134impl From<ConnectError> for PoolConnectError {
135 fn from(e: ConnectError) -> Self {
136 e!(PoolConnectError::ConnectError, Arc::new(e))
137 }
138}
139
140impl From<io::Error> for PoolConnectError {
141 fn from(e: io::Error) -> Self {
142 e!(PoolConnectError::OnConnectError, Arc::new(e))
143 }
144}
145
146#[stack_error(derive, add_meta)]
150pub enum ConnectionPoolError {
151 #[error("The connection pool has been shut down")]
153 Shutdown {},
154}
155
156enum ActorMessage {
157 RequestRef(RequestRef),
158 ConnectionIdle { id: EndpointId },
159 ConnectionShutdown { id: EndpointId },
160}
161
162struct RequestRef {
163 id: EndpointId,
164 tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
165}
166
167struct Context {
168 options: Options,
169 endpoint: Endpoint,
170 owner: ConnectionPool,
171 alpn: Vec<u8>,
172}
173
174impl Context {
175 async fn run_connection_actor(
176 self: Arc<Self>,
177 node_id: EndpointId,
178 mut rx: mpsc::Receiver<RequestRef>,
179 ) {
180 let context = self;
181
182 let conn_fut = {
183 let context = context.clone();
184 async move {
185 let conn = context
186 .endpoint
187 .connect(node_id, &context.alpn)
188 .await
189 .map_err(PoolConnectError::from)?;
190 if let Some(on_connect) = &context.options.on_connected {
191 on_connect(&context.endpoint, &conn)
192 .await
193 .map_err(PoolConnectError::from)?;
194 }
195 Result::<Connection, PoolConnectError>::Ok(conn)
196 }
197 };
198
199 let state = n0_future::time::timeout(context.options.connect_timeout, conn_fut)
201 .await
202 .map_err(|_| e!(PoolConnectError::Timeout))
203 .and_then(|r| r);
204 let conn_close = match &state {
205 Ok(conn) => {
206 let conn = conn.clone();
207 MaybeFuture::Some(async move { conn.closed().await })
208 }
209 Err(e) => {
210 debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
211 if context.owner.close(node_id).await.is_err() {
212 return;
213 }
214 MaybeFuture::None
215 }
216 };
217
218 let counter = ConnectionCounter::new();
219 let idle_timer = MaybeFuture::default();
220 let idle_stream = counter.clone().idle_stream();
221
222 tokio::pin!(idle_timer, idle_stream, conn_close);
223
224 loop {
225 tokio::select! {
226 biased;
227
228 handler = rx.recv() => {
230 match handler {
231 Some(RequestRef { id, tx }) => {
232 assert!(id == node_id, "Not for me!");
233 match &state {
234 Ok(state) => {
235 let res = ConnectionRef::new(state.clone(), counter.get_one());
236 info!(%node_id, "Handing out ConnectionRef {}", counter.current());
237
238 idle_timer.as_mut().set_none();
240 tx.send(Ok(res)).ok();
241 }
242 Err(cause) => {
243 tx.send(Err(cause.clone())).ok();
244 }
245 }
246 }
247 None => {
248 break;
250 }
251 }
252 }
253
254 _ = &mut conn_close => {
255 context.owner.close(node_id).await.ok();
257 }
258
259 _ = idle_stream.next() => {
260 if !counter.is_idle() {
261 continue;
262 };
263 trace!(%node_id, "Idle");
265 if context.owner.idle(node_id).await.is_err() {
266 break;
268 }
269 idle_timer.as_mut().set_future(n0_future::time::sleep(context.options.idle_timeout));
271 }
272
273 _ = &mut idle_timer => {
275 trace!(%node_id, "Idle timer expired, requesting shutdown");
276 context.owner.close(node_id).await.ok();
277 }
279 }
280 }
281
282 if let Ok(connection) = state {
283 let reason = if counter.is_idle() { b"idle" } else { b"drop" };
284 connection.close(0u32.into(), reason);
285 }
286
287 trace!(%node_id, "Connection actor shutting down");
288 }
289}
290
291struct Actor {
292 rx: mpsc::Receiver<ActorMessage>,
293 connections: HashMap<EndpointId, mpsc::Sender<RequestRef>>,
294 context: Arc<Context>,
295 idle: VecDeque<EndpointId>,
298 tasks: FuturesUnordered<future::Boxed<()>>,
300}
301
302impl Actor {
303 pub fn new(
304 endpoint: Endpoint,
305 alpn: &[u8],
306 options: Options,
307 ) -> (Self, mpsc::Sender<ActorMessage>) {
308 let (tx, rx) = mpsc::channel(100);
309 (
310 Self {
311 rx,
312 connections: HashMap::new(),
313 idle: VecDeque::new(),
314 context: Arc::new(Context {
315 options,
316 alpn: alpn.to_vec(),
317 endpoint,
318 owner: ConnectionPool { tx: tx.clone() },
319 }),
320 tasks: FuturesUnordered::new(),
321 },
322 tx,
323 )
324 }
325
326 fn add_idle(&mut self, id: EndpointId) {
327 self.remove_idle(id);
328 self.idle.push_back(id);
329 }
330
331 fn remove_idle(&mut self, id: EndpointId) {
332 self.idle.retain(|&x| x != id);
333 }
334
335 fn pop_oldest_idle(&mut self) -> Option<EndpointId> {
336 self.idle.pop_front()
337 }
338
339 fn remove_connection(&mut self, id: EndpointId) {
340 self.connections.remove(&id);
341 self.remove_idle(id);
342 }
343
344 async fn handle_msg(&mut self, msg: ActorMessage) {
345 match msg {
346 ActorMessage::RequestRef(mut msg) => {
347 let id = msg.id;
348 self.remove_idle(id);
349 if let Some(conn_tx) = self.connections.get(&id) {
351 if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
352 msg = e;
353 } else {
354 return;
355 }
356 self.remove_connection(id);
358 }
359
360 if self.connections.len() >= self.context.options.max_connections {
362 if let Some(idle) = self.pop_oldest_idle() {
363 trace!("removing oldest idle connection {}", idle);
365 self.connections.remove(&idle);
366 } else {
367 msg.tx
368 .send(Err(e!(PoolConnectError::TooManyConnections)))
369 .ok();
370 return;
371 }
372 }
373 let (conn_tx, conn_rx) = mpsc::channel(100);
374 self.connections.insert(id, conn_tx.clone());
375
376 let context = self.context.clone();
377
378 self.tasks
379 .push(Box::pin(context.run_connection_actor(id, conn_rx)));
380
381 if conn_tx.send(msg).await.is_err() {
383 error!(%id, "Failed to send handler to new connection actor");
384 self.connections.remove(&id);
385 }
386 }
387 ActorMessage::ConnectionIdle { id } => {
388 self.add_idle(id);
389 trace!(%id, "connection idle");
390 }
391 ActorMessage::ConnectionShutdown { id } => {
392 self.remove_connection(id);
394 trace!(%id, "removed connection");
395 }
396 }
397 }
398
399 pub async fn run(mut self) {
400 loop {
401 tokio::select! {
402 biased;
403
404 msg = self.rx.recv() => {
405 if let Some(msg) = msg {
406 self.handle_msg(msg).await;
407 } else {
408 break;
409 }
410 }
411
412 _ = self.tasks.next(), if !self.tasks.is_empty() => {}
413 }
414 }
415 }
416}
417
418#[derive(Debug, Clone)]
420pub struct ConnectionPool {
421 tx: mpsc::Sender<ActorMessage>,
422}
423
424impl ConnectionPool {
425 pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
426 let (actor, tx) = Actor::new(endpoint, alpn, options);
427
428 n0_future::task::spawn(actor.run());
430
431 Self { tx }
432 }
433
434 pub async fn get_or_connect(
439 &self,
440 id: EndpointId,
441 ) -> std::result::Result<ConnectionRef, PoolConnectError> {
442 let (tx, rx) = oneshot::channel();
443 self.tx
444 .send(ActorMessage::RequestRef(RequestRef { id, tx }))
445 .await
446 .map_err(|_| e!(PoolConnectError::Shutdown))?;
447 rx.await.map_err(|_| e!(PoolConnectError::Shutdown))?
448 }
449
450 pub async fn close(&self, id: EndpointId) -> std::result::Result<(), ConnectionPoolError> {
455 self.tx
456 .send(ActorMessage::ConnectionShutdown { id })
457 .await
458 .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
459 Ok(())
460 }
461
462 pub(crate) async fn idle(
466 &self,
467 id: EndpointId,
468 ) -> std::result::Result<(), ConnectionPoolError> {
469 self.tx
470 .send(ActorMessage::ConnectionIdle { id })
471 .await
472 .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
473 Ok(())
474 }
475}
476
477#[derive(Debug)]
478struct ConnectionCounterInner {
479 count: AtomicUsize,
480 notify: Notify,
481}
482
483#[derive(Debug, Clone)]
484struct ConnectionCounter {
485 inner: Arc<ConnectionCounterInner>,
486}
487
488impl ConnectionCounter {
489 fn new() -> Self {
490 Self {
491 inner: Arc::new(ConnectionCounterInner {
492 count: Default::default(),
493 notify: Notify::new(),
494 }),
495 }
496 }
497
498 fn current(&self) -> usize {
499 self.inner.count.load(Ordering::SeqCst)
500 }
501
502 fn get_one(&self) -> OneConnection {
504 self.inner.count.fetch_add(1, Ordering::SeqCst);
505 OneConnection {
506 inner: self.inner.clone(),
507 }
508 }
509
510 fn is_idle(&self) -> bool {
511 self.inner.count.load(Ordering::SeqCst) == 0
512 }
513
514 fn idle_stream(self) -> impl Stream<Item = ()> {
523 n0_future::stream::unfold(self, |c| async move {
524 c.inner.notify.notified().await;
525 Some(((), c))
526 })
527 }
528}
529
530#[derive(Debug)]
532struct OneConnection {
533 inner: Arc<ConnectionCounterInner>,
534}
535
536impl Drop for OneConnection {
537 fn drop(&mut self) {
538 if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
539 self.inner.notify.notify_waiters();
540 }
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use std::{collections::BTreeMap, sync::Arc, time::Duration};
547
548 use iroh::{
549 discovery::static_provider::StaticProvider,
550 endpoint::{Connection, ConnectionType},
551 protocol::{AcceptError, ProtocolHandler, Router},
552 Endpoint, EndpointAddr, EndpointId, RelayMode, SecretKey, TransportAddr, Watcher,
553 };
554 use n0_error::{AnyError, Result, StdResultExt};
555 use n0_future::{io, stream, BufferedStreamExt, StreamExt};
556 use testresult::TestResult;
557 use tracing::trace;
558
559 use super::{ConnectionPool, Options, PoolConnectError};
560 use crate::util::connection_pool::OnConnected;
561
562 const ECHO_ALPN: &[u8] = b"echo";
563
564 #[derive(Debug, Clone)]
565 struct Echo;
566
567 impl ProtocolHandler for Echo {
568 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
569 let conn_id = connection.stable_id();
570 let id = connection.remote_id();
571 trace!(%id, %conn_id, "Accepting echo connection");
572 loop {
573 match connection.accept_bi().await {
574 Ok((mut send, mut recv)) => {
575 trace!(%id, %conn_id, "Accepted echo request");
576 tokio::io::copy(&mut recv, &mut send).await?;
577 send.finish().map_err(AcceptError::from_err)?;
578 }
579 Err(e) => {
580 trace!(%id, %conn_id, "Failed to accept echo request {e}");
581 break;
582 }
583 }
584 }
585 Ok(())
586 }
587 }
588
589 async fn echo_client(conn: &Connection, text: &[u8]) -> Result<Vec<u8>> {
590 let conn_id = conn.stable_id();
591 let id = conn.remote_id();
592 trace!(%id, %conn_id, "Sending echo request");
593 let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
594 send.write_all(text).await.anyerr()?;
595 send.finish().anyerr()?;
596 let response = recv.read_to_end(1000).await.anyerr()?;
597 trace!(%id, %conn_id, "Received echo response");
598 Ok(response)
599 }
600
601 async fn echo_server() -> TestResult<(EndpointAddr, Router)> {
602 let endpoint = iroh::Endpoint::builder()
603 .alpns(vec![ECHO_ALPN.to_vec()])
604 .bind()
605 .await?;
606 endpoint.online().await;
607 let addr = endpoint.addr();
608 let router = iroh::protocol::Router::builder(endpoint)
609 .accept(ECHO_ALPN, Echo)
610 .spawn();
611
612 Ok((addr, router))
613 }
614
615 async fn echo_servers(n: usize) -> TestResult<(Vec<EndpointId>, Vec<Router>, StaticProvider)> {
616 let res = stream::iter(0..n)
617 .map(|_| echo_server())
618 .buffered_unordered(16)
619 .collect::<Vec<_>>()
620 .await;
621 let res: Vec<(EndpointAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
622 let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
623 let ids = addrs.iter().map(|a| a.id).collect::<Vec<_>>();
624 let discovery = StaticProvider::from_endpoint_info(addrs);
625 Ok((ids, routers, discovery))
626 }
627
628 async fn shutdown_routers(routers: Vec<Router>) {
629 stream::iter(routers)
630 .for_each_concurrent(16, |router| async move {
631 let _ = router.shutdown().await;
632 })
633 .await;
634 }
635
636 fn test_options() -> Options {
637 Options {
638 idle_timeout: Duration::from_millis(100),
639 connect_timeout: Duration::from_secs(5),
640 max_connections: 32,
641 on_connected: None,
642 }
643 }
644
645 struct EchoClient {
646 pool: ConnectionPool,
647 }
648
649 impl EchoClient {
650 async fn echo(
651 &self,
652 id: EndpointId,
653 text: Vec<u8>,
654 ) -> Result<Result<(usize, Vec<u8>), AnyError>, PoolConnectError> {
655 let conn = self.pool.get_or_connect(id).await?;
656 let id = conn.stable_id();
657 match echo_client(&conn, &text).await {
658 Ok(res) => Ok(Ok((id, res))),
659 Err(e) => Ok(Err(e)),
660 }
661 }
662 }
663
664 #[tokio::test]
665 async fn connection_pool_errors() -> TestResult<()> {
667 let discovery = StaticProvider::new();
669 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
670 .discovery(discovery.clone())
671 .bind()
672 .await?;
673 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
674 let client = EchoClient { pool };
675 {
676 let non_existing = SecretKey::from_bytes(&[0; 32]).public();
677 let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
678 assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
681 }
682 {
683 let non_listening = SecretKey::from_bytes(&[0; 32]).public();
684 discovery.add_endpoint_info(EndpointAddr {
686 id: non_listening,
687 addrs: vec![TransportAddr::Ip("127.0.0.1:12121".parse().unwrap())]
688 .into_iter()
689 .collect(),
690 });
691 let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
694 assert!(matches!(res, Err(PoolConnectError::Timeout { .. })));
695 }
696 Ok(())
697 }
698
699 #[tokio::test]
700 async fn connection_pool_smoke() -> TestResult<()> {
702 let n = 32;
703 let (ids, routers, discovery) = echo_servers(n).await?;
704 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
706 .discovery(discovery.clone())
707 .bind()
708 .await?;
709 let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
710 let client = EchoClient { pool };
711 let mut connection_ids = BTreeMap::new();
712 let msg = b"Hello, pool!".to_vec();
713 for id in &ids {
714 let (cid1, res) = client.echo(*id, msg.clone()).await??;
715 assert_eq!(res, msg);
716 let (cid2, res) = client.echo(*id, msg.clone()).await??;
717 assert_eq!(res, msg);
718 assert_eq!(cid1, cid2);
719 connection_ids.insert(id, cid1);
720 }
721 n0_future::time::sleep(Duration::from_millis(1000)).await;
722 for id in &ids {
723 let cid1 = *connection_ids.get(id).expect("Connection ID not found");
724 let (cid2, res) = client.echo(*id, msg.clone()).await??;
725 assert_eq!(res, msg);
726 assert_ne!(cid1, cid2);
727 }
728 shutdown_routers(routers).await;
729 Ok(())
730 }
731
732 #[tokio::test]
735 async fn connection_pool_idle() -> TestResult<()> {
737 let n = 32;
738 let (ids, routers, discovery) = echo_servers(n).await?;
739 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
741 .discovery(discovery.clone())
742 .bind()
743 .await?;
744 let pool = ConnectionPool::new(
745 endpoint.clone(),
746 ECHO_ALPN,
747 Options {
748 idle_timeout: Duration::from_secs(100),
749 max_connections: 8,
750 ..test_options()
751 },
752 );
753 let client = EchoClient { pool };
754 let msg = b"Hello, pool!".to_vec();
755 for id in &ids {
756 let (_, res) = client.echo(*id, msg.clone()).await??;
757 assert_eq!(res, msg);
758 }
759 shutdown_routers(routers).await;
760 Ok(())
761 }
762
763 #[tokio::test]
767 async fn on_connected_error() -> TestResult<()> {
769 let n = 1;
770 let (ids, routers, discovery) = echo_servers(n).await?;
771 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
772 .discovery(discovery)
773 .bind()
774 .await?;
775 let on_connected: OnConnected =
776 Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
777 let pool = ConnectionPool::new(
778 endpoint,
779 ECHO_ALPN,
780 Options {
781 on_connected: Some(on_connected),
782 ..test_options()
783 },
784 );
785 let client = EchoClient { pool };
786 let msg = b"Hello, pool!".to_vec();
787 for id in &ids {
788 let res = client.echo(*id, msg.clone()).await;
789 assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
790 }
791 shutdown_routers(routers).await;
792 Ok(())
793 }
794
795 #[tokio::test]
797 async fn on_connected_direct() -> TestResult<()> {
799 let n = 1;
800 let (ids, routers, discovery) = echo_servers(n).await?;
801 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
802 .discovery(discovery)
803 .bind()
804 .await?;
805 let on_connected = |ep: Endpoint, conn: Connection| async move {
806 let id = conn.remote_id();
807 let Some(watcher) = ep.conn_type(id) else {
808 return Err(io::Error::other("unable to get conn_type watcher"));
809 };
810 let mut stream = watcher.stream();
811 while let Some(status) = stream.next().await {
812 if let ConnectionType::Direct { .. } = status {
813 return Ok(());
814 }
815 }
816 Err(io::Error::other("connection closed before becoming direct"))
817 };
818 let pool = ConnectionPool::new(
819 endpoint,
820 ECHO_ALPN,
821 test_options().with_on_connected(on_connected),
822 );
823 let client = EchoClient { pool };
824 let msg = b"Hello, pool!".to_vec();
825 for id in &ids {
826 let res = client.echo(*id, msg.clone()).await;
827 assert!(res.is_ok());
828 }
829 shutdown_routers(routers).await;
830 Ok(())
831 }
832
833 #[tokio::test]
838 async fn watch_close() -> TestResult<()> {
840 let n = 1;
841 let (ids, routers, discovery) = echo_servers(n).await?;
842 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
843 .discovery(discovery)
844 .bind()
845 .await?;
846
847 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
848 let conn = pool.get_or_connect(ids[0]).await?;
849 let cid1 = conn.stable_id();
850 conn.close(0u32.into(), b"test");
851 n0_future::time::sleep(Duration::from_millis(500)).await;
852 let conn = pool.get_or_connect(ids[0]).await?;
853 let cid2 = conn.stable_id();
854 assert_ne!(cid1, cid2);
855 shutdown_routers(routers).await;
856 Ok(())
857 }
858}