1#[cfg(test)]
4use std::sync::atomic::AtomicBool;
5use std::{
6 collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque},
7 ops::DerefMut,
8 sync::Arc,
9 time::Duration,
10};
11
12use bytes::Bytes;
13use iroh::{
14 endpoint::{ConnectError, Connection},
15 protocol::{AcceptError, ProtocolHandler},
16 Endpoint, NodeAddr, NodeId, Watcher,
17};
18use irpc::{
19 channel::{self, mpsc::RecvError},
20 WithChannels,
21};
22use n0_future::{
23 stream::Boxed as BoxStream,
24 task::{self, AbortOnDropHandle},
25 time::Instant,
26 MergeUnbounded, Stream, StreamExt,
27};
28use n0_watcher::{Direct, Watchable};
29use rand::rngs::StdRng;
30use snafu::Snafu;
31use tokio::{
32 sync::{broadcast, mpsc},
33 task::JoinSet,
34};
35use tracing::{debug, error_span, info, instrument, trace, warn, Instrument};
36
37use self::{
38 dialer::Dialer,
39 discovery::GossipDiscovery,
40 net_proto::GossipMessage,
41 util::{AddrInfo, ConnectionCounter, Guarded, IrohRemoteConnection, Timers},
42};
43use crate::{
44 api::{self, GossipApi},
45 metrics::{inc, Metrics},
46 net::util::accept_stream,
47 proto::{self, Config, HyparviewConfig, PeerData, PlumtreeConfig, TopicId},
48};
49
50mod dialer;
51mod discovery;
52mod util;
53
54pub const GOSSIP_ALPN: &[u8] = b"/iroh-gossip/1";
56
57type InEvent = proto::topic::InEvent<NodeId>;
58type OutEvent = proto::topic::OutEvent<NodeId>;
59type Timer = proto::topic::Timer<NodeId>;
60type ProtoMessage = proto::topic::Message<NodeId>;
61type ProtoEvent = proto::topic::Event<NodeId>;
62type State = proto::topic::State<NodeId, StdRng>;
63type Command = proto::topic::Command<NodeId>;
64
65#[derive(Debug, Clone)]
83pub struct Gossip(Arc<Inner>);
84
85impl std::ops::Deref for Gossip {
86 type Target = GossipApi;
87 fn deref(&self) -> &Self::Target {
88 &self.0.api
89 }
90}
91
92#[derive(derive_more::Debug)]
93enum LocalActorMessage {
94 #[debug("HandleConnection({})", _0.fmt_short())]
95 HandleConnection(NodeId, Connection),
96 #[debug("Connect({}, {})", _0.fmt_short(), _1.fmt_short())]
97 Connect(NodeId, TopicId),
98 #[debug("SetPeerData({}, {})", _0.fmt_short(), _1.as_bytes().len())]
99 SetPeerData(NodeId, PeerData),
100}
101
102#[derive(Debug)]
103struct Inner {
104 api: GossipApi,
105 local_tx: mpsc::Sender<LocalActorMessage>,
106 _actor_handle: AbortOnDropHandle<()>,
107 max_message_size: usize,
108 metrics: Arc<Metrics>,
109}
110
111impl ProtocolHandler for Gossip {
112 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
113 let remote = connection.remote_node_id()?;
114 self.handle_connection(remote, connection)
115 .await
116 .map_err(AcceptError::from_err)?;
117 Ok(())
118 }
119
120 async fn shutdown(&self) {
121 }
123}
124
125#[derive(Debug, Clone)]
127pub struct Builder {
128 config: proto::Config,
129 alpn: Option<Bytes>,
130}
131
132impl Builder {
133 pub fn max_message_size(mut self, size: usize) -> Self {
136 self.config.max_message_size = size;
137 self
138 }
139
140 pub fn membership_config(mut self, config: HyparviewConfig) -> Self {
142 self.config.membership = config;
143 self
144 }
145
146 pub fn broadcast_config(mut self, config: PlumtreeConfig) -> Self {
148 self.config.broadcast = config;
149 self
150 }
151
152 pub fn alpn(mut self, alpn: impl AsRef<[u8]>) -> Self {
160 self.alpn = Some(alpn.as_ref().to_vec().into());
161 self
162 }
163
164 pub fn spawn(self, endpoint: Endpoint) -> Gossip {
166 Gossip::new(endpoint, self.config, self.alpn)
167 }
168}
169
170impl Gossip {
171 pub fn builder() -> Builder {
173 Builder {
174 config: Default::default(),
175 alpn: None,
176 }
177 }
178
179 #[cfg(feature = "rpc")]
181 pub async fn listen(self, endpoint: quinn::Endpoint) {
182 self.0.api.listen(endpoint).await
183 }
184
185 pub fn max_message_size(&self) -> usize {
187 self.0.max_message_size
188 }
189
190 pub async fn handle_connection(
194 &self,
195 remote: NodeId,
196 connection: Connection,
197 ) -> Result<(), ActorStoppedError> {
198 self.0
199 .local_tx
200 .send(LocalActorMessage::HandleConnection(remote, connection))
201 .await
202 .map_err(|_| ActorStoppedSnafu.build())?;
203 Ok(())
204 }
205
206 pub fn metrics(&self) -> &Arc<Metrics> {
208 &self.0.metrics
209 }
210
211 fn new(endpoint: Endpoint, config: Config, alpn: Option<Bytes>) -> Self {
212 let metrics = Arc::new(Metrics::default());
213 let max_message_size = config.max_message_size;
214 let me = endpoint.node_id();
215 let (api_tx, local_tx, actor) = Actor::new(endpoint, config, alpn, metrics.clone());
216 let actor_task = task::spawn(
217 actor
218 .run()
219 .instrument(error_span!("gossip", me=%me.fmt_short())),
220 );
221
222 Self(Arc::new(Inner {
223 local_tx,
224 max_message_size,
225 api: GossipApi::local(api_tx),
226 metrics,
227 _actor_handle: AbortOnDropHandle::new(actor_task),
228 }))
229 }
230
231 #[cfg(test)]
232 fn new_with_actor(endpoint: Endpoint, config: Config, alpn: Option<Bytes>) -> (Self, Actor) {
233 let metrics = Arc::new(Metrics::default());
234 let max_message_size = config.max_message_size;
235 let (api_tx, local_tx, actor) = Actor::new(endpoint, config, alpn, metrics.clone());
236 let handle = Self(Arc::new(Inner {
237 local_tx,
238 max_message_size,
239 api: GossipApi::local(api_tx),
240 metrics,
241 _actor_handle: AbortOnDropHandle::new(task::spawn(std::future::pending())),
242 }));
243 (handle, actor)
244 }
245}
246
247mod net_proto {
248 use irpc::{channel::mpsc, rpc_requests};
249 use serde::{Deserialize, Serialize};
250
251 use crate::proto::TopicId;
252
253 #[derive(Debug, Serialize, Deserialize, Clone)]
254 #[non_exhaustive]
255 pub struct JoinRequest {
256 pub topic_id: TopicId,
257 }
258
259 #[rpc_requests(message = GossipMessage)]
260 #[derive(Debug, Serialize, Deserialize)]
261 pub enum Request {
262 #[rpc(tx=mpsc::Sender<super::ProtoMessage>, rx=mpsc::Receiver<super::ProtoMessage>)]
263 Join(JoinRequest),
264 }
265}
266
267#[derive(Debug, Snafu)]
269pub struct ActorStoppedError;
270
271#[derive(strum::Display)]
272enum ActorToTopic {
273 Api(ApiJoinRequest),
274 Connected {
275 remote: NodeId,
276 tx: Guarded<channel::mpsc::Sender<ProtoMessage>>,
277 rx: Guarded<channel::mpsc::Receiver<ProtoMessage>>,
278 },
279 ConnectionFailed(NodeId),
280}
281
282type ApiJoinRequest = WithChannels<api::JoinRequest, api::Request>;
283type ApiRecvStream = BoxStream<Result<api::Command, RecvError>>;
284type RemoteRecvStream = BoxStream<(NodeId, Result<Option<ProtoMessage>, RecvError>)>;
285type AcceptRemoteRequestsStream =
286 MergeUnbounded<BoxStream<(NodeId, std::io::Result<Guarded<net_proto::GossipMessage>>)>>;
287
288struct Actor {
289 me: NodeId,
290 endpoint: Endpoint,
291 alpn: Bytes,
292 config: Config,
293 local_rx: mpsc::Receiver<LocalActorMessage>,
294 local_tx: mpsc::Sender<LocalActorMessage>,
295 api_rx: mpsc::Receiver<api::RpcMessage>,
296 topics: HashMap<TopicId, TopicHandle>,
297 pending_remotes_with_topics: HashMap<NodeId, HashSet<TopicId>>,
298 topic_tasks: JoinSet<TopicActor>,
299 remotes: HashMap<NodeId, RemoteState>,
300 close_connections: JoinSet<(NodeId, Connection)>,
301 dialer: Dialer,
302 our_peer_data: n0_watcher::Watchable<PeerData>,
303 metrics: Arc<Metrics>,
304 node_addr_updates: BoxStream<NodeAddr>,
305 accepting: AcceptRemoteRequestsStream,
306 discovery: GossipDiscovery,
307}
308
309impl Actor {
310 fn new(
311 endpoint: Endpoint,
312 config: Config,
313 alpn: Option<Bytes>,
314 metrics: Arc<Metrics>,
315 ) -> (
316 mpsc::Sender<api::RpcMessage>,
317 mpsc::Sender<LocalActorMessage>,
318 Self,
319 ) {
320 let (api_tx, api_rx) = tokio::sync::mpsc::channel(16);
321 let (local_tx, local_rx) = tokio::sync::mpsc::channel(16);
322
323 let me = endpoint.node_id();
324 let node_addr_updates = endpoint.watch_node_addr().stream();
325 let discovery = GossipDiscovery::default();
326 endpoint.discovery().add(discovery.clone());
327 let initial_peer_data = AddrInfo::from(endpoint.node_addr()).encode();
328 (
333 api_tx,
334 local_tx.clone(),
335 Actor {
336 endpoint,
337 me,
338 config,
339 api_rx,
340 local_tx,
341 local_rx,
342 node_addr_updates: Box::pin(node_addr_updates),
343 dialer: Dialer::default(),
344 our_peer_data: Watchable::new(initial_peer_data),
345 alpn: alpn.unwrap_or_else(|| crate::ALPN.to_vec().into()),
346 metrics: metrics.clone(),
347 topics: Default::default(),
348 pending_remotes_with_topics: Default::default(),
349 remotes: Default::default(),
350 close_connections: JoinSet::new(),
351 topic_tasks: JoinSet::new(),
352 accepting: Default::default(),
353 discovery,
354 },
355 )
356 }
357
358 async fn run(mut self) {
359 while self.tick().await {}
360 }
361
362 #[cfg(test)]
363 #[instrument("gossip", skip_all, fields(me=%self.me.fmt_short()))]
364 pub(crate) async fn finish(self) {
365 self.run().await
366 }
367
368 #[cfg(test)]
369 #[instrument("gossip", skip_all, fields(me=%self.me.fmt_short()))]
370 pub(crate) async fn steps(&mut self, n: usize) -> Result<(), ActorStoppedError> {
371 for _ in 0..n {
372 if !self.tick().await {
373 return Err(ActorStoppedError);
374 }
375 }
376 Ok(())
377 }
378
379 async fn tick(&mut self) -> bool {
380 trace!("wait for tick");
381 self.metrics.actor_tick_main.inc();
382 tokio::select! {
383 addr = self.node_addr_updates.next() => {
384 trace!("tick: node_addr_update");
385 match addr {
386 None => {
387 warn!("address stream returned None - endpoint has shut down");
388 false
389 }
390 Some(addr) => {
391 let data = AddrInfo::from(addr).encode();
392 self.our_peer_data.set(data).ok();
393 true
394 }
395 }
396 }
397 Some(msg) = self.local_rx.recv() => {
398 trace!("tick: local_rx {msg:?}");
399 match msg {
400 LocalActorMessage::HandleConnection(node_id, connection) => {
401 self.handle_remote_connection(node_id, Ok(connection), Direction::Accept).await;
402 }
403 LocalActorMessage::Connect(node_id, topic_id) => {
404 self.connect(node_id, topic_id);
405 }
406 LocalActorMessage::SetPeerData(node_id, data) => {
407 match AddrInfo::decode(&data) {
408 Err(err) => warn!(remote=%node_id.fmt_short(), ?err, len=data.inner().len(), "Failed to decode peer data"),
409 Ok(info) => {
410 debug!(peer = ?node_id, "add known addrs: {info:?}");
411 let node_addr = info.into_node_addr(node_id);
412 self.discovery.add(node_addr);
413 }
414 }
415 }
416 }
417 true
418 }
419 Some((node_id, res)) = self.dialer.next(), if !self.dialer.is_empty() => {
420 trace!(remote=%node_id.fmt_short(), ok=res.is_ok(), "tick: dialed");
421 self.handle_remote_connection(node_id, res, Direction::Dial).await;
422 true
423 }
424 Some((node_id, res)) = self.accepting.next(), if !self.accepting.is_empty() => {
425 trace!(remote=%node_id.fmt_short(), res=?res.as_ref().map(|_| ()), "tick: accepting");
426 match res {
427 Ok(request) => self.handle_remote_message(node_id, request).await,
428 Err(reason) => {
429 debug!(remote=%node_id.fmt_short(), ?reason, "accept loop for remote closed");
430 }
431 }
432 true
433 }
434 msg = self.api_rx.recv() => {
435 trace!(some=msg.is_some(), "tick: api_rx");
436 match msg {
437 Some(msg) => {
438 self.handle_api_message(msg).await;
439 true
440 }
441 None => {
442 trace!("all api senders dropped, stop actor");
443 false
444 }
445 }
446 }
447 Some(res) = self.close_connections.join_next(), if !self.close_connections.is_empty() => {
448 let (node_id, connection) = res.expect("connection task panicked");
449 trace!(remote=%node_id.fmt_short(), "tick: connection closed");
450 if let Some(state) = self.remotes.get(&node_id) {
451 if state.same_connection(&connection) {
452 self.remotes.remove(&node_id);
453 }
454 }
455 true
456 }
457 Some(actor) = self.topic_tasks.join_next(), if !self.topic_tasks.is_empty() => {
458 let actor = actor.expect("topic actor task panicked");
459 trace!(topic=%actor.topic_id.fmt_short(), "tick: topic actor finished");
460 self.topics.remove(&actor.topic_id);
461 true
462 }
463 else => unreachable!("reached else arm, but all fallible cases should be handled"),
464 }
465 }
466
467 #[cfg(test)]
468 fn endpoint(&self) -> &Endpoint {
469 &self.endpoint
470 }
471
472 fn drain_pending_dials(
473 &mut self,
474 remote: &NodeId,
475 ) -> impl Iterator<Item = (TopicId, &TopicHandle)> {
476 self.pending_remotes_with_topics
477 .remove(remote)
478 .into_iter()
479 .flatten()
480 .flat_map(|topic_id| self.topics.get(&topic_id).map(|handle| (topic_id, handle)))
481 }
482
483 fn connect(&mut self, remote: NodeId, topic_id: TopicId) {
484 let Some(handle) = self.topics.get(&topic_id) else {
485 return;
486 };
487 if let Some(state) = self.remotes.get(&remote) {
488 let tx = handle.tx.clone();
489 let state = state.clone();
490 task::spawn(async move {
492 let msg = state.open_topic(topic_id).await;
493 tx.send(msg).await.ok();
494 });
495 } else {
496 self.dialer
497 .queue_dial(&self.endpoint, remote, self.alpn.clone());
498 self.pending_remotes_with_topics
499 .entry(remote)
500 .or_default()
501 .insert(topic_id);
502 }
503 }
504
505 #[instrument("connection", skip_all, fields(remote=%remote.fmt_short()))]
506 async fn handle_remote_connection(
507 &mut self,
508 remote: NodeId,
509 res: Result<Connection, ConnectError>,
510 direction: Direction,
511 ) {
512 match (res.as_ref(), direction) {
513 (Ok(_), Direction::Dial) => inc(&self.metrics.peers_dialed_success),
514 (Err(_), Direction::Dial) => inc(&self.metrics.peers_dialed_failure),
515 (Ok(_), Direction::Accept) => inc(&self.metrics.peers_accepted),
516 (Err(_), Direction::Accept) => {}
517 }
518 let connection = match res {
519 Err(err) => {
520 debug!(?err, "Connection failed");
521 for (_, handle) in self.drain_pending_dials(&remote) {
522 handle
523 .send(ActorToTopic::ConnectionFailed(remote))
524 .await
525 .ok();
526 }
527 return;
528 }
529 Ok(connection) => connection,
530 };
531
532 let state = RemoteState::new(remote, connection.clone(), direction);
533
534 for (topic_id, handle) in self.drain_pending_dials(&remote) {
536 let tx = handle.tx.clone();
537 let state = state.clone();
538 task::spawn(
539 async move {
540 let msg = state.open_topic(topic_id).await;
541 tx.send(msg).await.ok();
542 }
543 .instrument(tracing::Span::current()),
544 );
545 }
546
547 let counter = state.counter.clone();
549 self.accepting.push(Box::pin(
550 accept_stream::<net_proto::Request>(connection.clone())
551 .map(move |req| (remote, req.map(|r| counter.guard(r)))),
552 ));
553
554 let counter = state.counter.clone();
556 let fut = async move {
557 match direction {
558 Direction::Dial => {
559 counter.idle_for(Duration::from_millis(500)).await;
560 info!("close connection (from dial): unused");
561 connection.close(1u32.into(), b"idle");
562 }
563 Direction::Accept => {
564 let reason = connection.closed().await;
565 info!(?reason, "connection closed (from accept)")
566 }
567 }
568 (remote, connection)
569 };
570 self.close_connections
571 .spawn(fut.instrument(error_span!("conn", remote=%remote.fmt_short())));
572
573 self.remotes.insert(remote, state);
574 }
575
576 #[instrument("request", skip_all, fields(remote=%remote.fmt_short()))]
577 async fn handle_remote_message(&mut self, remote: NodeId, request: Guarded<GossipMessage>) {
578 let (request, guard) = request.split();
579 let (topic_id, request) = match request {
580 GossipMessage::Join(req) => (req.inner.topic_id, req),
581 };
582 if let Some(topic) = self.topics.get(&topic_id) {
583 if let Err(_err) = topic
584 .send(ActorToTopic::Connected {
585 remote,
586 tx: Guarded::new(request.tx, guard.clone()),
587 rx: Guarded::new(request.rx, guard.clone()),
588 })
589 .await
590 {
591 warn!(topic=%topic_id.fmt_short(), "Topic actor dead");
592 }
593 } else {
594 debug!(topic=%topic_id.fmt_short(), "ignore request: unknown topic");
595 }
596 }
597
598 async fn handle_api_message(&mut self, msg: api::RpcMessage) {
599 let (topic_id, msg) = match msg {
600 api::RpcMessage::Join(msg) => (msg.inner.topic_id, msg),
601 };
602 let topic = self.topics.entry(topic_id).or_insert_with(|| {
603 let (handle, actor) = TopicHandle::new(
604 self.me,
605 topic_id,
606 self.config.clone(),
607 self.local_tx.clone(),
608 self.our_peer_data.watch(),
609 self.metrics.clone(),
610 );
611 self.topic_tasks.spawn(
612 actor
613 .run()
614 .instrument(error_span!("topic", topic=%topic_id.fmt_short())),
615 );
616 handle
617 });
618 if topic.send(ActorToTopic::Api(msg)).await.is_err() {
619 warn!(topic=%topic_id.fmt_short(), "Topic actor dead");
620 }
621 }
622}
623
624#[derive(Clone)]
625struct RemoteState {
626 node_id: NodeId,
627 conn_id: usize,
628 client: irpc::Client<net_proto::Request>,
629 #[allow(dead_code)]
630 direction: Direction,
631 counter: ConnectionCounter,
632}
633
634impl RemoteState {
635 fn new(node_id: NodeId, connection: Connection, direction: Direction) -> Self {
636 let conn_id = connection.stable_id();
637 let irpc_conn = IrohRemoteConnection::new(connection);
638 let client = irpc::Client::boxed(irpc_conn);
639 let counter = ConnectionCounter::new();
640 RemoteState {
641 client,
642 direction,
643 conn_id,
644 counter,
645 node_id,
646 }
647 }
648
649 fn same_connection(&self, conn: &Connection) -> bool {
650 self.conn_id == conn.stable_id()
651 }
652
653 async fn open_topic(&self, topic_id: TopicId) -> ActorToTopic {
654 let guard = self.counter.get_one();
655 let req = net_proto::JoinRequest { topic_id };
656 match self.client.bidi_streaming(req.clone(), 64, 64).await {
657 Ok((tx, rx)) => ActorToTopic::Connected {
658 remote: self.node_id,
659 tx: Guarded::new(tx, guard.clone()),
660 rx: Guarded::new(rx, guard),
661 },
662 Err(err) => {
663 warn!(?topic_id, ?err, "failed to open stream with remote");
664 ActorToTopic::ConnectionFailed(self.node_id)
665 }
666 }
667 }
668}
669
670#[derive(Debug, Copy, Clone)]
671enum Direction {
672 Dial,
673 Accept,
674}
675
676struct TopicHandle {
677 tx: mpsc::Sender<ActorToTopic>,
678 #[cfg(test)]
679 joined: Arc<AtomicBool>,
680}
681
682impl TopicHandle {
683 fn new(
684 me: NodeId,
685 topic_id: TopicId,
686 config: proto::Config,
687 to_actor_tx: mpsc::Sender<LocalActorMessage>,
688 peer_data: Direct<PeerData>,
689 metrics: Arc<Metrics>,
690 ) -> (Self, TopicActor) {
691 let (tx, rx) = mpsc::channel(16);
692 let state = State::new(me, None, config);
694 #[cfg(test)]
695 let joined = Arc::new(AtomicBool::new(false));
696 let (forward_event_tx, _) = broadcast::channel(512);
697 let actor = TopicActor {
698 topic_id,
699 state,
700 actor_rx: rx,
701 to_actor_tx,
702 peer_data,
703 forward_event_tx,
704 metrics,
705 init: false,
706 #[cfg(test)]
707 joined: joined.clone(),
708 timers: Default::default(),
709 neighbors: Default::default(),
710 out_events: Default::default(),
711 api_receivers: Default::default(),
712 remote_senders: Default::default(),
713 remote_receivers: Default::default(),
714 drop_peers_queue: Default::default(),
715 forward_event_tasks: Default::default(),
716 };
717 let handle = Self {
718 tx,
719 #[cfg(test)]
720 joined,
721 };
722 (handle, actor)
723 }
724
725 async fn send(&self, msg: ActorToTopic) -> Result<(), mpsc::error::SendError<ActorToTopic>> {
726 self.tx.send(msg).await
727 }
728
729 #[cfg(test)]
730 fn joined(&self) -> bool {
731 self.joined.load(std::sync::atomic::Ordering::Relaxed)
732 }
733}
734
735struct TopicActor {
736 topic_id: TopicId,
737 to_actor_tx: mpsc::Sender<LocalActorMessage>,
738 state: State,
739 actor_rx: mpsc::Receiver<ActorToTopic>,
740 timers: Timers<Timer>,
741 neighbors: BTreeSet<NodeId>,
742 peer_data: Direct<PeerData>,
743 out_events: VecDeque<OutEvent>,
744 api_receivers: MergeUnbounded<ApiRecvStream>,
745 remote_senders: HashMap<NodeId, MaybeSender>,
746 remote_receivers: MergeUnbounded<RemoteRecvStream>,
747 forward_event_tx: broadcast::Sender<ProtoEvent>,
748 forward_event_tasks: JoinSet<()>,
749 #[cfg(test)]
750 joined: Arc<AtomicBool>,
751 init: bool,
752 drop_peers_queue: HashSet<NodeId>,
753 metrics: Arc<Metrics>,
754}
755
756impl TopicActor {
757 pub async fn run(mut self) -> Self {
758 self.metrics.topics_joined.inc();
759 let peer_data = self.peer_data.clone().stream();
760 tokio::pin!(peer_data);
761 loop {
762 tokio::select! {
763 Some(msg) = self.actor_rx.recv() => {
764 trace!("tick: actor_rx {msg}");
765 self.handle_actor_message(msg).await;
766 },
767 Some(cmd) = self.api_receivers.next(), if !self.api_receivers.is_empty() => {
768 self.handle_api_command(cmd).await;
769 }
770 Some((remote, message)) = self.remote_receivers.next(), if !self.remote_receivers.is_empty() => {
771 trace!(remote=%remote.fmt_short(), msg=?message, "tick: remote_rx");
772 self.handle_remote_message(remote, message).await;
773 }
774 Some(data) = peer_data.next() => {
775 self.handle_in_event(InEvent::UpdatePeerData(data)).await;
776 }
777 _ = self.timers.wait_next() => {
778 let now = Instant::now();
779 while let Some((_instant, timer)) = self.timers.pop_before(now) {
780 self.handle_in_event(InEvent::TimerExpired(timer)).await;
781 }
782 }
783 _ = self.forward_event_tasks.join_next(), if !self.forward_event_tasks.is_empty() => {}
784 else => break,
785 }
786
787 if !self.drop_peers_queue.is_empty() {
788 let now = Instant::now();
789 for peer in self.drop_peers_queue.drain() {
790 self.out_events
791 .extend(self.state.handle(InEvent::PeerDisconnected(peer), now));
792 }
793 self.process_out_events(now).await;
794 }
795
796 if self.to_actor_tx.is_closed() {
797 warn!("Channel to main actor closed: abort topic loop");
798 break;
799 }
800 if self.init && self.api_receivers.is_empty() && self.forward_event_tasks.is_empty() {
801 debug!("Closing topic: All API subscribers dropped");
802 break;
803 }
804 }
805 self.metrics.topics_quit.inc();
806 self
807 }
808
809 async fn handle_actor_message(&mut self, msg: ActorToTopic) {
810 match msg {
811 ActorToTopic::Connected { remote, rx, tx } => {
812 self.remote_receivers
813 .push(Box::pin(into_stream(rx).map(move |msg| (remote, msg))));
814 let sender = self.remote_senders.entry(remote).or_default();
815 if let Err(err) = sender.init(tx).await {
816 warn!("Remote failed while pushing queued messages: {err:?}");
817 }
818 }
819 ActorToTopic::Api(req) => {
820 self.init = true;
821 let WithChannels { inner, tx, rx, .. } = req;
822 let initial_neighbors = self.neighbors.clone().into_iter();
823 self.forward_event_tasks.spawn(
824 forward_events(tx, self.forward_event_tx.subscribe(), initial_neighbors)
825 .instrument(tracing::Span::current()),
826 );
827 self.api_receivers.push(Box::pin(into_stream2(rx)));
828 self.handle_in_event(InEvent::Command(Command::Join(
829 inner.bootstrap.into_iter().collect(),
830 )))
831 .await;
832 }
833 ActorToTopic::ConnectionFailed(node_id) => {
834 self.handle_in_event(InEvent::PeerDisconnected(node_id))
835 .await
836 }
837 }
838 }
839
840 async fn handle_remote_message(
841 &mut self,
842 remote: NodeId,
843 message: Result<Option<ProtoMessage>, RecvError>,
844 ) {
845 let event = match message {
846 Ok(Some(message)) => InEvent::RecvMessage(remote, message),
847 Ok(None) => {
848 debug!(remote=%remote.fmt_short(), "Recv stream from remote closed");
849 InEvent::PeerDisconnected(remote)
850 }
851 Err(err) => {
852 warn!(remote=%remote.fmt_short(), ?err, "Recv stream from remote failed");
853 InEvent::PeerDisconnected(remote)
854 }
855 };
856 self.handle_in_event(event).await;
857 }
858
859 async fn handle_api_command(&mut self, command: Result<api::Command, RecvError>) {
860 let Ok(command) = command else {
861 return;
862 };
863 trace!("tick: api command {command}");
864 self.handle_in_event(InEvent::Command(command.into())).await;
865 }
866
867 async fn handle_in_event(&mut self, event: InEvent) {
868 trace!("tick: in event {event:?}");
869 let now = Instant::now();
870 self.metrics.track_in_event(&event);
871 self.out_events.extend(self.state.handle(event, now));
872 self.process_out_events(now).await;
873 }
874
875 async fn process_out_events(&mut self, now: Instant) {
876 while let Some(event) = self.out_events.pop_front() {
877 trace!("tick: out event {event:?}");
878 self.metrics.track_out_event(&event);
879 match event {
880 OutEvent::SendMessage(node_id, message) => {
881 self.send(node_id, message).await;
882 }
883 OutEvent::EmitEvent(event) => {
884 self.handle_event(event);
885 }
886 OutEvent::ScheduleTimer(delay, timer) => {
887 self.timers.insert(now + delay, timer);
888 }
889 OutEvent::DisconnectPeer(node_id) => {
890 self.remote_senders.remove(&node_id);
891 }
892 OutEvent::PeerData(node_id, peer_data) => {
893 self.to_actor_tx
894 .send(LocalActorMessage::SetPeerData(node_id, peer_data))
895 .await
896 .ok();
897 }
898 }
899 }
900 }
901
902 #[instrument(skip_all, fields(remote=%remote.fmt_short()))]
903 async fn send(&mut self, remote: NodeId, message: ProtoMessage) {
904 let sender = match self.remote_senders.entry(remote) {
905 hash_map::Entry::Occupied(entry) => entry.into_mut(),
906 hash_map::Entry::Vacant(entry) => {
907 debug!("requesting new connection");
908 self.to_actor_tx
909 .send(LocalActorMessage::Connect(remote, self.topic_id))
910 .await
911 .ok();
912 entry.insert(Default::default())
913 }
914 };
915 if let Err(err) = sender.send(message).await {
916 warn!(?err, remote=%remote.fmt_short(), "failed to send message");
917 self.drop_peers_queue.insert(remote);
918 }
919 }
920
921 fn handle_event(&mut self, event: ProtoEvent) {
922 match &event {
923 ProtoEvent::NeighborUp(n) => {
924 #[cfg(test)]
925 self.joined
926 .store(true, std::sync::atomic::Ordering::Relaxed);
927 self.neighbors.insert(*n);
928 }
929 ProtoEvent::NeighborDown(n) => {
930 self.neighbors.remove(n);
931 }
932 ProtoEvent::Received(_) => {}
933 }
934 self.forward_event_tx.send(event).ok();
935 }
936}
937
938async fn forward_events(
939 tx: channel::mpsc::Sender<api::Event>,
940 mut sub: broadcast::Receiver<ProtoEvent>,
941 initial_neighbors: impl Iterator<Item = NodeId>,
942) {
943 for neighbor in initial_neighbors {
944 if let Err(_err) = tx.send(api::Event::NeighborUp(neighbor)).await {
945 break;
946 }
947 }
948 loop {
949 let event = tokio::select! {
950 biased;
951 event = sub.recv() => event,
952 _ = tx.closed() => break
953 };
954 let event: api::Event = match event {
955 Ok(event) => event.into(),
956 Err(broadcast::error::RecvError::Lagged(_)) => api::Event::Lagged,
957 Err(broadcast::error::RecvError::Closed) => break,
958 };
959 if let Err(_err) = tx.send(event).await {
960 break;
961 }
962 }
963}
964
965#[derive(Debug)]
966enum MaybeSender {
967 Active(Guarded<channel::mpsc::Sender<ProtoMessage>>),
968 Pending(Vec<ProtoMessage>),
969}
970
971impl MaybeSender {
972 async fn send(&mut self, message: ProtoMessage) -> Result<(), channel::SendError> {
973 match self {
974 Self::Active(sender) => sender.send(message).await,
975 Self::Pending(messages) => {
976 messages.push(message);
977 Ok(())
978 }
979 }
980 }
981
982 async fn init(
983 &mut self,
984 sender: Guarded<channel::mpsc::Sender<ProtoMessage>>,
985 ) -> Result<(), channel::SendError> {
986 debug!("Initializing new sender");
987 *self = match self {
988 Self::Active(_old) => {
989 debug!("Dropping old sender");
990 Self::Active(sender)
991 }
992 Self::Pending(queue) => {
993 debug!("Sending {} queued messages", queue.len());
994 for msg in queue.drain(..) {
995 sender.send(msg).await?;
996 }
997 Self::Active(sender)
998 }
999 };
1000 Ok(())
1001 }
1002}
1003
1004impl Default for MaybeSender {
1005 fn default() -> Self {
1006 Self::Pending(Vec::new())
1007 }
1008}
1009
1010fn into_stream<T: irpc::RpcMessage>(
1013 receiver: impl DerefMut<Target = channel::mpsc::Receiver<T>> + Send + Sync + 'static,
1014) -> impl Stream<Item = Result<Option<T>, RecvError>> + Send + Sync + 'static {
1015 n0_future::stream::unfold(Some(receiver), |recv| async move {
1016 let mut recv = recv?;
1017 let res = recv.recv().await;
1018 match res {
1019 Err(err) => Some((Err(err), None)),
1020 Ok(Some(res)) => Some((Ok(Some(res)), Some(recv))),
1021 Ok(None) => Some((Ok(None), None)),
1022 }
1023 })
1024}
1025
1026fn into_stream2<T: irpc::RpcMessage>(
1027 receiver: channel::mpsc::Receiver<T>,
1028) -> impl Stream<Item = Result<T, RecvError>> + Send + Sync + 'static {
1029 n0_future::stream::unfold(Some(receiver), |recv| async move {
1030 let mut recv = recv?;
1031 match recv.recv().await {
1032 Err(err) => Some((Err(err), None)),
1033 Ok(Some(res)) => Some((Ok(res), Some(recv))),
1034 Ok(None) => None,
1035 }
1036 })
1037}
1038
1039#[cfg(test)]
1040pub(crate) mod tests {
1041 use std::{future::Future, time::Duration};
1042
1043 use bytes::Bytes;
1044 use futures_concurrency::future::TryJoin;
1045 use iroh::{
1046 discovery::static_provider::StaticProvider, endpoint::BindError, protocol::Router,
1047 NodeAddr, RelayMap, RelayMode, SecretKey,
1048 };
1049 use n0_snafu::{Result, ResultExt};
1050 use rand::{CryptoRng, Rng, SeedableRng};
1051 use tokio::{spawn, time::timeout};
1052 use tokio_util::sync::CancellationToken;
1053 use tracing::info;
1054 use tracing_test::traced_test;
1055
1056 use super::*;
1057 use crate::{
1058 api::{ApiError, Event, GossipReceiver, GossipSender},
1059 ALPN,
1060 };
1061
1062 impl Gossip {
1063 pub(super) async fn t_new(
1064 rng: &mut rand_chacha::ChaCha12Rng,
1065 config: proto::Config,
1066 relay_map: RelayMap,
1067 cancel: &CancellationToken,
1068 ) -> n0_snafu::Result<(Self, Endpoint, impl Future<Output = ()>, impl Drop)> {
1069 let (gossip, actor, ep_handle) =
1070 Gossip::t_new_with_actor(rng, config, relay_map, cancel).await?;
1071 let ep = actor.endpoint().clone();
1072 let me = ep.node_id().fmt_short();
1073 let actor_handle =
1074 task::spawn(actor.run().instrument(tracing::error_span!("gossip", %me)));
1075 Ok((gossip, ep, ep_handle, AbortOnDropHandle::new(actor_handle)))
1076 }
1077 pub(super) async fn t_new_with_actor(
1078 rng: &mut rand_chacha::ChaCha12Rng,
1079 config: proto::Config,
1080 relay_map: RelayMap,
1081 cancel: &CancellationToken,
1082 ) -> n0_snafu::Result<(Self, Actor, impl Future<Output = ()>)> {
1083 let endpoint = Endpoint::builder()
1084 .secret_key(SecretKey::generate(rng))
1085 .relay_mode(RelayMode::Custom(relay_map))
1086 .insecure_skip_relay_cert_verify(true)
1087 .bind()
1088 .await?;
1089
1090 endpoint.online().await;
1091 let (gossip, mut actor) = Gossip::new_with_actor(endpoint.clone(), config, None);
1092 actor.node_addr_updates = Box::pin(n0_future::stream::pending());
1093 let router = Router::builder(endpoint)
1094 .accept(GOSSIP_ALPN, gossip.clone())
1095 .spawn();
1096 let cancel = cancel.clone();
1097 let router_task = tokio::task::spawn(async move {
1098 cancel.cancelled().await;
1099 router.shutdown().await.ok();
1100 drop(router);
1101 });
1102 let router_fut = async move {
1103 router_task.await.expect("router task panicked");
1104 };
1105 Ok((gossip, actor, router_fut))
1106 }
1107 }
1108
1109 pub(crate) async fn create_endpoint(
1110 rng: &mut rand_chacha::ChaCha12Rng,
1111 relay_map: RelayMap,
1112 static_provider: Option<StaticProvider>,
1113 ) -> Result<Endpoint, BindError> {
1114 let ep = Endpoint::builder()
1115 .secret_key(SecretKey::generate(rng))
1116 .alpns(vec![ALPN.to_vec()])
1117 .relay_mode(RelayMode::Custom(relay_map))
1118 .insecure_skip_relay_cert_verify(true)
1119 .bind()
1120 .await?;
1121
1122 if let Some(static_provider) = static_provider {
1123 ep.discovery().add(static_provider);
1124 }
1125 ep.online().await;
1126 Ok(ep)
1127 }
1128
1129 async fn endpoint_loop(
1130 endpoint: Endpoint,
1131 gossip: Gossip,
1132 cancel: CancellationToken,
1133 ) -> Result<()> {
1134 loop {
1135 tokio::select! {
1136 biased;
1137 _ = cancel.cancelled() => break,
1138 incoming = endpoint.accept() => match incoming {
1139 None => break,
1140 Some(incoming) => {
1141 let connecting = match incoming.accept() {
1142 Ok(connecting) => connecting,
1143 Err(err) => {
1144 warn!("incoming connection failed: {err:#}");
1145 continue;
1148 }
1149 };
1150 let connection = connecting.await.e()?;
1151 let remote_node_id = connection.remote_node_id()?;
1152 gossip.handle_connection(remote_node_id, connection).await?
1153 }
1154 }
1155 }
1156 }
1157 Ok(())
1158 }
1159
1160 #[tokio::test]
1161 #[traced_test]
1162 async fn gossip_net_smoke() {
1163 let mut rng = rand_chacha::ChaCha12Rng::seed_from_u64(1);
1164 let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1165
1166 let static_provider = StaticProvider::new();
1167
1168 let ep1 = create_endpoint(&mut rng, relay_map.clone(), Some(static_provider.clone()))
1169 .await
1170 .unwrap();
1171 let ep2 = create_endpoint(&mut rng, relay_map.clone(), Some(static_provider.clone()))
1172 .await
1173 .unwrap();
1174 let ep3 = create_endpoint(&mut rng, relay_map.clone(), Some(static_provider.clone()))
1175 .await
1176 .unwrap();
1177
1178 let go1 = Gossip::builder().spawn(ep1.clone());
1179 let go2 = Gossip::builder().spawn(ep2.clone());
1180 let go3 = Gossip::builder().spawn(ep3.clone());
1181 debug!("peer1 {:?}", ep1.node_id());
1182 debug!("peer2 {:?}", ep2.node_id());
1183 debug!("peer3 {:?}", ep3.node_id());
1184 let pi1 = ep1.node_id();
1185 let pi2 = ep2.node_id();
1186
1187 let cancel = CancellationToken::new();
1188 let tasks = [
1189 spawn(endpoint_loop(ep1.clone(), go1.clone(), cancel.clone())),
1190 spawn(endpoint_loop(ep2.clone(), go2.clone(), cancel.clone())),
1191 spawn(endpoint_loop(ep3.clone(), go3.clone(), cancel.clone())),
1192 ];
1193
1194 debug!("----- adding peers ----- ");
1195 let topic: TopicId = blake3::hash(b"foobar").into();
1196
1197 let addr1 = NodeAddr::new(pi1).with_relay_url(relay_url.clone());
1198 let addr2 = NodeAddr::new(pi2).with_relay_url(relay_url);
1199 static_provider.add_node_info(addr1.clone());
1200 static_provider.add_node_info(addr2.clone());
1201
1202 debug!("----- joining ----- ");
1203 let [sub1, mut sub2, mut sub3] = [
1205 go1.subscribe_and_join(topic, vec![]),
1206 go2.subscribe_and_join(topic, vec![pi1]),
1207 go3.subscribe_and_join(topic, vec![pi2]),
1208 ]
1209 .try_join()
1210 .await
1211 .unwrap();
1212
1213 let (sink1, _stream1) = sub1.split();
1214
1215 let len = 2;
1216
1217 let pub1 = spawn(async move {
1219 for i in 0..len {
1220 let message = format!("hi{i}");
1221 info!("go1 broadcast: {message:?}");
1222 sink1.broadcast(message.into_bytes().into()).await.unwrap();
1223 tokio::time::sleep(Duration::from_micros(1)).await;
1224 }
1225 });
1226
1227 let sub2 = spawn(async move {
1229 let mut recv = vec![];
1230 loop {
1231 let ev = sub2.next().await.unwrap().unwrap();
1232 info!("go2 event: {ev:?}");
1233 if let Event::Received(msg) = ev {
1234 recv.push(msg.content);
1235 }
1236 if recv.len() == len {
1237 return recv;
1238 }
1239 }
1240 });
1241
1242 let sub3 = spawn(async move {
1244 let mut recv = vec![];
1245 loop {
1246 let ev = sub3.next().await.unwrap().unwrap();
1247 info!("go3 event: {ev:?}");
1248 if let Event::Received(msg) = ev {
1249 recv.push(msg.content);
1250 }
1251 if recv.len() == len {
1252 return recv;
1253 }
1254 }
1255 });
1256
1257 timeout(Duration::from_secs(10), pub1)
1258 .await
1259 .unwrap()
1260 .unwrap();
1261 let recv2 = timeout(Duration::from_secs(10), sub2)
1262 .await
1263 .unwrap()
1264 .unwrap();
1265 let recv3 = timeout(Duration::from_secs(10), sub3)
1266 .await
1267 .unwrap()
1268 .unwrap();
1269
1270 let expected: Vec<Bytes> = (0..len)
1271 .map(|i| Bytes::from(format!("hi{i}").into_bytes()))
1272 .collect();
1273 assert_eq!(recv2, expected);
1274 assert_eq!(recv3, expected);
1275
1276 cancel.cancel();
1277 for t in tasks {
1278 timeout(Duration::from_secs(10), t)
1279 .await
1280 .unwrap()
1281 .unwrap()
1282 .unwrap();
1283 }
1284 }
1285
1286 #[tokio::test]
1296 #[traced_test]
1297 async fn subscription_cleanup() -> Result {
1298 let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1299 let ct = CancellationToken::new();
1300 let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1301
1302 let (go1, mut actor, ep1_handle) =
1304 Gossip::t_new_with_actor(rng, Default::default(), relay_map.clone(), &ct).await?;
1305
1306 let (go2, ep2, ep2_handle, _test_actor_handle) =
1308 Gossip::t_new(rng, Default::default(), relay_map, &ct).await?;
1309
1310 let node_id1 = actor.endpoint().node_id();
1311 let node_id2 = ep2.node_id();
1312 tracing::info!(
1313 node_1 = %node_id1.fmt_short(),
1314 node_2 = %node_id2.fmt_short(),
1315 "nodes ready"
1316 );
1317
1318 let topic: TopicId = blake3::hash(b"subscription_cleanup").into();
1319 tracing::info!(%topic, "joining");
1320
1321 let ct2 = ct.clone();
1328 let go2_task = async move {
1329 let (_pub_tx, mut sub_rx) = go2.subscribe_and_join(topic, vec![]).await?.split();
1330
1331 let subscribe_fut = async {
1332 while let Some(ev) = sub_rx.try_next().await? {
1333 match ev {
1334 Event::Lagged => tracing::debug!("missed some messages :("),
1335 Event::Received(_) => unreachable!("test does not send messages"),
1336 other => tracing::debug!(?other, "gs event"),
1337 }
1338 }
1339
1340 tracing::debug!("subscribe stream ended");
1341 Ok::<_, n0_snafu::Error>(())
1342 };
1343
1344 tokio::select! {
1345 _ = ct2.cancelled() => Ok(()),
1346 res = subscribe_fut => res,
1347 }
1348 }
1349 .instrument(tracing::debug_span!("node_2", node_id2=%node_id2.fmt_short()));
1350 let go2_handle = task::spawn(go2_task);
1351
1352 let addr2 = NodeAddr::new(node_id2).with_relay_url(relay_url);
1354 let static_provider = StaticProvider::new();
1355 static_provider.add_node_info(addr2);
1356 actor.endpoint().discovery().add(static_provider);
1357 let (tx, mut rx) = mpsc::channel::<()>(1);
1359 let ct1 = ct.clone();
1360 let go1_task = async move {
1361 tracing::info!("subscribing the first time");
1363 let sub_1a = go1.subscribe_and_join(topic, vec![node_id2]).await?;
1364
1365 rx.recv().await.expect("signal for second subscribe");
1367 tracing::info!("subscribing a second time");
1368 let sub_1b = go1.subscribe_and_join(topic, vec![node_id2]).await?;
1369 drop(sub_1a);
1370
1371 rx.recv().await.expect("signal for second subscribe");
1373 tracing::info!("dropping all handles");
1374 drop(sub_1b);
1375
1376 ct1.cancelled().await;
1378 drop(go1);
1379
1380 Ok::<_, n0_snafu::Error>(())
1381 }
1382 .instrument(tracing::debug_span!("node_1", node_id1 = %node_id1.fmt_short()));
1383 let go1_handle = task::spawn(go1_task);
1384
1385 actor.steps(4).await?; tracing::info!("subscribe and join done, should be joined");
1391 let state = actor.topics.get(&topic).expect("get registered topic");
1392 assert!(state.joined());
1393
1394 tx.send(()).await.e()?;
1396 actor.steps(1).await?; let state = actor.topics.get(&topic).expect("get registered topic");
1398 assert!(state.joined());
1399
1400 tx.send(()).await.e()?;
1402 actor.steps(1).await?; assert!(!actor.topics.contains_key(&topic));
1405
1406 ct.cancel();
1408 let wait = Duration::from_secs(5);
1409 timeout(wait, ep1_handle).await.e()?;
1410 timeout(wait, ep2_handle).await.e()?;
1411 timeout(wait, go1_handle).await.e()?.e()??;
1412 timeout(wait, go2_handle).await.e()?.e()??;
1413 timeout(wait, actor.finish()).await.e()?;
1414
1415 Ok(())
1416 }
1417
1418 #[tokio::test(flavor = "multi_thread")]
1425 #[traced_test]
1426 async fn can_reconnect() -> Result {
1427 let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1428 let ct = CancellationToken::new();
1429 let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1430
1431 let (go1, ep1, ep1_handle, _test_actor_handle1) =
1432 Gossip::t_new(rng, Default::default(), relay_map.clone(), &ct).await?;
1433
1434 let (go2, ep2, ep2_handle, _test_actor_handle2) =
1435 Gossip::t_new(rng, Default::default(), relay_map, &ct).await?;
1436
1437 let node_id1 = ep1.node_id();
1438 let node_id2 = ep2.node_id();
1439 tracing::info!(
1440 node_1 = %node_id1.fmt_short(),
1441 node_2 = %node_id2.fmt_short(),
1442 "nodes ready"
1443 );
1444
1445 let topic: TopicId = blake3::hash(b"can_reconnect").into();
1446 tracing::info!(%topic, "joining");
1447
1448 let (tx, mut rx) = mpsc::channel::<()>(1);
1450 let addr1 = NodeAddr::new(node_id1).with_relay_url(relay_url.clone());
1451 let static_provider = StaticProvider::new();
1452 static_provider.add_node_info(addr1);
1453 ep2.discovery().add(static_provider.clone());
1454 let go2_task = async move {
1455 let mut sub = go2.subscribe(topic, Vec::new()).await?;
1456 sub.joined().await?;
1457
1458 rx.recv().await.expect("signal to unsubscribe");
1459 tracing::info!("unsubscribing");
1460 drop(sub);
1461
1462 rx.recv().await.expect("signal to subscribe again");
1463 tracing::info!("resubscribing");
1464 let mut sub = go2.subscribe(topic, vec![node_id1]).await?;
1465
1466 sub.joined().await?;
1467
1468 Result::<_, ApiError>::Ok(())
1469 }
1470 .instrument(tracing::debug_span!("node_2", node_id2=%node_id2.fmt_short()));
1471 let go2_handle = task::spawn(go2_task);
1472
1473 let addr2 = NodeAddr::new(node_id2).with_relay_url(relay_url);
1474 static_provider.add_node_info(addr2);
1475 ep1.discovery().add(static_provider);
1476
1477 let mut sub = go1.subscribe(topic, vec![node_id2]).await?;
1478 sub.joined().await?;
1480 info!("go1 joined");
1481
1482 tx.send(()).await.e()?;
1484
1485 info!("wait for neighbor down");
1486 let conn_timeout = Duration::from_millis(1000);
1488 let ev = timeout(conn_timeout, sub.try_next()).await.e()??;
1489 assert_eq!(ev, Some(Event::NeighborDown(node_id2)));
1490 tracing::info!("node 2 left");
1491
1492 tx.send(()).await.e()?;
1494
1495 info!("wait for neighbor up");
1496 let conn_timeout = Duration::from_millis(1000);
1497 let ev = timeout(conn_timeout, sub.try_next()).await.e()??;
1498 assert_eq!(ev, Some(Event::NeighborUp(node_id2)));
1499 tracing::info!("node 2 rejoined!");
1500
1501 let wait = Duration::from_secs(2);
1503 timeout(wait, go2_handle).await.e()?.e()??;
1504 ct.cancel();
1505 timeout(wait, ep1_handle).await.e()?;
1507 timeout(wait, ep2_handle).await.e()?;
1508
1509 Ok(())
1510 }
1511
1512 #[tokio::test]
1513 #[traced_test]
1514 async fn can_die_and_reconnect() -> Result {
1515 fn run_in_thread<T: Send + 'static>(
1518 cancel: CancellationToken,
1519 fut: impl std::future::Future<Output = T> + Send + 'static,
1520 ) -> std::thread::JoinHandle<Option<T>> {
1521 std::thread::spawn(move || {
1522 let rt = tokio::runtime::Builder::new_current_thread()
1523 .enable_all()
1524 .build()
1525 .unwrap();
1526 rt.block_on(async move { cancel.run_until_cancelled(fut).await })
1527 })
1528 }
1529
1530 async fn spawn_gossip(
1532 secret_key: SecretKey,
1533 relay_map: RelayMap,
1534 ) -> Result<(Router, Gossip), BindError> {
1535 let ep = Endpoint::builder()
1536 .secret_key(secret_key)
1537 .relay_mode(RelayMode::Custom(relay_map))
1538 .insecure_skip_relay_cert_verify(true)
1539 .bind()
1540 .await?;
1541 let gossip = Gossip::builder().spawn(ep.clone());
1542 let router = Router::builder(ep).accept(ALPN, gossip.clone()).spawn();
1543 Ok((router, gossip))
1544 }
1545
1546 async fn broadcast_once(
1548 secret_key: SecretKey,
1549 relay_map: RelayMap,
1550 bootstrap_addr: NodeAddr,
1551 topic_id: TopicId,
1552 message: String,
1553 ) -> Result {
1554 let (router, gossip) = spawn_gossip(secret_key, relay_map).await?;
1555 info!(node_id = %router.endpoint().node_id().fmt_short(), "broadcast node spawned");
1556 let bootstrap = vec![bootstrap_addr.node_id];
1557 let static_provider = StaticProvider::new();
1558 static_provider.add_node_info(bootstrap_addr);
1559 router.endpoint().discovery().add(static_provider);
1560 let mut topic = gossip.subscribe_and_join(topic_id, bootstrap).await?;
1561 topic.broadcast(message.as_bytes().to_vec().into()).await?;
1562 std::future::pending::<()>().await;
1563 Ok(())
1564 }
1565
1566 let (relay_map, _relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
1567 let mut rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1568 let topic_id = TopicId::from_bytes(rng.random());
1569
1570 let (addr_tx, addr_rx) = tokio::sync::oneshot::channel();
1573 let (msgs_recv_tx, mut msgs_recv_rx) = tokio::sync::mpsc::channel(3);
1574 let recv_task = tokio::task::spawn({
1575 let relay_map = relay_map.clone();
1576 let secret_key = SecretKey::generate(&mut rng);
1577 async move {
1578 let (router, gossip) = spawn_gossip(secret_key, relay_map).await?;
1579 router.endpoint().online().await;
1584 let addr = router.endpoint().node_addr();
1585 info!(node_id = %addr.node_id.fmt_short(), "recv node spawned");
1586 addr_tx.send(addr).unwrap();
1587 let mut topic = gossip.subscribe_and_join(topic_id, vec![]).await?;
1588 while let Some(event) = topic.try_next().await.unwrap() {
1589 if let Event::Received(message) = event {
1590 let message = std::str::from_utf8(&message.content).e()?.to_string();
1591 msgs_recv_tx.send(message).await.e()?;
1592 }
1593 }
1594 Result::<_, n0_snafu::Error>::Ok(())
1595 }
1596 });
1597
1598 let node0_addr = addr_rx.await.e()?;
1599 let max_wait = Duration::from_secs(5);
1600
1601 let cancel = CancellationToken::new();
1604 let secret = SecretKey::generate(&mut rng);
1605 let join_handle_1 = run_in_thread(
1606 cancel.clone(),
1607 broadcast_once(
1608 secret.clone(),
1609 relay_map.clone(),
1610 node0_addr.clone(),
1611 topic_id,
1612 "msg1".to_string(),
1613 ),
1614 );
1615 let msg = timeout(max_wait, msgs_recv_rx.recv()).await.e()?.unwrap();
1617 assert_eq!(&msg, "msg1");
1618 info!("kill broadcast node");
1619 cancel.cancel();
1620
1621 let cancel = CancellationToken::new();
1623 let join_handle_2 = run_in_thread(
1624 cancel.clone(),
1625 broadcast_once(
1626 secret.clone(),
1627 relay_map.clone(),
1628 node0_addr.clone(),
1629 topic_id,
1630 "msg2".to_string(),
1631 ),
1632 );
1633 let msg = timeout(max_wait, msgs_recv_rx.recv()).await.e()?.unwrap();
1636 assert_eq!(&msg, "msg2");
1637 info!("kill broadcast node");
1638 cancel.cancel();
1639
1640 info!("kill recv node");
1641 recv_task.abort();
1642 assert!(join_handle_1.join().unwrap().is_none());
1643 assert!(join_handle_2.join().unwrap().is_none());
1644
1645 Ok(())
1646 }
1647
1648 #[tokio::test]
1649 #[traced_test]
1650 async fn gossip_change_alpn() -> n0_snafu::Result<()> {
1651 let alpn = b"my-gossip-alpn";
1652 let topic_id = TopicId::from([0u8; 32]);
1653
1654 let ep1 = Endpoint::builder().bind().await?;
1655 let ep2 = Endpoint::builder().bind().await?;
1656 let gossip1 = Gossip::builder().alpn(alpn).spawn(ep1.clone());
1657 let gossip2 = Gossip::builder().alpn(alpn).spawn(ep2.clone());
1658 let router1 = Router::builder(ep1).accept(alpn, gossip1.clone()).spawn();
1659 let router2 = Router::builder(ep2).accept(alpn, gossip2.clone()).spawn();
1660
1661 let addr1 = router1.endpoint().node_addr();
1662 let id1 = addr1.node_id;
1663 let static_provider = StaticProvider::new();
1664 static_provider.add_node_info(addr1);
1665 router2.endpoint().discovery().add(static_provider);
1666
1667 let mut topic1 = gossip1.subscribe(topic_id, vec![]).await?;
1668 let mut topic2 = gossip2.subscribe(topic_id, vec![id1]).await?;
1669
1670 timeout(Duration::from_secs(3), topic1.joined())
1671 .await
1672 .e()??;
1673 timeout(Duration::from_secs(3), topic2.joined())
1674 .await
1675 .e()??;
1676 router1.shutdown().await.e()?;
1677 router2.shutdown().await.e()?;
1678 Ok(())
1679 }
1680
1681 #[tokio::test]
1682 #[traced_test]
1683 async fn gossip_rely_on_gossip_discovery() -> n0_snafu::Result<()> {
1684 let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1);
1685
1686 async fn spawn(
1687 rng: &mut impl CryptoRng,
1688 ) -> n0_snafu::Result<(NodeId, Router, Gossip, GossipSender, GossipReceiver)> {
1689 let topic_id = TopicId::from([0u8; 32]);
1690 let ep = Endpoint::builder()
1691 .secret_key(SecretKey::generate(rng))
1692 .relay_mode(RelayMode::Disabled)
1693 .bind()
1694 .await?;
1695 let node_id = ep.node_id();
1696 let gossip = Gossip::builder().spawn(ep.clone());
1697 let router = Router::builder(ep)
1698 .accept(GOSSIP_ALPN, gossip.clone())
1699 .spawn();
1700 let topic = gossip.subscribe(topic_id, vec![]).await?;
1701 let (sender, receiver) = topic.split();
1702 Ok((node_id, router, gossip, sender, receiver))
1703 }
1704
1705 let (n1, r1, _g1, _tx1, mut rx1) = spawn(rng).await?;
1707 let (n2, r2, _g2, tx2, mut rx2) = spawn(rng).await?;
1708 let (n3, r3, _g3, tx3, mut rx3) = spawn(rng).await?;
1709
1710 println!("nodes {:?}", [n1, n2, n3]);
1711
1712 let addr1 = r1.endpoint().node_addr();
1714 let disco = StaticProvider::new();
1715 disco.add_node_info(addr1);
1716
1717 r2.endpoint().discovery().add(disco.clone());
1719 tx2.join_peers(vec![n1]).await?;
1720
1721 timeout(Duration::from_secs(3), rx1.joined()).await.e()??;
1723 timeout(Duration::from_secs(3), rx2.joined()).await.e()??;
1724
1725 r3.endpoint().discovery().add(disco.clone());
1727 tx3.join_peers(vec![n1]).await?;
1728
1729 let ev = timeout(Duration::from_secs(3), rx3.next()).await.e()?;
1732 assert!(matches!(ev, Some(Ok(Event::NeighborUp(_)))));
1733 let ev = timeout(Duration::from_secs(3), rx3.next()).await.e()?;
1734 assert!(matches!(ev, Some(Ok(Event::NeighborUp(_)))));
1735
1736 assert_eq!(sorted(rx3.neighbors()), sorted([n1, n2]));
1737
1738 let ev = timeout(Duration::from_secs(3), rx2.next()).await.e()?;
1739 assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3));
1740
1741 let ev = timeout(Duration::from_secs(3), rx1.next()).await.e()?;
1742 assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3));
1743
1744 tokio::try_join!(r1.shutdown(), r2.shutdown(), r3.shutdown()).e()?;
1745 Ok(())
1746 }
1747
1748 fn sorted<T: Ord>(input: impl IntoIterator<Item = T>) -> Vec<T> {
1749 let mut out: Vec<_> = input.into_iter().collect();
1750 out.sort();
1751 out
1752 }
1753}