iroh_blobs/util/
connection_pool.rs

1//! A simple iroh connection pool
2//!
3//! Entry point is [`ConnectionPool`]. You create a connection pool for a specific
4//! ALPN and [`Options`]. Then the pool will manage connections for you.
5//!
6//! Access to connections is via the [`ConnectionPool::get_or_connect`] method, which
7//! gives you access to a connection via a [`ConnectionRef`] if possible.
8//!
9//! It is important that you keep the [`ConnectionRef`] alive while you are using
10//! the connection.
11use std::{
12    collections::{HashMap, VecDeque},
13    io,
14    ops::Deref,
15    sync::{
16        atomic::{AtomicUsize, Ordering},
17        Arc,
18    },
19    time::Duration,
20};
21
22use iroh::{
23    endpoint::{ConnectError, Connection},
24    Endpoint, EndpointId,
25};
26use n0_error::{e, stack_error};
27use n0_future::{
28    future::{self},
29    FuturesUnordered, MaybeFuture, Stream, StreamExt,
30};
31use tokio::sync::{
32    mpsc::{self, error::SendError as TokioSendError},
33    oneshot, Notify,
34};
35use tracing::{debug, error, info, trace};
36
37pub type OnConnected =
38    Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
39
40/// Configuration options for the connection pool
41#[derive(derive_more::Debug, Clone)]
42pub struct Options {
43    /// How long to keep idle connections around.
44    pub idle_timeout: Duration,
45    /// Timeout for connect. This includes the time spent in on_connect, if set.
46    pub connect_timeout: Duration,
47    /// Maximum number of connections to hand out.
48    pub max_connections: usize,
49    /// An optional callback that can be used to wait for the connection to enter some state.
50    /// An example usage could be to wait for the connection to become direct before handing
51    /// it out to the user.
52    #[debug(skip)]
53    pub on_connected: Option<OnConnected>,
54}
55
56impl Default for Options {
57    fn default() -> Self {
58        Self {
59            idle_timeout: Duration::from_secs(5),
60            connect_timeout: Duration::from_secs(1),
61            max_connections: 1024,
62            on_connected: None,
63        }
64    }
65}
66
67impl Options {
68    /// Set the on_connected callback
69    pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
70    where
71        F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
72        Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
73    {
74        self.on_connected = Some(Arc::new(move |ep, conn| {
75            let ep = ep.clone();
76            let conn = conn.clone();
77            Box::pin(f(ep, conn))
78        }));
79        self
80    }
81}
82
83/// A reference to a connection that is owned by a connection pool.
84#[derive(Debug)]
85pub struct ConnectionRef {
86    connection: iroh::endpoint::Connection,
87    _permit: OneConnection,
88}
89
90impl Deref for ConnectionRef {
91    type Target = iroh::endpoint::Connection;
92
93    fn deref(&self) -> &Self::Target {
94        &self.connection
95    }
96}
97
98impl ConnectionRef {
99    fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
100        Self {
101            connection,
102            _permit: counter,
103        }
104    }
105}
106
107/// Error when a connection can not be acquired
108///
109/// This includes the normal iroh connection errors as well as pool specific
110/// errors such as timeouts and connection limits.
111#[stack_error(derive, add_meta)]
112#[derive(Clone)]
113pub enum PoolConnectError {
114    /// Connection pool is shut down
115    #[error("Connection pool is shut down")]
116    Shutdown {},
117    /// Timeout during connect
118    #[error("Timeout during connect")]
119    Timeout {},
120    /// Too many connections
121    #[error("Too many connections")]
122    TooManyConnections {},
123    /// Error during connect
124    #[error(transparent)]
125    ConnectError { source: Arc<ConnectError> },
126    /// Error during on_connect callback
127    #[error(transparent)]
128    OnConnectError {
129        #[error(std_err)]
130        source: Arc<io::Error>,
131    },
132}
133
134impl From<ConnectError> for PoolConnectError {
135    fn from(e: ConnectError) -> Self {
136        e!(PoolConnectError::ConnectError, Arc::new(e))
137    }
138}
139
140impl From<io::Error> for PoolConnectError {
141    fn from(e: io::Error) -> Self {
142        e!(PoolConnectError::OnConnectError, Arc::new(e))
143    }
144}
145
146/// Error when calling a fn on the [`ConnectionPool`].
147///
148/// The only thing that can go wrong is that the connection pool is shut down.
149#[stack_error(derive, add_meta)]
150pub enum ConnectionPoolError {
151    /// The connection pool has been shut down
152    #[error("The connection pool has been shut down")]
153    Shutdown {},
154}
155
156enum ActorMessage {
157    RequestRef(RequestRef),
158    ConnectionIdle { id: EndpointId },
159    ConnectionShutdown { id: EndpointId },
160}
161
162struct RequestRef {
163    id: EndpointId,
164    tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
165}
166
167struct Context {
168    options: Options,
169    endpoint: Endpoint,
170    owner: ConnectionPool,
171    alpn: Vec<u8>,
172}
173
174impl Context {
175    async fn run_connection_actor(
176        self: Arc<Self>,
177        node_id: EndpointId,
178        mut rx: mpsc::Receiver<RequestRef>,
179    ) {
180        let context = self;
181
182        let conn_fut = {
183            let context = context.clone();
184            async move {
185                let conn = context
186                    .endpoint
187                    .connect(node_id, &context.alpn)
188                    .await
189                    .map_err(PoolConnectError::from)?;
190                if let Some(on_connect) = &context.options.on_connected {
191                    on_connect(&context.endpoint, &conn)
192                        .await
193                        .map_err(PoolConnectError::from)?;
194                }
195                Result::<Connection, PoolConnectError>::Ok(conn)
196            }
197        };
198
199        // Connect to the node
200        let state = n0_future::time::timeout(context.options.connect_timeout, conn_fut)
201            .await
202            .map_err(|_| e!(PoolConnectError::Timeout))
203            .and_then(|r| r);
204        let conn_close = match &state {
205            Ok(conn) => {
206                let conn = conn.clone();
207                MaybeFuture::Some(async move { conn.closed().await })
208            }
209            Err(e) => {
210                debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
211                if context.owner.close(node_id).await.is_err() {
212                    return;
213                }
214                MaybeFuture::None
215            }
216        };
217
218        let counter = ConnectionCounter::new();
219        let idle_timer = MaybeFuture::default();
220        let idle_stream = counter.clone().idle_stream();
221
222        tokio::pin!(idle_timer, idle_stream, conn_close);
223
224        loop {
225            tokio::select! {
226                biased;
227
228                // Handle new work
229                handler = rx.recv() => {
230                    match handler {
231                        Some(RequestRef { id, tx }) => {
232                            assert!(id == node_id, "Not for me!");
233                            match &state {
234                                Ok(state) => {
235                                    let res = ConnectionRef::new(state.clone(), counter.get_one());
236                                    info!(%node_id, "Handing out ConnectionRef {}", counter.current());
237
238                                    // clear the idle timer
239                                    idle_timer.as_mut().set_none();
240                                    tx.send(Ok(res)).ok();
241                                }
242                                Err(cause) => {
243                                    tx.send(Err(cause.clone())).ok();
244                                }
245                            }
246                        }
247                        None => {
248                            // Channel closed - exit
249                            break;
250                        }
251                    }
252                }
253
254                _ = &mut conn_close => {
255                    // connection was closed by somebody, notify owner that we should be removed
256                    context.owner.close(node_id).await.ok();
257                }
258
259                _ = idle_stream.next() => {
260                    if !counter.is_idle() {
261                        continue;
262                    };
263                    // notify the pool that we are idle.
264                    trace!(%node_id, "Idle");
265                    if context.owner.idle(node_id).await.is_err() {
266                        // If we can't notify the pool, we are shutting down
267                        break;
268                    }
269                    // set the idle timer
270                    idle_timer.as_mut().set_future(n0_future::time::sleep(context.options.idle_timeout));
271                }
272
273                // Idle timeout - request shutdown
274                _ = &mut idle_timer => {
275                    trace!(%node_id, "Idle timer expired, requesting shutdown");
276                    context.owner.close(node_id).await.ok();
277                    // Don't break here - wait for main actor to close our channel
278                }
279            }
280        }
281
282        if let Ok(connection) = state {
283            let reason = if counter.is_idle() { b"idle" } else { b"drop" };
284            connection.close(0u32.into(), reason);
285        }
286
287        trace!(%node_id, "Connection actor shutting down");
288    }
289}
290
291struct Actor {
292    rx: mpsc::Receiver<ActorMessage>,
293    connections: HashMap<EndpointId, mpsc::Sender<RequestRef>>,
294    context: Arc<Context>,
295    // idle set (most recent last)
296    // todo: use a better data structure if this becomes a performance issue
297    idle: VecDeque<EndpointId>,
298    // per connection tasks
299    tasks: FuturesUnordered<future::Boxed<()>>,
300}
301
302impl Actor {
303    pub fn new(
304        endpoint: Endpoint,
305        alpn: &[u8],
306        options: Options,
307    ) -> (Self, mpsc::Sender<ActorMessage>) {
308        let (tx, rx) = mpsc::channel(100);
309        (
310            Self {
311                rx,
312                connections: HashMap::new(),
313                idle: VecDeque::new(),
314                context: Arc::new(Context {
315                    options,
316                    alpn: alpn.to_vec(),
317                    endpoint,
318                    owner: ConnectionPool { tx: tx.clone() },
319                }),
320                tasks: FuturesUnordered::new(),
321            },
322            tx,
323        )
324    }
325
326    fn add_idle(&mut self, id: EndpointId) {
327        self.remove_idle(id);
328        self.idle.push_back(id);
329    }
330
331    fn remove_idle(&mut self, id: EndpointId) {
332        self.idle.retain(|&x| x != id);
333    }
334
335    fn pop_oldest_idle(&mut self) -> Option<EndpointId> {
336        self.idle.pop_front()
337    }
338
339    fn remove_connection(&mut self, id: EndpointId) {
340        self.connections.remove(&id);
341        self.remove_idle(id);
342    }
343
344    async fn handle_msg(&mut self, msg: ActorMessage) {
345        match msg {
346            ActorMessage::RequestRef(mut msg) => {
347                let id = msg.id;
348                self.remove_idle(id);
349                // Try to send to existing connection actor
350                if let Some(conn_tx) = self.connections.get(&id) {
351                    if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
352                        msg = e;
353                    } else {
354                        return;
355                    }
356                    // Connection actor died, remove it
357                    self.remove_connection(id);
358                }
359
360                // No connection actor or it died - check limits
361                if self.connections.len() >= self.context.options.max_connections {
362                    if let Some(idle) = self.pop_oldest_idle() {
363                        // remove the oldest idle connection to make room for one more
364                        trace!("removing oldest idle connection {}", idle);
365                        self.connections.remove(&idle);
366                    } else {
367                        msg.tx
368                            .send(Err(e!(PoolConnectError::TooManyConnections)))
369                            .ok();
370                        return;
371                    }
372                }
373                let (conn_tx, conn_rx) = mpsc::channel(100);
374                self.connections.insert(id, conn_tx.clone());
375
376                let context = self.context.clone();
377
378                self.tasks
379                    .push(Box::pin(context.run_connection_actor(id, conn_rx)));
380
381                // Send the handler to the new actor
382                if conn_tx.send(msg).await.is_err() {
383                    error!(%id, "Failed to send handler to new connection actor");
384                    self.connections.remove(&id);
385                }
386            }
387            ActorMessage::ConnectionIdle { id } => {
388                self.add_idle(id);
389                trace!(%id, "connection idle");
390            }
391            ActorMessage::ConnectionShutdown { id } => {
392                // Remove the connection from our map - this closes the channel
393                self.remove_connection(id);
394                trace!(%id, "removed connection");
395            }
396        }
397    }
398
399    pub async fn run(mut self) {
400        loop {
401            tokio::select! {
402                biased;
403
404                msg = self.rx.recv() => {
405                    if let Some(msg) = msg {
406                        self.handle_msg(msg).await;
407                    } else {
408                        break;
409                    }
410                }
411
412                _ = self.tasks.next(), if !self.tasks.is_empty() => {}
413            }
414        }
415    }
416}
417
418/// A connection pool
419#[derive(Debug, Clone)]
420pub struct ConnectionPool {
421    tx: mpsc::Sender<ActorMessage>,
422}
423
424impl ConnectionPool {
425    pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
426        let (actor, tx) = Actor::new(endpoint, alpn, options);
427
428        // Spawn the main actor
429        n0_future::task::spawn(actor.run());
430
431        Self { tx }
432    }
433
434    /// Returns either a fresh connection or a reference to an existing one.
435    ///
436    /// This is guaranteed to return after approximately [Options::connect_timeout]
437    /// with either an error or a connection.
438    pub async fn get_or_connect(
439        &self,
440        id: EndpointId,
441    ) -> std::result::Result<ConnectionRef, PoolConnectError> {
442        let (tx, rx) = oneshot::channel();
443        self.tx
444            .send(ActorMessage::RequestRef(RequestRef { id, tx }))
445            .await
446            .map_err(|_| e!(PoolConnectError::Shutdown))?;
447        rx.await.map_err(|_| e!(PoolConnectError::Shutdown))?
448    }
449
450    /// Close an existing connection, if it exists
451    ///
452    /// This will finish pending tasks and close the connection. New tasks will
453    /// get a new connection if they are submitted after this call
454    pub async fn close(&self, id: EndpointId) -> std::result::Result<(), ConnectionPoolError> {
455        self.tx
456            .send(ActorMessage::ConnectionShutdown { id })
457            .await
458            .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
459        Ok(())
460    }
461
462    /// Notify the connection pool that a connection is idle.
463    ///
464    /// Should only be called from connection handlers.
465    pub(crate) async fn idle(
466        &self,
467        id: EndpointId,
468    ) -> std::result::Result<(), ConnectionPoolError> {
469        self.tx
470            .send(ActorMessage::ConnectionIdle { id })
471            .await
472            .map_err(|_| e!(ConnectionPoolError::Shutdown))?;
473        Ok(())
474    }
475}
476
477#[derive(Debug)]
478struct ConnectionCounterInner {
479    count: AtomicUsize,
480    notify: Notify,
481}
482
483#[derive(Debug, Clone)]
484struct ConnectionCounter {
485    inner: Arc<ConnectionCounterInner>,
486}
487
488impl ConnectionCounter {
489    fn new() -> Self {
490        Self {
491            inner: Arc::new(ConnectionCounterInner {
492                count: Default::default(),
493                notify: Notify::new(),
494            }),
495        }
496    }
497
498    fn current(&self) -> usize {
499        self.inner.count.load(Ordering::SeqCst)
500    }
501
502    /// Increase the connection count and return a guard for the new connection
503    fn get_one(&self) -> OneConnection {
504        self.inner.count.fetch_add(1, Ordering::SeqCst);
505        OneConnection {
506            inner: self.inner.clone(),
507        }
508    }
509
510    fn is_idle(&self) -> bool {
511        self.inner.count.load(Ordering::SeqCst) == 0
512    }
513
514    /// Infinite stream that yields when the connection is briefly idle.
515    ///
516    /// Note that you still have to check if the connection is still idle when
517    /// you get the notification.
518    ///
519    /// Also note that this stream is triggered on [OneConnection::drop], so it
520    /// won't trigger initially even though a [ConnectionCounter] starts up as
521    /// idle.
522    fn idle_stream(self) -> impl Stream<Item = ()> {
523        n0_future::stream::unfold(self, |c| async move {
524            c.inner.notify.notified().await;
525            Some(((), c))
526        })
527    }
528}
529
530/// Guard for one connection
531#[derive(Debug)]
532struct OneConnection {
533    inner: Arc<ConnectionCounterInner>,
534}
535
536impl Drop for OneConnection {
537    fn drop(&mut self) {
538        if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
539            self.inner.notify.notify_waiters();
540        }
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use std::{collections::BTreeMap, sync::Arc, time::Duration};
547
548    use iroh::{
549        discovery::static_provider::StaticProvider,
550        endpoint::{Connection, ConnectionType},
551        protocol::{AcceptError, ProtocolHandler, Router},
552        Endpoint, EndpointAddr, EndpointId, RelayMode, SecretKey, TransportAddr, Watcher,
553    };
554    use n0_error::{AnyError, Result, StdResultExt};
555    use n0_future::{io, stream, BufferedStreamExt, StreamExt};
556    use testresult::TestResult;
557    use tracing::trace;
558
559    use super::{ConnectionPool, Options, PoolConnectError};
560    use crate::util::connection_pool::OnConnected;
561
562    const ECHO_ALPN: &[u8] = b"echo";
563
564    #[derive(Debug, Clone)]
565    struct Echo;
566
567    impl ProtocolHandler for Echo {
568        async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
569            let conn_id = connection.stable_id();
570            let id = connection.remote_id();
571            trace!(%id, %conn_id, "Accepting echo connection");
572            loop {
573                match connection.accept_bi().await {
574                    Ok((mut send, mut recv)) => {
575                        trace!(%id, %conn_id, "Accepted echo request");
576                        tokio::io::copy(&mut recv, &mut send).await?;
577                        send.finish().map_err(AcceptError::from_err)?;
578                    }
579                    Err(e) => {
580                        trace!(%id, %conn_id, "Failed to accept echo request {e}");
581                        break;
582                    }
583                }
584            }
585            Ok(())
586        }
587    }
588
589    async fn echo_client(conn: &Connection, text: &[u8]) -> Result<Vec<u8>> {
590        let conn_id = conn.stable_id();
591        let id = conn.remote_id();
592        trace!(%id, %conn_id, "Sending echo request");
593        let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
594        send.write_all(text).await.anyerr()?;
595        send.finish().anyerr()?;
596        let response = recv.read_to_end(1000).await.anyerr()?;
597        trace!(%id, %conn_id, "Received echo response");
598        Ok(response)
599    }
600
601    async fn echo_server() -> TestResult<(EndpointAddr, Router)> {
602        let endpoint = iroh::Endpoint::builder()
603            .alpns(vec![ECHO_ALPN.to_vec()])
604            .bind()
605            .await?;
606        endpoint.online().await;
607        let addr = endpoint.addr();
608        let router = iroh::protocol::Router::builder(endpoint)
609            .accept(ECHO_ALPN, Echo)
610            .spawn();
611
612        Ok((addr, router))
613    }
614
615    async fn echo_servers(n: usize) -> TestResult<(Vec<EndpointId>, Vec<Router>, StaticProvider)> {
616        let res = stream::iter(0..n)
617            .map(|_| echo_server())
618            .buffered_unordered(16)
619            .collect::<Vec<_>>()
620            .await;
621        let res: Vec<(EndpointAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
622        let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
623        let ids = addrs.iter().map(|a| a.id).collect::<Vec<_>>();
624        let discovery = StaticProvider::from_endpoint_info(addrs);
625        Ok((ids, routers, discovery))
626    }
627
628    async fn shutdown_routers(routers: Vec<Router>) {
629        stream::iter(routers)
630            .for_each_concurrent(16, |router| async move {
631                let _ = router.shutdown().await;
632            })
633            .await;
634    }
635
636    fn test_options() -> Options {
637        Options {
638            idle_timeout: Duration::from_millis(100),
639            connect_timeout: Duration::from_secs(5),
640            max_connections: 32,
641            on_connected: None,
642        }
643    }
644
645    struct EchoClient {
646        pool: ConnectionPool,
647    }
648
649    impl EchoClient {
650        async fn echo(
651            &self,
652            id: EndpointId,
653            text: Vec<u8>,
654        ) -> Result<Result<(usize, Vec<u8>), AnyError>, PoolConnectError> {
655            let conn = self.pool.get_or_connect(id).await?;
656            let id = conn.stable_id();
657            match echo_client(&conn, &text).await {
658                Ok(res) => Ok(Ok((id, res))),
659                Err(e) => Ok(Err(e)),
660            }
661        }
662    }
663
664    #[tokio::test]
665    // #[traced_test]
666    async fn connection_pool_errors() -> TestResult<()> {
667        // set up static discovery for all addrs
668        let discovery = StaticProvider::new();
669        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
670            .discovery(discovery.clone())
671            .bind()
672            .await?;
673        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
674        let client = EchoClient { pool };
675        {
676            let non_existing = SecretKey::from_bytes(&[0; 32]).public();
677            let res = client.echo(non_existing, b"Hello, world!".to_vec()).await;
678            // trying to connect to a non-existing id will fail with ConnectError
679            // because we don't have any information about the endpoint.
680            assert!(matches!(res, Err(PoolConnectError::ConnectError { .. })));
681        }
682        {
683            let non_listening = SecretKey::from_bytes(&[0; 32]).public();
684            // make up fake node info
685            discovery.add_endpoint_info(EndpointAddr {
686                id: non_listening,
687                addrs: vec![TransportAddr::Ip("127.0.0.1:12121".parse().unwrap())]
688                    .into_iter()
689                    .collect(),
690            });
691            // trying to connect to an id for which we have info, but the other
692            // end is not listening, will lead to a timeout.
693            let res = client.echo(non_listening, b"Hello, world!".to_vec()).await;
694            assert!(matches!(res, Err(PoolConnectError::Timeout { .. })));
695        }
696        Ok(())
697    }
698
699    #[tokio::test]
700    // #[traced_test]
701    async fn connection_pool_smoke() -> TestResult<()> {
702        let n = 32;
703        let (ids, routers, discovery) = echo_servers(n).await?;
704        // build a client endpoint that can resolve all the endpoint ids
705        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
706            .discovery(discovery.clone())
707            .bind()
708            .await?;
709        let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
710        let client = EchoClient { pool };
711        let mut connection_ids = BTreeMap::new();
712        let msg = b"Hello, pool!".to_vec();
713        for id in &ids {
714            let (cid1, res) = client.echo(*id, msg.clone()).await??;
715            assert_eq!(res, msg);
716            let (cid2, res) = client.echo(*id, msg.clone()).await??;
717            assert_eq!(res, msg);
718            assert_eq!(cid1, cid2);
719            connection_ids.insert(id, cid1);
720        }
721        n0_future::time::sleep(Duration::from_millis(1000)).await;
722        for id in &ids {
723            let cid1 = *connection_ids.get(id).expect("Connection ID not found");
724            let (cid2, res) = client.echo(*id, msg.clone()).await??;
725            assert_eq!(res, msg);
726            assert_ne!(cid1, cid2);
727        }
728        shutdown_routers(routers).await;
729        Ok(())
730    }
731
732    /// Tests that idle connections are being reclaimed to make room if we hit the
733    /// maximum connection limit.
734    #[tokio::test]
735    // #[traced_test]
736    async fn connection_pool_idle() -> TestResult<()> {
737        let n = 32;
738        let (ids, routers, discovery) = echo_servers(n).await?;
739        // build a client endpoint that can resolve all the endpoint ids
740        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
741            .discovery(discovery.clone())
742            .bind()
743            .await?;
744        let pool = ConnectionPool::new(
745            endpoint.clone(),
746            ECHO_ALPN,
747            Options {
748                idle_timeout: Duration::from_secs(100),
749                max_connections: 8,
750                ..test_options()
751            },
752        );
753        let client = EchoClient { pool };
754        let msg = b"Hello, pool!".to_vec();
755        for id in &ids {
756            let (_, res) = client.echo(*id, msg.clone()).await??;
757            assert_eq!(res, msg);
758        }
759        shutdown_routers(routers).await;
760        Ok(())
761    }
762
763    /// Uses an on_connected callback that just errors out every time.
764    ///
765    /// This is a basic smoke test that on_connected gets called at all.
766    #[tokio::test]
767    // #[traced_test]
768    async fn on_connected_error() -> TestResult<()> {
769        let n = 1;
770        let (ids, routers, discovery) = echo_servers(n).await?;
771        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
772            .discovery(discovery)
773            .bind()
774            .await?;
775        let on_connected: OnConnected =
776            Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
777        let pool = ConnectionPool::new(
778            endpoint,
779            ECHO_ALPN,
780            Options {
781                on_connected: Some(on_connected),
782                ..test_options()
783            },
784        );
785        let client = EchoClient { pool };
786        let msg = b"Hello, pool!".to_vec();
787        for id in &ids {
788            let res = client.echo(*id, msg.clone()).await;
789            assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
790        }
791        shutdown_routers(routers).await;
792        Ok(())
793    }
794
795    /// Uses an on_connected callback to ensure that the connection is direct.
796    #[tokio::test]
797    // #[traced_test]
798    async fn on_connected_direct() -> TestResult<()> {
799        let n = 1;
800        let (ids, routers, discovery) = echo_servers(n).await?;
801        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
802            .discovery(discovery)
803            .bind()
804            .await?;
805        let on_connected = |ep: Endpoint, conn: Connection| async move {
806            let id = conn.remote_id();
807            let Some(watcher) = ep.conn_type(id) else {
808                return Err(io::Error::other("unable to get conn_type watcher"));
809            };
810            let mut stream = watcher.stream();
811            while let Some(status) = stream.next().await {
812                if let ConnectionType::Direct { .. } = status {
813                    return Ok(());
814                }
815            }
816            Err(io::Error::other("connection closed before becoming direct"))
817        };
818        let pool = ConnectionPool::new(
819            endpoint,
820            ECHO_ALPN,
821            test_options().with_on_connected(on_connected),
822        );
823        let client = EchoClient { pool };
824        let msg = b"Hello, pool!".to_vec();
825        for id in &ids {
826            let res = client.echo(*id, msg.clone()).await;
827            assert!(res.is_ok());
828        }
829        shutdown_routers(routers).await;
830        Ok(())
831    }
832
833    /// Check that when a connection is closed, the pool will give you a new
834    /// connection next time you want one.
835    ///
836    /// This test fails if the connection watch is disabled.
837    #[tokio::test]
838    // #[traced_test]
839    async fn watch_close() -> TestResult<()> {
840        let n = 1;
841        let (ids, routers, discovery) = echo_servers(n).await?;
842        let endpoint = iroh::Endpoint::empty_builder(RelayMode::Default)
843            .discovery(discovery)
844            .bind()
845            .await?;
846
847        let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options());
848        let conn = pool.get_or_connect(ids[0]).await?;
849        let cid1 = conn.stable_id();
850        conn.close(0u32.into(), b"test");
851        n0_future::time::sleep(Duration::from_millis(500)).await;
852        let conn = pool.get_or_connect(ids[0]).await?;
853        let cid2 = conn.stable_id();
854        assert_ne!(cid1, cid2);
855        shutdown_routers(routers).await;
856        Ok(())
857    }
858}