1use std::{
5 collections::BTreeMap,
6 future::{Future, IntoFuture},
7 num::NonZeroU64,
8 sync::Arc,
9};
10
11use bao_tree::{
12 io::{BaoContentItem, Leaf},
13 ChunkNum, ChunkRanges,
14};
15use genawaiter::sync::{Co, Gen};
16use iroh::endpoint::Connection;
17use irpc::util::{AsyncReadVarintExt, WriteVarintExt};
18use n0_error::{e, stack_error, AnyError, Result, StdResultExt};
19use n0_future::{io, Stream, StreamExt};
20use ref_cast::RefCast;
21use tracing::{debug, trace};
22
23use super::blobs::{Bitfield, ExportBaoOptions};
24use crate::{
25 api::{
26 self,
27 blobs::{Blobs, WriteProgress},
28 ApiClient, Store,
29 },
30 get::{
31 fsm::{
32 AtBlobHeader, AtConnected, AtEndBlob, BlobContentNext, ConnectedNext, DecodeError,
33 EndBlobNext,
34 },
35 GetError, GetResult, Stats, StreamPair,
36 },
37 hashseq::{HashSeq, HashSeqIter},
38 protocol::{
39 ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest,
40 Request, RequestType, MAX_MESSAGE_SIZE,
41 },
42 provider::events::{ClientResult, ProgressError},
43 store::IROH_BLOCK_SIZE,
44 util::{
45 sink::{Sink, TokioMpscSenderSink},
46 RecvStream, SendStream,
47 },
48 Hash, HashAndFormat,
49};
50
51#[derive(Debug, Clone, RefCast)]
67#[repr(transparent)]
68pub struct Remote {
69 client: ApiClient,
70}
71
72#[derive(Debug)]
73pub enum GetProgressItem {
74 Progress(u64),
76 Done(Stats),
78 Error(GetError),
80}
81
82impl From<GetResult<Stats>> for GetProgressItem {
83 fn from(res: GetResult<Stats>) -> Self {
84 match res {
85 Ok(stats) => GetProgressItem::Done(stats),
86 Err(e) => GetProgressItem::Error(e),
87 }
88 }
89}
90
91impl TryFrom<GetProgressItem> for GetResult<Stats> {
92 type Error = &'static str;
93
94 fn try_from(item: GetProgressItem) -> Result<Self, Self::Error> {
95 match item {
96 GetProgressItem::Done(stats) => Ok(Ok(stats)),
97 GetProgressItem::Error(e) => Ok(Err(e)),
98 GetProgressItem::Progress(_) => Err("not a final item"),
99 }
100 }
101}
102
103pub struct GetProgress {
104 rx: tokio::sync::mpsc::Receiver<GetProgressItem>,
105 fut: n0_future::boxed::BoxFuture<()>,
106}
107
108impl IntoFuture for GetProgress {
109 type Output = GetResult<Stats>;
110 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
111
112 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
113 Box::pin(self.complete())
114 }
115}
116
117impl GetProgress {
118 pub fn stream(self) -> impl Stream<Item = GetProgressItem> {
119 into_stream(self.rx, self.fut)
120 }
121
122 pub async fn complete(self) -> GetResult<Stats> {
123 just_result(self.stream()).await.unwrap_or_else(|| {
124 Err(e!(
125 GetError::LocalFailure,
126 n0_error::anyerr!("stream closed without result")
127 ))
128 })
129 }
130}
131
132#[derive(Debug)]
133pub enum PushProgressItem {
134 Progress(u64),
136 Done(Stats),
138 Error(AnyError),
140}
141
142impl From<Result<Stats>> for PushProgressItem {
143 fn from(res: Result<Stats>) -> Self {
144 match res {
145 Ok(stats) => Self::Done(stats),
146 Err(e) => Self::Error(e),
147 }
148 }
149}
150
151impl TryFrom<PushProgressItem> for Result<Stats> {
152 type Error = &'static str;
153
154 fn try_from(item: PushProgressItem) -> Result<Self, Self::Error> {
155 match item {
156 PushProgressItem::Done(stats) => Ok(Ok(stats)),
157 PushProgressItem::Error(e) => Ok(Err(e)),
158 PushProgressItem::Progress(_) => Err("not a final item"),
159 }
160 }
161}
162
163pub struct PushProgress {
164 rx: tokio::sync::mpsc::Receiver<PushProgressItem>,
165 fut: n0_future::boxed::BoxFuture<()>,
166}
167
168impl IntoFuture for PushProgress {
169 type Output = Result<Stats>;
170 type IntoFuture = n0_future::boxed::BoxFuture<Self::Output>;
171
172 fn into_future(self) -> n0_future::boxed::BoxFuture<Self::Output> {
173 Box::pin(self.complete())
174 }
175}
176
177impl PushProgress {
178 pub fn stream(self) -> impl Stream<Item = PushProgressItem> {
179 into_stream(self.rx, self.fut)
180 }
181
182 pub async fn complete(self) -> Result<Stats> {
183 just_result(self.stream())
184 .await
185 .unwrap_or_else(|| Err(n0_error::anyerr!("stream closed without result")))
186 }
187}
188
189async fn just_result<S, R>(stream: S) -> Option<R>
190where
191 S: Stream<Item: std::fmt::Debug>,
192 R: TryFrom<S::Item>,
193{
194 tokio::pin!(stream);
195 while let Some(item) = stream.next().await {
196 if let Ok(res) = R::try_from(item) {
197 return Some(res);
198 }
199 }
200 None
201}
202
203fn into_stream<T, F>(mut rx: tokio::sync::mpsc::Receiver<T>, fut: F) -> impl Stream<Item = T>
204where
205 F: Future,
206{
207 Gen::new(move |co| async move {
208 tokio::pin!(fut);
209 loop {
210 tokio::select! {
211 biased;
212 item = rx.recv() => {
213 if let Some(item) = item {
214 co.yield_(item).await;
215 } else {
216 break;
217 }
218 }
219 _ = &mut fut => {
220 break;
221 }
222 }
223 }
224 while let Some(item) = rx.recv().await {
225 co.yield_(item).await;
226 }
227 })
228}
229
230#[derive(Debug)]
235pub struct LocalInfo {
236 request: Arc<GetRequest>,
238 bitfield: Bitfield,
240 children: Option<NonRawLocalInfo>,
242}
243
244impl LocalInfo {
245 pub fn local_bytes(&self) -> u64 {
247 let Some(root_requested) = self.requested_root_ranges() else {
248 return 0;
250 };
251 let mut local = self.bitfield.clone();
252 local.ranges.intersection_with(root_requested);
253 let mut res = local.total_bytes();
254 if let Some(children) = &self.children {
255 let Some(max_local_index) = children.hash_seq.keys().next_back() else {
256 return res;
258 };
259 for (offset, ranges) in self.request.ranges.iter_non_empty_infinite() {
260 if offset == 0 {
261 continue;
263 }
264 let child = offset - 1;
265 if child > *max_local_index {
266 break;
268 }
269 let Some(hash) = children.hash_seq.get(&child) else {
270 continue;
271 };
272 let bitfield = &children.bitfields[hash];
273 let mut local = bitfield.clone();
274 local.ranges.intersection_with(ranges);
275 res += local.total_bytes();
276 }
277 }
278 res
279 }
280
281 pub fn children(&self) -> Option<u64> {
283 if self.children.is_some() {
284 self.bitfield.validated_size().map(|x| x / 32)
285 } else {
286 Some(0)
287 }
288 }
289
290 fn requested_root_ranges(&self) -> Option<&ChunkRanges> {
295 self.request.ranges.iter().next()
296 }
297
298 pub fn is_complete(&self) -> bool {
304 let Some(root_requested) = self.requested_root_ranges() else {
305 return true;
307 };
308 if !self.bitfield.ranges.is_superset(root_requested) {
309 return false;
310 }
311 if let Some(children) = self.children.as_ref() {
312 let iter = self.request.ranges.iter_non_empty_infinite();
313 let max_child = self.bitfield.validated_size().map(|x| x / 32);
314 for (offset, range) in iter {
315 if offset == 0 {
316 continue;
318 }
319 let child = offset - 1;
320 if let Some(hash) = children.hash_seq.get(&child) {
321 let bitfield = &children.bitfields[hash];
322 if !bitfield.ranges.is_superset(range) {
323 return false;
325 }
326 } else {
327 if let Some(max_child) = max_child {
328 if child >= max_child {
329 return true;
331 }
332 }
333 return false;
334 }
335 }
336 }
337 true
338 }
339
340 pub fn missing(&self) -> GetRequest {
342 let Some(root_requested) = self.requested_root_ranges() else {
343 return GetRequest::new(self.request.hash, ChunkRangesSeq::empty());
345 };
346 let mut builder = GetRequest::builder().root(root_requested - &self.bitfield.ranges);
347
348 let Some(children) = self.children.as_ref() else {
349 return builder.build(self.request.hash);
350 };
351 let mut iter = self.request.ranges.iter_non_empty_infinite();
352 let max_local = children
353 .hash_seq
354 .keys()
355 .next_back()
356 .map(|x| *x + 1)
357 .unwrap_or_default();
358 let max_offset = self.bitfield.validated_size().map(|x| x / 32);
359 for (offset, requested) in iter.by_ref() {
360 if offset == 0 {
361 continue;
363 }
364 let child = offset - 1;
365 let missing = match children.hash_seq.get(&child) {
366 Some(hash) => requested.difference(&children.bitfields[hash].ranges),
367 None => requested.clone(),
368 };
369 builder = builder.child(child, missing);
370 if offset >= max_local {
371 break;
373 }
374 }
375 loop {
376 let Some((offset, requested)) = iter.next() else {
377 return builder.build(self.request.hash);
378 };
379 if offset == 0 {
380 continue;
382 }
383 let child = offset - 1;
384 if let Some(max_offset) = &max_offset {
385 if child >= *max_offset {
386 return builder.build(self.request.hash);
387 }
388 builder = builder.child(child, requested.clone());
389 } else {
390 builder = builder.child(child, requested.clone());
391 if iter.is_at_end() {
392 if iter.next().is_none() {
393 return builder.build(self.request.hash);
394 } else {
395 return builder.build_open(self.request.hash);
396 }
397 }
398 }
399 }
400 }
401}
402
403#[derive(Debug)]
404struct NonRawLocalInfo {
405 hash_seq: BTreeMap<u64, Hash>,
407 bitfields: BTreeMap<Hash, Bitfield>,
410}
411
412impl Remote {
427 pub(crate) fn ref_from_sender(sender: &ApiClient) -> &Self {
428 Self::ref_cast(sender)
429 }
430
431 fn store(&self) -> &Store {
432 Store::ref_from_sender(&self.client)
433 }
434
435 pub async fn local_for_request(
436 &self,
437 request: impl Into<Arc<GetRequest>>,
438 ) -> Result<LocalInfo> {
439 let request = request.into();
440 let root = request.hash;
441 let bitfield = self.store().observe(root).await?;
442 let children = if !request.ranges.is_blob() {
443 let opts = ExportBaoOptions {
444 hash: root,
445 ranges: bitfield.ranges.clone(),
446 };
447 let bao = self.store().export_bao_with_opts(opts, 32);
448 let mut by_index = BTreeMap::new();
449 let mut stream = bao.hashes_with_index();
450 while let Some(item) = stream.next().await {
451 if let Ok((index, hash)) = item {
452 by_index.insert(index, hash);
453 }
454 }
455 let mut bitfields = BTreeMap::new();
456 let mut hash_seq = BTreeMap::new();
457 let max = by_index.last_key_value().map(|(k, _)| *k + 1).unwrap_or(0);
458 for (index, _) in request.ranges.iter_non_empty_infinite() {
459 if index == 0 {
460 continue;
462 }
463 let child = index - 1;
464 if child > max {
465 break;
467 }
468 let Some(hash) = by_index.get(&child) else {
469 continue;
471 };
472 let bitfield = self.store().observe(*hash).await?;
473 bitfields.insert(*hash, bitfield);
474 hash_seq.insert(child, *hash);
475 }
476 Some(NonRawLocalInfo {
477 hash_seq,
478 bitfields,
479 })
480 } else {
481 None
482 };
483 Ok(LocalInfo {
484 request: request.clone(),
485 bitfield,
486 children,
487 })
488 }
489
490 pub async fn local(&self, content: impl Into<HashAndFormat>) -> Result<LocalInfo> {
492 let request = GetRequest::from(content.into());
493 self.local_for_request(request).await
494 }
495
496 pub fn fetch(
497 &self,
498 sp: impl GetStreamPair + 'static,
499 content: impl Into<HashAndFormat>,
500 ) -> GetProgress {
501 let content = content.into();
502 let (tx, rx) = tokio::sync::mpsc::channel(64);
503 let tx2 = tx.clone();
504 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
505 let this = self.clone();
506 let fut = async move {
507 let res = this.fetch_sink(sp, content, sink).await.into();
508 tx2.send(res).await.ok();
509 };
510 GetProgress {
511 rx,
512 fut: Box::pin(fut),
513 }
514 }
515
516 pub(crate) async fn fetch_sink(
524 &self,
525 sp: impl GetStreamPair,
526 content: impl Into<HashAndFormat>,
527 progress: impl Sink<u64, Error = irpc::channel::SendError>,
528 ) -> GetResult<Stats> {
529 let content = content.into();
530 let local = self
531 .local(content)
532 .await
533 .map_err(|e| e!(GetError::LocalFailure, e))?;
534 if local.is_complete() {
535 return Ok(Default::default());
536 }
537 let request = local.missing();
538 let stats = self.execute_get_sink(sp, request, progress).await?;
539 Ok(stats)
540 }
541
542 pub fn observe(
543 &self,
544 conn: Connection,
545 request: ObserveRequest,
546 ) -> impl Stream<Item = io::Result<Bitfield>> + 'static {
547 Gen::new(|co| async move {
548 if let Err(cause) = Self::observe_impl(conn, request, &co).await {
549 co.yield_(Err(cause)).await
550 }
551 })
552 }
553
554 async fn observe_impl(
555 conn: Connection,
556 request: ObserveRequest,
557 co: &Co<io::Result<Bitfield>>,
558 ) -> io::Result<()> {
559 let hash = request.hash;
560 debug!(%hash, "observing");
561 let (mut send, mut recv) = conn.open_bi().await?;
562 write_observe_request(request, &mut send).await?;
564 send.finish()?;
565 loop {
566 let msg = recv
567 .read_length_prefixed::<ObserveItem>(MAX_MESSAGE_SIZE)
568 .await?;
569 co.yield_(Ok(Bitfield::from(&msg))).await;
570 }
571 }
572
573 pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress {
574 let (tx, rx) = tokio::sync::mpsc::channel(64);
575 let tx2 = tx.clone();
576 let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress);
577 let this = self.clone();
578 let fut = async move {
579 let res = this.execute_push_sink(conn, request, sink).await.into();
580 tx2.send(res).await.ok();
581 };
582 PushProgress {
583 rx,
584 fut: Box::pin(fut),
585 }
586 }
587
588 pub(crate) async fn execute_push_sink(
592 &self,
593 conn: Connection,
594 request: PushRequest,
595 progress: impl Sink<u64, Error = irpc::channel::SendError>,
596 ) -> Result<Stats> {
597 let hash = request.hash;
598 debug!(%hash, "pushing");
599 let (mut send, mut recv) = conn.open_bi().await.anyerr()?;
600 let mut context = StreamContext {
601 payload_bytes_sent: 0,
602 sender: progress,
603 };
604 recv.stop(0u32.into()).anyerr()?;
606 let request = write_push_request(request, &mut send).await?;
608 let mut request_ranges = request.ranges.iter_infinite();
609 let root = request.hash;
610 let root_ranges = request_ranges.next().expect("infinite iterator");
611 if !root_ranges.is_empty() {
612 self.store()
613 .export_bao(root, root_ranges.clone())
614 .write_with_progress(&mut send, &mut context, &root, 0)
615 .await?;
616 }
617 if request.ranges.is_blob() {
618 send.finish().anyerr()?;
620 return Ok(Default::default());
621 }
622 let hash_seq = self.store().get_bytes(root).await?;
623 let hash_seq = HashSeq::try_from(hash_seq)?;
624 for (child, (child_hash, child_ranges)) in
625 hash_seq.into_iter().zip(request_ranges).enumerate()
626 {
627 if !child_ranges.is_empty() {
628 self.store()
629 .export_bao(child_hash, child_ranges.clone())
630 .write_with_progress(&mut send, &mut context, &child_hash, (child + 1) as u64)
631 .await?;
632 }
633 }
634 send.finish().anyerr()?;
635 Ok(Default::default())
636 }
637
638 pub fn execute_get(&self, conn: impl GetStreamPair, request: GetRequest) -> GetProgress {
639 self.execute_get_with_opts(conn, request)
640 }
641
642 pub fn execute_get_with_opts(
643 &self,
644 conn: impl GetStreamPair,
645 request: GetRequest,
646 ) -> GetProgress {
647 let (tx, rx) = tokio::sync::mpsc::channel(64);
648 let tx2 = tx.clone();
649 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
650 let this = self.clone();
651 let fut = async move {
652 let res = this.execute_get_sink(conn, request, sink).await.into();
653 tx2.send(res).await.ok();
654 };
655 GetProgress {
656 rx,
657 fut: Box::pin(fut),
658 }
659 }
660
661 pub(crate) async fn execute_get_sink(
670 &self,
671 conn: impl GetStreamPair,
672 request: GetRequest,
673 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
674 ) -> GetResult<Stats> {
675 let store = self.store();
676 let root = request.hash;
677 let conn = conn.open_stream_pair().await.map_err(|e| {
678 e!(
679 GetError::LocalFailure,
680 n0_error::anyerr!("failed to open stream pair: {e}")
681 )
682 })?;
683 let connected =
686 AtConnected::new(conn.t0, conn.recv, conn.send, request, Default::default());
687 trace!("Getting header");
688 let next_child = match connected
690 .next()
691 .await
692 .map_err(|e| e!(GetError::ConnectedNext, e))?
693 {
694 ConnectedNext::StartRoot(at_start_root) => {
695 let header = at_start_root.next();
696 let end = get_blob_ranges_impl(header, root, store, &mut progress).await?;
697 match end.next() {
698 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
699 EndBlobNext::Closing(at_closing) => Err(at_closing),
700 }
701 }
702 ConnectedNext::StartChild(at_start_child) => Ok(at_start_child),
703 ConnectedNext::Closing(at_closing) => Err(at_closing),
704 };
705 let at_closing = match next_child {
707 Ok(at_start_child) => {
708 let mut next_child = Ok(at_start_child);
709 let hash_seq = HashSeq::try_from(
710 store
711 .get_bytes(root)
712 .await
713 .map_err(|e| e!(GetError::LocalFailure, e.into()))?,
714 )
715 .map_err(|e| e!(GetError::BadRequest, e))?;
716 loop {
718 let at_start_child = match next_child {
719 Ok(at_start_child) => at_start_child,
720 Err(at_closing) => break at_closing,
721 };
722 let offset = at_start_child.offset() - 1;
723 let Some(hash) = hash_seq.get(offset as usize) else {
724 break at_start_child.finish();
725 };
726 trace!("getting child {offset} {}", hash.fmt_short());
727 let header = at_start_child.next(hash);
728 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
729 next_child = match end.next() {
730 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
731 EndBlobNext::Closing(at_closing) => Err(at_closing),
732 }
733 }
734 }
735 Err(at_closing) => at_closing,
736 };
737 let stats = at_closing
739 .next()
740 .await
741 .map_err(|e| e!(GetError::AtClosingNext, e))?;
742 trace!(?stats, "get hash seq done");
743 Ok(stats)
744 }
745
746 pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress {
747 let (tx, rx) = tokio::sync::mpsc::channel(64);
748 let tx2 = tx.clone();
749 let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress);
750 let this = self.clone();
751 let fut = async move {
752 let res = this.execute_get_many_sink(conn, request, sink).await.into();
753 tx2.send(res).await.ok();
754 };
755 GetProgress {
756 rx,
757 fut: Box::pin(fut),
758 }
759 }
760
761 pub async fn execute_get_many_sink(
770 &self,
771 conn: Connection,
772 request: GetManyRequest,
773 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
774 ) -> GetResult<Stats> {
775 let store = self.store();
776 let hash_seq = request.hashes.iter().copied().collect::<HashSeq>();
777 let next_child = crate::get::fsm::start_get_many(conn, request, Default::default()).await?;
778 let at_closing = match next_child {
780 Ok(at_start_child) => {
781 let mut next_child = Ok(at_start_child);
782 loop {
783 let at_start_child = match next_child {
784 Ok(at_start_child) => at_start_child,
785 Err(at_closing) => break at_closing,
786 };
787 let offset = at_start_child.offset();
788 let Some(hash) = hash_seq.get(offset as usize) else {
789 break at_start_child.finish();
790 };
791 trace!("getting child {offset} {}", hash.fmt_short());
792 let header = at_start_child.next(hash);
793 let end = get_blob_ranges_impl(header, hash, store, &mut progress).await?;
794 next_child = match end.next() {
795 EndBlobNext::MoreChildren(at_start_child) => Ok(at_start_child),
796 EndBlobNext::Closing(at_closing) => Err(at_closing),
797 }
798 }
799 }
800 Err(at_closing) => at_closing,
801 };
802 let stats = at_closing
804 .next()
805 .await
806 .map_err(|e| e!(GetError::AtClosingNext, e))?;
807 trace!(?stats, "get hash seq done");
808 Ok(stats)
809 }
810}
811
812#[allow(missing_docs)]
814#[non_exhaustive]
815#[stack_error(derive, add_meta)]
816pub enum ExecuteError {
817 #[error("Unable to open bidi stream")]
819 Connection {
820 #[error(std_err)]
821 source: iroh::endpoint::ConnectionError,
822 },
823 #[error("Unable to read from the remote")]
824 Read {
825 #[error(std_err)]
826 source: iroh::endpoint::ReadError,
827 },
828 #[error("Error sending the request")]
829 Send {
830 #[error(std_err)]
831 source: crate::get::fsm::ConnectedNextError,
832 },
833 #[error("Unable to read size")]
834 Size {
835 #[error(std_err)]
836 source: crate::get::fsm::AtBlobHeaderNextError,
837 },
838 #[error("Error while decoding the data")]
839 Decode {
840 #[error(std_err)]
841 source: crate::get::fsm::DecodeError,
842 },
843 #[error("Internal error while reading the hash sequence")]
844 ExportBao { source: api::ExportBaoError },
845 #[error("Hash sequence has an invalid length")]
846 InvalidHashSeq { source: AnyError },
847 #[error("Internal error importing the data")]
848 ImportBao { source: crate::api::RequestError },
849 #[error("Error sending download progress - receiver closed")]
850 SendDownloadProgress { source: irpc::channel::SendError },
851 #[error("Internal error importing the data")]
852 MpscSend {
853 #[error(std_err)]
854 source: tokio::sync::mpsc::error::SendError<BaoContentItem>,
855 },
856}
857
858pub trait GetStreamPair: Send + 'static {
859 fn open_stream_pair(
860 self,
861 ) -> impl Future<Output = io::Result<StreamPair<impl RecvStream, impl SendStream>>> + Send + 'static;
862}
863
864impl<R: RecvStream + 'static, W: SendStream + 'static> GetStreamPair for StreamPair<R, W> {
865 async fn open_stream_pair(self) -> io::Result<StreamPair<impl RecvStream, impl SendStream>> {
866 Ok(self)
867 }
868}
869
870impl GetStreamPair for Connection {
871 async fn open_stream_pair(
872 self,
873 ) -> io::Result<StreamPair<impl crate::util::RecvStream, impl crate::util::SendStream>> {
874 let connection_id = self.stable_id() as u64;
875 let (send, recv) = self.open_bi().await?;
876 Ok(StreamPair::new(connection_id, recv, send))
877 }
878}
879
880fn get_buffer_size(size: NonZeroU64) -> usize {
881 (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize
882}
883
884async fn get_blob_ranges_impl<R: RecvStream>(
885 header: AtBlobHeader<R>,
886 hash: Hash,
887 store: &Store,
888 mut progress: impl Sink<u64, Error = irpc::channel::SendError>,
889) -> GetResult<AtEndBlob<R>> {
890 let (mut content, size) = header
891 .next()
892 .await
893 .map_err(|e| e!(GetError::AtBlobHeaderNext, e))?;
894 let Some(size) = NonZeroU64::new(size) else {
895 return if hash == Hash::EMPTY {
896 let end = content.drain().await.map_err(|e| e!(GetError::Decode, e))?;
897 Ok(end)
898 } else {
899 Err(e!(
900 GetError::Decode,
901 DecodeError::leaf_hash_mismatch(ChunkNum(0))
902 ))
903 };
904 };
905 let buffer_size = get_buffer_size(size);
906 trace!(%size, %buffer_size, "get blob");
907 let handle = store
908 .import_bao(hash, size, buffer_size)
909 .await
910 .map_err(|e| e!(GetError::LocalFailure, e.into()))?;
911 let write = async move {
912 GetResult::Ok(loop {
913 match content.next().await {
914 BlobContentNext::More((next, res)) => {
915 let item = res.map_err(|e| e!(GetError::Decode, e))?;
916 progress
917 .send(next.stats().payload_bytes_read)
918 .await
919 .map_err(|e| e!(GetError::LocalFailure, e.into()))?;
920 handle
921 .tx
922 .send(item)
923 .await
924 .map_err(|e| e!(GetError::IrpcSend, e))?;
925 content = next;
926 }
927 BlobContentNext::Done(end) => {
928 drop(handle.tx);
929 break end;
930 }
931 }
932 })
933 };
934 let complete = async move {
935 handle.rx.await.map_err(|e| {
936 e!(
937 GetError::LocalFailure,
938 n0_error::anyerr!("error reading from import stream: {e}")
939 )
940 })
941 };
942 let (_, end) = tokio::try_join!(complete, write)?;
943 Ok(end)
944}
945
946#[derive(Debug)]
947pub(crate) struct LazyHashSeq {
948 blobs: Blobs,
949 hash: Hash,
950 current_chunk: Option<HashSeqChunk>,
951}
952
953#[derive(Debug)]
954pub(crate) struct HashSeqChunk {
955 offset: u64,
957 chunk: HashSeq,
959}
960
961impl TryFrom<Leaf> for HashSeqChunk {
962 type Error = AnyError;
963
964 fn try_from(leaf: Leaf) -> Result<Self, Self::Error> {
965 let offset = leaf.offset;
966 let chunk = HashSeq::try_from(leaf.data)?;
967 Ok(Self { offset, chunk })
968 }
969}
970
971impl IntoIterator for HashSeqChunk {
972 type Item = Hash;
973 type IntoIter = HashSeqIter;
974
975 fn into_iter(self) -> Self::IntoIter {
976 self.chunk.into_iter()
977 }
978}
979
980impl HashSeqChunk {
981 pub fn base(&self) -> u64 {
982 self.offset / 32
983 }
984
985 #[allow(dead_code)]
986 fn get(&self, offset: u64) -> Option<Hash> {
987 let start = self.offset;
988 let end = start + self.chunk.len() as u64;
989 if offset >= start && offset < end {
990 let o = (offset - start) as usize;
991 self.chunk.get(o)
992 } else {
993 None
994 }
995 }
996}
997
998impl LazyHashSeq {
999 #[allow(dead_code)]
1000 pub fn new(blobs: Blobs, hash: Hash) -> Self {
1001 Self {
1002 blobs,
1003 hash,
1004 current_chunk: None,
1005 }
1006 }
1007
1008 #[allow(dead_code)]
1009 pub async fn get_from_offset(&mut self, offset: u64) -> Result<Option<Hash>> {
1010 if offset == 0 {
1011 Ok(Some(self.hash))
1012 } else {
1013 self.get(offset - 1).await
1014 }
1015 }
1016
1017 #[allow(dead_code)]
1018 pub async fn get(&mut self, child_offset: u64) -> Result<Option<Hash>> {
1019 if let Some(chunk) = &self.current_chunk {
1021 if let Some(hash) = chunk.get(child_offset) {
1022 return Ok(Some(hash));
1023 }
1024 }
1025 let leaf = self
1027 .blobs
1028 .export_chunk(self.hash, child_offset * 32)
1029 .await?;
1030 let hs = HashSeqChunk::try_from(leaf)?;
1032 Ok(hs.get(child_offset).inspect(|_hash| {
1033 self.current_chunk = Some(hs);
1034 }))
1035 }
1036}
1037
1038async fn write_push_request(
1039 request: PushRequest,
1040 stream: &mut impl SendStream,
1041) -> Result<PushRequest> {
1042 let mut request_bytes = Vec::new();
1043 request_bytes.push(RequestType::Push as u8);
1044 request_bytes.write_length_prefixed(&request).unwrap();
1045 stream.send_bytes(request_bytes.into()).await?;
1046 Ok(request)
1047}
1048
1049async fn write_observe_request(
1050 request: ObserveRequest,
1051 stream: &mut impl SendStream,
1052) -> io::Result<()> {
1053 let request = Request::Observe(request);
1054 let request_bytes = postcard::to_allocvec(&request)
1055 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
1056 stream.send_bytes(request_bytes.into()).await?;
1057 Ok(())
1058}
1059
1060struct StreamContext<S> {
1061 payload_bytes_sent: u64,
1062 sender: S,
1063}
1064
1065impl<S> WriteProgress for StreamContext<S>
1066where
1067 S: Sink<u64, Error = irpc::channel::SendError>,
1068{
1069 async fn notify_payload_write(
1070 &mut self,
1071 _index: u64,
1072 _offset: u64,
1073 len: usize,
1074 ) -> ClientResult {
1075 self.payload_bytes_sent += len as u64;
1076 self.sender
1077 .send(self.payload_bytes_sent)
1078 .await
1079 .map_err(|e| n0_error::e!(ProgressError::Internal, e.into()))?;
1080 Ok(())
1081 }
1082
1083 fn log_other_write(&mut self, _len: usize) {}
1084
1085 async fn send_transfer_started(&mut self, _index: u64, _hash: &Hash, _size: u64) {}
1086}
1087
1088#[cfg(test)]
1089#[cfg(feature = "fs-store")]
1090mod tests {
1091 use bao_tree::{ChunkNum, ChunkRanges};
1092 use testresult::TestResult;
1093
1094 use crate::{
1095 api::blobs::Blobs,
1096 protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
1097 store::{
1098 fs::{
1099 tests::{test_data, INTERESTING_SIZES},
1100 FsStore,
1101 },
1102 mem::MemStore,
1103 util::tests::create_n0_bao,
1104 },
1105 tests::{add_test_hash_seq, add_test_hash_seq_incomplete},
1106 };
1107
1108 #[tokio::test]
1109 async fn test_local_info_raw() -> TestResult<()> {
1110 let td = tempfile::tempdir()?;
1111 let store = FsStore::load(td.path().join("blobs.db")).await?;
1112 let blobs = store.blobs();
1113 let tt = blobs.add_slice(b"test").temp_tag().await?;
1114 let hash = tt.hash();
1115 let info = store.remote().local(hash).await?;
1116 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1117 assert_eq!(info.local_bytes(), 4);
1118 assert!(info.is_complete());
1119 assert_eq!(
1120 info.missing(),
1121 GetRequest::new(hash, ChunkRangesSeq::empty())
1122 );
1123 Ok(())
1124 }
1125
1126 #[tokio::test]
1127 async fn test_local_info_hash_seq_large() -> TestResult<()> {
1128 let sizes = (0..1024 + 5).collect::<Vec<_>>();
1129 let relevant_sizes = sizes[32 * 16..32 * 32]
1130 .iter()
1131 .map(|x| *x as u64)
1132 .sum::<u64>();
1133 let td = tempfile::tempdir()?;
1134 let hash_seq_ranges = ChunkRanges::chunks(16..32);
1135 let store = FsStore::load(td.path().join("blobs.db")).await?;
1136 {
1137 let present = |i| {
1139 if i == 0 {
1140 hash_seq_ranges.clone()
1141 } else {
1142 ChunkRanges::from(..ChunkNum(1))
1143 }
1144 };
1145 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1146 let info = store.remote().local(content).await?;
1147 assert_eq!(info.bitfield.ranges, hash_seq_ranges);
1148 assert!(!info.is_complete());
1149 assert_eq!(info.local_bytes(), relevant_sizes + 16 * 1024);
1150 }
1151
1152 Ok(())
1153 }
1154
1155 async fn test_observe_partial(blobs: &Blobs) -> TestResult<()> {
1156 let sizes = INTERESTING_SIZES;
1157 for size in sizes {
1158 let data = test_data(size);
1159 let ranges = ChunkRanges::chunk(0);
1160 let (hash, bao) = create_n0_bao(&data, &ranges)?;
1161 blobs.import_bao_bytes(hash, ranges.clone(), bao).await?;
1162 let bitfield = blobs.observe(hash).await?;
1163 if size > 1024 {
1164 assert_eq!(bitfield.ranges, ranges);
1165 } else {
1166 assert_eq!(bitfield.ranges, ChunkRanges::all());
1167 }
1168 }
1169 Ok(())
1170 }
1171
1172 #[tokio::test]
1173 async fn test_observe_partial_mem() -> TestResult<()> {
1174 let store = MemStore::new();
1175 test_observe_partial(store.blobs()).await?;
1176 Ok(())
1177 }
1178
1179 #[tokio::test]
1180 async fn test_observe_partial_fs() -> TestResult<()> {
1181 let td = tempfile::tempdir()?;
1182 let store = FsStore::load(td.path()).await?;
1183 test_observe_partial(store.blobs()).await?;
1184 Ok(())
1185 }
1186
1187 #[tokio::test]
1188 async fn test_local_info_hash_seq() -> TestResult<()> {
1189 let sizes = INTERESTING_SIZES;
1190 let total_size = sizes.iter().map(|x| *x as u64).sum::<u64>();
1191 let hash_seq_size = (sizes.len() as u64) * 32;
1192 let td = tempfile::tempdir()?;
1193 let store = FsStore::load(td.path().join("blobs.db")).await?;
1194 {
1195 let present = |i| {
1197 if i == 0 {
1198 ChunkRanges::all()
1199 } else {
1200 ChunkRanges::empty()
1201 }
1202 };
1203 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1204 let info = store.remote().local(content).await?;
1205 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1206 assert_eq!(info.local_bytes(), hash_seq_size);
1207 assert!(!info.is_complete());
1208 assert_eq!(
1209 info.missing(),
1210 GetRequest::new(
1211 content.hash,
1212 ChunkRangesSeq::from_ranges([
1213 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::all(), ChunkRanges::all(),
1217 ChunkRanges::all(),
1218 ChunkRanges::all(),
1219 ChunkRanges::all(),
1220 ChunkRanges::all(),
1221 ChunkRanges::all(),
1222 ])
1223 )
1224 );
1225 store.tags().delete_all().await?;
1226 }
1227 {
1228 let present = |i| {
1230 if i == 0 {
1231 ChunkRanges::all()
1232 } else {
1233 ChunkRanges::from(..ChunkNum(1))
1234 }
1235 };
1236 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1237 let info = store.remote().local(content).await?;
1238 let first_chunk_size = sizes.into_iter().map(|x| x.min(1024) as u64).sum::<u64>();
1239 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1240 assert_eq!(info.local_bytes(), hash_seq_size + first_chunk_size);
1241 assert!(!info.is_complete());
1242 assert_eq!(
1243 info.missing(),
1244 GetRequest::new(
1245 content.hash,
1246 ChunkRangesSeq::from_ranges([
1247 ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::empty(), ChunkRanges::chunks(1..),
1252 ChunkRanges::chunks(1..),
1253 ChunkRanges::chunks(1..),
1254 ChunkRanges::chunks(1..),
1255 ChunkRanges::chunks(1..),
1256 ])
1257 )
1258 );
1259 }
1260 {
1261 let content = add_test_hash_seq(&store, sizes).await?;
1262 let info = store.remote().local(content).await?;
1263 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1264 assert_eq!(info.local_bytes(), total_size + hash_seq_size);
1265 assert!(info.is_complete());
1266 assert_eq!(
1267 info.missing(),
1268 GetRequest::new(content.hash, ChunkRangesSeq::empty())
1269 );
1270 }
1271 Ok(())
1272 }
1273
1274 #[tokio::test]
1275 async fn test_local_info_complex_request() -> TestResult<()> {
1276 let sizes = INTERESTING_SIZES;
1277 let hash_seq_size = (sizes.len() as u64) * 32;
1278 let td = tempfile::tempdir()?;
1279 let store = FsStore::load(td.path().join("blobs.db")).await?;
1280 let present = |i| {
1282 if i == 0 {
1283 ChunkRanges::all()
1284 } else {
1285 ChunkRanges::chunks(..2)
1286 }
1287 };
1288 let content = add_test_hash_seq_incomplete(&store, sizes, present).await?;
1289 {
1290 let request: GetRequest = GetRequest::builder()
1291 .root(ChunkRanges::all())
1292 .build(content.hash);
1293 let info = store.remote().local_for_request(request).await?;
1294 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1295 assert_eq!(info.local_bytes(), hash_seq_size);
1296 assert!(info.is_complete());
1297 }
1298 {
1299 let request: GetRequest = GetRequest::builder()
1300 .root(ChunkRanges::all())
1301 .next(ChunkRanges::all())
1302 .build(content.hash);
1303 let info = store.remote().local_for_request(request).await?;
1304 let expected_child_sizes = sizes
1305 .into_iter()
1306 .take(1)
1307 .map(|x| 1024.min(x as u64))
1308 .sum::<u64>();
1309 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1310 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1311 assert!(info.is_complete());
1312 }
1313 {
1314 let request: GetRequest = GetRequest::builder()
1315 .root(ChunkRanges::all())
1316 .next(ChunkRanges::all())
1317 .next(ChunkRanges::all())
1318 .build(content.hash);
1319 let info = store.remote().local_for_request(request).await?;
1320 let expected_child_sizes = sizes
1321 .into_iter()
1322 .take(2)
1323 .map(|x| 1024.min(x as u64))
1324 .sum::<u64>();
1325 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1326 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1327 assert!(info.is_complete());
1328 }
1329 {
1330 let request: GetRequest = GetRequest::builder()
1331 .root(ChunkRanges::all())
1332 .next(ChunkRanges::chunk(0))
1333 .build_open(content.hash);
1334 let info = store.remote().local_for_request(request).await?;
1335 let expected_child_sizes = sizes.into_iter().map(|x| 1024.min(x as u64)).sum::<u64>();
1336 assert_eq!(info.bitfield.ranges, ChunkRanges::all());
1337 assert_eq!(info.local_bytes(), hash_seq_size + expected_child_sizes);
1338 assert!(info.is_complete());
1339 }
1340 Ok(())
1341 }
1342}