1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
//! Utilities for working with tokio io

use std::{io, pin::Pin, task::Poll};

use iroh_io::AsyncStreamReader;
use tokio::io::AsyncWrite;

/// A reader that tracks the number of bytes read
#[derive(Debug)]
pub struct TrackingReader<R> {
    inner: R,
    read: u64,
}

impl<R> TrackingReader<R> {
    /// Wrap a reader in a tracking reader
    pub fn new(inner: R) -> Self {
        Self { inner, read: 0 }
    }

    /// Get the number of bytes read
    #[allow(dead_code)]
    pub fn bytes_read(&self) -> u64 {
        self.read
    }

    /// Get the inner reader
    pub fn into_parts(self) -> (R, u64) {
        (self.inner, self.read)
    }
}

impl<R> AsyncStreamReader for TrackingReader<R>
where
    R: AsyncStreamReader,
{
    async fn read_bytes(&mut self, len: usize) -> io::Result<bytes::Bytes> {
        let bytes = self.inner.read_bytes(len).await?;
        self.read = self.read.saturating_add(bytes.len() as u64);
        Ok(bytes)
    }

    async fn read<const L: usize>(&mut self) -> io::Result<[u8; L]> {
        let res = self.inner.read::<L>().await?;
        self.read = self.read.saturating_add(L as u64);
        Ok(res)
    }
}

/// A writer that tracks the number of bytes written
#[derive(Debug)]
pub struct TrackingWriter<W> {
    inner: W,
    written: u64,
}

impl<W> TrackingWriter<W> {
    /// Wrap a writer in a tracking writer
    pub fn new(inner: W) -> Self {
        Self { inner, written: 0 }
    }

    /// Get the number of bytes written
    #[allow(dead_code)]
    pub fn bytes_written(&self) -> u64 {
        self.written
    }

    /// Get the inner writer
    pub fn into_parts(self) -> (W, u64) {
        (self.inner, self.written)
    }
}

impl<W: AsyncWrite + Unpin> AsyncWrite for TrackingWriter<W> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let this = &mut *self;
        let res = Pin::new(&mut this.inner).poll_write(cx, buf);
        if let Poll::Ready(Ok(size)) = res {
            this.written = this.written.saturating_add(size as u64);
        }
        res
    }

    fn poll_flush(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<io::Result<()>> {
        Pin::new(&mut self.inner).poll_flush(cx)
    }

    fn poll_shutdown(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<io::Result<()>> {
        Pin::new(&mut self.inner).poll_shutdown(cx)
    }
}