iroh_relay/server/
resolver.rs

1use std::sync::Arc;
2
3use n0_future::{
4    task::{self, AbortOnDropHandle},
5    time::{self, Duration},
6};
7use reloadable_state::Reloadable;
8use rustls::{
9    server::{ClientHello, ResolvesServerCert},
10    sign::CertifiedKey,
11};
12use tokio_util::sync::CancellationToken;
13
14/// The default certificate reload interval.
15pub const DEFAULT_CERT_RELOAD_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
16
17/// A Certificate resolver that reloads the certificate every interval
18#[derive(Debug)]
19pub struct ReloadingResolver<Loader: Send + 'static> {
20    /// The inner reloadable value.
21    reloadable: Arc<Reloadable<CertifiedKey, Loader>>,
22    /// The handle to the task that reloads the certificate.
23    _handle: AbortOnDropHandle<()>,
24    /// Cancel token to shutdown the resolver.
25    cancel_token: CancellationToken,
26}
27
28impl<Loader> ReloadingResolver<Loader>
29where
30    Loader: Send + reloadable_state::core::Loader<Value = CertifiedKey> + 'static,
31{
32    /// Perform the initial load and construct the [`ReloadingResolver`].
33    pub async fn init(loader: Loader, interval: Duration) -> Result<Self, Loader::Error> {
34        let (reloadable, _) = Reloadable::init_load(loader).await?;
35        let reloadable = Arc::new(reloadable);
36
37        let cancel_token = CancellationToken::new();
38
39        // Spawn a task to reload the certificate every interval.
40        let _reloadable = reloadable.clone();
41        let _cancel_token = cancel_token.clone();
42        let _handle = task::spawn(async move {
43            let mut interval = time::interval(interval);
44            loop {
45                tokio::select! {
46                    _ = interval.tick() => {
47                        let _ = _reloadable.reload().await;
48                        tracing::info!("Reloaded the certificate");
49                    },
50                    _ = _cancel_token.cancelled() => {
51                        tracing::trace!("shutting down");
52                        break;
53                    }
54                }
55            }
56        });
57        let _handle = AbortOnDropHandle::new(_handle);
58
59        Ok(Self {
60            reloadable,
61            _handle,
62            cancel_token,
63        })
64    }
65
66    /// Shutdown the resolver.
67    pub fn shutdown(self) {
68        self.cancel_token.cancel();
69    }
70
71    /// Reload the certificate.
72    pub async fn reload(&self) {
73        let _ = self.reloadable.reload().await;
74    }
75}
76
77impl<Loader> ResolvesServerCert for ReloadingResolver<Loader>
78where
79    Loader: reloadable_state::core::Loader<Value = CertifiedKey>,
80    Loader: Send,
81    Loader: std::fmt::Debug,
82{
83    fn resolve(&self, _client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
84        Some(self.reloadable.get())
85    }
86}
87
88impl<Loader: Send> std::ops::Deref for ReloadingResolver<Loader> {
89    type Target = Reloadable<CertifiedKey, Loader>;
90
91    fn deref(&self) -> &Self::Target {
92        &self.reloadable
93    }
94}