1use std::{fmt::Debug, future::Future, io};
7
8use bao_tree::ChunkRanges;
9use iroh::endpoint::{self, ConnectionError, VarInt};
10use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
11use n0_error::{e, stack_error, Result};
12use n0_future::{
13 time::{Duration, Instant},
14 StreamExt,
15};
16use serde::{Deserialize, Serialize};
17use tokio::select;
18use tracing::{debug, debug_span, Instrument};
19
20use crate::{
21 api::{
22 blobs::{Bitfield, WriteProgress},
23 ExportBaoError, ExportBaoResult, RequestError, Store,
24 },
25 hashseq::HashSeq,
26 protocol::{
27 GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
28 },
29 provider::events::{
30 ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
31 RequestTracker,
32 },
33 util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
34 Hash,
35};
36pub mod events;
37use events::EventSender;
38
39type DefaultReader = iroh::endpoint::RecvStream;
40type DefaultWriter = iroh::endpoint::SendStream;
41
42#[derive(Debug, Serialize, Deserialize)]
44pub struct TransferStats {
45 pub payload_bytes_sent: u64,
47 pub other_bytes_sent: u64,
51 pub other_bytes_read: u64,
56 pub duration: Duration,
58}
59
60#[derive(Debug)]
62pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
63 t0: Instant,
64 connection_id: u64,
65 reader: R,
66 writer: W,
67 other_bytes_read: u64,
68 events: EventSender,
69}
70
71impl StreamPair {
72 pub async fn accept(
73 conn: &endpoint::Connection,
74 events: EventSender,
75 ) -> Result<Self, ConnectionError> {
76 let (writer, reader) = conn.accept_bi().await?;
77 Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
78 }
79}
80
81impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
82 pub fn stream_id(&self) -> u64 {
83 self.reader.id()
84 }
85
86 pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
87 Self {
88 t0: Instant::now(),
89 connection_id,
90 reader,
91 writer,
92 other_bytes_read: 0,
93 events,
94 }
95 }
96
97 pub async fn read_request(&mut self) -> Result<Request> {
106 let (res, size) = Request::read_async(&mut self.reader).await?;
107 self.other_bytes_read += size as u64;
108 Ok(res)
109 }
110
111 pub async fn into_writer(
113 mut self,
114 tracker: RequestTracker,
115 ) -> Result<ProgressWriter<W>, io::Error> {
116 self.reader.expect_eof().await?;
117 drop(self.reader);
118 Ok(ProgressWriter::new(
119 self.writer,
120 WriterContext {
121 t0: self.t0,
122 other_bytes_read: self.other_bytes_read,
123 payload_bytes_written: 0,
124 other_bytes_written: 0,
125 tracker,
126 },
127 ))
128 }
129
130 pub async fn into_reader(
131 mut self,
132 tracker: RequestTracker,
133 ) -> Result<ProgressReader<R>, io::Error> {
134 self.writer.sync().await?;
135 drop(self.writer);
136 Ok(ProgressReader {
137 inner: self.reader,
138 context: ReaderContext {
139 t0: self.t0,
140 other_bytes_read: self.other_bytes_read,
141 tracker,
142 },
143 })
144 }
145
146 pub async fn get_request(
147 &self,
148 f: impl FnOnce() -> GetRequest,
149 ) -> Result<RequestTracker, ProgressError> {
150 self.events
151 .request(f, self.connection_id, self.reader.id())
152 .await
153 }
154
155 pub async fn get_many_request(
156 &self,
157 f: impl FnOnce() -> GetManyRequest,
158 ) -> Result<RequestTracker, ProgressError> {
159 self.events
160 .request(f, self.connection_id, self.reader.id())
161 .await
162 }
163
164 pub async fn push_request(
165 &self,
166 f: impl FnOnce() -> PushRequest,
167 ) -> Result<RequestTracker, ProgressError> {
168 self.events
169 .request(f, self.connection_id, self.reader.id())
170 .await
171 }
172
173 pub async fn observe_request(
174 &self,
175 f: impl FnOnce() -> ObserveRequest,
176 ) -> Result<RequestTracker, ProgressError> {
177 self.events
178 .request(f, self.connection_id, self.reader.id())
179 .await
180 }
181
182 pub fn stats(&self) -> TransferStats {
183 TransferStats {
184 payload_bytes_sent: 0,
185 other_bytes_sent: 0,
186 other_bytes_read: self.other_bytes_read,
187 duration: self.t0.elapsed(),
188 }
189 }
190}
191
192#[derive(Debug)]
193struct ReaderContext {
194 t0: Instant,
196 other_bytes_read: u64,
198 tracker: RequestTracker,
200}
201
202impl ReaderContext {
203 fn stats(&self) -> TransferStats {
204 TransferStats {
205 payload_bytes_sent: 0,
206 other_bytes_sent: 0,
207 other_bytes_read: self.other_bytes_read,
208 duration: self.t0.elapsed(),
209 }
210 }
211}
212
213#[derive(Debug)]
214pub(crate) struct WriterContext {
215 t0: Instant,
217 other_bytes_read: u64,
219 payload_bytes_written: u64,
221 other_bytes_written: u64,
223 tracker: RequestTracker,
225}
226
227impl WriterContext {
228 fn stats(&self) -> TransferStats {
229 TransferStats {
230 payload_bytes_sent: self.payload_bytes_written,
231 other_bytes_sent: self.other_bytes_written,
232 other_bytes_read: self.other_bytes_read,
233 duration: self.t0.elapsed(),
234 }
235 }
236}
237
238impl WriteProgress for WriterContext {
239 async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
240 let len = len as u64;
241 let end_offset = offset + len;
242 self.payload_bytes_written += len;
243 self.tracker.transfer_progress(len, end_offset).await
244 }
245
246 fn log_other_write(&mut self, len: usize) {
247 self.other_bytes_written += len as u64;
248 }
249
250 async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
251 self.tracker.transfer_started(index, hash, size).await.ok();
252 }
253}
254
255#[derive(Debug)]
257pub struct ProgressWriter<W: SendStream = DefaultWriter> {
258 pub inner: W,
260 pub(crate) context: WriterContext,
261}
262
263impl<W: SendStream> ProgressWriter<W> {
264 fn new(inner: W, context: WriterContext) -> Self {
265 Self { inner, context }
266 }
267
268 async fn transfer_aborted(&self) {
269 self.context
270 .tracker
271 .transfer_aborted(|| Box::new(self.context.stats()))
272 .await
273 .ok();
274 }
275
276 async fn transfer_completed(&self) {
277 self.context
278 .tracker
279 .transfer_completed(|| Box::new(self.context.stats()))
280 .await
281 .ok();
282 }
283}
284
285pub async fn handle_connection(
287 connection: endpoint::Connection,
288 store: Store,
289 progress: EventSender,
290) {
291 let connection_id = connection.stable_id() as u64;
292 let span = debug_span!("connection", connection_id);
293 async move {
294 if let Err(cause) = progress
295 .client_connected(|| ClientConnected {
296 connection_id,
297 endpoint_id: Some(connection.remote_id()),
298 })
299 .await
300 {
301 connection.close(cause.code(), cause.reason());
302 debug!("closing connection: {cause}");
303 return;
304 }
305 while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
306 let span = debug_span!("stream", stream_id = %pair.stream_id());
307 let store = store.clone();
308 n0_future::task::spawn(handle_stream(pair, store).instrument(span));
309 }
310 progress
311 .connection_closed(|| ConnectionClosed { connection_id })
312 .await
313 .ok();
314 }
315 .instrument(span)
316 .await
317}
318
319pub trait ErrorHandler {
321 type W: AsyncStreamWriter;
322 type R: AsyncStreamReader;
323 fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
324 fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
325}
326
327async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
328 pair: &mut StreamPair<R, W>,
329 r: Result<T, E>,
330) -> Result<T, E> {
331 match r {
332 Ok(x) => Ok(x),
333 Err(e) => {
334 pair.writer.reset(e.code()).ok();
335 Err(e)
336 }
337 }
338}
339async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
340 writer: &mut ProgressWriter<W>,
341 r: Result<T, E>,
342) -> Result<T, E> {
343 match r {
344 Ok(x) => {
345 writer.transfer_completed().await;
346 Ok(x)
347 }
348 Err(e) => {
349 writer.inner.reset(e.code()).ok();
350 writer.transfer_aborted().await;
351 Err(e)
352 }
353 }
354}
355async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
356 reader: &mut ProgressReader<R>,
357 r: Result<T, E>,
358) -> Result<T, E> {
359 match r {
360 Ok(x) => {
361 reader.transfer_completed().await;
362 Ok(x)
363 }
364 Err(e) => {
365 reader.inner.stop(e.code()).ok();
366 reader.transfer_aborted().await;
367 Err(e)
368 }
369 }
370}
371
372pub async fn handle_stream<R: RecvStream, W: SendStream>(
373 mut pair: StreamPair<R, W>,
374 store: Store,
375) -> n0_error::Result<()> {
376 let request = pair.read_request().await?;
377 match request {
378 Request::Get(request) => handle_get(pair, store, request).await?,
379 Request::GetMany(request) => handle_get_many(pair, store, request).await?,
380 Request::Observe(request) => handle_observe(pair, store, request).await?,
381 Request::Push(request) => handle_push(pair, store, request).await?,
382 _ => {}
383 }
384 Ok(())
385}
386
387#[stack_error(derive, add_meta, from_sources)]
388pub enum HandleGetError {
389 #[error(transparent)]
390 ExportBao {
391 #[error(std_err)]
392 source: ExportBaoError,
393 },
394 #[error("Invalid hash sequence")]
395 InvalidHashSeq {},
396 #[error("Invalid offset")]
397 InvalidOffset {},
398}
399
400impl HasErrorCode for HandleGetError {
401 fn code(&self) -> VarInt {
402 match self {
403 HandleGetError::ExportBao {
404 source: ExportBaoError::ClientError { source, .. },
405 ..
406 } => source.code(),
407 HandleGetError::InvalidHashSeq { .. } => ERR_INTERNAL,
408 HandleGetError::InvalidOffset { .. } => ERR_INTERNAL,
409 _ => ERR_INTERNAL,
410 }
411 }
412}
413
414async fn handle_get_impl<W: SendStream>(
418 store: Store,
419 request: GetRequest,
420 writer: &mut ProgressWriter<W>,
421) -> Result<(), HandleGetError> {
422 let hash = request.hash;
423 debug!(%hash, "get received request");
424 let mut hash_seq = None;
425 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
426 if offset == 0 {
427 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
428 } else {
429 let hash_seq = match &hash_seq {
435 Some(b) => b,
436 None => {
437 let bytes = store.get_bytes(hash).await?;
438 let hs =
439 HashSeq::try_from(bytes).map_err(|_| e!(HandleGetError::InvalidHashSeq))?;
440 hash_seq = Some(hs);
441 hash_seq.as_ref().unwrap()
442 }
443 };
444 let o = usize::try_from(offset - 1).map_err(|_| e!(HandleGetError::InvalidOffset))?;
445 let Some(hash) = hash_seq.get(o) else {
446 break;
447 };
448 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
449 }
450 }
451 writer
452 .inner
453 .sync()
454 .await
455 .map_err(|e| e!(HandleGetError::ExportBao, e.into()))?;
456
457 Ok(())
458}
459
460pub async fn handle_get<R: RecvStream, W: SendStream>(
461 mut pair: StreamPair<R, W>,
462 store: Store,
463 request: GetRequest,
464) -> n0_error::Result<()> {
465 let res = pair.get_request(|| request.clone()).await;
466 let tracker = handle_read_request_result(&mut pair, res).await?;
467 let mut writer = pair.into_writer(tracker).await?;
468 let res = handle_get_impl(store, request, &mut writer).await;
469 handle_write_result(&mut writer, res).await?;
470 Ok(())
471}
472
473#[stack_error(derive, add_meta, from_sources)]
474pub enum HandleGetManyError {
475 #[error(transparent)]
476 ExportBao { source: ExportBaoError },
477}
478
479impl HasErrorCode for HandleGetManyError {
480 fn code(&self) -> VarInt {
481 match self {
482 Self::ExportBao {
483 source: ExportBaoError::ClientError { source, .. },
484 ..
485 } => source.code(),
486 _ => ERR_INTERNAL,
487 }
488 }
489}
490
491async fn handle_get_many_impl<W: SendStream>(
495 store: Store,
496 request: GetManyRequest,
497 writer: &mut ProgressWriter<W>,
498) -> Result<(), HandleGetManyError> {
499 debug!("get_many received request");
500 let request_ranges = request.ranges.iter_infinite();
501 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
502 if !ranges.is_empty() {
503 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
504 }
505 }
506 Ok(())
507}
508
509pub async fn handle_get_many<R: RecvStream, W: SendStream>(
510 mut pair: StreamPair<R, W>,
511 store: Store,
512 request: GetManyRequest,
513) -> n0_error::Result<()> {
514 let res = pair.get_many_request(|| request.clone()).await;
515 let tracker = handle_read_request_result(&mut pair, res).await?;
516 let mut writer = pair.into_writer(tracker).await?;
517 let res = handle_get_many_impl(store, request, &mut writer).await;
518 handle_write_result(&mut writer, res).await?;
519 Ok(())
520}
521
522#[stack_error(derive, add_meta, from_sources)]
523pub enum HandlePushError {
524 #[error(transparent)]
525 ExportBao { source: ExportBaoError },
526
527 #[error("Invalid hash sequence")]
528 InvalidHashSeq {},
529
530 #[error(transparent)]
531 Request { source: RequestError },
532}
533
534impl HasErrorCode for HandlePushError {
535 fn code(&self) -> VarInt {
536 match self {
537 Self::ExportBao {
538 source: ExportBaoError::ClientError { source, .. },
539 ..
540 } => source.code(),
541 _ => ERR_INTERNAL,
542 }
543 }
544}
545
546async fn handle_push_impl<R: RecvStream>(
550 store: Store,
551 request: PushRequest,
552 reader: &mut ProgressReader<R>,
553) -> Result<(), HandlePushError> {
554 let hash = request.hash;
555 debug!(%hash, "push received request");
556 let mut request_ranges = request.ranges.iter_infinite();
557 let root_ranges = request_ranges.next().expect("infinite iterator");
558 if !root_ranges.is_empty() {
559 store
561 .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
562 .await?;
563 }
564 if request.ranges.is_blob() {
565 debug!("push request complete");
566 return Ok(());
567 }
568 let hash_seq = store.get_bytes(hash).await?;
570 let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| e!(HandlePushError::InvalidHashSeq))?;
571 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
572 if child_ranges.is_empty() {
573 continue;
574 }
575 store
576 .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
577 .await?;
578 }
579 Ok(())
580}
581
582pub async fn handle_push<R: RecvStream, W: SendStream>(
583 mut pair: StreamPair<R, W>,
584 store: Store,
585 request: PushRequest,
586) -> n0_error::Result<()> {
587 let res = pair.push_request(|| request.clone()).await;
588 let tracker = handle_read_request_result(&mut pair, res).await?;
589 let mut reader = pair.into_reader(tracker).await?;
590 let res = handle_push_impl(store, request, &mut reader).await;
591 handle_read_result(&mut reader, res).await?;
592 Ok(())
593}
594
595pub(crate) async fn send_blob<W: SendStream>(
597 store: &Store,
598 index: u64,
599 hash: Hash,
600 ranges: ChunkRanges,
601 writer: &mut ProgressWriter<W>,
602) -> ExportBaoResult<()> {
603 store
604 .export_bao(hash, ranges)
605 .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
606 .await
607}
608
609#[stack_error(derive, add_meta, std_sources, from_sources)]
610pub enum HandleObserveError {
611 #[error("observe stream closed")]
612 ObserveStreamClosed {},
613
614 #[error(transparent)]
615 RemoteClosed { source: io::Error },
616}
617
618impl HasErrorCode for HandleObserveError {
619 fn code(&self) -> VarInt {
620 ERR_INTERNAL
621 }
622}
623
624async fn handle_observe_impl<W: SendStream>(
628 store: Store,
629 request: ObserveRequest,
630 writer: &mut ProgressWriter<W>,
631) -> std::result::Result<(), HandleObserveError> {
632 let mut stream = store
633 .observe(request.hash)
634 .stream()
635 .await
636 .map_err(|_| e!(HandleObserveError::ObserveStreamClosed))?;
637 let mut old = stream
638 .next()
639 .await
640 .ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
641 send_observe_item(writer, &old).await?;
643 loop {
645 select! {
646 new = stream.next() => {
647 let new = new.ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
648 let diff = old.diff(&new);
649 if diff.is_empty() {
650 continue;
651 }
652 send_observe_item(writer, &diff).await?;
653 old = new;
654 }
655 _ = writer.inner.stopped() => {
656 debug!("observer closed");
657 break;
658 }
659 }
660 }
661 Ok(())
662}
663
664async fn send_observe_item<W: SendStream>(
665 writer: &mut ProgressWriter<W>,
666 item: &Bitfield,
667) -> io::Result<()> {
668 let item = ObserveItem::from(item);
669 let len = writer.inner.write_length_prefixed(item).await?;
670 writer.context.log_other_write(len);
671 Ok(())
672}
673
674pub async fn handle_observe<R: RecvStream, W: SendStream>(
675 mut pair: StreamPair<R, W>,
676 store: Store,
677 request: ObserveRequest,
678) -> n0_error::Result<()> {
679 let res = pair.observe_request(|| request.clone()).await;
680 let tracker = handle_read_request_result(&mut pair, res).await?;
681 let mut writer = pair.into_writer(tracker).await?;
682 let res = handle_observe_impl(store, request, &mut writer).await;
683 handle_write_result(&mut writer, res).await?;
684 Ok(())
685}
686
687pub struct ProgressReader<R: RecvStream = DefaultReader> {
688 inner: R,
689 context: ReaderContext,
690}
691
692impl<R: RecvStream> ProgressReader<R> {
693 async fn transfer_aborted(&self) {
694 self.context
695 .tracker
696 .transfer_aborted(|| Box::new(self.context.stats()))
697 .await
698 .ok();
699 }
700
701 async fn transfer_completed(&self) {
702 self.context
703 .tracker
704 .transfer_completed(|| Box::new(self.context.stats()))
705 .await
706 .ok();
707 }
708}