1use std::{fmt::Debug, future::Future, io};
7
8use anyhow::Result;
9use bao_tree::ChunkRanges;
10use iroh::endpoint::{self, ConnectionError, VarInt};
11use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
12use n0_future::{
13 time::{Duration, Instant},
14 StreamExt,
15};
16use serde::{Deserialize, Serialize};
17use snafu::Snafu;
18use tokio::select;
19use tracing::{debug, debug_span, Instrument};
20
21use crate::{
22 api::{
23 blobs::{Bitfield, WriteProgress},
24 ExportBaoError, ExportBaoResult, RequestError, Store,
25 },
26 hashseq::HashSeq,
27 protocol::{
28 GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
29 },
30 provider::events::{
31 ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
32 RequestTracker,
33 },
34 util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
35 Hash,
36};
37pub mod events;
38use events::EventSender;
39
40type DefaultReader = iroh::endpoint::RecvStream;
41type DefaultWriter = iroh::endpoint::SendStream;
42
43#[derive(Debug, Serialize, Deserialize)]
45pub struct TransferStats {
46 pub payload_bytes_sent: u64,
48 pub other_bytes_sent: u64,
52 pub other_bytes_read: u64,
57 pub duration: Duration,
59}
60
61#[derive(Debug)]
63pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
64 t0: Instant,
65 connection_id: u64,
66 reader: R,
67 writer: W,
68 other_bytes_read: u64,
69 events: EventSender,
70}
71
72impl StreamPair {
73 pub async fn accept(
74 conn: &endpoint::Connection,
75 events: EventSender,
76 ) -> Result<Self, ConnectionError> {
77 let (writer, reader) = conn.accept_bi().await?;
78 Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
79 }
80}
81
82impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
83 pub fn stream_id(&self) -> u64 {
84 self.reader.id()
85 }
86
87 pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
88 Self {
89 t0: Instant::now(),
90 connection_id,
91 reader,
92 writer,
93 other_bytes_read: 0,
94 events,
95 }
96 }
97
98 pub async fn read_request(&mut self) -> Result<Request> {
107 let (res, size) = Request::read_async(&mut self.reader).await?;
108 self.other_bytes_read += size as u64;
109 Ok(res)
110 }
111
112 pub async fn into_writer(
114 mut self,
115 tracker: RequestTracker,
116 ) -> Result<ProgressWriter<W>, io::Error> {
117 self.reader.expect_eof().await?;
118 drop(self.reader);
119 Ok(ProgressWriter::new(
120 self.writer,
121 WriterContext {
122 t0: self.t0,
123 other_bytes_read: self.other_bytes_read,
124 payload_bytes_written: 0,
125 other_bytes_written: 0,
126 tracker,
127 },
128 ))
129 }
130
131 pub async fn into_reader(
132 mut self,
133 tracker: RequestTracker,
134 ) -> Result<ProgressReader<R>, io::Error> {
135 self.writer.sync().await?;
136 drop(self.writer);
137 Ok(ProgressReader {
138 inner: self.reader,
139 context: ReaderContext {
140 t0: self.t0,
141 other_bytes_read: self.other_bytes_read,
142 tracker,
143 },
144 })
145 }
146
147 pub async fn get_request(
148 &self,
149 f: impl FnOnce() -> GetRequest,
150 ) -> Result<RequestTracker, ProgressError> {
151 self.events
152 .request(f, self.connection_id, self.reader.id())
153 .await
154 }
155
156 pub async fn get_many_request(
157 &self,
158 f: impl FnOnce() -> GetManyRequest,
159 ) -> Result<RequestTracker, ProgressError> {
160 self.events
161 .request(f, self.connection_id, self.reader.id())
162 .await
163 }
164
165 pub async fn push_request(
166 &self,
167 f: impl FnOnce() -> PushRequest,
168 ) -> Result<RequestTracker, ProgressError> {
169 self.events
170 .request(f, self.connection_id, self.reader.id())
171 .await
172 }
173
174 pub async fn observe_request(
175 &self,
176 f: impl FnOnce() -> ObserveRequest,
177 ) -> Result<RequestTracker, ProgressError> {
178 self.events
179 .request(f, self.connection_id, self.reader.id())
180 .await
181 }
182
183 pub fn stats(&self) -> TransferStats {
184 TransferStats {
185 payload_bytes_sent: 0,
186 other_bytes_sent: 0,
187 other_bytes_read: self.other_bytes_read,
188 duration: self.t0.elapsed(),
189 }
190 }
191}
192
193#[derive(Debug)]
194struct ReaderContext {
195 t0: Instant,
197 other_bytes_read: u64,
199 tracker: RequestTracker,
201}
202
203impl ReaderContext {
204 fn stats(&self) -> TransferStats {
205 TransferStats {
206 payload_bytes_sent: 0,
207 other_bytes_sent: 0,
208 other_bytes_read: self.other_bytes_read,
209 duration: self.t0.elapsed(),
210 }
211 }
212}
213
214#[derive(Debug)]
215pub(crate) struct WriterContext {
216 t0: Instant,
218 other_bytes_read: u64,
220 payload_bytes_written: u64,
222 other_bytes_written: u64,
224 tracker: RequestTracker,
226}
227
228impl WriterContext {
229 fn stats(&self) -> TransferStats {
230 TransferStats {
231 payload_bytes_sent: self.payload_bytes_written,
232 other_bytes_sent: self.other_bytes_written,
233 other_bytes_read: self.other_bytes_read,
234 duration: self.t0.elapsed(),
235 }
236 }
237}
238
239impl WriteProgress for WriterContext {
240 async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
241 let len = len as u64;
242 let end_offset = offset + len;
243 self.payload_bytes_written += len;
244 self.tracker.transfer_progress(len, end_offset).await
245 }
246
247 fn log_other_write(&mut self, len: usize) {
248 self.other_bytes_written += len as u64;
249 }
250
251 async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
252 self.tracker.transfer_started(index, hash, size).await.ok();
253 }
254}
255
256#[derive(Debug)]
258pub struct ProgressWriter<W: SendStream = DefaultWriter> {
259 pub inner: W,
261 pub(crate) context: WriterContext,
262}
263
264impl<W: SendStream> ProgressWriter<W> {
265 fn new(inner: W, context: WriterContext) -> Self {
266 Self { inner, context }
267 }
268
269 async fn transfer_aborted(&self) {
270 self.context
271 .tracker
272 .transfer_aborted(|| Box::new(self.context.stats()))
273 .await
274 .ok();
275 }
276
277 async fn transfer_completed(&self) {
278 self.context
279 .tracker
280 .transfer_completed(|| Box::new(self.context.stats()))
281 .await
282 .ok();
283 }
284}
285
286pub async fn handle_connection(
288 connection: endpoint::Connection,
289 store: Store,
290 progress: EventSender,
291) {
292 let connection_id = connection.stable_id() as u64;
293 let span = debug_span!("connection", connection_id);
294 async move {
295 if let Err(cause) = progress
296 .client_connected(|| ClientConnected {
297 connection_id,
298 endpoint_id: Some(connection.remote_id()),
299 })
300 .await
301 {
302 connection.close(cause.code(), cause.reason());
303 debug!("closing connection: {cause}");
304 return;
305 }
306 while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
307 let span = debug_span!("stream", stream_id = %pair.stream_id());
308 let store = store.clone();
309 n0_future::task::spawn(handle_stream(pair, store).instrument(span));
310 }
311 progress
312 .connection_closed(|| ConnectionClosed { connection_id })
313 .await
314 .ok();
315 }
316 .instrument(span)
317 .await
318}
319
320pub trait ErrorHandler {
322 type W: AsyncStreamWriter;
323 type R: AsyncStreamReader;
324 fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
325 fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
326}
327
328async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
329 pair: &mut StreamPair<R, W>,
330 r: Result<T, E>,
331) -> Result<T, E> {
332 match r {
333 Ok(x) => Ok(x),
334 Err(e) => {
335 pair.writer.reset(e.code()).ok();
336 Err(e)
337 }
338 }
339}
340async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
341 writer: &mut ProgressWriter<W>,
342 r: Result<T, E>,
343) -> Result<T, E> {
344 match r {
345 Ok(x) => {
346 writer.transfer_completed().await;
347 Ok(x)
348 }
349 Err(e) => {
350 writer.inner.reset(e.code()).ok();
351 writer.transfer_aborted().await;
352 Err(e)
353 }
354 }
355}
356async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
357 reader: &mut ProgressReader<R>,
358 r: Result<T, E>,
359) -> Result<T, E> {
360 match r {
361 Ok(x) => {
362 reader.transfer_completed().await;
363 Ok(x)
364 }
365 Err(e) => {
366 reader.inner.stop(e.code()).ok();
367 reader.transfer_aborted().await;
368 Err(e)
369 }
370 }
371}
372
373pub async fn handle_stream<R: RecvStream, W: SendStream>(
374 mut pair: StreamPair<R, W>,
375 store: Store,
376) -> anyhow::Result<()> {
377 let request = pair.read_request().await?;
378 match request {
379 Request::Get(request) => handle_get(pair, store, request).await?,
380 Request::GetMany(request) => handle_get_many(pair, store, request).await?,
381 Request::Observe(request) => handle_observe(pair, store, request).await?,
382 Request::Push(request) => handle_push(pair, store, request).await?,
383 _ => {}
384 }
385 Ok(())
386}
387
388#[derive(Debug, Snafu)]
389#[snafu(module)]
390pub enum HandleGetError {
391 #[snafu(transparent)]
392 ExportBao {
393 source: ExportBaoError,
394 },
395 InvalidHashSeq,
396 InvalidOffset,
397}
398
399impl HasErrorCode for HandleGetError {
400 fn code(&self) -> VarInt {
401 match self {
402 HandleGetError::ExportBao {
403 source: ExportBaoError::ClientError { source, .. },
404 } => source.code(),
405 HandleGetError::InvalidHashSeq => ERR_INTERNAL,
406 HandleGetError::InvalidOffset => ERR_INTERNAL,
407 _ => ERR_INTERNAL,
408 }
409 }
410}
411
412async fn handle_get_impl<W: SendStream>(
416 store: Store,
417 request: GetRequest,
418 writer: &mut ProgressWriter<W>,
419) -> Result<(), HandleGetError> {
420 let hash = request.hash;
421 debug!(%hash, "get received request");
422 let mut hash_seq = None;
423 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
424 if offset == 0 {
425 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
426 } else {
427 let hash_seq = match &hash_seq {
433 Some(b) => b,
434 None => {
435 let bytes = store.get_bytes(hash).await?;
436 let hs =
437 HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?;
438 hash_seq = Some(hs);
439 hash_seq.as_ref().unwrap()
440 }
441 };
442 let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?;
443 let Some(hash) = hash_seq.get(o) else {
444 break;
445 };
446 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
447 }
448 }
449 writer
450 .inner
451 .sync()
452 .await
453 .map_err(|e| HandleGetError::ExportBao { source: e.into() })?;
454
455 Ok(())
456}
457
458pub async fn handle_get<R: RecvStream, W: SendStream>(
459 mut pair: StreamPair<R, W>,
460 store: Store,
461 request: GetRequest,
462) -> anyhow::Result<()> {
463 let res = pair.get_request(|| request.clone()).await;
464 let tracker = handle_read_request_result(&mut pair, res).await?;
465 let mut writer = pair.into_writer(tracker).await?;
466 let res = handle_get_impl(store, request, &mut writer).await;
467 handle_write_result(&mut writer, res).await?;
468 Ok(())
469}
470
471#[derive(Debug, Snafu)]
472pub enum HandleGetManyError {
473 #[snafu(transparent)]
474 ExportBao { source: ExportBaoError },
475}
476
477impl HasErrorCode for HandleGetManyError {
478 fn code(&self) -> VarInt {
479 match self {
480 Self::ExportBao {
481 source: ExportBaoError::ClientError { source, .. },
482 } => source.code(),
483 _ => ERR_INTERNAL,
484 }
485 }
486}
487
488async fn handle_get_many_impl<W: SendStream>(
492 store: Store,
493 request: GetManyRequest,
494 writer: &mut ProgressWriter<W>,
495) -> Result<(), HandleGetManyError> {
496 debug!("get_many received request");
497 let request_ranges = request.ranges.iter_infinite();
498 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
499 if !ranges.is_empty() {
500 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
501 }
502 }
503 Ok(())
504}
505
506pub async fn handle_get_many<R: RecvStream, W: SendStream>(
507 mut pair: StreamPair<R, W>,
508 store: Store,
509 request: GetManyRequest,
510) -> anyhow::Result<()> {
511 let res = pair.get_many_request(|| request.clone()).await;
512 let tracker = handle_read_request_result(&mut pair, res).await?;
513 let mut writer = pair.into_writer(tracker).await?;
514 let res = handle_get_many_impl(store, request, &mut writer).await;
515 handle_write_result(&mut writer, res).await?;
516 Ok(())
517}
518
519#[derive(Debug, Snafu)]
520pub enum HandlePushError {
521 #[snafu(transparent)]
522 ExportBao {
523 source: ExportBaoError,
524 },
525
526 InvalidHashSeq,
527
528 #[snafu(transparent)]
529 Request {
530 source: RequestError,
531 },
532}
533
534impl HasErrorCode for HandlePushError {
535 fn code(&self) -> VarInt {
536 match self {
537 Self::ExportBao {
538 source: ExportBaoError::ClientError { source, .. },
539 } => source.code(),
540 _ => ERR_INTERNAL,
541 }
542 }
543}
544
545async fn handle_push_impl<R: RecvStream>(
549 store: Store,
550 request: PushRequest,
551 reader: &mut ProgressReader<R>,
552) -> Result<(), HandlePushError> {
553 let hash = request.hash;
554 debug!(%hash, "push received request");
555 let mut request_ranges = request.ranges.iter_infinite();
556 let root_ranges = request_ranges.next().expect("infinite iterator");
557 if !root_ranges.is_empty() {
558 store
560 .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
561 .await?;
562 }
563 if request.ranges.is_blob() {
564 debug!("push request complete");
565 return Ok(());
566 }
567 let hash_seq = store.get_bytes(hash).await?;
569 let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?;
570 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
571 if child_ranges.is_empty() {
572 continue;
573 }
574 store
575 .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
576 .await?;
577 }
578 Ok(())
579}
580
581pub async fn handle_push<R: RecvStream, W: SendStream>(
582 mut pair: StreamPair<R, W>,
583 store: Store,
584 request: PushRequest,
585) -> anyhow::Result<()> {
586 let res = pair.push_request(|| request.clone()).await;
587 let tracker = handle_read_request_result(&mut pair, res).await?;
588 let mut reader = pair.into_reader(tracker).await?;
589 let res = handle_push_impl(store, request, &mut reader).await;
590 handle_read_result(&mut reader, res).await?;
591 Ok(())
592}
593
594pub(crate) async fn send_blob<W: SendStream>(
596 store: &Store,
597 index: u64,
598 hash: Hash,
599 ranges: ChunkRanges,
600 writer: &mut ProgressWriter<W>,
601) -> ExportBaoResult<()> {
602 store
603 .export_bao(hash, ranges)
604 .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
605 .await
606}
607
608#[derive(Debug, Snafu)]
609pub enum HandleObserveError {
610 ObserveStreamClosed,
611
612 #[snafu(transparent)]
613 RemoteClosed {
614 source: io::Error,
615 },
616}
617
618impl HasErrorCode for HandleObserveError {
619 fn code(&self) -> VarInt {
620 ERR_INTERNAL
621 }
622}
623
624async fn handle_observe_impl<W: SendStream>(
628 store: Store,
629 request: ObserveRequest,
630 writer: &mut ProgressWriter<W>,
631) -> std::result::Result<(), HandleObserveError> {
632 let mut stream = store
633 .observe(request.hash)
634 .stream()
635 .await
636 .map_err(|_| HandleObserveError::ObserveStreamClosed)?;
637 let mut old = stream
638 .next()
639 .await
640 .ok_or(HandleObserveError::ObserveStreamClosed)?;
641 send_observe_item(writer, &old).await?;
643 loop {
645 select! {
646 new = stream.next() => {
647 let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?;
648 let diff = old.diff(&new);
649 if diff.is_empty() {
650 continue;
651 }
652 send_observe_item(writer, &diff).await?;
653 old = new;
654 }
655 _ = writer.inner.stopped() => {
656 debug!("observer closed");
657 break;
658 }
659 }
660 }
661 Ok(())
662}
663
664async fn send_observe_item<W: SendStream>(
665 writer: &mut ProgressWriter<W>,
666 item: &Bitfield,
667) -> io::Result<()> {
668 let item = ObserveItem::from(item);
669 let len = writer.inner.write_length_prefixed(item).await?;
670 writer.context.log_other_write(len);
671 Ok(())
672}
673
674pub async fn handle_observe<R: RecvStream, W: SendStream>(
675 mut pair: StreamPair<R, W>,
676 store: Store,
677 request: ObserveRequest,
678) -> anyhow::Result<()> {
679 let res = pair.observe_request(|| request.clone()).await;
680 let tracker = handle_read_request_result(&mut pair, res).await?;
681 let mut writer = pair.into_writer(tracker).await?;
682 let res = handle_observe_impl(store, request, &mut writer).await;
683 handle_write_result(&mut writer, res).await?;
684 Ok(())
685}
686
687pub struct ProgressReader<R: RecvStream = DefaultReader> {
688 inner: R,
689 context: ReaderContext,
690}
691
692impl<R: RecvStream> ProgressReader<R> {
693 async fn transfer_aborted(&self) {
694 self.context
695 .tracker
696 .transfer_aborted(|| Box::new(self.context.stats()))
697 .await
698 .ok();
699 }
700
701 async fn transfer_completed(&self) {
702 self.context
703 .tracker
704 .transfer_completed(|| Box::new(self.context.stats()))
705 .await
706 .ok();
707 }
708}