1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
//! Types for get progress state management.

use std::{collections::HashMap, num::NonZeroU64};

use serde::{Deserialize, Serialize};
use tracing::warn;

use super::db::{BlobId, DownloadProgress};
use crate::{protocol::RangeSpec, store::BaoBlobSize, Hash};

/// The identifier for progress events.
pub type ProgressId = u64;

/// Accumulated progress state of a transfer.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct TransferState {
    /// The root blob of this transfer (may be a hash seq),
    pub root: BlobState,
    /// Whether we are connected to a node
    pub connected: bool,
    /// Children if the root blob is a hash seq, empty for raw blobs
    pub children: HashMap<NonZeroU64, BlobState>,
    /// Child being transferred at the moment.
    pub current: Option<BlobId>,
    /// Progress ids for individual blobs.
    pub progress_id_to_blob: HashMap<ProgressId, BlobId>,
}

impl TransferState {
    /// Create a new, empty transfer state.
    pub fn new(root_hash: Hash) -> Self {
        Self {
            root: BlobState::new(root_hash),
            connected: false,
            children: Default::default(),
            current: None,
            progress_id_to_blob: Default::default(),
        }
    }
}

/// State of a single blob in transfer
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct BlobState {
    /// The hash of this blob.
    pub hash: Hash,
    /// The size of this blob. Only known if the blob is partially present locally, or after having
    /// received the size from the remote.
    pub size: Option<BaoBlobSize>,
    /// The current state of the blob transfer.
    pub progress: BlobProgress,
    /// Ranges already available locally at the time of starting the transfer.
    pub local_ranges: Option<RangeSpec>,
    /// Number of children (only applies to hashseqs, None for raw blobs).
    pub child_count: Option<u64>,
}

/// Progress state for a single blob
#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub enum BlobProgress {
    /// Download is pending
    #[default]
    Pending,
    /// Download is in progress
    Progressing(u64),
    /// Download has finished
    Done,
}

impl BlobState {
    /// Create a new [`BlobState`].
    pub fn new(hash: Hash) -> Self {
        Self {
            hash,
            size: None,
            local_ranges: None,
            child_count: None,
            progress: BlobProgress::default(),
        }
    }
}

impl TransferState {
    /// Get state of the root blob of this transfer.
    pub fn root(&self) -> &BlobState {
        &self.root
    }

    /// Get a blob state by its [`BlobId`] in this transfer.
    pub fn get_blob(&self, blob_id: &BlobId) -> Option<&BlobState> {
        match blob_id {
            BlobId::Root => Some(&self.root),
            BlobId::Child(id) => self.children.get(id),
        }
    }

    /// Get the blob state currently being transferred.
    pub fn get_current(&self) -> Option<&BlobState> {
        self.current.as_ref().and_then(|id| self.get_blob(id))
    }

    fn get_or_insert_blob(&mut self, blob_id: BlobId, hash: Hash) -> &mut BlobState {
        match blob_id {
            BlobId::Root => &mut self.root,
            BlobId::Child(id) => self
                .children
                .entry(id)
                .or_insert_with(|| BlobState::new(hash)),
        }
    }
    fn get_blob_mut(&mut self, blob_id: &BlobId) -> Option<&mut BlobState> {
        match blob_id {
            BlobId::Root => Some(&mut self.root),
            BlobId::Child(id) => self.children.get_mut(id),
        }
    }

    fn get_by_progress_id(&mut self, progress_id: ProgressId) -> Option<&mut BlobState> {
        let blob_id = *self.progress_id_to_blob.get(&progress_id)?;
        self.get_blob_mut(&blob_id)
    }

    /// Update the state with a new [`DownloadProgress`] event for this transfer.
    pub fn on_progress(&mut self, event: DownloadProgress) {
        match event {
            DownloadProgress::InitialState(s) => {
                *self = s;
            }
            DownloadProgress::FoundLocal {
                child,
                hash,
                size,
                valid_ranges,
            } => {
                let blob = self.get_or_insert_blob(child, hash);
                blob.size = Some(size);
                blob.local_ranges = Some(valid_ranges);
            }
            DownloadProgress::Connected => self.connected = true,
            DownloadProgress::Found {
                id: progress_id,
                child: blob_id,
                hash,
                size,
            } => {
                let blob = self.get_or_insert_blob(blob_id, hash);
                blob.size = match blob.size {
                    // If we don't have a verified size for this blob yet: Use the size as reported
                    // by the remote.
                    None | Some(BaoBlobSize::Unverified(_)) => Some(BaoBlobSize::Unverified(size)),
                    // Otherwise, keep the existing verified size.
                    value @ Some(BaoBlobSize::Verified(_)) => value,
                };
                blob.progress = BlobProgress::Progressing(0);
                self.progress_id_to_blob.insert(progress_id, blob_id);
                self.current = Some(blob_id);
            }
            DownloadProgress::FoundHashSeq { hash, children } => {
                if hash == self.root.hash {
                    self.root.child_count = Some(children);
                } else {
                    // I think it is an invariant of the protocol that `FoundHashSeq` is only
                    // triggered for the root hash.
                    warn!("Received `FoundHashSeq` event for a hash which is not the download's root hash.")
                }
            }
            DownloadProgress::Progress { id, offset } => {
                if let Some(blob) = self.get_by_progress_id(id) {
                    blob.progress = BlobProgress::Progressing(offset);
                } else {
                    warn!(%id, "Received `Progress` event for unknown progress id.")
                }
            }
            DownloadProgress::Done { id } => {
                if let Some(blob) = self.get_by_progress_id(id) {
                    blob.progress = BlobProgress::Done;
                    self.progress_id_to_blob.remove(&id);
                } else {
                    warn!(%id, "Received `Done` event for unknown progress id.")
                }
            }
            DownloadProgress::AllDone(_) | DownloadProgress::Abort(_) => {}
        }
    }
}