iroh_n0des/simulation/
proto.rs

1use std::{collections::BTreeMap, net::SocketAddr};
2
3use anyhow::Result;
4use bytes::Bytes;
5use iroh::{EndpointAddr, EndpointId};
6use iroh_metrics::encoding::Update;
7use irpc::{WithChannels, channel::oneshot, rpc_requests, util::make_insecure_client_endpoint};
8use irpc_iroh::IrohRemoteConnection;
9use n0_future::{IterExt, StreamExt};
10use serde::{Deserialize, Serialize};
11use time::OffsetDateTime as DateTime;
12use tracing::debug;
13use uuid::Uuid;
14
15use super::{ENV_TRACE_SERVER, ENV_TRACE_SESSION_ID};
16
17pub const ALPN: &[u8] = b"/iroh/n0des-sim/1";
18
19pub type RemoteResult<T> = Result<T, RemoteError>;
20
21#[derive(Serialize, Deserialize, thiserror::Error, Debug)]
22pub enum RemoteError {
23    #[error("{0}")]
24    Other(String),
25    #[error("Trace not found")]
26    TraceNotFound,
27}
28
29impl RemoteError {
30    pub fn other(s: impl ToString) -> Self {
31        Self::Other(s.to_string())
32    }
33}
34
35impl From<anyhow::Error> for RemoteError {
36    fn from(value: anyhow::Error) -> Self {
37        Self::other(value)
38    }
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42#[rpc_requests(message = TraceMessage)]
43#[allow(clippy::large_enum_variant)]
44pub enum TraceProtocol {
45    #[rpc(tx=oneshot::Sender<RemoteResult<Option<GetSessionResponse>>>)]
46    GetSession(GetSession),
47    #[rpc(tx=oneshot::Sender<RemoteResult<Uuid>>)]
48    InitTrace(InitTrace),
49    #[rpc(tx=oneshot::Sender<RemoteResult<GetTraceResponse>>)]
50    GetTrace(GetTrace),
51    #[rpc(tx=oneshot::Sender<RemoteResult<StartNodeResponse>>)]
52    StartNode(StartNode),
53    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
54    EndNode(EndNode),
55    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
56    CloseTrace(CloseTrace),
57    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
58    PutCheckpoint(PutCheckpoint),
59    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
60    PutLogs(PutLogs),
61    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
62    PutMetrics(PutMetrics),
63    #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
64    WaitCheckpoint(WaitCheckpoint),
65    #[rpc(tx=oneshot::Sender<RemoteResult<WaitStartResponse>>)]
66    WaitStart(WaitStart),
67}
68
69#[derive(Debug, Serialize, Deserialize, Clone)]
70pub struct GetSession {
71    pub session_id: Uuid,
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct GetSessionResponse {
76    pub traces: Vec<TraceDetails>,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone)]
80pub struct TraceDetails {
81    pub trace_id: Uuid,
82    pub info: TraceInfo,
83    pub finished: bool,
84}
85
86#[derive(Debug, Serialize, Deserialize, Clone)]
87pub struct InitTrace {
88    pub session_id: Uuid,
89    pub info: TraceInfo,
90    #[serde(with = "time::serde::rfc3339")]
91    pub start_time: DateTime,
92    pub setup_data: Option<Bytes>,
93}
94
95#[derive(Debug, Serialize, Deserialize, Clone)]
96pub struct CloseTrace {
97    pub trace_id: Uuid,
98    #[serde(with = "time::serde::rfc3339")]
99    pub end_time: DateTime,
100    pub result: Result<(), String>,
101}
102
103#[derive(Debug, Serialize, Deserialize, Clone)]
104pub struct TraceInfo {
105    pub name: String,
106    pub node_count: u32,
107    pub expected_checkpoints: Option<u64>,
108}
109
110impl TraceInfo {
111    pub fn new(name: impl ToString, node_count: u32) -> Self {
112        Self {
113            name: name.to_string(),
114            node_count,
115            expected_checkpoints: None,
116        }
117    }
118}
119
120#[derive(Debug, Serialize, Deserialize, Clone)]
121pub struct GetTrace {
122    pub session_id: Uuid,
123    pub name: String,
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct GetTraceResponse {
128    pub trace_id: Uuid,
129    pub info: TraceInfo,
130    pub setup_data: Option<Bytes>,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct StartNode {
135    pub trace_id: Uuid,
136    pub node_info: EndpointInfo,
137    #[serde(with = "time::serde::rfc3339")]
138    pub start_time: DateTime,
139}
140
141#[derive(Debug, Serialize, Deserialize, Clone)]
142pub struct StartNodeResponse {
143    // pub trace_id: Uuid,
144    // pub setup_data: Option<Bytes>,
145}
146
147pub type EndpointIdx = u32;
148
149#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
150pub enum Scope {
151    Integrated,
152    Isolated(EndpointIdx),
153}
154
155#[derive(Debug, Serialize, Deserialize, Clone)]
156pub struct EndNode {
157    pub trace_id: Uuid,
158    pub idx: EndpointIdx,
159    #[serde(with = "time::serde::rfc3339")]
160    pub end_time: DateTime,
161    pub result: Result<(), String>,
162}
163
164#[derive(Debug, Serialize, Deserialize, Clone)]
165pub struct EndpointInfo {
166    pub idx: EndpointIdx,
167    pub endpoint_id: Option<EndpointId>,
168    pub label: Option<String>,
169}
170
171impl EndpointInfo {
172    pub fn new_empty(idx: EndpointIdx) -> Self {
173        Self {
174            idx,
175            endpoint_id: None,
176            label: None,
177        }
178    }
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct EndpointInfoWithAddr {
183    pub info: EndpointInfo,
184    pub addr: Option<EndpointAddr>,
185}
186
187#[derive(Debug, Serialize, Deserialize, Clone)]
188pub struct PutLogs {
189    pub trace_id: Uuid,
190    pub scope: Scope,
191    pub json_lines: Vec<String>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195pub struct PutMetrics {
196    pub trace_id: Uuid,
197    pub id: EndpointId,
198    pub checkpoint_id: Option<CheckpointId>,
199    #[serde(with = "time::serde::rfc3339")]
200    pub time: DateTime,
201    pub metrics: iroh_metrics::encoding::Update,
202}
203
204pub type CheckpointId = u64;
205
206#[derive(Debug, Serialize, Deserialize, Clone)]
207pub struct PutCheckpoint {
208    pub trace_id: Uuid,
209    pub checkpoint_id: CheckpointId,
210    pub idx: EndpointIdx,
211    pub label: Option<String>,
212    #[serde(with = "time::serde::rfc3339")]
213    pub time: DateTime,
214    pub result: Result<(), String>,
215}
216
217#[derive(Debug, Serialize, Deserialize, Clone)]
218pub struct WaitCheckpoint {
219    pub trace_id: Uuid,
220    pub checkpoint_id: CheckpointId,
221}
222
223#[derive(Debug, Serialize, Deserialize, Clone)]
224pub struct WaitStart {
225    pub trace_id: Uuid,
226    pub info: EndpointInfoWithAddr,
227}
228
229#[derive(Debug, Serialize, Deserialize, Clone)]
230pub struct WaitStartResponse {
231    pub infos: Vec<EndpointInfoWithAddr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct TraceClient {
236    client: irpc::Client<TraceProtocol>,
237    session_id: Uuid,
238}
239
240impl TraceClient {
241    pub fn from_env() -> Result<Option<Self>> {
242        if let Ok(addr) = std::env::var(ENV_TRACE_SERVER) {
243            let addr: SocketAddr = addr.parse()?;
244            let session_id: Uuid = match std::env::var(ENV_TRACE_SESSION_ID) {
245                Ok(id) => id.parse()?,
246                Err(_) => Uuid::now_v7(),
247            };
248            Ok(Some(Self::connect_quinn_insecure(addr, session_id)?))
249        } else {
250            Ok(None)
251        }
252    }
253
254    pub fn from_env_or_local() -> Result<Self> {
255        Ok(Self::from_env()?.unwrap_or_else(Self::local))
256    }
257
258    pub fn local() -> Self {
259        let (tx, rx) = tokio::sync::mpsc::channel(8);
260        LocalActor::spawn(rx);
261        let session_id = Uuid::now_v7();
262        Self {
263            client: irpc::Client::from(tx),
264            session_id,
265        }
266    }
267
268    pub fn new(client: irpc::Client<TraceProtocol>, session_id: Uuid) -> Self {
269        Self { client, session_id }
270    }
271
272    pub fn connect_quinn_insecure(remote: SocketAddr, session_id: Uuid) -> Result<Self> {
273        let addr_localhost = "127.0.0.1:0".parse().unwrap();
274        let endpoint = make_insecure_client_endpoint(addr_localhost)?;
275        Ok(Self::connect_quinn_endpoint(endpoint, remote, session_id))
276    }
277
278    pub fn connect_quinn_endpoint(
279        endpoint: quinn::Endpoint,
280        remote: SocketAddr,
281        session_id: Uuid,
282    ) -> Self {
283        let client = irpc::Client::quinn(endpoint, remote);
284        Self { client, session_id }
285    }
286
287    pub async fn connect_iroh(remote: iroh::EndpointId, session_id: Uuid) -> Result<Self> {
288        let endpoint = iroh::Endpoint::builder().bind().await?;
289        Self::connect_iroh_endpoint(endpoint, remote, session_id).await
290    }
291
292    pub async fn connect_iroh_endpoint(
293        endpoint: iroh::Endpoint,
294        remote: impl Into<iroh::EndpointAddr>,
295        session_id: Uuid,
296    ) -> Result<Self> {
297        let conn = endpoint.connect(remote.into(), ALPN).await?;
298        let conn = IrohRemoteConnection::new(conn);
299        let client = irpc::Client::boxed(conn);
300        Ok(Self { client, session_id })
301    }
302
303    pub async fn init_and_start_trace(&self, name: &str) -> Result<ActiveTrace> {
304        let trace_info = TraceInfo::new(name, 1);
305        let trace_id = self.init_trace(trace_info, None).await?;
306        let endpoint_info = EndpointInfo::new_empty(0);
307        let client = self.start_node(trace_id, endpoint_info).await?;
308        Ok(client)
309    }
310
311    pub async fn init_trace(&self, info: TraceInfo, setup_data: Option<Bytes>) -> Result<Uuid> {
312        debug!("init trace {info:?}");
313        let trace_id = self
314            .client
315            .rpc(InitTrace {
316                info,
317                session_id: self.session_id,
318                start_time: DateTime::now_utc(),
319                setup_data,
320            })
321            .await??;
322        debug!("init trace {trace_id}: OK");
323        Ok(trace_id)
324    }
325
326    pub async fn get_trace(&self, name: String) -> Result<GetTraceResponse> {
327        debug!("get trace {name}");
328        let data = self
329            .client
330            .rpc(GetTrace {
331                session_id: self.session_id,
332                name,
333            })
334            .await??;
335        debug!(?data, "get trace: OK");
336        Ok(data)
337    }
338
339    pub async fn close_trace(&self, trace_id: Uuid, result: Result<(), String>) -> Result<()> {
340        let end_time = DateTime::now_utc();
341        debug!(%trace_id, ?result, "close trace");
342        self.client
343            .rpc(CloseTrace {
344                trace_id,
345                end_time,
346                result,
347            })
348            .await??;
349        debug!(%trace_id, "close trace: OK");
350        Ok(())
351    }
352
353    pub async fn get_session(&self, session_id: Uuid) -> Result<Option<GetSessionResponse>> {
354        let res = self.client.rpc(GetSession { session_id }).await??;
355        Ok(res)
356    }
357
358    pub async fn put_logs(
359        &self,
360        trace_id: Uuid,
361        scope: Scope,
362        json_lines: Vec<String>,
363    ) -> Result<()> {
364        self.client
365            .rpc(PutLogs {
366                scope,
367                trace_id,
368                json_lines,
369            })
370            .await??;
371        Ok(())
372    }
373
374    pub async fn start_node(
375        &self,
376        trace_id: Uuid,
377        endpoint_info: EndpointInfo,
378    ) -> Result<ActiveTrace> {
379        let start_time = DateTime::now_utc();
380        let idx = endpoint_info.idx;
381        debug!(%trace_id, idx, "start node");
382        let res = self
383            .client
384            .rpc(StartNode {
385                trace_id,
386                node_info: endpoint_info,
387                start_time,
388            })
389            .await??;
390        let StartNodeResponse {} = res;
391        debug!(%trace_id, idx, "start node: OK");
392        Ok(ActiveTrace {
393            client: self.client.clone(),
394            trace_id,
395            idx,
396        })
397    }
398}
399
400#[derive(Debug, Clone)]
401pub struct ActiveTrace {
402    client: irpc::Client<TraceProtocol>,
403    trace_id: Uuid,
404    idx: u32,
405}
406
407impl ActiveTrace {
408    pub async fn put_checkpoint(
409        &self,
410        id: CheckpointId,
411        label: Option<String>,
412        result: Result<(), String>,
413    ) -> Result<()> {
414        debug!(id, "put checkpoint");
415        let time = DateTime::now_utc();
416        self.client
417            .rpc(PutCheckpoint {
418                trace_id: self.trace_id,
419                checkpoint_id: id,
420                idx: self.idx,
421                time,
422                label,
423                result,
424            })
425            .await??;
426        Ok(())
427    }
428
429    pub async fn put_metrics(
430        &self,
431        id: EndpointId,
432        checkpoint_id: Option<CheckpointId>,
433        metrics: Update,
434    ) -> Result<()> {
435        let time = DateTime::now_utc();
436        debug!(count = metrics.values.items.len(), "put metrics");
437        self.client
438            .rpc(PutMetrics {
439                trace_id: self.trace_id,
440                id,
441                checkpoint_id,
442                time,
443                metrics,
444            })
445            .await??;
446        Ok(())
447    }
448
449    pub async fn put_logs(&self, json_lines: Vec<String>) -> Result<()> {
450        self.client
451            .rpc(PutLogs {
452                scope: Scope::Isolated(self.idx),
453                trace_id: self.trace_id,
454                json_lines,
455            })
456            .await??;
457        Ok(())
458    }
459
460    pub async fn wait_start(
461        &self,
462        info: EndpointInfoWithAddr,
463    ) -> Result<Vec<EndpointInfoWithAddr>> {
464        debug!("waiting for start...");
465        let res = self
466            .client
467            .rpc(WaitStart {
468                info,
469                trace_id: self.trace_id,
470            })
471            .await??;
472        debug!("start confirmed");
473        Ok(res.infos)
474    }
475
476    pub async fn wait_checkpoint(&self, checkpoint_id: CheckpointId) -> Result<()> {
477        debug!(?checkpoint_id, "waiting for checkpoint...");
478        let res = self
479            .client
480            .rpc(WaitCheckpoint {
481                checkpoint_id,
482                trace_id: self.trace_id,
483            })
484            .await;
485        res??;
486        debug!(?checkpoint_id, "checkpoint confirmed");
487        Ok(())
488    }
489
490    pub async fn end(&self, result: Result<(), String>) -> Result<()> {
491        debug!("end node");
492        let end_time = DateTime::now_utc();
493        self.client
494            .rpc(EndNode {
495                trace_id: self.trace_id,
496                idx: self.idx,
497                end_time,
498                result,
499            })
500            .await??;
501        Ok(())
502    }
503}
504
505#[derive(Default)]
506struct LocalActor {
507    traces_by_name: BTreeMap<String, Uuid>,
508    traces: BTreeMap<Uuid, TraceState>,
509}
510
511struct TraceState {
512    init: InitTrace,
513    nodes: Vec<EndpointInfoWithAddr>,
514    barrier_start: Vec<oneshot::Sender<RemoteResult<WaitStartResponse>>>,
515    barrier_checkpoint: BTreeMap<CheckpointId, Vec<oneshot::Sender<RemoteResult<()>>>>,
516}
517
518impl TraceState {
519    fn node_count(&self) -> usize {
520        self.init.info.node_count as usize
521    }
522}
523impl LocalActor {
524    pub fn spawn(rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
525        let actor = Self::default();
526        tokio::task::spawn(actor.run(rx));
527    }
528
529    pub async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
530        while let Some(message) = rx.recv().await {
531            self.handle_message(message).await;
532        }
533    }
534    async fn handle_message(&mut self, message: TraceMessage) {
535        match message {
536            TraceMessage::GetSession(_msg) => {}
537            TraceMessage::InitTrace(msg) => {
538                let WithChannels { inner, tx, .. } = msg;
539                if self.traces_by_name.contains_key(&inner.info.name) {
540                    return send_err(tx, RemoteError::other("Trace already initialized")).await;
541                }
542                let uuid = Uuid::now_v7();
543                self.traces_by_name.insert(inner.info.name.clone(), uuid);
544                self.traces.insert(
545                    uuid,
546                    TraceState {
547                        init: inner,
548                        nodes: Default::default(),
549                        barrier_start: Default::default(),
550                        barrier_checkpoint: Default::default(),
551                    },
552                );
553                tx.send(Ok(uuid)).await.ok();
554            }
555            TraceMessage::GetTrace(msg) => {
556                let WithChannels { inner, tx, .. } = msg;
557                let GetTrace {
558                    session_id: _,
559                    name,
560                } = inner;
561                let Some((trace_id, info)) = self
562                    .traces_by_name
563                    .get(&name)
564                    .and_then(|trace_id| self.traces.get_key_value(trace_id))
565                else {
566                    return send_err(tx, RemoteError::other("Trace not initialized")).await;
567                };
568                tx.send(Ok(GetTraceResponse {
569                    trace_id: *trace_id,
570                    info: info.init.info.clone(),
571                    setup_data: info.init.setup_data.clone(),
572                }))
573                .await
574                .ok();
575            }
576            TraceMessage::StartNode(msg) => {
577                let WithChannels { inner, tx, .. } = msg;
578                if self.traces.contains_key(&inner.trace_id) {
579                    tx.send(Ok(StartNodeResponse {})).await.ok();
580                } else {
581                    send_err(tx, RemoteError::other("Trace not initialized")).await;
582                }
583            }
584            TraceMessage::EndNode(msg) => {
585                // noop
586                msg.tx.send(Ok(())).await.ok();
587            }
588            TraceMessage::CloseTrace(msg) => {
589                // noop
590                msg.tx.send(Ok(())).await.ok();
591            }
592            TraceMessage::PutCheckpoint(msg) => {
593                // noop
594                msg.tx.send(Ok(())).await.ok();
595            }
596            TraceMessage::PutLogs(msg) => {
597                // noop
598                msg.tx.send(Ok(())).await.ok();
599            }
600            TraceMessage::PutMetrics(msg) => {
601                // noop
602                msg.tx.send(Ok(())).await.ok();
603            }
604            TraceMessage::WaitCheckpoint(msg) => {
605                let WithChannels { inner, tx, .. } = msg;
606                let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
607                    return send_err(tx, RemoteError::TraceNotFound).await;
608                };
609                let node_count = trace.node_count();
610                let barrier = trace
611                    .barrier_checkpoint
612                    .entry(inner.checkpoint_id)
613                    .or_default();
614                barrier.push(tx);
615                debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, count=barrier.len(), total=node_count, "wait checkpoint");
616                if barrier.len() == node_count {
617                    debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, "release");
618                    barrier
619                        .drain(..)
620                        .map(|tx| tx.send(Ok(())))
621                        .into_unordered_stream()
622                        .count()
623                        .await;
624                }
625            }
626            TraceMessage::WaitStart(msg) => {
627                let WithChannels { inner, tx, .. } = msg;
628                let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
629                    return send_err(tx, RemoteError::TraceNotFound).await;
630                };
631                trace.nodes.push(inner.info);
632                trace.barrier_start.push(tx);
633                if trace.barrier_start.len() == trace.init.info.node_count as usize {
634                    let data = WaitStartResponse {
635                        infos: trace.nodes.clone(),
636                    };
637                    for tx in trace.barrier_start.drain(..) {
638                        tx.send(Ok(data.clone())).await.ok();
639                    }
640                }
641            }
642        }
643    }
644}
645
646async fn send_err<T>(tx: oneshot::Sender<RemoteResult<T>>, err: RemoteError) {
647    tx.send(Err(err)).await.ok();
648}