iroh_n0des/
client.rs

1use std::{
2    path::Path,
3    sync::{Arc, RwLock},
4    time::Duration,
5};
6
7use anyhow::{Result, anyhow, ensure};
8use iroh::{Endpoint, NodeAddr, NodeId};
9use iroh_metrics::{Registry, encoding::Encoder};
10use irpc_iroh::IrohRemoteConnection;
11use n0_future::task::AbortOnDropHandle;
12use rcan::Rcan;
13use tracing::warn;
14use uuid::Uuid;
15
16use crate::{
17    caps::Caps,
18    protocol::{ALPN, Auth, N0desClient, Ping, PutMetrics, RemoteError},
19};
20
21#[derive(Debug)]
22pub struct Client {
23    client: N0desClient,
24    _metrics_task: Option<AbortOnDropHandle<()>>,
25}
26
27/// Constructs an IPS client
28pub struct ClientBuilder {
29    cap_expiry: Duration,
30    cap: Option<Rcan<Caps>>,
31    endpoint: Endpoint,
32    enable_metrics: Option<Duration>,
33}
34
35const DEFAULT_CAP_EXPIRY: Duration = Duration::from_secs(60 * 60 * 24 * 30); // 1 month
36
37impl ClientBuilder {
38    pub fn new(endpoint: &Endpoint) -> Self {
39        Self {
40            cap: None,
41            cap_expiry: DEFAULT_CAP_EXPIRY,
42            endpoint: endpoint.clone(),
43            enable_metrics: Some(Duration::from_secs(60)),
44        }
45    }
46
47    /// Set the metrics collection interval
48    ///
49    /// Defaults to enabled, every 60 seconds.
50    pub fn metrics_interval(mut self, interval: Duration) -> Self {
51        self.enable_metrics = Some(interval);
52        self
53    }
54
55    /// Disable metrics collection.
56    pub fn disable_metrics(mut self) -> Self {
57        self.enable_metrics = None;
58        self
59    }
60
61    /// Loads the private ssh key from the given path, and creates the needed capability.
62    pub async fn ssh_key_from_file<P: AsRef<Path>>(self, path: P) -> Result<Self> {
63        let file_content = tokio::fs::read_to_string(path).await?;
64        let private_key = ssh_key::PrivateKey::from_openssh(&file_content)?;
65
66        self.ssh_key(&private_key)
67    }
68
69    /// Creates the capability from the provided private ssh key.
70    pub fn ssh_key(mut self, key: &ssh_key::PrivateKey) -> Result<Self> {
71        let local_node = self.endpoint.node_id();
72        let rcan = crate::caps::create_api_token(key, local_node, self.cap_expiry, Caps::all())?;
73        self.cap.replace(rcan);
74
75        Ok(self)
76    }
77
78    /// Sets the rcan directly.
79    pub fn rcan(mut self, cap: Rcan<Caps>) -> Result<Self> {
80        ensure!(
81            NodeId::from(*cap.audience()) == self.endpoint.node_id(),
82            "invalid audience"
83        );
84        self.cap.replace(cap);
85        Ok(self)
86    }
87
88    /// Create a new client, connected to the provide service node
89    pub async fn build(self, remote: impl Into<NodeAddr>) -> Result<Client, BuildError> {
90        let cap = self.cap.ok_or(BuildError::MissingCapability)?;
91        let conn = IrohRemoteConnection::new(self.endpoint.clone(), remote.into(), ALPN.to_vec());
92        let client = N0desClient::boxed(conn);
93
94        // If auth fails, the connection is aborted.
95        let () = client.rpc(Auth { caps: cap }).await?;
96
97        let metrics_task = self.enable_metrics.map(|interval| {
98            AbortOnDropHandle::new(n0_future::task::spawn(
99                MetricsTask {
100                    client: client.clone(),
101                    session_id: Uuid::new_v4(),
102                    endpoint: self.endpoint.clone(),
103                }
104                .run(interval),
105            ))
106        });
107
108        Ok(Client {
109            client,
110            _metrics_task: metrics_task,
111        })
112    }
113}
114
115#[derive(thiserror::Error, Debug)]
116pub enum BuildError {
117    #[error("Missing capability")]
118    MissingCapability,
119    #[error("Unauthorized")]
120    Unauthorized,
121    #[error("Remote error: {0}")]
122    Remote(#[from] RemoteError),
123    #[error("Connection error: {0}")]
124    Rpc(irpc::Error),
125}
126
127impl From<irpc::Error> for BuildError {
128    fn from(value: irpc::Error) -> Self {
129        match value {
130            irpc::Error::Request(irpc::RequestError::Connection(
131                iroh::endpoint::ConnectionError::ApplicationClosed(frame),
132            )) if frame.error_code == 401u32.into() => Self::Unauthorized,
133            value => Self::Rpc(value),
134        }
135    }
136}
137
138#[derive(thiserror::Error, Debug)]
139pub enum Error {
140    #[error("Remote error: {0}")]
141    Remote(#[from] RemoteError),
142    #[error("Connection error: {0}")]
143    Rpc(#[from] irpc::Error),
144    #[error(transparent)]
145    Other(#[from] anyhow::Error),
146}
147
148impl Client {
149    pub fn builder(endpoint: &Endpoint) -> ClientBuilder {
150        ClientBuilder::new(endpoint)
151    }
152
153    /// Pings the remote node.
154    pub async fn ping(&mut self) -> Result<(), Error> {
155        let req = rand::random();
156        let pong = self.client.rpc(Ping { req }).await?;
157        if pong.req == req {
158            Ok(())
159        } else {
160            Err(Error::Other(anyhow!("unexpected pong response")))
161        }
162    }
163}
164
165struct MetricsTask {
166    client: N0desClient,
167    session_id: Uuid,
168    endpoint: Endpoint,
169}
170
171impl MetricsTask {
172    async fn run(self, interval: Duration) {
173        let mut registry = Registry::default();
174        registry.register_all(self.endpoint.metrics());
175        let registry = Arc::new(RwLock::new(registry));
176        let mut encoder = Encoder::new(registry);
177
178        let mut metrics_timer = tokio::time::interval(interval);
179
180        loop {
181            metrics_timer.tick().await;
182            if let Err(err) = self.send_metrics(&mut encoder).await {
183                warn!("failed to push metrics: {:#?}", err);
184            }
185        }
186    }
187
188    async fn send_metrics(&self, encoder: &mut Encoder) -> Result<()> {
189        let update = encoder.export();
190        let req = PutMetrics {
191            session_id: self.session_id,
192            update,
193        };
194        self.client.rpc(req).await??;
195        Ok(())
196    }
197}