iroh_blobs/util/connection_pool.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
//! A simple iroh connection pool
//!
//! Entry point is [`ConnectionPool`]. You create a connection pool for a specific
//! ALPN and [`Options`]. Then the pool will manage connections for you.
//!
//! Access to connections is via the [`ConnectionPool::get_or_connect`] method, which
//! gives you access to a connection via a [`ConnectionRef`] if possible.
//!
//! It is important that you keep the [`ConnectionRef`] alive while you are using
//! the connection.
use std::{
collections::{HashMap, VecDeque},
ops::Deref,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use iroh::{endpoint::ConnectError, Endpoint, NodeId};
use n0_future::{
future::{self},
FuturesUnordered, MaybeFuture, Stream, StreamExt,
};
use snafu::Snafu;
use tokio::sync::{
mpsc::{self, error::SendError as TokioSendError},
oneshot, Notify,
};
use tokio_util::time::FutureExt as TimeFutureExt;
use tracing::{debug, error, trace};
/// Configuration options for the connection pool
#[derive(Debug, Clone, Copy)]
pub struct Options {
pub idle_timeout: Duration,
pub connect_timeout: Duration,
pub max_connections: usize,
}
impl Default for Options {
fn default() -> Self {
Self {
idle_timeout: Duration::from_secs(5),
connect_timeout: Duration::from_secs(1),
max_connections: 1024,
}
}
}
/// A reference to a connection that is owned by a connection pool.
#[derive(Debug)]
pub struct ConnectionRef {
connection: iroh::endpoint::Connection,
_permit: OneConnection,
}
impl Deref for ConnectionRef {
type Target = iroh::endpoint::Connection;
fn deref(&self) -> &Self::Target {
&self.connection
}
}
impl ConnectionRef {
fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self {
Self {
connection,
_permit: counter,
}
}
}
/// Error when a connection can not be acquired
///
/// This includes the normal iroh connection errors as well as pool specific
/// errors such as timeouts and connection limits.
#[derive(Debug, Clone, Snafu)]
#[snafu(module)]
pub enum PoolConnectError {
/// Connection pool is shut down
Shutdown,
/// Timeout during connect
Timeout,
/// Too many connections
TooManyConnections,
/// Error during connect
ConnectError { source: Arc<ConnectError> },
}
impl From<ConnectError> for PoolConnectError {
fn from(e: ConnectError) -> Self {
PoolConnectError::ConnectError {
source: Arc::new(e),
}
}
}
/// Error when calling a fn on the [`ConnectionPool`].
///
/// The only thing that can go wrong is that the connection pool is shut down.
#[derive(Debug, Snafu)]
#[snafu(module)]
pub enum ConnectionPoolError {
/// The connection pool has been shut down
Shutdown,
}
enum ActorMessage {
RequestRef(RequestRef),
ConnectionIdle { id: NodeId },
ConnectionShutdown { id: NodeId },
}
struct RequestRef {
id: NodeId,
tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
}
struct Context {
options: Options,
endpoint: Endpoint,
owner: ConnectionPool,
alpn: Vec<u8>,
}
impl Context {
async fn run_connection_actor(
self: Arc<Self>,
node_id: NodeId,
mut rx: mpsc::Receiver<RequestRef>,
) {
let context = self;
// Connect to the node
let state = context
.endpoint
.connect(node_id, &context.alpn)
.timeout(context.options.connect_timeout)
.await
.map_err(|_| PoolConnectError::Timeout)
.and_then(|r| r.map_err(PoolConnectError::from));
if let Err(e) = &state {
debug!(%node_id, "Failed to connect {e:?}, requesting shutdown");
if context.owner.close(node_id).await.is_err() {
return;
}
}
let counter = ConnectionCounter::new();
let idle_timer = MaybeFuture::default();
let idle_stream = counter.clone().idle_stream();
tokio::pin!(idle_timer, idle_stream);
loop {
tokio::select! {
biased;
// Handle new work
handler = rx.recv() => {
match handler {
Some(RequestRef { id, tx }) => {
assert!(id == node_id, "Not for me!");
match &state {
Ok(state) => {
let res = ConnectionRef::new(state.clone(), counter.get_one());
// clear the idle timer
idle_timer.as_mut().set_none();
tx.send(Ok(res)).ok();
}
Err(cause) => {
tx.send(Err(cause.clone())).ok();
}
}
}
None => {
// Channel closed - finish remaining tasks and exit
break;
}
}
}
_ = idle_stream.next() => {
if !counter.is_idle() {
continue;
};
// notify the pool that we are idle.
trace!(%node_id, "Idle");
if context.owner.idle(node_id).await.is_err() {
// If we can't notify the pool, we are shutting down
break;
}
// set the idle timer
idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
}
// Idle timeout - request shutdown
_ = &mut idle_timer => {
trace!(%node_id, "Idle timer expired, requesting shutdown");
context.owner.close(node_id).await.ok();
// Don't break here - wait for main actor to close our channel
}
}
}
if let Ok(connection) = state {
let reason = if counter.is_idle() { b"idle" } else { b"drop" };
connection.close(0u32.into(), reason);
}
trace!(%node_id, "Connection actor shutting down");
}
}
struct Actor {
rx: mpsc::Receiver<ActorMessage>,
connections: HashMap<NodeId, mpsc::Sender<RequestRef>>,
context: Arc<Context>,
// idle set (most recent last)
// todo: use a better data structure if this becomes a performance issue
idle: VecDeque<NodeId>,
// per connection tasks
tasks: FuturesUnordered<future::Boxed<()>>,
}
impl Actor {
pub fn new(
endpoint: Endpoint,
alpn: &[u8],
options: Options,
) -> (Self, mpsc::Sender<ActorMessage>) {
let (tx, rx) = mpsc::channel(100);
(
Self {
rx,
connections: HashMap::new(),
idle: VecDeque::new(),
context: Arc::new(Context {
options,
alpn: alpn.to_vec(),
endpoint,
owner: ConnectionPool { tx: tx.clone() },
}),
tasks: FuturesUnordered::new(),
},
tx,
)
}
fn add_idle(&mut self, id: NodeId) {
self.remove_idle(id);
self.idle.push_back(id);
}
fn remove_idle(&mut self, id: NodeId) {
self.idle.retain(|&x| x != id);
}
fn pop_oldest_idle(&mut self) -> Option<NodeId> {
self.idle.pop_front()
}
fn remove_connection(&mut self, id: NodeId) {
self.connections.remove(&id);
self.remove_idle(id);
}
async fn handle_msg(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::RequestRef(mut msg) => {
let id = msg.id;
self.remove_idle(id);
// Try to send to existing connection actor
if let Some(conn_tx) = self.connections.get(&id) {
if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
msg = e;
} else {
return;
}
// Connection actor died, remove it
self.remove_connection(id);
}
// No connection actor or it died - check limits
if self.connections.len() >= self.context.options.max_connections {
if let Some(idle) = self.pop_oldest_idle() {
// remove the oldest idle connection to make room for one more
trace!("removing oldest idle connection {}", idle);
self.connections.remove(&idle);
} else {
msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok();
return;
}
}
let (conn_tx, conn_rx) = mpsc::channel(100);
self.connections.insert(id, conn_tx.clone());
let context = self.context.clone();
self.tasks
.push(Box::pin(context.run_connection_actor(id, conn_rx)));
// Send the handler to the new actor
if conn_tx.send(msg).await.is_err() {
error!(%id, "Failed to send handler to new connection actor");
self.connections.remove(&id);
}
}
ActorMessage::ConnectionIdle { id } => {
self.add_idle(id);
trace!(%id, "connection idle");
}
ActorMessage::ConnectionShutdown { id } => {
// Remove the connection from our map - this closes the channel
self.remove_connection(id);
trace!(%id, "removed connection");
}
}
}
pub async fn run(mut self) {
loop {
tokio::select! {
biased;
msg = self.rx.recv() => {
if let Some(msg) = msg {
self.handle_msg(msg).await;
} else {
break;
}
}
_ = self.tasks.next(), if !self.tasks.is_empty() => {}
}
}
}
}
/// A connection pool
#[derive(Debug, Clone)]
pub struct ConnectionPool {
tx: mpsc::Sender<ActorMessage>,
}
impl ConnectionPool {
pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self {
let (actor, tx) = Actor::new(endpoint, alpn, options);
// Spawn the main actor
tokio::spawn(actor.run());
Self { tx }
}
/// Returns either a fresh connection or a reference to an existing one.
///
/// This is guaranteed to return after approximately [Options::connect_timeout]
/// with either an error or a connection.
pub async fn get_or_connect(
&self,
id: NodeId,
) -> std::result::Result<ConnectionRef, PoolConnectError> {
let (tx, rx) = oneshot::channel();
self.tx
.send(ActorMessage::RequestRef(RequestRef { id, tx }))
.await
.map_err(|_| PoolConnectError::Shutdown)?;
rx.await.map_err(|_| PoolConnectError::Shutdown)?
}
/// Close an existing connection, if it exists
///
/// This will finish pending tasks and close the connection. New tasks will
/// get a new connection if they are submitted after this call
pub async fn close(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
self.tx
.send(ActorMessage::ConnectionShutdown { id })
.await
.map_err(|_| ConnectionPoolError::Shutdown)?;
Ok(())
}
/// Notify the connection pool that a connection is idle.
///
/// Should only be called from connection handlers.
pub(crate) async fn idle(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> {
self.tx
.send(ActorMessage::ConnectionIdle { id })
.await
.map_err(|_| ConnectionPoolError::Shutdown)?;
Ok(())
}
}
#[derive(Debug)]
struct ConnectionCounterInner {
count: AtomicUsize,
notify: Notify,
}
#[derive(Debug, Clone)]
struct ConnectionCounter {
inner: Arc<ConnectionCounterInner>,
}
impl ConnectionCounter {
fn new() -> Self {
Self {
inner: Arc::new(ConnectionCounterInner {
count: Default::default(),
notify: Notify::new(),
}),
}
}
/// Increase the connection count and return a guard for the new connection
fn get_one(&self) -> OneConnection {
self.inner.count.fetch_add(1, Ordering::SeqCst);
OneConnection {
inner: self.inner.clone(),
}
}
fn is_idle(&self) -> bool {
self.inner.count.load(Ordering::SeqCst) == 0
}
/// Infinite stream that yields when the connection is briefly idle.
///
/// Note that you still have to check if the connection is still idle when
/// you get the notification.
///
/// Also note that this stream is triggered on [OneConnection::drop], so it
/// won't trigger initially even though a [ConnectionCounter] starts up as
/// idle.
fn idle_stream(self) -> impl Stream<Item = ()> {
n0_future::stream::unfold(self, |c| async move {
c.inner.notify.notified().await;
Some(((), c))
})
}
}
/// Guard for one connection
#[derive(Debug)]
struct OneConnection {
inner: Arc<ConnectionCounterInner>,
}
impl Drop for OneConnection {
fn drop(&mut self) {
if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 {
self.inner.notify.notify_waiters();
}
}
}