use std::{collections::HashMap, time::Duration};
use anyhow::{bail, Result};
use bytes::Bytes;
use iroh_base::key::NodeId;
use iroh_metrics::{inc, inc_by};
use time::{Date, OffsetDateTime};
use tokio::sync::mpsc;
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{info, info_span, trace, warn, Instrument};
use crate::{
defaults::timeouts::SERVER_WRITE_TIMEOUT as WRITE_TIMEOUT,
protos::relay::SERVER_CHANNEL_SIZE,
server::{client_conn::ClientConnConfig, clients::Clients, metrics::Metrics},
};
#[derive(Debug)]
pub(super) enum Message {
SendPacket {
dst: NodeId,
data: Bytes,
src: NodeId,
},
SendDiscoPacket {
dst: NodeId,
data: Bytes,
src: NodeId,
},
CreateClient(ClientConnConfig),
RemoveClient {
node_id: NodeId,
conn_num: usize,
},
}
#[derive(Debug, Clone)]
pub(super) struct Packet {
pub(super) src: NodeId,
pub(super) data: Bytes,
}
#[derive(Debug)]
pub(super) struct ServerActorTask {
pub(super) write_timeout: Duration,
pub(super) server_channel: mpsc::Sender<Message>,
loop_handler: AbortOnDropHandle<Result<()>>,
cancel: CancellationToken,
}
impl ServerActorTask {
pub(super) fn spawn() -> Self {
let (server_channel_s, server_channel_r) = mpsc::channel(SERVER_CHANNEL_SIZE);
let server_actor = Actor::new(server_channel_r);
let cancel_token = CancellationToken::new();
let done = cancel_token.clone();
let server_task = AbortOnDropHandle::new(tokio::spawn(
async move { server_actor.run(done).await }.instrument(info_span!("relay.server")),
));
Self {
write_timeout: WRITE_TIMEOUT,
server_channel: server_channel_s,
loop_handler: server_task,
cancel: cancel_token,
}
}
pub(super) async fn close(self) {
self.cancel.cancel();
match self.loop_handler.await {
Ok(Ok(())) => {}
Ok(Err(e)) => warn!("error shutting down server: {e:#}"),
Err(e) => warn!("error waiting for the server process to close: {e:?}"),
}
}
}
struct Actor {
receiver: mpsc::Receiver<Message>,
clients: Clients,
client_counter: ClientCounter,
}
impl Actor {
fn new(receiver: mpsc::Receiver<Message>) -> Self {
Self {
receiver,
clients: Clients::default(),
client_counter: ClientCounter::default(),
}
}
async fn run(mut self, done: CancellationToken) -> Result<()> {
loop {
tokio::select! {
biased;
_ = done.cancelled() => {
info!("server actor loop cancelled, closing loop");
self.clients.shutdown().await;
return Ok(());
}
msg = self.receiver.recv() => match msg {
Some(msg) => {
self.handle_message(msg).await;
}
None => {
warn!("unexpected actor error: receiver gone, shutting down actor loop");
self.clients.shutdown().await;
bail!("unexpected actor error, closed client connections, and shutting down actor loop");
}
}
}
}
}
async fn handle_message(&mut self, msg: Message) {
match msg {
Message::SendPacket { dst, data, src } => {
trace!(?src, ?dst, len = data.len(), "send packet");
if self.clients.contains_key(&dst) {
match self.clients.send_packet(&dst, Packet { data, src }).await {
Ok(()) => {
self.clients.record_send(&src, dst);
inc!(Metrics, send_packets_sent);
}
Err(err) => {
trace!(?dst, "failed to send packet: {err:#}");
inc!(Metrics, send_packets_dropped);
}
}
} else {
warn!(?dst, "no way to reach client, dropped packet");
inc!(Metrics, send_packets_dropped);
}
}
Message::SendDiscoPacket { dst, data, src } => {
trace!(?src, ?dst, len = data.len(), "send disco packet");
if self.clients.contains_key(&dst) {
match self
.clients
.send_disco_packet(&dst, Packet { data, src })
.await
{
Ok(()) => {
self.clients.record_send(&src, dst);
inc!(Metrics, disco_packets_sent);
}
Err(err) => {
trace!(?dst, "failed to send disco packet: {err:#}");
inc!(Metrics, disco_packets_dropped);
}
}
} else {
warn!(?dst, "disco: no way to reach client, dropped packet");
inc!(Metrics, disco_packets_dropped);
}
}
Message::CreateClient(client_builder) => {
inc!(Metrics, accepts);
let node_id = client_builder.node_id;
trace!(node_id = node_id.fmt_short(), "create client");
self.clients.register(client_builder).await;
let nc = self.client_counter.update(node_id);
inc_by!(Metrics, unique_client_keys, nc);
}
Message::RemoveClient { node_id, conn_num } => {
inc!(Metrics, disconnects);
trace!(node_id = %node_id.fmt_short(), "remove client");
if self.clients.has_client(&node_id, conn_num) {
self.clients.unregister(&node_id).await;
}
}
}
}
}
struct ClientCounter {
clients: HashMap<NodeId, usize>,
last_clear_date: Date,
}
impl Default for ClientCounter {
fn default() -> Self {
Self {
clients: HashMap::new(),
last_clear_date: OffsetDateTime::now_utc().date(),
}
}
}
impl ClientCounter {
fn check_and_clear(&mut self) {
let today = OffsetDateTime::now_utc().date();
if today != self.last_clear_date {
self.clients.clear();
self.last_clear_date = today;
}
}
fn update(&mut self, client: NodeId) -> u64 {
self.check_and_clear();
let new_conn = !self.clients.contains_key(&client);
let counter = self.clients.entry(client).or_insert(0);
*counter += 1;
new_conn as u64
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use iroh_base::key::SecretKey;
use tokio::io::DuplexStream;
use tokio_util::codec::Framed;
use super::*;
use crate::{
protos::relay::{recv_frame, DerpCodec, Frame, FrameType},
server::{
client_conn::ClientConnConfig,
streams::{MaybeTlsStream, RelayedStream},
},
};
fn test_client_builder(
node_id: NodeId,
server_channel: mpsc::Sender<Message>,
) -> (ClientConnConfig, Framed<DuplexStream, DerpCodec>) {
let (test_io, io) = tokio::io::duplex(1024);
(
ClientConnConfig {
node_id,
stream: RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)),
write_timeout: Duration::from_secs(1),
channel_capacity: 10,
rate_limit: None,
server_channel,
},
Framed::new(test_io, DerpCodec),
)
}
#[tokio::test]
async fn test_server_actor() -> Result<()> {
let (server_channel, server_channel_r) = mpsc::channel(20);
let server_actor: Actor = Actor::new(server_channel_r);
let done = CancellationToken::new();
let server_done = done.clone();
let server_task = tokio::spawn(
async move { server_actor.run(server_done).await }
.instrument(info_span!("relay.server")),
);
let node_id_a = SecretKey::generate().public();
let (client_a, mut a_io) = test_client_builder(node_id_a, server_channel.clone());
server_channel
.send(Message::CreateClient(client_a))
.await
.map_err(|_| anyhow::anyhow!("server gone"))?;
let node_id_b = SecretKey::generate().public();
let (client_b, mut b_io) = test_client_builder(node_id_b, server_channel.clone());
server_channel
.send(Message::CreateClient(client_b))
.await
.map_err(|_| anyhow::anyhow!("server gone"))?;
let msg = b"hello world!";
crate::client::conn::send_packet(&mut b_io, node_id_a, Bytes::from_static(msg)).await?;
let frame = recv_frame(FrameType::RecvPacket, &mut a_io).await?;
assert_eq!(
frame,
Frame::RecvPacket {
src_key: node_id_b,
content: msg.to_vec().into()
}
);
server_channel
.send(Message::RemoveClient {
node_id: node_id_b,
conn_num: 1,
})
.await
.map_err(|_| anyhow::anyhow!("server gone"))?;
let frame = recv_frame(FrameType::PeerGone, &mut a_io).await?;
assert_eq!(Frame::NodeGone { node_id: node_id_b }, frame);
done.cancel();
server_task.await??;
Ok(())
}
}