use std::{future::Future, num::NonZeroU32, pin::Pin, sync::Arc, task::Poll, time::Duration};
use anyhow::{Context, Result};
use bytes::Bytes;
use futures_lite::FutureExt;
use futures_sink::Sink;
use futures_util::{SinkExt, Stream, StreamExt};
use iroh_base::key::NodeId;
use iroh_metrics::{inc, inc_by};
use tokio::sync::mpsc;
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{error, info, instrument, trace, warn, Instrument};
use crate::{
protos::{
disco,
relay::{write_frame, Frame, KEEP_ALIVE},
},
server::{
actor::{self, Packet},
metrics::Metrics,
streams::RelayedStream,
ClientConnRateLimit,
},
};
#[derive(Debug)]
pub(super) struct ClientConnConfig {
pub(super) node_id: NodeId,
pub(super) stream: RelayedStream,
pub(super) write_timeout: Duration,
pub(super) channel_capacity: usize,
pub(super) rate_limit: Option<ClientConnRateLimit>,
pub(super) server_channel: mpsc::Sender<actor::Message>,
}
#[derive(Debug)]
pub(super) struct ClientConn {
pub(super) conn_num: usize,
pub(super) key: NodeId,
done: CancellationToken,
handle: AbortOnDropHandle<()>,
pub(super) send_queue: mpsc::Sender<Packet>,
pub(super) disco_send_queue: mpsc::Sender<Packet>,
pub(super) peer_gone: mpsc::Sender<NodeId>,
}
impl ClientConn {
pub fn new(config: ClientConnConfig, conn_num: usize) -> ClientConn {
let ClientConnConfig {
node_id: key,
stream: io,
write_timeout,
channel_capacity,
rate_limit,
server_channel,
} = config;
let stream = match rate_limit {
Some(cfg) => {
let mut quota = governor::Quota::per_second(cfg.bytes_per_second);
if let Some(max_burst) = cfg.max_burst_bytes {
quota = quota.allow_burst(max_burst);
}
let limiter = governor::RateLimiter::direct(quota);
RateLimitedRelayedStream::new(io, limiter)
}
None => RateLimitedRelayedStream::unlimited(io),
};
let done = CancellationToken::new();
let client_id = (key, conn_num);
let (send_queue_s, send_queue_r) = mpsc::channel(channel_capacity);
let (disco_send_queue_s, disco_send_queue_r) = mpsc::channel(channel_capacity);
let (peer_gone_s, peer_gone_r) = mpsc::channel(channel_capacity);
let actor = Actor {
stream,
timeout: write_timeout,
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
node_gone: peer_gone_r,
key,
preferred: false,
server_channel: server_channel.clone(),
};
let io_done = done.clone();
let io_client_id = client_id;
let handle = tokio::task::spawn(
async move {
let (key, conn_num) = io_client_id;
let res = actor.run(io_done).await;
let _ = server_channel
.send(actor::Message::RemoveClient {
node_id: key,
conn_num,
})
.await;
match res {
Err(e) => {
warn!(
"connection manager for {key:?} {conn_num}: writer closed in error {e}"
);
}
Ok(()) => {
info!("connection manager for {key:?} {conn_num}: writer closed");
}
}
}
.instrument(tracing::info_span!("client_conn_actor")),
);
ClientConn {
conn_num,
key,
handle: AbortOnDropHandle::new(handle),
done,
send_queue: send_queue_s,
disco_send_queue: disco_send_queue_s,
peer_gone: peer_gone_s,
}
}
pub async fn shutdown(self) {
self.done.cancel();
if let Err(e) = self.handle.await {
warn!(
"error closing actor loop for client connection {:?} {}: {e:?}",
self.key, self.conn_num
);
};
}
}
#[derive(Debug)]
struct Actor {
stream: RateLimitedRelayedStream,
timeout: Duration,
send_queue: mpsc::Receiver<Packet>,
disco_send_queue: mpsc::Receiver<Packet>,
node_gone: mpsc::Receiver<NodeId>,
key: NodeId,
server_channel: mpsc::Sender<actor::Message>,
preferred: bool,
}
impl Actor {
async fn run(mut self, done: CancellationToken) -> Result<()> {
let jitter = Duration::from_secs(5);
let mut keep_alive = tokio::time::interval(KEEP_ALIVE + jitter);
keep_alive.tick().await;
loop {
tokio::select! {
biased;
_ = done.cancelled() => {
trace!("actor loop cancelled, exiting");
self.stream.flush().await.context("flush")?;
break;
}
read_res = self.stream.next() => {
trace!("handle frame");
match read_res {
Some(Ok(frame)) => {
self.handle_frame(frame).await.context("handle_read")?;
}
Some(Err(err)) => {
return Err(err);
}
None => {
return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "read stream ended").into());
}
}
}
node_id = self.node_gone.recv() => {
let node_id = node_id.context("Server.node_gone dropped")?;
trace!("node_id gone: {:?}", node_id);
self.write_frame(Frame::NodeGone { node_id }).await?;
}
packet = self.send_queue.recv() => {
let packet = packet.context("Server.send_queue dropped")?;
trace!("send packet");
self.send_packet(packet).await.context("send packet")?;
}
packet = self.disco_send_queue.recv() => {
let packet = packet.context("Server.disco_send_queue dropped")?;
trace!("send disco packet");
self.send_packet(packet).await.context("send packet")?;
}
_ = keep_alive.tick() => {
trace!("keep alive");
self.write_frame(Frame::KeepAlive).await?;
}
}
self.stream.flush().await.context("tick flush")?;
}
Ok(())
}
async fn write_frame(&mut self, frame: Frame) -> Result<()> {
write_frame(&mut self.stream, frame, Some(self.timeout)).await
}
async fn send_packet(&mut self, packet: Packet) -> Result<()> {
let src_key = packet.src;
let content = packet.data;
if let Ok(len) = content.len().try_into() {
inc_by!(Metrics, bytes_sent, len);
}
self.write_frame(Frame::RecvPacket { src_key, content })
.await
}
async fn handle_frame(&mut self, frame: Frame) -> Result<()> {
match frame {
Frame::NotePreferred { preferred } => {
self.preferred = preferred;
inc!(Metrics, other_packets_recv);
}
Frame::SendPacket { dst_key, packet } => {
let packet_len = packet.len();
self.handle_frame_send_packet(dst_key, packet).await?;
inc_by!(Metrics, bytes_recv, packet_len as u64);
}
Frame::Ping { data } => {
inc!(Metrics, got_ping);
self.write_frame(Frame::Pong { data }).await?;
inc!(Metrics, sent_pong);
}
Frame::Health { .. } => {
inc!(Metrics, other_packets_recv);
}
_ => {
inc!(Metrics, unknown_frames);
}
}
Ok(())
}
async fn handle_frame_send_packet(&self, dst_key: NodeId, data: Bytes) -> Result<()> {
let message = if disco::looks_like_disco_wrapper(&data) {
inc!(Metrics, disco_packets_recv);
actor::Message::SendDiscoPacket {
dst: dst_key,
src: self.key,
data,
}
} else {
inc!(Metrics, send_packets_recv);
actor::Message::SendPacket {
dst: dst_key,
src: self.key,
data,
}
};
self.server_channel
.send(message)
.await
.map_err(|_| anyhow::anyhow!("server gone"))?;
Ok(())
}
}
#[derive(Debug)]
struct RateLimitedRelayedStream {
inner: RelayedStream,
limiter: Option<Arc<governor::DefaultDirectRateLimiter>>,
state: State,
limited_once: bool,
}
#[derive(derive_more::Debug)]
enum State {
#[debug("Blocked")]
Blocked {
delay: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
item: anyhow::Result<Frame>,
},
Ready,
}
impl RateLimitedRelayedStream {
fn new(inner: RelayedStream, limiter: governor::DefaultDirectRateLimiter) -> Self {
Self {
inner,
limiter: Some(Arc::new(limiter)),
state: State::Ready,
limited_once: false,
}
}
fn unlimited(inner: RelayedStream) -> Self {
Self {
inner,
limiter: None,
state: State::Ready,
limited_once: false,
}
}
}
impl RateLimitedRelayedStream {
fn record_rate_limited(&mut self) {
inc!(Metrics, frames_rx_ratelimited_total);
if !self.limited_once {
inc!(Metrics, conns_rx_ratelimited_total);
self.limited_once = true;
}
}
}
impl Stream for RateLimitedRelayedStream {
type Item = anyhow::Result<Frame>;
#[instrument(name = "rate_limited_relayed_stream", skip_all)]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let Some(ref limiter) = self.limiter else {
return Pin::new(&mut self.inner).poll_next(cx);
};
let limiter = limiter.clone();
loop {
match &mut self.state {
State::Ready => {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(item)) => {
match &item {
Ok(frame) => {
let Ok(frame_len) =
TryInto::<u32>::try_into(frame.len_with_header())
.and_then(TryInto::<NonZeroU32>::try_into)
else {
error!("frame len not NonZeroU32, is MAX_FRAME_SIZE too large?");
return Poll::Ready(Some(item));
};
match limiter.check_n(frame_len) {
Ok(Ok(_)) => return Poll::Ready(Some(item)),
Ok(Err(_)) => {
self.record_rate_limited();
let delay = Box::pin({
let limiter = limiter.clone();
async move {
limiter.until_n_ready(frame_len).await.ok();
}
});
self.state = State::Blocked { delay, item };
continue;
}
Err(_insufficient_capacity) => {
error!(
"frame larger than bucket capacity: \
configuration error: \
max_burst_bytes < MAX_FRAME_SIZE?"
);
return Poll::Ready(Some(item));
}
}
}
Err(_) => {
return Poll::Ready(Some(item));
}
}
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
State::Blocked { delay, .. } => {
match delay.poll(cx) {
Poll::Ready(_) => {
match std::mem::replace(&mut self.state, State::Ready) {
State::Ready => unreachable!(),
State::Blocked { item, .. } => {
return Poll::Ready(Some(item));
}
}
}
Poll::Pending => return Poll::Pending,
}
}
}
}
}
}
impl Sink<Frame> for RateLimitedRelayedStream {
type Error = std::io::Error;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: Frame) -> std::result::Result<(), Self::Error> {
Pin::new(&mut self.inner).start_send(item)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx)
}
}
#[cfg(test)]
mod tests {
use anyhow::bail;
use bytes::Bytes;
use iroh_base::key::SecretKey;
use testresult::TestResult;
use tokio_util::codec::Framed;
use super::*;
use crate::{
client::conn,
protos::relay::{recv_frame, DerpCodec, FrameType},
server::streams::MaybeTlsStream,
};
#[tokio::test]
async fn test_client_actor_basic() -> Result<()> {
let (send_queue_s, send_queue_r) = mpsc::channel(10);
let (disco_send_queue_s, disco_send_queue_r) = mpsc::channel(10);
let (peer_gone_s, peer_gone_r) = mpsc::channel(10);
let key = SecretKey::generate().public();
let (io, io_rw) = tokio::io::duplex(1024);
let mut io_rw = Framed::new(io_rw, DerpCodec);
let (server_channel_s, mut server_channel_r) = mpsc::channel(10);
let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec));
let actor = Actor {
stream: RateLimitedRelayedStream::unlimited(stream),
timeout: Duration::from_secs(1),
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
node_gone: peer_gone_r,
key,
server_channel: server_channel_s,
preferred: true,
};
let done = CancellationToken::new();
let io_done = done.clone();
let handle = tokio::task::spawn(async move { actor.run(io_done).await });
println!("-- write");
let data = b"hello world!";
println!(" send packet");
let packet = Packet {
src: key,
data: Bytes::from(&data[..]),
};
send_queue_s.send(packet.clone()).await?;
let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?;
assert_eq!(
frame,
Frame::RecvPacket {
src_key: key,
content: data.to_vec().into()
}
);
println!(" send disco packet");
disco_send_queue_s.send(packet.clone()).await?;
let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?;
assert_eq!(
frame,
Frame::RecvPacket {
src_key: key,
content: data.to_vec().into()
}
);
println!("send peer gone");
peer_gone_s.send(key).await?;
let frame = recv_frame(FrameType::PeerGone, &mut io_rw).await?;
assert_eq!(frame, Frame::NodeGone { node_id: key });
println!("--read");
let data = b"pingpong";
write_frame(&mut io_rw, Frame::Ping { data: *data }, None).await?;
println!(" recv pong");
let frame = recv_frame(FrameType::Pong, &mut io_rw).await?;
assert_eq!(frame, Frame::Pong { data: *data });
println!(" preferred: false");
write_frame(&mut io_rw, Frame::NotePreferred { preferred: false }, None).await?;
println!(" preferred: true");
write_frame(&mut io_rw, Frame::NotePreferred { preferred: true }, None).await?;
let target = SecretKey::generate().public();
println!(" send packet");
let data = b"hello world!";
conn::send_packet(&mut io_rw, target, Bytes::from_static(data)).await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
actor::Message::SendPacket {
dst: got_target,
data: got_data,
src: got_src,
} => {
assert_eq!(target, got_target);
assert_eq!(key, got_src);
assert_eq!(&data[..], &got_data);
}
m => {
bail!("expected ServerMessage::SendPacket, got {m:?}");
}
}
println!(" send disco packet");
let mut disco_data = disco::MAGIC.as_bytes().to_vec();
disco_data.extend_from_slice(target.as_bytes());
disco_data.extend_from_slice(data);
conn::send_packet(&mut io_rw, target, disco_data.clone().into()).await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
actor::Message::SendDiscoPacket {
dst: got_target,
src: got_src,
data: got_data,
} => {
assert_eq!(target, got_target);
assert_eq!(key, got_src);
assert_eq!(&disco_data[..], &got_data);
}
m => {
bail!("expected ServerMessage::SendDiscoPacket, got {m:?}");
}
}
done.cancel();
handle.await??;
Ok(())
}
#[tokio::test]
async fn test_client_conn_read_err() -> Result<()> {
let (_send_queue_s, send_queue_r) = mpsc::channel(10);
let (_disco_send_queue_s, disco_send_queue_r) = mpsc::channel(10);
let (_peer_gone_s, peer_gone_r) = mpsc::channel(10);
let key = SecretKey::generate().public();
let (io, io_rw) = tokio::io::duplex(1024);
let mut io_rw = Framed::new(io_rw, DerpCodec);
let (server_channel_s, mut server_channel_r) = mpsc::channel(10);
let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec));
println!("-- create client conn");
let actor = Actor {
stream: RateLimitedRelayedStream::unlimited(stream),
timeout: Duration::from_secs(1),
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
node_gone: peer_gone_r,
key,
server_channel: server_channel_s,
preferred: true,
};
let done = CancellationToken::new();
let io_done = done.clone();
println!("-- run client conn");
let handle = tokio::task::spawn(async move { actor.run(io_done).await });
println!(" send packet");
let data = b"hello world!";
let target = SecretKey::generate().public();
conn::send_packet(&mut io_rw, target, Bytes::from_static(data)).await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
actor::Message::SendPacket {
dst: got_target,
src: got_src,
data: got_data,
} => {
assert_eq!(target, got_target);
assert_eq!(key, got_src);
assert_eq!(&data[..], &got_data);
println!(" send packet success");
}
m => {
bail!("expected ServerMessage::SendPacket, got {m:?}");
}
}
println!("-- drop io");
drop(io_rw);
if let Err(err) = tokio::time::timeout(Duration::from_secs(1), handle).await?? {
if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
println!(" task closed successfully with `UnexpectedEof` error");
} else {
bail!("expected `UnexpectedEof` error, got unknown error: {io_err:?}");
}
} else {
bail!("expected `std::io::Error`, got `None`");
}
} else {
bail!("expected task to finish in `UnexpectedEof` error, got `Ok(())`");
}
Ok(())
}
#[tokio::test]
async fn test_rate_limit() -> TestResult {
let _logging = iroh_test::logging::setup();
const LIMIT: u32 = 50;
const MAX_FRAMES: u32 = 100;
let quota = governor::Quota::per_second(NonZeroU32::try_from(LIMIT)?);
let limiter = governor::RateLimiter::direct(quota);
let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _);
let mut frame_writer = Framed::new(io_write, DerpCodec);
let stream = RelayedStream::Derp(Framed::new(MaybeTlsStream::Test(io_read), DerpCodec));
let mut stream = RateLimitedRelayedStream::new(stream, limiter);
let data = Bytes::from_static(b"hello world!!");
let target = SecretKey::generate().public();
let frame = Frame::SendPacket {
dst_key: target,
packet: data.clone(),
};
let frame_len = frame.len_with_header();
assert_eq!(frame_len, LIMIT as usize);
info!("-- send packet");
frame_writer.send(frame.clone()).await?;
frame_writer.flush().await?;
let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout")
.expect("option")
.expect("ok");
assert_eq!(recv_frame, frame);
info!("-- send packet");
frame_writer.send(frame.clone()).await?;
frame_writer.flush().await?;
let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await;
assert!(res.is_err(), "expecting a timeout");
info!("-- timeout happened");
info!("-- sleep");
tokio::time::sleep(Duration::from_secs(1)).await;
let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
.await
.expect("timeout")
.expect("option")
.expect("ok");
assert_eq!(recv_frame, frame);
Ok(())
}
}