1use std::{
12 collections::{HashMap, VecDeque},
13 io,
14 ops::Deref,
15 sync::{
16 atomic::{AtomicUsize, Ordering},
17 Arc,
18 },
19 time::Duration,
20};
21
22use iroh::{
23 endpoint::{ConnectError, Connection},
24 Endpoint, EndpointId,
25};
26use n0_error::{e, stack_error};
27use n0_future::{
28 future::{self},
29 FuturesUnordered, MaybeFuture, Stream, StreamExt,
30};
31use tokio::sync::{
32 mpsc::{self, error::SendError as TokioSendError},
33 oneshot, Notify,
34};
35use tokio_util::time::FutureExt as TimeFutureExt;
36use tracing::{debug, error, info, trace};
37
38pub type OnConnected =
39 Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40
41#[derive(derive_more::Debug, Clone)]
43pub struct Options {
44 pub idle_timeout: Duration,
46 pub connect_timeout: Duration,
48 pub max_connections: usize,
50 #[debug(skip)]
54 pub on_connected: Option<OnConnected>,
55}
56
57impl Default for Options {
58 fn default() -> Self {
59 Self {
60 idle_timeout: Duration::from_secs(5),
61 connect_timeout: Duration::from_secs(1),
62 max_connections: 1024,
63 on_connected: None,
64 }
65 }
66}
67
68impl Options {
69 pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
71 where
72 F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
73 Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
74 {
75 self.on_connected = Some(Arc::new(move |ep, conn| {
76 let ep = ep.clone();
77 let conn = conn.clone();
78 Box::pin(f(ep, conn))
79 }));
80 self
81 }
82}
83
84#[derive(Debug)]
86pub struct ConnectionRef {
87 connection: iroh::endpoint::Connection,
88 _permit: OneConnection,
89}
90
91impl Deref for ConnectionRef {
92 type Target = iroh::endpoint::Connection;
93
94 fn deref(&self) -> &Self::Target {
95 &self.connection
96 }
97}
98
99impl ConnectionRef {
100 fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
101 Self {
102 connection,
103 _permit: counter,
104 }
105 }
106}
107
108#[stack_error(derive, add_meta)]
113#[derive(Clone)]
114pub enum PoolConnectError {
115 #[error("Connection pool is shut down")]
117 Shutdown {},
118 #[error("Timeout during connect")]
120 Timeout {},
121 #[error("Too many connections")]
123 TooManyConnections {},
124 #[error(transparent)]
126 ConnectError { source: Arc<ConnectError> },
127 #[error(transparent)]
129 OnConnectError {
130 #[error(std_err)]
131 source: Arc<io::Error>,
132 },
133}
134
135impl From<ConnectError> for PoolConnectError {
136 fn from(e: ConnectError) -> Self {
137 e!(PoolConnectError::ConnectError, Arc::new(e))
138 }
139}
140
141impl From<io::Error> for PoolConnectError {
142 fn from(e: io::Error) -> Self {
143 e!(PoolConnectError::OnConnectError, Arc::new(e))
144 }
145}
146
147#[stack_error(derive, add_meta)]
151pub enum ConnectionPoolError {
152 #[error("The connection pool has been shut down")]
154 Shutdown {},
155}
156
157enum ActorMessage {
158 RequestRef(RequestRef),
159 ConnectionIdle { id: EndpointId },
160 ConnectionShutdown { id: EndpointId },
161}
162
163struct RequestRef {
164 id: EndpointId,
165 tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
166}
167
168struct Context {
169 options: Options,
170 endpoint: Endpoint,
171 owner: ConnectionPool,
172 alpn: Vec<u8>,
173}
174
175impl Context {
176 async fn run_connection_actor(
177 self: Arc<Self>,
178 node_id: EndpointId,
179 mut rx: mpsc::Receiver<RequestRef>,
180 ) {
181 let context = self;
182
183 let conn_fut = {
184 let context = context.clone();
185 async move {
186 let conn = context
187 .endpoint
188 .connect(node_id, &context.alpn)
189 .await
190 .map_err(PoolConnectError::from)?;
191 if let Some(on_connect) = &context.options.on_connected {
192 on_connect(&context.endpoint, &conn)
193 .await
194 .map_err(PoolConnectError::from)?;
195 }
196 Result::<Connection, PoolConnectError>::Ok(conn)
197 }
198 };
199
200 let state = conn_fut
202 .timeout(context.options.connect_timeout)
203 .await
204 .map_err(|_| e!(PoolConnectError::Timeout))
205 .and_then(|r| r);
206 let conn_close = match &state {
207 Ok(conn) => {
208 let conn = conn.clone();
209 MaybeFuture::Some(async move { conn.closed().await })
210 }
211 Err(e) => {
212 debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
213 if context.owner.close(node_id).await.is_err() {
214 return;
215 }
216 MaybeFuture::None
217 }
218 };
219
220 let counter = ConnectionCounter::new();
221 let idle_timer = MaybeFuture::default();
222 let idle_stream = counter.clone().idle_stream();
223
224 tokio::pin!(idle_timer, idle_stream, conn_close);
225
226 loop {
227 tokio::select! {
228 biased;
229
230 handler = rx.recv() => {
232 match handler {
233 Some(RequestRef { id, tx }) => {
234 assert!(id == node_id, "Not for me!");
235 match &state {
236 Ok(state) => {
237 let res = ConnectionRef::new(state.clone(), counter.get_one());
238 info!(%node_id, "Handing out ConnectionRef {}", counter.current());
239
240 idle_timer.as_mut().set_none();
242 tx.send(Ok(res)).ok();
243 }
244 Err(cause) => {
245 tx.send(Err(cause.clone())).ok();
246 }
247 }
248 }
249 None => {
250 break;
252 }
253 }
254 }
255
256 _ = &mut conn_close => {
257 context.owner.close(node_id).await.ok();
259 }
260
261 _ = idle_stream.next() => {
262 if !counter.is_idle() {
263 continue;
264 };
265 trace!(%node_id, "Idle");
267 if context.owner.idle(node_id).await.is_err() {
268 break;
270 }
271 idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
273 }
274
275 _ = &mut idle_timer => {
277 trace!(%node_id, "Idle timer expired, requesting shutdown");
278 context.owner.close(node_id).await.ok();
279 }
281 }
282 }
283
284 if let Ok(connection) = state {
285 let reason = if counter.is_idle() { b"idle" } else { b"drop" };
286 connection.close(0u32.into(), reason);
287 }
288
289 trace!(%node_id, "Connection actor shutting down");
290 }
291}
292
293struct Actor {
294 rx: mpsc::Receiver<ActorMessage>,
295 connections: HashMap<EndpointId, mpsc::Sender<RequestRef>>,
296 context: Arc<Context>,
297 idle: VecDeque<EndpointId>,
300 tasks: FuturesUnordered<future::Boxed<()>>,
302}
303
304impl Actor {
305 pub fn new(
306 endpoint: Endpoint,
307 alpn: &[u8],
308 options: Options,
309 ) -> (Self, mpsc::Sender<ActorMessage>) {
310 let (tx, rx) = mpsc::channel(100);
311 (
312 Self {
313 rx,
314 connections: HashMap::new(),
315 idle: VecDeque::new(),
316 context: Arc::new(Context {
317 options,
318 alpn: alpn.to_vec(),
319 endpoint,
320 owner: ConnectionPool { tx: tx.clone() },
321 }),
322 tasks: FuturesUnordered::new(),
323 },
324 tx,
325 )
326 }
327
328 fn add_idle(&mut self, id: EndpointId) {
329 self.remove_idle(id);
330 self.idle.push_back(id);
331 }
332
333 fn remove_idle(&mut self, id: EndpointId) {
334 self.idle.retain(|&x| x != id);
335 }
336
337 fn pop_oldest_idle(&mut self) -> Option<EndpointId> {
338 self.idle.pop_front()
339 }
340
341 fn remove_connection(&mut self, id: EndpointId) {
342 self.connections.remove(&id);
343 self.remove_idle(id);
344 }
345
346 async fn handle_msg(&mut self, msg: ActorMessage) {
347 match msg {
348 ActorMessage::RequestRef(mut msg) => {
349 let id = msg.id;
350 self.remove_idle(id);
351 if let Some(conn_tx) = self.connections.get(&id) {
353 if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
354 msg = e;
355 } else {
356 return;
357 }
358 self.remove_connection(id);
360 }
361
362 if self.connections.len() >= self.context.options.max_connections {
364 if let Some(idle) = self.pop_oldest_idle() {
365 trace!("removing oldest idle connection {}", idle);
367 self.connections.remove(&idle);
368 } else {
369 msg.tx
370 .send(Err(e!(PoolConnectError::TooManyConnections)))
371 .ok();
372 return;
373 }
374 }
375 let (conn_tx, conn_rx) = mpsc::channel(100);
376 self.connections.insert(id, conn_tx.clone());
377
378 let context = self.context.clone();
379
380 self.tasks
381 .push(Box::pin(context.run_connection_actor(id, conn_rx)));
382
383 if conn_tx.send(msg).await.is_err() {
385 error!(%id, "Failed to send handler to new connection actor");
386 self.connections.remove(&id);
387 }
388 }
389 ActorMessage::ConnectionIdle { id } => {
390 self.add_idle(id);
391 trace!(%id, "connection idle");
392 }
393 ActorMessage::ConnectionShutdown { id } => {
394 self.remove_connection(id);
396 trace!(%id, "removed connection");
397 }
398 }
399 }
400
401 pub async fn run(mut self) {
402 loop {
403 tokio::select! {
404 biased;
405
406 msg = self.rx.recv() => {
407 if let Some(msg) = msg {
408 self.handle_msg(msg).await;
409 } else {
410 break;
411 }
412 }
413
414 _ = self.tasks.next(), if !self.tasks.is_empty() => {}
415 }
416 }
417 }
418}
419
420#[derive(Debug, Clone)]
422pub struct ConnectionPool {
423 tx: mpsc::Sender<ActorMessage>,
424}
425
426impl ConnectionPool {
427 pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
428 let (actor, tx) = Actor::new(endpoint, alpn, options);
429
430 tokio::spawn(actor.run());
432
433 Self { tx }
434 }
435
436 pub async fn get_or_connect(
441 &self,
442 id: EndpointId,
443 ) -> std::result::Result<ConnectionRef, PoolConnectError> {
444 let (tx, rx) = oneshot::channel();
445 self.tx
446 .send(ActorMessage::RequestRef(RequestRef { id, tx }))
447 .await
448 .map_err(|_| e!(PoolConnectError::Shutdown))?;
449 rx.await.map_err(|_| e!(PoolConnectError::Shutdown))?
450 }
451
452 pub async fn close(&self, id: EndpointId) -> std::result::Result<(), ConnectionPoolError> {
457 self.tx
458 .send(ActorMessage::ConnectionShutdown { id })
459 .await
460 .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
461 Ok(())
462 }
463
464 pub(crate) async fn idle(
468 &self,
469 id: EndpointId,
470 ) -> std::result::Result<(), ConnectionPoolError> {
471 self.tx
472 .send(ActorMessage::ConnectionIdle { id })
473 .await
474 .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
475 Ok(())
476 }
477}
478
479#[derive(Debug)]
480struct ConnectionCounterInner {
481 count: AtomicUsize,
482 notify: Notify,
483}
484
485#[derive(Debug, Clone)]
486struct ConnectionCounter {
487 inner: Arc<ConnectionCounterInner>,
488}
489
490impl ConnectionCounter {
491 fn new() -> Self {
492 Self {
493 inner: Arc::new(ConnectionCounterInner {
494 count: Default::default(),
495 notify: Notify::new(),
496 }),
497 }
498 }
499
500 fn current(&self) -> usize {
501 self.inner.count.load(Ordering::SeqCst)
502 }
503
504 fn get_one(&self) -> OneConnection {
506 self.inner.count.fetch_add(1, Ordering::SeqCst);
507 OneConnection {
508 inner: self.inner.clone(),
509 }
510 }
511
512 fn is_idle(&self) -> bool {
513 self.inner.count.load(Ordering::SeqCst) == 0
514 }
515
516 fn idle_stream(self) -> impl Stream<Item = ()> {
525 n0_future::stream::unfold(self, |c| async move {
526 c.inner.notify.notified().await;
527 Some(((), c))
528 })
529 }
530}
531
532#[derive(Debug)]
534struct OneConnection {
535 inner: Arc<ConnectionCounterInner>,
536}
537
538impl Drop for OneConnection {
539 fn drop(&mut self) {
540 if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
541 self.inner.notify.notify_waiters();
542 }
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use std::{collections::BTreeMap, sync::Arc, time::Duration};
549
550 use iroh::{
551 discovery::static_provider::StaticProvider,
552 endpoint::{Connection, ConnectionType},
553 protocol::{AcceptError, ProtocolHandler, Router},
554 Endpoint, EndpointAddr, EndpointId, RelayMode, SecretKey, TransportAddr, Watcher,
555 };
556 use n0_error::{AnyError, Result, StdResultExt};
557 use n0_future::{io, stream, BufferedStreamExt, StreamExt};
558 use testresult::TestResult;
559 use tracing::trace;
560
561 use super::{ConnectionPool, Options, PoolConnectError};
562 use crate::util::connection_pool::OnConnected;
563
564 const ECHO_ALPN: &[u8] = b"echo";
565
566 #[derive(Debug, Clone)]
567 struct Echo;
568
569 impl ProtocolHandler for Echo {
570 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
571 let conn_id = connection.stable_id();
572 let id = connection.remote_id().map_err(AcceptError::from_err)?;
573 trace!(%id, %conn_id, "Accepting echo connection");
574 loop {
575 match connection.accept_bi().await {
576 Ok((mut send, mut recv)) => {
577 trace!(%id, %conn_id, "Accepted echo request");
578 tokio::io::copy(&mut recv, &mut send).await?;
579 send.finish().map_err(AcceptError::from_err)?;
580 }
581 Err(e) => {
582 trace!(%id, %conn_id, "Failed to accept echo request {e}");
583 break;
584 }
585 }
586 }
587 Ok(())
588 }
589 }
590
591 async fn echo_client(conn: &Connection, text: &[u8]) -> Result<Vec<u8>> {
592 let conn_id = conn.stable_id();
593 let id = conn.remote_id().anyerr()?;
594 trace!(%id, %conn_id, "Sending echo request");
595 let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
596 send.write_all(text).await.anyerr()?;
597 send.finish().anyerr()?;
598 let response = recv.read_to_end(1000).await.anyerr()?;
599 trace!(%id, %conn_id, "Received echo response");
600 Ok(response)
601 }
602
603 async fn echo_server() -> TestResult<(EndpointAddr, Router)> {
604 let endpoint = iroh::Endpoint::builder()
605 .alpns(vec![ECHO_ALPN.to_vec()])
606 .bind()
607 .await?;
608 endpoint.online().await;
609 let addr = endpoint.addr();
610 let router = iroh::protocol::Router::builder(endpoint)
611 .accept(ECHO_ALPN, Echo)
612 .spawn();
613
614 Ok((addr, router))
615 }
616
617 async fn echo_servers(n: usize) -> TestResult<(Vec<EndpointId>, Vec<Router>, StaticProvider)> {
618 let res = stream::iter(0..n)
619 .map(|_| echo_server())
620 .buffered_unordered(16)
621 .collect::<Vec<_>>()
622 .await;
623 let res: Vec<(EndpointAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
624 let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
625 let ids = addrs.iter().map(|a| a.id).collect::<Vec<_>>();
626 let discovery = StaticProvider::from_endpoint_info(addrs);
627 Ok((ids, routers, discovery))
628 }
629
630 async fn shutdown_routers(routers: Vec<Router>) {
631 stream::iter(routers)
632 .for_each_concurrent(16, |router| async move {
633 let _ = router.shutdown().await;
634 })
635 .await;
636 }
637
638 fn test_options() -> Options {
639 Options {
640 idle_timeout: Duration::from_millis(100),
641 connect_timeout: Duration::from_secs(5),
642 max_connections: 32,
643 on_connected: None,
644 }
645 }
646
647 struct EchoClient {
648 pool: ConnectionPool,
649 }
650
651 impl EchoClient {
652 async fn echo(
653 &self,
654 id: EndpointId,
655 text: Vec<u8>,
656 ) -> Result<Result<(usize, Vec<u8>), AnyError>, PoolConnectError> {
657 let conn = self.pool.get_or_connect(id).await?;
658 let id = conn.stable_id();
659 match echo_client(&conn, &text).await {
660 Ok(res) => Ok(Ok((id, res))),
661 Err(e) => Ok(Err(e)),
662 }
663 }
664 }
665
666 #[tokio::test]
667 async fn connection_pool_errors() -> TestResult<()> {
669 let discovery = StaticProvider::new();
671 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
672 .discovery(discovery.clone())
673 .bind()
674 .await?;
675 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
676 let client = EchoClient { pool };
677 {
678 let non_existing = SecretKey::from_bytes(&[0; 32]).public();
679 let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
680 assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
683 }
684 {
685 let non_listening = SecretKey::from_bytes(&[0; 32]).public();
686 discovery.add_endpoint_info(EndpointAddr {
688 id: non_listening,
689 addrs: vec![TransportAddr::Ip("127.0.0.1:12121".parse().unwrap())]
690 .into_iter()
691 .collect(),
692 });
693 let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
696 assert!(matches!(res, Err(PoolConnectError::Timeout { .. })));
697 }
698 Ok(())
699 }
700
701 #[tokio::test]
702 async fn connection_pool_smoke() -> TestResult<()> {
704 let n = 32;
705 let (ids, routers, discovery) = echo_servers(n).await?;
706 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
708 .discovery(discovery.clone())
709 .bind()
710 .await?;
711 let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
712 let client = EchoClient { pool };
713 let mut connection_ids = BTreeMap::new();
714 let msg = b"Hello, pool!".to_vec();
715 for id in &ids {
716 let (cid1, res) = client.echo(*id, msg.clone()).await??;
717 assert_eq!(res, msg);
718 let (cid2, res) = client.echo(*id, msg.clone()).await??;
719 assert_eq!(res, msg);
720 assert_eq!(cid1, cid2);
721 connection_ids.insert(id, cid1);
722 }
723 tokio::time::sleep(Duration::from_millis(1000)).await;
724 for id in &ids {
725 let cid1 = *connection_ids.get(id).expect("Connection ID not found");
726 let (cid2, res) = client.echo(*id, msg.clone()).await??;
727 assert_eq!(res, msg);
728 assert_ne!(cid1, cid2);
729 }
730 shutdown_routers(routers).await;
731 Ok(())
732 }
733
734 #[tokio::test]
737 async fn connection_pool_idle() -> TestResult<()> {
739 let n = 32;
740 let (ids, routers, discovery) = echo_servers(n).await?;
741 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
743 .discovery(discovery.clone())
744 .bind()
745 .await?;
746 let pool = ConnectionPool::new(
747 endpoint.clone(),
748 ECHO_ALPN,
749 Options {
750 idle_timeout: Duration::from_secs(100),
751 max_connections: 8,
752 ..test_options()
753 },
754 );
755 let client = EchoClient { pool };
756 let msg = b"Hello, pool!".to_vec();
757 for id in &ids {
758 let (_, res) = client.echo(*id, msg.clone()).await??;
759 assert_eq!(res, msg);
760 }
761 shutdown_routers(routers).await;
762 Ok(())
763 }
764
765 #[tokio::test]
769 async fn on_connected_error() -> TestResult<()> {
771 let n = 1;
772 let (ids, routers, discovery) = echo_servers(n).await?;
773 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
774 .discovery(discovery)
775 .bind()
776 .await?;
777 let on_connected: OnConnected =
778 Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
779 let pool = ConnectionPool::new(
780 endpoint,
781 ECHO_ALPN,
782 Options {
783 on_connected: Some(on_connected),
784 ..test_options()
785 },
786 );
787 let client = EchoClient { pool };
788 let msg = b"Hello, pool!".to_vec();
789 for id in &ids {
790 let res = client.echo(*id, msg.clone()).await;
791 assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
792 }
793 shutdown_routers(routers).await;
794 Ok(())
795 }
796
797 #[tokio::test]
799 async fn on_connected_direct() -> TestResult<()> {
801 let n = 1;
802 let (ids, routers, discovery) = echo_servers(n).await?;
803 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
804 .discovery(discovery)
805 .bind()
806 .await?;
807 let on_connected = |ep: Endpoint, conn: Connection| async move {
808 let Ok(id) = conn.remote_id() else {
809 return Err(io::Error::other("unable to get endpoint id"));
810 };
811 let Some(watcher) = ep.conn_type(id) else {
812 return Err(io::Error::other("unable to get conn_type watcher"));
813 };
814 let mut stream = watcher.stream();
815 while let Some(status) = stream.next().await {
816 if let ConnectionType::Direct { .. } = status {
817 return Ok(());
818 }
819 }
820 Err(io::Error::other("connection closed before becoming direct"))
821 };
822 let pool = ConnectionPool::new(
823 endpoint,
824 ECHO_ALPN,
825 test_options().with_on_connected(on_connected),
826 );
827 let client = EchoClient { pool };
828 let msg = b"Hello, pool!".to_vec();
829 for id in &ids {
830 let res = client.echo(*id, msg.clone()).await;
831 assert!(res.is_ok());
832 }
833 shutdown_routers(routers).await;
834 Ok(())
835 }
836
837 #[tokio::test]
842 async fn watch_close() -> TestResult<()> {
844 let n = 1;
845 let (ids, routers, discovery) = echo_servers(n).await?;
846 let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
847 .discovery(discovery)
848 .bind()
849 .await?;
850
851 let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
852 let conn = pool.get_or_connect(ids[0]).await?;
853 let cid1 = conn.stable_id();
854 conn.close(0u32.into(), b"test");
855 tokio::time::sleep(Duration::from_millis(500)).await;
856 let conn = pool.get_or_connect(ids[0]).await?;
857 let cid2 = conn.stable_id();
858 assert_ne!(cid1, cid2);
859 shutdown_routers(routers).await;
860 Ok(())
861 }
862}