iroh_blobs/
provider.rs

1//! The low level server side API
2//!
3//! Note that while using this API directly is fine, the standard way
4//! to provide data is to just register a [`crate::BlobsProtocol`] protocol
5//! handler with an [`iroh::Endpoint`](iroh::protocol::Router).
6use std::{
7    fmt::Debug,
8    future::Future,
9    io,
10    time::{Duration, Instant},
11};
12
13use bao_tree::ChunkRanges;
14use iroh::endpoint::{self, VarInt};
15use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
16use n0_error::{e, stack_error, Result};
17use n0_future::StreamExt;
18use quinn::ConnectionError;
19use serde::{Deserialize, Serialize};
20use tokio::select;
21use tracing::{debug, debug_span, Instrument};
22
23use crate::{
24    api::{
25        blobs::{Bitfield, WriteProgress},
26        ExportBaoError, ExportBaoResult, RequestError, Store,
27    },
28    hashseq::HashSeq,
29    protocol::{
30        GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
31    },
32    provider::events::{
33        ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
34        RequestTracker,
35    },
36    util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
37    Hash,
38};
39pub mod events;
40use events::EventSender;
41
42type DefaultReader = iroh::endpoint::RecvStream;
43type DefaultWriter = iroh::endpoint::SendStream;
44
45/// Statistics about a successful or failed transfer.
46#[derive(Debug, Serialize, Deserialize)]
47pub struct TransferStats {
48    /// The number of bytes sent that are part of the payload.
49    pub payload_bytes_sent: u64,
50    /// The number of bytes sent that are not part of the payload.
51    ///
52    /// Hash pairs and the initial size header.
53    pub other_bytes_sent: u64,
54    /// The number of bytes read from the stream.
55    ///
56    /// In most cases this is just the request, for push requests this is
57    /// request, size header and hash pairs.
58    pub other_bytes_read: u64,
59    /// Total duration from reading the request to transfer completed.
60    pub duration: Duration,
61}
62
63/// A pair of [`SendStream`] and [`RecvStream`] with additional context data.
64#[derive(Debug)]
65pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
66    t0: Instant,
67    connection_id: u64,
68    reader: R,
69    writer: W,
70    other_bytes_read: u64,
71    events: EventSender,
72}
73
74impl StreamPair {
75    pub async fn accept(
76        conn: &endpoint::Connection,
77        events: EventSender,
78    ) -> Result<Self, ConnectionError> {
79        let (writer, reader) = conn.accept_bi().await?;
80        Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
81    }
82}
83
84impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
85    pub fn stream_id(&self) -> u64 {
86        self.reader.id()
87    }
88
89    pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
90        Self {
91            t0: Instant::now(),
92            connection_id,
93            reader,
94            writer,
95            other_bytes_read: 0,
96            events,
97        }
98    }
99
100    /// Read the request.
101    ///
102    /// Will fail if there is an error while reading, or if no valid request is sent.
103    ///
104    /// This will read exactly the number of bytes needed for the request, and
105    /// leave the rest of the stream for the caller to read.
106    ///
107    /// It is up to the caller do decide if there should be more data.
108    pub async fn read_request(&mut self) -> Result<Request> {
109        let (res, size) = Request::read_async(&mut self.reader).await?;
110        self.other_bytes_read += size as u64;
111        Ok(res)
112    }
113
114    /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
115    pub async fn into_writer(
116        mut self,
117        tracker: RequestTracker,
118    ) -> Result<ProgressWriter<W>, io::Error> {
119        self.reader.expect_eof().await?;
120        drop(self.reader);
121        Ok(ProgressWriter::new(
122            self.writer,
123            WriterContext {
124                t0: self.t0,
125                other_bytes_read: self.other_bytes_read,
126                payload_bytes_written: 0,
127                other_bytes_written: 0,
128                tracker,
129            },
130        ))
131    }
132
133    pub async fn into_reader(
134        mut self,
135        tracker: RequestTracker,
136    ) -> Result<ProgressReader<R>, io::Error> {
137        self.writer.sync().await?;
138        drop(self.writer);
139        Ok(ProgressReader {
140            inner: self.reader,
141            context: ReaderContext {
142                t0: self.t0,
143                other_bytes_read: self.other_bytes_read,
144                tracker,
145            },
146        })
147    }
148
149    pub async fn get_request(
150        &self,
151        f: impl FnOnce() -> GetRequest,
152    ) -> Result<RequestTracker, ProgressError> {
153        self.events
154            .request(f, self.connection_id, self.reader.id())
155            .await
156    }
157
158    pub async fn get_many_request(
159        &self,
160        f: impl FnOnce() -> GetManyRequest,
161    ) -> Result<RequestTracker, ProgressError> {
162        self.events
163            .request(f, self.connection_id, self.reader.id())
164            .await
165    }
166
167    pub async fn push_request(
168        &self,
169        f: impl FnOnce() -> PushRequest,
170    ) -> Result<RequestTracker, ProgressError> {
171        self.events
172            .request(f, self.connection_id, self.reader.id())
173            .await
174    }
175
176    pub async fn observe_request(
177        &self,
178        f: impl FnOnce() -> ObserveRequest,
179    ) -> Result<RequestTracker, ProgressError> {
180        self.events
181            .request(f, self.connection_id, self.reader.id())
182            .await
183    }
184
185    pub fn stats(&self) -> TransferStats {
186        TransferStats {
187            payload_bytes_sent: 0,
188            other_bytes_sent: 0,
189            other_bytes_read: self.other_bytes_read,
190            duration: self.t0.elapsed(),
191        }
192    }
193}
194
195#[derive(Debug)]
196struct ReaderContext {
197    /// The start time of the transfer
198    t0: Instant,
199    /// The number of bytes read from the stream
200    other_bytes_read: u64,
201    /// Progress tracking for the request
202    tracker: RequestTracker,
203}
204
205impl ReaderContext {
206    fn stats(&self) -> TransferStats {
207        TransferStats {
208            payload_bytes_sent: 0,
209            other_bytes_sent: 0,
210            other_bytes_read: self.other_bytes_read,
211            duration: self.t0.elapsed(),
212        }
213    }
214}
215
216#[derive(Debug)]
217pub(crate) struct WriterContext {
218    /// The start time of the transfer
219    t0: Instant,
220    /// The number of bytes read from the stream
221    other_bytes_read: u64,
222    /// The number of payload bytes written to the stream
223    payload_bytes_written: u64,
224    /// The number of bytes written that are not part of the payload
225    other_bytes_written: u64,
226    /// Way to report progress
227    tracker: RequestTracker,
228}
229
230impl WriterContext {
231    fn stats(&self) -> TransferStats {
232        TransferStats {
233            payload_bytes_sent: self.payload_bytes_written,
234            other_bytes_sent: self.other_bytes_written,
235            other_bytes_read: self.other_bytes_read,
236            duration: self.t0.elapsed(),
237        }
238    }
239}
240
241impl WriteProgress for WriterContext {
242    async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
243        let len = len as u64;
244        let end_offset = offset + len;
245        self.payload_bytes_written += len;
246        self.tracker.transfer_progress(len, end_offset).await
247    }
248
249    fn log_other_write(&mut self, len: usize) {
250        self.other_bytes_written += len as u64;
251    }
252
253    async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
254        self.tracker.transfer_started(index, hash, size).await.ok();
255    }
256}
257
258/// Wrapper for a [`quinn::SendStream`] with additional per request information.
259#[derive(Debug)]
260pub struct ProgressWriter<W: SendStream = DefaultWriter> {
261    /// The quinn::SendStream to write to
262    pub inner: W,
263    pub(crate) context: WriterContext,
264}
265
266impl<W: SendStream> ProgressWriter<W> {
267    fn new(inner: W, context: WriterContext) -> Self {
268        Self { inner, context }
269    }
270
271    async fn transfer_aborted(&self) {
272        self.context
273            .tracker
274            .transfer_aborted(|| Box::new(self.context.stats()))
275            .await
276            .ok();
277    }
278
279    async fn transfer_completed(&self) {
280        self.context
281            .tracker
282            .transfer_completed(|| Box::new(self.context.stats()))
283            .await
284            .ok();
285    }
286}
287
288/// Handle a single connection.
289pub async fn handle_connection(
290    connection: endpoint::Connection,
291    store: Store,
292    progress: EventSender,
293) {
294    let connection_id = connection.stable_id() as u64;
295    let span = debug_span!("connection", connection_id);
296    async move {
297        if let Err(cause) = progress
298            .client_connected(|| ClientConnected {
299                connection_id,
300                endpoint_id: connection.remote_id().ok(),
301            })
302            .await
303        {
304            connection.close(cause.code(), cause.reason());
305            debug!("closing connection: {cause}");
306            return;
307        }
308        while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
309            let span = debug_span!("stream", stream_id = %pair.stream_id());
310            let store = store.clone();
311            tokio::spawn(handle_stream(pair, store).instrument(span));
312        }
313        progress
314            .connection_closed(|| ConnectionClosed { connection_id })
315            .await
316            .ok();
317    }
318    .instrument(span)
319    .await
320}
321
322/// Describes how to handle errors for a stream.
323pub trait ErrorHandler {
324    type W: AsyncStreamWriter;
325    type R: AsyncStreamReader;
326    fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
327    fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
328}
329
330async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
331    pair: &mut StreamPair<R, W>,
332    r: Result<T, E>,
333) -> Result<T, E> {
334    match r {
335        Ok(x) => Ok(x),
336        Err(e) => {
337            pair.writer.reset(e.code()).ok();
338            Err(e)
339        }
340    }
341}
342async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
343    writer: &mut ProgressWriter<W>,
344    r: Result<T, E>,
345) -> Result<T, E> {
346    match r {
347        Ok(x) => {
348            writer.transfer_completed().await;
349            Ok(x)
350        }
351        Err(e) => {
352            writer.inner.reset(e.code()).ok();
353            writer.transfer_aborted().await;
354            Err(e)
355        }
356    }
357}
358async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
359    reader: &mut ProgressReader<R>,
360    r: Result<T, E>,
361) -> Result<T, E> {
362    match r {
363        Ok(x) => {
364            reader.transfer_completed().await;
365            Ok(x)
366        }
367        Err(e) => {
368            reader.inner.stop(e.code()).ok();
369            reader.transfer_aborted().await;
370            Err(e)
371        }
372    }
373}
374
375pub async fn handle_stream<R: RecvStream, W: SendStream>(
376    mut pair: StreamPair<R, W>,
377    store: Store,
378) -> n0_error::Result<()> {
379    let request = pair.read_request().await?;
380    match request {
381        Request::Get(request) => handle_get(pair, store, request).await?,
382        Request::GetMany(request) => handle_get_many(pair, store, request).await?,
383        Request::Observe(request) => handle_observe(pair, store, request).await?,
384        Request::Push(request) => handle_push(pair, store, request).await?,
385        _ => {}
386    }
387    Ok(())
388}
389
390#[stack_error(derive, add_meta, from_sources)]
391pub enum HandleGetError {
392    #[error(transparent)]
393    ExportBao {
394        #[error(std_err)]
395        source: ExportBaoError,
396    },
397    #[error("Invalid hash sequence")]
398    InvalidHashSeq {},
399    #[error("Invalid offset")]
400    InvalidOffset {},
401}
402
403impl HasErrorCode for HandleGetError {
404    fn code(&self) -> VarInt {
405        match self {
406            HandleGetError::ExportBao {
407                source: ExportBaoError::ClientError { source, .. },
408                ..
409            } => source.code(),
410            HandleGetError::InvalidHashSeq { .. } => ERR_INTERNAL,
411            HandleGetError::InvalidOffset { .. } => ERR_INTERNAL,
412            _ => ERR_INTERNAL,
413        }
414    }
415}
416
417/// Handle a single get request.
418///
419/// Requires a database, the request, and a writer.
420async fn handle_get_impl<W: SendStream>(
421    store: Store,
422    request: GetRequest,
423    writer: &mut ProgressWriter<W>,
424) -> Result<(), HandleGetError> {
425    let hash = request.hash;
426    debug!(%hash, "get received request");
427    let mut hash_seq = None;
428    for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
429        if offset == 0 {
430            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
431        } else {
432            // todo: this assumes that 1. the hashseq is complete and 2. it is
433            // small enough to fit in memory.
434            //
435            // This should really read the hashseq from the store in chunks,
436            // only where needed, so we can deal with holes and large hashseqs.
437            let hash_seq = match &hash_seq {
438                Some(b) => b,
439                None => {
440                    let bytes = store.get_bytes(hash).await?;
441                    let hs =
442                        HashSeq::try_from(bytes).map_err(|_| e!(HandleGetError::InvalidHashSeq))?;
443                    hash_seq = Some(hs);
444                    hash_seq.as_ref().unwrap()
445                }
446            };
447            let o = usize::try_from(offset - 1).map_err(|_| e!(HandleGetError::InvalidOffset))?;
448            let Some(hash) = hash_seq.get(o) else {
449                break;
450            };
451            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
452        }
453    }
454    writer
455        .inner
456        .sync()
457        .await
458        .map_err(|e| e!(HandleGetError::ExportBao, e.into()))?;
459
460    Ok(())
461}
462
463pub async fn handle_get<R: RecvStream, W: SendStream>(
464    mut pair: StreamPair<R, W>,
465    store: Store,
466    request: GetRequest,
467) -> n0_error::Result<()> {
468    let res = pair.get_request(|| request.clone()).await;
469    let tracker = handle_read_request_result(&mut pair, res).await?;
470    let mut writer = pair.into_writer(tracker).await?;
471    let res = handle_get_impl(store, request, &mut writer).await;
472    handle_write_result(&mut writer, res).await?;
473    Ok(())
474}
475
476#[stack_error(derive, add_meta, from_sources)]
477pub enum HandleGetManyError {
478    #[error(transparent)]
479    ExportBao { source: ExportBaoError },
480}
481
482impl HasErrorCode for HandleGetManyError {
483    fn code(&self) -> VarInt {
484        match self {
485            Self::ExportBao {
486                source: ExportBaoError::ClientError { source, .. },
487                ..
488            } => source.code(),
489            _ => ERR_INTERNAL,
490        }
491    }
492}
493
494/// Handle a single get request.
495///
496/// Requires a database, the request, and a writer.
497async fn handle_get_many_impl<W: SendStream>(
498    store: Store,
499    request: GetManyRequest,
500    writer: &mut ProgressWriter<W>,
501) -> Result<(), HandleGetManyError> {
502    debug!("get_many received request");
503    let request_ranges = request.ranges.iter_infinite();
504    for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
505        if !ranges.is_empty() {
506            send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
507        }
508    }
509    Ok(())
510}
511
512pub async fn handle_get_many<R: RecvStream, W: SendStream>(
513    mut pair: StreamPair<R, W>,
514    store: Store,
515    request: GetManyRequest,
516) -> n0_error::Result<()> {
517    let res = pair.get_many_request(|| request.clone()).await;
518    let tracker = handle_read_request_result(&mut pair, res).await?;
519    let mut writer = pair.into_writer(tracker).await?;
520    let res = handle_get_many_impl(store, request, &mut writer).await;
521    handle_write_result(&mut writer, res).await?;
522    Ok(())
523}
524
525#[stack_error(derive, add_meta, from_sources)]
526pub enum HandlePushError {
527    #[error(transparent)]
528    ExportBao { source: ExportBaoError },
529
530    #[error("Invalid hash sequence")]
531    InvalidHashSeq {},
532
533    #[error(transparent)]
534    Request { source: RequestError },
535}
536
537impl HasErrorCode for HandlePushError {
538    fn code(&self) -> VarInt {
539        match self {
540            Self::ExportBao {
541                source: ExportBaoError::ClientError { source, .. },
542                ..
543            } => source.code(),
544            _ => ERR_INTERNAL,
545        }
546    }
547}
548
549/// Handle a single push request.
550///
551/// Requires a database, the request, and a reader.
552async fn handle_push_impl<R: RecvStream>(
553    store: Store,
554    request: PushRequest,
555    reader: &mut ProgressReader<R>,
556) -> Result<(), HandlePushError> {
557    let hash = request.hash;
558    debug!(%hash, "push received request");
559    let mut request_ranges = request.ranges.iter_infinite();
560    let root_ranges = request_ranges.next().expect("infinite iterator");
561    if !root_ranges.is_empty() {
562        // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
563        store
564            .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
565            .await?;
566    }
567    if request.ranges.is_blob() {
568        debug!("push request complete");
569        return Ok(());
570    }
571    // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now!
572    let hash_seq = store.get_bytes(hash).await?;
573    let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| e!(HandlePushError::InvalidHashSeq))?;
574    for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
575        if child_ranges.is_empty() {
576            continue;
577        }
578        store
579            .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
580            .await?;
581    }
582    Ok(())
583}
584
585pub async fn handle_push<R: RecvStream, W: SendStream>(
586    mut pair: StreamPair<R, W>,
587    store: Store,
588    request: PushRequest,
589) -> n0_error::Result<()> {
590    let res = pair.push_request(|| request.clone()).await;
591    let tracker = handle_read_request_result(&mut pair, res).await?;
592    let mut reader = pair.into_reader(tracker).await?;
593    let res = handle_push_impl(store, request, &mut reader).await;
594    handle_read_result(&mut reader, res).await?;
595    Ok(())
596}
597
598/// Send a blob to the client.
599pub(crate) async fn send_blob<W: SendStream>(
600    store: &Store,
601    index: u64,
602    hash: Hash,
603    ranges: ChunkRanges,
604    writer: &mut ProgressWriter<W>,
605) -> ExportBaoResult<()> {
606    store
607        .export_bao(hash, ranges)
608        .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
609        .await
610}
611
612#[stack_error(derive, add_meta, std_sources, from_sources)]
613pub enum HandleObserveError {
614    #[error("observe stream closed")]
615    ObserveStreamClosed {},
616
617    #[error(transparent)]
618    RemoteClosed { source: io::Error },
619}
620
621impl HasErrorCode for HandleObserveError {
622    fn code(&self) -> VarInt {
623        ERR_INTERNAL
624    }
625}
626
627/// Handle a single push request.
628///
629/// Requires a database, the request, and a reader.
630async fn handle_observe_impl<W: SendStream>(
631    store: Store,
632    request: ObserveRequest,
633    writer: &mut ProgressWriter<W>,
634) -> std::result::Result<(), HandleObserveError> {
635    let mut stream = store
636        .observe(request.hash)
637        .stream()
638        .await
639        .map_err(|_| e!(HandleObserveError::ObserveStreamClosed))?;
640    let mut old = stream
641        .next()
642        .await
643        .ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
644    // send the initial bitfield
645    send_observe_item(writer, &old).await?;
646    // send updates until the remote loses interest
647    loop {
648        select! {
649            new = stream.next() => {
650                let new = new.ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
651                let diff = old.diff(&new);
652                if diff.is_empty() {
653                    continue;
654                }
655                send_observe_item(writer, &diff).await?;
656                old = new;
657            }
658            _ = writer.inner.stopped() => {
659                debug!("observer closed");
660                break;
661            }
662        }
663    }
664    Ok(())
665}
666
667async fn send_observe_item<W: SendStream>(
668    writer: &mut ProgressWriter<W>,
669    item: &Bitfield,
670) -> io::Result<()> {
671    let item = ObserveItem::from(item);
672    let len = writer.inner.write_length_prefixed(item).await?;
673    writer.context.log_other_write(len);
674    Ok(())
675}
676
677pub async fn handle_observe<R: RecvStream, W: SendStream>(
678    mut pair: StreamPair<R, W>,
679    store: Store,
680    request: ObserveRequest,
681) -> n0_error::Result<()> {
682    let res = pair.observe_request(|| request.clone()).await;
683    let tracker = handle_read_request_result(&mut pair, res).await?;
684    let mut writer = pair.into_writer(tracker).await?;
685    let res = handle_observe_impl(store, request, &mut writer).await;
686    handle_write_result(&mut writer, res).await?;
687    Ok(())
688}
689
690pub struct ProgressReader<R: RecvStream = DefaultReader> {
691    inner: R,
692    context: ReaderContext,
693}
694
695impl<R: RecvStream> ProgressReader<R> {
696    async fn transfer_aborted(&self) {
697        self.context
698            .tracker
699            .transfer_aborted(|| Box::new(self.context.stats()))
700            .await
701            .ok();
702    }
703
704    async fn transfer_completed(&self) {
705        self.context
706            .tracker
707            .transfer_completed(|| Box::new(self.context.stats()))
708            .await
709            .ok();
710    }
711}