use std::{
path::{Path, PathBuf},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use anyhow::{anyhow, Context as _, Result};
use bytes::Bytes;
use derive_more::{Display, FromStr};
use futures_lite::{Stream, StreamExt};
use iroh::NodeAddr;
use iroh_blobs::{export::ExportProgress, store::ExportMode, Hash};
use portable_atomic::{AtomicBool, Ordering};
use quic_rpc::{
client::BoxedConnector, message::RpcMsg, transport::flume::FlumeConnector, Connector,
};
use serde::{Deserialize, Serialize};
use super::{authors, flatten};
use crate::{
actor::OpenState,
rpc::{
proto::{
CloseRequest, CreateRequest, DelRequest, DelResponse, DocListRequest,
DocSubscribeRequest, DropRequest, ExportFileRequest, GetDownloadPolicyRequest,
GetExactRequest, GetManyRequest, GetSyncPeersRequest, ImportFileRequest, ImportRequest,
LeaveRequest, OpenRequest, RpcService, SetDownloadPolicyRequest, SetHashRequest,
SetRequest, ShareRequest, StartSyncRequest, StatusRequest,
},
AddrInfoOptions,
},
store::{DownloadPolicy, Query},
AuthorId, Capability, CapabilityKind, DocTicket, NamespaceId, PeerIdBytes,
};
#[doc(inline)]
pub use crate::{
engine::{LiveEvent, Origin, SyncEvent, SyncReason},
Entry,
};
pub type MemClient =
Client<FlumeConnector<crate::rpc::proto::Response, crate::rpc::proto::Request>>;
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct Client<C = BoxedConnector<RpcService>> {
pub(super) rpc: quic_rpc::RpcClient<RpcService, C>,
}
impl<C: Connector<RpcService>> Client<C> {
pub fn new(rpc: quic_rpc::RpcClient<RpcService, C>) -> Self {
Self { rpc }
}
pub fn authors(&self) -> authors::Client<C> {
authors::Client::new(self.rpc.clone())
}
pub async fn create(&self) -> Result<Doc<C>> {
let res = self.rpc.rpc(CreateRequest {}).await??;
let doc = Doc::new(self.rpc.clone(), res.id);
Ok(doc)
}
pub async fn drop_doc(&self, doc_id: NamespaceId) -> Result<()> {
self.rpc.rpc(DropRequest { doc_id }).await??;
Ok(())
}
pub async fn import_namespace(&self, capability: Capability) -> Result<Doc<C>> {
let res = self.rpc.rpc(ImportRequest { capability }).await??;
let doc = Doc::new(self.rpc.clone(), res.doc_id);
Ok(doc)
}
pub async fn import(&self, ticket: DocTicket) -> Result<Doc<C>> {
let DocTicket { capability, nodes } = ticket;
let doc = self.import_namespace(capability).await?;
doc.start_sync(nodes).await?;
Ok(doc)
}
pub async fn import_and_subscribe(
&self,
ticket: DocTicket,
) -> Result<(Doc<C>, impl Stream<Item = anyhow::Result<LiveEvent>>)> {
let DocTicket { capability, nodes } = ticket;
let res = self.rpc.rpc(ImportRequest { capability }).await??;
let doc = Doc::new(self.rpc.clone(), res.doc_id);
let events = doc.subscribe().await?;
doc.start_sync(nodes).await?;
Ok((doc, events))
}
pub async fn list(&self) -> Result<impl Stream<Item = Result<(NamespaceId, CapabilityKind)>>> {
let stream = self.rpc.server_streaming(DocListRequest {}).await?;
Ok(flatten(stream).map(|res| res.map(|res| (res.id, res.capability))))
}
pub async fn open(&self, id: NamespaceId) -> Result<Option<Doc<C>>> {
self.rpc.rpc(OpenRequest { doc_id: id }).await??;
let doc = Doc::new(self.rpc.clone(), id);
Ok(Some(doc))
}
}
#[derive(Debug, Clone)]
pub struct Doc<C: Connector<RpcService> = BoxedConnector<RpcService>>(Arc<DocInner<C>>)
where
C: quic_rpc::Connector<RpcService>;
impl<C: Connector<RpcService>> PartialEq for Doc<C> {
fn eq(&self, other: &Self) -> bool {
self.0.id == other.0.id
}
}
impl<C: Connector<RpcService>> Eq for Doc<C> {}
#[derive(Debug)]
struct DocInner<C: Connector<RpcService> = BoxedConnector<RpcService>> {
id: NamespaceId,
rpc: quic_rpc::RpcClient<RpcService, C>,
closed: AtomicBool,
rt: tokio::runtime::Handle,
}
impl<C> Drop for DocInner<C>
where
C: quic_rpc::Connector<RpcService>,
{
fn drop(&mut self) {
let doc_id = self.id;
let rpc = self.rpc.clone();
if !self.closed.swap(true, Ordering::Relaxed) {
self.rt.spawn(async move {
rpc.rpc(CloseRequest { doc_id }).await.ok();
});
}
}
}
impl<C: Connector<RpcService>> Doc<C> {
fn new(rpc: quic_rpc::RpcClient<RpcService, C>, id: NamespaceId) -> Self {
Self(Arc::new(DocInner {
rpc,
id,
closed: AtomicBool::new(false),
rt: tokio::runtime::Handle::current(),
}))
}
async fn rpc<M>(&self, msg: M) -> Result<M::Response>
where
M: RpcMsg<RpcService>,
{
let res = self.0.rpc.rpc(msg).await?;
Ok(res)
}
pub fn id(&self) -> NamespaceId {
self.0.id
}
pub async fn close(&self) -> Result<()> {
if !self.0.closed.swap(true, Ordering::Relaxed) {
self.rpc(CloseRequest { doc_id: self.id() }).await??;
}
Ok(())
}
fn ensure_open(&self) -> Result<()> {
if self.0.closed.load(Ordering::Relaxed) {
Err(anyhow!("document is closed"))
} else {
Ok(())
}
}
pub async fn set_bytes(
&self,
author_id: AuthorId,
key: impl Into<Bytes>,
value: impl Into<Bytes>,
) -> Result<Hash> {
self.ensure_open()?;
let res = self
.rpc(SetRequest {
doc_id: self.id(),
author_id,
key: key.into(),
value: value.into(),
})
.await??;
Ok(res.entry.content_hash())
}
pub async fn set_hash(
&self,
author_id: AuthorId,
key: impl Into<Bytes>,
hash: Hash,
size: u64,
) -> Result<()> {
self.ensure_open()?;
self.rpc(SetHashRequest {
doc_id: self.id(),
author_id,
key: key.into(),
hash,
size,
})
.await??;
Ok(())
}
pub async fn import_file(
&self,
author: AuthorId,
key: Bytes,
path: impl AsRef<Path>,
in_place: bool,
) -> Result<ImportFileProgress> {
self.ensure_open()?;
let stream = self
.0
.rpc
.server_streaming(ImportFileRequest {
doc_id: self.id(),
author_id: author,
path: path.as_ref().into(),
key,
in_place,
})
.await?;
Ok(ImportFileProgress::new(stream))
}
pub async fn export_file(
&self,
entry: Entry,
path: impl AsRef<Path>,
mode: ExportMode,
) -> Result<ExportFileProgress> {
self.ensure_open()?;
let stream = self
.0
.rpc
.server_streaming(ExportFileRequest {
entry,
path: path.as_ref().into(),
mode,
})
.await?;
Ok(ExportFileProgress::new(stream))
}
pub async fn del(&self, author_id: AuthorId, prefix: impl Into<Bytes>) -> Result<usize> {
self.ensure_open()?;
let res = self
.rpc(DelRequest {
doc_id: self.id(),
author_id,
prefix: prefix.into(),
})
.await??;
let DelResponse { removed } = res;
Ok(removed)
}
pub async fn get_exact(
&self,
author: AuthorId,
key: impl AsRef<[u8]>,
include_empty: bool,
) -> Result<Option<Entry>> {
self.ensure_open()?;
let res = self
.rpc(GetExactRequest {
author,
key: key.as_ref().to_vec().into(),
doc_id: self.id(),
include_empty,
})
.await??;
Ok(res.entry.map(|entry| entry.into()))
}
pub async fn get_many(
&self,
query: impl Into<Query>,
) -> Result<impl Stream<Item = Result<Entry>>> {
self.ensure_open()?;
let stream = self
.0
.rpc
.server_streaming(GetManyRequest {
doc_id: self.id(),
query: query.into(),
})
.await?;
Ok(flatten(stream).map(|res| res.map(|res| res.entry.into())))
}
pub async fn get_one(&self, query: impl Into<Query>) -> Result<Option<Entry>> {
self.get_many(query).await?.next().await.transpose()
}
pub async fn share(
&self,
mode: ShareMode,
addr_options: AddrInfoOptions,
) -> anyhow::Result<DocTicket> {
self.ensure_open()?;
let res = self
.rpc(ShareRequest {
doc_id: self.id(),
mode,
addr_options,
})
.await??;
Ok(res.0)
}
pub async fn start_sync(&self, peers: Vec<NodeAddr>) -> Result<()> {
self.ensure_open()?;
let _res = self
.rpc(StartSyncRequest {
doc_id: self.id(),
peers,
})
.await??;
Ok(())
}
pub async fn leave(&self) -> Result<()> {
self.ensure_open()?;
let _res = self.rpc(LeaveRequest { doc_id: self.id() }).await??;
Ok(())
}
pub async fn subscribe(&self) -> anyhow::Result<impl Stream<Item = anyhow::Result<LiveEvent>>> {
self.ensure_open()?;
let stream = self
.0
.rpc
.try_server_streaming(DocSubscribeRequest { doc_id: self.id() })
.await?;
Ok(stream.map(|res| match res {
Ok(res) => Ok(res.event),
Err(err) => Err(err.into()),
}))
}
pub async fn status(&self) -> anyhow::Result<OpenState> {
self.ensure_open()?;
let res = self.rpc(StatusRequest { doc_id: self.id() }).await??;
Ok(res.status)
}
pub async fn set_download_policy(&self, policy: DownloadPolicy) -> Result<()> {
self.rpc(SetDownloadPolicyRequest {
doc_id: self.id(),
policy,
})
.await??;
Ok(())
}
pub async fn get_download_policy(&self) -> Result<DownloadPolicy> {
let res = self
.rpc(GetDownloadPolicyRequest { doc_id: self.id() })
.await??;
Ok(res.policy)
}
pub async fn get_sync_peers(&self) -> Result<Option<Vec<PeerIdBytes>>> {
let res = self
.rpc(GetSyncPeersRequest { doc_id: self.id() })
.await??;
Ok(res.peers)
}
}
impl<'a, C> From<&'a Doc<C>> for &'a quic_rpc::RpcClient<RpcService, C>
where
C: quic_rpc::Connector<RpcService>,
{
fn from(doc: &'a Doc<C>) -> &'a quic_rpc::RpcClient<RpcService, C> {
&doc.0.rpc
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum ImportProgress {
Found {
id: u64,
name: String,
size: u64,
},
Progress {
id: u64,
offset: u64,
},
IngestDone {
id: u64,
hash: Hash,
},
AllDone {
key: Bytes,
},
Abort(serde_error::Error),
}
#[derive(Serialize, Deserialize, Debug, Clone, Display, FromStr)]
pub enum ShareMode {
Read,
Write,
}
#[derive(derive_more::Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct ImportFileProgress {
#[debug(skip)]
stream: Pin<Box<dyn Stream<Item = Result<ImportProgress>> + Send + Unpin + 'static>>,
}
impl ImportFileProgress {
fn new(
stream: (impl Stream<Item = Result<impl Into<ImportProgress>, impl Into<anyhow::Error>>>
+ Send
+ Unpin
+ 'static),
) -> Self {
let stream = stream.map(|item| match item {
Ok(item) => Ok(item.into()),
Err(err) => Err(err.into()),
});
Self {
stream: Box::pin(stream),
}
}
pub async fn finish(mut self) -> Result<ImportFileOutcome> {
let mut entry_size = 0;
let mut entry_hash = None;
while let Some(msg) = self.next().await {
match msg? {
ImportProgress::Found { size, .. } => {
entry_size = size;
}
ImportProgress::AllDone { key } => {
let hash = entry_hash
.context("expected DocImportProgress::IngestDone event to occur")?;
let outcome = ImportFileOutcome {
hash,
key,
size: entry_size,
};
return Ok(outcome);
}
ImportProgress::Abort(err) => return Err(err.into()),
ImportProgress::Progress { .. } => {}
ImportProgress::IngestDone { hash, .. } => {
entry_hash = Some(hash);
}
}
}
Err(anyhow!("Response stream ended prematurely"))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ImportFileOutcome {
pub hash: Hash,
pub size: u64,
pub key: Bytes,
}
impl Stream for ImportFileProgress {
type Item = Result<ImportProgress>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
#[derive(derive_more::Debug)]
pub struct ExportFileProgress {
#[debug(skip)]
stream: Pin<Box<dyn Stream<Item = Result<ExportProgress>> + Send + Unpin + 'static>>,
}
impl ExportFileProgress {
fn new(
stream: (impl Stream<Item = Result<impl Into<ExportProgress>, impl Into<anyhow::Error>>>
+ Send
+ Unpin
+ 'static),
) -> Self {
let stream = stream.map(|item| match item {
Ok(item) => Ok(item.into()),
Err(err) => Err(err.into()),
});
Self {
stream: Box::pin(stream),
}
}
pub async fn finish(mut self) -> Result<ExportFileOutcome> {
let mut total_size = 0;
let mut path = None;
while let Some(msg) = self.next().await {
match msg? {
ExportProgress::Found { size, outpath, .. } => {
total_size = size.value();
path = Some(outpath);
}
ExportProgress::AllDone => {
let path = path.context("expected ExportProgress::Found event to occur")?;
let outcome = ExportFileOutcome {
size: total_size,
path,
};
return Ok(outcome);
}
ExportProgress::Done { .. } => {}
ExportProgress::Abort(err) => return Err(anyhow!(err)),
ExportProgress::Progress { .. } => {}
}
}
Err(anyhow!("Response stream ended prematurely"))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExportFileOutcome {
pub size: u64,
pub path: PathBuf,
}
impl Stream for ExportFileProgress {
type Item = Result<ExportProgress>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}