use std::{fmt::Debug, io, ops::Deref};
use irpc::{
channel::{mpsc, none::NoSender, oneshot},
rpc_requests, Channels, WithChannels,
};
use serde::{Deserialize, Serialize};
use snafu::Snafu;
use crate::{
protocol::{
GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
ERR_PERMISSION,
},
provider::{events::irpc_ext::IrpcClientExt, TransferStats},
Hash,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum ConnectMode {
#[default]
None,
Notify,
Request,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum ObserveMode {
#[default]
None,
Notify,
Request,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum RequestMode {
#[default]
None,
Notify,
Request,
NotifyLog,
RequestLog,
Disabled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum ThrottleMode {
#[default]
None,
Throttle,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AbortReason {
RateLimited,
Permission,
}
#[derive(Debug, Snafu)]
pub enum ProgressError {
Limit,
Permission,
#[snafu(transparent)]
Internal {
source: irpc::Error,
},
}
impl From<ProgressError> for io::Error {
fn from(value: ProgressError) -> Self {
match value {
ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(),
ProgressError::Permission => io::ErrorKind::PermissionDenied.into(),
ProgressError::Internal { source } => source.into(),
}
}
}
pub trait HasErrorCode {
fn code(&self) -> quinn::VarInt;
}
impl HasErrorCode for ProgressError {
fn code(&self) -> quinn::VarInt {
match self {
ProgressError::Limit => ERR_LIMIT,
ProgressError::Permission => ERR_PERMISSION,
ProgressError::Internal { .. } => ERR_INTERNAL,
}
}
}
impl ProgressError {
pub fn reason(&self) -> &'static [u8] {
match self {
ProgressError::Limit => b"limit",
ProgressError::Permission => b"permission",
ProgressError::Internal { .. } => b"internal",
}
}
}
impl From<AbortReason> for ProgressError {
fn from(value: AbortReason) -> Self {
match value {
AbortReason::RateLimited => ProgressError::Limit,
AbortReason::Permission => ProgressError::Permission,
}
}
}
impl From<irpc::channel::RecvError> for ProgressError {
fn from(value: irpc::channel::RecvError) -> Self {
ProgressError::Internal {
source: value.into(),
}
}
}
impl From<irpc::channel::SendError> for ProgressError {
fn from(value: irpc::channel::SendError) -> Self {
ProgressError::Internal {
source: value.into(),
}
}
}
pub type EventResult = Result<(), AbortReason>;
pub type ClientResult = Result<(), ProgressError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EventMask {
pub connected: ConnectMode,
pub get: RequestMode,
pub get_many: RequestMode,
pub push: RequestMode,
pub observe: ObserveMode,
pub throttle: ThrottleMode,
}
impl Default for EventMask {
fn default() -> Self {
Self::DEFAULT
}
}
impl EventMask {
pub const DEFAULT: Self = Self {
connected: ConnectMode::None,
get: RequestMode::None,
get_many: RequestMode::None,
push: RequestMode::Disabled,
throttle: ThrottleMode::None,
observe: ObserveMode::None,
};
pub const ALL_READONLY: Self = Self {
connected: ConnectMode::Request,
get: RequestMode::RequestLog,
get_many: RequestMode::RequestLog,
push: RequestMode::Disabled,
throttle: ThrottleMode::Throttle,
observe: ObserveMode::Request,
};
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Notify<T>(T);
impl<T> Deref for Notify<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Default, Clone)]
pub struct EventSender {
mask: EventMask,
inner: Option<irpc::Client<ProviderProto>>,
}
#[derive(Debug, Default)]
enum RequestUpdates {
#[default]
None,
Active(mpsc::Sender<RequestUpdate>),
Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
}
#[derive(Debug)]
pub struct RequestTracker {
updates: RequestUpdates,
throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
}
impl RequestTracker {
fn new(
updates: RequestUpdates,
throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
) -> Self {
Self { updates, throttle }
}
pub const NONE: Self = Self {
updates: RequestUpdates::None,
throttle: None,
};
pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
if let RequestUpdates::Active(tx) = &self.updates {
tx.send(
TransferStarted {
index,
hash: *hash,
size,
}
.into(),
)
.await?;
}
Ok(())
}
pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
if let RequestUpdates::Active(tx) = &mut self.updates {
tx.try_send(TransferProgress { end_offset }.into()).await?;
}
if let Some((throttle, connection_id, request_id)) = &self.throttle {
throttle
.rpc(Throttle {
connection_id: *connection_id,
request_id: *request_id,
size: len,
})
.await??;
}
Ok(())
}
pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
if let RequestUpdates::Active(tx) = &self.updates {
tx.send(TransferCompleted { stats: f() }.into()).await?;
}
Ok(())
}
pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
if let RequestUpdates::Active(tx) = &self.updates {
tx.send(TransferAborted { stats: f() }.into()).await?;
}
Ok(())
}
}
impl EventSender {
pub const DEFAULT: Self = Self {
mask: EventMask::DEFAULT,
inner: None,
};
pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
Self {
mask,
inner: Some(irpc::Client::from(client)),
}
}
pub fn channel(
capacity: usize,
mask: EventMask,
) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
(Self::new(tx, mask), rx)
}
pub fn tracing(&self, mask: EventMask) -> Self {
use tracing::trace;
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
n0_future::task::spawn(async move {
fn log_request_events(
mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
connection_id: u64,
request_id: u64,
) {
n0_future::task::spawn(async move {
while let Ok(Some(update)) = rx.recv().await {
trace!(%connection_id, %request_id, "{update:?}");
}
});
}
while let Some(msg) = rx.recv().await {
match msg {
ProviderMessage::ClientConnected(msg) => {
trace!("{:?}", msg.inner);
msg.tx.send(Ok(())).await.ok();
}
ProviderMessage::ClientConnectedNotify(msg) => {
trace!("{:?}", msg.inner);
}
ProviderMessage::ConnectionClosed(msg) => {
trace!("{:?}", msg.inner);
}
ProviderMessage::GetRequestReceived(msg) => {
trace!("{:?}", msg.inner);
msg.tx.send(Ok(())).await.ok();
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::GetRequestReceivedNotify(msg) => {
trace!("{:?}", msg.inner);
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::GetManyRequestReceived(msg) => {
trace!("{:?}", msg.inner);
msg.tx.send(Ok(())).await.ok();
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::GetManyRequestReceivedNotify(msg) => {
trace!("{:?}", msg.inner);
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::PushRequestReceived(msg) => {
trace!("{:?}", msg.inner);
msg.tx.send(Ok(())).await.ok();
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::PushRequestReceivedNotify(msg) => {
trace!("{:?}", msg.inner);
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::ObserveRequestReceived(msg) => {
trace!("{:?}", msg.inner);
msg.tx.send(Ok(())).await.ok();
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::ObserveRequestReceivedNotify(msg) => {
trace!("{:?}", msg.inner);
log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
}
ProviderMessage::Throttle(msg) => {
trace!("{:?}", msg.inner);
msg.tx.send(Ok(())).await.ok();
}
}
}
});
Self {
mask,
inner: Some(irpc::Client::from(tx)),
}
}
pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
if let Some(client) = &self.inner {
match self.mask.connected {
ConnectMode::None => {}
ConnectMode::Notify => client.notify(Notify(f())).await?,
ConnectMode::Request => client.rpc(f()).await??,
}
};
Ok(())
}
pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
if let Some(client) = &self.inner {
client.notify(f()).await?;
};
Ok(())
}
pub(crate) async fn request<Req>(
&self,
f: impl FnOnce() -> Req,
connection_id: u64,
request_id: u64,
) -> Result<RequestTracker, ProgressError>
where
ProviderProto: From<RequestReceived<Req>>,
ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
RequestReceived<Req>: Channels<
ProviderProto,
Tx = oneshot::Sender<EventResult>,
Rx = mpsc::Receiver<RequestUpdate>,
>,
ProviderProto: From<Notify<RequestReceived<Req>>>,
ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
Notify<RequestReceived<Req>>:
Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
{
let client = self.inner.as_ref();
Ok(self.create_tracker((
match self.mask.get {
RequestMode::None => RequestUpdates::None,
RequestMode::Notify if client.is_some() => {
let msg = RequestReceived {
request: f(),
connection_id,
request_id,
};
RequestUpdates::Disabled(
client.unwrap().notify_streaming(Notify(msg), 32).await?,
)
}
RequestMode::Request if client.is_some() => {
let msg = RequestReceived {
request: f(),
connection_id,
request_id,
};
let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
rx.await??;
RequestUpdates::Disabled(tx)
}
RequestMode::NotifyLog if client.is_some() => {
let msg = RequestReceived {
request: f(),
connection_id,
request_id,
};
RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
}
RequestMode::RequestLog if client.is_some() => {
let msg = RequestReceived {
request: f(),
connection_id,
request_id,
};
let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
rx.await??;
RequestUpdates::Active(tx)
}
RequestMode::Disabled => {
return Err(ProgressError::Permission);
}
_ => RequestUpdates::None,
},
connection_id,
request_id,
)))
}
fn create_tracker(
&self,
(updates, connection_id, request_id): (RequestUpdates, u64, u64),
) -> RequestTracker {
let throttle = match self.mask.throttle {
ThrottleMode::None => None,
ThrottleMode::Throttle => self
.inner
.clone()
.map(|client| (client, connection_id, request_id)),
};
RequestTracker::new(updates, throttle)
}
}
#[rpc_requests(message = ProviderMessage)]
#[derive(Debug, Serialize, Deserialize)]
pub enum ProviderProto {
#[rpc(tx = oneshot::Sender<EventResult>)]
ClientConnected(ClientConnected),
#[rpc(tx = NoSender)]
ClientConnectedNotify(Notify<ClientConnected>),
#[rpc(tx = NoSender)]
ConnectionClosed(ConnectionClosed),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
GetRequestReceived(RequestReceived<GetRequest>),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
GetManyRequestReceived(RequestReceived<GetManyRequest>),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
PushRequestReceived(RequestReceived<PushRequest>),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
ObserveRequestReceived(RequestReceived<ObserveRequest>),
#[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
#[rpc(tx = oneshot::Sender<EventResult>)]
Throttle(Throttle),
}
mod proto {
use iroh::NodeId;
use serde::{Deserialize, Serialize};
use crate::{provider::TransferStats, Hash};
#[derive(Debug, Serialize, Deserialize)]
pub struct ClientConnected {
pub connection_id: u64,
pub node_id: NodeId,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ConnectionClosed {
pub connection_id: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RequestReceived<R> {
pub connection_id: u64,
pub request_id: u64,
pub request: R,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Throttle {
pub connection_id: u64,
pub request_id: u64,
pub size: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferProgress {
pub end_offset: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferStarted {
pub index: u64,
pub hash: Hash,
pub size: u64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferCompleted {
pub stats: Box<TransferStats>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferAborted {
pub stats: Box<TransferStats>,
}
#[derive(Debug, Serialize, Deserialize, derive_more::From)]
pub enum RequestUpdate {
Started(TransferStarted),
Progress(TransferProgress),
Completed(TransferCompleted),
Aborted(TransferAborted),
}
}
pub use proto::*;
mod irpc_ext {
use std::future::Future;
use irpc::{
channel::{mpsc, none::NoSender},
Channels, RpcMessage, Service, WithChannels,
};
pub trait IrpcClientExt<S: Service> {
fn notify_streaming<Req, Update>(
&self,
msg: Req,
local_update_cap: usize,
) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
Update: RpcMessage;
}
impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
fn notify_streaming<Req, Update>(
&self,
msg: Req,
local_update_cap: usize,
) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
where
S: From<Req>,
S::Message: From<WithChannels<Req, S>>,
Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
Update: RpcMessage,
{
let client = self.clone();
async move {
let request = client.request().await?;
match request {
irpc::Request::Local(local) => {
let (req_tx, req_rx) = mpsc::channel(local_update_cap);
local
.send((msg, NoSender, req_rx))
.await
.map_err(irpc::Error::from)?;
Ok(req_tx)
}
irpc::Request::Remote(remote) => {
let (s, _) = remote.write(msg).await?;
Ok(s.into())
}
}
}
}
}
}