1use std::{fmt::Debug, future::Future, io, time::Duration};
7
8use anyhow::Result;
9use bao_tree::ChunkRanges;
10use iroh::endpoint::{self, ConnectionError, VarInt};
11use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
12use n0_future::{time::Instant, StreamExt};
13use serde::{Deserialize, Serialize};
14use snafu::Snafu;
15use tokio::select;
16use tracing::{debug, debug_span, Instrument};
17
18use crate::{
19 api::{
20 blobs::{Bitfield, WriteProgress},
21 ExportBaoError, ExportBaoResult, RequestError, Store,
22 },
23 hashseq::HashSeq,
24 protocol::{
25 GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
26 },
27 provider::events::{
28 ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
29 RequestTracker,
30 },
31 util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
32 Hash,
33};
34pub mod events;
35use events::EventSender;
36
37type DefaultReader = iroh::endpoint::RecvStream;
38type DefaultWriter = iroh::endpoint::SendStream;
39
40#[derive(Debug, Serialize, Deserialize)]
42pub struct TransferStats {
43 pub payload_bytes_sent: u64,
45 pub other_bytes_sent: u64,
49 pub other_bytes_read: u64,
54 pub duration: Duration,
56}
57
58#[derive(Debug)]
60pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
61 t0: Instant,
62 connection_id: u64,
63 reader: R,
64 writer: W,
65 other_bytes_read: u64,
66 events: EventSender,
67}
68
69impl StreamPair {
70 pub async fn accept(
71 conn: &endpoint::Connection,
72 events: EventSender,
73 ) -> Result<Self, ConnectionError> {
74 let (writer, reader) = conn.accept_bi().await?;
75 Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
76 }
77}
78
79impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
80 pub fn stream_id(&self) -> u64 {
81 self.reader.id()
82 }
83
84 pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
85 Self {
86 t0: Instant::now(),
87 connection_id,
88 reader,
89 writer,
90 other_bytes_read: 0,
91 events,
92 }
93 }
94
95 pub async fn read_request(&mut self) -> Result<Request> {
104 let (res, size) = Request::read_async(&mut self.reader).await?;
105 self.other_bytes_read += size as u64;
106 Ok(res)
107 }
108
109 pub async fn into_writer(
111 mut self,
112 tracker: RequestTracker,
113 ) -> Result<ProgressWriter<W>, io::Error> {
114 self.reader.expect_eof().await?;
115 drop(self.reader);
116 Ok(ProgressWriter::new(
117 self.writer,
118 WriterContext {
119 t0: self.t0,
120 other_bytes_read: self.other_bytes_read,
121 payload_bytes_written: 0,
122 other_bytes_written: 0,
123 tracker,
124 },
125 ))
126 }
127
128 pub async fn into_reader(
129 mut self,
130 tracker: RequestTracker,
131 ) -> Result<ProgressReader<R>, io::Error> {
132 self.writer.sync().await?;
133 drop(self.writer);
134 Ok(ProgressReader {
135 inner: self.reader,
136 context: ReaderContext {
137 t0: self.t0,
138 other_bytes_read: self.other_bytes_read,
139 tracker,
140 },
141 })
142 }
143
144 pub async fn get_request(
145 &self,
146 f: impl FnOnce() -> GetRequest,
147 ) -> Result<RequestTracker, ProgressError> {
148 self.events
149 .request(f, self.connection_id, self.reader.id())
150 .await
151 }
152
153 pub async fn get_many_request(
154 &self,
155 f: impl FnOnce() -> GetManyRequest,
156 ) -> Result<RequestTracker, ProgressError> {
157 self.events
158 .request(f, self.connection_id, self.reader.id())
159 .await
160 }
161
162 pub async fn push_request(
163 &self,
164 f: impl FnOnce() -> PushRequest,
165 ) -> Result<RequestTracker, ProgressError> {
166 self.events
167 .request(f, self.connection_id, self.reader.id())
168 .await
169 }
170
171 pub async fn observe_request(
172 &self,
173 f: impl FnOnce() -> ObserveRequest,
174 ) -> Result<RequestTracker, ProgressError> {
175 self.events
176 .request(f, self.connection_id, self.reader.id())
177 .await
178 }
179
180 pub fn stats(&self) -> TransferStats {
181 TransferStats {
182 payload_bytes_sent: 0,
183 other_bytes_sent: 0,
184 other_bytes_read: self.other_bytes_read,
185 duration: self.t0.elapsed(),
186 }
187 }
188}
189
190#[derive(Debug)]
191struct ReaderContext {
192 t0: Instant,
194 other_bytes_read: u64,
196 tracker: RequestTracker,
198}
199
200impl ReaderContext {
201 fn stats(&self) -> TransferStats {
202 TransferStats {
203 payload_bytes_sent: 0,
204 other_bytes_sent: 0,
205 other_bytes_read: self.other_bytes_read,
206 duration: self.t0.elapsed(),
207 }
208 }
209}
210
211#[derive(Debug)]
212pub(crate) struct WriterContext {
213 t0: Instant,
215 other_bytes_read: u64,
217 payload_bytes_written: u64,
219 other_bytes_written: u64,
221 tracker: RequestTracker,
223}
224
225impl WriterContext {
226 fn stats(&self) -> TransferStats {
227 TransferStats {
228 payload_bytes_sent: self.payload_bytes_written,
229 other_bytes_sent: self.other_bytes_written,
230 other_bytes_read: self.other_bytes_read,
231 duration: self.t0.elapsed(),
232 }
233 }
234}
235
236impl WriteProgress for WriterContext {
237 async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
238 let len = len as u64;
239 let end_offset = offset + len;
240 self.payload_bytes_written += len;
241 self.tracker.transfer_progress(len, end_offset).await
242 }
243
244 fn log_other_write(&mut self, len: usize) {
245 self.other_bytes_written += len as u64;
246 }
247
248 async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
249 self.tracker.transfer_started(index, hash, size).await.ok();
250 }
251}
252
253#[derive(Debug)]
255pub struct ProgressWriter<W: SendStream = DefaultWriter> {
256 pub inner: W,
258 pub(crate) context: WriterContext,
259}
260
261impl<W: SendStream> ProgressWriter<W> {
262 fn new(inner: W, context: WriterContext) -> Self {
263 Self { inner, context }
264 }
265
266 async fn transfer_aborted(&self) {
267 self.context
268 .tracker
269 .transfer_aborted(|| Box::new(self.context.stats()))
270 .await
271 .ok();
272 }
273
274 async fn transfer_completed(&self) {
275 self.context
276 .tracker
277 .transfer_completed(|| Box::new(self.context.stats()))
278 .await
279 .ok();
280 }
281}
282
283pub async fn handle_connection(
285 connection: endpoint::Connection,
286 store: Store,
287 progress: EventSender,
288) {
289 let connection_id = connection.stable_id() as u64;
290 let span = debug_span!("connection", connection_id);
291 async move {
292 if let Err(cause) = progress
293 .client_connected(|| ClientConnected {
294 connection_id,
295 endpoint_id: connection.remote_id().ok(),
296 })
297 .await
298 {
299 connection.close(cause.code(), cause.reason());
300 debug!("closing connection: {cause}");
301 return;
302 }
303 while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
304 let span = debug_span!("stream", stream_id = %pair.stream_id());
305 let store = store.clone();
306 n0_future::task::spawn(handle_stream(pair, store).instrument(span));
307 }
308 progress
309 .connection_closed(|| ConnectionClosed { connection_id })
310 .await
311 .ok();
312 }
313 .instrument(span)
314 .await
315}
316
317pub trait ErrorHandler {
319 type W: AsyncStreamWriter;
320 type R: AsyncStreamReader;
321 fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
322 fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
323}
324
325async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
326 pair: &mut StreamPair<R, W>,
327 r: Result<T, E>,
328) -> Result<T, E> {
329 match r {
330 Ok(x) => Ok(x),
331 Err(e) => {
332 pair.writer.reset(e.code()).ok();
333 Err(e)
334 }
335 }
336}
337async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
338 writer: &mut ProgressWriter<W>,
339 r: Result<T, E>,
340) -> Result<T, E> {
341 match r {
342 Ok(x) => {
343 writer.transfer_completed().await;
344 Ok(x)
345 }
346 Err(e) => {
347 writer.inner.reset(e.code()).ok();
348 writer.transfer_aborted().await;
349 Err(e)
350 }
351 }
352}
353async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
354 reader: &mut ProgressReader<R>,
355 r: Result<T, E>,
356) -> Result<T, E> {
357 match r {
358 Ok(x) => {
359 reader.transfer_completed().await;
360 Ok(x)
361 }
362 Err(e) => {
363 reader.inner.stop(e.code()).ok();
364 reader.transfer_aborted().await;
365 Err(e)
366 }
367 }
368}
369
370pub async fn handle_stream<R: RecvStream, W: SendStream>(
371 mut pair: StreamPair<R, W>,
372 store: Store,
373) -> anyhow::Result<()> {
374 let request = pair.read_request().await?;
375 match request {
376 Request::Get(request) => handle_get(pair, store, request).await?,
377 Request::GetMany(request) => handle_get_many(pair, store, request).await?,
378 Request::Observe(request) => handle_observe(pair, store, request).await?,
379 Request::Push(request) => handle_push(pair, store, request).await?,
380 _ => {}
381 }
382 Ok(())
383}
384
385#[derive(Debug, Snafu)]
386#[snafu(module)]
387pub enum HandleGetError {
388 #[snafu(transparent)]
389 ExportBao {
390 source: ExportBaoError,
391 },
392 InvalidHashSeq,
393 InvalidOffset,
394}
395
396impl HasErrorCode for HandleGetError {
397 fn code(&self) -> VarInt {
398 match self {
399 HandleGetError::ExportBao {
400 source: ExportBaoError::ClientError { source, .. },
401 } => source.code(),
402 HandleGetError::InvalidHashSeq => ERR_INTERNAL,
403 HandleGetError::InvalidOffset => ERR_INTERNAL,
404 _ => ERR_INTERNAL,
405 }
406 }
407}
408
409async fn handle_get_impl<W: SendStream>(
413 store: Store,
414 request: GetRequest,
415 writer: &mut ProgressWriter<W>,
416) -> Result<(), HandleGetError> {
417 let hash = request.hash;
418 debug!(%hash, "get received request");
419 let mut hash_seq = None;
420 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
421 if offset == 0 {
422 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
423 } else {
424 let hash_seq = match &hash_seq {
430 Some(b) => b,
431 None => {
432 let bytes = store.get_bytes(hash).await?;
433 let hs =
434 HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?;
435 hash_seq = Some(hs);
436 hash_seq.as_ref().unwrap()
437 }
438 };
439 let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?;
440 let Some(hash) = hash_seq.get(o) else {
441 break;
442 };
443 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
444 }
445 }
446 writer
447 .inner
448 .sync()
449 .await
450 .map_err(|e| HandleGetError::ExportBao { source: e.into() })?;
451
452 Ok(())
453}
454
455pub async fn handle_get<R: RecvStream, W: SendStream>(
456 mut pair: StreamPair<R, W>,
457 store: Store,
458 request: GetRequest,
459) -> anyhow::Result<()> {
460 let res = pair.get_request(|| request.clone()).await;
461 let tracker = handle_read_request_result(&mut pair, res).await?;
462 let mut writer = pair.into_writer(tracker).await?;
463 let res = handle_get_impl(store, request, &mut writer).await;
464 handle_write_result(&mut writer, res).await?;
465 Ok(())
466}
467
468#[derive(Debug, Snafu)]
469pub enum HandleGetManyError {
470 #[snafu(transparent)]
471 ExportBao { source: ExportBaoError },
472}
473
474impl HasErrorCode for HandleGetManyError {
475 fn code(&self) -> VarInt {
476 match self {
477 Self::ExportBao {
478 source: ExportBaoError::ClientError { source, .. },
479 } => source.code(),
480 _ => ERR_INTERNAL,
481 }
482 }
483}
484
485async fn handle_get_many_impl<W: SendStream>(
489 store: Store,
490 request: GetManyRequest,
491 writer: &mut ProgressWriter<W>,
492) -> Result<(), HandleGetManyError> {
493 debug!("get_many received request");
494 let request_ranges = request.ranges.iter_infinite();
495 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
496 if !ranges.is_empty() {
497 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
498 }
499 }
500 Ok(())
501}
502
503pub async fn handle_get_many<R: RecvStream, W: SendStream>(
504 mut pair: StreamPair<R, W>,
505 store: Store,
506 request: GetManyRequest,
507) -> anyhow::Result<()> {
508 let res = pair.get_many_request(|| request.clone()).await;
509 let tracker = handle_read_request_result(&mut pair, res).await?;
510 let mut writer = pair.into_writer(tracker).await?;
511 let res = handle_get_many_impl(store, request, &mut writer).await;
512 handle_write_result(&mut writer, res).await?;
513 Ok(())
514}
515
516#[derive(Debug, Snafu)]
517pub enum HandlePushError {
518 #[snafu(transparent)]
519 ExportBao {
520 source: ExportBaoError,
521 },
522
523 InvalidHashSeq,
524
525 #[snafu(transparent)]
526 Request {
527 source: RequestError,
528 },
529}
530
531impl HasErrorCode for HandlePushError {
532 fn code(&self) -> VarInt {
533 match self {
534 Self::ExportBao {
535 source: ExportBaoError::ClientError { source, .. },
536 } => source.code(),
537 _ => ERR_INTERNAL,
538 }
539 }
540}
541
542async fn handle_push_impl<R: RecvStream>(
546 store: Store,
547 request: PushRequest,
548 reader: &mut ProgressReader<R>,
549) -> Result<(), HandlePushError> {
550 let hash = request.hash;
551 debug!(%hash, "push received request");
552 let mut request_ranges = request.ranges.iter_infinite();
553 let root_ranges = request_ranges.next().expect("infinite iterator");
554 if !root_ranges.is_empty() {
555 store
557 .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
558 .await?;
559 }
560 if request.ranges.is_blob() {
561 debug!("push request complete");
562 return Ok(());
563 }
564 let hash_seq = store.get_bytes(hash).await?;
566 let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?;
567 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
568 if child_ranges.is_empty() {
569 continue;
570 }
571 store
572 .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
573 .await?;
574 }
575 Ok(())
576}
577
578pub async fn handle_push<R: RecvStream, W: SendStream>(
579 mut pair: StreamPair<R, W>,
580 store: Store,
581 request: PushRequest,
582) -> anyhow::Result<()> {
583 let res = pair.push_request(|| request.clone()).await;
584 let tracker = handle_read_request_result(&mut pair, res).await?;
585 let mut reader = pair.into_reader(tracker).await?;
586 let res = handle_push_impl(store, request, &mut reader).await;
587 handle_read_result(&mut reader, res).await?;
588 Ok(())
589}
590
591pub(crate) async fn send_blob<W: SendStream>(
593 store: &Store,
594 index: u64,
595 hash: Hash,
596 ranges: ChunkRanges,
597 writer: &mut ProgressWriter<W>,
598) -> ExportBaoResult<()> {
599 store
600 .export_bao(hash, ranges)
601 .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
602 .await
603}
604
605#[derive(Debug, Snafu)]
606pub enum HandleObserveError {
607 ObserveStreamClosed,
608
609 #[snafu(transparent)]
610 RemoteClosed {
611 source: io::Error,
612 },
613}
614
615impl HasErrorCode for HandleObserveError {
616 fn code(&self) -> VarInt {
617 ERR_INTERNAL
618 }
619}
620
621async fn handle_observe_impl<W: SendStream>(
625 store: Store,
626 request: ObserveRequest,
627 writer: &mut ProgressWriter<W>,
628) -> std::result::Result<(), HandleObserveError> {
629 let mut stream = store
630 .observe(request.hash)
631 .stream()
632 .await
633 .map_err(|_| HandleObserveError::ObserveStreamClosed)?;
634 let mut old = stream
635 .next()
636 .await
637 .ok_or(HandleObserveError::ObserveStreamClosed)?;
638 send_observe_item(writer, &old).await?;
640 loop {
642 select! {
643 new = stream.next() => {
644 let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?;
645 let diff = old.diff(&new);
646 if diff.is_empty() {
647 continue;
648 }
649 send_observe_item(writer, &diff).await?;
650 old = new;
651 }
652 _ = writer.inner.stopped() => {
653 debug!("observer closed");
654 break;
655 }
656 }
657 }
658 Ok(())
659}
660
661async fn send_observe_item<W: SendStream>(
662 writer: &mut ProgressWriter<W>,
663 item: &Bitfield,
664) -> io::Result<()> {
665 let item = ObserveItem::from(item);
666 let len = writer.inner.write_length_prefixed(item).await?;
667 writer.context.log_other_write(len);
668 Ok(())
669}
670
671pub async fn handle_observe<R: RecvStream, W: SendStream>(
672 mut pair: StreamPair<R, W>,
673 store: Store,
674 request: ObserveRequest,
675) -> anyhow::Result<()> {
676 let res = pair.observe_request(|| request.clone()).await;
677 let tracker = handle_read_request_result(&mut pair, res).await?;
678 let mut writer = pair.into_writer(tracker).await?;
679 let res = handle_observe_impl(store, request, &mut writer).await;
680 handle_write_result(&mut writer, res).await?;
681 Ok(())
682}
683
684pub struct ProgressReader<R: RecvStream = DefaultReader> {
685 inner: R,
686 context: ReaderContext,
687}
688
689impl<R: RecvStream> ProgressReader<R> {
690 async fn transfer_aborted(&self) {
691 self.context
692 .tracker
693 .transfer_aborted(|| Box::new(self.context.stats()))
694 .await
695 .ok();
696 }
697
698 async fn transfer_completed(&self) {
699 self.context
700 .tracker
701 .transfer_completed(|| Box::new(self.context.stats()))
702 .await
703 .ok();
704 }
705}