use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{Context, Result};
use bytes::Bytes;
use quinn::{
crypto::rustls::QuicClientConfig, Connection, Endpoint, RecvStream, SendStream, TransportConfig,
};
use rustls::{
pki_types::{CertificateDer, PrivateKeyDer},
RootCertStore,
};
use tracing::{trace, warn};
use crate::{
client_handler, stats::TransferResult, ClientStats, ConnectionSelector, EndpointSelector, Opt,
};
pub const ALPN: &[u8] = b"n0/quinn-bench/0";
pub fn server_endpoint(
rt: &tokio::runtime::Runtime,
cert: CertificateDer<'static>,
key: PrivateKeyDer<'static>,
opt: &Opt,
) -> (SocketAddr, quinn::Endpoint) {
let cert_chain = vec![cert];
let mut server_config = quinn::ServerConfig::with_single_cert(cert_chain, key).unwrap();
server_config.transport = Arc::new(transport_config(opt.max_streams, opt.initial_mtu));
let endpoint = {
let _guard = rt.enter();
quinn::Endpoint::server(
server_config,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
)
.unwrap()
};
let server_addr = endpoint.local_addr().unwrap();
(server_addr, endpoint)
}
pub async fn client(
server_addr: SocketAddr,
server_cert: CertificateDer<'static>,
opt: Opt,
) -> Result<ClientStats> {
let client_start = std::time::Instant::now();
let (endpoint, connection) = connect_client(server_addr, server_cert, opt).await?;
let client_connect_time = client_start.elapsed();
let mut res = client_handler(
EndpointSelector::Quinn(endpoint),
ConnectionSelector::Quinn(connection),
opt,
)
.await?;
res.connect_time = client_connect_time;
Ok(res)
}
pub async fn connect_client(
server_addr: SocketAddr,
server_cert: CertificateDer<'_>,
opt: Opt,
) -> Result<(::quinn::Endpoint, Connection)> {
let endpoint =
quinn::Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap();
let mut roots = RootCertStore::empty();
roots.add(server_cert)?;
let provider = rustls::crypto::ring::default_provider();
let crypto = rustls::ClientConfig::builder_with_provider(provider.into())
.with_protocol_versions(&[&rustls::version::TLS13])
.unwrap()
.with_root_certificates(roots)
.with_no_client_auth();
let mut client_config = quinn::ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto)?));
client_config.transport_config(Arc::new(transport_config(opt.max_streams, opt.initial_mtu)));
let connection = endpoint
.connect_with(client_config, server_addr, "localhost")
.unwrap()
.await
.context("unable to connect")?;
trace!("connected");
Ok((endpoint, connection))
}
pub fn transport_config(max_streams: usize, initial_mtu: u16) -> TransportConfig {
let mut config = TransportConfig::default();
config.max_concurrent_uni_streams(max_streams.try_into().unwrap());
config.initial_mtu(initial_mtu);
config
}
async fn drain_stream(
stream: &mut RecvStream,
read_unordered: bool,
) -> Result<(usize, Duration, u64)> {
let mut read = 0;
let download_start = Instant::now();
let mut first_byte = true;
let mut ttfb = download_start.elapsed();
let mut num_chunks: u64 = 0;
if read_unordered {
while let Some(chunk) = stream.read_chunk(usize::MAX, false).await? {
if first_byte {
ttfb = download_start.elapsed();
first_byte = false;
}
read += chunk.bytes.len();
num_chunks += 1;
}
} else {
#[rustfmt::skip]
let mut bufs = [
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(),
];
while let Some(n) = stream.read_chunks(&mut bufs[..]).await? {
if first_byte {
ttfb = download_start.elapsed();
first_byte = false;
}
read += bufs.iter().take(n).map(|buf| buf.len()).sum::<usize>();
num_chunks += 1;
}
}
Ok((read, ttfb, num_chunks))
}
async fn send_data_on_stream(stream: &mut SendStream, stream_size: u64) -> Result<()> {
const DATA: &[u8] = &[0xAB; 1024 * 1024];
let bytes_data = Bytes::from_static(DATA);
let full_chunks = stream_size / (DATA.len() as u64);
let remaining = (stream_size % (DATA.len() as u64)) as usize;
for _ in 0..full_chunks {
stream
.write_chunk(bytes_data.clone())
.await
.context("failed sending data")?;
}
if remaining != 0 {
stream
.write_chunk(bytes_data.slice(0..remaining))
.await
.context("failed sending data")?;
}
stream.finish().context("failed finishing stream")?;
stream
.stopped()
.await
.context("failed to wait for stream to be stopped")?;
Ok(())
}
pub async fn handle_client_stream(
connection: &Connection,
upload_size: u64,
read_unordered: bool,
) -> Result<(TransferResult, TransferResult)> {
let start = Instant::now();
let (mut send_stream, mut recv_stream) = connection
.open_bi()
.await
.context("failed to open stream")?;
send_data_on_stream(&mut send_stream, upload_size).await?;
let upload_result = TransferResult::new(start.elapsed(), upload_size, Duration::default(), 0);
let start = Instant::now();
let (size, ttfb, num_chunks) = drain_stream(&mut recv_stream, read_unordered).await?;
let download_result = TransferResult::new(start.elapsed(), size as u64, ttfb, num_chunks);
Ok((upload_result, download_result))
}
pub async fn server(endpoint: Endpoint, opt: Opt) -> Result<()> {
let mut server_tasks = Vec::new();
for _ in 0..opt.clients {
let incoming = endpoint.accept().await.unwrap();
let connecting = match incoming.accept() {
Ok(connecting) => connecting,
Err(err) => {
warn!("incoming connection failed: {err:#}");
continue;
}
};
let connection = connecting.await.context("handshake failed")?;
server_tasks.push(tokio::spawn(async move {
loop {
let (mut send_stream, mut recv_stream) = match connection.accept_bi().await {
Err(::quinn::ConnectionError::ApplicationClosed(_)) => break,
Err(e) => {
eprintln!("accepting stream failed: {e:?}");
break;
}
Ok(stream) => stream,
};
trace!("stream established");
tokio::spawn(async move {
drain_stream(&mut recv_stream, opt.read_unordered).await?;
send_data_on_stream(&mut send_stream, opt.download_size).await?;
Ok::<_, anyhow::Error>(())
});
}
if opt.stats {
println!("\nServer connection stats:\n{:#?}", connection.stats());
}
}));
}
for handle in server_tasks {
if let Err(e) = handle.await {
eprintln!("Server task error: {e:?}");
};
}
Ok(())
}