use std::{collections::HashMap, num::NonZeroU64};
use serde::{Deserialize, Serialize};
use tracing::warn;
use super::db::{BlobId, DownloadProgress};
use crate::{protocol::RangeSpec, store::BaoBlobSize, Hash};
pub type ProgressId = u64;
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct TransferState {
pub root: BlobState,
pub connected: bool,
pub children: HashMap<NonZeroU64, BlobState>,
pub current: Option<BlobId>,
pub progress_id_to_blob: HashMap<ProgressId, BlobId>,
}
impl TransferState {
pub fn new(root_hash: Hash) -> Self {
Self {
root: BlobState::new(root_hash),
connected: false,
children: Default::default(),
current: None,
progress_id_to_blob: Default::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct BlobState {
pub hash: Hash,
pub size: Option<BaoBlobSize>,
pub progress: BlobProgress,
pub local_ranges: Option<RangeSpec>,
pub child_count: Option<u64>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub enum BlobProgress {
#[default]
Pending,
Progressing(u64),
Done,
}
impl BlobState {
pub fn new(hash: Hash) -> Self {
Self {
hash,
size: None,
local_ranges: None,
child_count: None,
progress: BlobProgress::default(),
}
}
}
impl TransferState {
pub fn root(&self) -> &BlobState {
&self.root
}
pub fn get_blob(&self, blob_id: &BlobId) -> Option<&BlobState> {
match blob_id {
BlobId::Root => Some(&self.root),
BlobId::Child(id) => self.children.get(id),
}
}
pub fn get_current(&self) -> Option<&BlobState> {
self.current.as_ref().and_then(|id| self.get_blob(id))
}
fn get_or_insert_blob(&mut self, blob_id: BlobId, hash: Hash) -> &mut BlobState {
match blob_id {
BlobId::Root => &mut self.root,
BlobId::Child(id) => self
.children
.entry(id)
.or_insert_with(|| BlobState::new(hash)),
}
}
fn get_blob_mut(&mut self, blob_id: &BlobId) -> Option<&mut BlobState> {
match blob_id {
BlobId::Root => Some(&mut self.root),
BlobId::Child(id) => self.children.get_mut(id),
}
}
fn get_by_progress_id(&mut self, progress_id: ProgressId) -> Option<&mut BlobState> {
let blob_id = *self.progress_id_to_blob.get(&progress_id)?;
self.get_blob_mut(&blob_id)
}
pub fn on_progress(&mut self, event: DownloadProgress) {
match event {
DownloadProgress::InitialState(s) => {
*self = s;
}
DownloadProgress::FoundLocal {
child,
hash,
size,
valid_ranges,
} => {
let blob = self.get_or_insert_blob(child, hash);
blob.size = Some(size);
blob.local_ranges = Some(valid_ranges);
}
DownloadProgress::Connected => self.connected = true,
DownloadProgress::Found {
id: progress_id,
child: blob_id,
hash,
size,
} => {
let blob = self.get_or_insert_blob(blob_id, hash);
blob.size = match blob.size {
None | Some(BaoBlobSize::Unverified(_)) => Some(BaoBlobSize::Unverified(size)),
value @ Some(BaoBlobSize::Verified(_)) => value,
};
blob.progress = BlobProgress::Progressing(0);
self.progress_id_to_blob.insert(progress_id, blob_id);
self.current = Some(blob_id);
}
DownloadProgress::FoundHashSeq { hash, children } => {
if hash == self.root.hash {
self.root.child_count = Some(children);
} else {
warn!("Received `FoundHashSeq` event for a hash which is not the download's root hash.")
}
}
DownloadProgress::Progress { id, offset } => {
if let Some(blob) = self.get_by_progress_id(id) {
blob.progress = BlobProgress::Progressing(offset);
} else {
warn!(%id, "Received `Progress` event for unknown progress id.")
}
}
DownloadProgress::Done { id } => {
if let Some(blob) = self.get_by_progress_id(id) {
blob.progress = BlobProgress::Done;
self.progress_id_to_blob.remove(&id);
} else {
warn!(%id, "Received `Done` event for unknown progress id.")
}
}
DownloadProgress::AllDone(_) | DownloadProgress::Abort(_) => {}
}
}
}