iroh_blobs/util/
connection_pool.rs

1//! A simple iroh connection pool
2//!
3//! Entry point is [`ConnectionPool`]. You create a connection pool for a specific
4//! ALPN and [`Options`]. Then the pool will manage connections for you.
5//!
6//! Access to connections is via the [`ConnectionPool::get_or_connect`] method, which
7//! gives you access to a connection via a [`ConnectionRef`] if possible.
8//!
9//! It is important that you keep the [`ConnectionRef`] alive while you are using
10//! the connection.
11use std::{
12    collections::{HashMap, VecDeque},
13    io,
14    ops::Deref,
15    sync::{
16        atomic::{AtomicUsize, Ordering},
17        Arc,
18    },
19    time::Duration,
20};
21
22use iroh::{
23    endpoint::{ConnectError, Connection},
24    Endpoint, EndpointId,
25};
26use n0_error::{e, stack_error};
27use n0_future::{
28    future::{self},
29    FuturesUnordered, MaybeFuture, Stream, StreamExt,
30};
31use tokio::sync::{
32    mpsc::{self, error::SendError as TokioSendError},
33    oneshot, Notify,
34};
35use tokio_util::time::FutureExt as TimeFutureExt;
36use tracing::{debug, error, info, trace};
37
38pub type OnConnected =
39    Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40
41/// Configuration options for the connection pool
42#[derive(derive_more::Debug, Clone)]
43pub struct Options {
44    /// How long to keep idle connections around.
45    pub idle_timeout: Duration,
46    /// Timeout for connect. This includes the time spent in on_connect, if set.
47    pub connect_timeout: Duration,
48    /// Maximum number of connections to hand out.
49    pub max_connections: usize,
50    /// An optional callback that can be used to wait for the connection to enter some state.
51    /// An example usage could be to wait for the connection to become direct before handing
52    /// it out to the user.
53    #[debug(skip)]
54    pub on_connected: Option<OnConnected>,
55}
56
57impl Default for Options {
58    fn default() -> Self {
59        Self {
60            idle_timeout: Duration::from_secs(5),
61            connect_timeout: Duration::from_secs(1),
62            max_connections: 1024,
63            on_connected: None,
64        }
65    }
66}
67
68impl Options {
69    /// Set the on_connected callback
70    pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
71    where
72        F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
73        Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
74    {
75        self.on_connected = Some(Arc::new(move |ep, conn| {
76            let ep = ep.clone();
77            let conn = conn.clone();
78            Box::pin(f(ep, conn))
79        }));
80        self
81    }
82}
83
84/// A reference to a connection that is owned by a connection pool.
85#[derive(Debug)]
86pub struct ConnectionRef {
87    connection: iroh::endpoint::Connection,
88    _permit: OneConnection,
89}
90
91impl Deref for ConnectionRef {
92    type Target = iroh::endpoint::Connection;
93
94    fn deref(&self) -> &Self::Target {
95        &self.connection
96    }
97}
98
99impl ConnectionRef {
100    fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
101        Self {
102            connection,
103            _permit: counter,
104        }
105    }
106}
107
108/// Error when a connection can not be acquired
109///
110/// This includes the normal iroh connection errors as well as pool specific
111/// errors such as timeouts and connection limits.
112#[stack_error(derive, add_meta)]
113#[derive(Clone)]
114pub enum PoolConnectError {
115    /// Connection pool is shut down
116    #[error("Connection pool is shut down")]
117    Shutdown {},
118    /// Timeout during connect
119    #[error("Timeout during connect")]
120    Timeout {},
121    /// Too many connections
122    #[error("Too many connections")]
123    TooManyConnections {},
124    /// Error during connect
125    #[error(transparent)]
126    ConnectError { source: Arc<ConnectError> },
127    /// Error during on_connect callback
128    #[error(transparent)]
129    OnConnectError {
130        #[error(std_err)]
131        source: Arc<io::Error>,
132    },
133}
134
135impl From<ConnectError> for PoolConnectError {
136    fn from(e: ConnectError) -> Self {
137        e!(PoolConnectError::ConnectError, Arc::new(e))
138    }
139}
140
141impl From<io::Error> for PoolConnectError {
142    fn from(e: io::Error) -> Self {
143        e!(PoolConnectError::OnConnectError, Arc::new(e))
144    }
145}
146
147/// Error when calling a fn on the [`ConnectionPool`].
148///
149/// The only thing that can go wrong is that the connection pool is shut down.
150#[stack_error(derive, add_meta)]
151pub enum ConnectionPoolError {
152    /// The connection pool has been shut down
153    #[error("The connection pool has been shut down")]
154    Shutdown {},
155}
156
157enum ActorMessage {
158    RequestRef(RequestRef),
159    ConnectionIdle { id: EndpointId },
160    ConnectionShutdown { id: EndpointId },
161}
162
163struct RequestRef {
164    id: EndpointId,
165    tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
166}
167
168struct Context {
169    options: Options,
170    endpoint: Endpoint,
171    owner: ConnectionPool,
172    alpn: Vec<u8>,
173}
174
175impl Context {
176    async fn run_connection_actor(
177        self: Arc<Self>,
178        node_id: EndpointId,
179        mut rx: mpsc::Receiver<RequestRef>,
180    ) {
181        let context = self;
182
183        let conn_fut = {
184            let context = context.clone();
185            async move {
186                let conn = context
187                    .endpoint
188                    .connect(node_id, &context.alpn)
189                    .await
190                    .map_err(PoolConnectError::from)?;
191                if let Some(on_connect) = &context.options.on_connected {
192                    on_connect(&context.endpoint, &conn)
193                        .await
194                        .map_err(PoolConnectError::from)?;
195                }
196                Result::<Connection, PoolConnectError>::Ok(conn)
197            }
198        };
199
200        // Connect to the node
201        let state = conn_fut
202            .timeout(context.options.connect_timeout)
203            .await
204            .map_err(|_| e!(PoolConnectError::Timeout))
205            .and_then(|r| r);
206        let conn_close = match &state {
207            Ok(conn) => {
208                let conn = conn.clone();
209                MaybeFuture::Some(async move { conn.closed().await })
210            }
211            Err(e) => {
212                debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
213                if context.owner.close(node_id).await.is_err() {
214                    return;
215                }
216                MaybeFuture::None
217            }
218        };
219
220        let counter = ConnectionCounter::new();
221        let idle_timer = MaybeFuture::default();
222        let idle_stream = counter.clone().idle_stream();
223
224        tokio::pin!(idle_timer, idle_stream, conn_close);
225
226        loop {
227            tokio::select! {
228                biased;
229
230                // Handle new work
231                handler = rx.recv() => {
232                    match handler {
233                        Some(RequestRef { id, tx }) => {
234                            assert!(id == node_id, "Not for me!");
235                            match &state {
236                                Ok(state) => {
237                                    let res = ConnectionRef::new(state.clone(), counter.get_one());
238                                    info!(%node_id, "Handing out ConnectionRef {}", counter.current());
239
240                                    // clear the idle timer
241                                    idle_timer.as_mut().set_none();
242                                    tx.send(Ok(res)).ok();
243                                }
244                                Err(cause) => {
245                                    tx.send(Err(cause.clone())).ok();
246                                }
247                            }
248                        }
249                        None => {
250                            // Channel closed - exit
251                            break;
252                        }
253                    }
254                }
255
256                _ = &mut conn_close => {
257                    // connection was closed by somebody, notify owner that we should be removed
258                    context.owner.close(node_id).await.ok();
259                }
260
261                _ = idle_stream.next() => {
262                    if !counter.is_idle() {
263                        continue;
264                    };
265                    // notify the pool that we are idle.
266                    trace!(%node_id, "Idle");
267                    if context.owner.idle(node_id).await.is_err() {
268                        // If we can't notify the pool, we are shutting down
269                        break;
270                    }
271                    // set the idle timer
272                    idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
273                }
274
275                // Idle timeout - request shutdown
276                _ = &mut idle_timer => {
277                    trace!(%node_id, "Idle timer expired, requesting shutdown");
278                    context.owner.close(node_id).await.ok();
279                    // Don't break here - wait for main actor to close our channel
280                }
281            }
282        }
283
284        if let Ok(connection) = state {
285            let reason = if counter.is_idle() { b"idle" } else { b"drop" };
286            connection.close(0u32.into(), reason);
287        }
288
289        trace!(%node_id, "Connection actor shutting down");
290    }
291}
292
293struct Actor {
294    rx: mpsc::Receiver<ActorMessage>,
295    connections: HashMap<EndpointId, mpsc::Sender<RequestRef>>,
296    context: Arc<Context>,
297    // idle set (most recent last)
298    // todo: use a better data structure if this becomes a performance issue
299    idle: VecDeque<EndpointId>,
300    // per connection tasks
301    tasks: FuturesUnordered<future::Boxed<()>>,
302}
303
304impl Actor {
305    pub fn new(
306        endpoint: Endpoint,
307        alpn: &[u8],
308        options: Options,
309    ) -> (Self, mpsc::Sender<ActorMessage>) {
310        let (tx, rx) = mpsc::channel(100);
311        (
312            Self {
313                rx,
314                connections: HashMap::new(),
315                idle: VecDeque::new(),
316                context: Arc::new(Context {
317                    options,
318                    alpn: alpn.to_vec(),
319                    endpoint,
320                    owner: ConnectionPool { tx: tx.clone() },
321                }),
322                tasks: FuturesUnordered::new(),
323            },
324            tx,
325        )
326    }
327
328    fn add_idle(&mut self, id: EndpointId) {
329        self.remove_idle(id);
330        self.idle.push_back(id);
331    }
332
333    fn remove_idle(&mut self, id: EndpointId) {
334        self.idle.retain(|&x| x != id);
335    }
336
337    fn pop_oldest_idle(&mut self) -> Option<EndpointId> {
338        self.idle.pop_front()
339    }
340
341    fn remove_connection(&mut self, id: EndpointId) {
342        self.connections.remove(&id);
343        self.remove_idle(id);
344    }
345
346    async fn handle_msg(&mut self, msg: ActorMessage) {
347        match msg {
348            ActorMessage::RequestRef(mut msg) => {
349                let id = msg.id;
350                self.remove_idle(id);
351                // Try to send to existing connection actor
352                if let Some(conn_tx) = self.connections.get(&id) {
353                    if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
354                        msg = e;
355                    } else {
356                        return;
357                    }
358                    // Connection actor died, remove it
359                    self.remove_connection(id);
360                }
361
362                // No connection actor or it died - check limits
363                if self.connections.len() >= self.context.options.max_connections {
364                    if let Some(idle) = self.pop_oldest_idle() {
365                        // remove the oldest idle connection to make room for one more
366                        trace!("removing oldest idle connection {}", idle);
367                        self.connections.remove(&idle);
368                    } else {
369                        msg.tx
370                            .send(Err(e!(PoolConnectError::TooManyConnections)))
371                            .ok();
372                        return;
373                    }
374                }
375                let (conn_tx, conn_rx) = mpsc::channel(100);
376                self.connections.insert(id, conn_tx.clone());
377
378                let context = self.context.clone();
379
380                self.tasks
381                    .push(Box::pin(context.run_connection_actor(id, conn_rx)));
382
383                // Send the handler to the new actor
384                if conn_tx.send(msg).await.is_err() {
385                    error!(%id, "Failed to send handler to new connection actor");
386                    self.connections.remove(&id);
387                }
388            }
389            ActorMessage::ConnectionIdle { id } => {
390                self.add_idle(id);
391                trace!(%id, "connection idle");
392            }
393            ActorMessage::ConnectionShutdown { id } => {
394                // Remove the connection from our map - this closes the channel
395                self.remove_connection(id);
396                trace!(%id, "removed connection");
397            }
398        }
399    }
400
401    pub async fn run(mut self) {
402        loop {
403            tokio::select! {
404                biased;
405
406                msg = self.rx.recv() => {
407                    if let Some(msg) = msg {
408                        self.handle_msg(msg).await;
409                    } else {
410                        break;
411                    }
412                }
413
414                _ = self.tasks.next(), if !self.tasks.is_empty() => {}
415            }
416        }
417    }
418}
419
420/// A connection pool
421#[derive(Debug, Clone)]
422pub struct ConnectionPool {
423    tx: mpsc::Sender<ActorMessage>,
424}
425
426impl ConnectionPool {
427    pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
428        let (actor, tx) = Actor::new(endpoint, alpn, options);
429
430        // Spawn the main actor
431        tokio::spawn(actor.run());
432
433        Self { tx }
434    }
435
436    /// Returns either a fresh connection or a reference to an existing one.
437    ///
438    /// This is guaranteed to return after approximately [Options::connect_timeout]
439    /// with either an error or a connection.
440    pub async fn get_or_connect(
441        &self,
442        id: EndpointId,
443    ) -> std::result::Result<ConnectionRef, PoolConnectError> {
444        let (tx, rx) = oneshot::channel();
445        self.tx
446            .send(ActorMessage::RequestRef(RequestRef { id, tx }))
447            .await
448            .map_err(|_| e!(PoolConnectError::Shutdown))?;
449        rx.await.map_err(|_| e!(PoolConnectError::Shutdown))?
450    }
451
452    /// Close an existing connection, if it exists
453    ///
454    /// This will finish pending tasks and close the connection. New tasks will
455    /// get a new connection if they are submitted after this call
456    pub async fn close(&self, id: EndpointId) -> std::result::Result<(), ConnectionPoolError> {
457        self.tx
458            .send(ActorMessage::ConnectionShutdown { id })
459            .await
460            .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
461        Ok(())
462    }
463
464    /// Notify the connection pool that a connection is idle.
465    ///
466    /// Should only be called from connection handlers.
467    pub(crate) async fn idle(
468        &self,
469        id: EndpointId,
470    ) -> std::result::Result<(), ConnectionPoolError> {
471        self.tx
472            .send(ActorMessage::ConnectionIdle { id })
473            .await
474            .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
475        Ok(())
476    }
477}
478
479#[derive(Debug)]
480struct ConnectionCounterInner {
481    count: AtomicUsize,
482    notify: Notify,
483}
484
485#[derive(Debug, Clone)]
486struct ConnectionCounter {
487    inner: Arc<ConnectionCounterInner>,
488}
489
490impl ConnectionCounter {
491    fn new() -> Self {
492        Self {
493            inner: Arc::new(ConnectionCounterInner {
494                count: Default::default(),
495                notify: Notify::new(),
496            }),
497        }
498    }
499
500    fn current(&self) -> usize {
501        self.inner.count.load(Ordering::SeqCst)
502    }
503
504    /// Increase the connection count and return a guard for the new connection
505    fn get_one(&self) -> OneConnection {
506        self.inner.count.fetch_add(1, Ordering::SeqCst);
507        OneConnection {
508            inner: self.inner.clone(),
509        }
510    }
511
512    fn is_idle(&self) -> bool {
513        self.inner.count.load(Ordering::SeqCst) == 0
514    }
515
516    /// Infinite stream that yields when the connection is briefly idle.
517    ///
518    /// Note that you still have to check if the connection is still idle when
519    /// you get the notification.
520    ///
521    /// Also note that this stream is triggered on [OneConnection::drop], so it
522    /// won't trigger initially even though a [ConnectionCounter] starts up as
523    /// idle.
524    fn idle_stream(self) -> impl Stream<Item = ()> {
525        n0_future::stream::unfold(self, |c| async move {
526            c.inner.notify.notified().await;
527            Some(((), c))
528        })
529    }
530}
531
532/// Guard for one connection
533#[derive(Debug)]
534struct OneConnection {
535    inner: Arc<ConnectionCounterInner>,
536}
537
538impl Drop for OneConnection {
539    fn drop(&mut self) {
540        if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
541            self.inner.notify.notify_waiters();
542        }
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use std::{collections::BTreeMap, sync::Arc, time::Duration};
549
550    use iroh::{
551        discovery::static_provider::StaticProvider,
552        endpoint::{Connection, ConnectionType},
553        protocol::{AcceptError, ProtocolHandler, Router},
554        Endpoint, EndpointAddr, EndpointId, RelayMode, SecretKey, TransportAddr, Watcher,
555    };
556    use n0_error::{AnyError, Result, StdResultExt};
557    use n0_future::{io, stream, BufferedStreamExt, StreamExt};
558    use testresult::TestResult;
559    use tracing::trace;
560
561    use super::{ConnectionPool, Options, PoolConnectError};
562    use crate::util::connection_pool::OnConnected;
563
564    const ECHO_ALPN: &[u8] = b"echo";
565
566    #[derive(Debug, Clone)]
567    struct Echo;
568
569    impl ProtocolHandler for Echo {
570        async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
571            let conn_id = connection.stable_id();
572            let id = connection.remote_id().map_err(AcceptError::from_err)?;
573            trace!(%id, %conn_id, "Accepting echo connection");
574            loop {
575                match connection.accept_bi().await {
576                    Ok((mut send, mut recv)) => {
577                        trace!(%id, %conn_id, "Accepted echo request");
578                        tokio::io::copy(&mut recv, &mut send).await?;
579                        send.finish().map_err(AcceptError::from_err)?;
580                    }
581                    Err(e) => {
582                        trace!(%id, %conn_id, "Failed to accept echo request {e}");
583                        break;
584                    }
585                }
586            }
587            Ok(())
588        }
589    }
590
591    async fn echo_client(conn: &Connection, text: &[u8]) -> Result<Vec<u8>> {
592        let conn_id = conn.stable_id();
593        let id = conn.remote_id().anyerr()?;
594        trace!(%id, %conn_id, "Sending echo request");
595        let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
596        send.write_all(text).await.anyerr()?;
597        send.finish().anyerr()?;
598        let response = recv.read_to_end(1000).await.anyerr()?;
599        trace!(%id, %conn_id, "Received echo response");
600        Ok(response)
601    }
602
603    async fn echo_server() -> TestResult<(EndpointAddr, Router)> {
604        let endpoint = iroh::Endpoint::builder()
605            .alpns(vec![ECHO_ALPN.to_vec()])
606            .bind()
607            .await?;
608        endpoint.online().await;
609        let addr = endpoint.addr();
610        let router = iroh::protocol::Router::builder(endpoint)
611            .accept(ECHO_ALPN, Echo)
612            .spawn();
613
614        Ok((addr, router))
615    }
616
617    async fn echo_servers(n: usize) -> TestResult<(Vec<EndpointId>, Vec<Router>, StaticProvider)> {
618        let res = stream::iter(0..n)
619            .map(|_| echo_server())
620            .buffered_unordered(16)
621            .collect::<Vec<_>>()
622            .await;
623        let res: Vec<(EndpointAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
624        let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
625        let ids = addrs.iter().map(|a| a.id).collect::<Vec<_>>();
626        let discovery = StaticProvider::from_endpoint_info(addrs);
627        Ok((ids, routers, discovery))
628    }
629
630    async fn shutdown_routers(routers: Vec<Router>) {
631        stream::iter(routers)
632            .for_each_concurrent(16, |router| async move {
633                let _ = router.shutdown().await;
634            })
635            .await;
636    }
637
638    fn test_options() -> Options {
639        Options {
640            idle_timeout: Duration::from_millis(100),
641            connect_timeout: Duration::from_secs(5),
642            max_connections: 32,
643            on_connected: None,
644        }
645    }
646
647    struct EchoClient {
648        pool: ConnectionPool,
649    }
650
651    impl EchoClient {
652        async fn echo(
653            &self,
654            id: EndpointId,
655            text: Vec<u8>,
656        ) -> Result<Result<(usize, Vec<u8>), AnyError>, PoolConnectError> {
657            let conn = self.pool.get_or_connect(id).await?;
658            let id = conn.stable_id();
659            match echo_client(&conn, &text).await {
660                Ok(res) => Ok(Ok((id, res))),
661                Err(e) => Ok(Err(e)),
662            }
663        }
664    }
665
666    #[tokio::test]
667    // #[traced_test]
668    async fn connection_pool_errors() -> TestResult<()> {
669        // set up static discovery for all addrs
670        let discovery = StaticProvider::new();
671        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
672            .discovery(discovery.clone())
673            .bind()
674            .await?;
675        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
676        let client = EchoClient { pool };
677        {
678            let non_existing = SecretKey::from_bytes(&[0; 32]).public();
679            let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
680            // trying to connect to a non-existing id will fail with ConnectError
681            // because we don't have any information about the endpoint.
682            assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
683        }
684        {
685            let non_listening = SecretKey::from_bytes(&[0; 32]).public();
686            // make up fake node info
687            discovery.add_endpoint_info(EndpointAddr {
688                id: non_listening,
689                addrs: vec![TransportAddr::Ip("127.0.0.1:12121".parse().unwrap())]
690                    .into_iter()
691                    .collect(),
692            });
693            // trying to connect to an id for which we have info, but the other
694            // end is not listening, will lead to a timeout.
695            let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
696            assert!(matches!(res, Err(PoolConnectError::Timeout { .. })));
697        }
698        Ok(())
699    }
700
701    #[tokio::test]
702    // #[traced_test]
703    async fn connection_pool_smoke() -> TestResult<()> {
704        let n = 32;
705        let (ids, routers, discovery) = echo_servers(n).await?;
706        // build a client endpoint that can resolve all the endpoint ids
707        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
708            .discovery(discovery.clone())
709            .bind()
710            .await?;
711        let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
712        let client = EchoClient { pool };
713        let mut connection_ids = BTreeMap::new();
714        let msg = b"Hello, pool!".to_vec();
715        for id in &ids {
716            let (cid1, res) = client.echo(*id, msg.clone()).await??;
717            assert_eq!(res, msg);
718            let (cid2, res) = client.echo(*id, msg.clone()).await??;
719            assert_eq!(res, msg);
720            assert_eq!(cid1, cid2);
721            connection_ids.insert(id, cid1);
722        }
723        tokio::time::sleep(Duration::from_millis(1000)).await;
724        for id in &ids {
725            let cid1 = *connection_ids.get(id).expect("Connection ID not found");
726            let (cid2, res) = client.echo(*id, msg.clone()).await??;
727            assert_eq!(res, msg);
728            assert_ne!(cid1, cid2);
729        }
730        shutdown_routers(routers).await;
731        Ok(())
732    }
733
734    /// Tests that idle connections are being reclaimed to make room if we hit the
735    /// maximum connection limit.
736    #[tokio::test]
737    // #[traced_test]
738    async fn connection_pool_idle() -> TestResult<()> {
739        let n = 32;
740        let (ids, routers, discovery) = echo_servers(n).await?;
741        // build a client endpoint that can resolve all the endpoint ids
742        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
743            .discovery(discovery.clone())
744            .bind()
745            .await?;
746        let pool = ConnectionPool::new(
747            endpoint.clone(),
748            ECHO_ALPN,
749            Options {
750                idle_timeout: Duration::from_secs(100),
751                max_connections: 8,
752                ..test_options()
753            },
754        );
755        let client = EchoClient { pool };
756        let msg = b"Hello, pool!".to_vec();
757        for id in &ids {
758            let (_, res) = client.echo(*id, msg.clone()).await??;
759            assert_eq!(res, msg);
760        }
761        shutdown_routers(routers).await;
762        Ok(())
763    }
764
765    /// Uses an on_connected callback that just errors out every time.
766    ///
767    /// This is a basic smoke test that on_connected gets called at all.
768    #[tokio::test]
769    // #[traced_test]
770    async fn on_connected_error() -> TestResult<()> {
771        let n = 1;
772        let (ids, routers, discovery) = echo_servers(n).await?;
773        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
774            .discovery(discovery)
775            .bind()
776            .await?;
777        let on_connected: OnConnected =
778            Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
779        let pool = ConnectionPool::new(
780            endpoint,
781            ECHO_ALPN,
782            Options {
783                on_connected: Some(on_connected),
784                ..test_options()
785            },
786        );
787        let client = EchoClient { pool };
788        let msg = b"Hello, pool!".to_vec();
789        for id in &ids {
790            let res = client.echo(*id, msg.clone()).await;
791            assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
792        }
793        shutdown_routers(routers).await;
794        Ok(())
795    }
796
797    /// Uses an on_connected callback to ensure that the connection is direct.
798    #[tokio::test]
799    // #[traced_test]
800    async fn on_connected_direct() -> TestResult<()> {
801        let n = 1;
802        let (ids, routers, discovery) = echo_servers(n).await?;
803        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
804            .discovery(discovery)
805            .bind()
806            .await?;
807        let on_connected = |ep: Endpoint, conn: Connection| async move {
808            let Ok(id) = conn.remote_id() else {
809                return Err(io::Error::other("unable to get endpoint id"));
810            };
811            let Some(watcher) = ep.conn_type(id) else {
812                return Err(io::Error::other("unable to get conn_type watcher"));
813            };
814            let mut stream = watcher.stream();
815            while let Some(status) = stream.next().await {
816                if let ConnectionType::Direct { .. } = status {
817                    return Ok(());
818                }
819            }
820            Err(io::Error::other("connection closed before becoming direct"))
821        };
822        let pool = ConnectionPool::new(
823            endpoint,
824            ECHO_ALPN,
825            test_options().with_on_connected(on_connected),
826        );
827        let client = EchoClient { pool };
828        let msg = b"Hello, pool!".to_vec();
829        for id in &ids {
830            let res = client.echo(*id, msg.clone()).await;
831            assert!(res.is_ok());
832        }
833        shutdown_routers(routers).await;
834        Ok(())
835    }
836
837    /// Check that when a connection is closed, the pool will give you a new
838    /// connection next time you want one.
839    ///
840    /// This test fails if the connection watch is disabled.
841    #[tokio::test]
842    // #[traced_test]
843    async fn watch_close() -> TestResult<()> {
844        let n = 1;
845        let (ids, routers, discovery) = echo_servers(n).await?;
846        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
847            .discovery(discovery)
848            .bind()
849            .await?;
850
851        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
852        let conn = pool.get_or_connect(ids[0]).await?;
853        let cid1 = conn.stable_id();
854        conn.close(0u32.into(), b"test");
855        tokio::time::sleep(Duration::from_millis(500)).await;
856        let conn = pool.get_or_connect(ids[0]).await?;
857        let cid2 = conn.stable_id();
858        assert_ne!(cid1, cid2);
859        shutdown_routers(routers).await;
860        Ok(())
861    }
862}