use std::{future::Future, io, pin::Pin};
use anyhow::anyhow;
use bao_tree::{ChunkNum, ChunkRanges};
use futures_lite::StreamExt;
use genawaiter::{
rc::{Co, Gen},
GeneratorState,
};
use iroh::endpoint::Connection;
use iroh_io::AsyncSliceReader;
use tokio::sync::oneshot;
use tracing::trace;
use crate::{
get::{
self,
fsm::{AtBlobHeader, AtEndBlob, ConnectedNext, EndBlobNext},
progress::{BlobId, DownloadProgressEvent},
Error, Stats,
},
hashseq::parse_hash_seq,
protocol::{GetRequest, RangeSpec, RangeSpecSeq},
store::{
BaoBatchWriter, BaoBlobSize, FallibleProgressBatchWriter, MapEntry, MapEntryMut, MapMut,
Store as BaoStore,
},
util::progress::{IdGenerator, ProgressSender},
BlobFormat, Hash, HashAndFormat,
};
type GetGenerator = Gen<Yield, (), Pin<Box<dyn Future<Output = Result<Stats, Error>>>>>;
type GetFuture = Pin<Box<dyn Future<Output = Result<Stats, Error>> + 'static>>;
pub async fn get_to_db<
D: BaoStore,
C: FnOnce() -> F,
F: Future<Output = anyhow::Result<Connection>>,
>(
db: &D,
get_conn: C,
hash_and_format: &HashAndFormat,
progress_sender: impl ProgressSender<Msg = DownloadProgressEvent> + IdGenerator,
) -> Result<Stats, Error> {
match get_to_db_in_steps(db.clone(), *hash_and_format, progress_sender).await? {
FetchState::Complete(res) => Ok(res),
FetchState::NeedsConn(state) => {
let conn = get_conn().await.map_err(Error::Io)?;
state.proceed(conn).await
}
}
}
pub async fn get_to_db_in_steps<
D: BaoStore,
P: ProgressSender<Msg = DownloadProgressEvent> + IdGenerator,
>(
db: D,
hash_and_format: HashAndFormat,
progress_sender: P,
) -> Result<FetchState, Error> {
let mut gen: GetGenerator = genawaiter::rc::Gen::new(move |co| {
let fut = async move { producer(co, &db, &hash_and_format, progress_sender).await };
let fut: GetFuture = Box::pin(fut);
fut
});
match gen.async_resume().await {
GeneratorState::Yielded(Yield::NeedConn(reply)) => {
Ok(FetchState::NeedsConn(FetchStateNeedsConn(gen, reply)))
}
GeneratorState::Complete(res) => res.map(FetchState::Complete),
}
}
#[derive(derive_more::Debug)]
#[debug("FetchStateNeedsConn")]
pub struct FetchStateNeedsConn(GetGenerator, oneshot::Sender<Connection>);
impl FetchStateNeedsConn {
pub async fn proceed(mut self, conn: Connection) -> Result<Stats, Error> {
self.1.send(conn).expect("receiver is not dropped");
match self.0.async_resume().await {
GeneratorState::Yielded(y) => match y {
Yield::NeedConn(_) => panic!("NeedsConn may only be yielded once"),
},
GeneratorState::Complete(res) => res,
}
}
}
#[derive(Debug)]
pub enum FetchState {
Complete(Stats),
NeedsConn(FetchStateNeedsConn),
}
struct GetCo(Co<Yield>);
impl GetCo {
async fn get_conn(&self) -> Connection {
let (tx, rx) = oneshot::channel();
self.0.yield_(Yield::NeedConn(tx)).await;
rx.await.expect("sender may not be dropped")
}
}
enum Yield {
NeedConn(oneshot::Sender<Connection>),
}
async fn producer<D: BaoStore>(
co: Co<Yield, ()>,
db: &D,
hash_and_format: &HashAndFormat,
progress: impl ProgressSender<Msg = DownloadProgressEvent> + IdGenerator,
) -> Result<Stats, Error> {
let HashAndFormat { hash, format } = hash_and_format;
let co = GetCo(co);
match format {
BlobFormat::Raw => get_blob(db, co, hash, progress).await,
BlobFormat::HashSeq => get_hash_seq(db, co, hash, progress).await,
}
}
async fn get_blob<D: BaoStore>(
db: &D,
co: GetCo,
hash: &Hash,
progress: impl ProgressSender<Msg = DownloadProgressEvent> + IdGenerator,
) -> Result<Stats, Error> {
let end = match db.get_mut(hash).await? {
Some(entry) if entry.is_complete() => {
tracing::info!("already got entire blob");
progress
.send(DownloadProgressEvent::FoundLocal {
child: BlobId::Root,
hash: *hash,
size: entry.size(),
valid_ranges: RangeSpec::all(),
})
.await?;
return Ok(Stats::default());
}
Some(entry) => {
trace!("got partial data for {}", hash);
let valid_ranges = valid_ranges::<D>(&entry)
.await
.ok()
.unwrap_or_else(ChunkRanges::all);
progress
.send(DownloadProgressEvent::FoundLocal {
child: BlobId::Root,
hash: *hash,
size: entry.size(),
valid_ranges: RangeSpec::new(&valid_ranges),
})
.await?;
let required_ranges: ChunkRanges = ChunkRanges::all().difference(&valid_ranges);
let request = GetRequest::new(*hash, RangeSpecSeq::from_ranges([required_ranges]));
let conn = co.get_conn().await;
let request = get::fsm::start(conn, request);
let connected = request.next().await?;
let ConnectedNext::StartRoot(start) = connected.next().await? else {
return Err(Error::NoncompliantNode(anyhow!("expected StartRoot")));
};
let header = start.next();
get_blob_inner_partial(db, header, entry, progress).await?
}
None => {
let conn = co.get_conn().await;
let request = get::fsm::start(conn, GetRequest::single(*hash));
let connected = request.next().await?;
let ConnectedNext::StartRoot(start) = connected.next().await? else {
return Err(Error::NoncompliantNode(anyhow!("expected StartRoot")));
};
let header = start.next();
get_blob_inner(db, header, progress).await?
}
};
let EndBlobNext::Closing(end) = end.next() else {
return Err(Error::NoncompliantNode(anyhow!("expected StartRoot")));
};
let stats = end.next().await?;
Ok(stats)
}
pub async fn valid_ranges<D: MapMut>(entry: &D::EntryMut) -> anyhow::Result<ChunkRanges> {
use tracing::trace as log;
let mut data_reader = entry.data_reader().await?;
let data_size = data_reader.size().await?;
let valid_from_data = ChunkRanges::from(..ChunkNum::full_chunks(data_size));
let mut outboard = entry.outboard().await?;
let all = ChunkRanges::all();
let mut stream = bao_tree::io::fsm::valid_outboard_ranges(&mut outboard, &all);
let mut valid_from_outboard = ChunkRanges::empty();
while let Some(range) = stream.next().await {
valid_from_outboard |= ChunkRanges::from(range?);
}
let valid: ChunkRanges = valid_from_data.intersection(&valid_from_outboard);
log!("valid_from_data: {:?}", valid_from_data);
log!("valid_from_outboard: {:?}", valid_from_data);
Ok(valid)
}
async fn get_blob_inner<D: BaoStore>(
db: &D,
at_header: AtBlobHeader,
sender: impl ProgressSender<Msg = DownloadProgressEvent> + IdGenerator,
) -> Result<AtEndBlob, Error> {
let (at_content, size) = at_header.next().await?;
let hash = at_content.hash();
let child_offset = at_content.offset();
let entry = db.get_or_create(hash, size).await?;
let bw = entry.batch_writer().await?;
let id = sender.new_id();
sender
.send(DownloadProgressEvent::Found {
id,
hash,
size,
child: BlobId::from_offset(child_offset),
})
.await?;
let sender2 = sender.clone();
let on_write = move |offset: u64, _length: usize| {
sender2
.try_send(DownloadProgressEvent::Progress { id, offset })
.inspect_err(|_| {
tracing::info!("aborting download of {}", hash);
})?;
Ok(())
};
let mut bw = FallibleProgressBatchWriter::new(bw, on_write);
let end = at_content.write_all_batch(&mut bw).await?;
bw.sync().await?;
drop(bw);
db.insert_complete(entry).await?;
sender.send(DownloadProgressEvent::Done { id }).await?;
Ok(end)
}
async fn get_blob_inner_partial<D: BaoStore>(
db: &D,
at_header: AtBlobHeader,
entry: D::EntryMut,
sender: impl ProgressSender<Msg = DownloadProgressEvent> + IdGenerator,
) -> Result<AtEndBlob, Error> {
let (at_content, size) = at_header.next().await?;
let bw = entry.batch_writer().await?;
let id = sender.new_id();
let hash = at_content.hash();
let child_offset = at_content.offset();
sender
.send(DownloadProgressEvent::Found {
id,
hash,
size,
child: BlobId::from_offset(child_offset),
})
.await?;
let sender2 = sender.clone();
let on_write = move |offset: u64, _length: usize| {
sender2
.try_send(DownloadProgressEvent::Progress { id, offset })
.inspect_err(|_| {
tracing::info!("aborting download of {}", hash);
})?;
Ok(())
};
let mut bw = FallibleProgressBatchWriter::new(bw, on_write);
let at_end = at_content.write_all_batch(&mut bw).await?;
bw.sync().await?;
drop(bw);
db.insert_complete(entry).await?;
sender.send(DownloadProgressEvent::Done { id }).await?;
Ok(at_end)
}
pub async fn blob_info<D: BaoStore>(db: &D, hash: &Hash) -> io::Result<BlobInfo<D>> {
io::Result::Ok(match db.get_mut(hash).await? {
Some(entry) if entry.is_complete() => BlobInfo::Complete {
size: entry.size().value(),
},
Some(entry) => {
let valid_ranges = valid_ranges::<D>(&entry)
.await
.ok()
.unwrap_or_else(ChunkRanges::all);
BlobInfo::Partial {
entry,
valid_ranges,
}
}
None => BlobInfo::Missing,
})
}
async fn blob_infos<D: BaoStore>(db: &D, hash_seq: &[Hash]) -> io::Result<Vec<BlobInfo<D>>> {
let items = futures_lite::stream::iter(hash_seq)
.then(|hash| blob_info(db, hash))
.collect::<Vec<_>>();
items.await.into_iter().collect()
}
async fn get_hash_seq<D: BaoStore>(
db: &D,
co: GetCo,
root_hash: &Hash,
sender: impl ProgressSender<Msg = DownloadProgressEvent> + IdGenerator,
) -> Result<Stats, Error> {
use tracing::info as log;
let finishing = match db.get_mut(root_hash).await? {
Some(entry) if entry.is_complete() => {
log!("already got collection - doing partial download");
sender
.send(DownloadProgressEvent::FoundLocal {
child: BlobId::Root,
hash: *root_hash,
size: entry.size(),
valid_ranges: RangeSpec::all(),
})
.await?;
let reader = entry.data_reader().await?;
let (mut hash_seq, children) = parse_hash_seq(reader).await.map_err(|err| {
Error::NoncompliantNode(anyhow!("Failed to parse downloaded HashSeq: {err}"))
})?;
sender
.send(DownloadProgressEvent::FoundHashSeq {
hash: *root_hash,
children,
})
.await?;
let mut children: Vec<Hash> = vec![];
while let Some(hash) = hash_seq.next().await? {
children.push(hash);
}
let missing_info = blob_infos(db, &children).await?;
for (i, info) in missing_info.iter().enumerate() {
if let Some(size) = info.size() {
sender
.send(DownloadProgressEvent::FoundLocal {
child: BlobId::from_offset((i as u64) + 1),
hash: children[i],
size,
valid_ranges: RangeSpec::new(info.valid_ranges()),
})
.await?;
}
}
if missing_info
.iter()
.all(|x| matches!(x, BlobInfo::Complete { .. }))
{
log!("nothing to do");
return Ok(Stats::default());
}
let missing_iter = std::iter::once(ChunkRanges::empty())
.chain(missing_info.iter().map(|x| x.missing_ranges()))
.collect::<Vec<_>>();
log!("requesting chunks {:?}", missing_iter);
let request = GetRequest::new(*root_hash, RangeSpecSeq::from_ranges(missing_iter));
let conn = co.get_conn().await;
let request = get::fsm::start(conn, request);
let connected = request.next().await?;
log!("connected");
let ConnectedNext::StartChild(start) = connected.next().await? else {
return Err(Error::NoncompliantNode(anyhow!("expected StartChild")));
};
let mut next = EndBlobNext::MoreChildren(start);
loop {
let start = match next {
EndBlobNext::MoreChildren(start) => start,
EndBlobNext::Closing(finish) => break finish,
};
let child_offset = usize::try_from(start.child_offset())
.map_err(|_| Error::NoncompliantNode(anyhow!("child offset too large")))?;
let (child_hash, info) =
match (children.get(child_offset), missing_info.get(child_offset)) {
(Some(blob), Some(info)) => (*blob, info),
_ => break start.finish(),
};
tracing::info!(
"requesting child {} {:?}",
child_hash,
info.missing_ranges()
);
let header = start.next(child_hash);
let end_blob = match info {
BlobInfo::Missing => get_blob_inner(db, header, sender.clone()).await?,
BlobInfo::Partial { entry, .. } => {
get_blob_inner_partial(db, header, entry.clone(), sender.clone()).await?
}
BlobInfo::Complete { .. } => {
return Err(Error::NoncompliantNode(anyhow!(
"got data we have not requested"
)));
}
};
next = end_blob.next();
}
}
_ => {
tracing::debug!("don't have collection - doing full download");
let conn = co.get_conn().await;
let request = get::fsm::start(conn, GetRequest::all(*root_hash));
let connected = request.next().await?;
let ConnectedNext::StartRoot(start) = connected.next().await? else {
return Err(Error::NoncompliantNode(anyhow!("expected StartRoot")));
};
let header = start.next();
let end_root = get_blob_inner(db, header, sender.clone()).await?;
let entry = db
.get(root_hash)
.await?
.ok_or_else(|| Error::LocalFailure(anyhow!("just downloaded but not in db")))?;
let reader = entry.data_reader().await?;
let (mut collection, count) = parse_hash_seq(reader).await.map_err(|err| {
Error::NoncompliantNode(anyhow!("Failed to parse downloaded HashSeq: {err}"))
})?;
sender
.send(DownloadProgressEvent::FoundHashSeq {
hash: *root_hash,
children: count,
})
.await?;
let mut children = vec![];
while let Some(hash) = collection.next().await? {
children.push(hash);
}
let mut next = end_root.next();
loop {
let start = match next {
EndBlobNext::MoreChildren(start) => start,
EndBlobNext::Closing(finish) => break finish,
};
let child_offset = usize::try_from(start.child_offset())
.map_err(|_| Error::NoncompliantNode(anyhow!("child offset too large")))?;
let child_hash = match children.get(child_offset) {
Some(blob) => *blob,
None => break start.finish(),
};
let header = start.next(child_hash);
let end_blob = get_blob_inner(db, header, sender.clone()).await?;
next = end_blob.next();
}
}
};
let stats = finishing.next().await?;
Ok(stats)
}
#[derive(Debug, Clone)]
pub enum BlobInfo<D: BaoStore> {
Complete {
size: u64,
},
Partial {
entry: D::EntryMut,
valid_ranges: ChunkRanges,
},
Missing,
}
impl<D: BaoStore> BlobInfo<D> {
pub fn size(&self) -> Option<BaoBlobSize> {
match self {
BlobInfo::Complete { size } => Some(BaoBlobSize::Verified(*size)),
BlobInfo::Partial { entry, .. } => Some(entry.size()),
BlobInfo::Missing => None,
}
}
pub fn valid_ranges(&self) -> ChunkRanges {
match self {
BlobInfo::Complete { .. } => ChunkRanges::all(),
BlobInfo::Partial { valid_ranges, .. } => valid_ranges.clone(),
BlobInfo::Missing => ChunkRanges::empty(),
}
}
pub fn missing_ranges(&self) -> ChunkRanges {
match self {
BlobInfo::Complete { .. } => ChunkRanges::empty(),
BlobInfo::Partial { valid_ranges, .. } => ChunkRanges::all().difference(valid_ranges),
BlobInfo::Missing => ChunkRanges::all(),
}
}
}