1use std::{
5 collections::BTreeMap,
6 future::{Future, IntoFuture},
7 num::NonZeroU64,
8 sync::Arc,
9};
10
11use bao_tree::{
12 io::{BaoContentItem, Leaf},
13 ChunkNum, ChunkRanges,
14};
15use genawaiter::sync::{Co, Gen};
16use iroh::endpoint::Connection;
17use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
18use n0_error::{e, stack_error, AnyError, Result, StdResultExt};
19use n0_future::{io, Stream, StreamExt};
20use ref_cast::RefCast;
21use tracing::{debug, trace};
22
23use super::blobs::{Bitfield, ExportBaoOptions};
24use crate::{
25 api::{
26 self,
27 blobs::{Blobs, WriteProgress},
28 ApiClient, Store,
29 },
30 get::{
31 fsm::{
32 AtBlobHeader, AtConnected, AtEndBlob, BlobContentNext, ConnectedNext, DecodeError,
33 EndBlobNext,
34 },
35 GetError, GetResult, Stats, StreamPair,
36 },
37 hashseq::{HashSeq, HashSeqIter},
38 protocol::{
39 ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest,
40 Request, RequestType, MAX_MESSAGE_SIZE,
41 },
42 provider::events::{ClientResult, ProgressError},
43 store::IROH_BLOCK_SIZE,
44 util::{
45 sink::{Sink, TokioMpscSenderSink},
46 RecvStream, SendStream,
47 },
48 Hash, HashAndFormat,
49};
50
51#[derive(Debug, Clone, RefCast)]
67#[repr(transparent)]
68pub struct Remote {
69 client: ApiClient,
70}
71
72#[derive(Debug)]
73pub enum GetProgressItem {
74 Progress(u64),
76 Done(Stats),
78 Error(GetError),
80}
81
82impl From<GetResult<Stats>> for GetProgressItem {
83 fn from(res: GetResult<Stats>) -> Self {
84 match res {
85 Ok(stats) => GetProgressItem::Done(stats),
86 Err(e) => GetProgressItem::Error(e),
87 }
88 }
89}
90
91impl TryFrom<GetProgressItem> for GetResult<Stats> {
92 type Error = &'static str;
93
94 fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
95 match item {
96 GetProgressItem::Done(stats) => Ok(Ok(stats)),
97 GetProgressItem::Error(e) => Ok(Err(e)),
98 GetProgressItem::Progress(_) => Err("not a final item"),
99 }
100 }
101}
102
103pub struct GetProgress {
104 rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
105 fut: n0_future::boxed::BoxFuture<()>,
106}
107
108impl IntoFuture for GetProgress {
109 type Output = GetResult<Stats>;
110 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
111
112 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
113 Box::pin(self.complete())
114 }
115}
116
117impl GetProgress {
118 pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
119 into_stream(self.rx, self.fut)
120 }
121
122 pub async fn complete(self) -> GetResult<Stats> {
123 just_result(self.stream()).await.unwrap_or_else(|| {
124 Err(e!(
125 GetError::LocalFailure,
126 n0_error::anyerr!("stream closed without result")
127 ))
128 })
129 }
130}
131
132#[derive(Debug)]
133pub enum PushProgressItem {
134 Progress(u64),
136 Done(Stats),
138 Error(AnyError),
140}
141
142impl From<Result<Stats>> for PushProgressItem {
143 fn from(res: Result<Stats>) -> Self {
144 match res {
145 Ok(stats) => Self::Done(stats),
146 Err(e) => Self::Error(e),
147 }
148 }
149}
150
151impl TryFrom<PushProgressItem> for Result<Stats> {
152 type Error = &'static str;
153
154 fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
155 match item {
156 PushProgressItem::Done(stats) => Ok(Ok(stats)),
157 PushProgressItem::Error(e) => Ok(Err(e)),
158 PushProgressItem::Progress(_) => Err("not a final item"),
159 }
160 }
161}
162
163pub struct PushProgress {
164 rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
165 fut: n0_future::boxed::BoxFuture<()>,
166}
167
168impl IntoFuture for PushProgress {
169 type Output = Result<Stats>;
170 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
171
172 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
173 Box::pin(self.complete())
174 }
175}
176
177impl PushProgress {
178 pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
179 into_stream(self.rx, self.fut)
180 }
181
182 pub async fn complete(self) -> Result<Stats> {
183 just_result(self.stream())
184 .await
185 .unwrap_or_else(|| Err(n0_error::anyerr!("stream closed without result")))
186 }
187}
188
189async fn just_result<S, R>(stream: S) -> Option<R>
190where
191 S: Stream<Item: std::fmt::Debug>,
192 R: TryFrom<S::Item>,
193{
194 tokio::pin!(stream);
195 while let Some(item) = stream.next().await {
196 if let Ok(res) = R::try_from(item) {
197 return Some(res);
198 }
199 }
200 None
201}
202
203fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
204where
205 F: Future,
206{
207 Gen::new(move |co| async move {
208 tokio::pin!(fut);
209 loop {
210 tokio::select! {
211 biased;
212 item = rx.recv() => {
213 if let Some(item) = item {
214 co.yield_(item).await;
215 } else {
216 break;
217 }
218 }
219 _ = &mut fut => {
220 break;
221 }
222 }
223 }
224 while let Some(item) = rx.recv().await {
225 co.yield_(item).await;
226 }
227 })
228}
229
230#[derive(Debug)]
235pub struct LocalInfo {
236 request: Arc<GetRequest>,
238 bitfield: Bitfield,
240 children: Option<NonRawLocalInfo>,
242}
243
244impl LocalInfo {
245 pub fn local_bytes(&self) -> u64 {
247 let Some(root_requested) = self.requested_root_ranges() else {
248 return 0;
250 };
251 let mut local = self.bitfield.clone();
252 local.ranges.intersection_with(root_requested);
253 let mut res = local.total_bytes();
254 if let Some(children) = &self.children {
255 let Some(max_local_index) = children.hash_seq.keys().next_back() else {
256 return res;
258 };
259 for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
260 if offset == 0 {
261 continue;
263 }
264 let child = offset - 1;
265 if child > *max_local_index {
266 break;
268 }
269 let Some(hash) = children.hash_seq.get(&child) else {
270 continue;
271 };
272 let bitfield = &children.bitfields[hash];
273 let mut local = bitfield.clone();
274 local.ranges.intersection_with(ranges);
275 res += local.total_bytes();
276 }
277 }
278 res
279 }
280
281 pub fn children(&self) -> Option<u64> {
283 if self.children.is_some() {
284 self.bitfield.validated_size().map(|x| x / 32)
285 } else {
286 Some(0)
287 }
288 }
289
290 fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
295 self.request.ranges.iter().next()
296 }
297
298 pub fn is_complete(&self) -> bool {
304 let Some(root_requested) = self.requested_root_ranges() else {
305 return true;
307 };
308 if !self.bitfield.ranges.is_superset(root_requested) {
309 return false;
310 }
311 if let Some(children) = self.children.as_ref() {
312 let mut iter = self.request.ranges.iter_non_empty_infinite();
313 let max_child = self.bitfield.validated_size().map(|x| x / 32);
314 loop {
315 let Some((offset, range)) = iter.next() else {
316 break;
317 };
318 if offset == 0 {
319 continue;
321 }
322 let child = offset - 1;
323 if let Some(hash) = children.hash_seq.get(&child) {
324 let bitfield = &children.bitfields[hash];
325 if !bitfield.ranges.is_superset(range) {
326 return false;
328 }
329 } else {
330 if let Some(max_child) = max_child {
331 if child >= max_child {
332 return true;
334 }
335 }
336 return false;
337 }
338 }
339 }
340 true
341 }
342
343 pub fn missing(&self) -> GetRequest {
345 let Some(root_requested) = self.requested_root_ranges() else {
346 return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
348 };
349 let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
350
351 let Some(children) = self.children.as_ref() else {
352 return builder.build(self.request.hash);
353 };
354 let mut iter = self.request.ranges.iter_non_empty_infinite();
355 let max_local = children
356 .hash_seq
357 .keys()
358 .next_back()
359 .map(|x| *x + 1)
360 .unwrap_or_default();
361 let max_offset = self.bitfield.validated_size().map(|x| x / 32);
362 loop {
363 let Some((offset, requested)) = iter.next() else {
364 break;
365 };
366 if offset == 0 {
367 continue;
369 }
370 let child = offset - 1;
371 let missing = match children.hash_seq.get(&child) {
372 Some(hash) => requested.difference(&children.bitfields[hash].ranges),
373 None => requested.clone(),
374 };
375 builder = builder.child(child, missing);
376 if offset >= max_local {
377 break;
379 }
380 }
381 loop {
382 let Some((offset, requested)) = iter.next() else {
383 return builder.build(self.request.hash);
384 };
385 if offset == 0 {
386 continue;
388 }
389 let child = offset - 1;
390 if let Some(max_offset) = &max_offset {
391 if child >= *max_offset {
392 return builder.build(self.request.hash);
393 }
394 builder = builder.child(child, requested.clone());
395 } else {
396 builder = builder.child(child, requested.clone());
397 if iter.is_at_end() {
398 if iter.next().is_none() {
399 return builder.build(self.request.hash);
400 } else {
401 return builder.build_open(self.request.hash);
402 }
403 }
404 }
405 }
406 }
407}
408
409#[derive(Debug)]
410struct NonRawLocalInfo {
411 hash_seq: BTreeMap<u64, Hash>,
413 bitfields: BTreeMap<Hash, Bitfield>,
416}
417
418impl Remote {
433 pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
434 Self::ref_cast(sender)
435 }
436
437 fn store(&self) -> &Store {
438 Store::ref_from_sender(&self.client)
439 }
440
441 pub async fn local_for_request(
442 &self,
443 request: impl Into<Arc<GetRequest>>,
444 ) -> Result<LocalInfo> {
445 let request = request.into();
446 let root = request.hash;
447 let bitfield = self.store().observe(root).await?;
448 let children = if !request.ranges.is_blob() {
449 let opts = ExportBaoOptions {
450 hash: root,
451 ranges: bitfield.ranges.clone(),
452 };
453 let bao = self.store().export_bao_with_opts(opts, 32);
454 let mut by_index = BTreeMap::new();
455 let mut stream = bao.hashes_with_index();
456 while let Some(item) = stream.next().await {
457 if let Ok((index, hash)) = item {
458 by_index.insert(index, hash);
459 }
460 }
461 let mut bitfields = BTreeMap::new();
462 let mut hash_seq = BTreeMap::new();
463 let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
464 for (index, _) in request.ranges.iter_non_empty_infinite() {
465 if index == 0 {
466 continue;
468 }
469 let child = index - 1;
470 if child > max {
471 break;
473 }
474 let Some(hash) = by_index.get(&child) else {
475 continue;
477 };
478 let bitfield = self.store().observe(*hash).await?;
479 bitfields.insert(*hash, bitfield);
480 hash_seq.insert(child, *hash);
481 }
482 Some(NonRawLocalInfo {
483 hash_seq,
484 bitfields,
485 })
486 } else {
487 None
488 };
489 Ok(LocalInfo {
490 request: request.clone(),
491 bitfield,
492 children,
493 })
494 }
495
496 pub async fn local(&self, content: impl Into<HashAndFormat>) -> Result<LocalInfo> {
498 let request = GetRequest::from(content.into());
499 self.local_for_request(request).await
500 }
501
502 pub fn fetch(
503 &self,
504 sp: impl GetStreamPair + 'static,
505 content: impl Into<HashAndFormat>,
506 ) -> GetProgress {
507 let content = content.into();
508 let (tx, rx) = tokio::sync::mpsc::channel(64);
509 let tx2 = tx.clone();
510 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
511 let this = self.clone();
512 let fut = async move {
513 let res = this.fetch_sink(sp, content, sink).await.into();
514 tx2.send(res).await.ok();
515 };
516 GetProgress {
517 rx,
518 fut: Box::pin(fut),
519 }
520 }
521
522 pub(crate) async fn fetch_sink(
530 &self,
531 sp: impl GetStreamPair,
532 content: impl Into<HashAndFormat>,
533 progress: impl Sink<u64, Error = irpc::channel::SendError>,
534 ) -> GetResult<Stats> {
535 let content = content.into();
536 let local = self
537 .local(content)
538 .await
539 .map_err(|e| e!(GetError::LocalFailure, e))?;
540 if local.is_complete() {
541 return Ok(Default::default());
542 }
543 let request = local.missing();
544 let stats = self.execute_get_sink(sp, request, progress).await?;
545 Ok(stats)
546 }
547
548 pub fn observe(
549 &self,
550 conn: Connection,
551 request: ObserveRequest,
552 ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
553 Gen::new(|co| async move {
554 if let Err(cause) = Self::observe_impl(conn, request, &co).await {
555 co.yield_(Err(cause)).await
556 }
557 })
558 }
559
560 async fn observe_impl(
561 conn: Connection,
562 request: ObserveRequest,
563 co: &Co<io::Result<Bitfield>>,
564 ) -> io::Result<()> {
565 let hash = request.hash;
566 debug!(%hash, "observing");
567 let (mut send, mut recv) = conn.open_bi().await?;
568 write_observe_request(request, &mut send).await?;
570 send.finish()?;
571 loop {
572 let msg = recv
573 .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
574 .await?;
575 co.yield_(Ok(Bitfield::from(&msg))).await;
576 }
577 }
578
579 pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
580 let (tx, rx) = tokio::sync::mpsc::channel(64);
581 let tx2 = tx.clone();
582 let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress);
583 let this = self.clone();
584 let fut = async move {
585 let res = this.execute_push_sink(conn, request, sink).await.into();
586 tx2.send(res).await.ok();
587 };
588 PushProgress {
589 rx,
590 fut: Box::pin(fut),
591 }
592 }
593
594 pub(crate) async fn execute_push_sink(
598 &self,
599 conn: Connection,
600 request: PushRequest,
601 progress: impl Sink<u64, Error = irpc::channel::SendError>,
602 ) -> Result<Stats> {
603 let hash = request.hash;
604 debug!(%hash, "pushing");
605 let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
606 let mut context = StreamContext {
607 payload_bytes_sent: 0,
608 sender: progress,
609 };
610 recv.stop(0u32.into()).anyerr()?;
612 let request = write_push_request(request, &mut send).await?;
614 let mut request_ranges = request.ranges.iter_infinite();
615 let root = request.hash;
616 let root_ranges = request_ranges.next().expect("infinite iterator");
617 if !root_ranges.is_empty() {
618 self.store()
619 .export_bao(root, root_ranges.clone())
620 .write_with_progress(&mut send, &mut context, &root, 0)
621 .await?;
622 }
623 if request.ranges.is_blob() {
624 send.finish().anyerr()?;
626 return Ok(Default::default());
627 }
628 let hash_seq = self.store().get_bytes(root).await?;
629 let hash_seq = HashSeq::try_from(hash_seq)?;
630 for (child, (child_hash, child_ranges)) in
631 hash_seq.into_iter().zip(request_ranges).enumerate()
632 {
633 if !child_ranges.is_empty() {
634 self.store()
635 .export_bao(child_hash, child_ranges.clone())
636 .write_with_progress(&mut send, &mut context, &child_hash, (child + 1) as u64)
637 .await?;
638 }
639 }
640 send.finish().anyerr()?;
641 Ok(Default::default())
642 }
643
644 pub fn execute_get(&self, conn: impl GetStreamPair, request: GetRequest) -> GetProgress {
645 self.execute_get_with_opts(conn, request)
646 }
647
648 pub fn execute_get_with_opts(
649 &self,
650 conn: impl GetStreamPair,
651 request: GetRequest,
652 ) -> GetProgress {
653 let (tx, rx) = tokio::sync::mpsc::channel(64);
654 let tx2 = tx.clone();
655 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
656 let this = self.clone();
657 let fut = async move {
658 let res = this.execute_get_sink(conn, request, sink).await.into();
659 tx2.send(res).await.ok();
660 };
661 GetProgress {
662 rx,
663 fut: Box::pin(fut),
664 }
665 }
666
667 pub(crate) async fn execute_get_sink(
676 &self,
677 conn: impl GetStreamPair,
678 request: GetRequest,
679 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
680 ) -> GetResult<Stats> {
681 let store = self.store();
682 let root = request.hash;
683 let conn = conn.open_stream_pair().await.map_err(|e| {
684 e!(
685 GetError::LocalFailure,
686 n0_error::anyerr!("failed to open stream pair: {e}")
687 )
688 })?;
689 let connected =
692 AtConnected::new(conn.t0, conn.recv, conn.send, request, Default::default());
693 trace!("Getting header");
694 let next_child = match connected
696 .next()
697 .await
698 .map_err(|e| e!(GetError::ConnectedNext, e))?
699 {
700 ConnectedNext::StartRoot(at_start_root) => {
701 let header = at_start_root.next();
702 let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
703 match end.next() {
704 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
705 EndBlobNext::Closing(at_closing) => Err(at_closing),
706 }
707 }
708 ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
709 ConnectedNext::Closing(at_closing) => Err(at_closing),
710 };
711 let at_closing = match next_child {
713 Ok(at_start_child) => {
714 let mut next_child = Ok(at_start_child);
715 let hash_seq = HashSeq::try_from(
716 store
717 .get_bytes(root)
718 .await
719 .map_err(|e| e!(GetError::LocalFailure, e.into()))?,
720 )
721 .map_err(|e| e!(GetError::BadRequest, e))?;
722 loop {
724 let at_start_child = match next_child {
725 Ok(at_start_child) => at_start_child,
726 Err(at_closing) => break at_closing,
727 };
728 let offset = at_start_child.offset() - 1;
729 let Some(hash) = hash_seq.get(offset as usize) else {
730 break at_start_child.finish();
731 };
732 trace!("getting child {offset} {}", hash.fmt_short());
733 let header = at_start_child.next(hash);
734 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
735 next_child = match end.next() {
736 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
737 EndBlobNext::Closing(at_closing) => Err(at_closing),
738 }
739 }
740 }
741 Err(at_closing) => at_closing,
742 };
743 let stats = at_closing
745 .next()
746 .await
747 .map_err(|e| e!(GetError::AtClosingNext, e))?;
748 trace!(?stats, "get hash seq done");
749 Ok(stats)
750 }
751
752 pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
753 let (tx, rx) = tokio::sync::mpsc::channel(64);
754 let tx2 = tx.clone();
755 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
756 let this = self.clone();
757 let fut = async move {
758 let res = this.execute_get_many_sink(conn, request, sink).await.into();
759 tx2.send(res).await.ok();
760 };
761 GetProgress {
762 rx,
763 fut: Box::pin(fut),
764 }
765 }
766
767 pub async fn execute_get_many_sink(
776 &self,
777 conn: Connection,
778 request: GetManyRequest,
779 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
780 ) -> GetResult<Stats> {
781 let store = self.store();
782 let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
783 let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
784 let at_closing = match next_child {
786 Ok(at_start_child) => {
787 let mut next_child = Ok(at_start_child);
788 loop {
789 let at_start_child = match next_child {
790 Ok(at_start_child) => at_start_child,
791 Err(at_closing) => break at_closing,
792 };
793 let offset = at_start_child.offset();
794 let Some(hash) = hash_seq.get(offset as usize) else {
795 break at_start_child.finish();
796 };
797 trace!("getting child {offset} {}", hash.fmt_short());
798 let header = at_start_child.next(hash);
799 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
800 next_child = match end.next() {
801 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
802 EndBlobNext::Closing(at_closing) => Err(at_closing),
803 }
804 }
805 }
806 Err(at_closing) => at_closing,
807 };
808 let stats = at_closing
810 .next()
811 .await
812 .map_err(|e| e!(GetError::AtClosingNext, e))?;
813 trace!(?stats, "get hash seq done");
814 Ok(stats)
815 }
816}
817
818#[allow(missing_docs)]
820#[non_exhaustive]
821#[stack_error(derive, add_meta)]
822pub enum ExecuteError {
823 #[error("Unable to open bidi stream")]
825 Connection {
826 #[error(std_err)]
827 source: iroh::endpoint::ConnectionError,
828 },
829 #[error("Unable to read from the remote")]
830 Read {
831 #[error(std_err)]
832 source: iroh::endpoint::ReadError,
833 },
834 #[error("Error sending the request")]
835 Send {
836 #[error(std_err)]
837 source: crate::get::fsm::ConnectedNextError,
838 },
839 #[error("Unable to read size")]
840 Size {
841 #[error(std_err)]
842 source: crate::get::fsm::AtBlobHeaderNextError,
843 },
844 #[error("Error while decoding the data")]
845 Decode {
846 #[error(std_err)]
847 source: crate::get::fsm::DecodeError,
848 },
849 #[error("Internal error while reading the hash sequence")]
850 ExportBao { source: api::ExportBaoError },
851 #[error("Hash sequence has an invalid length")]
852 InvalidHashSeq { source: AnyError },
853 #[error("Internal error importing the data")]
854 ImportBao { source: crate::api::RequestError },
855 #[error("Error sending download progress - receiver closed")]
856 SendDownloadProgress { source: irpc::channel::SendError },
857 #[error("Internal error importing the data")]
858 MpscSend {
859 #[error(std_err)]
860 source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
861 },
862}
863
864pub trait GetStreamPair: Send + 'static {
865 fn open_stream_pair(
866 self,
867 ) -> impl Future<Output = io::Result<StreamPair<impl RecvStream, impl SendStream>>> + Send + 'static;
868}
869
870impl<R: RecvStream + 'static, W: SendStream + 'static> GetStreamPair for StreamPair<R, W> {
871 async fn open_stream_pair(self) -> io::Result<StreamPair<impl RecvStream, impl SendStream>> {
872 Ok(self)
873 }
874}
875
876impl GetStreamPair for Connection {
877 async fn open_stream_pair(
878 self,
879 ) -> io::Result<StreamPair<impl crate::util::RecvStream, impl crate::util::SendStream>> {
880 let connection_id = self.stable_id() as u64;
881 let (send, recv) = self.open_bi().await?;
882 Ok(StreamPair::new(connection_id, recv, send))
883 }
884}
885
886fn get_buffer_size(size: NonZeroU64) -> usize {
887 (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
888}
889
890async fn get_blob_ranges_impl<R: RecvStream>(
891 header: AtBlobHeader<R>,
892 hash: Hash,
893 store: &Store,
894 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
895) -> GetResult<AtEndBlob<R>> {
896 let (mut content, size) = header
897 .next()
898 .await
899 .map_err(|e| e!(GetError::AtBlobHeaderNext, e))?;
900 let Some(size) = NonZeroU64::new(size) else {
901 return if hash == Hash::EMPTY {
902 let end = content.drain().await.map_err(|e| e!(GetError::Decode, e))?;
903 Ok(end)
904 } else {
905 Err(e!(
906 GetError::Decode,
907 DecodeError::leaf_hash_mismatch(ChunkNum(0))
908 ))
909 };
910 };
911 let buffer_size = get_buffer_size(size);
912 trace!(%size, %buffer_size, "get blob");
913 let handle = store
914 .import_bao(hash, size, buffer_size)
915 .await
916 .map_err(|e| e!(GetError::LocalFailure, e.into()))?;
917 let write = async move {
918 GetResult::Ok(loop {
919 match content.next().await {
920 BlobContentNext::More((next, res)) => {
921 let item = res.map_err(|e| e!(GetError::Decode, e))?;
922 progress
923 .send(next.stats().payload_bytes_read)
924 .await
925 .map_err(|e| e!(GetError::LocalFailure, e.into()))?;
926 handle
927 .tx
928 .send(item)
929 .await
930 .map_err(|e| e!(GetError::IrpcSend, e))?;
931 content = next;
932 }
933 BlobContentNext::Done(end) => {
934 drop(handle.tx);
935 break end;
936 }
937 }
938 })
939 };
940 let complete = async move {
941 handle.rx.await.map_err(|e| {
942 e!(
943 GetError::LocalFailure,
944 n0_error::anyerr!("error reading from import stream: {e}")
945 )
946 })
947 };
948 let (_, end) = tokio::try_join!(complete, write)?;
949 Ok(end)
950}
951
952#[derive(Debug)]
953pub(crate) struct LazyHashSeq {
954 blobs: Blobs,
955 hash: Hash,
956 current_chunk: Option<HashSeqChunk>,
957}
958
959#[derive(Debug)]
960pub(crate) struct HashSeqChunk {
961 offset: u64,
963 chunk: HashSeq,
965}
966
967impl TryFrom<Leaf> for HashSeqChunk {
968 type Error = AnyError;
969
970 fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
971 let offset = leaf.offset;
972 let chunk = HashSeq::try_from(leaf.data)?;
973 Ok(Self { offset, chunk })
974 }
975}
976
977impl IntoIterator for HashSeqChunk {
978 type Item = Hash;
979 type IntoIter = HashSeqIter;
980
981 fn into_iter(self) -> Self::IntoIter {
982 self.chunk.into_iter()
983 }
984}
985
986impl HashSeqChunk {
987 pub fn base(&self) -> u64 {
988 self.offset / 32
989 }
990
991 #[allow(dead_code)]
992 fn get(&self, offset: u64) -> Option<Hash> {
993 let start = self.offset;
994 let end = start + self.chunk.len() as u64;
995 if offset >= start && offset < end {
996 let o = (offset - start) as usize;
997 self.chunk.get(o)
998 } else {
999 None
1000 }
1001 }
1002}
1003
1004impl LazyHashSeq {
1005 #[allow(dead_code)]
1006 pub fn new(blobs: Blobs, hash: Hash) -> Self {
1007 Self {
1008 blobs,
1009 hash,
1010 current_chunk: None,
1011 }
1012 }
1013
1014 #[allow(dead_code)]
1015 pub async fn get_from_offset(&mut self, offset: u64) -> Result<Option<Hash>> {
1016 if offset == 0 {
1017 Ok(Some(self.hash))
1018 } else {
1019 self.get(offset - 1).await
1020 }
1021 }
1022
1023 #[allow(dead_code)]
1024 pub async fn get(&mut self, child_offset: u64) -> Result<Option<Hash>> {
1025 if let Some(chunk) = &self.current_chunk {
1027 if let Some(hash) = chunk.get(child_offset) {
1028 return Ok(Some(hash));
1029 }
1030 }
1031 let leaf = self
1033 .blobs
1034 .export_chunk(self.hash, child_offset * 32)
1035 .await?;
1036 let hs = HashSeqChunk::try_from(leaf)?;
1038 Ok(hs.get(child_offset).inspect(|_hash| {
1039 self.current_chunk = Some(hs);
1040 }))
1041 }
1042}
1043
1044async fn write_push_request(
1045 request: PushRequest,
1046 stream: &mut impl SendStream,
1047) -> Result<PushRequest> {
1048 let mut request_bytes = Vec::new();
1049 request_bytes.push(RequestType::Push as u8);
1050 request_bytes.write_length_prefixed(&request).unwrap();
1051 stream.send_bytes(request_bytes.into()).await?;
1052 Ok(request)
1053}
1054
1055async fn write_observe_request(
1056 request: ObserveRequest,
1057 stream: &mut impl SendStream,
1058) -> io::Result<()> {
1059 let request = Request::Observe(request);
1060 let request_bytes = postcard::to_allocvec(&request)
1061 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1062 stream.send_bytes(request_bytes.into()).await?;
1063 Ok(())
1064}
1065
1066struct StreamContext<S> {
1067 payload_bytes_sent: u64,
1068 sender: S,
1069}
1070
1071impl<S> WriteProgress for StreamContext<S>
1072where
1073 S: Sink<u64, Error = irpc::channel::SendError>,
1074{
1075 async fn notify_payload_write(
1076 &mut self,
1077 _index: u64,
1078 _offset: u64,
1079 len: usize,
1080 ) -> ClientResult {
1081 self.payload_bytes_sent += len as u64;
1082 self.sender
1083 .send(self.payload_bytes_sent)
1084 .await
1085 .map_err(|e| n0_error::e!(ProgressError::Internal, e.into()))?;
1086 Ok(())
1087 }
1088
1089 fn log_other_write(&mut self, _len: usize) {}
1090
1091 async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1092}
1093
1094#[cfg(test)]
1095#[cfg(feature = "fs-store")]
1096mod tests {
1097 use bao_tree::{ChunkNum, ChunkRanges};
1098 use testresult::TestResult;
1099
1100 use crate::{
1101 api::blobs::Blobs,
1102 protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
1103 store::{
1104 fs::{
1105 tests::{test_data, INTERESTING_SIZES},
1106 FsStore,
1107 },
1108 mem::MemStore,
1109 util::tests::create_n0_bao,
1110 },
1111 tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1112 };
1113
1114 #[tokio::test]
1115 async fn test_local_info_raw() -> TestResult<()> {
1116 let td = tempfile::tempdir()?;
1117 let store = FsStore::load(td.path().join("blobs.db")).await?;
1118 let blobs = store.blobs();
1119 let tt = blobs.add_slice(b"test").temp_tag().await?;
1120 let hash = tt.hash();
1121 let info = store.remote().local(hash).await?;
1122 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1123 assert_eq!(info.local_bytes(), 4);
1124 assert!(info.is_complete());
1125 assert_eq!(
1126 info.missing(),
1127 GetRequest::new(hash, ChunkRangesSeq::empty())
1128 );
1129 Ok(())
1130 }
1131
1132 #[tokio::test]
1133 async fn test_local_info_hash_seq_large() -> TestResult<()> {
1134 let sizes = (0..1024 + 5).collect::<Vec<_>>();
1135 let relevant_sizes = sizes[32 * 16..32 * 32]
1136 .iter()
1137 .map(|x| *x as u64)
1138 .sum::<u64>();
1139 let td = tempfile::tempdir()?;
1140 let hash_seq_ranges = ChunkRanges::chunks(16..32);
1141 let store = FsStore::load(td.path().join("blobs.db")).await?;
1142 {
1143 let present = |i| {
1145 if i == 0 {
1146 hash_seq_ranges.clone()
1147 } else {
1148 ChunkRanges::from(..ChunkNum(1))
1149 }
1150 };
1151 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1152 let info = store.remote().local(content).await?;
1153 assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1154 assert!(!info.is_complete());
1155 assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1156 }
1157
1158 Ok(())
1159 }
1160
1161 async fn test_observe_partial(blobs: &Blobs) -> TestResult<()> {
1162 let sizes = INTERESTING_SIZES;
1163 for size in sizes {
1164 let data = test_data(size);
1165 let ranges = ChunkRanges::chunk(0);
1166 let (hash, bao) = create_n0_bao(&data, &ranges)?;
1167 blobs.import_bao_bytes(hash, ranges.clone(), bao).await?;
1168 let bitfield = blobs.observe(hash).await?;
1169 if size > 1024 {
1170 assert_eq!(bitfield.ranges, ranges);
1171 } else {
1172 assert_eq!(bitfield.ranges, ChunkRanges::all());
1173 }
1174 }
1175 Ok(())
1176 }
1177
1178 #[tokio::test]
1179 async fn test_observe_partial_mem() -> TestResult<()> {
1180 let store = MemStore::new();
1181 test_observe_partial(store.blobs()).await?;
1182 Ok(())
1183 }
1184
1185 #[tokio::test]
1186 async fn test_observe_partial_fs() -> TestResult<()> {
1187 let td = tempfile::tempdir()?;
1188 let store = FsStore::load(td.path()).await?;
1189 test_observe_partial(store.blobs()).await?;
1190 Ok(())
1191 }
1192
1193 #[tokio::test]
1194 async fn test_local_info_hash_seq() -> TestResult<()> {
1195 let sizes = INTERESTING_SIZES;
1196 let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1197 let hash_seq_size = (sizes.len() as u64) * 32;
1198 let td = tempfile::tempdir()?;
1199 let store = FsStore::load(td.path().join("blobs.db")).await?;
1200 {
1201 let present = |i| {
1203 if i == 0 {
1204 ChunkRanges::all()
1205 } else {
1206 ChunkRanges::empty()
1207 }
1208 };
1209 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1210 let info = store.remote().local(content).await?;
1211 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1212 assert_eq!(info.local_bytes(), hash_seq_size);
1213 assert!(!info.is_complete());
1214 assert_eq!(
1215 info.missing(),
1216 GetRequest::new(
1217 content.hash,
1218 ChunkRangesSeq::from_ranges([
1219 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::all(), ChunkRanges::all(),
1223 ChunkRanges::all(),
1224 ChunkRanges::all(),
1225 ChunkRanges::all(),
1226 ChunkRanges::all(),
1227 ChunkRanges::all(),
1228 ])
1229 )
1230 );
1231 store.tags().delete_all().await?;
1232 }
1233 {
1234 let present = |i| {
1236 if i == 0 {
1237 ChunkRanges::all()
1238 } else {
1239 ChunkRanges::from(..ChunkNum(1))
1240 }
1241 };
1242 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1243 let info = store.remote().local(content).await?;
1244 let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1245 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1246 assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1247 assert!(!info.is_complete());
1248 assert_eq!(
1249 info.missing(),
1250 GetRequest::new(
1251 content.hash,
1252 ChunkRangesSeq::from_ranges([
1253 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::chunks(1..),
1258 ChunkRanges::chunks(1..),
1259 ChunkRanges::chunks(1..),
1260 ChunkRanges::chunks(1..),
1261 ChunkRanges::chunks(1..),
1262 ])
1263 )
1264 );
1265 }
1266 {
1267 let content = add_test_hash_seq(&store, sizes).await?;
1268 let info = store.remote().local(content).await?;
1269 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1270 assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1271 assert!(info.is_complete());
1272 assert_eq!(
1273 info.missing(),
1274 GetRequest::new(content.hash, ChunkRangesSeq::empty())
1275 );
1276 }
1277 Ok(())
1278 }
1279
1280 #[tokio::test]
1281 async fn test_local_info_complex_request() -> TestResult<()> {
1282 let sizes = INTERESTING_SIZES;
1283 let hash_seq_size = (sizes.len() as u64) * 32;
1284 let td = tempfile::tempdir()?;
1285 let store = FsStore::load(td.path().join("blobs.db")).await?;
1286 let present = |i| {
1288 if i == 0 {
1289 ChunkRanges::all()
1290 } else {
1291 ChunkRanges::chunks(..2)
1292 }
1293 };
1294 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1295 {
1296 let request: GetRequest = GetRequest::builder()
1297 .root(ChunkRanges::all())
1298 .build(content.hash);
1299 let info = store.remote().local_for_request(request).await?;
1300 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1301 assert_eq!(info.local_bytes(), hash_seq_size);
1302 assert!(info.is_complete());
1303 }
1304 {
1305 let request: GetRequest = GetRequest::builder()
1306 .root(ChunkRanges::all())
1307 .next(ChunkRanges::all())
1308 .build(content.hash);
1309 let info = store.remote().local_for_request(request).await?;
1310 let expected_child_sizes = sizes
1311 .into_iter()
1312 .take(1)
1313 .map(|x| 1024.min(x as u64))
1314 .sum::<u64>();
1315 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1316 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1317 assert!(info.is_complete());
1318 }
1319 {
1320 let request: GetRequest = GetRequest::builder()
1321 .root(ChunkRanges::all())
1322 .next(ChunkRanges::all())
1323 .next(ChunkRanges::all())
1324 .build(content.hash);
1325 let info = store.remote().local_for_request(request).await?;
1326 let expected_child_sizes = sizes
1327 .into_iter()
1328 .take(2)
1329 .map(|x| 1024.min(x as u64))
1330 .sum::<u64>();
1331 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1332 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1333 assert!(info.is_complete());
1334 }
1335 {
1336 let request: GetRequest = GetRequest::builder()
1337 .root(ChunkRanges::all())
1338 .next(ChunkRanges::chunk(0))
1339 .build_open(content.hash);
1340 let info = store.remote().local_for_request(request).await?;
1341 let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1342 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1343 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1344 assert!(info.is_complete());
1345 }
1346 Ok(())
1347 }
1348}