1use std::{
7 fmt::Debug,
8 future::Future,
9 io,
10 time::{Duration, Instant},
11};
12
13use bao_tree::ChunkRanges;
14use iroh::endpoint::{self, VarInt};
15use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
16use n0_error::{e, stack_error, Result};
17use n0_future::StreamExt;
18use quinn::ConnectionError;
19use serde::{Deserialize, Serialize};
20use tokio::select;
21use tracing::{debug, debug_span, Instrument};
22
23use crate::{
24 api::{
25 blobs::{Bitfield, WriteProgress},
26 ExportBaoError, ExportBaoResult, RequestError, Store,
27 },
28 hashseq::HashSeq,
29 protocol::{
30 GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
31 },
32 provider::events::{
33 ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
34 RequestTracker,
35 },
36 util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
37 Hash,
38};
39pub mod events;
40use events::EventSender;
41
42type DefaultReader = iroh::endpoint::RecvStream;
43type DefaultWriter = iroh::endpoint::SendStream;
44
45#[derive(Debug, Serialize, Deserialize)]
47pub struct TransferStats {
48 pub payload_bytes_sent: u64,
50 pub other_bytes_sent: u64,
54 pub other_bytes_read: u64,
59 pub duration: Duration,
61}
62
63#[derive(Debug)]
65pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
66 t0: Instant,
67 connection_id: u64,
68 reader: R,
69 writer: W,
70 other_bytes_read: u64,
71 events: EventSender,
72}
73
74impl StreamPair {
75 pub async fn accept(
76 conn: &endpoint::Connection,
77 events: EventSender,
78 ) -> Result<Self, ConnectionError> {
79 let (writer, reader) = conn.accept_bi().await?;
80 Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
81 }
82}
83
84impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
85 pub fn stream_id(&self) -> u64 {
86 self.reader.id()
87 }
88
89 pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
90 Self {
91 t0: Instant::now(),
92 connection_id,
93 reader,
94 writer,
95 other_bytes_read: 0,
96 events,
97 }
98 }
99
100 pub async fn read_request(&mut self) -> Result<Request> {
109 let (res, size) = Request::read_async(&mut self.reader).await?;
110 self.other_bytes_read += size as u64;
111 Ok(res)
112 }
113
114 pub async fn into_writer(
116 mut self,
117 tracker: RequestTracker,
118 ) -> Result<ProgressWriter<W>, io::Error> {
119 self.reader.expect_eof().await?;
120 drop(self.reader);
121 Ok(ProgressWriter::new(
122 self.writer,
123 WriterContext {
124 t0: self.t0,
125 other_bytes_read: self.other_bytes_read,
126 payload_bytes_written: 0,
127 other_bytes_written: 0,
128 tracker,
129 },
130 ))
131 }
132
133 pub async fn into_reader(
134 mut self,
135 tracker: RequestTracker,
136 ) -> Result<ProgressReader<R>, io::Error> {
137 self.writer.sync().await?;
138 drop(self.writer);
139 Ok(ProgressReader {
140 inner: self.reader,
141 context: ReaderContext {
142 t0: self.t0,
143 other_bytes_read: self.other_bytes_read,
144 tracker,
145 },
146 })
147 }
148
149 pub async fn get_request(
150 &self,
151 f: impl FnOnce() -> GetRequest,
152 ) -> Result<RequestTracker, ProgressError> {
153 self.events
154 .request(f, self.connection_id, self.reader.id())
155 .await
156 }
157
158 pub async fn get_many_request(
159 &self,
160 f: impl FnOnce() -> GetManyRequest,
161 ) -> Result<RequestTracker, ProgressError> {
162 self.events
163 .request(f, self.connection_id, self.reader.id())
164 .await
165 }
166
167 pub async fn push_request(
168 &self,
169 f: impl FnOnce() -> PushRequest,
170 ) -> Result<RequestTracker, ProgressError> {
171 self.events
172 .request(f, self.connection_id, self.reader.id())
173 .await
174 }
175
176 pub async fn observe_request(
177 &self,
178 f: impl FnOnce() -> ObserveRequest,
179 ) -> Result<RequestTracker, ProgressError> {
180 self.events
181 .request(f, self.connection_id, self.reader.id())
182 .await
183 }
184
185 pub fn stats(&self) -> TransferStats {
186 TransferStats {
187 payload_bytes_sent: 0,
188 other_bytes_sent: 0,
189 other_bytes_read: self.other_bytes_read,
190 duration: self.t0.elapsed(),
191 }
192 }
193}
194
195#[derive(Debug)]
196struct ReaderContext {
197 t0: Instant,
199 other_bytes_read: u64,
201 tracker: RequestTracker,
203}
204
205impl ReaderContext {
206 fn stats(&self) -> TransferStats {
207 TransferStats {
208 payload_bytes_sent: 0,
209 other_bytes_sent: 0,
210 other_bytes_read: self.other_bytes_read,
211 duration: self.t0.elapsed(),
212 }
213 }
214}
215
216#[derive(Debug)]
217pub(crate) struct WriterContext {
218 t0: Instant,
220 other_bytes_read: u64,
222 payload_bytes_written: u64,
224 other_bytes_written: u64,
226 tracker: RequestTracker,
228}
229
230impl WriterContext {
231 fn stats(&self) -> TransferStats {
232 TransferStats {
233 payload_bytes_sent: self.payload_bytes_written,
234 other_bytes_sent: self.other_bytes_written,
235 other_bytes_read: self.other_bytes_read,
236 duration: self.t0.elapsed(),
237 }
238 }
239}
240
241impl WriteProgress for WriterContext {
242 async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
243 let len = len as u64;
244 let end_offset = offset + len;
245 self.payload_bytes_written += len;
246 self.tracker.transfer_progress(len, end_offset).await
247 }
248
249 fn log_other_write(&mut self, len: usize) {
250 self.other_bytes_written += len as u64;
251 }
252
253 async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
254 self.tracker.transfer_started(index, hash, size).await.ok();
255 }
256}
257
258#[derive(Debug)]
260pub struct ProgressWriter<W: SendStream = DefaultWriter> {
261 pub inner: W,
263 pub(crate) context: WriterContext,
264}
265
266impl<W: SendStream> ProgressWriter<W> {
267 fn new(inner: W, context: WriterContext) -> Self {
268 Self { inner, context }
269 }
270
271 async fn transfer_aborted(&self) {
272 self.context
273 .tracker
274 .transfer_aborted(|| Box::new(self.context.stats()))
275 .await
276 .ok();
277 }
278
279 async fn transfer_completed(&self) {
280 self.context
281 .tracker
282 .transfer_completed(|| Box::new(self.context.stats()))
283 .await
284 .ok();
285 }
286}
287
288pub async fn handle_connection(
290 connection: endpoint::Connection,
291 store: Store,
292 progress: EventSender,
293) {
294 let connection_id = connection.stable_id() as u64;
295 let span = debug_span!("connection", connection_id);
296 async move {
297 if let Err(cause) = progress
298 .client_connected(|| ClientConnected {
299 connection_id,
300 endpoint_id: connection.remote_id().ok(),
301 })
302 .await
303 {
304 connection.close(cause.code(), cause.reason());
305 debug!("closing connection: {cause}");
306 return;
307 }
308 while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
309 let span = debug_span!("stream", stream_id = %pair.stream_id());
310 let store = store.clone();
311 tokio::spawn(handle_stream(pair, store).instrument(span));
312 }
313 progress
314 .connection_closed(|| ConnectionClosed { connection_id })
315 .await
316 .ok();
317 }
318 .instrument(span)
319 .await
320}
321
322pub trait ErrorHandler {
324 type W: AsyncStreamWriter;
325 type R: AsyncStreamReader;
326 fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
327 fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
328}
329
330async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
331 pair: &mut StreamPair<R, W>,
332 r: Result<T, E>,
333) -> Result<T, E> {
334 match r {
335 Ok(x) => Ok(x),
336 Err(e) => {
337 pair.writer.reset(e.code()).ok();
338 Err(e)
339 }
340 }
341}
342async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
343 writer: &mut ProgressWriter<W>,
344 r: Result<T, E>,
345) -> Result<T, E> {
346 match r {
347 Ok(x) => {
348 writer.transfer_completed().await;
349 Ok(x)
350 }
351 Err(e) => {
352 writer.inner.reset(e.code()).ok();
353 writer.transfer_aborted().await;
354 Err(e)
355 }
356 }
357}
358async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
359 reader: &mut ProgressReader<R>,
360 r: Result<T, E>,
361) -> Result<T, E> {
362 match r {
363 Ok(x) => {
364 reader.transfer_completed().await;
365 Ok(x)
366 }
367 Err(e) => {
368 reader.inner.stop(e.code()).ok();
369 reader.transfer_aborted().await;
370 Err(e)
371 }
372 }
373}
374
375pub async fn handle_stream<R: RecvStream, W: SendStream>(
376 mut pair: StreamPair<R, W>,
377 store: Store,
378) -> n0_error::Result<()> {
379 let request = pair.read_request().await?;
380 match request {
381 Request::Get(request) => handle_get(pair, store, request).await?,
382 Request::GetMany(request) => handle_get_many(pair, store, request).await?,
383 Request::Observe(request) => handle_observe(pair, store, request).await?,
384 Request::Push(request) => handle_push(pair, store, request).await?,
385 _ => {}
386 }
387 Ok(())
388}
389
390#[stack_error(derive, add_meta, from_sources)]
391pub enum HandleGetError {
392 #[error(transparent)]
393 ExportBao {
394 #[error(std_err)]
395 source: ExportBaoError,
396 },
397 #[error("Invalid hash sequence")]
398 InvalidHashSeq {},
399 #[error("Invalid offset")]
400 InvalidOffset {},
401}
402
403impl HasErrorCode for HandleGetError {
404 fn code(&self) -> VarInt {
405 match self {
406 HandleGetError::ExportBao {
407 source: ExportBaoError::ClientError { source, .. },
408 ..
409 } => source.code(),
410 HandleGetError::InvalidHashSeq { .. } => ERR_INTERNAL,
411 HandleGetError::InvalidOffset { .. } => ERR_INTERNAL,
412 _ => ERR_INTERNAL,
413 }
414 }
415}
416
417async fn handle_get_impl<W: SendStream>(
421 store: Store,
422 request: GetRequest,
423 writer: &mut ProgressWriter<W>,
424) -> Result<(), HandleGetError> {
425 let hash = request.hash;
426 debug!(%hash, "get received request");
427 let mut hash_seq = None;
428 for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
429 if offset == 0 {
430 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
431 } else {
432 let hash_seq = match &hash_seq {
438 Some(b) => b,
439 None => {
440 let bytes = store.get_bytes(hash).await?;
441 let hs =
442 HashSeq::try_from(bytes).map_err(|_| e!(HandleGetError::InvalidHashSeq))?;
443 hash_seq = Some(hs);
444 hash_seq.as_ref().unwrap()
445 }
446 };
447 let o = usize::try_from(offset - 1).map_err(|_| e!(HandleGetError::InvalidOffset))?;
448 let Some(hash) = hash_seq.get(o) else {
449 break;
450 };
451 send_blob(&store, offset, hash, ranges.clone(), writer).await?;
452 }
453 }
454 writer
455 .inner
456 .sync()
457 .await
458 .map_err(|e| e!(HandleGetError::ExportBao, e.into()))?;
459
460 Ok(())
461}
462
463pub async fn handle_get<R: RecvStream, W: SendStream>(
464 mut pair: StreamPair<R, W>,
465 store: Store,
466 request: GetRequest,
467) -> n0_error::Result<()> {
468 let res = pair.get_request(|| request.clone()).await;
469 let tracker = handle_read_request_result(&mut pair, res).await?;
470 let mut writer = pair.into_writer(tracker).await?;
471 let res = handle_get_impl(store, request, &mut writer).await;
472 handle_write_result(&mut writer, res).await?;
473 Ok(())
474}
475
476#[stack_error(derive, add_meta, from_sources)]
477pub enum HandleGetManyError {
478 #[error(transparent)]
479 ExportBao { source: ExportBaoError },
480}
481
482impl HasErrorCode for HandleGetManyError {
483 fn code(&self) -> VarInt {
484 match self {
485 Self::ExportBao {
486 source: ExportBaoError::ClientError { source, .. },
487 ..
488 } => source.code(),
489 _ => ERR_INTERNAL,
490 }
491 }
492}
493
494async fn handle_get_many_impl<W: SendStream>(
498 store: Store,
499 request: GetManyRequest,
500 writer: &mut ProgressWriter<W>,
501) -> Result<(), HandleGetManyError> {
502 debug!("get_many received request");
503 let request_ranges = request.ranges.iter_infinite();
504 for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
505 if !ranges.is_empty() {
506 send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
507 }
508 }
509 Ok(())
510}
511
512pub async fn handle_get_many<R: RecvStream, W: SendStream>(
513 mut pair: StreamPair<R, W>,
514 store: Store,
515 request: GetManyRequest,
516) -> n0_error::Result<()> {
517 let res = pair.get_many_request(|| request.clone()).await;
518 let tracker = handle_read_request_result(&mut pair, res).await?;
519 let mut writer = pair.into_writer(tracker).await?;
520 let res = handle_get_many_impl(store, request, &mut writer).await;
521 handle_write_result(&mut writer, res).await?;
522 Ok(())
523}
524
525#[stack_error(derive, add_meta, from_sources)]
526pub enum HandlePushError {
527 #[error(transparent)]
528 ExportBao { source: ExportBaoError },
529
530 #[error("Invalid hash sequence")]
531 InvalidHashSeq {},
532
533 #[error(transparent)]
534 Request { source: RequestError },
535}
536
537impl HasErrorCode for HandlePushError {
538 fn code(&self) -> VarInt {
539 match self {
540 Self::ExportBao {
541 source: ExportBaoError::ClientError { source, .. },
542 ..
543 } => source.code(),
544 _ => ERR_INTERNAL,
545 }
546 }
547}
548
549async fn handle_push_impl<R: RecvStream>(
553 store: Store,
554 request: PushRequest,
555 reader: &mut ProgressReader<R>,
556) -> Result<(), HandlePushError> {
557 let hash = request.hash;
558 debug!(%hash, "push received request");
559 let mut request_ranges = request.ranges.iter_infinite();
560 let root_ranges = request_ranges.next().expect("infinite iterator");
561 if !root_ranges.is_empty() {
562 store
564 .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
565 .await?;
566 }
567 if request.ranges.is_blob() {
568 debug!("push request complete");
569 return Ok(());
570 }
571 let hash_seq = store.get_bytes(hash).await?;
573 let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| e!(HandlePushError::InvalidHashSeq))?;
574 for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
575 if child_ranges.is_empty() {
576 continue;
577 }
578 store
579 .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
580 .await?;
581 }
582 Ok(())
583}
584
585pub async fn handle_push<R: RecvStream, W: SendStream>(
586 mut pair: StreamPair<R, W>,
587 store: Store,
588 request: PushRequest,
589) -> n0_error::Result<()> {
590 let res = pair.push_request(|| request.clone()).await;
591 let tracker = handle_read_request_result(&mut pair, res).await?;
592 let mut reader = pair.into_reader(tracker).await?;
593 let res = handle_push_impl(store, request, &mut reader).await;
594 handle_read_result(&mut reader, res).await?;
595 Ok(())
596}
597
598pub(crate) async fn send_blob<W: SendStream>(
600 store: &Store,
601 index: u64,
602 hash: Hash,
603 ranges: ChunkRanges,
604 writer: &mut ProgressWriter<W>,
605) -> ExportBaoResult<()> {
606 store
607 .export_bao(hash, ranges)
608 .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
609 .await
610}
611
612#[stack_error(derive, add_meta, std_sources, from_sources)]
613pub enum HandleObserveError {
614 #[error("observe stream closed")]
615 ObserveStreamClosed {},
616
617 #[error(transparent)]
618 RemoteClosed { source: io::Error },
619}
620
621impl HasErrorCode for HandleObserveError {
622 fn code(&self) -> VarInt {
623 ERR_INTERNAL
624 }
625}
626
627async fn handle_observe_impl<W: SendStream>(
631 store: Store,
632 request: ObserveRequest,
633 writer: &mut ProgressWriter<W>,
634) -> std::result::Result<(), HandleObserveError> {
635 let mut stream = store
636 .observe(request.hash)
637 .stream()
638 .await
639 .map_err(|_| e!(HandleObserveError::ObserveStreamClosed))?;
640 let mut old = stream
641 .next()
642 .await
643 .ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
644 send_observe_item(writer, &old).await?;
646 loop {
648 select! {
649 new = stream.next() => {
650 let new = new.ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
651 let diff = old.diff(&new);
652 if diff.is_empty() {
653 continue;
654 }
655 send_observe_item(writer, &diff).await?;
656 old = new;
657 }
658 _ = writer.inner.stopped() => {
659 debug!("observer closed");
660 break;
661 }
662 }
663 }
664 Ok(())
665}
666
667async fn send_observe_item<W: SendStream>(
668 writer: &mut ProgressWriter<W>,
669 item: &Bitfield,
670) -> io::Result<()> {
671 let item = ObserveItem::from(item);
672 let len = writer.inner.write_length_prefixed(item).await?;
673 writer.context.log_other_write(len);
674 Ok(())
675}
676
677pub async fn handle_observe<R: RecvStream, W: SendStream>(
678 mut pair: StreamPair<R, W>,
679 store: Store,
680 request: ObserveRequest,
681) -> n0_error::Result<()> {
682 let res = pair.observe_request(|| request.clone()).await;
683 let tracker = handle_read_request_result(&mut pair, res).await?;
684 let mut writer = pair.into_writer(tracker).await?;
685 let res = handle_observe_impl(store, request, &mut writer).await;
686 handle_write_result(&mut writer, res).await?;
687 Ok(())
688}
689
690pub struct ProgressReader<R: RecvStream = DefaultReader> {
691 inner: R,
692 context: ReaderContext,
693}
694
695impl<R: RecvStream> ProgressReader<R> {
696 async fn transfer_aborted(&self) {
697 self.context
698 .tracker
699 .transfer_aborted(|| Box::new(self.context.stats()))
700 .await
701 .ok();
702 }
703
704 async fn transfer_completed(&self) {
705 self.context
706 .tracker
707 .transfer_completed(|| Box::new(self.context.stats()))
708 .await
709 .ok();
710 }
711}