1use std::{
6 collections::{BTreeSet, HashSet},
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11use bytes::Bytes;
12use iroh_base::EndpointId;
13use irpc::{channel::mpsc, rpc_requests, Client};
14use n0_error::{e, stack_error};
15use n0_future::{Stream, StreamExt, TryStreamExt};
16use serde::{Deserialize, Serialize};
17
18use crate::proto::{DeliveryScope, TopicId};
19
20const TOPIC_EVENTS_DEFAULT_CAP: usize = 2048;
22const TOPIC_COMMANDS_CAP: usize = 64;
24
25#[rpc_requests(message = RpcMessage, rpc_feature = "rpc")]
27#[derive(Debug, Serialize, Deserialize)]
28pub(crate) enum Request {
29 #[rpc(tx=mpsc::Sender<Event>, rx=mpsc::Receiver<Command>)]
30 Join(JoinRequest),
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34pub(crate) struct JoinRequest {
35 pub topic_id: TopicId,
36 pub bootstrap: BTreeSet<EndpointId>,
37}
38
39#[allow(missing_docs)]
40#[stack_error(derive, add_meta, from_sources)]
41#[non_exhaustive]
42pub enum ApiError {
43 #[error(transparent)]
44 Rpc { source: irpc::Error },
45 #[error("topic closed")]
47 Closed,
48}
49
50impl From<irpc::channel::SendError> for ApiError {
51 fn from(value: irpc::channel::SendError) -> Self {
52 irpc::Error::from(value).into()
53 }
54}
55
56impl From<irpc::channel::mpsc::RecvError> for ApiError {
57 fn from(value: irpc::channel::mpsc::RecvError) -> Self {
58 irpc::Error::from(value).into()
59 }
60}
61
62impl From<irpc::channel::oneshot::RecvError> for ApiError {
63 fn from(value: irpc::channel::oneshot::RecvError) -> Self {
64 irpc::Error::from(value).into()
65 }
66}
67
68#[derive(Debug, Clone)]
82pub struct GossipApi {
83 client: Client<Request>,
84}
85
86impl GossipApi {
87 #[cfg(feature = "net")]
88 pub(crate) fn local(tx: tokio::sync::mpsc::Sender<RpcMessage>) -> Self {
89 let local = irpc::LocalSender::<Request>::from(tx);
90 Self {
91 client: local.into(),
92 }
93 }
94
95 #[cfg(feature = "rpc")]
97 pub fn connect(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
98 let inner = irpc::Client::quinn(endpoint, addr);
99 Self { client: inner }
100 }
101
102 #[cfg(all(feature = "rpc", feature = "net"))]
104 pub(crate) async fn listen(&self, endpoint: quinn::Endpoint) {
105 use irpc::rpc::{listen, RemoteService};
106
107 let local = self
108 .client
109 .as_local()
110 .expect("cannot listen on remote client");
111 let handler = Request::remote_handler(local);
112
113 listen::<Request>(endpoint, handler).await
114 }
115
116 pub async fn subscribe_with_opts(
124 &self,
125 topic_id: TopicId,
126 opts: JoinOptions,
127 ) -> Result<GossipTopic, ApiError> {
128 let req = JoinRequest {
129 topic_id,
130 bootstrap: opts.bootstrap,
131 };
132 let (tx, rx) = self
133 .client
134 .bidi_streaming(req, TOPIC_COMMANDS_CAP, opts.subscription_capacity)
135 .await?;
136 Ok(GossipTopic::new(tx, rx))
137 }
138
139 pub async fn subscribe_and_join(
141 &self,
142 topic_id: TopicId,
143 bootstrap: Vec<EndpointId>,
144 ) -> Result<GossipTopic, ApiError> {
145 let mut sub = self
146 .subscribe_with_opts(topic_id, JoinOptions::with_bootstrap(bootstrap))
147 .await?;
148 sub.joined().await?;
149 Ok(sub)
150 }
151
152 pub async fn subscribe(
158 &self,
159 topic_id: TopicId,
160 bootstrap: Vec<EndpointId>,
161 ) -> Result<GossipTopic, ApiError> {
162 let sub = self
163 .subscribe_with_opts(topic_id, JoinOptions::with_bootstrap(bootstrap))
164 .await?;
165
166 Ok(sub)
167 }
168}
169
170#[derive(Debug, Clone)]
172pub struct GossipSender(mpsc::Sender<Command>);
173
174impl GossipSender {
175 pub(crate) fn new(sender: mpsc::Sender<Command>) -> Self {
176 Self(sender)
177 }
178
179 pub async fn broadcast(&self, message: Bytes) -> Result<(), ApiError> {
181 self.send(Command::Broadcast(message)).await?;
182 Ok(())
183 }
184
185 pub async fn broadcast_neighbors(&self, message: Bytes) -> Result<(), ApiError> {
187 self.send(Command::BroadcastNeighbors(message)).await?;
188 Ok(())
189 }
190
191 pub async fn join_peers(&self, peers: Vec<EndpointId>) -> Result<(), ApiError> {
193 self.send(Command::JoinPeers(peers)).await?;
194 Ok(())
195 }
196
197 async fn send(&self, command: Command) -> Result<(), irpc::channel::SendError> {
198 self.0.send(command).await?;
199 Ok(())
200 }
201}
202
203#[derive(Debug)]
212pub struct GossipTopic {
213 sender: GossipSender,
214 receiver: GossipReceiver,
215}
216
217impl GossipTopic {
218 pub(crate) fn new(sender: mpsc::Sender<Command>, receiver: mpsc::Receiver<Event>) -> Self {
219 let sender = GossipSender::new(sender);
220 Self {
221 sender,
222 receiver: GossipReceiver::new(receiver),
223 }
224 }
225
226 pub fn split(self) -> (GossipSender, GossipReceiver) {
228 (self.sender, self.receiver)
229 }
230
231 pub async fn broadcast(&mut self, message: Bytes) -> Result<(), ApiError> {
233 self.sender.broadcast(message).await
234 }
235
236 pub async fn broadcast_neighbors(&mut self, message: Bytes) -> Result<(), ApiError> {
238 self.sender.broadcast_neighbors(message).await
239 }
240
241 pub async fn joined(&mut self) -> Result<(), ApiError> {
245 self.receiver.joined().await
246 }
247
248 pub fn is_joined(&self) -> bool {
250 self.receiver.is_joined()
251 }
252}
253
254impl Stream for GossipTopic {
255 type Item = Result<Event, ApiError>;
256
257 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
258 Pin::new(&mut self.receiver).poll_next(cx)
259 }
260}
261
262#[derive(derive_more::Debug)]
266pub struct GossipReceiver {
267 #[debug("BoxStream")]
268 stream: Pin<Box<dyn Stream<Item = Result<Event, ApiError>> + Send + Sync + 'static>>,
269 neighbors: HashSet<EndpointId>,
270}
271
272impl GossipReceiver {
273 pub(crate) fn new(events_rx: mpsc::Receiver<Event>) -> Self {
274 let stream = events_rx.into_stream().map_err(ApiError::from);
275 let stream = Box::pin(stream);
276 Self {
277 stream,
278 neighbors: Default::default(),
279 }
280 }
281
282 pub fn neighbors(&self) -> impl Iterator<Item = EndpointId> + '_ {
284 self.neighbors.iter().copied()
285 }
286
287 pub async fn joined(&mut self) -> Result<(), ApiError> {
295 while !self.is_joined() {
296 let _event = self.next().await.ok_or(e!(ApiError::Closed))??;
297 }
298 Ok(())
299 }
300
301 pub fn is_joined(&self) -> bool {
303 !self.neighbors.is_empty()
304 }
305}
306
307impl Stream for GossipReceiver {
308 type Item = Result<Event, ApiError>;
309
310 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
311 let item = std::task::ready!(Pin::new(&mut self.stream).poll_next(cx));
312 if let Some(Ok(item)) = &item {
313 match item {
314 Event::NeighborUp(endpoint_id) => {
315 self.neighbors.insert(*endpoint_id);
316 }
317 Event::NeighborDown(endpoint_id) => {
318 self.neighbors.remove(endpoint_id);
319 }
320 _ => {}
321 }
322 }
323 Poll::Ready(item)
324 }
325}
326
327#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
331pub enum Event {
332 NeighborUp(EndpointId),
334 NeighborDown(EndpointId),
336 Received(Message),
338 Lagged,
340}
341
342impl From<crate::proto::Event<EndpointId>> for Event {
343 fn from(event: crate::proto::Event<EndpointId>) -> Self {
344 match event {
345 crate::proto::Event::NeighborUp(endpoint_id) => Self::NeighborUp(endpoint_id),
346 crate::proto::Event::NeighborDown(endpoint_id) => Self::NeighborDown(endpoint_id),
347 crate::proto::Event::Received(message) => Self::Received(Message {
348 content: message.content,
349 scope: message.scope,
350 delivered_from: message.delivered_from,
351 }),
352 }
353 }
354}
355
356#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, derive_more::Debug, Serialize, Deserialize)]
358pub struct Message {
359 #[debug("Bytes({})", self.content.len())]
361 pub content: Bytes,
362 pub scope: DeliveryScope,
365 pub delivered_from: EndpointId,
367}
368
369#[derive(Serialize, Deserialize, derive_more::Debug, Clone)]
371pub enum Command {
372 Broadcast(#[debug("Bytes({})", _0.len())] Bytes),
374 BroadcastNeighbors(#[debug("Bytes({})", _0.len())] Bytes),
376 JoinPeers(Vec<EndpointId>),
378}
379
380#[derive(Serialize, Deserialize, Debug)]
382pub struct JoinOptions {
383 pub bootstrap: BTreeSet<EndpointId>,
385 pub subscription_capacity: usize,
393}
394
395impl JoinOptions {
396 pub fn with_bootstrap(endpoints: impl IntoIterator<Item = EndpointId>) -> Self {
399 Self {
400 bootstrap: endpoints.into_iter().collect(),
401 subscription_capacity: TOPIC_EVENTS_DEFAULT_CAP,
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use crate::api::GossipTopic;
409
410 #[cfg(all(feature = "rpc", feature = "net"))]
411 #[tokio::test]
412 #[tracing_test::traced_test]
413 async fn test_rpc() -> n0_error::Result<()> {
414 use iroh::{discovery::static_provider::StaticProvider, protocol::Router, RelayMap};
415 use n0_error::{AnyError, Result, StackResultExt, StdResultExt};
416 use n0_future::{time::Duration, StreamExt};
417 use rand_chacha::rand_core::SeedableRng;
418
419 use crate::{
420 api::{Event, GossipApi},
421 net::{test::create_endpoint, Gossip},
422 proto::TopicId,
423 ALPN,
424 };
425
426 let mut rng = rand_chacha::ChaCha12Rng::seed_from_u64(1);
427 let (relay_map, _relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
428
429 async fn create_gossip_endpoint(
430 rng: &mut rand_chacha::ChaCha12Rng,
431 relay_map: RelayMap,
432 ) -> Result<(Router, Gossip)> {
433 let endpoint = create_endpoint(rng, relay_map, None).await?;
434 let gossip = Gossip::builder().spawn(endpoint.clone());
435 let router = Router::builder(endpoint)
436 .accept(ALPN, gossip.clone())
437 .spawn();
438 Ok((router, gossip))
439 }
440
441 let topic_id = TopicId::from_bytes([0u8; 32]);
442
443 let (router, gossip) = create_gossip_endpoint(&mut rng, relay_map.clone()).await?;
445
446 let (endpoint2_id, endpoint2_addr, endpoint2_task) = {
448 let (router, gossip) = create_gossip_endpoint(&mut rng, relay_map.clone()).await?;
449 let endpoint_addr = router.endpoint().addr();
450 let endpoint_id = router.endpoint().id();
451 let task = tokio::task::spawn(async move {
452 let mut topic = gossip.subscribe_and_join(topic_id, vec![]).await?;
453 topic.broadcast(b"hello".to_vec().into()).await?;
454 Ok::<_, AnyError>(router)
455 });
456 (endpoint_id, endpoint_addr, task)
457 };
458
459 let static_provider = StaticProvider::new();
461 static_provider.add_endpoint_info(endpoint2_addr);
462
463 router.endpoint().discovery().add(static_provider);
464
465 let (rpc_server_endpoint, rpc_server_cert) =
467 irpc::util::make_server_endpoint("127.0.0.1:0".parse().unwrap())
468 .context("make server endpoint")?;
469 let rpc_server_addr = rpc_server_endpoint
470 .local_addr()
471 .std_context("resolve server addr")?;
472 let rpc_server_task = tokio::task::spawn(async move {
473 gossip.listen(rpc_server_endpoint).await;
474 });
475
476 let rpc_client_endpoint =
478 irpc::util::make_client_endpoint("127.0.0.1:0".parse().unwrap(), &[&rpc_server_cert])
479 .context("make client endpoint")?;
480 let rpc_client = GossipApi::connect(rpc_client_endpoint, rpc_server_addr);
481
482 let recv = async move {
484 let mut topic = rpc_client
485 .subscribe_and_join(topic_id, vec![endpoint2_id])
486 .await?;
487 while let Some(event) = topic.try_next().await? {
489 match event {
490 Event::Received(message) => {
491 assert_eq!(&message.content[..], b"hello");
492 break;
493 }
494 Event::Lagged => panic!("unexpected lagged event"),
495 _ => {}
496 }
497 }
498 Ok::<_, AnyError>(())
499 };
500
501 tokio::time::timeout(Duration::from_secs(10), recv)
503 .await
504 .std_context("rpc recv timeout")??;
505
506 rpc_server_task.abort();
508 router.shutdown().await.std_context("shutdown router")?;
509 let router2 = endpoint2_task.await.std_context("join endpoint task")??;
510 router2
511 .shutdown()
512 .await
513 .std_context("shutdown second router")?;
514 Ok(())
515 }
516
517 #[test]
518 fn ensure_gossip_topic_is_sync() {
519 #[allow(unused)]
520 fn get() -> GossipTopic {
521 unimplemented!()
522 }
523 #[allow(unused)]
524 fn check(_t: impl Sync) {}
525 #[allow(unused)]
526 fn foo() {
527 check(get());
528 }
529 }
530}