1use std::{fmt::Debug, io, ops::Deref};
2
3use irpc::{
4 channel::{mpsc, none::NoSender, oneshot},
5 rpc_requests, Channels, WithChannels,
6};
7use n0_error::{e, stack_error};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11 protocol::{
12 GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
13 ERR_PERMISSION,
14 },
15 provider::{events::irpc_ext::IrpcClientExt, TransferStats},
16 Hash,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
21#[repr(u8)]
22pub enum ConnectMode {
23 #[default]
25 None,
26 Notify,
28 Intercept,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
34#[repr(u8)]
35pub enum ObserveMode {
36 #[default]
38 None,
39 Notify,
41 Intercept,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum RequestMode {
49 #[default]
51 None,
52 Notify,
54 Intercept,
56 NotifyLog,
58 InterceptLog,
61 Disabled,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70#[repr(u8)]
71pub enum ThrottleMode {
72 #[default]
74 None,
75 Intercept,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum AbortReason {
81 RateLimited,
83 Permission,
85}
86
87#[stack_error(derive, add_meta, from_sources)]
89pub enum ProgressError {
90 #[error("limit")]
91 Limit {},
92 #[error("permission")]
93 Permission {},
94 #[error(transparent)]
95 Internal { source: irpc::Error },
96}
97
98impl From<ProgressError> for io::Error {
99 fn from(value: ProgressError) -> Self {
100 match value {
101 ProgressError::Limit { .. } => io::ErrorKind::QuotaExceeded.into(),
102 ProgressError::Permission { .. } => io::ErrorKind::PermissionDenied.into(),
103 ProgressError::Internal { source, .. } => source.into(),
104 }
105 }
106}
107
108pub trait HasErrorCode {
109 fn code(&self) -> quinn::VarInt;
110}
111
112impl HasErrorCode for ProgressError {
113 fn code(&self) -> quinn::VarInt {
114 match self {
115 ProgressError::Limit { .. } => ERR_LIMIT,
116 ProgressError::Permission { .. } => ERR_PERMISSION,
117 ProgressError::Internal { .. } => ERR_INTERNAL,
118 }
119 }
120}
121
122impl ProgressError {
123 pub fn reason(&self) -> &'static [u8] {
124 match self {
125 ProgressError::Limit { .. } => b"limit",
126 ProgressError::Permission { .. } => b"permission",
127 ProgressError::Internal { .. } => b"internal",
128 }
129 }
130}
131
132impl From<AbortReason> for ProgressError {
133 fn from(value: AbortReason) -> Self {
134 match value {
135 AbortReason::RateLimited => n0_error::e!(ProgressError::Limit),
136 AbortReason::Permission => n0_error::e!(ProgressError::Permission),
137 }
138 }
139}
140
141impl From<irpc::channel::mpsc::RecvError> for ProgressError {
142 fn from(value: irpc::channel::mpsc::RecvError) -> Self {
143 n0_error::e!(ProgressError::Internal, value.into())
144 }
145}
146
147impl From<irpc::channel::oneshot::RecvError> for ProgressError {
148 fn from(value: irpc::channel::oneshot::RecvError) -> Self {
149 n0_error::e!(ProgressError::Internal, value.into())
150 }
151}
152
153impl From<irpc::channel::SendError> for ProgressError {
154 fn from(value: irpc::channel::SendError) -> Self {
155 n0_error::e!(ProgressError::Internal, value.into())
156 }
157}
158
159pub type EventResult = Result<(), AbortReason>;
160pub type ClientResult = Result<(), ProgressError>;
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub struct EventMask {
168 pub connected: ConnectMode,
170 pub get: RequestMode,
172 pub get_many: RequestMode,
174 pub push: RequestMode,
176 pub observe: ObserveMode,
178 pub throttle: ThrottleMode,
180}
181
182impl Default for EventMask {
183 fn default() -> Self {
184 Self::DEFAULT
185 }
186}
187
188impl EventMask {
189 pub const DEFAULT: Self = Self {
191 connected: ConnectMode::None,
192 get: RequestMode::None,
193 get_many: RequestMode::None,
194 push: RequestMode::Disabled,
195 throttle: ThrottleMode::None,
196 observe: ObserveMode::None,
197 };
198
199 pub const ALL_READONLY: Self = Self {
205 connected: ConnectMode::Intercept,
206 get: RequestMode::InterceptLog,
207 get_many: RequestMode::InterceptLog,
208 push: RequestMode::Disabled,
209 throttle: ThrottleMode::Intercept,
210 observe: ObserveMode::Intercept,
211 };
212}
213
214#[derive(Debug, Serialize, Deserialize)]
216pub struct Notify<T>(T);
217
218impl<T> Deref for Notify<T> {
219 type Target = T;
220
221 fn deref(&self) -> &Self::Target {
222 &self.0
223 }
224}
225
226#[derive(Debug, Default, Clone)]
227pub struct EventSender {
228 mask: EventMask,
229 inner: Option<irpc::Client<ProviderProto>>,
230}
231
232#[derive(Debug, Default)]
233enum RequestUpdates {
234 #[default]
236 None,
237 Active(mpsc::Sender<RequestUpdate>),
239 Disabled(#[allow(dead_code)] mpsc::Sender<RequestUpdate>),
242}
243
244#[derive(Debug)]
245pub struct RequestTracker {
246 updates: RequestUpdates,
247 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
248}
249
250impl RequestTracker {
251 fn new(
252 updates: RequestUpdates,
253 throttle: Option<(irpc::Client<ProviderProto>, u64, u64)>,
254 ) -> Self {
255 Self { updates, throttle }
256 }
257
258 pub const NONE: Self = Self {
260 updates: RequestUpdates::None,
261 throttle: None,
262 };
263
264 pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
266 if let RequestUpdates::Active(tx) = &self.updates {
267 tx.send(
268 TransferStarted {
269 index,
270 hash: *hash,
271 size,
272 }
273 .into(),
274 )
275 .await?;
276 }
277 Ok(())
278 }
279
280 pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
282 if let RequestUpdates::Active(tx) = &mut self.updates {
283 tx.try_send(TransferProgress { end_offset }.into()).await?;
284 }
285 if let Some((throttle, connection_id, request_id)) = &self.throttle {
286 throttle
287 .rpc(Throttle {
288 connection_id: *connection_id,
289 request_id: *request_id,
290 size: len,
291 })
292 .await??;
293 }
294 Ok(())
295 }
296
297 pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
299 if let RequestUpdates::Active(tx) = &self.updates {
300 tx.send(TransferCompleted { stats: f() }.into()).await?;
301 }
302 Ok(())
303 }
304
305 pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
307 if let RequestUpdates::Active(tx) = &self.updates {
308 tx.send(TransferAborted { stats: f() }.into()).await?;
309 }
310 Ok(())
311 }
312}
313
314impl EventSender {
319 pub const DEFAULT: Self = Self {
321 mask: EventMask::DEFAULT,
322 inner: None,
323 };
324
325 pub fn new(client: tokio::sync::mpsc::Sender<ProviderMessage>, mask: EventMask) -> Self {
326 Self {
327 mask,
328 inner: Some(irpc::Client::from(client)),
329 }
330 }
331
332 pub fn channel(
333 capacity: usize,
334 mask: EventMask,
335 ) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
336 let (tx, rx) = tokio::sync::mpsc::channel(capacity);
337 (Self::new(tx, mask), rx)
338 }
339
340 pub fn tracing(&self, mask: EventMask) -> Self {
342 use tracing::trace;
343 let (tx, mut rx) = tokio::sync::mpsc::channel(32);
344 n0_future::task::spawn(async move {
345 fn log_request_events(
346 mut rx: irpc::channel::mpsc::Receiver<RequestUpdate>,
347 connection_id: u64,
348 request_id: u64,
349 ) {
350 n0_future::task::spawn(async move {
351 while let Ok(Some(update)) = rx.recv().await {
352 trace!(%connection_id, %request_id, "{update:?}");
353 }
354 });
355 }
356 while let Some(msg) = rx.recv().await {
357 match msg {
358 ProviderMessage::ClientConnected(msg) => {
359 trace!("{:?}", msg.inner);
360 msg.tx.send(Ok(())).await.ok();
361 }
362 ProviderMessage::ClientConnectedNotify(msg) => {
363 trace!("{:?}", msg.inner);
364 }
365 ProviderMessage::ConnectionClosed(msg) => {
366 trace!("{:?}", msg.inner);
367 }
368 ProviderMessage::GetRequestReceived(msg) => {
369 trace!("{:?}", msg.inner);
370 msg.tx.send(Ok(())).await.ok();
371 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
372 }
373 ProviderMessage::GetRequestReceivedNotify(msg) => {
374 trace!("{:?}", msg.inner);
375 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
376 }
377 ProviderMessage::GetManyRequestReceived(msg) => {
378 trace!("{:?}", msg.inner);
379 msg.tx.send(Ok(())).await.ok();
380 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
381 }
382 ProviderMessage::GetManyRequestReceivedNotify(msg) => {
383 trace!("{:?}", msg.inner);
384 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
385 }
386 ProviderMessage::PushRequestReceived(msg) => {
387 trace!("{:?}", msg.inner);
388 msg.tx.send(Ok(())).await.ok();
389 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
390 }
391 ProviderMessage::PushRequestReceivedNotify(msg) => {
392 trace!("{:?}", msg.inner);
393 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
394 }
395 ProviderMessage::ObserveRequestReceived(msg) => {
396 trace!("{:?}", msg.inner);
397 msg.tx.send(Ok(())).await.ok();
398 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
399 }
400 ProviderMessage::ObserveRequestReceivedNotify(msg) => {
401 trace!("{:?}", msg.inner);
402 log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id);
403 }
404 ProviderMessage::Throttle(msg) => {
405 trace!("{:?}", msg.inner);
406 msg.tx.send(Ok(())).await.ok();
407 }
408 }
409 }
410 });
411 Self {
412 mask,
413 inner: Some(irpc::Client::from(tx)),
414 }
415 }
416
417 pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult {
419 if let Some(client) = &self.inner {
420 match self.mask.connected {
421 ConnectMode::None => {}
422 ConnectMode::Notify => client.notify(Notify(f())).await?,
423 ConnectMode::Intercept => client.rpc(f()).await??,
424 }
425 };
426 Ok(())
427 }
428
429 pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult {
431 if let Some(client) = &self.inner {
432 client.notify(f()).await?;
433 };
434 Ok(())
435 }
436
437 pub(crate) async fn request<Req>(
441 &self,
442 f: impl FnOnce() -> Req,
443 connection_id: u64,
444 request_id: u64,
445 ) -> Result<RequestTracker, ProgressError>
446 where
447 ProviderProto: From<RequestReceived<Req>>,
448 ProviderMessage: From<WithChannels<RequestReceived<Req>, ProviderProto>>,
449 RequestReceived<Req>: Channels<
450 ProviderProto,
451 Tx = oneshot::Sender<EventResult>,
452 Rx = mpsc::Receiver<RequestUpdate>,
453 >,
454 ProviderProto: From<Notify<RequestReceived<Req>>>,
455 ProviderMessage: From<WithChannels<Notify<RequestReceived<Req>>, ProviderProto>>,
456 Notify<RequestReceived<Req>>:
457 Channels<ProviderProto, Tx = NoSender, Rx = mpsc::Receiver<RequestUpdate>>,
458 {
459 let client = self.inner.as_ref();
460 Ok(self.create_tracker((
461 match self.mask.get {
462 RequestMode::None => RequestUpdates::None,
463 RequestMode::Notify if client.is_some() => {
464 let msg = RequestReceived {
465 request: f(),
466 connection_id,
467 request_id,
468 };
469 RequestUpdates::Disabled(
470 client.unwrap().notify_streaming(Notify(msg), 32).await?,
471 )
472 }
473 RequestMode::Intercept if client.is_some() => {
474 let msg = RequestReceived {
475 request: f(),
476 connection_id,
477 request_id,
478 };
479 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
480 rx.await??;
482 RequestUpdates::Disabled(tx)
483 }
484 RequestMode::NotifyLog if client.is_some() => {
485 let msg = RequestReceived {
486 request: f(),
487 connection_id,
488 request_id,
489 };
490 RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?)
491 }
492 RequestMode::InterceptLog if client.is_some() => {
493 let msg = RequestReceived {
494 request: f(),
495 connection_id,
496 request_id,
497 };
498 let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?;
499 rx.await??;
501 RequestUpdates::Active(tx)
502 }
503 RequestMode::Disabled => {
504 return Err(e!(ProgressError::Permission));
505 }
506 _ => RequestUpdates::None,
507 },
508 connection_id,
509 request_id,
510 )))
511 }
512
513 fn create_tracker(
514 &self,
515 (updates, connection_id, request_id): (RequestUpdates, u64, u64),
516 ) -> RequestTracker {
517 let throttle = match self.mask.throttle {
518 ThrottleMode::None => None,
519 ThrottleMode::Intercept => self
520 .inner
521 .clone()
522 .map(|client| (client, connection_id, request_id)),
523 };
524 RequestTracker::new(updates, throttle)
525 }
526}
527
528#[rpc_requests(message = ProviderMessage)]
529#[derive(Debug, Serialize, Deserialize)]
530pub enum ProviderProto {
531 #[rpc(tx = oneshot::Sender<EventResult>)]
533 ClientConnected(ClientConnected),
534
535 #[rpc(tx = NoSender)]
537 ClientConnectedNotify(Notify<ClientConnected>),
538
539 #[rpc(tx = NoSender)]
541 ConnectionClosed(ConnectionClosed),
542
543 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
545 GetRequestReceived(RequestReceived<GetRequest>),
546
547 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
549 GetRequestReceivedNotify(Notify<RequestReceived<GetRequest>>),
550
551 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
553 GetManyRequestReceived(RequestReceived<GetManyRequest>),
554
555 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
557 GetManyRequestReceivedNotify(Notify<RequestReceived<GetManyRequest>>),
558
559 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
561 PushRequestReceived(RequestReceived<PushRequest>),
562
563 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
565 PushRequestReceivedNotify(Notify<RequestReceived<PushRequest>>),
566
567 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = oneshot::Sender<EventResult>)]
569 ObserveRequestReceived(RequestReceived<ObserveRequest>),
570
571 #[rpc(rx = mpsc::Receiver<RequestUpdate>, tx = NoSender)]
573 ObserveRequestReceivedNotify(Notify<RequestReceived<ObserveRequest>>),
574
575 #[rpc(tx = oneshot::Sender<EventResult>)]
577 Throttle(Throttle),
578}
579
580mod proto {
581 use iroh::EndpointId;
582 use serde::{Deserialize, Serialize};
583
584 use crate::{provider::TransferStats, Hash};
585
586 #[derive(Debug, Serialize, Deserialize)]
587 pub struct ClientConnected {
588 pub connection_id: u64,
589 pub endpoint_id: Option<EndpointId>,
590 }
591
592 #[derive(Debug, Serialize, Deserialize)]
593 pub struct ConnectionClosed {
594 pub connection_id: u64,
595 }
596
597 #[derive(Debug, Serialize, Deserialize)]
599 pub struct RequestReceived<R> {
600 pub connection_id: u64,
602 pub request_id: u64,
604 pub request: R,
606 }
607
608 #[derive(Debug, Serialize, Deserialize)]
610 pub struct Throttle {
611 pub connection_id: u64,
613 pub request_id: u64,
615 pub size: u64,
617 }
618
619 #[derive(Debug, Serialize, Deserialize)]
620 pub struct TransferProgress {
621 pub end_offset: u64,
623 }
624
625 #[derive(Debug, Serialize, Deserialize)]
626 pub struct TransferStarted {
627 pub index: u64,
628 pub hash: Hash,
629 pub size: u64,
630 }
631
632 #[derive(Debug, Serialize, Deserialize)]
633 pub struct TransferCompleted {
634 pub stats: Box<TransferStats>,
635 }
636
637 #[derive(Debug, Serialize, Deserialize)]
638 pub struct TransferAborted {
639 pub stats: Box<TransferStats>,
640 }
641
642 #[derive(Debug, Serialize, Deserialize, derive_more::From)]
644 pub enum RequestUpdate {
645 Started(TransferStarted),
647 Progress(TransferProgress),
649 Completed(TransferCompleted),
651 Aborted(TransferAborted),
653 }
654}
655pub use proto::*;
656
657mod irpc_ext {
658 use std::future::Future;
659
660 use irpc::{
661 channel::{mpsc, none::NoSender},
662 Channels, RpcMessage, Service, WithChannels,
663 };
664
665 pub trait IrpcClientExt<S: Service> {
666 fn notify_streaming<Req, Update>(
667 &self,
668 msg: Req,
669 local_update_cap: usize,
670 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
671 where
672 S: From<Req>,
673 S::Message: From<WithChannels<Req, S>>,
674 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
675 Update: RpcMessage;
676 }
677
678 impl<S: Service> IrpcClientExt<S> for irpc::Client<S> {
679 fn notify_streaming<Req, Update>(
680 &self,
681 msg: Req,
682 local_update_cap: usize,
683 ) -> impl Future<Output = irpc::Result<mpsc::Sender<Update>>>
684 where
685 S: From<Req>,
686 S::Message: From<WithChannels<Req, S>>,
687 Req: Channels<S, Tx = NoSender, Rx = mpsc::Receiver<Update>>,
688 Update: RpcMessage,
689 {
690 let client = self.clone();
691 async move {
692 let request = client.request().await?;
693 match request {
694 irpc::Request::Local(local) => {
695 let (req_tx, req_rx) = mpsc::channel(local_update_cap);
696 local
697 .send((msg, NoSender, req_rx))
698 .await
699 .map_err(irpc::Error::from)?;
700 Ok(req_tx)
701 }
702 irpc::Request::Remote(remote) => {
703 let (s, _) = remote.write(msg).await?;
704 Ok(s.into())
705 }
706 }
707 }
708 }
709 }
710}