iroh_gossip/
net.rs

1//! Networking for the `iroh-gossip` protocol
2
3#[cfg(test)]
4use std::sync::atomic::AtomicBool;
5use std::{
6    collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque},
7    ops::DerefMut,
8    sync::Arc,
9    time::Duration,
10};
11
12use bytes::Bytes;
13use iroh::{
14    endpoint::{ConnectError, Connection},
15    protocol::{AcceptError, ProtocolHandler},
16    Endpoint, NodeAddr, NodeId, Watcher,
17};
18use irpc::{
19    channel::{self, mpsc::RecvError},
20    WithChannels,
21};
22use n0_future::{
23    stream::Boxed as BoxStream,
24    task::{self, AbortOnDropHandle},
25    time::Instant,
26    MergeUnbounded, Stream, StreamExt,
27};
28use n0_watcher::{Direct, Watchable};
29use rand::rngs::StdRng;
30use snafu::Snafu;
31use tokio::{
32    sync::{broadcast, mpsc},
33    task::JoinSet,
34};
35use tracing::{debug, error_span, info, instrument, trace, warn, Instrument};
36
37use self::{
38    dialer::Dialer,
39    discovery::GossipDiscovery,
40    net_proto::GossipMessage,
41    util::{AddrInfo, ConnectionCounter, Guarded, IrohRemoteConnection, Timers},
42};
43use crate::{
44    api::{self, GossipApi},
45    metrics::{inc, Metrics},
46    net::util::accept_stream,
47    proto::{self, Config, HyparviewConfig, PeerData, PlumtreeConfig, TopicId},
48};
49
50mod dialer;
51mod discovery;
52mod util;
53
54/// ALPN protocol name
55pub const GOSSIP_ALPN: &[u8] = b"/iroh-gossip/1";
56
57type InEvent = proto::topic::InEvent<NodeId>;
58type OutEvent = proto::topic::OutEvent<NodeId>;
59type Timer = proto::topic::Timer<NodeId>;
60type ProtoMessage = proto::topic::Message<NodeId>;
61type ProtoEvent = proto::topic::Event<NodeId>;
62type State = proto::topic::State<NodeId, StdRng>;
63type Command = proto::topic::Command<NodeId>;
64
65/// Publish and subscribe on gossiping topics.
66///
67/// Each topic is a separate broadcast tree with separate memberships.
68/// A topic has to be joined before you can publish or subscribe on the topic.
69/// To join the swarm for a topic, you have to know the [`NodeId`] of at least one peer that also joined the topic.
70///
71/// Messages published on the swarm will be delivered to all peers that joined the swarm for that
72/// topic. You will also be relaying (gossiping) messages published by other peers.
73///
74/// With the default settings, the protocol will maintain up to 5 peer connections per topic.
75///
76/// Even though the [`Gossip`] is created from a [`Endpoint`], it does not accept connections
77/// itself. You should run an accept loop on the [`Endpoint`] yourself, check the ALPN protocol of incoming
78/// connections, and if the ALPN protocol equals [`GOSSIP_ALPN`], forward the connection to the
79/// gossip actor through [Self::handle_connection].
80///
81/// The gossip actor will, however, initiate new connections to other peers by itself.
82#[derive(Debug, Clone)]
83pub struct Gossip(Arc<Inner>);
84
85impl std::ops::Deref for Gossip {
86    type Target = GossipApi;
87    fn deref(&self) -> &Self::Target {
88        &self.0.api
89    }
90}
91
92#[derive(derive_more::Debug)]
93enum LocalActorMessage {
94    #[debug("HandleConnection({})", _0.fmt_short())]
95    HandleConnection(NodeId, Connection),
96    #[debug("Connect({}, {})", _0.fmt_short(), _1.fmt_short())]
97    Connect(NodeId, TopicId),
98    #[debug("SetPeerData({}, {})", _0.fmt_short(), _1.as_bytes().len())]
99    SetPeerData(NodeId, PeerData),
100}
101
102#[derive(Debug)]
103struct Inner {
104    api: GossipApi,
105    local_tx: mpsc::Sender<LocalActorMessage>,
106    _actor_handle: AbortOnDropHandle<()>,
107    max_message_size: usize,
108    metrics: Arc<Metrics>,
109}
110
111impl ProtocolHandler for Gossip {
112    async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
113        let remote = connection.remote_node_id()?;
114        self.handle_connection(remote, connection)
115            .await
116            .map_err(AcceptError::from_err)?;
117        Ok(())
118    }
119
120    async fn shutdown(&self) {
121        // TODO: Graceful shutdown?
122    }
123}
124
125/// Builder to configure and construct [`Gossip`].
126#[derive(Debug, Clone)]
127pub struct Builder {
128    config: proto::Config,
129    alpn: Option<Bytes>,
130}
131
132impl Builder {
133    /// Sets the maximum message size in bytes.
134    /// By default this is `4096` bytes.
135    pub fn max_message_size(mut self, size: usize) -> Self {
136        self.config.max_message_size = size;
137        self
138    }
139
140    /// Set the membership configuration.
141    pub fn membership_config(mut self, config: HyparviewConfig) -> Self {
142        self.config.membership = config;
143        self
144    }
145
146    /// Set the broadcast configuration.
147    pub fn broadcast_config(mut self, config: PlumtreeConfig) -> Self {
148        self.config.broadcast = config;
149        self
150    }
151
152    /// Set the ALPN this gossip instance uses.
153    ///
154    /// It has to be the same for all peers in the network. If you set a custom ALPN,
155    /// you have to use the same ALPN when registering the [`Gossip`] in on a iroh
156    /// router with [`RouterBuilder::accept`].
157    ///
158    /// [`RouterBuilder::accept`]: iroh::protocol::RouterBuilder::accept
159    pub fn alpn(mut self, alpn: impl AsRef<[u8]>) -> Self {
160        self.alpn = Some(alpn.as_ref().to_vec().into());
161        self
162    }
163
164    /// Spawn a gossip actor and get a handle for it
165    pub fn spawn(self, endpoint: Endpoint) -> Gossip {
166        Gossip::new(endpoint, self.config, self.alpn)
167    }
168}
169
170impl Gossip {
171    /// Creates a default `Builder`, with the endpoint set.
172    pub fn builder() -> Builder {
173        Builder {
174            config: Default::default(),
175            alpn: None,
176        }
177    }
178
179    /// Listen on a quinn endpoint for incoming RPC connections.
180    #[cfg(feature = "rpc")]
181    pub async fn listen(self, endpoint: quinn::Endpoint) {
182        self.0.api.listen(endpoint).await
183    }
184
185    /// Get the maximum message size configured for this gossip actor.
186    pub fn max_message_size(&self) -> usize {
187        self.0.max_message_size
188    }
189
190    /// Handle an incoming [`Connection`].
191    ///
192    /// Make sure to check the ALPN protocol yourself before passing the connection.
193    pub async fn handle_connection(
194        &self,
195        remote: NodeId,
196        connection: Connection,
197    ) -> Result<(), ActorStoppedError> {
198        self.0
199            .local_tx
200            .send(LocalActorMessage::HandleConnection(remote, connection))
201            .await
202            .map_err(|_| ActorStoppedSnafu.build())?;
203        Ok(())
204    }
205
206    /// Returns the metrics tracked for this gossip instance.
207    pub fn metrics(&self) -> &Arc<Metrics> {
208        &self.0.metrics
209    }
210
211    fn new(endpoint: Endpoint, config: Config, alpn: Option<Bytes>) -> Self {
212        let metrics = Arc::new(Metrics::default());
213        let max_message_size = config.max_message_size;
214        let me = endpoint.node_id();
215        let (api_tx, local_tx, actor) = Actor::new(endpoint, config, alpn, metrics.clone());
216        let actor_task = task::spawn(
217            actor
218                .run()
219                .instrument(error_span!("gossip", me=%me.fmt_short())),
220        );
221
222        Self(Arc::new(Inner {
223            local_tx,
224            max_message_size,
225            api: GossipApi::local(api_tx),
226            metrics,
227            _actor_handle: AbortOnDropHandle::new(actor_task),
228        }))
229    }
230
231    #[cfg(test)]
232    fn new_with_actor(endpoint: Endpoint, config: Config, alpn: Option<Bytes>) -> (Self, Actor) {
233        let metrics = Arc::new(Metrics::default());
234        let max_message_size = config.max_message_size;
235        let (api_tx, local_tx, actor) = Actor::new(endpoint, config, alpn, metrics.clone());
236        let handle = Self(Arc::new(Inner {
237            local_tx,
238            max_message_size,
239            api: GossipApi::local(api_tx),
240            metrics,
241            _actor_handle: AbortOnDropHandle::new(task::spawn(std::future::pending())),
242        }));
243        (handle, actor)
244    }
245}
246
247mod net_proto {
248    use irpc::{channel::mpsc, rpc_requests};
249    use serde::{Deserialize, Serialize};
250
251    use crate::proto::TopicId;
252
253    #[derive(Debug, Serialize, Deserialize, Clone)]
254    #[non_exhaustive]
255    pub struct JoinRequest {
256        pub topic_id: TopicId,
257    }
258
259    #[rpc_requests(message = GossipMessage)]
260    #[derive(Debug, Serialize, Deserialize)]
261    pub enum Request {
262        #[rpc(tx=mpsc::Sender<super::ProtoMessage>, rx=mpsc::Receiver<super::ProtoMessage>)]
263        Join(JoinRequest),
264    }
265}
266
267/// Error emitted when the gossip actor stopped.
268#[derive(Debug, Snafu)]
269pub struct ActorStoppedError;
270
271#[derive(strum::Display)]
272enum ActorToTopic {
273    Api(ApiJoinRequest),
274    Connected {
275        remote: NodeId,
276        tx: Guarded<channel::mpsc::Sender<ProtoMessage>>,
277        rx: Guarded<channel::mpsc::Receiver<ProtoMessage>>,
278    },
279    ConnectionFailed(NodeId),
280}
281
282type ApiJoinRequest = WithChannels<api::JoinRequest, api::Request>;
283type ApiRecvStream = BoxStream<Result<api::Command, RecvError>>;
284type RemoteRecvStream = BoxStream<(NodeId, Result<Option<ProtoMessage>, RecvError>)>;
285type AcceptRemoteRequestsStream =
286    MergeUnbounded<BoxStream<(NodeId, std::io::Result<Guarded<net_proto::GossipMessage>>)>>;
287
288struct Actor {
289    me: NodeId,
290    endpoint: Endpoint,
291    alpn: Bytes,
292    config: Config,
293    local_rx: mpsc::Receiver<LocalActorMessage>,
294    local_tx: mpsc::Sender<LocalActorMessage>,
295    api_rx: mpsc::Receiver<api::RpcMessage>,
296    topics: HashMap<TopicId, TopicHandle>,
297    pending_remotes_with_topics: HashMap<NodeId, HashSet<TopicId>>,
298    topic_tasks: JoinSet<TopicActor>,
299    remotes: HashMap<NodeId, RemoteState>,
300    close_connections: JoinSet<(NodeId, Connection)>,
301    dialer: Dialer,
302    our_peer_data: n0_watcher::Watchable<PeerData>,
303    metrics: Arc<Metrics>,
304    node_addr_updates: BoxStream<NodeAddr>,
305    accepting: AcceptRemoteRequestsStream,
306    discovery: GossipDiscovery,
307}
308
309impl Actor {
310    fn new(
311        endpoint: Endpoint,
312        config: Config,
313        alpn: Option<Bytes>,
314        metrics: Arc<Metrics>,
315    ) -> (
316        mpsc::Sender<api::RpcMessage>,
317        mpsc::Sender<LocalActorMessage>,
318        Self,
319    ) {
320        let (api_tx, api_rx) = tokio::sync::mpsc::channel(16);
321        let (local_tx, local_rx) = tokio::sync::mpsc::channel(16);
322
323        let me = endpoint.node_id();
324        let node_addr_updates = endpoint.watch_node_addr().stream();
325        let discovery = GossipDiscovery::default();
326        endpoint.discovery().add(discovery.clone());
327        let initial_peer_data = AddrInfo::from(endpoint.node_addr()).encode();
328        // let peer_data = endpoint
329        //     .watch_node_addr()
330        //     .map(|addr| AddrInfo::from(addr).encode())
331        //     .unwrap();
332        (
333            api_tx,
334            local_tx.clone(),
335            Actor {
336                endpoint,
337                me,
338                config,
339                api_rx,
340                local_tx,
341                local_rx,
342                node_addr_updates: Box::pin(node_addr_updates),
343                dialer: Dialer::default(),
344                our_peer_data: Watchable::new(initial_peer_data),
345                alpn: alpn.unwrap_or_else(|| crate::ALPN.to_vec().into()),
346                metrics: metrics.clone(),
347                topics: Default::default(),
348                pending_remotes_with_topics: Default::default(),
349                remotes: Default::default(),
350                close_connections: JoinSet::new(),
351                topic_tasks: JoinSet::new(),
352                accepting: Default::default(),
353                discovery,
354            },
355        )
356    }
357
358    async fn run(mut self) {
359        while self.tick().await {}
360    }
361
362    #[cfg(test)]
363    #[instrument("gossip", skip_all, fields(me=%self.me.fmt_short()))]
364    pub(crate) async fn finish(self) {
365        self.run().await
366    }
367
368    #[cfg(test)]
369    #[instrument("gossip", skip_all, fields(me=%self.me.fmt_short()))]
370    pub(crate) async fn steps(&mut self, n: usize) -> Result<(), ActorStoppedError> {
371        for _ in 0..n {
372            if !self.tick().await {
373                return Err(ActorStoppedError);
374            }
375        }
376        Ok(())
377    }
378
379    async fn tick(&mut self) -> bool {
380        trace!("wait for tick");
381        self.metrics.actor_tick_main.inc();
382        tokio::select! {
383            addr = self.node_addr_updates.next() => {
384                trace!("tick: node_addr_update");
385                match addr {
386                    None => {
387                        warn!("address stream returned None - endpoint has shut down");
388                        false
389                    }
390                    Some(addr) => {
391                        let data = AddrInfo::from(addr).encode();
392                        self.our_peer_data.set(data).ok();
393                        true
394                    }
395                }
396            }
397            Some(msg) = self.local_rx.recv() => {
398                trace!("tick: local_rx {msg:?}");
399                match msg {
400                    LocalActorMessage::HandleConnection(node_id, connection) => {
401                        self.handle_remote_connection(node_id, Ok(connection), Direction::Accept).await;
402                    }
403                    LocalActorMessage::Connect(node_id, topic_id) => {
404                        self.connect(node_id, topic_id);
405                    }
406                    LocalActorMessage::SetPeerData(node_id, data) => {
407                        match AddrInfo::decode(&data) {
408                            Err(err) => warn!(remote=%node_id.fmt_short(), ?err, len=data.inner().len(), "Failed to decode peer data"),
409                            Ok(info) => {
410                                debug!(peer = ?node_id, "add known addrs: {info:?}");
411                                let node_addr = info.into_node_addr(node_id);
412                                self.discovery.add(node_addr);
413                            }
414                        }
415                    }
416                }
417                true
418            }
419            Some((node_id, res)) = self.dialer.next(), if !self.dialer.is_empty() => {
420                trace!(remote=%node_id.fmt_short(), ok=res.is_ok(), "tick: dialed");
421                self.handle_remote_connection(node_id, res, Direction::Dial).await;
422                true
423            }
424            Some((node_id, res)) = self.accepting.next(), if !self.accepting.is_empty() => {
425                trace!(remote=%node_id.fmt_short(), res=?res.as_ref().map(|_| ()), "tick: accepting");
426                match res {
427                    Ok(request) => self.handle_remote_message(node_id, request).await,
428                    Err(reason) => {
429                        debug!(remote=%node_id.fmt_short(), ?reason, "accept loop for remote closed");
430                    }
431                }
432                true
433            }
434            msg = self.api_rx.recv() => {
435                trace!(some=msg.is_some(), "tick: api_rx");
436                match msg {
437                    Some(msg) => {
438                        self.handle_api_message(msg).await;
439                        true
440                    }
441                    None => {
442                        trace!("all api senders dropped, stop actor");
443                        false
444                    }
445                }
446            }
447            Some(res) = self.close_connections.join_next(), if !self.close_connections.is_empty() => {
448                let (node_id, connection) = res.expect("connection task panicked");
449                trace!(remote=%node_id.fmt_short(), "tick: connection closed");
450                if let Some(state) = self.remotes.get(&node_id) {
451                    if state.same_connection(&connection) {
452                        self.remotes.remove(&node_id);
453                    }
454                }
455                true
456            }
457            Some(actor) = self.topic_tasks.join_next(), if !self.topic_tasks.is_empty() => {
458                let actor = actor.expect("topic actor task panicked");
459                trace!(topic=%actor.topic_id.fmt_short(), "tick: topic actor finished");
460                self.topics.remove(&actor.topic_id);
461                true
462            }
463            else => unreachable!("reached else arm, but all fallible cases should be handled"),
464        }
465    }
466
467    #[cfg(test)]
468    fn endpoint(&self) -> &Endpoint {
469        &self.endpoint
470    }
471
472    fn drain_pending_dials(
473        &mut self,
474        remote: &NodeId,
475    ) -> impl Iterator<Item = (TopicId, &TopicHandle)> {
476        self.pending_remotes_with_topics
477            .remove(remote)
478            .into_iter()
479            .flatten()
480            .flat_map(|topic_id| self.topics.get(&topic_id).map(|handle| (topic_id, handle)))
481    }
482
483    fn connect(&mut self, remote: NodeId, topic_id: TopicId) {
484        let Some(handle) = self.topics.get(&topic_id) else {
485            return;
486        };
487        if let Some(state) = self.remotes.get(&remote) {
488            let tx = handle.tx.clone();
489            let state = state.clone();
490            // TODO: Track task?
491            task::spawn(async move {
492                let msg = state.open_topic(topic_id).await;
493                tx.send(msg).await.ok();
494            });
495        } else {
496            self.dialer
497                .queue_dial(&self.endpoint, remote, self.alpn.clone());
498            self.pending_remotes_with_topics
499                .entry(remote)
500                .or_default()
501                .insert(topic_id);
502        }
503    }
504
505    #[instrument("connection", skip_all, fields(remote=%remote.fmt_short()))]
506    async fn handle_remote_connection(
507        &mut self,
508        remote: NodeId,
509        res: Result<Connection, ConnectError>,
510        direction: Direction,
511    ) {
512        match (res.as_ref(), direction) {
513            (Ok(_), Direction::Dial) => inc(&self.metrics.peers_dialed_success),
514            (Err(_), Direction::Dial) => inc(&self.metrics.peers_dialed_failure),
515            (Ok(_), Direction::Accept) => inc(&self.metrics.peers_accepted),
516            (Err(_), Direction::Accept) => {}
517        }
518        let connection = match res {
519            Err(err) => {
520                debug!(?err, "Connection failed");
521                for (_, handle) in self.drain_pending_dials(&remote) {
522                    handle
523                        .send(ActorToTopic::ConnectionFailed(remote))
524                        .await
525                        .ok();
526                }
527                return;
528            }
529            Ok(connection) => connection,
530        };
531
532        let state = RemoteState::new(remote, connection.clone(), direction);
533
534        // Open requests for pending topics.
535        for (topic_id, handle) in self.drain_pending_dials(&remote) {
536            let tx = handle.tx.clone();
537            let state = state.clone();
538            task::spawn(
539                async move {
540                    let msg = state.open_topic(topic_id).await;
541                    tx.send(msg).await.ok();
542                }
543                .instrument(tracing::Span::current()),
544            );
545        }
546
547        // Read incoming requests.
548        let counter = state.counter.clone();
549        self.accepting.push(Box::pin(
550            accept_stream::<net_proto::Request>(connection.clone())
551                .map(move |req| (remote, req.map(|r| counter.guard(r)))),
552        ));
553
554        // Close on idle (if dialed) or await close (if accepted).
555        let counter = state.counter.clone();
556        let fut = async move {
557            match direction {
558                Direction::Dial => {
559                    counter.idle_for(Duration::from_millis(500)).await;
560                    info!("close connection (from dial): unused");
561                    connection.close(1u32.into(), b"idle");
562                }
563                Direction::Accept => {
564                    let reason = connection.closed().await;
565                    info!(?reason, "connection closed (from accept)")
566                }
567            }
568            (remote, connection)
569        };
570        self.close_connections
571            .spawn(fut.instrument(error_span!("conn", remote=%remote.fmt_short())));
572
573        self.remotes.insert(remote, state);
574    }
575
576    #[instrument("request", skip_all, fields(remote=%remote.fmt_short()))]
577    async fn handle_remote_message(&mut self, remote: NodeId, request: Guarded<GossipMessage>) {
578        let (request, guard) = request.split();
579        let (topic_id, request) = match request {
580            GossipMessage::Join(req) => (req.inner.topic_id, req),
581        };
582        if let Some(topic) = self.topics.get(&topic_id) {
583            if let Err(_err) = topic
584                .send(ActorToTopic::Connected {
585                    remote,
586                    tx: Guarded::new(request.tx, guard.clone()),
587                    rx: Guarded::new(request.rx, guard.clone()),
588                })
589                .await
590            {
591                warn!(topic=%topic_id.fmt_short(), "Topic actor dead");
592            }
593        } else {
594            debug!(topic=%topic_id.fmt_short(), "ignore request: unknown topic");
595        }
596    }
597
598    async fn handle_api_message(&mut self, msg: api::RpcMessage) {
599        let (topic_id, msg) = match msg {
600            api::RpcMessage::Join(msg) => (msg.inner.topic_id, msg),
601        };
602        let topic = self.topics.entry(topic_id).or_insert_with(|| {
603            let (handle, actor) = TopicHandle::new(
604                self.me,
605                topic_id,
606                self.config.clone(),
607                self.local_tx.clone(),
608                self.our_peer_data.watch(),
609                self.metrics.clone(),
610            );
611            self.topic_tasks.spawn(
612                actor
613                    .run()
614                    .instrument(error_span!("topic", topic=%topic_id.fmt_short())),
615            );
616            handle
617        });
618        if topic.send(ActorToTopic::Api(msg)).await.is_err() {
619            warn!(topic=%topic_id.fmt_short(), "Topic actor dead");
620        }
621    }
622}
623
624#[derive(Clone)]
625struct RemoteState {
626    node_id: NodeId,
627    conn_id: usize,
628    client: irpc::Client<net_proto::Request>,
629    #[allow(dead_code)]
630    direction: Direction,
631    counter: ConnectionCounter,
632}
633
634impl RemoteState {
635    fn new(node_id: NodeId, connection: Connection, direction: Direction) -> Self {
636        let conn_id = connection.stable_id();
637        let irpc_conn = IrohRemoteConnection::new(connection);
638        let client = irpc::Client::boxed(irpc_conn);
639        let counter = ConnectionCounter::new();
640        RemoteState {
641            client,
642            direction,
643            conn_id,
644            counter,
645            node_id,
646        }
647    }
648
649    fn same_connection(&self, conn: &Connection) -> bool {
650        self.conn_id == conn.stable_id()
651    }
652
653    async fn open_topic(&self, topic_id: TopicId) -> ActorToTopic {
654        let guard = self.counter.get_one();
655        let req = net_proto::JoinRequest { topic_id };
656        match self.client.bidi_streaming(req.clone(), 64, 64).await {
657            Ok((tx, rx)) => ActorToTopic::Connected {
658                remote: self.node_id,
659                tx: Guarded::new(tx, guard.clone()),
660                rx: Guarded::new(rx, guard),
661            },
662            Err(err) => {
663                warn!(?topic_id, ?err, "failed to open stream with remote");
664                ActorToTopic::ConnectionFailed(self.node_id)
665            }
666        }
667    }
668}
669
670#[derive(Debug, Copy, Clone)]
671enum Direction {
672    Dial,
673    Accept,
674}
675
676struct TopicHandle {
677    tx: mpsc::Sender<ActorToTopic>,
678    #[cfg(test)]
679    joined: Arc<AtomicBool>,
680}
681
682impl TopicHandle {
683    fn new(
684        me: NodeId,
685        topic_id: TopicId,
686        config: proto::Config,
687        to_actor_tx: mpsc::Sender<LocalActorMessage>,
688        peer_data: Direct<PeerData>,
689        metrics: Arc<Metrics>,
690    ) -> (Self, TopicActor) {
691        let (tx, rx) = mpsc::channel(16);
692        // TODO: peer_data
693        let state = State::new(me, None, config);
694        #[cfg(test)]
695        let joined = Arc::new(AtomicBool::new(false));
696        let (forward_event_tx, _) = broadcast::channel(512);
697        let actor = TopicActor {
698            topic_id,
699            state,
700            actor_rx: rx,
701            to_actor_tx,
702            peer_data,
703            forward_event_tx,
704            metrics,
705            init: false,
706            #[cfg(test)]
707            joined: joined.clone(),
708            timers: Default::default(),
709            neighbors: Default::default(),
710            out_events: Default::default(),
711            api_receivers: Default::default(),
712            remote_senders: Default::default(),
713            remote_receivers: Default::default(),
714            drop_peers_queue: Default::default(),
715            forward_event_tasks: Default::default(),
716        };
717        let handle = Self {
718            tx,
719            #[cfg(test)]
720            joined,
721        };
722        (handle, actor)
723    }
724
725    async fn send(&self, msg: ActorToTopic) -> Result<(), mpsc::error::SendError<ActorToTopic>> {
726        self.tx.send(msg).await
727    }
728
729    #[cfg(test)]
730    fn joined(&self) -> bool {
731        self.joined.load(std::sync::atomic::Ordering::Relaxed)
732    }
733}
734
735struct TopicActor {
736    topic_id: TopicId,
737    to_actor_tx: mpsc::Sender<LocalActorMessage>,
738    state: State,
739    actor_rx: mpsc::Receiver<ActorToTopic>,
740    timers: Timers<Timer>,
741    neighbors: BTreeSet<NodeId>,
742    peer_data: Direct<PeerData>,
743    out_events: VecDeque<OutEvent>,
744    api_receivers: MergeUnbounded<ApiRecvStream>,
745    remote_senders: HashMap<NodeId, MaybeSender>,
746    remote_receivers: MergeUnbounded<RemoteRecvStream>,
747    forward_event_tx: broadcast::Sender<ProtoEvent>,
748    forward_event_tasks: JoinSet<()>,
749    #[cfg(test)]
750    joined: Arc<AtomicBool>,
751    init: bool,
752    drop_peers_queue: HashSet<NodeId>,
753    metrics: Arc<Metrics>,
754}
755
756impl TopicActor {
757    pub async fn run(mut self) -> Self {
758        self.metrics.topics_joined.inc();
759        let peer_data = self.peer_data.clone().stream();
760        tokio::pin!(peer_data);
761        loop {
762            tokio::select! {
763                Some(msg) = self.actor_rx.recv() => {
764                    trace!("tick: actor_rx {msg}");
765                    self.handle_actor_message(msg).await;
766                },
767                Some(cmd) = self.api_receivers.next(), if !self.api_receivers.is_empty() => {
768                    self.handle_api_command(cmd).await;
769                }
770                Some((remote, message)) = self.remote_receivers.next(), if !self.remote_receivers.is_empty() => {
771                    trace!(remote=%remote.fmt_short(), msg=?message, "tick: remote_rx");
772                    self.handle_remote_message(remote, message).await;
773                }
774                Some(data) = peer_data.next() => {
775                    self.handle_in_event(InEvent::UpdatePeerData(data)).await;
776                }
777                _ = self.timers.wait_next() => {
778                    let now = Instant::now();
779                    while let Some((_instant, timer)) = self.timers.pop_before(now) {
780                        self.handle_in_event(InEvent::TimerExpired(timer)).await;
781                    }
782                }
783                _ = self.forward_event_tasks.join_next(), if !self.forward_event_tasks.is_empty() => {}
784                else => break,
785            }
786
787            if !self.drop_peers_queue.is_empty() {
788                let now = Instant::now();
789                for peer in self.drop_peers_queue.drain() {
790                    self.out_events
791                        .extend(self.state.handle(InEvent::PeerDisconnected(peer), now));
792                }
793                self.process_out_events(now).await;
794            }
795
796            if self.to_actor_tx.is_closed() {
797                warn!("Channel to main actor closed: abort topic loop");
798                break;
799            }
800            if self.init && self.api_receivers.is_empty() && self.forward_event_tasks.is_empty() {
801                debug!("Closing topic: All API subscribers dropped");
802                break;
803            }
804        }
805        self.metrics.topics_quit.inc();
806        self
807    }
808
809    async fn handle_actor_message(&mut self, msg: ActorToTopic) {
810        match msg {
811            ActorToTopic::Connected { remote, rx, tx } => {
812                self.remote_receivers
813                    .push(Box::pin(into_stream(rx).map(move |msg| (remote, msg))));
814                let sender = self.remote_senders.entry(remote).or_default();
815                if let Err(err) = sender.init(tx).await {
816                    warn!("Remote failed while pushing queued messages: {err:?}");
817                }
818            }
819            ActorToTopic::Api(req) => {
820                self.init = true;
821                let WithChannels { inner, tx, rx, .. } = req;
822                let initial_neighbors = self.neighbors.clone().into_iter();
823                self.forward_event_tasks.spawn(
824                    forward_events(tx, self.forward_event_tx.subscribe(), initial_neighbors)
825                        .instrument(tracing::Span::current()),
826                );
827                self.api_receivers.push(Box::pin(into_stream2(rx)));
828                self.handle_in_event(InEvent::Command(Command::Join(
829                    inner.bootstrap.into_iter().collect(),
830                )))
831                .await;
832            }
833            ActorToTopic::ConnectionFailed(node_id) => {
834                self.handle_in_event(InEvent::PeerDisconnected(node_id))
835                    .await
836            }
837        }
838    }
839
840    async fn handle_remote_message(
841        &mut self,
842        remote: NodeId,
843        message: Result<Option<ProtoMessage>, RecvError>,
844    ) {
845        let event = match message {
846            Ok(Some(message)) => InEvent::RecvMessage(remote, message),
847            Ok(None) => {
848                debug!(remote=%remote.fmt_short(), "Recv stream from remote closed");
849                InEvent::PeerDisconnected(remote)
850            }
851            Err(err) => {
852                warn!(remote=%remote.fmt_short(), ?err, "Recv stream from remote failed");
853                InEvent::PeerDisconnected(remote)
854            }
855        };
856        self.handle_in_event(event).await;
857    }
858
859    async fn handle_api_command(&mut self, command: Result<api::Command, RecvError>) {
860        let Ok(command) = command else {
861            return;
862        };
863        trace!("tick: api command {command}");
864        self.handle_in_event(InEvent::Command(command.into())).await;
865    }
866
867    async fn handle_in_event(&mut self, event: InEvent) {
868        trace!("tick: in event {event:?}");
869        let now = Instant::now();
870        self.metrics.track_in_event(&event);
871        self.out_events.extend(self.state.handle(event, now));
872        self.process_out_events(now).await;
873    }
874
875    async fn process_out_events(&mut self, now: Instant) {
876        while let Some(event) = self.out_events.pop_front() {
877            trace!("tick: out event {event:?}");
878            self.metrics.track_out_event(&event);
879            match event {
880                OutEvent::SendMessage(node_id, message) => {
881                    self.send(node_id, message).await;
882                }
883                OutEvent::EmitEvent(event) => {
884                    self.handle_event(event);
885                }
886                OutEvent::ScheduleTimer(delay, timer) => {
887                    self.timers.insert(now + delay, timer);
888                }
889                OutEvent::DisconnectPeer(node_id) => {
890                    self.remote_senders.remove(&node_id);
891                }
892                OutEvent::PeerData(node_id, peer_data) => {
893                    self.to_actor_tx
894                        .send(LocalActorMessage::SetPeerData(node_id, peer_data))
895                        .await
896                        .ok();
897                }
898            }
899        }
900    }
901
902    #[instrument(skip_all, fields(remote=%remote.fmt_short()))]
903    async fn send(&mut self, remote: NodeId, message: ProtoMessage) {
904        let sender = match self.remote_senders.entry(remote) {
905            hash_map::Entry::Occupied(entry) => entry.into_mut(),
906            hash_map::Entry::Vacant(entry) => {
907                debug!("requesting new connection");
908                self.to_actor_tx
909                    .send(LocalActorMessage::Connect(remote, self.topic_id))
910                    .await
911                    .ok();
912                entry.insert(Default::default())
913            }
914        };
915        if let Err(err) = sender.send(message).await {
916            warn!(?err, remote=%remote.fmt_short(), "failed to send message");
917            self.drop_peers_queue.insert(remote);
918        }
919    }
920
921    fn handle_event(&mut self, event: ProtoEvent) {
922        match &event {
923            ProtoEvent::NeighborUp(n) => {
924                #[cfg(test)]
925                self.joined
926                    .store(true, std::sync::atomic::Ordering::Relaxed);
927                self.neighbors.insert(*n);
928            }
929            ProtoEvent::NeighborDown(n) => {
930                self.neighbors.remove(n);
931            }
932            ProtoEvent::Received(_) => {}
933        }
934        self.forward_event_tx.send(event).ok();
935    }
936}
937
938async fn forward_events(
939    tx: channel::mpsc::Sender<api::Event>,
940    mut sub: broadcast::Receiver<ProtoEvent>,
941    initial_neighbors: impl Iterator<Item = NodeId>,
942) {
943    for neighbor in initial_neighbors {
944        if let Err(_err) = tx.send(api::Event::NeighborUp(neighbor)).await {
945            break;
946        }
947    }
948    loop {
949        let event = tokio::select! {
950            biased;
951            event = sub.recv() => event,
952            _ = tx.closed() => break
953        };
954        let event: api::Event = match event {
955            Ok(event) => event.into(),
956            Err(broadcast::error::RecvError::Lagged(_)) => api::Event::Lagged,
957            Err(broadcast::error::RecvError::Closed) => break,
958        };
959        if let Err(_err) = tx.send(event).await {
960            break;
961        }
962    }
963}
964
965#[derive(Debug)]
966enum MaybeSender {
967    Active(Guarded<channel::mpsc::Sender<ProtoMessage>>),
968    Pending(Vec<ProtoMessage>),
969}
970
971impl MaybeSender {
972    async fn send(&mut self, message: ProtoMessage) -> Result<(), channel::SendError> {
973        match self {
974            Self::Active(sender) => sender.send(message).await,
975            Self::Pending(messages) => {
976                messages.push(message);
977                Ok(())
978            }
979        }
980    }
981
982    async fn init(
983        &mut self,
984        sender: Guarded<channel::mpsc::Sender<ProtoMessage>>,
985    ) -> Result<(), channel::SendError> {
986        debug!("Initializing new sender");
987        *self = match self {
988            Self::Active(_old) => {
989                debug!("Dropping old sender");
990                Self::Active(sender)
991            }
992            Self::Pending(queue) => {
993                debug!("Sending {} queued messages", queue.len());
994                for msg in queue.drain(..) {
995                    sender.send(msg).await?;
996                }
997                Self::Active(sender)
998            }
999        };
1000        Ok(())
1001    }
1002}
1003
1004impl Default for MaybeSender {
1005    fn default() -> Self {
1006        Self::Pending(Vec::new())
1007    }
1008}
1009
1010// TODO: Upstream to irpc: This differs from Receiver::into_stream: it returns
1011// None after the first error, whereas upstream would loop on the error
1012fn into_stream<T: irpc::RpcMessage>(
1013    receiver: impl DerefMut<Target = channel::mpsc::Receiver<T>> + Send + Sync + 'static,
1014) -> impl Stream<Item = Result<Option<T>, RecvError>> + Send + Sync + 'static {
1015    n0_future::stream::unfold(Some(receiver), |recv| async move {
1016        let mut recv = recv?;
1017        let res = recv.recv().await;
1018        match res {
1019            Err(err) => Some((Err(err), None)),
1020            Ok(Some(res)) => Some((Ok(Some(res)), Some(recv))),
1021            Ok(None) => Some((Ok(None), None)),
1022        }
1023    })
1024}
1025
1026fn into_stream2<T: irpc::RpcMessage>(
1027    receiver: channel::mpsc::Receiver<T>,
1028) -> impl Stream<Item = Result<T, RecvError>> + Send + Sync + 'static {
1029    n0_future::stream::unfold(Some(receiver), |recv| async move {
1030        let mut recv = recv?;
1031        match recv.recv().await {
1032            Err(err) => Some((Err(err), None)),
1033            Ok(Some(res)) => Some((Ok(res), Some(recv))),
1034            Ok(None) => None,
1035        }
1036    })
1037}
1038
1039#[cfg(test)]
1040pub(crate) mod tests {
1041    use std::{future::Future, time::Duration};
1042
1043    use bytes::Bytes;
1044    use futures_concurrency::future::TryJoin;
1045    use iroh::{
1046        discovery::static_provider::StaticProvider, endpoint::BindError, protocol::Router,
1047        NodeAddr, RelayMap, RelayMode, SecretKey,
1048    };
1049    use n0_snafu::{Result, ResultExt};
1050    use rand::{CryptoRng, Rng, SeedableRng};
1051    use tokio::{spawn, time::timeout};
1052    use tokio_util::sync::CancellationToken;
1053    use tracing::info;
1054    use tracing_test::traced_test;
1055
1056    use super::*;
1057    use crate::{
1058        api::{ApiError, Event, GossipReceiver, GossipSender},
1059        ALPN,
1060    };
1061
1062    impl Gossip {
1063        pub(super) async fn t_new(
1064            rng: &mut rand_chacha::ChaCha12Rng,
1065            config: proto::Config,
1066            relay_map: RelayMap,
1067            cancel: &CancellationToken,
1068        ) -> n0_snafu::Result<(Self, Endpoint, impl Future<Output = ()>, impl Drop)> {
1069            let (gossip, actor, ep_handle) =
1070                Gossip::t_new_with_actor(rng, config, relay_map, cancel).await?;
1071            let ep = actor.endpoint().clone();
1072            let me = ep.node_id().fmt_short();
1073            let actor_handle =
1074                task::spawn(actor.run().instrument(tracing::error_span!("gossip", %me)));
1075            Ok((gossip, ep, ep_handle, AbortOnDropHandle::new(actor_handle)))
1076        }
1077        pub(super) async fn t_new_with_actor(
1078            rng: &mut rand_chacha::ChaCha12Rng,
1079            config: proto::Config,
1080            relay_map: RelayMap,
1081            cancel: &CancellationToken,
1082        ) -> n0_snafu::Result<(Self, Actor, impl Future<Output = ()>)> {
1083            let endpoint = Endpoint::builder()
1084                .secret_key(SecretKey::generate(rng))
1085                .relay_mode(RelayMode::Custom(relay_map))
1086                .insecure_skip_relay_cert_verify(true)
1087                .bind()
1088                .await?;
1089
1090            endpoint.online().await;
1091            let (gossip, mut actor) = Gossip::new_with_actor(endpoint.clone(), config, None);
1092            actor.node_addr_updates = Box::pin(n0_future::stream::pending());
1093            let router = Router::builder(endpoint)
1094                .accept(GOSSIP_ALPN, gossip.clone())
1095                .spawn();
1096            let cancel = cancel.clone();
1097            let router_task = tokio::task::spawn(async move {
1098                cancel.cancelled().await;
1099                router.shutdown().await.ok();
1100                drop(router);
1101            });
1102            let router_fut = async move {
1103                router_task.await.expect("router task panicked");
1104            };
1105            Ok((gossip, actor, router_fut))
1106        }
1107    }
1108
1109    pub(crate) async fn create_endpoint(
1110        rng: &mut rand_chacha::ChaCha12Rng,
1111        relay_map: RelayMap,
1112        static_provider: Option<StaticProvider>,
1113    ) -> Result<Endpoint, BindError> {
1114        let ep = Endpoint::builder()
1115            .secret_key(SecretKey::generate(rng))
1116            .alpns(vec![ALPN.to_vec()])
1117            .relay_mode(RelayMode::Custom(relay_map))
1118            .insecure_skip_relay_cert_verify(true)
1119            .bind()
1120            .await?;
1121
1122        if let Some(static_provider) = static_provider {
1123            ep.discovery().add(static_provider);
1124        }
1125        ep.online().await;
1126        Ok(ep)
1127    }
1128
1129    async fn endpoint_loop(
1130        endpoint: Endpoint,
1131        gossip: Gossip,
1132        cancel: CancellationToken,
1133    ) -> Result<()> {
1134        loop {
1135            tokio::select! {
1136                biased;
1137                _ = cancel.cancelled() => break,
1138                incoming = endpoint.accept() => match incoming {
1139                    None => break,
1140                    Some(incoming) => {
1141                        let connecting = match incoming.accept() {
1142                            Ok(connecting) => connecting,
1143                            Err(err) => {
1144                                warn!("incoming connection failed: {err:#}");
1145                                // we can carry on in these cases:
1146                                // this can be caused by retransmitted datagrams
1147                                continue;
1148                            }
1149                        };
1150                        let connection = connecting.await.e()?;
1151                        let remote_node_id = connection.remote_node_id()?;
1152                        gossip.handle_connection(remote_node_id, connection).await?
1153                    }
1154                }
1155            }
1156        }
1157        Ok(())
1158    }
1159
1160    #[tokio::test]
1161    #[traced_test]
1162    async fn gossip_net_smoke() {
1163        let mut rng = rand_chacha::ChaCha12Rng::seed_from_u64(1);
1164        let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1165
1166        let static_provider = StaticProvider::new();
1167
1168        let ep1 = create_endpoint(&mut rng, relay_map.clone(), Some(static_provider.clone()))
1169            .await
1170            .unwrap();
1171        let ep2 = create_endpoint(&mut rng, relay_map.clone(), Some(static_provider.clone()))
1172            .await
1173            .unwrap();
1174        let ep3 = create_endpoint(&mut rng, relay_map.clone(), Some(static_provider.clone()))
1175            .await
1176            .unwrap();
1177
1178        let go1 = Gossip::builder().spawn(ep1.clone());
1179        let go2 = Gossip::builder().spawn(ep2.clone());
1180        let go3 = Gossip::builder().spawn(ep3.clone());
1181        debug!("peer1 {:?}", ep1.node_id());
1182        debug!("peer2 {:?}", ep2.node_id());
1183        debug!("peer3 {:?}", ep3.node_id());
1184        let pi1 = ep1.node_id();
1185        let pi2 = ep2.node_id();
1186
1187        let cancel = CancellationToken::new();
1188        let tasks = [
1189            spawn(endpoint_loop(ep1.clone(), go1.clone(), cancel.clone())),
1190            spawn(endpoint_loop(ep2.clone(), go2.clone(), cancel.clone())),
1191            spawn(endpoint_loop(ep3.clone(), go3.clone(), cancel.clone())),
1192        ];
1193
1194        debug!("----- adding peers  ----- ");
1195        let topic: TopicId = blake3::hash(b"foobar").into();
1196
1197        let addr1 = NodeAddr::new(pi1).with_relay_url(relay_url.clone());
1198        let addr2 = NodeAddr::new(pi2).with_relay_url(relay_url);
1199        static_provider.add_node_info(addr1.clone());
1200        static_provider.add_node_info(addr2.clone());
1201
1202        debug!("----- joining  ----- ");
1203        // join the topics and wait for the connection to succeed
1204        let [sub1, mut sub2, mut sub3] = [
1205            go1.subscribe_and_join(topic, vec![]),
1206            go2.subscribe_and_join(topic, vec![pi1]),
1207            go3.subscribe_and_join(topic, vec![pi2]),
1208        ]
1209        .try_join()
1210        .await
1211        .unwrap();
1212
1213        let (sink1, _stream1) = sub1.split();
1214
1215        let len = 2;
1216
1217        // publish messages on node1
1218        let pub1 = spawn(async move {
1219            for i in 0..len {
1220                let message = format!("hi{i}");
1221                info!("go1 broadcast: {message:?}");
1222                sink1.broadcast(message.into_bytes().into()).await.unwrap();
1223                tokio::time::sleep(Duration::from_micros(1)).await;
1224            }
1225        });
1226
1227        // wait for messages on node2
1228        let sub2 = spawn(async move {
1229            let mut recv = vec![];
1230            loop {
1231                let ev = sub2.next().await.unwrap().unwrap();
1232                info!("go2 event: {ev:?}");
1233                if let Event::Received(msg) = ev {
1234                    recv.push(msg.content);
1235                }
1236                if recv.len() == len {
1237                    return recv;
1238                }
1239            }
1240        });
1241
1242        // wait for messages on node3
1243        let sub3 = spawn(async move {
1244            let mut recv = vec![];
1245            loop {
1246                let ev = sub3.next().await.unwrap().unwrap();
1247                info!("go3 event: {ev:?}");
1248                if let Event::Received(msg) = ev {
1249                    recv.push(msg.content);
1250                }
1251                if recv.len() == len {
1252                    return recv;
1253                }
1254            }
1255        });
1256
1257        timeout(Duration::from_secs(10), pub1)
1258            .await
1259            .unwrap()
1260            .unwrap();
1261        let recv2 = timeout(Duration::from_secs(10), sub2)
1262            .await
1263            .unwrap()
1264            .unwrap();
1265        let recv3 = timeout(Duration::from_secs(10), sub3)
1266            .await
1267            .unwrap()
1268            .unwrap();
1269
1270        let expected: Vec<Bytes> = (0..len)
1271            .map(|i| Bytes::from(format!("hi{i}").into_bytes()))
1272            .collect();
1273        assert_eq!(recv2, expected);
1274        assert_eq!(recv3, expected);
1275
1276        cancel.cancel();
1277        for t in tasks {
1278            timeout(Duration::from_secs(10), t)
1279                .await
1280                .unwrap()
1281                .unwrap()
1282                .unwrap();
1283        }
1284    }
1285
1286    /// Test that when a gossip topic is no longer needed it's actually unsubscribed.
1287    ///
1288    /// This test will:
1289    /// - Create two endpoints, the first using manual event loop.
1290    /// - Subscribe both nodes to the same topic. The first node will subscribe twice and connect
1291    ///   to the second node. The second node will subscribe without bootstrap.
1292    /// - Ensure that the first node removes the subscription iff all topic handles have been
1293    ///   dropped.
1294    // NOTE: this is a regression test.
1295    #[tokio::test]
1296    #[traced_test]
1297    async fn subscription_cleanup() -> Result {
1298        let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1299        let ct = CancellationToken::new();
1300        let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1301
1302        // create the first node with a manual actor loop
1303        let (go1, mut actor, ep1_handle) =
1304            Gossip::t_new_with_actor(rng, Default::default(), relay_map.clone(), &ct).await?;
1305
1306        // create the second node with the usual actor loop
1307        let (go2, ep2, ep2_handle, _test_actor_handle) =
1308            Gossip::t_new(rng, Default::default(), relay_map, &ct).await?;
1309
1310        let node_id1 = actor.endpoint().node_id();
1311        let node_id2 = ep2.node_id();
1312        tracing::info!(
1313            node_1 = %node_id1.fmt_short(),
1314            node_2 = %node_id2.fmt_short(),
1315            "nodes ready"
1316        );
1317
1318        let topic: TopicId = blake3::hash(b"subscription_cleanup").into();
1319        tracing::info!(%topic, "joining");
1320
1321        // create the tasks for each gossip instance:
1322        // - second node subscribes once without bootstrap and listens to events
1323        // - first node subscribes twice with the second node as bootstrap. This is done on command
1324        //   from the main task (this)
1325
1326        // second node
1327        let ct2 = ct.clone();
1328        let go2_task = async move {
1329            let (_pub_tx, mut sub_rx) = go2.subscribe_and_join(topic, vec![]).await?.split();
1330
1331            let subscribe_fut = async {
1332                while let Some(ev) = sub_rx.try_next().await? {
1333                    match ev {
1334                        Event::Lagged => tracing::debug!("missed some messages :("),
1335                        Event::Received(_) => unreachable!("test does not send messages"),
1336                        other => tracing::debug!(?other, "gs event"),
1337                    }
1338                }
1339
1340                tracing::debug!("subscribe stream ended");
1341                Ok::<_, n0_snafu::Error>(())
1342            };
1343
1344            tokio::select! {
1345                _ = ct2.cancelled() => Ok(()),
1346                res = subscribe_fut => res,
1347            }
1348        }
1349        .instrument(tracing::debug_span!("node_2", node_id2=%node_id2.fmt_short()));
1350        let go2_handle = task::spawn(go2_task);
1351
1352        // first node
1353        let addr2 = NodeAddr::new(node_id2).with_relay_url(relay_url);
1354        let static_provider = StaticProvider::new();
1355        static_provider.add_node_info(addr2);
1356        actor.endpoint().discovery().add(static_provider);
1357        // we use a channel to signal advancing steps to the task
1358        let (tx, mut rx) = mpsc::channel::<()>(1);
1359        let ct1 = ct.clone();
1360        let go1_task = async move {
1361            // first subscribe is done immediately
1362            tracing::info!("subscribing the first time");
1363            let sub_1a = go1.subscribe_and_join(topic, vec![node_id2]).await?;
1364
1365            // wait for signal to subscribe a second time
1366            rx.recv().await.expect("signal for second subscribe");
1367            tracing::info!("subscribing a second time");
1368            let sub_1b = go1.subscribe_and_join(topic, vec![node_id2]).await?;
1369            drop(sub_1a);
1370
1371            // wait for signal to drop the second handle as well
1372            rx.recv().await.expect("signal for second subscribe");
1373            tracing::info!("dropping all handles");
1374            drop(sub_1b);
1375
1376            // wait for cancellation
1377            ct1.cancelled().await;
1378            drop(go1);
1379
1380            Ok::<_, n0_snafu::Error>(())
1381        }
1382        .instrument(tracing::debug_span!("node_1", node_id1 = %node_id1.fmt_short()));
1383        let go1_handle = task::spawn(go1_task);
1384
1385        // advance and check that the topic is now subscribed
1386        actor.steps(4).await?; // api_rx subscribe;
1387                               // internal_rx connection request (from topic actor);
1388                               // dialer connected;
1389                               // internal_rx update peer data (from topic actor);
1390        tracing::info!("subscribe and join done, should be joined");
1391        let state = actor.topics.get(&topic).expect("get registered topic");
1392        assert!(state.joined());
1393
1394        // signal the second subscribe, we should remain subscribed
1395        tx.send(()).await.e()?;
1396        actor.steps(1).await?; // api_rx subscribe;
1397        let state = actor.topics.get(&topic).expect("get registered topic");
1398        assert!(state.joined());
1399
1400        // signal to drop the second handle, the topic should no longer be subscribed
1401        tx.send(()).await.e()?;
1402        actor.steps(1).await?; // topic task finished
1403
1404        assert!(!actor.topics.contains_key(&topic));
1405
1406        // cleanup and ensure everything went as expected
1407        ct.cancel();
1408        let wait = Duration::from_secs(5);
1409        timeout(wait, ep1_handle).await.e()?;
1410        timeout(wait, ep2_handle).await.e()?;
1411        timeout(wait, go1_handle).await.e()?.e()??;
1412        timeout(wait, go2_handle).await.e()?.e()??;
1413        timeout(wait, actor.finish()).await.e()?;
1414
1415        Ok(())
1416    }
1417
1418    /// Test that nodes can reconnect to each other.
1419    ///
1420    /// This test will create two nodes subscribed to the same topic. The second node will
1421    /// unsubscribe and then resubscribe and connection between the nodes should succeed both
1422    /// times.
1423    // NOTE: This is a regression test
1424    #[tokio::test(flavor = "multi_thread")]
1425    #[traced_test]
1426    async fn can_reconnect() -> Result {
1427        let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1428        let ct = CancellationToken::new();
1429        let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1430
1431        let (go1, ep1, ep1_handle, _test_actor_handle1) =
1432            Gossip::t_new(rng, Default::default(), relay_map.clone(), &ct).await?;
1433
1434        let (go2, ep2, ep2_handle, _test_actor_handle2) =
1435            Gossip::t_new(rng, Default::default(), relay_map, &ct).await?;
1436
1437        let node_id1 = ep1.node_id();
1438        let node_id2 = ep2.node_id();
1439        tracing::info!(
1440            node_1 = %node_id1.fmt_short(),
1441            node_2 = %node_id2.fmt_short(),
1442            "nodes ready"
1443        );
1444
1445        let topic: TopicId = blake3::hash(b"can_reconnect").into();
1446        tracing::info!(%topic, "joining");
1447
1448        // channel used to signal the second gossip instance to advance the test
1449        let (tx, mut rx) = mpsc::channel::<()>(1);
1450        let addr1 = NodeAddr::new(node_id1).with_relay_url(relay_url.clone());
1451        let static_provider = StaticProvider::new();
1452        static_provider.add_node_info(addr1);
1453        ep2.discovery().add(static_provider.clone());
1454        let go2_task = async move {
1455            let mut sub = go2.subscribe(topic, Vec::new()).await?;
1456            sub.joined().await?;
1457
1458            rx.recv().await.expect("signal to unsubscribe");
1459            tracing::info!("unsubscribing");
1460            drop(sub);
1461
1462            rx.recv().await.expect("signal to subscribe again");
1463            tracing::info!("resubscribing");
1464            let mut sub = go2.subscribe(topic, vec![node_id1]).await?;
1465
1466            sub.joined().await?;
1467
1468            Result::<_, ApiError>::Ok(())
1469        }
1470        .instrument(tracing::debug_span!("node_2", node_id2=%node_id2.fmt_short()));
1471        let go2_handle = task::spawn(go2_task);
1472
1473        let addr2 = NodeAddr::new(node_id2).with_relay_url(relay_url);
1474        static_provider.add_node_info(addr2);
1475        ep1.discovery().add(static_provider);
1476
1477        let mut sub = go1.subscribe(topic, vec![node_id2]).await?;
1478        // wait for subscribed notification
1479        sub.joined().await?;
1480        info!("go1 joined");
1481
1482        // signal node_2 to unsubscribe
1483        tx.send(()).await.e()?;
1484
1485        info!("wait for neighbor down");
1486        // we should receive a Neighbor down event
1487        let conn_timeout = Duration::from_millis(1000);
1488        let ev = timeout(conn_timeout, sub.try_next()).await.e()??;
1489        assert_eq!(ev, Some(Event::NeighborDown(node_id2)));
1490        tracing::info!("node 2 left");
1491
1492        // signal node_2 to subscribe again
1493        tx.send(()).await.e()?;
1494
1495        info!("wait for neighbor up");
1496        let conn_timeout = Duration::from_millis(1000);
1497        let ev = timeout(conn_timeout, sub.try_next()).await.e()??;
1498        assert_eq!(ev, Some(Event::NeighborUp(node_id2)));
1499        tracing::info!("node 2 rejoined!");
1500
1501        // wait for go2 to also be rejoined, then the task terminates
1502        let wait = Duration::from_secs(2);
1503        timeout(wait, go2_handle).await.e()?.e()??;
1504        ct.cancel();
1505        // cleanup and ensure everything went as expected
1506        timeout(wait, ep1_handle).await.e()?;
1507        timeout(wait, ep2_handle).await.e()?;
1508
1509        Ok(())
1510    }
1511
1512    #[tokio::test]
1513    #[traced_test]
1514    async fn can_die_and_reconnect() -> Result {
1515        /// Runs a future in a separate runtime on a separate thread, cancelling everything
1516        /// abruptly once `cancel` is invoked.
1517        fn run_in_thread<T: Send + 'static>(
1518            cancel: CancellationToken,
1519            fut: impl std::future::Future<Output = T> + Send + 'static,
1520        ) -> std::thread::JoinHandle<Option<T>> {
1521            std::thread::spawn(move || {
1522                let rt = tokio::runtime::Builder::new_current_thread()
1523                    .enable_all()
1524                    .build()
1525                    .unwrap();
1526                rt.block_on(async move { cancel.run_until_cancelled(fut).await })
1527            })
1528        }
1529
1530        /// Spawns a new endpoint and gossip instance.
1531        async fn spawn_gossip(
1532            secret_key: SecretKey,
1533            relay_map: RelayMap,
1534        ) -> Result<(Router, Gossip), BindError> {
1535            let ep = Endpoint::builder()
1536                .secret_key(secret_key)
1537                .relay_mode(RelayMode::Custom(relay_map))
1538                .insecure_skip_relay_cert_verify(true)
1539                .bind()
1540                .await?;
1541            let gossip = Gossip::builder().spawn(ep.clone());
1542            let router = Router::builder(ep).accept(ALPN, gossip.clone()).spawn();
1543            Ok((router, gossip))
1544        }
1545
1546        /// Spawns a gossip node, and broadcasts a single message, then sleep until cancelled externally.
1547        async fn broadcast_once(
1548            secret_key: SecretKey,
1549            relay_map: RelayMap,
1550            bootstrap_addr: NodeAddr,
1551            topic_id: TopicId,
1552            message: String,
1553        ) -> Result {
1554            let (router, gossip) = spawn_gossip(secret_key, relay_map).await?;
1555            info!(node_id = %router.endpoint().node_id().fmt_short(), "broadcast node spawned");
1556            let bootstrap = vec![bootstrap_addr.node_id];
1557            let static_provider = StaticProvider::new();
1558            static_provider.add_node_info(bootstrap_addr);
1559            router.endpoint().discovery().add(static_provider);
1560            let mut topic = gossip.subscribe_and_join(topic_id, bootstrap).await?;
1561            topic.broadcast(message.as_bytes().to_vec().into()).await?;
1562            std::future::pending::<()>().await;
1563            Ok(())
1564        }
1565
1566        let (relay_map, _relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1567        let mut rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1568        let topic_id = TopicId::from_bytes(rng.random());
1569
1570        // spawn a gossip node, send the node's address on addr_tx,
1571        // then wait to receive `count` messages, and terminate.
1572        let (addr_tx, addr_rx) = tokio::sync::oneshot::channel();
1573        let (msgs_recv_tx, mut msgs_recv_rx) = tokio::sync::mpsc::channel(3);
1574        let recv_task = tokio::task::spawn({
1575            let relay_map = relay_map.clone();
1576            let secret_key = SecretKey::generate(&mut rng);
1577            async move {
1578                let (router, gossip) = spawn_gossip(secret_key, relay_map).await?;
1579                // wait for the relay to be set. iroh currently has issues when trying
1580                // to immediately reconnect with changed direct addresses, but when the
1581                // relay path is available it works.
1582                // See https://github.com/n0-computer/iroh/pull/3372
1583                router.endpoint().online().await;
1584                let addr = router.endpoint().node_addr();
1585                info!(node_id = %addr.node_id.fmt_short(), "recv node spawned");
1586                addr_tx.send(addr).unwrap();
1587                let mut topic = gossip.subscribe_and_join(topic_id, vec![]).await?;
1588                while let Some(event) = topic.try_next().await.unwrap() {
1589                    if let Event::Received(message) = event {
1590                        let message = std::str::from_utf8(&message.content).e()?.to_string();
1591                        msgs_recv_tx.send(message).await.e()?;
1592                    }
1593                }
1594                Result::<_, n0_snafu::Error>::Ok(())
1595            }
1596        });
1597
1598        let node0_addr = addr_rx.await.e()?;
1599        let max_wait = Duration::from_secs(5);
1600
1601        // spawn a node, send a message, and then abruptly terminate the node ungracefully
1602        // after the message was received on our receiver node.
1603        let cancel = CancellationToken::new();
1604        let secret = SecretKey::generate(&mut rng);
1605        let join_handle_1 = run_in_thread(
1606            cancel.clone(),
1607            broadcast_once(
1608                secret.clone(),
1609                relay_map.clone(),
1610                node0_addr.clone(),
1611                topic_id,
1612                "msg1".to_string(),
1613            ),
1614        );
1615        // assert that we received the message on the receiver node.
1616        let msg = timeout(max_wait, msgs_recv_rx.recv()).await.e()?.unwrap();
1617        assert_eq!(&msg, "msg1");
1618        info!("kill broadcast node");
1619        cancel.cancel();
1620
1621        // spawns the node again with the same node id, and send another message
1622        let cancel = CancellationToken::new();
1623        let join_handle_2 = run_in_thread(
1624            cancel.clone(),
1625            broadcast_once(
1626                secret.clone(),
1627                relay_map.clone(),
1628                node0_addr.clone(),
1629                topic_id,
1630                "msg2".to_string(),
1631            ),
1632        );
1633        // assert that we received the message on the receiver node.
1634        // this means that the reconnect with the same node id worked.
1635        let msg = timeout(max_wait, msgs_recv_rx.recv()).await.e()?.unwrap();
1636        assert_eq!(&msg, "msg2");
1637        info!("kill broadcast node");
1638        cancel.cancel();
1639
1640        info!("kill recv node");
1641        recv_task.abort();
1642        assert!(join_handle_1.join().unwrap().is_none());
1643        assert!(join_handle_2.join().unwrap().is_none());
1644
1645        Ok(())
1646    }
1647
1648    #[tokio::test]
1649    #[traced_test]
1650    async fn gossip_change_alpn() -> n0_snafu::Result<()> {
1651        let alpn = b"my-gossip-alpn";
1652        let topic_id = TopicId::from([0u8; 32]);
1653
1654        let ep1 = Endpoint::builder().bind().await?;
1655        let ep2 = Endpoint::builder().bind().await?;
1656        let gossip1 = Gossip::builder().alpn(alpn).spawn(ep1.clone());
1657        let gossip2 = Gossip::builder().alpn(alpn).spawn(ep2.clone());
1658        let router1 = Router::builder(ep1).accept(alpn, gossip1.clone()).spawn();
1659        let router2 = Router::builder(ep2).accept(alpn, gossip2.clone()).spawn();
1660
1661        let addr1 = router1.endpoint().node_addr();
1662        let id1 = addr1.node_id;
1663        let static_provider = StaticProvider::new();
1664        static_provider.add_node_info(addr1);
1665        router2.endpoint().discovery().add(static_provider);
1666
1667        let mut topic1 = gossip1.subscribe(topic_id, vec![]).await?;
1668        let mut topic2 = gossip2.subscribe(topic_id, vec![id1]).await?;
1669
1670        timeout(Duration::from_secs(3), topic1.joined())
1671            .await
1672            .e()??;
1673        timeout(Duration::from_secs(3), topic2.joined())
1674            .await
1675            .e()??;
1676        router1.shutdown().await.e()?;
1677        router2.shutdown().await.e()?;
1678        Ok(())
1679    }
1680
1681    #[tokio::test]
1682    #[traced_test]
1683    async fn gossip_rely_on_gossip_discovery() -> n0_snafu::Result<()> {
1684        let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1685
1686        async fn spawn(
1687            rng: &mut impl CryptoRng,
1688        ) -> n0_snafu::Result<(NodeId, Router, Gossip, GossipSender, GossipReceiver)> {
1689            let topic_id = TopicId::from([0u8; 32]);
1690            let ep = Endpoint::builder()
1691                .secret_key(SecretKey::generate(rng))
1692                .relay_mode(RelayMode::Disabled)
1693                .bind()
1694                .await?;
1695            let node_id = ep.node_id();
1696            let gossip = Gossip::builder().spawn(ep.clone());
1697            let router = Router::builder(ep)
1698                .accept(GOSSIP_ALPN, gossip.clone())
1699                .spawn();
1700            let topic = gossip.subscribe(topic_id, vec![]).await?;
1701            let (sender, receiver) = topic.split();
1702            Ok((node_id, router, gossip, sender, receiver))
1703        }
1704
1705        // spawn 3 nodes without relay or discovery
1706        let (n1, r1, _g1, _tx1, mut rx1) = spawn(rng).await?;
1707        let (n2, r2, _g2, tx2, mut rx2) = spawn(rng).await?;
1708        let (n3, r3, _g3, tx3, mut rx3) = spawn(rng).await?;
1709
1710        println!("nodes {:?}", [n1, n2, n3]);
1711
1712        // create a static discovery that has only node 1 addr info set
1713        let addr1 = r1.endpoint().node_addr();
1714        let disco = StaticProvider::new();
1715        disco.add_node_info(addr1);
1716
1717        // add addr info of node1 to node2 and join node1
1718        r2.endpoint().discovery().add(disco.clone());
1719        tx2.join_peers(vec![n1]).await?;
1720
1721        // await join node2 -> nodde1
1722        timeout(Duration::from_secs(3), rx1.joined()).await.e()??;
1723        timeout(Duration::from_secs(3), rx2.joined()).await.e()??;
1724
1725        // add addr info of node1 to node3 and join node1
1726        r3.endpoint().discovery().add(disco.clone());
1727        tx3.join_peers(vec![n1]).await?;
1728
1729        // await join at node3: n1 and n2
1730        // n2 only works because because we use gossip discovery!
1731        let ev = timeout(Duration::from_secs(3), rx3.next()).await.e()?;
1732        assert!(matches!(ev, Some(Ok(Event::NeighborUp(_)))));
1733        let ev = timeout(Duration::from_secs(3), rx3.next()).await.e()?;
1734        assert!(matches!(ev, Some(Ok(Event::NeighborUp(_)))));
1735
1736        assert_eq!(sorted(rx3.neighbors()), sorted([n1, n2]));
1737
1738        let ev = timeout(Duration::from_secs(3), rx2.next()).await.e()?;
1739        assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3));
1740
1741        let ev = timeout(Duration::from_secs(3), rx1.next()).await.e()?;
1742        assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3));
1743
1744        tokio::try_join!(r1.shutdown(), r2.shutdown(), r3.shutdown()).e()?;
1745        Ok(())
1746    }
1747
1748    fn sorted<T: Ord>(input: impl IntoIterator<Item = T>) -> Vec<T> {
1749        let mut out: Vec<_> = input.into_iter().collect();
1750        out.sort();
1751        out
1752    }
1753}