1use std::{
6 collections::{BTreeSet, HashSet},
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11use bytes::Bytes;
12use iroh_base::EndpointId;
13use irpc::{channel::mpsc, rpc_requests, Client};
14use n0_error::{e, stack_error};
15use n0_future::{Stream, StreamExt, TryStreamExt};
16use serde::{Deserialize, Serialize};
17
18use crate::proto::{DeliveryScope, TopicId};
19
20const TOPIC_EVENTS_DEFAULT_CAP: usize = 2048;
22const TOPIC_COMMANDS_CAP: usize = 64;
24
25#[rpc_requests(message = RpcMessage, rpc_feature = "rpc")]
27#[derive(Debug, Serialize, Deserialize)]
28pub(crate) enum Request {
29 #[rpc(tx=mpsc::Sender<Event>, rx=mpsc::Receiver<Command>)]
30 Join(JoinRequest),
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34pub(crate) struct JoinRequest {
35 pub topic_id: TopicId,
36 pub bootstrap: BTreeSet<EndpointId>,
37}
38
39#[allow(missing_docs)]
40#[stack_error(derive, add_meta, from_sources)]
41#[non_exhaustive]
42pub enum ApiError {
43 #[error(transparent)]
44 Rpc { source: irpc::Error },
45 #[error("topic closed")]
47 Closed,
48}
49
50impl From<irpc::channel::SendError> for ApiError {
51 fn from(value: irpc::channel::SendError) -> Self {
52 irpc::Error::from(value).into()
53 }
54}
55
56impl From<irpc::channel::mpsc::RecvError> for ApiError {
57 fn from(value: irpc::channel::mpsc::RecvError) -> Self {
58 irpc::Error::from(value).into()
59 }
60}
61
62impl From<irpc::channel::oneshot::RecvError> for ApiError {
63 fn from(value: irpc::channel::oneshot::RecvError) -> Self {
64 irpc::Error::from(value).into()
65 }
66}
67
68#[derive(Debug, Clone)]
82pub struct GossipApi {
83 client: Client<Request>,
84}
85
86impl GossipApi {
87 #[cfg(feature = "net")]
88 pub(crate) fn local(tx: tokio::sync::mpsc::Sender<RpcMessage>) -> Self {
89 let local = irpc::LocalSender::<Request>::from(tx);
90 Self {
91 client: local.into(),
92 }
93 }
94
95 #[cfg(feature = "rpc")]
97 pub fn connect(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
98 let inner = irpc::Client::quinn(endpoint, addr);
99 Self { client: inner }
100 }
101
102 #[cfg(all(feature = "rpc", feature = "net"))]
104 pub(crate) async fn listen(&self, endpoint: quinn::Endpoint) {
105 use irpc::rpc::{listen, RemoteService};
106
107 let local = self
108 .client
109 .as_local()
110 .expect("cannot listen on remote client");
111 let handler = Request::remote_handler(local);
112
113 listen::<Request>(endpoint, handler).await
114 }
115
116 pub async fn subscribe_with_opts(
124 &self,
125 topic_id: TopicId,
126 opts: JoinOptions,
127 ) -> Result<GossipTopic, ApiError> {
128 let req = JoinRequest {
129 topic_id,
130 bootstrap: opts.bootstrap,
131 };
132 let (tx, rx) = self
133 .client
134 .bidi_streaming(req, TOPIC_COMMANDS_CAP, opts.subscription_capacity)
135 .await?;
136 Ok(GossipTopic::new(tx, rx))
137 }
138
139 pub async fn subscribe_and_join(
141 &self,
142 topic_id: TopicId,
143 bootstrap: Vec<EndpointId>,
144 ) -> Result<GossipTopic, ApiError> {
145 let mut sub = self
146 .subscribe_with_opts(topic_id, JoinOptions::with_bootstrap(bootstrap))
147 .await?;
148 sub.joined().await?;
149 Ok(sub)
150 }
151
152 pub async fn subscribe(
158 &self,
159 topic_id: TopicId,
160 bootstrap: Vec<EndpointId>,
161 ) -> Result<GossipTopic, ApiError> {
162 let sub = self
163 .subscribe_with_opts(topic_id, JoinOptions::with_bootstrap(bootstrap))
164 .await?;
165
166 Ok(sub)
167 }
168}
169
170#[derive(Debug, Clone)]
172pub struct GossipSender(mpsc::Sender<Command>);
173
174impl GossipSender {
175 pub(crate) fn new(sender: mpsc::Sender<Command>) -> Self {
176 Self(sender)
177 }
178
179 pub async fn broadcast(&self, message: Bytes) -> Result<(), ApiError> {
181 self.send(Command::Broadcast(message)).await?;
182 Ok(())
183 }
184
185 pub async fn broadcast_neighbors(&self, message: Bytes) -> Result<(), ApiError> {
187 self.send(Command::BroadcastNeighbors(message)).await?;
188 Ok(())
189 }
190
191 pub async fn join_peers(&self, peers: Vec<EndpointId>) -> Result<(), ApiError> {
193 self.send(Command::JoinPeers(peers)).await?;
194 Ok(())
195 }
196
197 async fn send(&self, command: Command) -> Result<(), irpc::channel::SendError> {
198 self.0.send(command).await?;
199 Ok(())
200 }
201}
202
203#[derive(Debug)]
212pub struct GossipTopic {
213 sender: GossipSender,
214 receiver: GossipReceiver,
215}
216
217impl GossipTopic {
218 pub(crate) fn new(sender: mpsc::Sender<Command>, receiver: mpsc::Receiver<Event>) -> Self {
219 let sender = GossipSender::new(sender);
220 Self {
221 sender,
222 receiver: GossipReceiver::new(receiver),
223 }
224 }
225
226 pub fn split(self) -> (GossipSender, GossipReceiver) {
228 (self.sender, self.receiver)
229 }
230
231 pub async fn broadcast(&mut self, message: Bytes) -> Result<(), ApiError> {
233 self.sender.broadcast(message).await
234 }
235
236 pub async fn broadcast_neighbors(&mut self, message: Bytes) -> Result<(), ApiError> {
238 self.sender.broadcast_neighbors(message).await
239 }
240
241 pub fn neighbors(&self) -> impl Iterator<Item = EndpointId> + '_ {
243 self.receiver.neighbors()
244 }
245
246 pub async fn joined(&mut self) -> Result<(), ApiError> {
250 self.receiver.joined().await
251 }
252
253 pub fn is_joined(&self) -> bool {
255 self.receiver.is_joined()
256 }
257}
258
259impl Stream for GossipTopic {
260 type Item = Result<Event, ApiError>;
261
262 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
263 Pin::new(&mut self.receiver).poll_next(cx)
264 }
265}
266
267#[derive(derive_more::Debug)]
271pub struct GossipReceiver {
272 #[debug("BoxStream")]
273 stream: Pin<Box<dyn Stream<Item = Result<Event, ApiError>> + Send + Sync + 'static>>,
274 neighbors: HashSet<EndpointId>,
275}
276
277impl GossipReceiver {
278 pub(crate) fn new(events_rx: mpsc::Receiver<Event>) -> Self {
279 let stream = events_rx.into_stream().map_err(ApiError::from);
280 let stream = Box::pin(stream);
281 Self {
282 stream,
283 neighbors: Default::default(),
284 }
285 }
286
287 pub fn neighbors(&self) -> impl Iterator<Item = EndpointId> + '_ {
289 self.neighbors.iter().copied()
290 }
291
292 pub async fn joined(&mut self) -> Result<(), ApiError> {
300 while !self.is_joined() {
301 let _event = self.next().await.ok_or(e!(ApiError::Closed))??;
302 }
303 Ok(())
304 }
305
306 pub fn is_joined(&self) -> bool {
308 !self.neighbors.is_empty()
309 }
310}
311
312impl Stream for GossipReceiver {
313 type Item = Result<Event, ApiError>;
314
315 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
316 let item = std::task::ready!(Pin::new(&mut self.stream).poll_next(cx));
317 if let Some(Ok(item)) = &item {
318 match item {
319 Event::NeighborUp(endpoint_id) => {
320 self.neighbors.insert(*endpoint_id);
321 }
322 Event::NeighborDown(endpoint_id) => {
323 self.neighbors.remove(endpoint_id);
324 }
325 _ => {}
326 }
327 }
328 Poll::Ready(item)
329 }
330}
331
332#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
336pub enum Event {
337 NeighborUp(EndpointId),
339 NeighborDown(EndpointId),
341 Received(Message),
343 Lagged,
345}
346
347impl From<crate::proto::Event<EndpointId>> for Event {
348 fn from(event: crate::proto::Event<EndpointId>) -> Self {
349 match event {
350 crate::proto::Event::NeighborUp(endpoint_id) => Self::NeighborUp(endpoint_id),
351 crate::proto::Event::NeighborDown(endpoint_id) => Self::NeighborDown(endpoint_id),
352 crate::proto::Event::Received(message) => Self::Received(Message {
353 content: message.content,
354 scope: message.scope,
355 delivered_from: message.delivered_from,
356 }),
357 }
358 }
359}
360
361#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, derive_more::Debug, Serialize, Deserialize)]
363pub struct Message {
364 #[debug("Bytes({})", self.content.len())]
366 pub content: Bytes,
367 pub scope: DeliveryScope,
370 pub delivered_from: EndpointId,
372}
373
374#[derive(Serialize, Deserialize, derive_more::Debug, Clone)]
376pub enum Command {
377 Broadcast(#[debug("Bytes({})", _0.len())] Bytes),
379 BroadcastNeighbors(#[debug("Bytes({})", _0.len())] Bytes),
381 JoinPeers(Vec<EndpointId>),
383}
384
385#[derive(Serialize, Deserialize, Debug)]
387pub struct JoinOptions {
388 pub bootstrap: BTreeSet<EndpointId>,
390 pub subscription_capacity: usize,
398}
399
400impl JoinOptions {
401 pub fn with_bootstrap(endpoints: impl IntoIterator<Item = EndpointId>) -> Self {
404 Self {
405 bootstrap: endpoints.into_iter().collect(),
406 subscription_capacity: TOPIC_EVENTS_DEFAULT_CAP,
407 }
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use crate::api::GossipTopic;
414
415 #[cfg(all(feature = "rpc", feature = "net"))]
416 #[tokio::test]
417 #[n0_tracing_test::traced_test]
418 async fn test_rpc() -> n0_error::Result<()> {
419 use iroh::{discovery::static_provider::StaticProvider, protocol::Router, RelayMap};
420 use n0_error::{AnyError, Result, StackResultExt, StdResultExt};
421 use n0_future::{time::Duration, StreamExt};
422 use rand_chacha::rand_core::SeedableRng;
423
424 use crate::{
425 api::{Event, GossipApi},
426 net::{test::create_endpoint, Gossip},
427 proto::TopicId,
428 ALPN,
429 };
430
431 let mut rng = rand_chacha::ChaCha12Rng::seed_from_u64(1);
432 let (relay_map, _relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
433
434 async fn create_gossip_endpoint(
435 rng: &mut rand_chacha::ChaCha12Rng,
436 relay_map: RelayMap,
437 ) -> Result<(Router, Gossip)> {
438 let endpoint = create_endpoint(rng, relay_map, None).await?;
439 let gossip = Gossip::builder().spawn(endpoint.clone());
440 let router = Router::builder(endpoint)
441 .accept(ALPN, gossip.clone())
442 .spawn();
443 Ok((router, gossip))
444 }
445
446 let topic_id = TopicId::from_bytes([0u8; 32]);
447
448 let (router, gossip) = create_gossip_endpoint(&mut rng, relay_map.clone()).await?;
450
451 let (endpoint2_id, endpoint2_addr, endpoint2_task) = {
453 let (router, gossip) = create_gossip_endpoint(&mut rng, relay_map.clone()).await?;
454 let endpoint_addr = router.endpoint().addr();
455 let endpoint_id = router.endpoint().id();
456 let task = tokio::task::spawn(async move {
457 let mut topic = gossip.subscribe_and_join(topic_id, vec![]).await?;
458 topic.broadcast(b"hello".to_vec().into()).await?;
459 Ok::<_, AnyError>(router)
460 });
461 (endpoint_id, endpoint_addr, task)
462 };
463
464 let static_provider = StaticProvider::new();
466 static_provider.add_endpoint_info(endpoint2_addr);
467
468 router.endpoint().discovery().add(static_provider);
469
470 let (rpc_server_endpoint, rpc_server_cert) =
472 irpc::util::make_server_endpoint("127.0.0.1:0".parse().unwrap())
473 .context("make server endpoint")?;
474 let rpc_server_addr = rpc_server_endpoint
475 .local_addr()
476 .std_context("resolve server addr")?;
477 let rpc_server_task = tokio::task::spawn(async move {
478 gossip.listen(rpc_server_endpoint).await;
479 });
480
481 let rpc_client_endpoint =
483 irpc::util::make_client_endpoint("127.0.0.1:0".parse().unwrap(), &[&rpc_server_cert])
484 .context("make client endpoint")?;
485 let rpc_client = GossipApi::connect(rpc_client_endpoint, rpc_server_addr);
486
487 let recv = async move {
489 let mut topic = rpc_client
490 .subscribe_and_join(topic_id, vec![endpoint2_id])
491 .await?;
492 while let Some(event) = topic.try_next().await? {
494 match event {
495 Event::Received(message) => {
496 assert_eq!(&message.content[..], b"hello");
497 break;
498 }
499 Event::Lagged => panic!("unexpected lagged event"),
500 _ => {}
501 }
502 }
503 Ok::<_, AnyError>(())
504 };
505
506 tokio::time::timeout(Duration::from_secs(10), recv)
508 .await
509 .std_context("rpc recv timeout")??;
510
511 rpc_server_task.abort();
513 router.shutdown().await.std_context("shutdown router")?;
514 let router2 = endpoint2_task.await.std_context("join endpoint task")??;
515 router2
516 .shutdown()
517 .await
518 .std_context("shutdown second router")?;
519 Ok(())
520 }
521
522 #[test]
523 fn ensure_gossip_topic_is_sync() {
524 #[allow(unused)]
525 fn get() -> GossipTopic {
526 unimplemented!()
527 }
528 #[allow(unused)]
529 fn check(_t: impl Sync) {}
530 #[allow(unused)]
531 fn foo() {
532 check(get());
533 }
534 }
535}