#![allow(missing_docs)]
use std::{collections::BTreeSet, fmt::Debug, ops::DerefMut, sync::Arc};
use anyhow::{bail, Result};
use futures_lite::future::Boxed as BoxedFuture;
use futures_util::future::BoxFuture;
use iroh::{endpoint::Connecting, protocol::ProtocolHandler, Endpoint, NodeAddr};
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::{
downloader::Downloader,
provider::EventSender,
store::GcConfig,
util::{
local_pool::{self, LocalPoolHandle},
SetTagOption,
},
BlobFormat, Hash,
};
pub type ProtectCb = Box<dyn Fn(&mut BTreeSet<Hash>) -> BoxFuture<()> + Send + Sync>;
#[derive(derive_more::Debug)]
enum GcState {
Initial(#[debug(skip)] Vec<ProtectCb>),
Started(#[allow(dead_code)] Option<local_pool::Run<()>>),
}
impl Default for GcState {
fn default() -> Self {
Self::Initial(Vec::new())
}
}
#[derive(Debug)]
pub(crate) struct BlobsInner<S> {
pub(crate) rt: LocalPoolHandle,
pub(crate) store: S,
events: EventSender,
pub(crate) downloader: Downloader,
pub(crate) endpoint: Endpoint,
gc_state: std::sync::Mutex<GcState>,
#[cfg(feature = "rpc")]
pub(crate) batches: tokio::sync::Mutex<BlobBatches>,
}
#[derive(Debug, Clone)]
pub struct Blobs<S> {
pub(crate) inner: Arc<BlobsInner<S>>,
#[cfg(feature = "rpc")]
pub(crate) rpc_handler: Arc<std::sync::OnceLock<crate::rpc::RpcHandler>>,
}
#[cfg(feature = "rpc")]
#[derive(Debug, Default)]
pub(crate) struct BlobBatches {
batches: std::collections::BTreeMap<BatchId, BlobBatch>,
max: u64,
}
#[cfg(feature = "rpc")]
#[derive(Debug, Default)]
struct BlobBatch {
tags: std::collections::BTreeMap<crate::HashAndFormat, Vec<crate::TempTag>>,
}
#[cfg(feature = "rpc")]
impl BlobBatches {
pub fn create(&mut self) -> BatchId {
let id = self.max;
self.max += 1;
BatchId(id)
}
pub fn store(&mut self, batch: BatchId, tt: crate::TempTag) {
let entry = self.batches.entry(batch).or_default();
entry.tags.entry(tt.hash_and_format()).or_default().push(tt);
}
pub fn remove_one(&mut self, batch: BatchId, content: &crate::HashAndFormat) -> Result<()> {
if let Some(batch) = self.batches.get_mut(&batch) {
if let Some(tags) = batch.tags.get_mut(content) {
tags.pop();
if tags.is_empty() {
batch.tags.remove(content);
}
return Ok(());
}
}
anyhow::bail!("tag not found in batch");
}
pub fn remove(&mut self, batch: BatchId) {
self.batches.remove(&batch);
}
}
#[derive(Debug)]
pub struct Builder<S> {
store: S,
events: Option<EventSender>,
}
impl<S: crate::store::Store> Builder<S> {
pub fn events(mut self, value: EventSender) -> Self {
self.events = Some(value);
self
}
pub fn build(self, rt: &LocalPoolHandle, endpoint: &Endpoint) -> Blobs<S> {
let downloader = Downloader::new(self.store.clone(), endpoint.clone(), rt.clone());
Blobs::new(
self.store,
rt.clone(),
self.events.unwrap_or_default(),
downloader,
endpoint.clone(),
)
}
}
impl<S> Blobs<S> {
pub fn builder(store: S) -> Builder<S> {
Builder {
store,
events: None,
}
}
}
impl Blobs<crate::store::mem::Store> {
pub fn memory() -> Builder<crate::store::mem::Store> {
Self::builder(crate::store::mem::Store::new())
}
}
impl Blobs<crate::store::fs::Store> {
pub async fn persistent(
path: impl AsRef<std::path::Path>,
) -> anyhow::Result<Builder<crate::store::fs::Store>> {
Ok(Self::builder(crate::store::fs::Store::load(path).await?))
}
}
impl<S: crate::store::Store> Blobs<S> {
pub fn new(
store: S,
rt: LocalPoolHandle,
events: EventSender,
downloader: Downloader,
endpoint: Endpoint,
) -> Self {
Self {
inner: Arc::new(BlobsInner {
rt,
store,
events,
downloader,
endpoint,
#[cfg(feature = "rpc")]
batches: Default::default(),
gc_state: Default::default(),
}),
#[cfg(feature = "rpc")]
rpc_handler: Default::default(),
}
}
pub fn store(&self) -> &S {
&self.inner.store
}
pub fn events(&self) -> &EventSender {
&self.inner.events
}
pub fn rt(&self) -> &LocalPoolHandle {
&self.inner.rt
}
pub fn downloader(&self) -> &Downloader {
&self.inner.downloader
}
pub fn endpoint(&self) -> &Endpoint {
&self.inner.endpoint
}
pub fn add_protected(&self, cb: ProtectCb) -> Result<()> {
let mut state = self.inner.gc_state.lock().unwrap();
match &mut *state {
GcState::Initial(cbs) => {
cbs.push(cb);
}
GcState::Started(_) => {
anyhow::bail!("cannot add protected blobs after gc has started");
}
}
Ok(())
}
pub fn start_gc(&self, config: GcConfig) -> Result<()> {
let mut state = self.inner.gc_state.lock().unwrap();
let protected = match state.deref_mut() {
GcState::Initial(items) => std::mem::take(items),
GcState::Started(_) => bail!("gc already started"),
};
let protected = Arc::new(protected);
let protected_cb = move || {
let protected = protected.clone();
async move {
let mut set = BTreeSet::new();
for cb in protected.iter() {
cb(&mut set).await;
}
set
}
};
let store = self.store().clone();
let run = self
.rt()
.spawn(move || async move { store.gc_run(config, protected_cb).await });
*state = GcState::Started(Some(run));
Ok(())
}
}
impl<S: crate::store::Store> ProtocolHandler for Blobs<S> {
fn accept(&self, conn: Connecting) -> BoxedFuture<Result<()>> {
let db = self.store().clone();
let events = self.events().clone();
let rt = self.rt().clone();
Box::pin(async move {
crate::provider::handle_connection(conn.await?, db, events, rt).await;
Ok(())
})
}
fn shutdown(&self) -> BoxedFuture<()> {
let store = self.store().clone();
Box::pin(async move {
store.shutdown().await;
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlobDownloadRequest {
pub hash: Hash,
pub format: BlobFormat,
pub nodes: Vec<NodeAddr>,
pub tag: SetTagOption,
pub mode: DownloadMode,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DownloadMode {
Direct,
Queued,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Serialize, Deserialize, Ord, Clone, Copy, Hash)]
pub struct BatchId(pub u64);