iroh_blobs/
provider.rs

1//! The low level server side API
2//!
3//! Note that while using this API directly is fine, the standard way
4//! to provide data is to just register a [`crate::BlobsProtocol`] protocol
5//! handler with an [`iroh::Endpoint`](iroh::protocol::Router).
6use std::{fmt::Debug, future::Future, io};
7
8use anyhow::Result;
9use bao_tree::ChunkRanges;
10use iroh::endpoint::{self, ConnectionError, VarInt};
11use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
12use n0_future::{
13    time::{Duration, Instant},
14    StreamExt,
15};
16use serde::{Deserialize, Serialize};
17use snafu::Snafu;
18use tokio::select;
19use tracing::{debug, debug_span, Instrument};
20
21use crate::{
22    api::{
23        blobs::{Bitfield, WriteProgress},
24        ExportBaoError, ExportBaoResult, RequestError, Store,
25    },
26    hashseq::HashSeq,
27    protocol::{
28        GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
29    },
30    provider::events::{
31        ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
32        RequestTracker,
33    },
34    util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
35    Hash,
36};
37pub mod events;
38use events::EventSender;
39
40type DefaultReader = iroh::endpoint::RecvStream;
41type DefaultWriter = iroh::endpoint::SendStream;
42
43/// Statistics about a successful or failed transfer.
44#[derive(Debug, Serialize, Deserialize)]
45pub struct TransferStats {
46    /// The number of bytes sent that are part of the payload.
47    pub payload_bytes_sent: u64,
48    /// The number of bytes sent that are not part of the payload.
49    ///
50    /// Hash pairs and the initial size header.
51    pub other_bytes_sent: u64,
52    /// The number of bytes read from the stream.
53    ///
54    /// In most cases this is just the request, for push requests this is
55    /// request, size header and hash pairs.
56    pub other_bytes_read: u64,
57    /// Total duration from reading the request to transfer completed.
58    pub duration: Duration,
59}
60
61/// A pair of [`SendStream`] and [`RecvStream`] with additional context data.
62#[derive(Debug)]
63pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
64    t0: Instant,
65    connection_id: u64,
66    reader: R,
67    writer: W,
68    other_bytes_read: u64,
69    events: EventSender,
70}
71
72impl StreamPair {
73    pub async fn accept(
74        conn: &endpoint::Connection,
75        events: EventSender,
76    ) -> Result<Self, ConnectionError> {
77        let (writer, reader) = conn.accept_bi().await?;
78        Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
79    }
80}
81
82impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
83    pub fn stream_id(&self) -> u64 {
84        self.reader.id()
85    }
86
87    pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
88        Self {
89            t0: Instant::now(),
90            connection_id,
91            reader,
92            writer,
93            other_bytes_read: 0,
94            events,
95        }
96    }
97
98    /// Read the request.
99    ///
100    /// Will fail if there is an error while reading, or if no valid request is sent.
101    ///
102    /// This will read exactly the number of bytes needed for the request, and
103    /// leave the rest of the stream for the caller to read.
104    ///
105    /// It is up to the caller do decide if there should be more data.
106    pub async fn read_request(&mut self) -> Result<Request> {
107        let (res, size) = Request::read_async(&mut self.reader).await?;
108        self.other_bytes_read += size as u64;
109        Ok(res)
110    }
111
112    /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
113    pub async fn into_writer(
114        mut self,
115        tracker: RequestTracker,
116    ) -> Result<ProgressWriter<W>, io::Error> {
117        self.reader.expect_eof().await?;
118        drop(self.reader);
119        Ok(ProgressWriter::new(
120            self.writer,
121            WriterContext {
122                t0: self.t0,
123                other_bytes_read: self.other_bytes_read,
124                payload_bytes_written: 0,
125                other_bytes_written: 0,
126                tracker,
127            },
128        ))
129    }
130
131    pub async fn into_reader(
132        mut self,
133        tracker: RequestTracker,
134    ) -> Result<ProgressReader<R>, io::Error> {
135        self.writer.sync().await?;
136        drop(self.writer);
137        Ok(ProgressReader {
138            inner: self.reader,
139            context: ReaderContext {
140                t0: self.t0,
141                other_bytes_read: self.other_bytes_read,
142                tracker,
143            },
144        })
145    }
146
147    pub async fn get_request(
148        &self,
149        f: impl FnOnce() -> GetRequest,
150    ) -> Result<RequestTracker, ProgressError> {
151        self.events
152            .request(f, self.connection_id, self.reader.id())
153            .await
154    }
155
156    pub async fn get_many_request(
157        &self,
158        f: impl FnOnce() -> GetManyRequest,
159    ) -> Result<RequestTracker, ProgressError> {
160        self.events
161            .request(f, self.connection_id, self.reader.id())
162            .await
163    }
164
165    pub async fn push_request(
166        &self,
167        f: impl FnOnce() -> PushRequest,
168    ) -> Result<RequestTracker, ProgressError> {
169        self.events
170            .request(f, self.connection_id, self.reader.id())
171            .await
172    }
173
174    pub async fn observe_request(
175        &self,
176        f: impl FnOnce() -> ObserveRequest,
177    ) -> Result<RequestTracker, ProgressError> {
178        self.events
179            .request(f, self.connection_id, self.reader.id())
180            .await
181    }
182
183    pub fn stats(&self) -> TransferStats {
184        TransferStats {
185            payload_bytes_sent: 0,
186            other_bytes_sent: 0,
187            other_bytes_read: self.other_bytes_read,
188            duration: self.t0.elapsed(),
189        }
190    }
191}
192
193#[derive(Debug)]
194struct ReaderContext {
195    /// The start time of the transfer
196    t0: Instant,
197    /// The number of bytes read from the stream
198    other_bytes_read: u64,
199    /// Progress tracking for the request
200    tracker: RequestTracker,
201}
202
203impl ReaderContext {
204    fn stats(&self) -> TransferStats {
205        TransferStats {
206            payload_bytes_sent: 0,
207            other_bytes_sent: 0,
208            other_bytes_read: self.other_bytes_read,
209            duration: self.t0.elapsed(),
210        }
211    }
212}
213
214#[derive(Debug)]
215pub(crate) struct WriterContext {
216    /// The start time of the transfer
217    t0: Instant,
218    /// The number of bytes read from the stream
219    other_bytes_read: u64,
220    /// The number of payload bytes written to the stream
221    payload_bytes_written: u64,
222    /// The number of bytes written that are not part of the payload
223    other_bytes_written: u64,
224    /// Way to report progress
225    tracker: RequestTracker,
226}
227
228impl WriterContext {
229    fn stats(&self) -> TransferStats {
230        TransferStats {
231            payload_bytes_sent: self.payload_bytes_written,
232            other_bytes_sent: self.other_bytes_written,
233            other_bytes_read: self.other_bytes_read,
234            duration: self.t0.elapsed(),
235        }
236    }
237}
238
239impl WriteProgress for WriterContext {
240    async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
241        let len = len as u64;
242        let end_offset = offset + len;
243        self.payload_bytes_written += len;
244        self.tracker.transfer_progress(len, end_offset).await
245    }
246
247    fn log_other_write(&mut self, len: usize) {
248        self.other_bytes_written += len as u64;
249    }
250
251    async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
252        self.tracker.transfer_started(index, hash, size).await.ok();
253    }
254}
255
256/// Wrapper for a [`quinn::SendStream`] with additional per request information.
257#[derive(Debug)]
258pub struct ProgressWriter<W: SendStream = DefaultWriter> {
259    /// The quinn::SendStream to write to
260    pub inner: W,
261    pub(crate) context: WriterContext,
262}
263
264impl<W: SendStream> ProgressWriter<W> {
265    fn new(inner: W, context: WriterContext) -> Self {
266        Self { inner, context }
267    }
268
269    async fn transfer_aborted(&self) {
270        self.context
271            .tracker
272            .transfer_aborted(|| Box::new(self.context.stats()))
273            .await
274            .ok();
275    }
276
277    async fn transfer_completed(&self) {
278        self.context
279            .tracker
280            .transfer_completed(|| Box::new(self.context.stats()))
281            .await
282            .ok();
283    }
284}
285
286/// Handle a single connection.
287pub async fn handle_connection(
288    connection: endpoint::Connection,
289    store: Store,
290    progress: EventSender,
291) {
292    let connection_id = connection.stable_id() as u64;
293    let span = debug_span!("connection", connection_id);
294    async move {
295        if let Err(cause) = progress
296            .client_connected(|| ClientConnected {
297                connection_id,
298                endpoint_id: Some(connection.remote_id()),
299            })
300            .await
301        {
302            connection.close(cause.code(), cause.reason());
303            debug!("closing connection: {cause}");
304            return;
305        }
306        while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
307            let span = debug_span!("stream", stream_id = %pair.stream_id());
308            let store = store.clone();
309            n0_future::task::spawn(handle_stream(pair, store).instrument(span));
310        }
311        progress
312            .connection_closed(|| ConnectionClosed { connection_id })
313            .await
314            .ok();
315    }
316    .instrument(span)
317    .await
318}
319
320/// Describes how to handle errors for a stream.
321pub trait ErrorHandler {
322    type W: AsyncStreamWriter;
323    type R: AsyncStreamReader;
324    fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
325    fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
326}
327
328async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
329    pair: &mut StreamPair<R, W>,
330    r: Result<T, E>,
331) -> Result<T, E> {
332    match r {
333        Ok(x) => Ok(x),
334        Err(e) => {
335            pair.writer.reset(e.code()).ok();
336            Err(e)
337        }
338    }
339}
340async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
341    writer: &mut ProgressWriter<W>,
342    r: Result<T, E>,
343) -> Result<T, E> {
344    match r {
345        Ok(x) => {
346            writer.transfer_completed().await;
347            Ok(x)
348        }
349        Err(e) => {
350            writer.inner.reset(e.code()).ok();
351            writer.transfer_aborted().await;
352            Err(e)
353        }
354    }
355}
356async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
357    reader: &mut ProgressReader<R>,
358    r: Result<T, E>,
359) -> Result<T, E> {
360    match r {
361        Ok(x) => {
362            reader.transfer_completed().await;
363            Ok(x)
364        }
365        Err(e) => {
366            reader.inner.stop(e.code()).ok();
367            reader.transfer_aborted().await;
368            Err(e)
369        }
370    }
371}
372
373pub async fn handle_stream<R: RecvStream, W: SendStream>(
374    mut pair: StreamPair<R, W>,
375    store: Store,
376) -> anyhow::Result<()> {
377    let request = pair.read_request().await?;
378    match request {
379        Request::Get(request) => handle_get(pair, store, request).await?,
380        Request::GetMany(request) => handle_get_many(pair, store, request).await?,
381        Request::Observe(request) => handle_observe(pair, store, request).await?,
382        Request::Push(request) => handle_push(pair, store, request).await?,
383        _ => {}
384    }
385    Ok(())
386}
387
388#[derive(Debug, Snafu)]
389#[snafu(module)]
390pub enum HandleGetError {
391    #[snafu(transparent)]
392    ExportBao {
393        source: ExportBaoError,
394    },
395    InvalidHashSeq,
396    InvalidOffset,
397}
398
399impl HasErrorCode for HandleGetError {
400    fn code(&self) -> VarInt {
401        match self {
402            HandleGetError::ExportBao {
403                source: ExportBaoError::ClientError { source, .. },
404            } => source.code(),
405            HandleGetError::InvalidHashSeq => ERR_INTERNAL,
406            HandleGetError::InvalidOffset => ERR_INTERNAL,
407            _ => ERR_INTERNAL,
408        }
409    }
410}
411
412/// Handle a single get request.
413///
414/// Requires a database, the request, and a writer.
415async fn handle_get_impl<W: SendStream>(
416    store: Store,
417    request: GetRequest,
418    writer: &mut ProgressWriter<W>,
419) -> Result<(), HandleGetError> {
420    let hash = request.hash;
421    debug!(%hash, "get received request");
422    let mut hash_seq = None;
423    for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
424        if offset == 0 {
425            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
426        } else {
427            // todo: this assumes that 1. the hashseq is complete and 2. it is
428            // small enough to fit in memory.
429            //
430            // This should really read the hashseq from the store in chunks,
431            // only where needed, so we can deal with holes and large hashseqs.
432            let hash_seq = match &hash_seq {
433                Some(b) => b,
434                None => {
435                    let bytes = store.get_bytes(hash).await?;
436                    let hs =
437                        HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?;
438                    hash_seq = Some(hs);
439                    hash_seq.as_ref().unwrap()
440                }
441            };
442            let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?;
443            let Some(hash) = hash_seq.get(o) else {
444                break;
445            };
446            send_blob(&store, offset, hash, ranges.clone(), writer).await?;
447        }
448    }
449    writer
450        .inner
451        .sync()
452        .await
453        .map_err(|e| HandleGetError::ExportBao { source: e.into() })?;
454
455    Ok(())
456}
457
458pub async fn handle_get<R: RecvStream, W: SendStream>(
459    mut pair: StreamPair<R, W>,
460    store: Store,
461    request: GetRequest,
462) -> anyhow::Result<()> {
463    let res = pair.get_request(|| request.clone()).await;
464    let tracker = handle_read_request_result(&mut pair, res).await?;
465    let mut writer = pair.into_writer(tracker).await?;
466    let res = handle_get_impl(store, request, &mut writer).await;
467    handle_write_result(&mut writer, res).await?;
468    Ok(())
469}
470
471#[derive(Debug, Snafu)]
472pub enum HandleGetManyError {
473    #[snafu(transparent)]
474    ExportBao { source: ExportBaoError },
475}
476
477impl HasErrorCode for HandleGetManyError {
478    fn code(&self) -> VarInt {
479        match self {
480            Self::ExportBao {
481                source: ExportBaoError::ClientError { source, .. },
482            } => source.code(),
483            _ => ERR_INTERNAL,
484        }
485    }
486}
487
488/// Handle a single get request.
489///
490/// Requires a database, the request, and a writer.
491async fn handle_get_many_impl<W: SendStream>(
492    store: Store,
493    request: GetManyRequest,
494    writer: &mut ProgressWriter<W>,
495) -> Result<(), HandleGetManyError> {
496    debug!("get_many received request");
497    let request_ranges = request.ranges.iter_infinite();
498    for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
499        if !ranges.is_empty() {
500            send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
501        }
502    }
503    Ok(())
504}
505
506pub async fn handle_get_many<R: RecvStream, W: SendStream>(
507    mut pair: StreamPair<R, W>,
508    store: Store,
509    request: GetManyRequest,
510) -> anyhow::Result<()> {
511    let res = pair.get_many_request(|| request.clone()).await;
512    let tracker = handle_read_request_result(&mut pair, res).await?;
513    let mut writer = pair.into_writer(tracker).await?;
514    let res = handle_get_many_impl(store, request, &mut writer).await;
515    handle_write_result(&mut writer, res).await?;
516    Ok(())
517}
518
519#[derive(Debug, Snafu)]
520pub enum HandlePushError {
521    #[snafu(transparent)]
522    ExportBao {
523        source: ExportBaoError,
524    },
525
526    InvalidHashSeq,
527
528    #[snafu(transparent)]
529    Request {
530        source: RequestError,
531    },
532}
533
534impl HasErrorCode for HandlePushError {
535    fn code(&self) -> VarInt {
536        match self {
537            Self::ExportBao {
538                source: ExportBaoError::ClientError { source, .. },
539            } => source.code(),
540            _ => ERR_INTERNAL,
541        }
542    }
543}
544
545/// Handle a single push request.
546///
547/// Requires a database, the request, and a reader.
548async fn handle_push_impl<R: RecvStream>(
549    store: Store,
550    request: PushRequest,
551    reader: &mut ProgressReader<R>,
552) -> Result<(), HandlePushError> {
553    let hash = request.hash;
554    debug!(%hash, "push received request");
555    let mut request_ranges = request.ranges.iter_infinite();
556    let root_ranges = request_ranges.next().expect("infinite iterator");
557    if !root_ranges.is_empty() {
558        // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress
559        store
560            .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
561            .await?;
562    }
563    if request.ranges.is_blob() {
564        debug!("push request complete");
565        return Ok(());
566    }
567    // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now!
568    let hash_seq = store.get_bytes(hash).await?;
569    let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?;
570    for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
571        if child_ranges.is_empty() {
572            continue;
573        }
574        store
575            .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
576            .await?;
577    }
578    Ok(())
579}
580
581pub async fn handle_push<R: RecvStream, W: SendStream>(
582    mut pair: StreamPair<R, W>,
583    store: Store,
584    request: PushRequest,
585) -> anyhow::Result<()> {
586    let res = pair.push_request(|| request.clone()).await;
587    let tracker = handle_read_request_result(&mut pair, res).await?;
588    let mut reader = pair.into_reader(tracker).await?;
589    let res = handle_push_impl(store, request, &mut reader).await;
590    handle_read_result(&mut reader, res).await?;
591    Ok(())
592}
593
594/// Send a blob to the client.
595pub(crate) async fn send_blob<W: SendStream>(
596    store: &Store,
597    index: u64,
598    hash: Hash,
599    ranges: ChunkRanges,
600    writer: &mut ProgressWriter<W>,
601) -> ExportBaoResult<()> {
602    store
603        .export_bao(hash, ranges)
604        .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
605        .await
606}
607
608#[derive(Debug, Snafu)]
609pub enum HandleObserveError {
610    ObserveStreamClosed,
611
612    #[snafu(transparent)]
613    RemoteClosed {
614        source: io::Error,
615    },
616}
617
618impl HasErrorCode for HandleObserveError {
619    fn code(&self) -> VarInt {
620        ERR_INTERNAL
621    }
622}
623
624/// Handle a single push request.
625///
626/// Requires a database, the request, and a reader.
627async fn handle_observe_impl<W: SendStream>(
628    store: Store,
629    request: ObserveRequest,
630    writer: &mut ProgressWriter<W>,
631) -> std::result::Result<(), HandleObserveError> {
632    let mut stream = store
633        .observe(request.hash)
634        .stream()
635        .await
636        .map_err(|_| HandleObserveError::ObserveStreamClosed)?;
637    let mut old = stream
638        .next()
639        .await
640        .ok_or(HandleObserveError::ObserveStreamClosed)?;
641    // send the initial bitfield
642    send_observe_item(writer, &old).await?;
643    // send updates until the remote loses interest
644    loop {
645        select! {
646            new = stream.next() => {
647                let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?;
648                let diff = old.diff(&new);
649                if diff.is_empty() {
650                    continue;
651                }
652                send_observe_item(writer, &diff).await?;
653                old = new;
654            }
655            _ = writer.inner.stopped() => {
656                debug!("observer closed");
657                break;
658            }
659        }
660    }
661    Ok(())
662}
663
664async fn send_observe_item<W: SendStream>(
665    writer: &mut ProgressWriter<W>,
666    item: &Bitfield,
667) -> io::Result<()> {
668    let item = ObserveItem::from(item);
669    let len = writer.inner.write_length_prefixed(item).await?;
670    writer.context.log_other_write(len);
671    Ok(())
672}
673
674pub async fn handle_observe<R: RecvStream, W: SendStream>(
675    mut pair: StreamPair<R, W>,
676    store: Store,
677    request: ObserveRequest,
678) -> anyhow::Result<()> {
679    let res = pair.observe_request(|| request.clone()).await;
680    let tracker = handle_read_request_result(&mut pair, res).await?;
681    let mut writer = pair.into_writer(tracker).await?;
682    let res = handle_observe_impl(store, request, &mut writer).await;
683    handle_write_result(&mut writer, res).await?;
684    Ok(())
685}
686
687pub struct ProgressReader<R: RecvStream = DefaultReader> {
688    inner: R,
689    context: ReaderContext,
690}
691
692impl<R: RecvStream> ProgressReader<R> {
693    async fn transfer_aborted(&self) {
694        self.context
695            .tracker
696            .transfer_aborted(|| Box::new(self.context.stats()))
697            .await
698            .ok();
699    }
700
701    async fn transfer_completed(&self) {
702        self.context
703            .tracker
704            .transfer_completed(|| Box::new(self.context.stats()))
705            .await
706            .ok();
707    }
708}