1use std::{fmt::Debug, io, ops::Deref};
2
3use iroh::endpoint::VarInt;
4use irpc::{
5 channel::{mpsc, none::NoSender, oneshot},
6 rpc_requests, Channels, WithChannels,
7};
8use serde::{Deserialize, Serialize};
9use snafu::Snafu;
10
11use crate::{
12 protocol::{
13 GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
14 ERR_PERMISSION,
15 },
16 provider::{events::irpc_ext::IrpcClientExt, TransferStats},
17 Hash,
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22#[repr(u8)]
23pub enum ConnectMode {
24 #[default]
26 None,
27 Notify,
29 Intercept,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35#[repr(u8)]
36pub enum ObserveMode {
37 #[default]
39 None,
40 Notify,
42 Intercept,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
48#[repr(u8)]
49pub enum RequestMode {
50 #[default]
52 None,
53 Notify,
55 Intercept,
57 NotifyLog,
59 InterceptLog,
62 Disabled,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
71#[repr(u8)]
72pub enum ThrottleMode {
73 #[default]
75 None,
76 Intercept,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
81pub enum AbortReason {
82 RateLimited,
84 Permission,
86}
87
88#[derive(Debug, Snafu)]
90pub enum ProgressError {
91 Limit,
92 Permission,
93 #[snafu(transparent)]
94 Internal {
95 source: irpc::Error,
96 },
97}
98
99impl From<ProgressError> for io::Error {
100 fn from(value: ProgressError) -> Self {
101 match value {
102 ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(),
103 ProgressError::Permission => io::ErrorKind::PermissionDenied.into(),
104 ProgressError::Internal { source } => source.into(),
105 }
106 }
107}
108
109pub trait HasErrorCode {
110 fn code(&self) -> VarInt;
111}
112
113impl HasErrorCode for ProgressError {
114 fn code(&self) -> VarInt {
115 match self {
116 ProgressError::Limit => ERR_LIMIT,
117 ProgressError::Permission => ERR_PERMISSION,
118 ProgressError::Internal { .. } => ERR_INTERNAL,
119 }
120 }
121}
122
123impl ProgressError {
124 pub fn reason(&self) -> &'static [u8] {
125 match self {
126 ProgressError::Limit => b"limit",
127 ProgressError::Permission => b"permission",
128 ProgressError::Internal { .. } => b"internal",
129 }
130 }
131}
132
133impl From<AbortReason> for ProgressError {
134 fn from(value: AbortReason) -> Self {
135 match value {
136 AbortReason::RateLimited => ProgressError::Limit,
137 AbortReason::Permission => ProgressError::Permission,
138 }
139 }
140}
141
142impl From<irpc::channel::mpsc::RecvError> for ProgressError {
143 fn from(value: irpc::channel::mpsc::RecvError) -> Self {
144 ProgressError::Internal {
145 source: value.into(),
146 }
147 }
148}
149
150impl From<irpc::channel::oneshot::RecvError> for ProgressError {
151 fn from(value: irpc::channel::oneshot::RecvError) -> Self {
152 ProgressError::Internal {
153 source: value.into(),
154 }
155 }
156}
157
158impl From<irpc::channel::SendError> for ProgressError {
159 fn from(value: irpc::channel::SendError) -> Self {
160 ProgressError::Internal {
161 source: value.into(),
162 }
163 }
164}
165
166pub type EventResult = Result<(), AbortReason>;
167pub type ClientResult = Result<(), ProgressError>;
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
174pub struct EventMask {
175 pub connected: ConnectMode,
177 pub get: RequestMode,
179 pub get_many: RequestMode,
181 pub push: RequestMode,
183 pub observe: ObserveMode,
185 pub throttle: ThrottleMode,
187}
188
189impl Default for EventMask {
190 fn default() -> Self {
191 Self::DEFAULT
192 }
193}
194
195impl EventMask {
196 pub const DEFAULT: Self = Self {
198 connected: ConnectMode::None,
199 get: RequestMode::None,
200 get_many: RequestMode::None,
201 push: RequestMode::Disabled,
202 throttle: ThrottleMode::None,
203 observe: ObserveMode::None,
204 };
205
206 pub const ALL_READONLY: Self = Self {
212 connected: ConnectMode::Intercept,
213 get: RequestMode::InterceptLog,
214 get_many: RequestMode::InterceptLog,
215 push: RequestMode::Disabled,
216 throttle: ThrottleMode::Intercept,
217 observe: ObserveMode::Intercept,
218 };
219}
220
221#[derive(Debug, Serialize, Deserialize)]
223pub struct Notify<T>(T);
224
225impl<T> Deref for Notify<T> {
226 type Target = T;
227
228 fn deref(&self) -> &Self::Target {
229 &self.0
230 }
231}
232
233#[derive(Debug, Default, Clone)]
234pub struct EventSender {
235 mask: EventMask,
236 inner: Option<irpc::Client<ProviderProto>>,
237}
238
239#[derive(Debug, Default)]
240enum RequestUpdates {
241 #[default]
243 None,
244 Active(mpsc::Sender<RequestUpdate>),
246 Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
249}
250
251#[derive(Debug)]
252pub struct RequestTracker {
253 updates: RequestUpdates,
254 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
255}
256
257impl RequestTracker {
258 fn new(
259 updates: RequestUpdates,
260 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
261 ) -> Self {
262 Self { updates, throttle }
263 }
264
265 pub const NONE: Self = Self {
267 updates: RequestUpdates::None,
268 throttle: None,
269 };
270
271 pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
273 if let RequestUpdates::Active(tx) = &self.updates {
274 tx.send(
275 TransferStarted {
276 index,
277 hash: *hash,
278 size,
279 }
280 .into(),
281 )
282 .await?;
283 }
284 Ok(())
285 }
286
287 pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
289 if let RequestUpdates::Active(tx) = &mut self.updates {
290 tx.try_send(TransferProgress { end_offset }.into()).await?;
291 }
292 if let Some((throttle, connection_id, request_id)) = &self.throttle {
293 throttle
294 .rpc(Throttle {
295 connection_id: *connection_id,
296 request_id: *request_id,
297 size: len,
298 })
299 .await??;
300 }
301 Ok(())
302 }
303
304 pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
306 if let RequestUpdates::Active(tx) = &self.updates {
307 tx.send(TransferCompleted { stats: f() }.into()).await?;
308 }
309 Ok(())
310 }
311
312 pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
314 if let RequestUpdates::Active(tx) = &self.updates {
315 tx.send(TransferAborted { stats: f() }.into()).await?;
316 }
317 Ok(())
318 }
319}
320
321impl EventSender {
326 pub const DEFAULT: Self = Self {
328 mask: EventMask::DEFAULT,
329 inner: None,
330 };
331
332 pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
333 Self {
334 mask,
335 inner: Some(irpc::Client::from(client)),
336 }
337 }
338
339 pub fn channel(
340 capacity: usize,
341 mask: EventMask,
342 ) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
343 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
344 (Self::new(tx, mask), rx)
345 }
346
347 pub fn tracing(&self, mask: EventMask) -> Self {
349 use tracing::trace;
350 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
351 n0_future::task::spawn(async move {
352 fn log_request_events(
353 mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
354 connection_id: u64,
355 request_id: u64,
356 ) {
357 n0_future::task::spawn(async move {
358 while let Ok(Some(update)) = rx.recv().await {
359 trace!(%connection_id, %request_id, "{update:?}");
360 }
361 });
362 }
363 while let Some(msg) = rx.recv().await {
364 match msg {
365 ProviderMessage::ClientConnected(msg) => {
366 trace!("{:?}", msg.inner);
367 msg.tx.send(Ok(())).await.ok();
368 }
369 ProviderMessage::ClientConnectedNotify(msg) => {
370 trace!("{:?}", msg.inner);
371 }
372 ProviderMessage::ConnectionClosed(msg) => {
373 trace!("{:?}", msg.inner);
374 }
375 ProviderMessage::GetRequestReceived(msg) => {
376 trace!("{:?}", msg.inner);
377 msg.tx.send(Ok(())).await.ok();
378 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
379 }
380 ProviderMessage::GetRequestReceivedNotify(msg) => {
381 trace!("{:?}", msg.inner);
382 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
383 }
384 ProviderMessage::GetManyRequestReceived(msg) => {
385 trace!("{:?}", msg.inner);
386 msg.tx.send(Ok(())).await.ok();
387 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
388 }
389 ProviderMessage::GetManyRequestReceivedNotify(msg) => {
390 trace!("{:?}", msg.inner);
391 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
392 }
393 ProviderMessage::PushRequestReceived(msg) => {
394 trace!("{:?}", msg.inner);
395 msg.tx.send(Ok(())).await.ok();
396 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
397 }
398 ProviderMessage::PushRequestReceivedNotify(msg) => {
399 trace!("{:?}", msg.inner);
400 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
401 }
402 ProviderMessage::ObserveRequestReceived(msg) => {
403 trace!("{:?}", msg.inner);
404 msg.tx.send(Ok(())).await.ok();
405 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
406 }
407 ProviderMessage::ObserveRequestReceivedNotify(msg) => {
408 trace!("{:?}", msg.inner);
409 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
410 }
411 ProviderMessage::Throttle(msg) => {
412 trace!("{:?}", msg.inner);
413 msg.tx.send(Ok(())).await.ok();
414 }
415 }
416 }
417 });
418 Self {
419 mask,
420 inner: Some(irpc::Client::from(tx)),
421 }
422 }
423
424 pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
426 if let Some(client) = &self.inner {
427 match self.mask.connected {
428 ConnectMode::None => {}
429 ConnectMode::Notify => client.notify(Notify(f())).await?,
430 ConnectMode::Intercept => client.rpc(f()).await??,
431 }
432 };
433 Ok(())
434 }
435
436 pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
438 if let Some(client) = &self.inner {
439 client.notify(f()).await?;
440 };
441 Ok(())
442 }
443
444 pub(crate) async fn request<Req>(
448 &self,
449 f: impl FnOnce() -> Req,
450 connection_id: u64,
451 request_id: u64,
452 ) -> Result<RequestTracker, ProgressError>
453 where
454 ProviderProto: From<RequestReceived<Req>>,
455 ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
456 RequestReceived<Req>: Channels<
457 ProviderProto,
458 Tx = oneshot::Sender<EventResult>,
459 Rx = mpsc::Receiver<RequestUpdate>,
460 >,
461 ProviderProto: From<Notify<RequestReceived<Req>>>,
462 ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
463 Notify<RequestReceived<Req>>:
464 Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
465 {
466 let client = self.inner.as_ref();
467 Ok(self.create_tracker((
468 match self.mask.get {
469 RequestMode::None => RequestUpdates::None,
470 RequestMode::Notify if client.is_some() => {
471 let msg = RequestReceived {
472 request: f(),
473 connection_id,
474 request_id,
475 };
476 RequestUpdates::Disabled(
477 client.unwrap().notify_streaming(Notify(msg), 32).await?,
478 )
479 }
480 RequestMode::Intercept if client.is_some() => {
481 let msg = RequestReceived {
482 request: f(),
483 connection_id,
484 request_id,
485 };
486 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
487 rx.await??;
489 RequestUpdates::Disabled(tx)
490 }
491 RequestMode::NotifyLog if client.is_some() => {
492 let msg = RequestReceived {
493 request: f(),
494 connection_id,
495 request_id,
496 };
497 RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
498 }
499 RequestMode::InterceptLog if client.is_some() => {
500 let msg = RequestReceived {
501 request: f(),
502 connection_id,
503 request_id,
504 };
505 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
506 rx.await??;
508 RequestUpdates::Active(tx)
509 }
510 RequestMode::Disabled => {
511 return Err(ProgressError::Permission);
512 }
513 _ => RequestUpdates::None,
514 },
515 connection_id,
516 request_id,
517 )))
518 }
519
520 fn create_tracker(
521 &self,
522 (updates, connection_id, request_id): (RequestUpdates, u64, u64),
523 ) -> RequestTracker {
524 let throttle = match self.mask.throttle {
525 ThrottleMode::None => None,
526 ThrottleMode::Intercept => self
527 .inner
528 .clone()
529 .map(|client| (client, connection_id, request_id)),
530 };
531 RequestTracker::new(updates, throttle)
532 }
533}
534
535#[rpc_requests(message = ProviderMessage, rpc_feature = "rpc")]
536#[derive(Debug, Serialize, Deserialize)]
537pub enum ProviderProto {
538 #[rpc(tx = oneshot::Sender<EventResult>)]
540 ClientConnected(ClientConnected),
541
542 #[rpc(tx = NoSender)]
544 ClientConnectedNotify(Notify<ClientConnected>),
545
546 #[rpc(tx = NoSender)]
548 ConnectionClosed(ConnectionClosed),
549
550 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
552 GetRequestReceived(RequestReceived<GetRequest>),
553
554 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
556 GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
557
558 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
560 GetManyRequestReceived(RequestReceived<GetManyRequest>),
561
562 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
564 GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
565
566 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
568 PushRequestReceived(RequestReceived<PushRequest>),
569
570 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
572 PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
573
574 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
576 ObserveRequestReceived(RequestReceived<ObserveRequest>),
577
578 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
580 ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
581
582 #[rpc(tx = oneshot::Sender<EventResult>)]
584 Throttle(Throttle),
585}
586
587mod proto {
588 use iroh::EndpointId;
589 use serde::{Deserialize, Serialize};
590
591 use crate::{provider::TransferStats, Hash};
592
593 #[derive(Debug, Serialize, Deserialize)]
594 pub struct ClientConnected {
595 pub connection_id: u64,
596 pub endpoint_id: Option<EndpointId>,
597 }
598
599 #[derive(Debug, Serialize, Deserialize)]
600 pub struct ConnectionClosed {
601 pub connection_id: u64,
602 }
603
604 #[derive(Debug, Serialize, Deserialize)]
606 pub struct RequestReceived<R> {
607 pub connection_id: u64,
609 pub request_id: u64,
611 pub request: R,
613 }
614
615 #[derive(Debug, Serialize, Deserialize)]
617 pub struct Throttle {
618 pub connection_id: u64,
620 pub request_id: u64,
622 pub size: u64,
624 }
625
626 #[derive(Debug, Serialize, Deserialize)]
627 pub struct TransferProgress {
628 pub end_offset: u64,
630 }
631
632 #[derive(Debug, Serialize, Deserialize)]
633 pub struct TransferStarted {
634 pub index: u64,
635 pub hash: Hash,
636 pub size: u64,
637 }
638
639 #[derive(Debug, Serialize, Deserialize)]
640 pub struct TransferCompleted {
641 pub stats: Box<TransferStats>,
642 }
643
644 #[derive(Debug, Serialize, Deserialize)]
645 pub struct TransferAborted {
646 pub stats: Box<TransferStats>,
647 }
648
649 #[derive(Debug, Serialize, Deserialize, derive_more::From)]
651 pub enum RequestUpdate {
652 Started(TransferStarted),
654 Progress(TransferProgress),
656 Completed(TransferCompleted),
658 Aborted(TransferAborted),
660 }
661}
662pub use proto::*;
663
664mod irpc_ext {
665 use std::future::Future;
666
667 use irpc::{
668 channel::{mpsc, none::NoSender},
669 Channels, RpcMessage, Service, WithChannels,
670 };
671
672 pub trait IrpcClientExt<S: Service> {
673 fn notify_streaming<Req, Update>(
674 &self,
675 msg: Req,
676 local_update_cap: usize,
677 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
678 where
679 S: From<Req>,
680 S::Message: From<WithChannels<Req, S>>,
681 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
682 Update: RpcMessage;
683 }
684
685 impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
686 fn notify_streaming<Req, Update>(
687 &self,
688 msg: Req,
689 local_update_cap: usize,
690 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
691 where
692 S: From<Req>,
693 S::Message: From<WithChannels<Req, S>>,
694 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
695 Update: RpcMessage,
696 {
697 let client = self.clone();
698 async move {
699 let request = client.request().await?;
700 match request {
701 irpc::Request::Local(local) => {
702 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
703 local
704 .send((msg, NoSender, req_rx))
705 .await
706 .map_err(irpc::Error::from)?;
707 Ok(req_tx)
708 }
709 #[cfg(feature = "rpc")]
710 irpc::Request::Remote(remote) => {
711 let (s, _) = remote.write(msg).await?;
712 Ok(s.into())
713 }
714 #[cfg(not(feature = "rpc"))]
715 irpc::Request::Remote(_) => {
716 unreachable!()
717 }
718 }
719 }
720 }
721 }
722}