1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
use std::{
    num::ParseIntError,
    str::FromStr,
    sync::{Arc, Mutex},
    time::Instant,
};

use anyhow::Result;
use clap::Parser;
use stats::Stats;
use tokio::{
    runtime::{Builder, Runtime},
    sync::Semaphore,
};
use tracing::info;

pub mod iroh;
#[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))]
pub mod quinn;
pub mod s2n;
pub mod stats;

#[derive(Parser, Debug, Clone, Copy)]
#[clap(name = "iroh-net-bench")]
pub enum Commands {
    Iroh(Opt),
    #[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))]
    Quinn(Opt),
    S2n(s2n::Opt),
}

#[derive(Parser, Debug, Clone, Copy)]
#[clap(name = "options")]
pub struct Opt {
    /// The total number of clients which should be created
    #[clap(long = "clients", short = 'c', default_value = "1")]
    pub clients: usize,
    /// The total number of streams which should be created
    #[clap(long = "streams", short = 'n', default_value = "1")]
    pub streams: usize,
    /// The amount of concurrent streams which should be used
    #[clap(long = "max_streams", short = 'm', default_value = "1")]
    pub max_streams: usize,
    /// Number of bytes to transmit from server to client
    ///
    /// This can use SI prefixes for sizes. E.g. 1M will transfer 1MiB, 10G
    /// will transfer 10GiB.
    #[clap(long, default_value = "1G", value_parser = parse_byte_size)]
    pub download_size: u64,
    /// Number of bytes to transmit from client to server
    ///
    /// This can use SI prefixes for sizes. E.g. 1M will transfer 1MiB, 10G
    /// will transfer 10GiB.
    #[clap(long, default_value = "0", value_parser = parse_byte_size)]
    pub upload_size: u64,
    /// Show connection stats the at the end of the benchmark
    #[clap(long = "stats")]
    pub stats: bool,
    /// Show iroh library counter metrics at the end of the benchmark
    ///
    /// These metrics are process-wide, so contain metrics for
    /// clients and the server all summed up.
    #[clap(long)]
    pub metrics: bool,
    /// Whether to use the unordered read API
    #[clap(long = "unordered")]
    pub read_unordered: bool,
    /// Starting guess for maximum UDP payload size
    #[clap(long, default_value = "1200")]
    pub initial_mtu: u16,
    /// Whether to run a local relay and have the server and clients connect to that.
    ///
    /// Can be combined with the `DEV_RELAY_ONLY` environment variable (at compile time)
    /// to test throughput for relay-only traffic locally.
    /// (e.g. `DEV_RELAY_ONLY=true cargo run --release -- iroh --with-relay`)
    #[clap(long, default_value_t = false)]
    pub with_relay: bool,
}

pub enum EndpointSelector {
    Iroh(::iroh::Endpoint),
    #[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))]
    Quinn(::quinn::Endpoint),
}

impl EndpointSelector {
    pub async fn close(self) -> Result<()> {
        match self {
            EndpointSelector::Iroh(endpoint) => {
                endpoint.close().await?;
            }
            #[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))]
            EndpointSelector::Quinn(endpoint) => {
                endpoint.close(0u32.into(), b"");
            }
        }
        Ok(())
    }
}

pub enum ConnectionSelector {
    Iroh(::iroh::endpoint::Connection),
    #[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))]
    Quinn(::quinn::Connection),
}

impl ConnectionSelector {
    pub fn stats(&self) {
        match self {
            ConnectionSelector::Iroh(connection) => {
                println!("{:#?}", connection.stats());
            }
            #[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))]
            ConnectionSelector::Quinn(connection) => {
                println!("{:#?}", connection.stats());
            }
        }
    }

    pub fn close(&self, error_code: u32, reason: &[u8]) {
        match self {
            ConnectionSelector::Iroh(connection) => {
                connection.close(error_code.into(), reason);
            }
            #[cfg(not(any(target_os = "freebsd", target_os = "openbsd", target_os = "netbsd")))]
            ConnectionSelector::Quinn(connection) => {
                connection.close(error_code.into(), reason);
            }
        }
    }
}

pub fn configure_tracing_subscriber() {
    tracing::subscriber::set_global_default(
        tracing_subscriber::FmtSubscriber::builder()
            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
            .finish(),
    )
    .unwrap();
}

pub fn rt() -> Runtime {
    Builder::new_current_thread().enable_all().build().unwrap()
}

fn parse_byte_size(s: &str) -> Result<u64, ParseIntError> {
    let s = s.trim();

    let multiplier = match s.chars().last() {
        Some('T') => 1024 * 1024 * 1024 * 1024,
        Some('G') => 1024 * 1024 * 1024,
        Some('M') => 1024 * 1024,
        Some('k') => 1024,
        _ => 1,
    };

    let s = if multiplier != 1 {
        &s[..s.len() - 1]
    } else {
        s
    };

    let base: u64 = u64::from_str(s)?;

    Ok(base * multiplier)
}

#[derive(Default)]
pub struct ClientStats {
    upload_stats: Stats,
    download_stats: Stats,
    connect_time: std::time::Duration,
}

impl ClientStats {
    pub fn print(&self, client_id: usize) {
        println!();
        println!("Client {client_id} stats:");

        let ct = self.connect_time.as_nanos() as f64 / 1_000_000.0;
        println!("Connect time: {ct}ms");

        if self.upload_stats.total_size != 0 {
            self.upload_stats.print("upload");
        }

        if self.download_stats.total_size != 0 {
            self.download_stats.print("download");
        }
    }
}

/// Take the provided endpoint and run the client benchmark
pub async fn client_handler(
    endpoint: EndpointSelector,
    connection: ConnectionSelector,
    opt: Opt,
) -> Result<ClientStats> {
    let start = Instant::now();

    let connection = Arc::new(connection);

    let mut stats = ClientStats::default();
    let mut first_error = None;

    let sem = Arc::new(Semaphore::new(opt.max_streams));
    let results = Arc::new(Mutex::new(Vec::new()));
    for _ in 0..opt.streams {
        let permit = sem.clone().acquire_owned().await.unwrap();
        let results = results.clone();
        let connection = connection.clone();
        tokio::spawn(async move {
            let result = match &*connection {
                ConnectionSelector::Iroh(connection) => {
                    iroh::handle_client_stream(connection, opt.upload_size, opt.read_unordered)
                        .await
                }
                #[cfg(not(any(
                    target_os = "freebsd",
                    target_os = "openbsd",
                    target_os = "netbsd"
                )))]
                ConnectionSelector::Quinn(connection) => {
                    quinn::handle_client_stream(connection, opt.upload_size, opt.read_unordered)
                        .await
                }
            };
            // handle_client_stream(connection, opt.upload_size, opt.read_unordered).await;
            info!("stream finished: {:?}", result);
            results.lock().unwrap().push(result);
            drop(permit);
        });
    }

    // Wait for remaining streams to finish
    let _ = sem.acquire_many(opt.max_streams as u32).await.unwrap();

    stats.upload_stats.total_duration = start.elapsed();
    stats.download_stats.total_duration = start.elapsed();

    for result in results.lock().unwrap().drain(..) {
        match result {
            Ok((upload_result, download_result)) => {
                stats.upload_stats.stream_finished(upload_result);
                stats.download_stats.stream_finished(download_result);
            }
            Err(e) => {
                if first_error.is_none() {
                    first_error = Some(e);
                }
            }
        }
    }

    // Explicit close of the connection, since handles can still be around due
    // to `Arc`ing them
    connection.close(0u32, b"Benchmark done");

    endpoint.close().await?;

    if opt.stats {
        println!("\nClient connection stats:\n{:#?}", connection.stats());
    }

    match first_error {
        None => Ok(stats),
        Some(e) => Err(e),
    }
}