iroh_blobs/api/
downloader.rs

1//! API for downloads from multiple nodes.
2use std::{
3    collections::{HashMap, HashSet},
4    fmt::Debug,
5    future::{Future, IntoFuture},
6    sync::Arc,
7};
8
9use genawaiter::sync::Gen;
10use iroh::{Endpoint, EndpointId};
11use irpc::{channel::mpsc, rpc_requests};
12use n0_error::{anyerr, Result};
13use n0_future::{future, stream, BufferedStreamExt, Stream, StreamExt};
14use rand::seq::SliceRandom;
15use serde::{de::Error, Deserialize, Serialize};
16use tokio::task::JoinSet;
17use tracing::instrument::Instrument;
18
19use super::Store;
20use crate::{
21    protocol::{GetManyRequest, GetRequest},
22    util::{
23        connection_pool::ConnectionPool,
24        sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink},
25    },
26    BlobFormat, Hash, HashAndFormat,
27};
28
29#[derive(Debug, Clone)]
30pub struct Downloader {
31    client: irpc::Client<SwarmProtocol>,
32}
33
34#[rpc_requests(message = SwarmMsg, alias = "Msg")]
35#[derive(Debug, Serialize, Deserialize)]
36enum SwarmProtocol {
37    #[rpc(tx = mpsc::Sender<DownloadProgressItem>)]
38    Download(DownloadRequest),
39}
40
41struct DownloaderActor {
42    store: Store,
43    pool: ConnectionPool,
44    tasks: JoinSet<()>,
45    running: HashSet<tokio::task::Id>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49pub enum DownloadProgressItem {
50    #[serde(skip)]
51    Error(n0_error::AnyError),
52    TryProvider {
53        id: EndpointId,
54        request: Arc<GetRequest>,
55    },
56    ProviderFailed {
57        id: EndpointId,
58        request: Arc<GetRequest>,
59    },
60    PartComplete {
61        request: Arc<GetRequest>,
62    },
63    Progress(u64),
64    DownloadError,
65}
66
67impl DownloaderActor {
68    fn new(store: Store, endpoint: Endpoint) -> Self {
69        Self {
70            store,
71            pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()),
72            tasks: JoinSet::new(),
73            running: HashSet::new(),
74        }
75    }
76
77    async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<SwarmMsg>) {
78        while let Some(msg) = rx.recv().await {
79            match msg {
80                SwarmMsg::Download(request) => {
81                    self.spawn(handle_download(
82                        self.store.clone(),
83                        self.pool.clone(),
84                        request,
85                    ));
86                }
87            }
88        }
89    }
90
91    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
92        let span = tracing::Span::current();
93        let id = self.tasks.spawn(fut.instrument(span)).id();
94        self.running.insert(id);
95    }
96}
97
98async fn handle_download(store: Store, pool: ConnectionPool, msg: DownloadMsg) {
99    let DownloadMsg { inner, mut tx, .. } = msg;
100    if let Err(cause) = handle_download_impl(store, pool, inner, &mut tx).await {
101        tx.send(DownloadProgressItem::Error(cause)).await.ok();
102    }
103}
104
105async fn handle_download_impl(
106    store: Store,
107    pool: ConnectionPool,
108    request: DownloadRequest,
109    tx: &mut mpsc::Sender<DownloadProgressItem>,
110) -> Result<()> {
111    match request.strategy {
112        SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?,
113        SplitStrategy::None => match request.request {
114            FiniteRequest::Get(get) => {
115                let sink = IrpcSenderRefSink(tx);
116                execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?;
117            }
118            FiniteRequest::GetMany(_) => {
119                handle_download_split_impl(store, pool, request, tx).await?
120            }
121        },
122    }
123    Ok(())
124}
125
126async fn handle_download_split_impl(
127    store: Store,
128    pool: ConnectionPool,
129    request: DownloadRequest,
130    tx: &mut mpsc::Sender<DownloadProgressItem>,
131) -> Result<()> {
132    let providers = request.providers;
133    let requests = split_request(&request.request, &providers, &pool, &store, Drain).await?;
134    let (progress_tx, progress_rx) = tokio::sync::mpsc::channel(32);
135    let mut futs = stream::iter(requests.into_iter().enumerate())
136        .map(|(id, request)| {
137            let pool = pool.clone();
138            let providers = providers.clone();
139            let store = store.clone();
140            let progress_tx = progress_tx.clone();
141            async move {
142                let hash = request.hash;
143                let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgressItem)>(16);
144                progress_tx.send(rx).await.ok();
145                let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x));
146                let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await;
147                (hash, res)
148            }
149        })
150        .buffered_unordered(32);
151    let mut progress_stream = {
152        let mut offsets = HashMap::new();
153        let mut total = 0;
154        into_stream(progress_rx)
155            .flat_map(into_stream)
156            .map(move |(id, item)| match item {
157                DownloadProgressItem::Progress(offset) => {
158                    total += offset;
159                    if let Some(prev) = offsets.insert(id, offset) {
160                        total -= prev;
161                    }
162                    DownloadProgressItem::Progress(total)
163                }
164                x => x,
165            })
166    };
167    loop {
168        tokio::select! {
169            Some(item) = progress_stream.next() => {
170                tx.send(item).await?;
171            },
172            res = futs.next() => {
173                match res {
174                    Some((_hash, Ok(()))) => {
175                    }
176                    Some((_hash, Err(_e))) => {
177                        tx.send(DownloadProgressItem::DownloadError).await?;
178                    }
179                    None => break,
180                }
181            }
182            _ = tx.closed() => {
183                // The sender has been closed, we should stop processing.
184                break;
185            }
186        }
187    }
188    Ok(())
189}
190
191fn into_stream<T>(mut recv: tokio::sync::mpsc::Receiver<T>) -> impl Stream<Item = T> {
192    Gen::new(|co| async move {
193        while let Some(item) = recv.recv().await {
194            co.yield_(item).await;
195        }
196    })
197}
198
199#[derive(Debug, Serialize, Deserialize, derive_more::From)]
200pub enum FiniteRequest {
201    Get(GetRequest),
202    GetMany(GetManyRequest),
203}
204
205pub trait SupportedRequest {
206    fn into_request(self) -> FiniteRequest;
207}
208
209impl<I: Into<Hash>, T: IntoIterator<Item = I>> SupportedRequest for T {
210    fn into_request(self) -> FiniteRequest {
211        let hashes = self.into_iter().map(Into::into).collect::<GetManyRequest>();
212        FiniteRequest::GetMany(hashes)
213    }
214}
215
216impl SupportedRequest for GetRequest {
217    fn into_request(self) -> FiniteRequest {
218        self.into()
219    }
220}
221
222impl SupportedRequest for GetManyRequest {
223    fn into_request(self) -> FiniteRequest {
224        self.into()
225    }
226}
227
228impl SupportedRequest for Hash {
229    fn into_request(self) -> FiniteRequest {
230        GetRequest::blob(self).into()
231    }
232}
233
234impl SupportedRequest for HashAndFormat {
235    fn into_request(self) -> FiniteRequest {
236        (match self.format {
237            BlobFormat::Raw => GetRequest::blob(self.hash),
238            BlobFormat::HashSeq => GetRequest::all(self.hash),
239        })
240        .into()
241    }
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245pub struct AddProviderRequest {
246    pub hash: Hash,
247    pub providers: Vec<EndpointId>,
248}
249
250#[derive(Debug)]
251pub struct DownloadRequest {
252    pub request: FiniteRequest,
253    pub providers: Arc<dyn ContentDiscovery>,
254    pub strategy: SplitStrategy,
255}
256
257impl DownloadRequest {
258    pub fn new(
259        request: impl SupportedRequest,
260        providers: impl ContentDiscovery,
261        strategy: SplitStrategy,
262    ) -> Self {
263        Self {
264            request: request.into_request(),
265            providers: Arc::new(providers),
266            strategy,
267        }
268    }
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272pub enum SplitStrategy {
273    None,
274    Split,
275}
276
277impl Serialize for DownloadRequest {
278    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
279    where
280        S: serde::Serializer,
281    {
282        Err(serde::ser::Error::custom(
283            "cannot serialize DownloadRequest",
284        ))
285    }
286}
287
288// Implement Deserialize to always fail
289impl<'de> Deserialize<'de> for DownloadRequest {
290    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
291    where
292        D: serde::Deserializer<'de>,
293    {
294        Err(D::Error::custom("cannot deserialize DownloadRequest"))
295    }
296}
297
298pub type DownloadOptions = DownloadRequest;
299
300pub struct DownloadProgress {
301    fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>,
302}
303
304impl DownloadProgress {
305    fn new(fut: future::Boxed<irpc::Result<mpsc::Receiver<DownloadProgressItem>>>) -> Self {
306        Self { fut }
307    }
308
309    pub async fn stream(self) -> irpc::Result<impl Stream<Item = DownloadProgressItem> + Unpin> {
310        let rx = self.fut.await?;
311        Ok(Box::pin(rx.into_stream().map(|item| match item {
312            Ok(item) => item,
313            Err(e) => DownloadProgressItem::Error(e.into()),
314        })))
315    }
316
317    async fn complete(self) -> Result<()> {
318        let rx = self.fut.await?;
319        let stream = rx.into_stream();
320        tokio::pin!(stream);
321        while let Some(item) = stream.next().await {
322            match item? {
323                DownloadProgressItem::Error(e) => Err(e)?,
324                DownloadProgressItem::DownloadError => {
325                    n0_error::bail_any!("Download error");
326                }
327                _ => {}
328            }
329        }
330        Ok(())
331    }
332}
333
334impl IntoFuture for DownloadProgress {
335    type Output = Result<()>;
336    type IntoFuture = future::Boxed<Self::Output>;
337
338    fn into_future(self) -> Self::IntoFuture {
339        Box::pin(self.complete())
340    }
341}
342
343impl Downloader {
344    pub fn new(store: &Store, endpoint: &Endpoint) -> Self {
345        let (tx, rx) = tokio::sync::mpsc::channel::<SwarmMsg>(32);
346        let actor = DownloaderActor::new(store.clone(), endpoint.clone());
347        tokio::spawn(actor.run(rx));
348        Self { client: tx.into() }
349    }
350
351    pub fn download(
352        &self,
353        request: impl SupportedRequest,
354        providers: impl ContentDiscovery,
355    ) -> DownloadProgress {
356        let request = request.into_request();
357        let providers = Arc::new(providers);
358        self.download_with_opts(DownloadOptions {
359            request,
360            providers,
361            strategy: SplitStrategy::None,
362        })
363    }
364
365    pub fn download_with_opts(&self, options: DownloadOptions) -> DownloadProgress {
366        let fut = self.client.server_streaming(options, 32);
367        DownloadProgress::new(Box::pin(fut))
368    }
369}
370
371/// Split a request into multiple requests that can be run in parallel.
372async fn split_request<'a>(
373    request: &'a FiniteRequest,
374    providers: &Arc<dyn ContentDiscovery>,
375    pool: &ConnectionPool,
376    store: &Store,
377    progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
378) -> Result<Box<dyn Iterator<Item = GetRequest> + Send + 'a>> {
379    Ok(match request {
380        FiniteRequest::Get(req) => {
381            let Some(_first) = req.ranges.iter_infinite().next() else {
382                return Ok(Box::new(std::iter::empty()));
383            };
384            let first = GetRequest::blob(req.hash);
385            execute_get(pool, Arc::new(first), providers, store, progress).await?;
386            let size = store.observe(req.hash).await?.size();
387            n0_error::ensure_any!(size % 32 == 0, "Size is not a multiple of 32");
388            let n = size / 32;
389            Box::new(
390                req.ranges
391                    .iter_infinite()
392                    .take(n as usize + 1)
393                    .enumerate()
394                    .filter_map(|(i, ranges)| {
395                        if i != 0 && !ranges.is_empty() {
396                            Some(
397                                GetRequest::builder()
398                                    .offset(i as u64, ranges.clone())
399                                    .build(req.hash),
400                            )
401                        } else {
402                            None
403                        }
404                    }),
405            )
406        }
407        FiniteRequest::GetMany(req) => Box::new(
408            req.hashes
409                .iter()
410                .enumerate()
411                .map(|(i, hash)| GetRequest::blob_ranges(*hash, req.ranges[i as u64].clone())),
412        ),
413    })
414}
415
416/// Execute a get request sequentially for multiple providers.
417///
418/// It will try each provider in order
419/// until it finds one that can fulfill the request. When trying a new provider,
420/// it takes the progress from the previous providers into account, so e.g.
421/// if the first provider had the first 10% of the data, it will only ask the next
422/// provider for the remaining 90%.
423///
424/// This is fully sequential, so there will only be one request in flight at a time.
425///
426/// If the request is not complete after trying all providers, it will return an error.
427/// If the provider stream never ends, it will try indefinitely.
428async fn execute_get(
429    pool: &ConnectionPool,
430    request: Arc<GetRequest>,
431    providers: &Arc<dyn ContentDiscovery>,
432    store: &Store,
433    mut progress: impl Sink<DownloadProgressItem, Error = irpc::channel::SendError>,
434) -> Result<()> {
435    let remote = store.remote();
436    let mut providers = providers.find_providers(request.content());
437    while let Some(provider) = providers.next().await {
438        progress
439            .send(DownloadProgressItem::TryProvider {
440                id: provider,
441                request: request.clone(),
442            })
443            .await?;
444        let conn = pool.get_or_connect(provider);
445        let local = remote.local_for_request(request.clone()).await?;
446        if local.is_complete() {
447            return Ok(());
448        }
449        let local_bytes = local.local_bytes();
450        let Ok(conn) = conn.await else {
451            progress
452                .send(DownloadProgressItem::ProviderFailed {
453                    id: provider,
454                    request: request.clone(),
455                })
456                .await?;
457            continue;
458        };
459        match remote
460            .execute_get_sink(
461                conn.clone(),
462                local.missing(),
463                (&mut progress).with_map(move |x| DownloadProgressItem::Progress(x + local_bytes)),
464            )
465            .await
466        {
467            Ok(_stats) => {
468                progress
469                    .send(DownloadProgressItem::PartComplete {
470                        request: request.clone(),
471                    })
472                    .await?;
473                return Ok(());
474            }
475            Err(_cause) => {
476                progress
477                    .send(DownloadProgressItem::ProviderFailed {
478                        id: provider,
479                        request: request.clone(),
480                    })
481                    .await?;
482                continue;
483            }
484        }
485    }
486    Err(anyerr!("Unable to download {}", request.hash))
487}
488
489/// Trait for pluggable content discovery strategies.
490pub trait ContentDiscovery: Debug + Send + Sync + 'static {
491    fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed<EndpointId>;
492}
493
494impl<C, I> ContentDiscovery for C
495where
496    C: Debug + Clone + IntoIterator<Item = I> + Send + Sync + 'static,
497    C::IntoIter: Send + Sync + 'static,
498    I: Into<EndpointId> + Send + Sync + 'static,
499{
500    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
501        let providers = self.clone();
502        n0_future::stream::iter(providers.into_iter().map(Into::into)).boxed()
503    }
504}
505
506#[derive(derive_more::Debug)]
507pub struct Shuffled {
508    nodes: Vec<EndpointId>,
509}
510
511impl Shuffled {
512    pub fn new(nodes: Vec<EndpointId>) -> Self {
513        Self { nodes }
514    }
515}
516
517impl ContentDiscovery for Shuffled {
518    fn find_providers(&self, _: HashAndFormat) -> n0_future::stream::Boxed<EndpointId> {
519        let mut nodes = self.nodes.clone();
520        nodes.shuffle(&mut rand::rng());
521        n0_future::stream::iter(nodes).boxed()
522    }
523}
524
525#[cfg(test)]
526#[cfg(feature = "fs-store")]
527mod tests {
528    use std::ops::Deref;
529
530    use bao_tree::ChunkRanges;
531    use n0_future::StreamExt;
532    use testresult::TestResult;
533
534    use crate::{
535        api::{
536            blobs::AddBytesOptions,
537            downloader::{DownloadOptions, Downloader, Shuffled, SplitStrategy},
538        },
539        hashseq::HashSeq,
540        protocol::{GetManyRequest, GetRequest},
541        tests::node_test_setup_fs,
542    };
543
544    #[tokio::test]
545    #[ignore = "todo"]
546    async fn downloader_get_many_smoke() -> TestResult<()> {
547        let testdir = tempfile::tempdir()?;
548        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
549        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
550        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
551        let tt1 = store1.add_slice("hello world").await?;
552        let tt2 = store2.add_slice("hello world 2").await?;
553        let node1_addr = r1.endpoint().addr();
554        let node1_id = node1_addr.id;
555        let node2_addr = r2.endpoint().addr();
556        let node2_id = node2_addr.id;
557        let swarm = Downloader::new(&store3, r3.endpoint());
558        sp3.add_endpoint_info(node1_addr.clone());
559        sp3.add_endpoint_info(node2_addr.clone());
560        let request = GetManyRequest::builder()
561            .hash(tt1.hash, ChunkRanges::all())
562            .hash(tt2.hash, ChunkRanges::all())
563            .build();
564        let mut progress = swarm
565            .download(request, Shuffled::new(vec![node1_id, node2_id]))
566            .stream()
567            .await?;
568        while progress.next().await.is_some() {}
569        assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
570        assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
571        Ok(())
572    }
573
574    #[tokio::test]
575    async fn downloader_get_smoke() -> TestResult<()> {
576        // tracing_subscriber::fmt::try_init().ok();
577        let testdir = tempfile::tempdir()?;
578        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
579        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
580        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
581        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
582        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
583        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
584        let root = store1
585            .add_bytes_with_opts(AddBytesOptions {
586                data: hs.clone().into(),
587                format: crate::BlobFormat::HashSeq,
588            })
589            .await?;
590        let node1_addr = r1.endpoint().addr();
591        let node1_id = node1_addr.id;
592        let node2_addr = r2.endpoint().addr();
593        let node2_id = node2_addr.id;
594        let swarm = Downloader::new(&store3, r3.endpoint());
595        sp3.add_endpoint_info(node1_addr.clone());
596        sp3.add_endpoint_info(node2_addr.clone());
597        let request = GetRequest::builder()
598            .root(ChunkRanges::all())
599            .next(ChunkRanges::all())
600            .next(ChunkRanges::all())
601            .build(root.hash);
602        if true {
603            let mut progress = swarm
604                .download_with_opts(DownloadOptions::new(
605                    request,
606                    [node1_id, node2_id],
607                    SplitStrategy::Split,
608                ))
609                .stream()
610                .await?;
611            while progress.next().await.is_some() {}
612        }
613        if false {
614            let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
615            let remote = store3.remote();
616            let _rh = remote
617                .execute_get(
618                    conn.clone(),
619                    GetRequest::builder()
620                        .root(ChunkRanges::all())
621                        .build(root.hash),
622                )
623                .await?;
624            let h1 = remote.execute_get(
625                conn.clone(),
626                GetRequest::builder()
627                    .child(0, ChunkRanges::all())
628                    .build(root.hash),
629            );
630            let h2 = remote.execute_get(
631                conn.clone(),
632                GetRequest::builder()
633                    .child(1, ChunkRanges::all())
634                    .build(root.hash),
635            );
636            h1.await?;
637            h2.await?;
638        }
639        Ok(())
640    }
641
642    #[tokio::test]
643    async fn downloader_get_all() -> TestResult<()> {
644        let testdir = tempfile::tempdir()?;
645        let (r1, store1, _, _) = node_test_setup_fs(testdir.path().join("a")).await?;
646        let (r2, store2, _, _) = node_test_setup_fs(testdir.path().join("b")).await?;
647        let (r3, store3, _, sp3) = node_test_setup_fs(testdir.path().join("c")).await?;
648        let tt1 = store1.add_slice(vec![1; 10000000]).await?;
649        let tt2 = store2.add_slice(vec![2; 10000000]).await?;
650        let hs = [tt1.hash, tt2.hash].into_iter().collect::<HashSeq>();
651        let root = store1
652            .add_bytes_with_opts(AddBytesOptions {
653                data: hs.clone().into(),
654                format: crate::BlobFormat::HashSeq,
655            })
656            .await?;
657        let node1_addr = r1.endpoint().addr();
658        let node1_id = node1_addr.id;
659        let node2_addr = r2.endpoint().addr();
660        let node2_id = node2_addr.id;
661        let swarm = Downloader::new(&store3, r3.endpoint());
662        sp3.add_endpoint_info(node1_addr.clone());
663        sp3.add_endpoint_info(node2_addr.clone());
664        let request = GetRequest::all(root.hash);
665        let mut progress = swarm
666            .download_with_opts(DownloadOptions::new(
667                request,
668                [node1_id, node2_id],
669                SplitStrategy::Split,
670            ))
671            .stream()
672            .await?;
673        while progress.next().await.is_some() {}
674        Ok(())
675    }
676}