iroh_blobs/
provider.rs

1//! The low level server side API
2//!
3//! Note that while using this API directly is fine, the standard way
4//! to provide data is to just register a [`crate::BlobsProtocol`] protocol
5//! handler with an [`iroh::Endpoint`](iroh::protocol::Router).
6use std::{fmt::Debug, future::Future, io};
7
8use bao_tree::ChunkRanges;
9use iroh::endpoint::{self, ConnectionError, VarInt};
10use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
11use n0_error::{e, stack_error, Result};
12use n0_future::{
13    time::{Duration, Instant},
14    StreamExt,
15};
16use serde::{Deserialize, Serialize};
17use tokio::select;
18use tracing::{debug, debug_span, Instrument};
19
20use crate::{
21    api::{
22        blobs::{Bitfield, WriteProgress},
23        ExportBaoError, ExportBaoResult, RequestError, Store,
24    },
25    hashseq::HashSeq,
26    protocol::{
27        GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
28    },
29    provider::events::{
30        ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
31        RequestTracker,
32    },
33    util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
34    Hash,
35};
36pub mod events;
37use events::EventSender;
38
39type DefaultReader = iroh::endpoint::RecvStream;
40type DefaultWriter = iroh::endpoint::SendStream;
41
42/// Statistics about a successful or failed transfer.
43#[derive(Debug, Serialize, Deserialize)]
44pub struct TransferStats {
45    /// The number of bytes sent that are part of the payload.
46    pub payload_bytes_sent: u64,
47    /// The number of bytes sent that are not part of the payload.
48    ///
49    /// Hash pairs and the initial size header.
50    pub other_bytes_sent: u64,
51    /// The number of bytes read from the stream.
52    ///
53    /// In most cases this is just the request, for push requests this is
54    /// request, size header and hash pairs.
55    pub other_bytes_read: u64,
56    /// Total duration from reading the request to transfer completed.
57    pub duration: Duration,
58}
59
60/// A pair of [`SendStream`] and [`RecvStream`] with additional context data.
61#[derive(Debug)]
62pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
63    t0: Instant,
64    connection_id: u64,
65    reader: R,
66    writer: W,
67    other_bytes_read: u64,
68    events: EventSender,
69}
70
71impl StreamPair {
72    pub async fn accept(
73        conn: &endpoint::Connection,
74        events: EventSender,
75    ) -> Result<Self, ConnectionError> {
76        let (writer, reader) = conn.accept_bi().await?;
77        Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
78    }
79}
80
81impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
82    pub fn stream_id(&self) -> u64 {
83        self.reader.id()
84    }
85
86    pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
87        Self {
88            t0: Instant::now(),
89            connection_id,
90            reader,
91            writer,
92            other_bytes_read: 0,
93            events,
94        }
95    }
96
97    /// Read the request.
98    ///
99    /// Will fail if there is an error while reading, or if no valid request is sent.
100    ///
101    /// This will read exactly the number of bytes needed for the request, and
102    /// leave the rest of the stream for the caller to read.
103    ///
104    /// It is up to the caller do decide if there should be more data.
105    pub async fn read_request(&mut self) -> Result<Request> {
106        let (res, size) = Request::read_async(&mut self.reader).await?;
107        self.other_bytes_read += size as u64;
108        Ok(res)
109    }
110
111    /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
112    pub async fn into_writer(
113        mut self,
114        tracker: RequestTracker,
115    ) -> Result<ProgressWriter<W>, io::Error> {
116        self.reader.expect_eof().await?;
117        drop(self.reader);
118        Ok(ProgressWriter::new(
119            self.writer,
120            WriterContext {
121                t0: self.t0,
122                other_bytes_read: self.other_bytes_read,
123                payload_bytes_written: 0,
124                other_bytes_written: 0,
125                tracker,
126            },
127        ))
128    }
129
130    pub async fn into_reader(
131        mut self,
132        tracker: RequestTracker,
133    ) -> Result<ProgressReader<R>, io::Error> {
134        self.writer.sync().await?;
135        drop(self.writer);
136        Ok(ProgressReader {
137            inner: self.reader,
138            context: ReaderContext {
139                t0: self.t0,
140                other_bytes_read: self.other_bytes_read,
141                tracker,
142            },
143        })
144    }
145
146    pub async fn get_request(
147        &self,
148        f: impl FnOnce() -> GetRequest,
149    ) -> Result<RequestTracker, ProgressError> {
150        self.events
151            .request(f, self.connection_id, self.reader.id())
152            .await
153    }
154
155    pub async fn get_many_request(
156        &self,
157        f: impl FnOnce() -> GetManyRequest,
158    ) -> Result<RequestTracker, ProgressError> {
159        self.events
160            .request(f, self.connection_id, self.reader.id())
161            .await
162    }
163
164    pub async fn push_request(
165        &self,
166        f: impl FnOnce() -> PushRequest,
167    ) -> Result<RequestTracker, ProgressError> {
168        self.events
169            .request(f, self.connection_id, self.reader.id())
170            .await
171    }
172
173    pub async fn observe_request(
174        &self,
175        f: impl FnOnce() -> ObserveRequest,
176    ) -> Result<RequestTracker, ProgressError> {
177        self.events
178            .request(f, self.connection_id, self.reader.id())
179            .await
180    }
181
182    pub fn stats(&self) -> TransferStats {
183        TransferStats {
184            payload_bytes_sent: 0,
185            other_bytes_sent: 0,
186            other_bytes_read: self.other_bytes_read,
187            duration: self.t0.elapsed(),
188        }
189    }
190}
191
192#[derive(Debug)]
193struct ReaderContext {
194    /// The start time of the transfer
195    t0: Instant,
196    /// The number of bytes read from the stream
197    other_bytes_read: u64,
198    /// Progress tracking for the request
199    tracker: RequestTracker,
200}
201
202impl ReaderContext {
203    fn stats(&self) -> TransferStats {
204        TransferStats {
205            payload_bytes_sent: 0,
206            other_bytes_sent: 0,
207            other_bytes_read: self.other_bytes_read,
208            duration: self.t0.elapsed(),
209        }
210    }
211}
212
213#[derive(Debug)]
214pub(crate) struct WriterContext {
215    /// The start time of the transfer
216    t0: Instant,
217    /// The number of bytes read from the stream
218    other_bytes_read: u64,
219    /// The number of payload bytes written to the stream
220    payload_bytes_written: u64,
221    /// The number of bytes written that are not part of the payload
222    other_bytes_written: u64,
223    /// Way to report progress
224    tracker: RequestTracker,
225}
226
227impl WriterContext {
228    fn stats(&self) -> TransferStats {
229        TransferStats {
230            payload_bytes_sent: self.payload_bytes_written,
231            other_bytes_sent: self.other_bytes_written,
232            other_bytes_read: self.other_bytes_read,
233            duration: self.t0.elapsed(),
234        }
235    }
236}
237
238impl WriteProgress for WriterContext {
239    async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
240        let len = len as u64;
241        let end_offset = offset + len;
242        self.payload_bytes_written += len;
243        self.tracker.transfer_progress(len, end_offset).await
244    }
245
246    fn log_other_write(&mut self, len: usize) {
247        self.other_bytes_written += len as u64;
248    }
249
250    async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
251        self.tracker.transfer_started(index, hash, size).await.ok();
252    }
253}
254
255/// Wrapper for a [`quinn::SendStream`] with additional per request information.
256#[derive(Debug)]
257pub struct ProgressWriter<W: SendStream = DefaultWriter> {
258    /// The quinn::SendStream to write to
259    pub inner: W,
260    pub(crate) context: WriterContext,
261}
262
263impl<W: SendStream> ProgressWriter<W> {
264    fn new(inner: W, context: WriterContext) -> Self {
265        Self { inner, context }
266    }
267
268    async fn transfer_aborted(&self) {
269        self.context
270            .tracker
271            .transfer_aborted(|| Box::new(self.context.stats()))
272            .await
273            .ok();
274    }
275
276    async fn transfer_completed(&self) {
277        self.context
278            .tracker
279            .transfer_completed(|| Box::new(self.context.stats()))
280            .await
281            .ok();
282    }
283}
284
285/// Handle a single connection.
286pub async fn handle_connection(
287    connection: endpoint::Connection,
288    store: Store,
289    progress: EventSender,
290) {
291    let connection_id = connection.stable_id() as u64;
292    let span = debug_span!("connection", connection_id);
293    async move {
294        if let Err(cause) = progress
295            .client_connected(|| ClientConnected {
296                connection_id,
297                endpoint_id: Some(connection.remote_id()),
298            })
299            .await
300        {
301            connection.close(cause.code(), cause.reason());
302            debug!("closing connection: {cause}");
303            return;
304        }
305        while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
306            let span = debug_span!("stream", stream_id = %pair.stream_id());
307            let store = store.clone();
308            n0_future::task::spawn(handle_stream(pair, store).instrument(span));
309        }
310        progress
311            .connection_closed(|| ConnectionClosed { connection_id })
312            .await
313            .ok();
314    }
315    .instrument(span)
316    .await
317}
318
319/// Describes how to handle errors for a stream.
320pub trait ErrorHandler {
321    type W: AsyncStreamWriter;
322    type R: AsyncStreamReader;
323    fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
324    fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
325}
326
327async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
328    pair: &mut StreamPair<R, W>,
329    r: Result<T, E>,
330) -> Result<T, E> {
331    match r {
332        Ok(x) => Ok(x),
333        Err(e) => {
334            pair.writer.reset(e.code()).ok();
335            Err(e)
336        }
337    }
338}
339async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
340    writer: &mut ProgressWriter<W>,
341    r: Result<T, E>,
342) -> Result<T, E> {
343    match r {
344        Ok(x) => {
345            writer.transfer_completed().await;
346            Ok(x)
347        }
348        Err(e) => {
349            writer.inner.reset(e.code()).ok();
350            writer.transfer_aborted().await;
351            Err(e)
352        }
353    }
354}
355async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
356    reader: &mut ProgressReader<R>,
357    r: Result<T, E>,
358) -> Result<T, E> {
359    match r {
360        Ok(x) => {
361            reader.transfer_completed().await;
362            Ok(x)
363        }
364        Err(e) => {
365            reader.inner.stop(e.code()).ok();
366            reader.transfer_aborted().await;
367            Err(e)
368        }
369    }
370}
371
372pub async fn handle_stream<R: RecvStream, W: SendStream>(
373    mut pair: StreamPair<R, W>,
374    store: Store,
375) -> n0_error::Result<()> {
376    let request = pair.read_request().await?;
377    match request {
378        Request::Get(request) => handle_get(pair, store, request).await?,
379        Request::GetMany(request) => handle_get_many(pair, store, request).await?,
380        Request::Observe(request) => handle_observe(pair, store, request).await?,
381        Request::Push(request) => handle_push(pair, store, request).await?,
382        _ => {}
383    }
384    Ok(())
385}
386
387#[stack_error(derive, add_meta, from_sources)]
388pub enum HandleGetError {
389    #[error(transparent)]
390    ExportBao {
391        #[error(std_err)]
392        source: ExportBaoError,
393    },
394    #[error("Invalid hash sequence")]
395    InvalidHashSeq {},
396    #[error("Invalid offset")]
397    InvalidOffset {},
398}
399
400impl HasErrorCode for HandleGetError {
401    fn code(&self) -> VarInt {
402        match self {
403            HandleGetError::ExportBao {
404                source: ExportBaoError::ClientError { source, .. },
405                ..
406            } => source.code(),
407            HandleGetError::InvalidHashSeq { .. } => ERR_INTERNAL,
408            HandleGetError::InvalidOffset { .. } => ERR_INTERNAL,
409            _ => ERR_INTERNAL,
410        }
411    }
412}
413
414/// Handle a single get request.
415///
416/// Requires a database, the request, and a writer.
417async fn handle_get_impl<W: SendStream>(
418    store: Store,
419    request: GetRequest,
420    writer: &mut ProgressWriter<W>,
421) -> Result<(), HandleGetError> {
422    let hash = request.hash;
423    debug!(%hash, "get received request");
424    let mut hash_seq = None;
425    for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
426        if offset == 0 {
427            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
428        } else {
429            // todo: this assumes that 1. the hashseq is complete and 2. it is
430            // small enough to fit in memory.
431            //
432            // This should really read the hashseq from the store in chunks,
433            // only where needed, so we can deal with holes and large hashseqs.
434            let hash_seq = match &hash_seq {
435                Some(b) => b,
436                None => {
437                    let bytes = store.get_bytes(hash).await?;
438                    let hs =
439                        HashSeq::try_from(bytes).map_err(|_| e!(HandleGetError::InvalidHashSeq))?;
440                    hash_seq = Some(hs);
441                    hash_seq.as_ref().unwrap()
442                }
443            };
444            let o = usize::try_from(offset - 1).map_err(|_| e!(HandleGetError::InvalidOffset))?;
445            let Some(hash) = hash_seq.get(o) else {
446                break;
447            };
448            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
449        }
450    }
451    writer
452        .inner
453        .sync()
454        .await
455        .map_err(|e| e!(HandleGetError::ExportBao, e.into()))?;
456
457    Ok(())
458}
459
460pub async fn handle_get<R: RecvStream, W: SendStream>(
461    mut pair: StreamPair<R, W>,
462    store: Store,
463    request: GetRequest,
464) -> n0_error::Result<()> {
465    let res = pair.get_request(|| request.clone()).await;
466    let tracker = handle_read_request_result(&mut pair, res).await?;
467    let mut writer = pair.into_writer(tracker).await?;
468    let res = handle_get_impl(store, request, &mut writer).await;
469    handle_write_result(&mut writer, res).await?;
470    Ok(())
471}
472
473#[stack_error(derive, add_meta, from_sources)]
474pub enum HandleGetManyError {
475    #[error(transparent)]
476    ExportBao { source: ExportBaoError },
477}
478
479impl HasErrorCode for HandleGetManyError {
480    fn code(&self) -> VarInt {
481        match self {
482            Self::ExportBao {
483                source: ExportBaoError::ClientError { source, .. },
484                ..
485            } => source.code(),
486            _ => ERR_INTERNAL,
487        }
488    }
489}
490
491/// Handle a single get request.
492///
493/// Requires a database, the request, and a writer.
494async fn handle_get_many_impl<W: SendStream>(
495    store: Store,
496    request: GetManyRequest,
497    writer: &mut ProgressWriter<W>,
498) -> Result<(), HandleGetManyError> {
499    debug!("get_many received request");
500    let request_ranges = request.ranges.iter_infinite();
501    for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
502        if !ranges.is_empty() {
503            send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
504        }
505    }
506    Ok(())
507}
508
509pub async fn handle_get_many<R: RecvStream, W: SendStream>(
510    mut pair: StreamPair<R, W>,
511    store: Store,
512    request: GetManyRequest,
513) -> n0_error::Result<()> {
514    let res = pair.get_many_request(|| request.clone()).await;
515    let tracker = handle_read_request_result(&mut pair, res).await?;
516    let mut writer = pair.into_writer(tracker).await?;
517    let res = handle_get_many_impl(store, request, &mut writer).await;
518    handle_write_result(&mut writer, res).await?;
519    Ok(())
520}
521
522#[stack_error(derive, add_meta, from_sources)]
523pub enum HandlePushError {
524    #[error(transparent)]
525    ExportBao { source: ExportBaoError },
526
527    #[error("Invalid hash sequence")]
528    InvalidHashSeq {},
529
530    #[error(transparent)]
531    Request { source: RequestError },
532}
533
534impl HasErrorCode for HandlePushError {
535    fn code(&self) -> VarInt {
536        match self {
537            Self::ExportBao {
538                source: ExportBaoError::ClientError { source, .. },
539                ..
540            } => source.code(),
541            _ => ERR_INTERNAL,
542        }
543    }
544}
545
546/// Handle a single push request.
547///
548/// Requires a database, the request, and a reader.
549async fn handle_push_impl<R: RecvStream>(
550    store: Store,
551    request: PushRequest,
552    reader: &mut ProgressReader<R>,
553) -> Result<(), HandlePushError> {
554    let hash = request.hash;
555    debug!(%hash, "push received request");
556    let mut request_ranges = request.ranges.iter_infinite();
557    let root_ranges = request_ranges.next().expect("infinite iterator");
558    if !root_ranges.is_empty() {
559        // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
560        store
561            .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
562            .await?;
563    }
564    if request.ranges.is_blob() {
565        debug!("push request complete");
566        return Ok(());
567    }
568    // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now!
569    let hash_seq = store.get_bytes(hash).await?;
570    let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| e!(HandlePushError::InvalidHashSeq))?;
571    for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
572        if child_ranges.is_empty() {
573            continue;
574        }
575        store
576            .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
577            .await?;
578    }
579    Ok(())
580}
581
582pub async fn handle_push<R: RecvStream, W: SendStream>(
583    mut pair: StreamPair<R, W>,
584    store: Store,
585    request: PushRequest,
586) -> n0_error::Result<()> {
587    let res = pair.push_request(|| request.clone()).await;
588    let tracker = handle_read_request_result(&mut pair, res).await?;
589    let mut reader = pair.into_reader(tracker).await?;
590    let res = handle_push_impl(store, request, &mut reader).await;
591    handle_read_result(&mut reader, res).await?;
592    Ok(())
593}
594
595/// Send a blob to the client.
596pub(crate) async fn send_blob<W: SendStream>(
597    store: &Store,
598    index: u64,
599    hash: Hash,
600    ranges: ChunkRanges,
601    writer: &mut ProgressWriter<W>,
602) -> ExportBaoResult<()> {
603    store
604        .export_bao(hash, ranges)
605        .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
606        .await
607}
608
609#[stack_error(derive, add_meta, std_sources, from_sources)]
610pub enum HandleObserveError {
611    #[error("observe stream closed")]
612    ObserveStreamClosed {},
613
614    #[error(transparent)]
615    RemoteClosed { source: io::Error },
616}
617
618impl HasErrorCode for HandleObserveError {
619    fn code(&self) -> VarInt {
620        ERR_INTERNAL
621    }
622}
623
624/// Handle a single push request.
625///
626/// Requires a database, the request, and a reader.
627async fn handle_observe_impl<W: SendStream>(
628    store: Store,
629    request: ObserveRequest,
630    writer: &mut ProgressWriter<W>,
631) -> std::result::Result<(), HandleObserveError> {
632    let mut stream = store
633        .observe(request.hash)
634        .stream()
635        .await
636        .map_err(|_| e!(HandleObserveError::ObserveStreamClosed))?;
637    let mut old = stream
638        .next()
639        .await
640        .ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
641    // send the initial bitfield
642    send_observe_item(writer, &old).await?;
643    // send updates until the remote loses interest
644    loop {
645        select! {
646            new = stream.next() => {
647                let new = new.ok_or_else(|| e!(HandleObserveError::ObserveStreamClosed))?;
648                let diff = old.diff(&new);
649                if diff.is_empty() {
650                    continue;
651                }
652                send_observe_item(writer, &diff).await?;
653                old = new;
654            }
655            _ = writer.inner.stopped() => {
656                debug!("observer closed");
657                break;
658            }
659        }
660    }
661    Ok(())
662}
663
664async fn send_observe_item<W: SendStream>(
665    writer: &mut ProgressWriter<W>,
666    item: &Bitfield,
667) -> io::Result<()> {
668    let item = ObserveItem::from(item);
669    let len = writer.inner.write_length_prefixed(item).await?;
670    writer.context.log_other_write(len);
671    Ok(())
672}
673
674pub async fn handle_observe<R: RecvStream, W: SendStream>(
675    mut pair: StreamPair<R, W>,
676    store: Store,
677    request: ObserveRequest,
678) -> n0_error::Result<()> {
679    let res = pair.observe_request(|| request.clone()).await;
680    let tracker = handle_read_request_result(&mut pair, res).await?;
681    let mut writer = pair.into_writer(tracker).await?;
682    let res = handle_observe_impl(store, request, &mut writer).await;
683    handle_write_result(&mut writer, res).await?;
684    Ok(())
685}
686
687pub struct ProgressReader<R: RecvStream = DefaultReader> {
688    inner: R,
689    context: ReaderContext,
690}
691
692impl<R: RecvStream> ProgressReader<R> {
693    async fn transfer_aborted(&self) {
694        self.context
695            .tracker
696            .transfer_aborted(|| Box::new(self.context.stats()))
697            .await
698            .ok();
699    }
700
701    async fn transfer_completed(&self) {
702        self.context
703            .tracker
704            .transfer_completed(|| Box::new(self.context.stats()))
705            .await
706            .ok();
707    }
708}