use std::{
collections::HashMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use anyhow::anyhow;
use parking_lot::Mutex;
use super::DownloadKind;
use crate::{
get::{db::DownloadProgress, progress::TransferState},
util::progress::{AsyncChannelProgressSender, IdGenerator, ProgressSendError, ProgressSender},
};
pub type ProgressSubscriber = AsyncChannelProgressSender<DownloadProgress>;
#[derive(Debug, Default)]
pub struct ProgressTracker {
running: HashMap<DownloadKind, Shared>,
id_gen: Arc<AtomicU64>,
}
impl ProgressTracker {
pub fn new() -> Self {
Self::default()
}
pub fn track(
&mut self,
kind: DownloadKind,
subscribers: impl IntoIterator<Item = ProgressSubscriber>,
) -> BroadcastProgressSender {
let inner = Inner {
subscribers: subscribers.into_iter().collect(),
state: TransferState::new(kind.hash()),
};
let shared = Arc::new(Mutex::new(inner));
self.running.insert(kind, Arc::clone(&shared));
let id_gen = Arc::clone(&self.id_gen);
BroadcastProgressSender { shared, id_gen }
}
pub async fn subscribe(
&mut self,
kind: DownloadKind,
sender: ProgressSubscriber,
) -> anyhow::Result<()> {
let initial_msg = self
.running
.get_mut(&kind)
.ok_or_else(|| anyhow!("state for download {kind:?} not found"))?
.lock()
.subscribe(sender.clone());
sender.send(initial_msg).await?;
Ok(())
}
pub fn unsubscribe(&mut self, kind: &DownloadKind, sender: &ProgressSubscriber) {
if let Some(shared) = self.running.get_mut(kind) {
shared.lock().unsubscribe(sender)
}
}
pub fn remove(&mut self, kind: &DownloadKind) {
self.running.remove(kind);
}
}
type Shared = Arc<Mutex<Inner>>;
#[derive(Debug)]
struct Inner {
subscribers: Vec<ProgressSubscriber>,
state: TransferState,
}
impl Inner {
fn subscribe(&mut self, subscriber: ProgressSubscriber) -> DownloadProgress {
let msg = DownloadProgress::InitialState(self.state.clone());
self.subscribers.push(subscriber);
msg
}
fn unsubscribe(&mut self, sender: &ProgressSubscriber) {
self.subscribers.retain(|s| !s.same_channel(sender));
}
fn on_progress(&mut self, progress: DownloadProgress) {
self.state.on_progress(progress);
}
}
#[derive(Debug, Clone)]
pub struct BroadcastProgressSender {
shared: Shared,
id_gen: Arc<AtomicU64>,
}
impl IdGenerator for BroadcastProgressSender {
fn new_id(&self) -> u64 {
self.id_gen.fetch_add(1, Ordering::SeqCst)
}
}
impl ProgressSender for BroadcastProgressSender {
type Msg = DownloadProgress;
async fn send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
let futs = {
let mut inner = self.shared.lock();
inner.on_progress(msg.clone());
let futs = inner
.subscribers
.iter_mut()
.map(|sender| {
let sender = sender.clone();
let msg = msg.clone();
async move {
match sender.send(msg).await {
Ok(()) => None,
Err(ProgressSendError::ReceiverDropped) => Some(sender),
}
}
})
.collect::<Vec<_>>();
drop(inner);
futs
};
let failed_senders = futures_buffered::join_all(futs).await;
if failed_senders.iter().any(|s| s.is_some()) {
let mut inner = self.shared.lock();
for sender in failed_senders.into_iter().flatten() {
inner.unsubscribe(&sender);
}
drop(inner);
}
Ok(())
}
fn try_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
let mut inner = self.shared.lock();
inner.on_progress(msg.clone());
inner
.subscribers
.retain_mut(|sender| match sender.try_send(msg.clone()) {
Err(ProgressSendError::ReceiverDropped) => false,
Ok(()) => true,
});
Ok(())
}
fn blocking_send(&self, msg: Self::Msg) -> Result<(), ProgressSendError> {
let mut inner = self.shared.lock();
inner.on_progress(msg.clone());
inner
.subscribers
.retain_mut(|sender| match sender.blocking_send(msg.clone()) {
Err(ProgressSendError::ReceiverDropped) => false,
Ok(()) => true,
});
Ok(())
}
}