use std::{
fmt::Debug,
future::Future,
io,
time::{Duration, Instant},
};
use anyhow::Result;
use bao_tree::ChunkRanges;
use iroh::endpoint;
use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
use n0_future::StreamExt;
use quinn::{ConnectionError, VarInt};
use serde::{Deserialize, Serialize};
use snafu::Snafu;
use tokio::select;
use tracing::{debug, debug_span, warn, Instrument};
use crate::{
api::{
blobs::{Bitfield, WriteProgress},
ExportBaoError, ExportBaoResult, RequestError, Store,
},
hashseq::HashSeq,
protocol::{
GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL,
},
provider::events::{
ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError,
RequestTracker,
},
util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt},
Hash,
};
pub mod events;
use events::EventSender;
type DefaultWriter = iroh::endpoint::SendStream;
type DefaultReader = iroh::endpoint::RecvStream;
#[derive(Debug, Serialize, Deserialize)]
pub struct TransferStats {
pub payload_bytes_sent: u64,
pub other_bytes_sent: u64,
pub other_bytes_read: u64,
pub duration: Duration,
}
#[derive(Debug)]
pub struct StreamPair<R: RecvStream = DefaultReader, W: SendStream = DefaultWriter> {
t0: Instant,
connection_id: u64,
reader: R,
writer: W,
other_bytes_read: u64,
events: EventSender,
}
impl StreamPair {
pub async fn accept(
conn: &endpoint::Connection,
events: EventSender,
) -> Result<Self, ConnectionError> {
let (writer, reader) = conn.accept_bi().await?;
Ok(Self::new(conn.stable_id() as u64, reader, writer, events))
}
}
impl<R: RecvStream, W: SendStream> StreamPair<R, W> {
pub fn stream_id(&self) -> u64 {
self.reader.id()
}
pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self {
Self {
t0: Instant::now(),
connection_id,
reader,
writer,
other_bytes_read: 0,
events,
}
}
pub async fn read_request(&mut self) -> Result<Request> {
let (res, size) = Request::read_async(&mut self.reader).await?;
self.other_bytes_read += size as u64;
Ok(res)
}
pub async fn into_writer(
mut self,
tracker: RequestTracker,
) -> Result<ProgressWriter<W>, io::Error> {
self.reader.expect_eof().await?;
drop(self.reader);
Ok(ProgressWriter::new(
self.writer,
WriterContext {
t0: self.t0,
other_bytes_read: self.other_bytes_read,
payload_bytes_written: 0,
other_bytes_written: 0,
tracker,
},
))
}
pub async fn into_reader(
mut self,
tracker: RequestTracker,
) -> Result<ProgressReader<R>, io::Error> {
self.writer.sync().await?;
drop(self.writer);
Ok(ProgressReader {
inner: self.reader,
context: ReaderContext {
t0: self.t0,
other_bytes_read: self.other_bytes_read,
tracker,
},
})
}
pub async fn get_request(
&self,
f: impl FnOnce() -> GetRequest,
) -> Result<RequestTracker, ProgressError> {
self.events
.request(f, self.connection_id, self.reader.id())
.await
}
pub async fn get_many_request(
&self,
f: impl FnOnce() -> GetManyRequest,
) -> Result<RequestTracker, ProgressError> {
self.events
.request(f, self.connection_id, self.reader.id())
.await
}
pub async fn push_request(
&self,
f: impl FnOnce() -> PushRequest,
) -> Result<RequestTracker, ProgressError> {
self.events
.request(f, self.connection_id, self.reader.id())
.await
}
pub async fn observe_request(
&self,
f: impl FnOnce() -> ObserveRequest,
) -> Result<RequestTracker, ProgressError> {
self.events
.request(f, self.connection_id, self.reader.id())
.await
}
pub fn stats(&self) -> TransferStats {
TransferStats {
payload_bytes_sent: 0,
other_bytes_sent: 0,
other_bytes_read: self.other_bytes_read,
duration: self.t0.elapsed(),
}
}
}
#[derive(Debug)]
struct ReaderContext {
t0: Instant,
other_bytes_read: u64,
tracker: RequestTracker,
}
impl ReaderContext {
fn stats(&self) -> TransferStats {
TransferStats {
payload_bytes_sent: 0,
other_bytes_sent: 0,
other_bytes_read: self.other_bytes_read,
duration: self.t0.elapsed(),
}
}
}
#[derive(Debug)]
pub(crate) struct WriterContext {
t0: Instant,
other_bytes_read: u64,
payload_bytes_written: u64,
other_bytes_written: u64,
tracker: RequestTracker,
}
impl WriterContext {
fn stats(&self) -> TransferStats {
TransferStats {
payload_bytes_sent: self.payload_bytes_written,
other_bytes_sent: self.other_bytes_written,
other_bytes_read: self.other_bytes_read,
duration: self.t0.elapsed(),
}
}
}
impl WriteProgress for WriterContext {
async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult {
let len = len as u64;
let end_offset = offset + len;
self.payload_bytes_written += len;
self.tracker.transfer_progress(len, end_offset).await
}
fn log_other_write(&mut self, len: usize) {
self.other_bytes_written += len as u64;
}
async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) {
self.tracker.transfer_started(index, hash, size).await.ok();
}
}
#[derive(Debug)]
pub struct ProgressWriter<W: SendStream = DefaultWriter> {
pub inner: W,
pub(crate) context: WriterContext,
}
impl<W: SendStream> ProgressWriter<W> {
fn new(inner: W, context: WriterContext) -> Self {
Self { inner, context }
}
async fn transfer_aborted(&self) {
self.context
.tracker
.transfer_aborted(|| Box::new(self.context.stats()))
.await
.ok();
}
async fn transfer_completed(&self) {
self.context
.tracker
.transfer_completed(|| Box::new(self.context.stats()))
.await
.ok();
}
}
pub async fn handle_connection(
connection: endpoint::Connection,
store: Store,
progress: EventSender,
) {
let connection_id = connection.stable_id() as u64;
let span = debug_span!("connection", connection_id);
async move {
let Ok(node_id) = connection.remote_node_id() else {
warn!("failed to get node id");
return;
};
if let Err(cause) = progress
.client_connected(|| ClientConnected {
connection_id,
node_id,
})
.await
{
connection.close(cause.code(), cause.reason());
debug!("closing connection: {cause}");
return;
}
while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await {
let span = debug_span!("stream", stream_id = %pair.stream_id());
let store = store.clone();
tokio::spawn(handle_stream(pair, store).instrument(span));
}
progress
.connection_closed(|| ConnectionClosed { connection_id })
.await
.ok();
}
.instrument(span)
.await
}
pub trait ErrorHandler {
type W: AsyncStreamWriter;
type R: AsyncStreamReader;
fn stop(reader: &mut Self::R, code: VarInt) -> impl Future<Output = ()>;
fn reset(writer: &mut Self::W, code: VarInt) -> impl Future<Output = ()>;
}
async fn handle_read_request_result<R: RecvStream, W: SendStream, T, E: HasErrorCode>(
pair: &mut StreamPair<R, W>,
r: Result<T, E>,
) -> Result<T, E> {
match r {
Ok(x) => Ok(x),
Err(e) => {
pair.writer.reset(e.code()).ok();
Err(e)
}
}
}
async fn handle_write_result<W: SendStream, T, E: HasErrorCode>(
writer: &mut ProgressWriter<W>,
r: Result<T, E>,
) -> Result<T, E> {
match r {
Ok(x) => {
writer.transfer_completed().await;
Ok(x)
}
Err(e) => {
writer.inner.reset(e.code()).ok();
writer.transfer_aborted().await;
Err(e)
}
}
}
async fn handle_read_result<R: RecvStream, T, E: HasErrorCode>(
reader: &mut ProgressReader<R>,
r: Result<T, E>,
) -> Result<T, E> {
match r {
Ok(x) => {
reader.transfer_completed().await;
Ok(x)
}
Err(e) => {
reader.inner.stop(e.code()).ok();
reader.transfer_aborted().await;
Err(e)
}
}
}
pub async fn handle_stream<R: RecvStream, W: SendStream>(
mut pair: StreamPair<R, W>,
store: Store,
) -> anyhow::Result<()> {
let request = pair.read_request().await?;
match request {
Request::Get(request) => handle_get(pair, store, request).await?,
Request::GetMany(request) => handle_get_many(pair, store, request).await?,
Request::Observe(request) => handle_observe(pair, store, request).await?,
Request::Push(request) => handle_push(pair, store, request).await?,
_ => {}
}
Ok(())
}
#[derive(Debug, Snafu)]
#[snafu(module)]
pub enum HandleGetError {
#[snafu(transparent)]
ExportBao {
source: ExportBaoError,
},
InvalidHashSeq,
InvalidOffset,
}
impl HasErrorCode for HandleGetError {
fn code(&self) -> VarInt {
match self {
HandleGetError::ExportBao {
source: ExportBaoError::Progress { source, .. },
} => source.code(),
HandleGetError::InvalidHashSeq => ERR_INTERNAL,
HandleGetError::InvalidOffset => ERR_INTERNAL,
_ => ERR_INTERNAL,
}
}
}
async fn handle_get_impl<W: SendStream>(
store: Store,
request: GetRequest,
writer: &mut ProgressWriter<W>,
) -> Result<(), HandleGetError> {
let hash = request.hash;
debug!(%hash, "get received request");
let mut hash_seq = None;
for (offset, ranges) in request.ranges.iter_non_empty_infinite() {
if offset == 0 {
send_blob(&store, offset, hash, ranges.clone(), writer).await?;
} else {
let hash_seq = match &hash_seq {
Some(b) => b,
None => {
let bytes = store.get_bytes(hash).await?;
let hs =
HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?;
hash_seq = Some(hs);
hash_seq.as_ref().unwrap()
}
};
let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?;
let Some(hash) = hash_seq.get(o) else {
break;
};
send_blob(&store, offset, hash, ranges.clone(), writer).await?;
}
}
writer
.inner
.sync()
.await
.map_err(|e| HandleGetError::ExportBao { source: e.into() })?;
Ok(())
}
pub async fn handle_get<R: RecvStream, W: SendStream>(
mut pair: StreamPair<R, W>,
store: Store,
request: GetRequest,
) -> anyhow::Result<()> {
let res = pair.get_request(|| request.clone()).await;
let tracker = handle_read_request_result(&mut pair, res).await?;
let mut writer = pair.into_writer(tracker).await?;
let res = handle_get_impl(store, request, &mut writer).await;
handle_write_result(&mut writer, res).await?;
Ok(())
}
#[derive(Debug, Snafu)]
pub enum HandleGetManyError {
#[snafu(transparent)]
ExportBao { source: ExportBaoError },
}
impl HasErrorCode for HandleGetManyError {
fn code(&self) -> VarInt {
match self {
Self::ExportBao {
source: ExportBaoError::Progress { source, .. },
} => source.code(),
_ => ERR_INTERNAL,
}
}
}
async fn handle_get_many_impl<W: SendStream>(
store: Store,
request: GetManyRequest,
writer: &mut ProgressWriter<W>,
) -> Result<(), HandleGetManyError> {
debug!("get_many received request");
let request_ranges = request.ranges.iter_infinite();
for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() {
if !ranges.is_empty() {
send_blob(&store, child as u64, *hash, ranges.clone(), writer).await?;
}
}
Ok(())
}
pub async fn handle_get_many<R: RecvStream, W: SendStream>(
mut pair: StreamPair<R, W>,
store: Store,
request: GetManyRequest,
) -> anyhow::Result<()> {
let res = pair.get_many_request(|| request.clone()).await;
let tracker = handle_read_request_result(&mut pair, res).await?;
let mut writer = pair.into_writer(tracker).await?;
let res = handle_get_many_impl(store, request, &mut writer).await;
handle_write_result(&mut writer, res).await?;
Ok(())
}
#[derive(Debug, Snafu)]
pub enum HandlePushError {
#[snafu(transparent)]
ExportBao {
source: ExportBaoError,
},
InvalidHashSeq,
#[snafu(transparent)]
Request {
source: RequestError,
},
}
impl HasErrorCode for HandlePushError {
fn code(&self) -> VarInt {
match self {
Self::ExportBao {
source: ExportBaoError::Progress { source, .. },
} => source.code(),
_ => ERR_INTERNAL,
}
}
}
async fn handle_push_impl<R: RecvStream>(
store: Store,
request: PushRequest,
reader: &mut ProgressReader<R>,
) -> Result<(), HandlePushError> {
let hash = request.hash;
debug!(%hash, "push received request");
let mut request_ranges = request.ranges.iter_infinite();
let root_ranges = request_ranges.next().expect("infinite iterator");
if !root_ranges.is_empty() {
store
.import_bao_reader(hash, root_ranges.clone(), &mut reader.inner)
.await?;
}
if request.ranges.is_blob() {
debug!("push request complete");
return Ok(());
}
let hash_seq = store.get_bytes(hash).await?;
let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?;
for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) {
if child_ranges.is_empty() {
continue;
}
store
.import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner)
.await?;
}
Ok(())
}
pub async fn handle_push<R: RecvStream, W: SendStream>(
mut pair: StreamPair<R, W>,
store: Store,
request: PushRequest,
) -> anyhow::Result<()> {
let res = pair.push_request(|| request.clone()).await;
let tracker = handle_read_request_result(&mut pair, res).await?;
let mut reader = pair.into_reader(tracker).await?;
let res = handle_push_impl(store, request, &mut reader).await;
handle_read_result(&mut reader, res).await?;
Ok(())
}
pub(crate) async fn send_blob<W: SendStream>(
store: &Store,
index: u64,
hash: Hash,
ranges: ChunkRanges,
writer: &mut ProgressWriter<W>,
) -> ExportBaoResult<()> {
store
.export_bao(hash, ranges)
.write_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
.await
}
#[derive(Debug, Snafu)]
pub enum HandleObserveError {
ObserveStreamClosed,
#[snafu(transparent)]
RemoteClosed {
source: io::Error,
},
}
impl HasErrorCode for HandleObserveError {
fn code(&self) -> VarInt {
ERR_INTERNAL
}
}
async fn handle_observe_impl<W: SendStream>(
store: Store,
request: ObserveRequest,
writer: &mut ProgressWriter<W>,
) -> std::result::Result<(), HandleObserveError> {
let mut stream = store
.observe(request.hash)
.stream()
.await
.map_err(|_| HandleObserveError::ObserveStreamClosed)?;
let mut old = stream
.next()
.await
.ok_or(HandleObserveError::ObserveStreamClosed)?;
send_observe_item(writer, &old).await?;
loop {
select! {
new = stream.next() => {
let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?;
let diff = old.diff(&new);
if diff.is_empty() {
continue;
}
send_observe_item(writer, &diff).await?;
old = new;
}
_ = writer.inner.stopped() => {
debug!("observer closed");
break;
}
}
}
Ok(())
}
async fn send_observe_item<W: SendStream>(
writer: &mut ProgressWriter<W>,
item: &Bitfield,
) -> io::Result<()> {
let item = ObserveItem::from(item);
let len = writer.inner.write_length_prefixed(item).await?;
writer.context.log_other_write(len);
Ok(())
}
pub async fn handle_observe<R: RecvStream, W: SendStream>(
mut pair: StreamPair<R, W>,
store: Store,
request: ObserveRequest,
) -> anyhow::Result<()> {
let res = pair.observe_request(|| request.clone()).await;
let tracker = handle_read_request_result(&mut pair, res).await?;
let mut writer = pair.into_writer(tracker).await?;
let res = handle_observe_impl(store, request, &mut writer).await;
handle_write_result(&mut writer, res).await?;
Ok(())
}
pub struct ProgressReader<R: RecvStream = DefaultReader> {
inner: R,
context: ReaderContext,
}
impl<R: RecvStream> ProgressReader<R> {
async fn transfer_aborted(&self) {
self.context
.tracker
.transfer_aborted(|| Box::new(self.context.stats()))
.await
.ok();
}
async fn transfer_completed(&self) {
self.context
.tracker
.transfer_completed(|| Box::new(self.context.stats()))
.await
.ok();
}
}