iroh_n0des/simulation/
proto.rs

1use std::{collections::BTreeMap, net::SocketAddr};
2
3use anyhow::Result;
4use bytes::Bytes;
5use iroh::{NodeAddr, NodeId};
6use iroh_metrics::encoding::Update;
7use irpc::{WithChannels, channel::oneshot, rpc_requests, util::make_insecure_client_endpoint};
8use irpc_iroh::IrohRemoteConnection;
9use n0_future::{IterExt, StreamExt};
10use serde::{Deserialize, Serialize};
11use time::OffsetDateTime as DateTime;
12use tracing::debug;
13use uuid::Uuid;
14
15use super::{ENV_TRACE_SERVER, ENV_TRACE_SESSION_ID};
16
17pub const ALPN: &[u8] = b"/iroh/n0des-sim/1";
18
19pub type RemoteResult<T> = Result<T, RemoteError>;
20
21#[derive(Serialize, Deserialize, thiserror::Error, Debug)]
22pub enum RemoteError {
23    #[error("{0}")]
24    Other(String),
25    #[error("Trace not found")]
26    TraceNotFound,
27}
28
29impl RemoteError {
30    pub fn other(s: impl ToString) -> Self {
31        Self::Other(s.to_string())
32    }
33}
34
35impl From<anyhow::Error> for RemoteError {
36    fn from(value: anyhow::Error) -> Self {
37        Self::other(value)
38    }
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42#[rpc_requests(message = TraceMessage)]
43#[allow(clippy::large_enum_variant)]
44pub enum TraceProtocol {
45    #[rpc(tx=oneshot::Sender<RemoteResult<Option<GetSessionResponse>>>)]
46    GetSession(GetSession),
47    #[rpc(tx=oneshot::Sender<RemoteResult<Uuid>>)]
48    InitTrace(InitTrace),
49    #[rpc(tx=oneshot::Sender<RemoteResult<GetTraceResponse>>)]
50    GetTrace(GetTrace),
51    #[rpc(tx=oneshot::Sender<RemoteResult<StartNodeResponse>>)]
52    StartNode(StartNode),
53    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
54    EndNode(EndNode),
55    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
56    CloseTrace(CloseTrace),
57    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
58    PutCheckpoint(PutCheckpoint),
59    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
60    PutLogs(PutLogs),
61    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
62    PutMetrics(PutMetrics),
63    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
64    WaitCheckpoint(WaitCheckpoint),
65    #[rpc(tx=oneshot::Sender<RemoteResult<WaitStartResponse>>)]
66    WaitStart(WaitStart),
67}
68
69#[derive(Debug, Serialize, Deserialize, Clone)]
70pub struct GetSession {
71    pub session_id: Uuid,
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct GetSessionResponse {
76    pub traces: Vec<TraceDetails>,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone)]
80pub struct TraceDetails {
81    pub trace_id: Uuid,
82    pub info: TraceInfo,
83    pub finished: bool,
84}
85
86#[derive(Debug, Serialize, Deserialize, Clone)]
87pub struct InitTrace {
88    pub session_id: Uuid,
89    pub info: TraceInfo,
90    #[serde(with = "time::serde::rfc3339")]
91    pub start_time: DateTime,
92    pub setup_data: Option<Bytes>,
93}
94
95#[derive(Debug, Serialize, Deserialize, Clone)]
96pub struct CloseTrace {
97    pub trace_id: Uuid,
98    #[serde(with = "time::serde::rfc3339")]
99    pub end_time: DateTime,
100    pub result: Result<(), String>,
101}
102
103#[derive(Debug, Serialize, Deserialize, Clone)]
104pub struct TraceInfo {
105    pub name: String,
106    pub node_count: u32,
107    pub expected_checkpoints: Option<u64>,
108}
109
110impl TraceInfo {
111    pub fn new(name: impl ToString, node_count: u32) -> Self {
112        Self {
113            name: name.to_string(),
114            node_count,
115            expected_checkpoints: None,
116        }
117    }
118}
119
120#[derive(Debug, Serialize, Deserialize, Clone)]
121pub struct GetTrace {
122    pub session_id: Uuid,
123    pub name: String,
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct GetTraceResponse {
128    pub trace_id: Uuid,
129    pub info: TraceInfo,
130    pub setup_data: Option<Bytes>,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct StartNode {
135    pub trace_id: Uuid,
136    pub node_info: NodeInfo,
137    #[serde(with = "time::serde::rfc3339")]
138    pub start_time: DateTime,
139}
140
141#[derive(Debug, Serialize, Deserialize, Clone)]
142pub struct StartNodeResponse {
143    // pub trace_id: Uuid,
144    // pub setup_data: Option<Bytes>,
145}
146
147pub type NodeIdx = u32;
148
149#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
150pub enum Scope {
151    Integrated,
152    Isolated(NodeIdx),
153}
154
155#[derive(Debug, Serialize, Deserialize, Clone)]
156pub struct EndNode {
157    pub trace_id: Uuid,
158    pub node_idx: NodeIdx,
159    #[serde(with = "time::serde::rfc3339")]
160    pub end_time: DateTime,
161    pub result: Result<(), String>,
162}
163
164#[derive(Debug, Serialize, Deserialize, Clone)]
165pub struct NodeInfo {
166    pub idx: NodeIdx,
167    pub node_id: Option<NodeId>,
168    pub label: Option<String>,
169}
170
171impl NodeInfo {
172    pub fn new_empty(idx: NodeIdx) -> Self {
173        Self {
174            idx,
175            node_id: None,
176            label: None,
177        }
178    }
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct NodeInfoWithAddr {
183    pub info: NodeInfo,
184    pub addr: Option<NodeAddr>,
185}
186
187#[derive(Debug, Serialize, Deserialize, Clone)]
188pub struct PutLogs {
189    pub trace_id: Uuid,
190    pub scope: Scope,
191    pub json_lines: Vec<String>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195pub struct PutMetrics {
196    pub trace_id: Uuid,
197    pub node_id: NodeId,
198    pub checkpoint_id: Option<CheckpointId>,
199    #[serde(with = "time::serde::rfc3339")]
200    pub time: DateTime,
201    pub metrics: iroh_metrics::encoding::Update,
202}
203
204pub type CheckpointId = u64;
205
206#[derive(Debug, Serialize, Deserialize, Clone)]
207pub struct PutCheckpoint {
208    pub trace_id: Uuid,
209    pub checkpoint_id: CheckpointId,
210    pub node_idx: NodeIdx,
211    pub label: Option<String>,
212    #[serde(with = "time::serde::rfc3339")]
213    pub time: DateTime,
214    pub result: Result<(), String>,
215}
216
217#[derive(Debug, Serialize, Deserialize, Clone)]
218pub struct WaitCheckpoint {
219    pub trace_id: Uuid,
220    pub checkpoint_id: CheckpointId,
221}
222
223#[derive(Debug, Serialize, Deserialize, Clone)]
224pub struct WaitStart {
225    pub trace_id: Uuid,
226    pub info: NodeInfoWithAddr,
227}
228
229#[derive(Debug, Serialize, Deserialize, Clone)]
230pub struct WaitStartResponse {
231    pub infos: Vec<NodeInfoWithAddr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct TraceClient {
236    client: irpc::Client<TraceProtocol>,
237    session_id: Uuid,
238}
239
240impl TraceClient {
241    pub fn from_env() -> Result<Option<Self>> {
242        if let Ok(addr) = std::env::var(ENV_TRACE_SERVER) {
243            let addr: SocketAddr = addr.parse()?;
244            let session_id: Uuid = match std::env::var(ENV_TRACE_SESSION_ID) {
245                Ok(id) => id.parse()?,
246                Err(_) => Uuid::now_v7(),
247            };
248            Ok(Some(Self::connect_quinn_insecure(addr, session_id)?))
249        } else {
250            Ok(None)
251        }
252    }
253
254    pub fn from_env_or_local() -> Result<Self> {
255        Ok(Self::from_env()?.unwrap_or_else(Self::local))
256    }
257
258    pub fn local() -> Self {
259        let (tx, rx) = tokio::sync::mpsc::channel(8);
260        LocalActor::spawn(rx);
261        let session_id = Uuid::now_v7();
262        Self {
263            client: irpc::Client::from(tx),
264            session_id,
265        }
266    }
267
268    pub fn new(client: irpc::Client<TraceProtocol>, session_id: Uuid) -> Self {
269        Self { client, session_id }
270    }
271
272    pub fn connect_quinn_insecure(remote: SocketAddr, session_id: Uuid) -> Result<Self> {
273        let addr_localhost = "127.0.0.1:0".parse().unwrap();
274        let endpoint = make_insecure_client_endpoint(addr_localhost)?;
275        Ok(Self::connect_quinn_endpoint(endpoint, remote, session_id))
276    }
277
278    pub fn connect_quinn_endpoint(
279        endpoint: quinn::Endpoint,
280        remote: SocketAddr,
281        session_id: Uuid,
282    ) -> Self {
283        let client = irpc::Client::quinn(endpoint, remote);
284        Self { client, session_id }
285    }
286
287    pub async fn connect_iroh(remote: iroh::NodeId, session_id: Uuid) -> Result<Self> {
288        let endpoint = iroh::Endpoint::builder().bind().await?;
289        Ok(Self::connect_iroh_endpoint(endpoint, remote, session_id))
290    }
291
292    pub fn connect_iroh_endpoint(
293        endpoint: iroh::Endpoint,
294        remote: impl Into<iroh::NodeAddr>,
295        session_id: Uuid,
296    ) -> Self {
297        let conn = IrohRemoteConnection::new(endpoint, remote.into(), ALPN.to_vec());
298        let client = irpc::Client::boxed(conn);
299        Self { client, session_id }
300    }
301
302    pub async fn init_and_start_trace(&self, name: &str) -> Result<ActiveTrace> {
303        let trace_info = TraceInfo::new(name, 1);
304        let trace_id = self.init_trace(trace_info, None).await?;
305        let node_info = NodeInfo::new_empty(0);
306        let client = self.start_node(trace_id, node_info).await?;
307        Ok(client)
308    }
309
310    pub async fn init_trace(&self, info: TraceInfo, setup_data: Option<Bytes>) -> Result<Uuid> {
311        debug!("init trace {info:?}");
312        let trace_id = self
313            .client
314            .rpc(InitTrace {
315                info,
316                session_id: self.session_id,
317                start_time: DateTime::now_utc(),
318                setup_data,
319            })
320            .await??;
321        debug!("init trace {trace_id}: OK");
322        Ok(trace_id)
323    }
324
325    pub async fn get_trace(&self, name: String) -> Result<GetTraceResponse> {
326        debug!("get trace {name}");
327        let data = self
328            .client
329            .rpc(GetTrace {
330                session_id: self.session_id,
331                name,
332            })
333            .await??;
334        debug!(?data, "get trace: OK");
335        Ok(data)
336    }
337
338    pub async fn close_trace(&self, trace_id: Uuid, result: Result<(), String>) -> Result<()> {
339        let end_time = DateTime::now_utc();
340        debug!(%trace_id, ?result, "close trace");
341        self.client
342            .rpc(CloseTrace {
343                trace_id,
344                end_time,
345                result,
346            })
347            .await??;
348        debug!(%trace_id, "close trace: OK");
349        Ok(())
350    }
351
352    pub async fn get_session(&self, session_id: Uuid) -> Result<Option<GetSessionResponse>> {
353        let res = self.client.rpc(GetSession { session_id }).await??;
354        Ok(res)
355    }
356
357    pub async fn put_logs(
358        &self,
359        trace_id: Uuid,
360        scope: Scope,
361        json_lines: Vec<String>,
362    ) -> Result<()> {
363        self.client
364            .rpc(PutLogs {
365                scope,
366                trace_id,
367                json_lines,
368            })
369            .await??;
370        Ok(())
371    }
372
373    pub async fn start_node(&self, trace_id: Uuid, node_info: NodeInfo) -> Result<ActiveTrace> {
374        let start_time = DateTime::now_utc();
375        let node_idx = node_info.idx;
376        debug!(%trace_id, node_idx, "start node");
377        let res = self
378            .client
379            .rpc(StartNode {
380                trace_id,
381                node_info,
382                start_time,
383            })
384            .await??;
385        let StartNodeResponse {} = res;
386        debug!(%trace_id, node_idx, "start node: OK");
387        Ok(ActiveTrace {
388            client: self.client.clone(),
389            trace_id,
390            node_idx,
391        })
392    }
393}
394
395#[derive(Debug, Clone)]
396pub struct ActiveTrace {
397    client: irpc::Client<TraceProtocol>,
398    trace_id: Uuid,
399    node_idx: u32,
400}
401
402impl ActiveTrace {
403    pub async fn put_checkpoint(
404        &self,
405        id: CheckpointId,
406        label: Option<String>,
407        result: Result<(), String>,
408    ) -> Result<()> {
409        debug!(id, "put checkpoint");
410        let time = DateTime::now_utc();
411        self.client
412            .rpc(PutCheckpoint {
413                trace_id: self.trace_id,
414                checkpoint_id: id,
415                node_idx: self.node_idx,
416                time,
417                label,
418                result,
419            })
420            .await??;
421        Ok(())
422    }
423
424    pub async fn put_metrics(
425        &self,
426        node_id: NodeId,
427        checkpoint_id: Option<CheckpointId>,
428        metrics: Update,
429    ) -> Result<()> {
430        let time = DateTime::now_utc();
431        debug!(count = metrics.values.items.len(), "put metrics");
432        self.client
433            .rpc(PutMetrics {
434                trace_id: self.trace_id,
435                node_id,
436                checkpoint_id,
437                time,
438                metrics,
439            })
440            .await??;
441        Ok(())
442    }
443
444    pub async fn put_logs(&self, json_lines: Vec<String>) -> Result<()> {
445        self.client
446            .rpc(PutLogs {
447                scope: Scope::Isolated(self.node_idx),
448                trace_id: self.trace_id,
449                json_lines,
450            })
451            .await??;
452        Ok(())
453    }
454
455    pub async fn wait_start(&self, info: NodeInfoWithAddr) -> Result<Vec<NodeInfoWithAddr>> {
456        debug!("waiting for start...");
457        let res = self
458            .client
459            .rpc(WaitStart {
460                info,
461                trace_id: self.trace_id,
462            })
463            .await??;
464        debug!("start confirmed");
465        Ok(res.infos)
466    }
467
468    pub async fn wait_checkpoint(&self, checkpoint_id: CheckpointId) -> Result<()> {
469        debug!(?checkpoint_id, "waiting for checkpoint...");
470        let res = self
471            .client
472            .rpc(WaitCheckpoint {
473                checkpoint_id,
474                trace_id: self.trace_id,
475            })
476            .await;
477        res??;
478        debug!(?checkpoint_id, "checkpoint confirmed");
479        Ok(())
480    }
481
482    pub async fn end(&self, result: Result<(), String>) -> Result<()> {
483        debug!("end node");
484        let end_time = DateTime::now_utc();
485        self.client
486            .rpc(EndNode {
487                trace_id: self.trace_id,
488                node_idx: self.node_idx,
489                end_time,
490                result,
491            })
492            .await??;
493        Ok(())
494    }
495}
496
497#[derive(Default)]
498struct LocalActor {
499    traces_by_name: BTreeMap<String, Uuid>,
500    traces: BTreeMap<Uuid, TraceState>,
501}
502
503struct TraceState {
504    init: InitTrace,
505    nodes: Vec<NodeInfoWithAddr>,
506    barrier_start: Vec<oneshot::Sender<RemoteResult<WaitStartResponse>>>,
507    barrier_checkpoint: BTreeMap<CheckpointId, Vec<oneshot::Sender<RemoteResult<()>>>>,
508}
509
510impl TraceState {
511    fn node_count(&self) -> usize {
512        self.init.info.node_count as usize
513    }
514}
515impl LocalActor {
516    pub fn spawn(rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
517        let actor = Self::default();
518        tokio::task::spawn(actor.run(rx));
519    }
520
521    pub async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
522        while let Some(message) = rx.recv().await {
523            self.handle_message(message).await;
524        }
525    }
526    async fn handle_message(&mut self, message: TraceMessage) {
527        match message {
528            TraceMessage::GetSession(_msg) => {}
529            TraceMessage::InitTrace(msg) => {
530                let WithChannels { inner, tx, .. } = msg;
531                if self.traces_by_name.contains_key(&inner.info.name) {
532                    return send_err(tx, RemoteError::other("Trace already initialized")).await;
533                }
534                let uuid = Uuid::now_v7();
535                self.traces_by_name.insert(inner.info.name.clone(), uuid);
536                self.traces.insert(
537                    uuid,
538                    TraceState {
539                        init: inner,
540                        nodes: Default::default(),
541                        barrier_start: Default::default(),
542                        barrier_checkpoint: Default::default(),
543                    },
544                );
545                tx.send(Ok(uuid)).await.ok();
546            }
547            TraceMessage::GetTrace(msg) => {
548                let WithChannels { inner, tx, .. } = msg;
549                let GetTrace {
550                    session_id: _,
551                    name,
552                } = inner;
553                let Some((trace_id, info)) = self
554                    .traces_by_name
555                    .get(&name)
556                    .and_then(|trace_id| self.traces.get_key_value(trace_id))
557                else {
558                    return send_err(tx, RemoteError::other("Trace not initialized")).await;
559                };
560                tx.send(Ok(GetTraceResponse {
561                    trace_id: *trace_id,
562                    info: info.init.info.clone(),
563                    setup_data: info.init.setup_data.clone(),
564                }))
565                .await
566                .ok();
567            }
568            TraceMessage::StartNode(msg) => {
569                let WithChannels { inner, tx, .. } = msg;
570                if self.traces.contains_key(&inner.trace_id) {
571                    tx.send(Ok(StartNodeResponse {})).await.ok();
572                } else {
573                    send_err(tx, RemoteError::other("Trace not initialized")).await;
574                }
575            }
576            TraceMessage::EndNode(msg) => {
577                // noop
578                msg.tx.send(Ok(())).await.ok();
579            }
580            TraceMessage::CloseTrace(msg) => {
581                // noop
582                msg.tx.send(Ok(())).await.ok();
583            }
584            TraceMessage::PutCheckpoint(msg) => {
585                // noop
586                msg.tx.send(Ok(())).await.ok();
587            }
588            TraceMessage::PutLogs(msg) => {
589                // noop
590                msg.tx.send(Ok(())).await.ok();
591            }
592            TraceMessage::PutMetrics(msg) => {
593                // noop
594                msg.tx.send(Ok(())).await.ok();
595            }
596            TraceMessage::WaitCheckpoint(msg) => {
597                let WithChannels { inner, tx, .. } = msg;
598                let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
599                    return send_err(tx, RemoteError::TraceNotFound).await;
600                };
601                let node_count = trace.node_count();
602                let barrier = trace
603                    .barrier_checkpoint
604                    .entry(inner.checkpoint_id)
605                    .or_default();
606                barrier.push(tx);
607                debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, count=barrier.len(), total=node_count, "wait checkpoint");
608                if barrier.len() == node_count {
609                    debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, "release");
610                    barrier
611                        .drain(..)
612                        .map(|tx| tx.send(Ok(())))
613                        .into_unordered_stream()
614                        .count()
615                        .await;
616                }
617            }
618            TraceMessage::WaitStart(msg) => {
619                let WithChannels { inner, tx, .. } = msg;
620                let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
621                    return send_err(tx, RemoteError::TraceNotFound).await;
622                };
623                trace.nodes.push(inner.info);
624                trace.barrier_start.push(tx);
625                if trace.barrier_start.len() == trace.init.info.node_count as usize {
626                    let data = WaitStartResponse {
627                        infos: trace.nodes.clone(),
628                    };
629                    for tx in trace.barrier_start.drain(..) {
630                        tx.send(Ok(data.clone())).await.ok();
631                    }
632                }
633            }
634        }
635    }
636}
637
638async fn send_err<T>(tx: oneshot::Sender<RemoteResult<T>>, err: RemoteError) {
639    tx.send(Err(err)).await.ok();
640}