use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
time::Instant,
};
use anyhow::{bail, Context, Result};
use axum::{
extract::{ConnectInfo, Request},
handler::Handler,
http::Method,
middleware::{self, Next},
response::IntoResponse,
routing::get,
Router,
};
use iroh_metrics::{inc, inc_by};
use serde::{Deserialize, Serialize};
use tokio::{net::TcpListener, task::JoinSet};
use tower_http::{
cors::{self, CorsLayer},
trace::TraceLayer,
};
use tracing::{info, span, warn, Level};
mod doh;
mod error;
mod pkarr;
mod rate_limiting;
mod tls;
pub use self::{rate_limiting::RateLimitConfig, tls::CertMode};
use crate::{config::Config, metrics::Metrics, state::AppState};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct HttpConfig {
pub port: u16,
pub bind_addr: Option<IpAddr>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct HttpsConfig {
pub port: u16,
pub bind_addr: Option<IpAddr>,
pub domains: Vec<String>,
pub cert_mode: CertMode,
pub letsencrypt_contact: Option<String>,
pub letsencrypt_prod: Option<bool>,
}
pub struct HttpServer {
tasks: JoinSet<std::io::Result<()>>,
http_addr: Option<SocketAddr>,
https_addr: Option<SocketAddr>,
}
impl HttpServer {
pub async fn spawn(
http_config: Option<HttpConfig>,
https_config: Option<HttpsConfig>,
rate_limit_config: RateLimitConfig,
state: AppState,
) -> Result<HttpServer> {
if http_config.is_none() && https_config.is_none() {
bail!("Either http or https config is required");
}
let app = create_app(state, &rate_limit_config);
let mut tasks = JoinSet::new();
let http_addr = if let Some(config) = http_config {
let bind_addr = SocketAddr::new(
config.bind_addr.unwrap_or(Ipv4Addr::UNSPECIFIED.into()),
config.port,
);
let app = app.clone();
let listener = TcpListener::bind(bind_addr).await?.into_std()?;
let bound_addr = listener.local_addr()?;
let fut = axum_server::from_tcp(listener)
.serve(app.into_make_service_with_connect_info::<SocketAddr>());
info!("HTTP server listening on {bind_addr}");
tasks.spawn(fut);
Some(bound_addr)
} else {
None
};
let https_addr = if let Some(config) = https_config {
let bind_addr = SocketAddr::new(
config.bind_addr.unwrap_or(Ipv4Addr::UNSPECIFIED.into()),
config.port,
);
let acceptor = {
let cache_path = Config::data_dir()?
.join("cert_cache")
.join(config.cert_mode.to_string());
tokio::fs::create_dir_all(&cache_path)
.await
.with_context(|| {
format!("failed to create cert cache dir at {cache_path:?}")
})?;
config
.cert_mode
.build(
config.domains,
cache_path,
config.letsencrypt_contact,
config.letsencrypt_prod.unwrap_or(false),
)
.await?
};
let listener = TcpListener::bind(bind_addr).await?.into_std()?;
let bound_addr = listener.local_addr()?;
let fut = axum_server::from_tcp(listener)
.acceptor(acceptor)
.serve(app.into_make_service_with_connect_info::<SocketAddr>());
info!("HTTPS server listening on {bind_addr}");
tasks.spawn(fut);
Some(bound_addr)
} else {
None
};
Ok(HttpServer {
tasks,
http_addr,
https_addr,
})
}
pub fn http_addr(&self) -> Option<SocketAddr> {
self.http_addr
}
pub fn https_addr(&self) -> Option<SocketAddr> {
self.https_addr
}
pub async fn shutdown(mut self) -> Result<()> {
self.tasks.abort_all();
self.run_until_done().await?;
Ok(())
}
pub async fn run_until_done(mut self) -> Result<()> {
let mut final_res: anyhow::Result<()> = Ok(());
while let Some(res) = self.tasks.join_next().await {
match res {
Ok(Ok(())) => {}
Err(err) if err.is_cancelled() => {}
Ok(Err(err)) => {
warn!(?err, "task failed");
final_res = Err(anyhow::Error::from(err));
}
Err(err) => {
warn!(?err, "task panicked");
final_res = Err(err.into());
}
}
}
final_res
}
}
pub(crate) fn create_app(state: AppState, rate_limit_config: &RateLimitConfig) -> Router {
let cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST, Method::PUT])
.allow_origin(cors::Any);
let trace = TraceLayer::new_for_http().make_span_with(|request: &http::Request<_>| {
let conn_info = request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.expect("connectinfo extension to be present");
let span = span!(
Level::DEBUG,
"http_request",
method = ?request.method(),
uri = ?request.uri(),
src = %conn_info.0,
);
span
});
let rate_limit = rate_limiting::create(rate_limit_config);
let router = Router::new()
.route("/dns-query", get(doh::get).post(doh::post))
.route(
"/pkarr/:key",
if let Some(rate_limit) = rate_limit {
get(pkarr::get).put(pkarr::put.layer(rate_limit))
} else {
get(pkarr::get).put(pkarr::put)
},
)
.route("/healthcheck", get(|| async { "OK" }))
.route("/", get(|| async { "Hi!" }))
.with_state(state);
router
.layer(cors)
.layer(trace)
.route_layer(middleware::from_fn(metrics_middleware))
}
async fn metrics_middleware(req: Request, next: Next) -> impl IntoResponse {
let start = Instant::now();
let response = next.run(req).await;
let latency = start.elapsed().as_millis();
let status = response.status();
inc_by!(Metrics, http_requests_duration_ms, latency as u64);
inc!(Metrics, http_requests);
if status.is_success() {
inc!(Metrics, http_requests_success);
} else {
inc!(Metrics, http_requests_error);
}
response
}