use std::{
collections::{BTreeSet, HashSet},
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use iroh_base::NodeId;
use irpc::{channel::spsc, Client};
use irpc_derive::rpc_requests;
use n0_future::{Stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use crate::proto::{DeliveryScope, TopicId};
const TOPIC_EVENTS_DEFAULT_CAP: usize = 2048;
const TOPIC_COMMANDS_CAP: usize = 64;
#[derive(Debug, Clone, Copy)]
pub(super) struct Service;
impl irpc::Service for Service {}
#[rpc_requests(Service, message = RpcMessage)]
#[derive(Debug, Serialize, Deserialize)]
pub(crate) enum Request {
#[rpc(tx=spsc::Sender<Event>, rx=spsc::Receiver<Command>)]
Join(JoinRequest),
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct JoinRequest {
pub topic_id: TopicId,
pub bootstrap: BTreeSet<NodeId>,
}
#[derive(Debug, thiserror::Error)]
pub enum ApiError {
#[error(transparent)]
Rpc(#[from] irpc::Error),
#[error("topic closed")]
Closed,
}
impl From<irpc::channel::SendError> for ApiError {
fn from(value: irpc::channel::SendError) -> Self {
Self::Rpc(value.into())
}
}
impl From<irpc::channel::RecvError> for ApiError {
fn from(value: irpc::channel::RecvError) -> Self {
Self::Rpc(value.into())
}
}
#[derive(Debug, Clone)]
pub struct GossipApi {
client: Client<RpcMessage, Request, Service>,
}
impl GossipApi {
#[cfg(feature = "net")]
pub(crate) fn local(tx: tokio::sync::mpsc::Sender<RpcMessage>) -> Self {
let local = irpc::LocalSender::<RpcMessage, Service>::from(tx);
Self {
client: local.into(),
}
}
#[cfg(feature = "rpc")]
pub fn connect(endpoint: quinn::Endpoint, addr: std::net::SocketAddr) -> Self {
let inner = irpc::Client::quinn(endpoint, addr);
Self { client: inner }
}
#[cfg(all(feature = "rpc", feature = "net"))]
pub(crate) async fn listen(&self, endpoint: quinn::Endpoint) {
use std::sync::Arc;
use irpc::rpc::{listen, Handler};
let local = self
.client
.local()
.expect("cannot listen on remote client")
.clone();
let handler: Handler<Request> = Arc::new(move |req, rx, tx| {
let local = local.clone();
Box::pin({
match req {
Request::Join(msg) => local.send((msg, tx, rx)),
}
})
});
listen::<Request>(endpoint, handler).await
}
pub async fn subscribe_with_opts(
&self,
topic_id: TopicId,
opts: JoinOptions,
) -> Result<GossipTopic, ApiError> {
let req = JoinRequest {
topic_id,
bootstrap: opts.bootstrap,
};
let (tx, rx) = self
.client
.bidi_streaming(req, TOPIC_COMMANDS_CAP, opts.subscription_capacity)
.await?;
Ok(GossipTopic::new(tx, rx))
}
pub async fn subscribe_and_join(
&self,
topic_id: TopicId,
bootstrap: Vec<NodeId>,
) -> Result<GossipTopic, ApiError> {
let mut sub = self
.subscribe_with_opts(topic_id, JoinOptions::with_bootstrap(bootstrap))
.await?;
sub.joined().await?;
Ok(sub)
}
pub async fn subscribe(
&self,
topic_id: TopicId,
bootstrap: Vec<NodeId>,
) -> Result<GossipTopic, ApiError> {
let sub = self
.subscribe_with_opts(topic_id, JoinOptions::with_bootstrap(bootstrap))
.await?;
Ok(sub)
}
}
#[derive(Debug)]
pub struct GossipSender(spsc::Sender<Command>);
impl GossipSender {
pub(crate) fn new(sender: spsc::Sender<Command>) -> Self {
Self(sender)
}
pub async fn broadcast(&mut self, message: Bytes) -> Result<(), ApiError> {
self.send(Command::Broadcast(message)).await?;
Ok(())
}
pub async fn broadcast_neighbors(&mut self, message: Bytes) -> Result<(), ApiError> {
self.send(Command::BroadcastNeighbors(message)).await?;
Ok(())
}
pub async fn join_peers(&mut self, peers: Vec<NodeId>) -> Result<(), ApiError> {
self.send(Command::JoinPeers(peers)).await?;
Ok(())
}
async fn send(&mut self, command: Command) -> Result<(), irpc::channel::SendError> {
self.0.send(command).await?;
Ok(())
}
}
#[derive(Debug)]
pub struct GossipTopic {
sender: GossipSender,
receiver: GossipReceiver,
}
impl GossipTopic {
pub(crate) fn new(sender: spsc::Sender<Command>, receiver: spsc::Receiver<Event>) -> Self {
let sender = GossipSender::new(sender);
Self {
sender,
receiver: GossipReceiver::new(receiver),
}
}
pub fn split(self) -> (GossipSender, GossipReceiver) {
(self.sender, self.receiver)
}
pub async fn broadcast(&mut self, message: Bytes) -> Result<(), ApiError> {
self.sender.broadcast(message).await
}
pub async fn broadcast_neighbors(&mut self, message: Bytes) -> Result<(), ApiError> {
self.sender.broadcast_neighbors(message).await
}
pub async fn joined(&mut self) -> Result<(), ApiError> {
self.receiver.joined().await
}
pub fn is_joined(&self) -> bool {
self.receiver.is_joined()
}
}
impl Stream for GossipTopic {
type Item = Result<Event, ApiError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.receiver).poll_next(cx)
}
}
#[derive(derive_more::Debug)]
pub struct GossipReceiver {
#[debug("BoxStream")]
stream: Pin<Box<dyn Stream<Item = Result<Event, ApiError>> + Send + 'static>>,
neighbors: HashSet<NodeId>,
}
impl GossipReceiver {
pub(crate) fn new(events_rx: spsc::Receiver<Event>) -> Self {
let stream = events_rx.into_stream().map_err(ApiError::from);
let stream = Box::pin(stream);
Self {
stream,
neighbors: Default::default(),
}
}
pub fn neighbors(&self) -> impl Iterator<Item = NodeId> + '_ {
self.neighbors.iter().copied()
}
pub async fn joined(&mut self) -> Result<(), ApiError> {
while !self.is_joined() {
let _event = self.next().await.ok_or(ApiError::Closed)??;
}
Ok(())
}
pub fn is_joined(&self) -> bool {
!self.neighbors.is_empty()
}
}
impl Stream for GossipReceiver {
type Item = Result<Event, ApiError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let item = std::task::ready!(Pin::new(&mut self.stream).poll_next(cx));
if let Some(Ok(item)) = &item {
match item {
Event::Gossip(GossipEvent::NeighborUp(node_id)) => {
self.neighbors.insert(*node_id);
}
Event::Gossip(GossipEvent::NeighborDown(node_id)) => {
self.neighbors.remove(node_id);
}
_ => {}
}
}
Poll::Ready(item)
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub enum Event {
Gossip(GossipEvent),
Lagged,
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)]
pub enum GossipEvent {
NeighborUp(NodeId),
NeighborDown(NodeId),
Received(Message),
}
impl From<crate::proto::Event<NodeId>> for GossipEvent {
fn from(event: crate::proto::Event<NodeId>) -> Self {
match event {
crate::proto::Event::NeighborUp(node_id) => Self::NeighborUp(node_id),
crate::proto::Event::NeighborDown(node_id) => Self::NeighborDown(node_id),
crate::proto::Event::Received(message) => Self::Received(Message {
content: message.content,
scope: message.scope,
delivered_from: message.delivered_from,
}),
}
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, derive_more::Debug, Serialize, Deserialize)]
pub struct Message {
#[debug("Bytes({})", self.content.len())]
pub content: Bytes,
pub scope: DeliveryScope,
pub delivered_from: NodeId,
}
#[derive(Serialize, Deserialize, derive_more::Debug)]
pub enum Command {
Broadcast(#[debug("Bytes({})", _0.len())] Bytes),
BroadcastNeighbors(#[debug("Bytes({})", _0.len())] Bytes),
JoinPeers(Vec<NodeId>),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JoinOptions {
pub bootstrap: BTreeSet<NodeId>,
pub subscription_capacity: usize,
}
impl JoinOptions {
pub fn with_bootstrap(nodes: impl IntoIterator<Item = NodeId>) -> Self {
Self {
bootstrap: nodes.into_iter().collect(),
subscription_capacity: TOPIC_EVENTS_DEFAULT_CAP,
}
}
}
#[cfg(test)]
mod tests {
#[cfg(all(feature = "rpc", feature = "net"))]
#[tokio::test]
#[tracing_test::traced_test]
async fn test_rpc() -> testresult::TestResult {
use iroh::{protocol::Router, RelayMap};
use n0_future::{time::Duration, StreamExt};
use rand::SeedableRng;
use crate::{
api::{Event, GossipApi, GossipEvent},
net::{test::create_endpoint, Gossip},
proto::TopicId,
ALPN,
};
let mut rng = rand_chacha::ChaCha12Rng::seed_from_u64(1);
let (relay_map, _relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap();
async fn create_gossip_endpoint(
rng: &mut rand_chacha::ChaCha12Rng,
relay_map: RelayMap,
) -> anyhow::Result<(Router, Gossip)> {
let endpoint = create_endpoint(rng, relay_map).await?;
let gossip = Gossip::builder().spawn(endpoint.clone()).await?;
let router = Router::builder(endpoint)
.accept(ALPN, gossip.clone())
.spawn();
Ok((router, gossip))
}
let topic_id = TopicId::from_bytes([0u8; 32]);
let (router, gossip) = create_gossip_endpoint(&mut rng, relay_map.clone()).await?;
let (node2_id, node2_addr, node2_task) = {
let (router, gossip) = create_gossip_endpoint(&mut rng, relay_map.clone()).await?;
let node_addr = router.endpoint().node_addr().await?;
let node_id = router.endpoint().node_id();
let task = tokio::task::spawn(async move {
let mut topic = gossip.subscribe_and_join(topic_id, vec![]).await?;
topic.broadcast(b"hello".to_vec().into()).await?;
anyhow::Ok(router)
});
(node_id, node_addr, task)
};
router.endpoint().add_node_addr(node2_addr)?;
let (rpc_server_endpoint, rpc_server_cert) =
irpc::util::make_server_endpoint("127.0.0.1:0".parse().unwrap())?;
let rpc_server_addr = rpc_server_endpoint.local_addr()?;
let rpc_server_task = tokio::task::spawn(async move {
gossip.listen(rpc_server_endpoint).await;
});
let rpc_client_endpoint =
irpc::util::make_client_endpoint("127.0.0.1:0".parse().unwrap(), &[&rpc_server_cert])?;
let rpc_client = GossipApi::connect(rpc_client_endpoint, rpc_server_addr);
let recv = async move {
let mut topic = rpc_client
.subscribe_and_join(topic_id, vec![node2_id])
.await?;
while let Some(event) = topic.try_next().await? {
match event {
Event::Gossip(GossipEvent::Received(message)) => {
assert_eq!(&message.content[..], b"hello");
break;
}
Event::Lagged => panic!("unexpected lagged event"),
_ => {}
}
}
anyhow::Ok(())
};
tokio::time::timeout(Duration::from_secs(10), recv).await??;
rpc_server_task.abort();
router.shutdown().await?;
let router2 = node2_task.await??;
router2.shutdown().await?;
Ok(())
}
}