1use std::{
12    collections::{HashMap, VecDeque},
13    io,
14    ops::Deref,
15    sync::{
16        atomic::{AtomicUsize, Ordering},
17        Arc,
18    },
19    time::Duration,
20};
21
22use iroh::{
23    endpoint::{ConnectError, Connection},
24    Endpoint, EndpointId,
25};
26use n0_future::{
27    future::{self},
28    FuturesUnordered, MaybeFuture, Stream, StreamExt,
29};
30use snafu::Snafu;
31use tokio::sync::{
32    mpsc::{self, error::SendError as TokioSendError},
33    oneshot, Notify,
34};
35use tokio_util::time::FutureExt as TimeFutureExt;
36use tracing::{debug, error, info, trace};
37
38pub type OnConnected =
39    Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40
41#[derive(derive_more::Debug, Clone)]
43pub struct Options {
44    pub idle_timeout: Duration,
46    pub connect_timeout: Duration,
48    pub max_connections: usize,
50    #[debug(skip)]
54    pub on_connected: Option<OnConnected>,
55}
56
57impl Default for Options {
58    fn default() -> Self {
59        Self {
60            idle_timeout: Duration::from_secs(5),
61            connect_timeout: Duration::from_secs(1),
62            max_connections: 1024,
63            on_connected: None,
64        }
65    }
66}
67
68impl Options {
69    pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
71    where
72        F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
73        Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
74    {
75        self.on_connected = Some(Arc::new(move |ep, conn| {
76            let ep = ep.clone();
77            let conn = conn.clone();
78            Box::pin(f(ep, conn))
79        }));
80        self
81    }
82}
83
84#[derive(Debug)]
86pub struct ConnectionRef {
87    connection: iroh::endpoint::Connection,
88    _permit: OneConnection,
89}
90
91impl Deref for ConnectionRef {
92    type Target = iroh::endpoint::Connection;
93
94    fn deref(&self) -> &Self::Target {
95        &self.connection
96    }
97}
98
99impl ConnectionRef {
100    fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
101        Self {
102            connection,
103            _permit: counter,
104        }
105    }
106}
107
108#[derive(Debug, Clone, Snafu)]
113#[snafu(module)]
114pub enum PoolConnectError {
115    Shutdown,
117    Timeout,
119    TooManyConnections,
121    ConnectError { source: Arc<ConnectError> },
123    OnConnectError { source: Arc<io::Error> },
125}
126
127impl From<ConnectError> for PoolConnectError {
128    fn from(e: ConnectError) -> Self {
129        PoolConnectError::ConnectError {
130            source: Arc::new(e),
131        }
132    }
133}
134
135impl From<io::Error> for PoolConnectError {
136    fn from(e: io::Error) -> Self {
137        PoolConnectError::OnConnectError {
138            source: Arc::new(e),
139        }
140    }
141}
142
143#[derive(Debug, Snafu)]
147#[snafu(module)]
148pub enum ConnectionPoolError {
149    Shutdown,
151}
152
153enum ActorMessage {
154    RequestRef(RequestRef),
155    ConnectionIdle { id: EndpointId },
156    ConnectionShutdown { id: EndpointId },
157}
158
159struct RequestRef {
160    id: EndpointId,
161    tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
162}
163
164struct Context {
165    options: Options,
166    endpoint: Endpoint,
167    owner: ConnectionPool,
168    alpn: Vec<u8>,
169}
170
171impl Context {
172    async fn run_connection_actor(
173        self: Arc<Self>,
174        node_id: EndpointId,
175        mut rx: mpsc::Receiver<RequestRef>,
176    ) {
177        let context = self;
178
179        let conn_fut = {
180            let context = context.clone();
181            async move {
182                let conn = context
183                    .endpoint
184                    .connect(node_id, &context.alpn)
185                    .await
186                    .map_err(PoolConnectError::from)?;
187                if let Some(on_connect) = &context.options.on_connected {
188                    on_connect(&context.endpoint, &conn)
189                        .await
190                        .map_err(PoolConnectError::from)?;
191                }
192                Result::<Connection, PoolConnectError>::Ok(conn)
193            }
194        };
195
196        let state = conn_fut
198            .timeout(context.options.connect_timeout)
199            .await
200            .map_err(|_| PoolConnectError::Timeout)
201            .and_then(|r| r);
202        let conn_close = match &state {
203            Ok(conn) => {
204                let conn = conn.clone();
205                MaybeFuture::Some(async move { conn.closed().await })
206            }
207            Err(e) => {
208                debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
209                if context.owner.close(node_id).await.is_err() {
210                    return;
211                }
212                MaybeFuture::None
213            }
214        };
215
216        let counter = ConnectionCounter::new();
217        let idle_timer = MaybeFuture::default();
218        let idle_stream = counter.clone().idle_stream();
219
220        tokio::pin!(idle_timer, idle_stream, conn_close);
221
222        loop {
223            tokio::select! {
224                biased;
225
226                handler = rx.recv() => {
228                    match handler {
229                        Some(RequestRef { id, tx }) => {
230                            assert!(id == node_id, "Not for me!");
231                            match &state {
232                                Ok(state) => {
233                                    let res = ConnectionRef::new(state.clone(), counter.get_one());
234                                    info!(%node_id, "Handing out ConnectionRef {}", counter.current());
235
236                                    idle_timer.as_mut().set_none();
238                                    tx.send(Ok(res)).ok();
239                                }
240                                Err(cause) => {
241                                    tx.send(Err(cause.clone())).ok();
242                                }
243                            }
244                        }
245                        None => {
246                            break;
248                        }
249                    }
250                }
251
252                _ = &mut conn_close => {
253                    context.owner.close(node_id).await.ok();
255                }
256
257                _ = idle_stream.next() => {
258                    if !counter.is_idle() {
259                        continue;
260                    };
261                    trace!(%node_id, "Idle");
263                    if context.owner.idle(node_id).await.is_err() {
264                        break;
266                    }
267                    idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
269                }
270
271                _ = &mut idle_timer => {
273                    trace!(%node_id, "Idle timer expired, requesting shutdown");
274                    context.owner.close(node_id).await.ok();
275                    }
277            }
278        }
279
280        if let Ok(connection) = state {
281            let reason = if counter.is_idle() { b"idle" } else { b"drop" };
282            connection.close(0u32.into(), reason);
283        }
284
285        trace!(%node_id, "Connection actor shutting down");
286    }
287}
288
289struct Actor {
290    rx: mpsc::Receiver<ActorMessage>,
291    connections: HashMap<EndpointId, mpsc::Sender<RequestRef>>,
292    context: Arc<Context>,
293    idle: VecDeque<EndpointId>,
296    tasks: FuturesUnordered<future::Boxed<()>>,
298}
299
300impl Actor {
301    pub fn new(
302        endpoint: Endpoint,
303        alpn: &[u8],
304        options: Options,
305    ) -> (Self, mpsc::Sender<ActorMessage>) {
306        let (tx, rx) = mpsc::channel(100);
307        (
308            Self {
309                rx,
310                connections: HashMap::new(),
311                idle: VecDeque::new(),
312                context: Arc::new(Context {
313                    options,
314                    alpn: alpn.to_vec(),
315                    endpoint,
316                    owner: ConnectionPool { tx: tx.clone() },
317                }),
318                tasks: FuturesUnordered::new(),
319            },
320            tx,
321        )
322    }
323
324    fn add_idle(&mut self, id: EndpointId) {
325        self.remove_idle(id);
326        self.idle.push_back(id);
327    }
328
329    fn remove_idle(&mut self, id: EndpointId) {
330        self.idle.retain(|&x| x != id);
331    }
332
333    fn pop_oldest_idle(&mut self) -> Option<EndpointId> {
334        self.idle.pop_front()
335    }
336
337    fn remove_connection(&mut self, id: EndpointId) {
338        self.connections.remove(&id);
339        self.remove_idle(id);
340    }
341
342    async fn handle_msg(&mut self, msg: ActorMessage) {
343        match msg {
344            ActorMessage::RequestRef(mut msg) => {
345                let id = msg.id;
346                self.remove_idle(id);
347                if let Some(conn_tx) = self.connections.get(&id) {
349                    if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
350                        msg = e;
351                    } else {
352                        return;
353                    }
354                    self.remove_connection(id);
356                }
357
358                if self.connections.len() >= self.context.options.max_connections {
360                    if let Some(idle) = self.pop_oldest_idle() {
361                        trace!("removing oldest idle connection {}", idle);
363                        self.connections.remove(&idle);
364                    } else {
365                        msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok();
366                        return;
367                    }
368                }
369                let (conn_tx, conn_rx) = mpsc::channel(100);
370                self.connections.insert(id, conn_tx.clone());
371
372                let context = self.context.clone();
373
374                self.tasks
375                    .push(Box::pin(context.run_connection_actor(id, conn_rx)));
376
377                if conn_tx.send(msg).await.is_err() {
379                    error!(%id, "Failed to send handler to new connection actor");
380                    self.connections.remove(&id);
381                }
382            }
383            ActorMessage::ConnectionIdle { id } => {
384                self.add_idle(id);
385                trace!(%id, "connection idle");
386            }
387            ActorMessage::ConnectionShutdown { id } => {
388                self.remove_connection(id);
390                trace!(%id, "removed connection");
391            }
392        }
393    }
394
395    pub async fn run(mut self) {
396        loop {
397            tokio::select! {
398                biased;
399
400                msg = self.rx.recv() => {
401                    if let Some(msg) = msg {
402                        self.handle_msg(msg).await;
403                    } else {
404                        break;
405                    }
406                }
407
408                _ = self.tasks.next(), if !self.tasks.is_empty() => {}
409            }
410        }
411    }
412}
413
414#[derive(Debug, Clone)]
416pub struct ConnectionPool {
417    tx: mpsc::Sender<ActorMessage>,
418}
419
420impl ConnectionPool {
421    pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
422        let (actor, tx) = Actor::new(endpoint, alpn, options);
423
424        tokio::spawn(actor.run());
426
427        Self { tx }
428    }
429
430    pub async fn get_or_connect(
435        &self,
436        id: EndpointId,
437    ) -> std::result::Result<ConnectionRef, PoolConnectError> {
438        let (tx, rx) = oneshot::channel();
439        self.tx
440            .send(ActorMessage::RequestRef(RequestRef { id, tx }))
441            .await
442            .map_err(|_| PoolConnectError::Shutdown)?;
443        rx.await.map_err(|_| PoolConnectError::Shutdown)?
444    }
445
446    pub async fn close(&self, id: EndpointId) -> std::result::Result<(), ConnectionPoolError> {
451        self.tx
452            .send(ActorMessage::ConnectionShutdown { id })
453            .await
454            .map_err(|_| ConnectionPoolError::Shutdown)?;
455        Ok(())
456    }
457
458    pub(crate) async fn idle(
462        &self,
463        id: EndpointId,
464    ) -> std::result::Result<(), ConnectionPoolError> {
465        self.tx
466            .send(ActorMessage::ConnectionIdle { id })
467            .await
468            .map_err(|_| ConnectionPoolError::Shutdown)?;
469        Ok(())
470    }
471}
472
473#[derive(Debug)]
474struct ConnectionCounterInner {
475    count: AtomicUsize,
476    notify: Notify,
477}
478
479#[derive(Debug, Clone)]
480struct ConnectionCounter {
481    inner: Arc<ConnectionCounterInner>,
482}
483
484impl ConnectionCounter {
485    fn new() -> Self {
486        Self {
487            inner: Arc::new(ConnectionCounterInner {
488                count: Default::default(),
489                notify: Notify::new(),
490            }),
491        }
492    }
493
494    fn current(&self) -> usize {
495        self.inner.count.load(Ordering::SeqCst)
496    }
497
498    fn get_one(&self) -> OneConnection {
500        self.inner.count.fetch_add(1, Ordering::SeqCst);
501        OneConnection {
502            inner: self.inner.clone(),
503        }
504    }
505
506    fn is_idle(&self) -> bool {
507        self.inner.count.load(Ordering::SeqCst) == 0
508    }
509
510    fn idle_stream(self) -> impl Stream<Item = ()> {
519        n0_future::stream::unfold(self, |c| async move {
520            c.inner.notify.notified().await;
521            Some(((), c))
522        })
523    }
524}
525
526#[derive(Debug)]
528struct OneConnection {
529    inner: Arc<ConnectionCounterInner>,
530}
531
532impl Drop for OneConnection {
533    fn drop(&mut self) {
534        if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
535            self.inner.notify.notify_waiters();
536        }
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use std::{collections::BTreeMap, sync::Arc, time::Duration};
543
544    use iroh::{
545        discovery::static_provider::StaticProvider,
546        endpoint::{Connection, ConnectionType},
547        protocol::{AcceptError, ProtocolHandler, Router},
548        Endpoint, EndpointAddr, EndpointId, RelayMode, SecretKey, TransportAddr, Watcher,
549    };
550    use n0_future::{io, stream, BufferedStreamExt, StreamExt};
551    use n0_snafu::ResultExt;
552    use testresult::TestResult;
553    use tracing::trace;
554
555    use super::{ConnectionPool, Options, PoolConnectError};
556    use crate::util::connection_pool::OnConnected;
557
558    const ECHO_ALPN: &[u8] = b"echo";
559
560    #[derive(Debug, Clone)]
561    struct Echo;
562
563    impl ProtocolHandler for Echo {
564        async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
565            let conn_id = connection.stable_id();
566            let id = connection.remote_id().map_err(AcceptError::from_err)?;
567            trace!(%id, %conn_id, "Accepting echo connection");
568            loop {
569                match connection.accept_bi().await {
570                    Ok((mut send, mut recv)) => {
571                        trace!(%id, %conn_id, "Accepted echo request");
572                        tokio::io::copy(&mut recv, &mut send).await?;
573                        send.finish().map_err(AcceptError::from_err)?;
574                    }
575                    Err(e) => {
576                        trace!(%id, %conn_id, "Failed to accept echo request {e}");
577                        break;
578                    }
579                }
580            }
581            Ok(())
582        }
583    }
584
585    async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result<Vec<u8>> {
586        let conn_id = conn.stable_id();
587        let id = conn.remote_id().e()?;
588        trace!(%id, %conn_id, "Sending echo request");
589        let (mut send, mut recv) = conn.open_bi().await.e()?;
590        send.write_all(text).await.e()?;
591        send.finish().e()?;
592        let response = recv.read_to_end(1000).await.e()?;
593        trace!(%id, %conn_id, "Received echo response");
594        Ok(response)
595    }
596
597    async fn echo_server() -> TestResult<(EndpointAddr, Router)> {
598        let endpoint = iroh::Endpoint::builder()
599            .alpns(vec![ECHO_ALPN.to_vec()])
600            .bind()
601            .await?;
602        endpoint.online().await;
603        let addr = endpoint.addr();
604        let router = iroh::protocol::Router::builder(endpoint)
605            .accept(ECHO_ALPN, Echo)
606            .spawn();
607
608        Ok((addr, router))
609    }
610
611    async fn echo_servers(n: usize) -> TestResult<(Vec<EndpointId>, Vec<Router>, StaticProvider)> {
612        let res = stream::iter(0..n)
613            .map(|_| echo_server())
614            .buffered_unordered(16)
615            .collect::<Vec<_>>()
616            .await;
617        let res: Vec<(EndpointAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
618        let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
619        let ids = addrs.iter().map(|a| a.id).collect::<Vec<_>>();
620        let discovery = StaticProvider::from_endpoint_info(addrs);
621        Ok((ids, routers, discovery))
622    }
623
624    async fn shutdown_routers(routers: Vec<Router>) {
625        stream::iter(routers)
626            .for_each_concurrent(16, |router| async move {
627                let _ = router.shutdown().await;
628            })
629            .await;
630    }
631
632    fn test_options() -> Options {
633        Options {
634            idle_timeout: Duration::from_millis(100),
635            connect_timeout: Duration::from_secs(5),
636            max_connections: 32,
637            on_connected: None,
638        }
639    }
640
641    struct EchoClient {
642        pool: ConnectionPool,
643    }
644
645    impl EchoClient {
646        async fn echo(
647            &self,
648            id: EndpointId,
649            text: Vec<u8>,
650        ) -> Result<Result<(usize, Vec<u8>), n0_snafu::Error>, PoolConnectError> {
651            let conn = self.pool.get_or_connect(id).await?;
652            let id = conn.stable_id();
653            match echo_client(&conn, &text).await {
654                Ok(res) => Ok(Ok((id, res))),
655                Err(e) => Ok(Err(e)),
656            }
657        }
658    }
659
660    #[tokio::test]
661    async fn connection_pool_errors() -> TestResult<()> {
663        let discovery = StaticProvider::new();
665        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
666            .discovery(discovery.clone())
667            .bind()
668            .await?;
669        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
670        let client = EchoClient { pool };
671        {
672            let non_existing = SecretKey::from_bytes(&[0; 32]).public();
673            let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
674            assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
677        }
678        {
679            let non_listening = SecretKey::from_bytes(&[0; 32]).public();
680            discovery.add_endpoint_info(EndpointAddr {
682                id: non_listening,
683                addrs: vec![TransportAddr::Ip("127.0.0.1:12121".parse().unwrap())]
684                    .into_iter()
685                    .collect(),
686            });
687            let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
690            assert!(matches!(res, Err(PoolConnectError::Timeout)));
691        }
692        Ok(())
693    }
694
695    #[tokio::test]
696    async fn connection_pool_smoke() -> TestResult<()> {
698        let n = 32;
699        let (ids, routers, discovery) = echo_servers(n).await?;
700        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
702            .discovery(discovery.clone())
703            .bind()
704            .await?;
705        let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
706        let client = EchoClient { pool };
707        let mut connection_ids = BTreeMap::new();
708        let msg = b"Hello, pool!".to_vec();
709        for id in &ids {
710            let (cid1, res) = client.echo(*id, msg.clone()).await??;
711            assert_eq!(res, msg);
712            let (cid2, res) = client.echo(*id, msg.clone()).await??;
713            assert_eq!(res, msg);
714            assert_eq!(cid1, cid2);
715            connection_ids.insert(id, cid1);
716        }
717        tokio::time::sleep(Duration::from_millis(1000)).await;
718        for id in &ids {
719            let cid1 = *connection_ids.get(id).expect("Connection ID not found");
720            let (cid2, res) = client.echo(*id, msg.clone()).await??;
721            assert_eq!(res, msg);
722            assert_ne!(cid1, cid2);
723        }
724        shutdown_routers(routers).await;
725        Ok(())
726    }
727
728    #[tokio::test]
731    async fn connection_pool_idle() -> TestResult<()> {
733        let n = 32;
734        let (ids, routers, discovery) = echo_servers(n).await?;
735        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
737            .discovery(discovery.clone())
738            .bind()
739            .await?;
740        let pool = ConnectionPool::new(
741            endpoint.clone(),
742            ECHO_ALPN,
743            Options {
744                idle_timeout: Duration::from_secs(100),
745                max_connections: 8,
746                ..test_options()
747            },
748        );
749        let client = EchoClient { pool };
750        let msg = b"Hello, pool!".to_vec();
751        for id in &ids {
752            let (_, res) = client.echo(*id, msg.clone()).await??;
753            assert_eq!(res, msg);
754        }
755        shutdown_routers(routers).await;
756        Ok(())
757    }
758
759    #[tokio::test]
763    async fn on_connected_error() -> TestResult<()> {
765        let n = 1;
766        let (ids, routers, discovery) = echo_servers(n).await?;
767        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
768            .discovery(discovery)
769            .bind()
770            .await?;
771        let on_connected: OnConnected =
772            Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
773        let pool = ConnectionPool::new(
774            endpoint,
775            ECHO_ALPN,
776            Options {
777                on_connected: Some(on_connected),
778                ..test_options()
779            },
780        );
781        let client = EchoClient { pool };
782        let msg = b"Hello, pool!".to_vec();
783        for id in &ids {
784            let res = client.echo(*id, msg.clone()).await;
785            assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
786        }
787        shutdown_routers(routers).await;
788        Ok(())
789    }
790
791    #[tokio::test]
793    async fn on_connected_direct() -> TestResult<()> {
795        let n = 1;
796        let (ids, routers, discovery) = echo_servers(n).await?;
797        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
798            .discovery(discovery)
799            .bind()
800            .await?;
801        let on_connected = |ep: Endpoint, conn: Connection| async move {
802            let Ok(id) = conn.remote_id() else {
803                return Err(io::Error::other("unable to get endpoint id"));
804            };
805            let Some(watcher) = ep.conn_type(id) else {
806                return Err(io::Error::other("unable to get conn_type watcher"));
807            };
808            let mut stream = watcher.stream();
809            while let Some(status) = stream.next().await {
810                if let ConnectionType::Direct { .. } = status {
811                    return Ok(());
812                }
813            }
814            Err(io::Error::other("connection closed before becoming direct"))
815        };
816        let pool = ConnectionPool::new(
817            endpoint,
818            ECHO_ALPN,
819            test_options().with_on_connected(on_connected),
820        );
821        let client = EchoClient { pool };
822        let msg = b"Hello, pool!".to_vec();
823        for id in &ids {
824            let res = client.echo(*id, msg.clone()).await;
825            assert!(res.is_ok());
826        }
827        shutdown_routers(routers).await;
828        Ok(())
829    }
830
831    #[tokio::test]
836    async fn watch_close() -> TestResult<()> {
838        let n = 1;
839        let (ids, routers, discovery) = echo_servers(n).await?;
840        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
841            .discovery(discovery)
842            .bind()
843            .await?;
844
845        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
846        let conn = pool.get_or_connect(ids[0]).await?;
847        let cid1 = conn.stable_id();
848        conn.close(0u32.into(), b"test");
849        tokio::time::sleep(Duration::from_millis(500)).await;
850        let conn = pool.get_or_connect(ids[0]).await?;
851        let cid2 = conn.stable_id();
852        assert_ne!(cid1, cid2);
853        shutdown_routers(routers).await;
854        Ok(())
855    }
856}