1use std::{fmt::Debug, io, ops::Deref};
2
3use irpc::{
4 channel::{mpsc, none::NoSender, oneshot},
5 rpc_requests, Channels, WithChannels,
6};
7use serde::{Deserialize, Serialize};
8use snafu::Snafu;
9
10use crate::{
11 protocol::{
12 GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
13 ERR_PERMISSION,
14 },
15 provider::{events::irpc_ext::IrpcClientExt, TransferStats},
16 Hash,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21#[repr(u8)]
22pub enum ConnectMode {
23 #[default]
25 None,
26 Notify,
28 Intercept,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34#[repr(u8)]
35pub enum ObserveMode {
36 #[default]
38 None,
39 Notify,
41 Intercept,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum RequestMode {
49 #[default]
51 None,
52 Notify,
54 Intercept,
56 NotifyLog,
58 InterceptLog,
61 Disabled,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70#[repr(u8)]
71pub enum ThrottleMode {
72 #[default]
74 None,
75 Intercept,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum AbortReason {
81 RateLimited,
83 Permission,
85}
86
87#[derive(Debug, Snafu)]
89pub enum ProgressError {
90 Limit,
91 Permission,
92 #[snafu(transparent)]
93 Internal {
94 source: irpc::Error,
95 },
96}
97
98impl From<ProgressError> for io::Error {
99 fn from(value: ProgressError) -> Self {
100 match value {
101 ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(),
102 ProgressError::Permission => io::ErrorKind::PermissionDenied.into(),
103 ProgressError::Internal { source } => source.into(),
104 }
105 }
106}
107
108pub trait HasErrorCode {
109 fn code(&self) -> quinn::VarInt;
110}
111
112impl HasErrorCode for ProgressError {
113 fn code(&self) -> quinn::VarInt {
114 match self {
115 ProgressError::Limit => ERR_LIMIT,
116 ProgressError::Permission => ERR_PERMISSION,
117 ProgressError::Internal { .. } => ERR_INTERNAL,
118 }
119 }
120}
121
122impl ProgressError {
123 pub fn reason(&self) -> &'static [u8] {
124 match self {
125 ProgressError::Limit => b"limit",
126 ProgressError::Permission => b"permission",
127 ProgressError::Internal { .. } => b"internal",
128 }
129 }
130}
131
132impl From<AbortReason> for ProgressError {
133 fn from(value: AbortReason) -> Self {
134 match value {
135 AbortReason::RateLimited => ProgressError::Limit,
136 AbortReason::Permission => ProgressError::Permission,
137 }
138 }
139}
140
141impl From<irpc::channel::RecvError> for ProgressError {
142 fn from(value: irpc::channel::RecvError) -> Self {
143 ProgressError::Internal {
144 source: value.into(),
145 }
146 }
147}
148
149impl From<irpc::channel::SendError> for ProgressError {
150 fn from(value: irpc::channel::SendError) -> Self {
151 ProgressError::Internal {
152 source: value.into(),
153 }
154 }
155}
156
157pub type EventResult = Result<(), AbortReason>;
158pub type ClientResult = Result<(), ProgressError>;
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165pub struct EventMask {
166 pub connected: ConnectMode,
168 pub get: RequestMode,
170 pub get_many: RequestMode,
172 pub push: RequestMode,
174 pub observe: ObserveMode,
176 pub throttle: ThrottleMode,
178}
179
180impl Default for EventMask {
181 fn default() -> Self {
182 Self::DEFAULT
183 }
184}
185
186impl EventMask {
187 pub const DEFAULT: Self = Self {
189 connected: ConnectMode::None,
190 get: RequestMode::None,
191 get_many: RequestMode::None,
192 push: RequestMode::Disabled,
193 throttle: ThrottleMode::None,
194 observe: ObserveMode::None,
195 };
196
197 pub const ALL_READONLY: Self = Self {
203 connected: ConnectMode::Intercept,
204 get: RequestMode::InterceptLog,
205 get_many: RequestMode::InterceptLog,
206 push: RequestMode::Disabled,
207 throttle: ThrottleMode::Intercept,
208 observe: ObserveMode::Intercept,
209 };
210}
211
212#[derive(Debug, Serialize, Deserialize)]
214pub struct Notify<T>(T);
215
216impl<T> Deref for Notify<T> {
217 type Target = T;
218
219 fn deref(&self) -> &Self::Target {
220 &self.0
221 }
222}
223
224#[derive(Debug, Default, Clone)]
225pub struct EventSender {
226 mask: EventMask,
227 inner: Option<irpc::Client<ProviderProto>>,
228}
229
230#[derive(Debug, Default)]
231enum RequestUpdates {
232 #[default]
234 None,
235 Active(mpsc::Sender<RequestUpdate>),
237 Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
240}
241
242#[derive(Debug)]
243pub struct RequestTracker {
244 updates: RequestUpdates,
245 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
246}
247
248impl RequestTracker {
249 fn new(
250 updates: RequestUpdates,
251 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
252 ) -> Self {
253 Self { updates, throttle }
254 }
255
256 pub const NONE: Self = Self {
258 updates: RequestUpdates::None,
259 throttle: None,
260 };
261
262 pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
264 if let RequestUpdates::Active(tx) = &self.updates {
265 tx.send(
266 TransferStarted {
267 index,
268 hash: *hash,
269 size,
270 }
271 .into(),
272 )
273 .await?;
274 }
275 Ok(())
276 }
277
278 pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
280 if let RequestUpdates::Active(tx) = &mut self.updates {
281 tx.try_send(TransferProgress { end_offset }.into()).await?;
282 }
283 if let Some((throttle, connection_id, request_id)) = &self.throttle {
284 throttle
285 .rpc(Throttle {
286 connection_id: *connection_id,
287 request_id: *request_id,
288 size: len,
289 })
290 .await??;
291 }
292 Ok(())
293 }
294
295 pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
297 if let RequestUpdates::Active(tx) = &self.updates {
298 tx.send(TransferCompleted { stats: f() }.into()).await?;
299 }
300 Ok(())
301 }
302
303 pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
305 if let RequestUpdates::Active(tx) = &self.updates {
306 tx.send(TransferAborted { stats: f() }.into()).await?;
307 }
308 Ok(())
309 }
310}
311
312impl EventSender {
317 pub const DEFAULT: Self = Self {
319 mask: EventMask::DEFAULT,
320 inner: None,
321 };
322
323 pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
324 Self {
325 mask,
326 inner: Some(irpc::Client::from(client)),
327 }
328 }
329
330 pub fn channel(
331 capacity: usize,
332 mask: EventMask,
333 ) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
334 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
335 (Self::new(tx, mask), rx)
336 }
337
338 pub fn tracing(&self, mask: EventMask) -> Self {
340 use tracing::trace;
341 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
342 n0_future::task::spawn(async move {
343 fn log_request_events(
344 mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
345 connection_id: u64,
346 request_id: u64,
347 ) {
348 n0_future::task::spawn(async move {
349 while let Ok(Some(update)) = rx.recv().await {
350 trace!(%connection_id, %request_id, "{update:?}");
351 }
352 });
353 }
354 while let Some(msg) = rx.recv().await {
355 match msg {
356 ProviderMessage::ClientConnected(msg) => {
357 trace!("{:?}", msg.inner);
358 msg.tx.send(Ok(())).await.ok();
359 }
360 ProviderMessage::ClientConnectedNotify(msg) => {
361 trace!("{:?}", msg.inner);
362 }
363 ProviderMessage::ConnectionClosed(msg) => {
364 trace!("{:?}", msg.inner);
365 }
366 ProviderMessage::GetRequestReceived(msg) => {
367 trace!("{:?}", msg.inner);
368 msg.tx.send(Ok(())).await.ok();
369 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
370 }
371 ProviderMessage::GetRequestReceivedNotify(msg) => {
372 trace!("{:?}", msg.inner);
373 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
374 }
375 ProviderMessage::GetManyRequestReceived(msg) => {
376 trace!("{:?}", msg.inner);
377 msg.tx.send(Ok(())).await.ok();
378 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
379 }
380 ProviderMessage::GetManyRequestReceivedNotify(msg) => {
381 trace!("{:?}", msg.inner);
382 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
383 }
384 ProviderMessage::PushRequestReceived(msg) => {
385 trace!("{:?}", msg.inner);
386 msg.tx.send(Ok(())).await.ok();
387 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
388 }
389 ProviderMessage::PushRequestReceivedNotify(msg) => {
390 trace!("{:?}", msg.inner);
391 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
392 }
393 ProviderMessage::ObserveRequestReceived(msg) => {
394 trace!("{:?}", msg.inner);
395 msg.tx.send(Ok(())).await.ok();
396 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
397 }
398 ProviderMessage::ObserveRequestReceivedNotify(msg) => {
399 trace!("{:?}", msg.inner);
400 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
401 }
402 ProviderMessage::Throttle(msg) => {
403 trace!("{:?}", msg.inner);
404 msg.tx.send(Ok(())).await.ok();
405 }
406 }
407 }
408 });
409 Self {
410 mask,
411 inner: Some(irpc::Client::from(tx)),
412 }
413 }
414
415 pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
417 if let Some(client) = &self.inner {
418 match self.mask.connected {
419 ConnectMode::None => {}
420 ConnectMode::Notify => client.notify(Notify(f())).await?,
421 ConnectMode::Intercept => client.rpc(f()).await??,
422 }
423 };
424 Ok(())
425 }
426
427 pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
429 if let Some(client) = &self.inner {
430 client.notify(f()).await?;
431 };
432 Ok(())
433 }
434
435 pub(crate) async fn request<Req>(
439 &self,
440 f: impl FnOnce() -> Req,
441 connection_id: u64,
442 request_id: u64,
443 ) -> Result<RequestTracker, ProgressError>
444 where
445 ProviderProto: From<RequestReceived<Req>>,
446 ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
447 RequestReceived<Req>: Channels<
448 ProviderProto,
449 Tx = oneshot::Sender<EventResult>,
450 Rx = mpsc::Receiver<RequestUpdate>,
451 >,
452 ProviderProto: From<Notify<RequestReceived<Req>>>,
453 ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
454 Notify<RequestReceived<Req>>:
455 Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
456 {
457 let client = self.inner.as_ref();
458 Ok(self.create_tracker((
459 match self.mask.get {
460 RequestMode::None => RequestUpdates::None,
461 RequestMode::Notify if client.is_some() => {
462 let msg = RequestReceived {
463 request: f(),
464 connection_id,
465 request_id,
466 };
467 RequestUpdates::Disabled(
468 client.unwrap().notify_streaming(Notify(msg), 32).await?,
469 )
470 }
471 RequestMode::Intercept if client.is_some() => {
472 let msg = RequestReceived {
473 request: f(),
474 connection_id,
475 request_id,
476 };
477 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
478 rx.await??;
480 RequestUpdates::Disabled(tx)
481 }
482 RequestMode::NotifyLog if client.is_some() => {
483 let msg = RequestReceived {
484 request: f(),
485 connection_id,
486 request_id,
487 };
488 RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
489 }
490 RequestMode::InterceptLog if client.is_some() => {
491 let msg = RequestReceived {
492 request: f(),
493 connection_id,
494 request_id,
495 };
496 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
497 rx.await??;
499 RequestUpdates::Active(tx)
500 }
501 RequestMode::Disabled => {
502 return Err(ProgressError::Permission);
503 }
504 _ => RequestUpdates::None,
505 },
506 connection_id,
507 request_id,
508 )))
509 }
510
511 fn create_tracker(
512 &self,
513 (updates, connection_id, request_id): (RequestUpdates, u64, u64),
514 ) -> RequestTracker {
515 let throttle = match self.mask.throttle {
516 ThrottleMode::None => None,
517 ThrottleMode::Intercept => self
518 .inner
519 .clone()
520 .map(|client| (client, connection_id, request_id)),
521 };
522 RequestTracker::new(updates, throttle)
523 }
524}
525
526#[rpc_requests(message = ProviderMessage)]
527#[derive(Debug, Serialize, Deserialize)]
528pub enum ProviderProto {
529 #[rpc(tx = oneshot::Sender<EventResult>)]
531 ClientConnected(ClientConnected),
532
533 #[rpc(tx = NoSender)]
535 ClientConnectedNotify(Notify<ClientConnected>),
536
537 #[rpc(tx = NoSender)]
539 ConnectionClosed(ConnectionClosed),
540
541 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
543 GetRequestReceived(RequestReceived<GetRequest>),
544
545 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
547 GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
548
549 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
551 GetManyRequestReceived(RequestReceived<GetManyRequest>),
552
553 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
555 GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
556
557 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
559 PushRequestReceived(RequestReceived<PushRequest>),
560
561 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
563 PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
564
565 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
567 ObserveRequestReceived(RequestReceived<ObserveRequest>),
568
569 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
571 ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
572
573 #[rpc(tx = oneshot::Sender<EventResult>)]
575 Throttle(Throttle),
576}
577
578mod proto {
579 use iroh::NodeId;
580 use serde::{Deserialize, Serialize};
581
582 use crate::{provider::TransferStats, Hash};
583
584 #[derive(Debug, Serialize, Deserialize)]
585 pub struct ClientConnected {
586 pub connection_id: u64,
587 pub node_id: Option<NodeId>,
588 }
589
590 #[derive(Debug, Serialize, Deserialize)]
591 pub struct ConnectionClosed {
592 pub connection_id: u64,
593 }
594
595 #[derive(Debug, Serialize, Deserialize)]
597 pub struct RequestReceived<R> {
598 pub connection_id: u64,
600 pub request_id: u64,
602 pub request: R,
604 }
605
606 #[derive(Debug, Serialize, Deserialize)]
608 pub struct Throttle {
609 pub connection_id: u64,
611 pub request_id: u64,
613 pub size: u64,
615 }
616
617 #[derive(Debug, Serialize, Deserialize)]
618 pub struct TransferProgress {
619 pub end_offset: u64,
621 }
622
623 #[derive(Debug, Serialize, Deserialize)]
624 pub struct TransferStarted {
625 pub index: u64,
626 pub hash: Hash,
627 pub size: u64,
628 }
629
630 #[derive(Debug, Serialize, Deserialize)]
631 pub struct TransferCompleted {
632 pub stats: Box<TransferStats>,
633 }
634
635 #[derive(Debug, Serialize, Deserialize)]
636 pub struct TransferAborted {
637 pub stats: Box<TransferStats>,
638 }
639
640 #[derive(Debug, Serialize, Deserialize, derive_more::From)]
642 pub enum RequestUpdate {
643 Started(TransferStarted),
645 Progress(TransferProgress),
647 Completed(TransferCompleted),
649 Aborted(TransferAborted),
651 }
652}
653pub use proto::*;
654
655mod irpc_ext {
656 use std::future::Future;
657
658 use irpc::{
659 channel::{mpsc, none::NoSender},
660 Channels, RpcMessage, Service, WithChannels,
661 };
662
663 pub trait IrpcClientExt<S: Service> {
664 fn notify_streaming<Req, Update>(
665 &self,
666 msg: Req,
667 local_update_cap: usize,
668 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
669 where
670 S: From<Req>,
671 S::Message: From<WithChannels<Req, S>>,
672 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
673 Update: RpcMessage;
674 }
675
676 impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
677 fn notify_streaming<Req, Update>(
678 &self,
679 msg: Req,
680 local_update_cap: usize,
681 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
682 where
683 S: From<Req>,
684 S::Message: From<WithChannels<Req, S>>,
685 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
686 Update: RpcMessage,
687 {
688 let client = self.clone();
689 async move {
690 let request = client.request().await?;
691 match request {
692 irpc::Request::Local(local) => {
693 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
694 local
695 .send((msg, NoSender, req_rx))
696 .await
697 .map_err(irpc::Error::from)?;
698 Ok(req_tx)
699 }
700 irpc::Request::Remote(remote) => {
701 let (s, _) = remote.write(msg).await?;
702 Ok(s.into())
703 }
704 }
705 }
706 }
707 }
708}