iroh_blobs/
provider.rs

1//! The low level server side API
2//!
3//! Note that while using this API directly is fine, the standard way
4//! to provide data is to just register a [`crate::BlobsProtocol`] protocol
5//! handler with an [`iroh::Endpoint`](iroh::protocol::Router).
6use std::{fmt::Debug, future::Future, io, time::Duration};
7
8use anyhow::Result;
9use bao_tree::ChunkRanges;
10use iroh::endpoint::{self, ConnectionError, VarInt};
11use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
12use n0_future::{time::Instant, StreamExt};
13use serde::{Deserialize, Serialize};
14use snafu::Snafu;
15use tokio::select;
16use tracing::{debug, debug_span, Instrument};
17
18use crate::{
19    api::{
20        blobs::{Bitfield, WriteProgress},
21        ExportBaoError, ExportBaoResult, RequestError, Store,
22    },
23    hashseq::HashSeq,
24    protocol::{
25        GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
26    },
27    provider::events::{
28        ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
29        RequestTracker,
30    },
31    util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
32    Hash,
33};
34pub mod events;
35use events::EventSender;
36
37type DefaultReader = iroh::endpoint::RecvStream;
38type DefaultWriter = iroh::endpoint::SendStream;
39
40/// Statistics about a successful or failed transfer.
41#[derive(Debug, Serialize, Deserialize)]
42pub struct TransferStats {
43    /// The number of bytes sent that are part of the payload.
44    pub payload_bytes_sent: u64,
45    /// The number of bytes sent that are not part of the payload.
46    ///
47    /// Hash pairs and the initial size header.
48    pub other_bytes_sent: u64,
49    /// The number of bytes read from the stream.
50    ///
51    /// In most cases this is just the request, for push requests this is
52    /// request, size header and hash pairs.
53    pub other_bytes_read: u64,
54    /// Total duration from reading the request to transfer completed.
55    pub duration: Duration,
56}
57
58/// A pair of [`SendStream`] and [`RecvStream`] with additional context data.
59#[derive(Debug)]
60pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
61    t0: Instant,
62    connection_id: u64,
63    reader: R,
64    writer: W,
65    other_bytes_read: u64,
66    events: EventSender,
67}
68
69impl StreamPair {
70    pub async fn accept(
71        conn: &endpoint::Connection,
72        events: EventSender,
73    ) -> Result<Self, ConnectionError> {
74        let (writer, reader) = conn.accept_bi().await?;
75        Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
76    }
77}
78
79impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
80    pub fn stream_id(&self) -> u64 {
81        self.reader.id()
82    }
83
84    pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
85        Self {
86            t0: Instant::now(),
87            connection_id,
88            reader,
89            writer,
90            other_bytes_read: 0,
91            events,
92        }
93    }
94
95    /// Read the request.
96    ///
97    /// Will fail if there is an error while reading, or if no valid request is sent.
98    ///
99    /// This will read exactly the number of bytes needed for the request, and
100    /// leave the rest of the stream for the caller to read.
101    ///
102    /// It is up to the caller do decide if there should be more data.
103    pub async fn read_request(&mut self) -> Result<Request> {
104        let (res, size) = Request::read_async(&mut self.reader).await?;
105        self.other_bytes_read += size as u64;
106        Ok(res)
107    }
108
109    /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
110    pub async fn into_writer(
111        mut self,
112        tracker: RequestTracker,
113    ) -> Result<ProgressWriter<W>, io::Error> {
114        self.reader.expect_eof().await?;
115        drop(self.reader);
116        Ok(ProgressWriter::new(
117            self.writer,
118            WriterContext {
119                t0: self.t0,
120                other_bytes_read: self.other_bytes_read,
121                payload_bytes_written: 0,
122                other_bytes_written: 0,
123                tracker,
124            },
125        ))
126    }
127
128    pub async fn into_reader(
129        mut self,
130        tracker: RequestTracker,
131    ) -> Result<ProgressReader<R>, io::Error> {
132        self.writer.sync().await?;
133        drop(self.writer);
134        Ok(ProgressReader {
135            inner: self.reader,
136            context: ReaderContext {
137                t0: self.t0,
138                other_bytes_read: self.other_bytes_read,
139                tracker,
140            },
141        })
142    }
143
144    pub async fn get_request(
145        &self,
146        f: impl FnOnce() -> GetRequest,
147    ) -> Result<RequestTracker, ProgressError> {
148        self.events
149            .request(f, self.connection_id, self.reader.id())
150            .await
151    }
152
153    pub async fn get_many_request(
154        &self,
155        f: impl FnOnce() -> GetManyRequest,
156    ) -> Result<RequestTracker, ProgressError> {
157        self.events
158            .request(f, self.connection_id, self.reader.id())
159            .await
160    }
161
162    pub async fn push_request(
163        &self,
164        f: impl FnOnce() -> PushRequest,
165    ) -> Result<RequestTracker, ProgressError> {
166        self.events
167            .request(f, self.connection_id, self.reader.id())
168            .await
169    }
170
171    pub async fn observe_request(
172        &self,
173        f: impl FnOnce() -> ObserveRequest,
174    ) -> Result<RequestTracker, ProgressError> {
175        self.events
176            .request(f, self.connection_id, self.reader.id())
177            .await
178    }
179
180    pub fn stats(&self) -> TransferStats {
181        TransferStats {
182            payload_bytes_sent: 0,
183            other_bytes_sent: 0,
184            other_bytes_read: self.other_bytes_read,
185            duration: self.t0.elapsed(),
186        }
187    }
188}
189
190#[derive(Debug)]
191struct ReaderContext {
192    /// The start time of the transfer
193    t0: Instant,
194    /// The number of bytes read from the stream
195    other_bytes_read: u64,
196    /// Progress tracking for the request
197    tracker: RequestTracker,
198}
199
200impl ReaderContext {
201    fn stats(&self) -> TransferStats {
202        TransferStats {
203            payload_bytes_sent: 0,
204            other_bytes_sent: 0,
205            other_bytes_read: self.other_bytes_read,
206            duration: self.t0.elapsed(),
207        }
208    }
209}
210
211#[derive(Debug)]
212pub(crate) struct WriterContext {
213    /// The start time of the transfer
214    t0: Instant,
215    /// The number of bytes read from the stream
216    other_bytes_read: u64,
217    /// The number of payload bytes written to the stream
218    payload_bytes_written: u64,
219    /// The number of bytes written that are not part of the payload
220    other_bytes_written: u64,
221    /// Way to report progress
222    tracker: RequestTracker,
223}
224
225impl WriterContext {
226    fn stats(&self) -> TransferStats {
227        TransferStats {
228            payload_bytes_sent: self.payload_bytes_written,
229            other_bytes_sent: self.other_bytes_written,
230            other_bytes_read: self.other_bytes_read,
231            duration: self.t0.elapsed(),
232        }
233    }
234}
235
236impl WriteProgress for WriterContext {
237    async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
238        let len = len as u64;
239        let end_offset = offset + len;
240        self.payload_bytes_written += len;
241        self.tracker.transfer_progress(len, end_offset).await
242    }
243
244    fn log_other_write(&mut self, len: usize) {
245        self.other_bytes_written += len as u64;
246    }
247
248    async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
249        self.tracker.transfer_started(index, hash, size).await.ok();
250    }
251}
252
253/// Wrapper for a [`quinn::SendStream`] with additional per request information.
254#[derive(Debug)]
255pub struct ProgressWriter<W: SendStream = DefaultWriter> {
256    /// The quinn::SendStream to write to
257    pub inner: W,
258    pub(crate) context: WriterContext,
259}
260
261impl<W: SendStream> ProgressWriter<W> {
262    fn new(inner: W, context: WriterContext) -> Self {
263        Self { inner, context }
264    }
265
266    async fn transfer_aborted(&self) {
267        self.context
268            .tracker
269            .transfer_aborted(|| Box::new(self.context.stats()))
270            .await
271            .ok();
272    }
273
274    async fn transfer_completed(&self) {
275        self.context
276            .tracker
277            .transfer_completed(|| Box::new(self.context.stats()))
278            .await
279            .ok();
280    }
281}
282
283/// Handle a single connection.
284pub async fn handle_connection(
285    connection: endpoint::Connection,
286    store: Store,
287    progress: EventSender,
288) {
289    let connection_id = connection.stable_id() as u64;
290    let span = debug_span!("connection", connection_id);
291    async move {
292        if let Err(cause) = progress
293            .client_connected(|| ClientConnected {
294                connection_id,
295                endpoint_id: connection.remote_id().ok(),
296            })
297            .await
298        {
299            connection.close(cause.code(), cause.reason());
300            debug!("closing connection: {cause}");
301            return;
302        }
303        while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
304            let span = debug_span!("stream", stream_id = %pair.stream_id());
305            let store = store.clone();
306            n0_future::task::spawn(handle_stream(pair, store).instrument(span));
307        }
308        progress
309            .connection_closed(|| ConnectionClosed { connection_id })
310            .await
311            .ok();
312    }
313    .instrument(span)
314    .await
315}
316
317/// Describes how to handle errors for a stream.
318pub trait ErrorHandler {
319    type W: AsyncStreamWriter;
320    type R: AsyncStreamReader;
321    fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
322    fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
323}
324
325async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
326    pair: &mut StreamPair<R, W>,
327    r: Result<T, E>,
328) -> Result<T, E> {
329    match r {
330        Ok(x) => Ok(x),
331        Err(e) => {
332            pair.writer.reset(e.code()).ok();
333            Err(e)
334        }
335    }
336}
337async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
338    writer: &mut ProgressWriter<W>,
339    r: Result<T, E>,
340) -> Result<T, E> {
341    match r {
342        Ok(x) => {
343            writer.transfer_completed().await;
344            Ok(x)
345        }
346        Err(e) => {
347            writer.inner.reset(e.code()).ok();
348            writer.transfer_aborted().await;
349            Err(e)
350        }
351    }
352}
353async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
354    reader: &mut ProgressReader<R>,
355    r: Result<T, E>,
356) -> Result<T, E> {
357    match r {
358        Ok(x) => {
359            reader.transfer_completed().await;
360            Ok(x)
361        }
362        Err(e) => {
363            reader.inner.stop(e.code()).ok();
364            reader.transfer_aborted().await;
365            Err(e)
366        }
367    }
368}
369
370pub async fn handle_stream<R: RecvStream, W: SendStream>(
371    mut pair: StreamPair<R, W>,
372    store: Store,
373) -> anyhow::Result<()> {
374    let request = pair.read_request().await?;
375    match request {
376        Request::Get(request) => handle_get(pair, store, request).await?,
377        Request::GetMany(request) => handle_get_many(pair, store, request).await?,
378        Request::Observe(request) => handle_observe(pair, store, request).await?,
379        Request::Push(request) => handle_push(pair, store, request).await?,
380        _ => {}
381    }
382    Ok(())
383}
384
385#[derive(Debug, Snafu)]
386#[snafu(module)]
387pub enum HandleGetError {
388    #[snafu(transparent)]
389    ExportBao {
390        source: ExportBaoError,
391    },
392    InvalidHashSeq,
393    InvalidOffset,
394}
395
396impl HasErrorCode for HandleGetError {
397    fn code(&self) -> VarInt {
398        match self {
399            HandleGetError::ExportBao {
400                source: ExportBaoError::ClientError { source, .. },
401            } => source.code(),
402            HandleGetError::InvalidHashSeq => ERR_INTERNAL,
403            HandleGetError::InvalidOffset => ERR_INTERNAL,
404            _ => ERR_INTERNAL,
405        }
406    }
407}
408
409/// Handle a single get request.
410///
411/// Requires a database, the request, and a writer.
412async fn handle_get_impl<W: SendStream>(
413    store: Store,
414    request: GetRequest,
415    writer: &mut ProgressWriter<W>,
416) -> Result<(), HandleGetError> {
417    let hash = request.hash;
418    debug!(%hash, "get received request");
419    let mut hash_seq = None;
420    for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
421        if offset == 0 {
422            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
423        } else {
424            // todo: this assumes that 1. the hashseq is complete and 2. it is
425            // small enough to fit in memory.
426            //
427            // This should really read the hashseq from the store in chunks,
428            // only where needed, so we can deal with holes and large hashseqs.
429            let hash_seq = match &hash_seq {
430                Some(b) => b,
431                None => {
432                    let bytes = store.get_bytes(hash).await?;
433                    let hs =
434                        HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?;
435                    hash_seq = Some(hs);
436                    hash_seq.as_ref().unwrap()
437                }
438            };
439            let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?;
440            let Some(hash) = hash_seq.get(o) else {
441                break;
442            };
443            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
444        }
445    }
446    writer
447        .inner
448        .sync()
449        .await
450        .map_err(|e| HandleGetError::ExportBao { source: e.into() })?;
451
452    Ok(())
453}
454
455pub async fn handle_get<R: RecvStream, W: SendStream>(
456    mut pair: StreamPair<R, W>,
457    store: Store,
458    request: GetRequest,
459) -> anyhow::Result<()> {
460    let res = pair.get_request(|| request.clone()).await;
461    let tracker = handle_read_request_result(&mut pair, res).await?;
462    let mut writer = pair.into_writer(tracker).await?;
463    let res = handle_get_impl(store, request, &mut writer).await;
464    handle_write_result(&mut writer, res).await?;
465    Ok(())
466}
467
468#[derive(Debug, Snafu)]
469pub enum HandleGetManyError {
470    #[snafu(transparent)]
471    ExportBao { source: ExportBaoError },
472}
473
474impl HasErrorCode for HandleGetManyError {
475    fn code(&self) -> VarInt {
476        match self {
477            Self::ExportBao {
478                source: ExportBaoError::ClientError { source, .. },
479            } => source.code(),
480            _ => ERR_INTERNAL,
481        }
482    }
483}
484
485/// Handle a single get request.
486///
487/// Requires a database, the request, and a writer.
488async fn handle_get_many_impl<W: SendStream>(
489    store: Store,
490    request: GetManyRequest,
491    writer: &mut ProgressWriter<W>,
492) -> Result<(), HandleGetManyError> {
493    debug!("get_many received request");
494    let request_ranges = request.ranges.iter_infinite();
495    for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
496        if !ranges.is_empty() {
497            send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
498        }
499    }
500    Ok(())
501}
502
503pub async fn handle_get_many<R: RecvStream, W: SendStream>(
504    mut pair: StreamPair<R, W>,
505    store: Store,
506    request: GetManyRequest,
507) -> anyhow::Result<()> {
508    let res = pair.get_many_request(|| request.clone()).await;
509    let tracker = handle_read_request_result(&mut pair, res).await?;
510    let mut writer = pair.into_writer(tracker).await?;
511    let res = handle_get_many_impl(store, request, &mut writer).await;
512    handle_write_result(&mut writer, res).await?;
513    Ok(())
514}
515
516#[derive(Debug, Snafu)]
517pub enum HandlePushError {
518    #[snafu(transparent)]
519    ExportBao {
520        source: ExportBaoError,
521    },
522
523    InvalidHashSeq,
524
525    #[snafu(transparent)]
526    Request {
527        source: RequestError,
528    },
529}
530
531impl HasErrorCode for HandlePushError {
532    fn code(&self) -> VarInt {
533        match self {
534            Self::ExportBao {
535                source: ExportBaoError::ClientError { source, .. },
536            } => source.code(),
537            _ => ERR_INTERNAL,
538        }
539    }
540}
541
542/// Handle a single push request.
543///
544/// Requires a database, the request, and a reader.
545async fn handle_push_impl<R: RecvStream>(
546    store: Store,
547    request: PushRequest,
548    reader: &mut ProgressReader<R>,
549) -> Result<(), HandlePushError> {
550    let hash = request.hash;
551    debug!(%hash, "push received request");
552    let mut request_ranges = request.ranges.iter_infinite();
553    let root_ranges = request_ranges.next().expect("infinite iterator");
554    if !root_ranges.is_empty() {
555        // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
556        store
557            .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
558            .await?;
559    }
560    if request.ranges.is_blob() {
561        debug!("push request complete");
562        return Ok(());
563    }
564    // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now!
565    let hash_seq = store.get_bytes(hash).await?;
566    let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?;
567    for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
568        if child_ranges.is_empty() {
569            continue;
570        }
571        store
572            .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
573            .await?;
574    }
575    Ok(())
576}
577
578pub async fn handle_push<R: RecvStream, W: SendStream>(
579    mut pair: StreamPair<R, W>,
580    store: Store,
581    request: PushRequest,
582) -> anyhow::Result<()> {
583    let res = pair.push_request(|| request.clone()).await;
584    let tracker = handle_read_request_result(&mut pair, res).await?;
585    let mut reader = pair.into_reader(tracker).await?;
586    let res = handle_push_impl(store, request, &mut reader).await;
587    handle_read_result(&mut reader, res).await?;
588    Ok(())
589}
590
591/// Send a blob to the client.
592pub(crate) async fn send_blob<W: SendStream>(
593    store: &Store,
594    index: u64,
595    hash: Hash,
596    ranges: ChunkRanges,
597    writer: &mut ProgressWriter<W>,
598) -> ExportBaoResult<()> {
599    store
600        .export_bao(hash, ranges)
601        .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
602        .await
603}
604
605#[derive(Debug, Snafu)]
606pub enum HandleObserveError {
607    ObserveStreamClosed,
608
609    #[snafu(transparent)]
610    RemoteClosed {
611        source: io::Error,
612    },
613}
614
615impl HasErrorCode for HandleObserveError {
616    fn code(&self) -> VarInt {
617        ERR_INTERNAL
618    }
619}
620
621/// Handle a single push request.
622///
623/// Requires a database, the request, and a reader.
624async fn handle_observe_impl<W: SendStream>(
625    store: Store,
626    request: ObserveRequest,
627    writer: &mut ProgressWriter<W>,
628) -> std::result::Result<(), HandleObserveError> {
629    let mut stream = store
630        .observe(request.hash)
631        .stream()
632        .await
633        .map_err(|_| HandleObserveError::ObserveStreamClosed)?;
634    let mut old = stream
635        .next()
636        .await
637        .ok_or(HandleObserveError::ObserveStreamClosed)?;
638    // send the initial bitfield
639    send_observe_item(writer, &old).await?;
640    // send updates until the remote loses interest
641    loop {
642        select! {
643            new = stream.next() => {
644                let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?;
645                let diff = old.diff(&new);
646                if diff.is_empty() {
647                    continue;
648                }
649                send_observe_item(writer, &diff).await?;
650                old = new;
651            }
652            _ = writer.inner.stopped() => {
653                debug!("observer closed");
654                break;
655            }
656        }
657    }
658    Ok(())
659}
660
661async fn send_observe_item<W: SendStream>(
662    writer: &mut ProgressWriter<W>,
663    item: &Bitfield,
664) -> io::Result<()> {
665    let item = ObserveItem::from(item);
666    let len = writer.inner.write_length_prefixed(item).await?;
667    writer.context.log_other_write(len);
668    Ok(())
669}
670
671pub async fn handle_observe<R: RecvStream, W: SendStream>(
672    mut pair: StreamPair<R, W>,
673    store: Store,
674    request: ObserveRequest,
675) -> anyhow::Result<()> {
676    let res = pair.observe_request(|| request.clone()).await;
677    let tracker = handle_read_request_result(&mut pair, res).await?;
678    let mut writer = pair.into_writer(tracker).await?;
679    let res = handle_observe_impl(store, request, &mut writer).await;
680    handle_write_result(&mut writer, res).await?;
681    Ok(())
682}
683
684pub struct ProgressReader<R: RecvStream = DefaultReader> {
685    inner: R,
686    context: ReaderContext,
687}
688
689impl<R: RecvStream> ProgressReader<R> {
690    async fn transfer_aborted(&self) {
691        self.context
692            .tracker
693            .transfer_aborted(|| Box::new(self.context.stats()))
694            .await
695            .ok();
696    }
697
698    async fn transfer_completed(&self) {
699        self.context
700            .tracker
701            .transfer_completed(|| Box::new(self.context.stats()))
702            .await
703            .ok();
704    }
705}