1use std::{
10 future::{Future, IntoFuture},
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14};
15
16use bao_tree::{io::BaoContentItem, ChunkNum, ChunkRanges};
17use bytes::Bytes;
18use genawaiter::sync::{Co, Gen};
19use iroh::endpoint::Connection;
20use n0_error::e;
21use n0_future::{Stream, StreamExt};
22use nested_enum_utils::enum_conversions;
23use rand::Rng;
24use tokio::sync::mpsc;
25
26use super::{fsm, GetError, GetResult, Stats};
27use crate::{
28 hashseq::HashSeq,
29 protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
30 Hash, HashAndFormat,
31};
32
33pub struct GetBlobResult {
38 rx: n0_future::stream::Boxed<GetBlobItem>,
39}
40
41impl IntoFuture for GetBlobResult {
42 type Output = GetResult<Bytes>;
43 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
44
45 fn into_future(self) -> Self::IntoFuture {
46 Box::pin(self.bytes())
47 }
48}
49
50impl GetBlobResult {
51 pub async fn bytes(self) -> GetResult<Bytes> {
52 let (bytes, _) = self.bytes_and_stats().await?;
53 Ok(bytes)
54 }
55
56 pub async fn bytes_and_stats(mut self) -> GetResult<(Bytes, Stats)> {
57 let mut parts = Vec::new();
58 let stats = loop {
59 let Some(item) = self.next().await else {
60 return Err(e!(
61 GetError::LocalFailure,
62 n0_error::anyerr!("unexpected end")
63 ));
64 };
65 match item {
66 GetBlobItem::Item(item) => {
67 if let BaoContentItem::Leaf(leaf) = item {
68 parts.push(leaf.data);
69 }
70 }
71 GetBlobItem::Done(stats) => {
72 break stats;
73 }
74 GetBlobItem::Error(cause) => {
75 return Err(cause);
76 }
77 }
78 };
79 let bytes = if parts.len() == 1 {
80 parts.pop().unwrap()
81 } else {
82 let mut bytes = Vec::new();
83 for part in parts {
84 bytes.extend_from_slice(&part);
85 }
86 bytes.into()
87 };
88 Ok((bytes, stats))
89 }
90}
91
92impl Stream for GetBlobResult {
93 type Item = GetBlobItem;
94
95 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
96 self.rx.poll_next(cx)
97 }
98}
99
100#[derive(Debug)]
102#[enum_conversions()]
103pub enum GetBlobItem {
104 Item(BaoContentItem),
106 Done(Stats),
108 Error(GetError),
110}
111
112pub fn get_blob(connection: Connection, hash: Hash) -> GetBlobResult {
113 let generator = Gen::new(|co| async move {
114 if let Err(cause) = get_blob_impl(&connection, &hash, &co).await {
115 co.yield_(GetBlobItem::Error(cause)).await;
116 }
117 });
118 GetBlobResult {
119 rx: Box::pin(generator),
120 }
121}
122
123async fn get_blob_impl(
124 connection: &Connection,
125 hash: &Hash,
126 co: &Co<GetBlobItem>,
127) -> GetResult<()> {
128 let request = GetRequest::blob(*hash);
129 let request = fsm::start(connection.clone(), request, Default::default());
130 let connected = request.next().await?;
131 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
132 unreachable!("expected start root");
133 };
134 let header = start.next();
135 let (mut curr, _size) = header.next().await?;
136 let end = loop {
137 match curr.next().await {
138 fsm::BlobContentNext::More((next, res)) => {
139 co.yield_(res?.into()).await;
140 curr = next;
141 }
142 fsm::BlobContentNext::Done(end) => {
143 break end;
144 }
145 }
146 };
147 let fsm::EndBlobNext::Closing(closing) = end.next() else {
148 unreachable!("expected closing");
149 };
150 let stats = closing.next().await?;
151 co.yield_(stats.into()).await;
152 Ok(())
153}
154
155pub async fn get_unverified_size(connection: &Connection, hash: &Hash) -> GetResult<(u64, Stats)> {
160 let request = GetRequest::new(
161 *hash,
162 ChunkRangesSeq::from_ranges(vec![ChunkRanges::last_chunk()]),
163 );
164 let request = fsm::start(connection.clone(), request, Default::default());
165 let connected = request.next().await?;
166 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
167 unreachable!("expected start root");
168 };
169 let at_blob_header = start.next();
170 let (curr, size) = at_blob_header.next().await?;
171 let stats = curr.finish().next().await?;
172 Ok((size, stats))
173}
174
175pub async fn get_verified_size(connection: &Connection, hash: &Hash) -> GetResult<(u64, Stats)> {
180 tracing::trace!("Getting verified size of {}", hash.to_hex());
181 let request = GetRequest::new(
182 *hash,
183 ChunkRangesSeq::from_ranges(vec![ChunkRanges::last_chunk()]),
184 );
185 let request = fsm::start(connection.clone(), request, Default::default());
186 let connected = request.next().await?;
187 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
188 unreachable!("expected start root");
189 };
190 let header = start.next();
191 let (mut curr, size) = header.next().await?;
192 let end = loop {
193 match curr.next().await {
194 fsm::BlobContentNext::More((next, res)) => {
195 let _ = res?;
196 curr = next;
197 }
198 fsm::BlobContentNext::Done(end) => {
199 break end;
200 }
201 }
202 };
203 let fsm::EndBlobNext::Closing(closing) = end.next() else {
204 unreachable!("expected closing");
205 };
206 let stats = closing.next().await?;
207 tracing::trace!(
208 "Got verified size of {}, {:.6}s",
209 hash.to_hex(),
210 stats.elapsed.as_secs_f64()
211 );
212 Ok((size, stats))
213}
214
215pub async fn get_hash_seq_and_sizes(
223 connection: &Connection,
224 hash: &Hash,
225 max_size: u64,
226 _progress: Option<mpsc::Sender<u64>>,
227) -> GetResult<(HashSeq, Arc<[u64]>)> {
228 let content = HashAndFormat::hash_seq(*hash);
229 tracing::debug!("Getting hash seq and children sizes of {}", content);
230 let request = GetRequest::new(
231 *hash,
232 ChunkRangesSeq::from_ranges_infinite([ChunkRanges::all(), ChunkRanges::last_chunk()]),
233 );
234 let at_start = fsm::start(connection.clone(), request, Default::default());
235 let at_connected = at_start.next().await?;
236 let fsm::ConnectedNext::StartRoot(start) = at_connected.next().await? else {
237 unreachable!("query includes root");
238 };
239 let at_start_root = start.next();
240 let (at_blob_content, size) = at_start_root.next().await?;
241 if size > max_size {
243 return Err(e!(
244 GetError::BadRequest,
245 n0_error::anyerr!("size too large")
246 ));
247 }
248 let (mut curr, hash_seq) = at_blob_content.concatenate_into_vec().await?;
249 let hash_seq =
250 HashSeq::try_from(Bytes::from(hash_seq)).map_err(|e| e!(GetError::BadRequest, e))?;
251 let mut sizes = Vec::with_capacity(hash_seq.len());
252 let closing = loop {
253 match curr.next() {
254 fsm::EndBlobNext::MoreChildren(more) => {
255 let hash = match hash_seq.get(sizes.len()) {
256 Some(hash) => hash,
257 None => break more.finish(),
258 };
259 let at_header = more.next(hash);
260 let (at_content, size) = at_header.next().await?;
261 let next = at_content.drain().await?;
262 sizes.push(size);
263 curr = next;
264 }
265 fsm::EndBlobNext::Closing(closing) => break closing,
266 }
267 };
268 let _stats = closing.next().await?;
269 tracing::debug!(
270 "Got hash seq and children sizes of {}: {:?}",
271 content,
272 sizes
273 );
274 Ok((hash_seq, sizes.into()))
275}
276
277pub async fn get_chunk_probe(
289 connection: &Connection,
290 hash: &Hash,
291 chunk: ChunkNum,
292) -> GetResult<Stats> {
293 let ranges = ChunkRanges::from(chunk..chunk + 1);
294 let ranges = ChunkRangesSeq::from_ranges([ranges]);
295 let request = GetRequest::new(*hash, ranges);
296 let request = fsm::start(connection.clone(), request, Default::default());
297 let connected = request.next().await?;
298 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
299 unreachable!("query includes root");
300 };
301 let header = start.next();
302 let (mut curr, _size) = header.next().await?;
303 let end = loop {
304 match curr.next().await {
305 fsm::BlobContentNext::More((next, res)) => {
306 res?;
307 curr = next;
308 }
309 fsm::BlobContentNext::Done(end) => {
310 break end;
311 }
312 }
313 };
314 let fsm::EndBlobNext::Closing(closing) = end.next() else {
315 unreachable!("query contains only one blob");
316 };
317 let stats = closing.next().await?;
318 Ok(stats)
319}
320
321pub fn random_hash_seq_ranges(sizes: &[u64], mut rng: impl Rng) -> ChunkRangesSeq {
327 let total_chunks = sizes
328 .iter()
329 .map(|size| ChunkNum::full_chunks(*size).0)
330 .sum::<u64>();
331 let random_chunk = rng.random_range(0..total_chunks);
332 let mut remaining = random_chunk;
333 let mut ranges = vec![];
334 ranges.push(ChunkRanges::empty());
335 for size in sizes.iter() {
336 let chunks = ChunkNum::full_chunks(*size).0;
337 if remaining < chunks {
338 ranges.push(ChunkRanges::from(
339 ChunkNum(remaining)..ChunkNum(remaining + 1),
340 ));
341 break;
342 } else {
343 remaining -= chunks;
344 ranges.push(ChunkRanges::empty());
345 }
346 }
347 ChunkRangesSeq::from_ranges(ranges)
348}