1use std::{collections::BTreeMap, net::SocketAddr};
2
3use anyhow::Result;
4use bytes::Bytes;
5use iroh::{NodeAddr, NodeId};
6use iroh_metrics::encoding::Update;
7use irpc::{WithChannels, channel::oneshot, rpc_requests, util::make_insecure_client_endpoint};
8use irpc_iroh::IrohRemoteConnection;
9use n0_future::{IterExt, StreamExt};
10use serde::{Deserialize, Serialize};
11use time::OffsetDateTime as DateTime;
12use tracing::debug;
13use uuid::Uuid;
14
15use super::{ENV_TRACE_SERVER, ENV_TRACE_SESSION_ID};
16
17pub const ALPN: &[u8] = b"/iroh/n0des-sim/1";
18
19pub type RemoteResult<T> = Result<T, RemoteError>;
20
21#[derive(Serialize, Deserialize, thiserror::Error, Debug)]
22pub enum RemoteError {
23 #[error("{0}")]
24 Other(String),
25 #[error("Trace not found")]
26 TraceNotFound,
27}
28
29impl RemoteError {
30 pub fn other(s: impl ToString) -> Self {
31 Self::Other(s.to_string())
32 }
33}
34
35impl From<anyhow::Error> for RemoteError {
36 fn from(value: anyhow::Error) -> Self {
37 Self::other(value)
38 }
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42#[rpc_requests(message = TraceMessage)]
43#[allow(clippy::large_enum_variant)]
44pub enum TraceProtocol {
45 #[rpc(tx=oneshot::Sender<RemoteResult<Option<GetSessionResponse>>>)]
46 GetSession(GetSession),
47 #[rpc(tx=oneshot::Sender<RemoteResult<Uuid>>)]
48 InitTrace(InitTrace),
49 #[rpc(tx=oneshot::Sender<RemoteResult<GetTraceResponse>>)]
50 GetTrace(GetTrace),
51 #[rpc(tx=oneshot::Sender<RemoteResult<StartNodeResponse>>)]
52 StartNode(StartNode),
53 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
54 EndNode(EndNode),
55 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
56 CloseTrace(CloseTrace),
57 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
58 PutCheckpoint(PutCheckpoint),
59 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
60 PutLogs(PutLogs),
61 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
62 PutMetrics(PutMetrics),
63 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
64 WaitCheckpoint(WaitCheckpoint),
65 #[rpc(tx=oneshot::Sender<RemoteResult<WaitStartResponse>>)]
66 WaitStart(WaitStart),
67}
68
69#[derive(Debug, Serialize, Deserialize, Clone)]
70pub struct GetSession {
71 pub session_id: Uuid,
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct GetSessionResponse {
76 pub traces: Vec<TraceDetails>,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone)]
80pub struct TraceDetails {
81 pub trace_id: Uuid,
82 pub info: TraceInfo,
83 pub finished: bool,
84}
85
86#[derive(Debug, Serialize, Deserialize, Clone)]
87pub struct InitTrace {
88 pub session_id: Uuid,
89 pub info: TraceInfo,
90 #[serde(with = "time::serde::rfc3339")]
91 pub start_time: DateTime,
92 pub setup_data: Option<Bytes>,
93}
94
95#[derive(Debug, Serialize, Deserialize, Clone)]
96pub struct CloseTrace {
97 pub trace_id: Uuid,
98 #[serde(with = "time::serde::rfc3339")]
99 pub end_time: DateTime,
100 pub result: Result<(), String>,
101}
102
103#[derive(Debug, Serialize, Deserialize, Clone)]
104pub struct TraceInfo {
105 pub name: String,
106 pub node_count: u32,
107 pub expected_checkpoints: Option<u64>,
108}
109
110impl TraceInfo {
111 pub fn new(name: impl ToString, node_count: u32) -> Self {
112 Self {
113 name: name.to_string(),
114 node_count,
115 expected_checkpoints: None,
116 }
117 }
118}
119
120#[derive(Debug, Serialize, Deserialize, Clone)]
121pub struct GetTrace {
122 pub session_id: Uuid,
123 pub name: String,
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct GetTraceResponse {
128 pub trace_id: Uuid,
129 pub info: TraceInfo,
130 pub setup_data: Option<Bytes>,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct StartNode {
135 pub trace_id: Uuid,
136 pub node_info: NodeInfo,
137 #[serde(with = "time::serde::rfc3339")]
138 pub start_time: DateTime,
139}
140
141#[derive(Debug, Serialize, Deserialize, Clone)]
142pub struct StartNodeResponse {
143 }
146
147pub type NodeIdx = u32;
148
149#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
150pub enum Scope {
151 Integrated,
152 Isolated(NodeIdx),
153}
154
155#[derive(Debug, Serialize, Deserialize, Clone)]
156pub struct EndNode {
157 pub trace_id: Uuid,
158 pub node_idx: NodeIdx,
159 #[serde(with = "time::serde::rfc3339")]
160 pub end_time: DateTime,
161 pub result: Result<(), String>,
162}
163
164#[derive(Debug, Serialize, Deserialize, Clone)]
165pub struct NodeInfo {
166 pub idx: NodeIdx,
167 pub node_id: Option<NodeId>,
168 pub label: Option<String>,
169}
170
171impl NodeInfo {
172 pub fn new_empty(idx: NodeIdx) -> Self {
173 Self {
174 idx,
175 node_id: None,
176 label: None,
177 }
178 }
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct NodeInfoWithAddr {
183 pub info: NodeInfo,
184 pub addr: Option<NodeAddr>,
185}
186
187#[derive(Debug, Serialize, Deserialize, Clone)]
188pub struct PutLogs {
189 pub trace_id: Uuid,
190 pub scope: Scope,
191 pub json_lines: Vec<String>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195pub struct PutMetrics {
196 pub trace_id: Uuid,
197 pub node_id: NodeId,
198 pub checkpoint_id: Option<CheckpointId>,
199 #[serde(with = "time::serde::rfc3339")]
200 pub time: DateTime,
201 pub metrics: iroh_metrics::encoding::Update,
202}
203
204pub type CheckpointId = u64;
205
206#[derive(Debug, Serialize, Deserialize, Clone)]
207pub struct PutCheckpoint {
208 pub trace_id: Uuid,
209 pub checkpoint_id: CheckpointId,
210 pub node_idx: NodeIdx,
211 pub label: Option<String>,
212 #[serde(with = "time::serde::rfc3339")]
213 pub time: DateTime,
214 pub result: Result<(), String>,
215}
216
217#[derive(Debug, Serialize, Deserialize, Clone)]
218pub struct WaitCheckpoint {
219 pub trace_id: Uuid,
220 pub checkpoint_id: CheckpointId,
221}
222
223#[derive(Debug, Serialize, Deserialize, Clone)]
224pub struct WaitStart {
225 pub trace_id: Uuid,
226 pub info: NodeInfoWithAddr,
227}
228
229#[derive(Debug, Serialize, Deserialize, Clone)]
230pub struct WaitStartResponse {
231 pub infos: Vec<NodeInfoWithAddr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct TraceClient {
236 client: irpc::Client<TraceProtocol>,
237 session_id: Uuid,
238}
239
240impl TraceClient {
241 pub fn from_env() -> Result<Option<Self>> {
242 if let Ok(addr) = std::env::var(ENV_TRACE_SERVER) {
243 let addr: SocketAddr = addr.parse()?;
244 let session_id: Uuid = match std::env::var(ENV_TRACE_SESSION_ID) {
245 Ok(id) => id.parse()?,
246 Err(_) => Uuid::now_v7(),
247 };
248 Ok(Some(Self::connect_quinn_insecure(addr, session_id)?))
249 } else {
250 Ok(None)
251 }
252 }
253
254 pub fn from_env_or_local() -> Result<Self> {
255 Ok(Self::from_env()?.unwrap_or_else(Self::local))
256 }
257
258 pub fn local() -> Self {
259 let (tx, rx) = tokio::sync::mpsc::channel(8);
260 LocalActor::spawn(rx);
261 let session_id = Uuid::now_v7();
262 Self {
263 client: irpc::Client::from(tx),
264 session_id,
265 }
266 }
267
268 pub fn new(client: irpc::Client<TraceProtocol>, session_id: Uuid) -> Self {
269 Self { client, session_id }
270 }
271
272 pub fn connect_quinn_insecure(remote: SocketAddr, session_id: Uuid) -> Result<Self> {
273 let addr_localhost = "127.0.0.1:0".parse().unwrap();
274 let endpoint = make_insecure_client_endpoint(addr_localhost)?;
275 Ok(Self::connect_quinn_endpoint(endpoint, remote, session_id))
276 }
277
278 pub fn connect_quinn_endpoint(
279 endpoint: quinn::Endpoint,
280 remote: SocketAddr,
281 session_id: Uuid,
282 ) -> Self {
283 let client = irpc::Client::quinn(endpoint, remote);
284 Self { client, session_id }
285 }
286
287 pub async fn connect_iroh(remote: iroh::NodeId, session_id: Uuid) -> Result<Self> {
288 let endpoint = iroh::Endpoint::builder().bind().await?;
289 Ok(Self::connect_iroh_endpoint(endpoint, remote, session_id))
290 }
291
292 pub fn connect_iroh_endpoint(
293 endpoint: iroh::Endpoint,
294 remote: impl Into<iroh::NodeAddr>,
295 session_id: Uuid,
296 ) -> Self {
297 let conn = IrohRemoteConnection::new(endpoint, remote.into(), ALPN.to_vec());
298 let client = irpc::Client::boxed(conn);
299 Self { client, session_id }
300 }
301
302 pub async fn init_and_start_trace(&self, name: &str) -> Result<ActiveTrace> {
303 let trace_info = TraceInfo::new(name, 1);
304 let trace_id = self.init_trace(trace_info, None).await?;
305 let node_info = NodeInfo::new_empty(0);
306 let client = self.start_node(trace_id, node_info).await?;
307 Ok(client)
308 }
309
310 pub async fn init_trace(&self, info: TraceInfo, setup_data: Option<Bytes>) -> Result<Uuid> {
311 debug!("init trace {info:?}");
312 let trace_id = self
313 .client
314 .rpc(InitTrace {
315 info,
316 session_id: self.session_id,
317 start_time: DateTime::now_utc(),
318 setup_data,
319 })
320 .await??;
321 debug!("init trace {trace_id}: OK");
322 Ok(trace_id)
323 }
324
325 pub async fn get_trace(&self, name: String) -> Result<GetTraceResponse> {
326 debug!("get trace {name}");
327 let data = self
328 .client
329 .rpc(GetTrace {
330 session_id: self.session_id,
331 name,
332 })
333 .await??;
334 debug!(?data, "get trace: OK");
335 Ok(data)
336 }
337
338 pub async fn close_trace(&self, trace_id: Uuid, result: Result<(), String>) -> Result<()> {
339 let end_time = DateTime::now_utc();
340 debug!(%trace_id, ?result, "close trace");
341 self.client
342 .rpc(CloseTrace {
343 trace_id,
344 end_time,
345 result,
346 })
347 .await??;
348 debug!(%trace_id, "close trace: OK");
349 Ok(())
350 }
351
352 pub async fn get_session(&self, session_id: Uuid) -> Result<Option<GetSessionResponse>> {
353 let res = self.client.rpc(GetSession { session_id }).await??;
354 Ok(res)
355 }
356
357 pub async fn put_logs(
358 &self,
359 trace_id: Uuid,
360 scope: Scope,
361 json_lines: Vec<String>,
362 ) -> Result<()> {
363 self.client
364 .rpc(PutLogs {
365 scope,
366 trace_id,
367 json_lines,
368 })
369 .await??;
370 Ok(())
371 }
372
373 pub async fn start_node(&self, trace_id: Uuid, node_info: NodeInfo) -> Result<ActiveTrace> {
374 let start_time = DateTime::now_utc();
375 let node_idx = node_info.idx;
376 debug!(%trace_id, node_idx, "start node");
377 let res = self
378 .client
379 .rpc(StartNode {
380 trace_id,
381 node_info,
382 start_time,
383 })
384 .await??;
385 let StartNodeResponse {} = res;
386 debug!(%trace_id, node_idx, "start node: OK");
387 Ok(ActiveTrace {
388 client: self.client.clone(),
389 trace_id,
390 node_idx,
391 })
392 }
393}
394
395#[derive(Debug, Clone)]
396pub struct ActiveTrace {
397 client: irpc::Client<TraceProtocol>,
398 trace_id: Uuid,
399 node_idx: u32,
400}
401
402impl ActiveTrace {
403 pub async fn put_checkpoint(
404 &self,
405 id: CheckpointId,
406 label: Option<String>,
407 result: Result<(), String>,
408 ) -> Result<()> {
409 debug!(id, "put checkpoint");
410 let time = DateTime::now_utc();
411 self.client
412 .rpc(PutCheckpoint {
413 trace_id: self.trace_id,
414 checkpoint_id: id,
415 node_idx: self.node_idx,
416 time,
417 label,
418 result,
419 })
420 .await??;
421 Ok(())
422 }
423
424 pub async fn put_metrics(
425 &self,
426 node_id: NodeId,
427 checkpoint_id: Option<CheckpointId>,
428 metrics: Update,
429 ) -> Result<()> {
430 let time = DateTime::now_utc();
431 debug!(count = metrics.values.items.len(), "put metrics");
432 self.client
433 .rpc(PutMetrics {
434 trace_id: self.trace_id,
435 node_id,
436 checkpoint_id,
437 time,
438 metrics,
439 })
440 .await??;
441 Ok(())
442 }
443
444 pub async fn put_logs(&self, json_lines: Vec<String>) -> Result<()> {
445 self.client
446 .rpc(PutLogs {
447 scope: Scope::Isolated(self.node_idx),
448 trace_id: self.trace_id,
449 json_lines,
450 })
451 .await??;
452 Ok(())
453 }
454
455 pub async fn wait_start(&self, info: NodeInfoWithAddr) -> Result<Vec<NodeInfoWithAddr>> {
456 debug!("waiting for start...");
457 let res = self
458 .client
459 .rpc(WaitStart {
460 info,
461 trace_id: self.trace_id,
462 })
463 .await??;
464 debug!("start confirmed");
465 Ok(res.infos)
466 }
467
468 pub async fn wait_checkpoint(&self, checkpoint_id: CheckpointId) -> Result<()> {
469 debug!(?checkpoint_id, "waiting for checkpoint...");
470 let res = self
471 .client
472 .rpc(WaitCheckpoint {
473 checkpoint_id,
474 trace_id: self.trace_id,
475 })
476 .await;
477 res??;
478 debug!(?checkpoint_id, "checkpoint confirmed");
479 Ok(())
480 }
481
482 pub async fn end(&self, result: Result<(), String>) -> Result<()> {
483 debug!("end node");
484 let end_time = DateTime::now_utc();
485 self.client
486 .rpc(EndNode {
487 trace_id: self.trace_id,
488 node_idx: self.node_idx,
489 end_time,
490 result,
491 })
492 .await??;
493 Ok(())
494 }
495}
496
497#[derive(Default)]
498struct LocalActor {
499 traces_by_name: BTreeMap<String, Uuid>,
500 traces: BTreeMap<Uuid, TraceState>,
501}
502
503struct TraceState {
504 init: InitTrace,
505 nodes: Vec<NodeInfoWithAddr>,
506 barrier_start: Vec<oneshot::Sender<RemoteResult<WaitStartResponse>>>,
507 barrier_checkpoint: BTreeMap<CheckpointId, Vec<oneshot::Sender<RemoteResult<()>>>>,
508}
509
510impl TraceState {
511 fn node_count(&self) -> usize {
512 self.init.info.node_count as usize
513 }
514}
515impl LocalActor {
516 pub fn spawn(rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
517 let actor = Self::default();
518 tokio::task::spawn(actor.run(rx));
519 }
520
521 pub async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
522 while let Some(message) = rx.recv().await {
523 self.handle_message(message).await;
524 }
525 }
526 async fn handle_message(&mut self, message: TraceMessage) {
527 match message {
528 TraceMessage::GetSession(_msg) => {}
529 TraceMessage::InitTrace(msg) => {
530 let WithChannels { inner, tx, .. } = msg;
531 if self.traces_by_name.contains_key(&inner.info.name) {
532 return send_err(tx, RemoteError::other("Trace already initialized")).await;
533 }
534 let uuid = Uuid::now_v7();
535 self.traces_by_name.insert(inner.info.name.clone(), uuid);
536 self.traces.insert(
537 uuid,
538 TraceState {
539 init: inner,
540 nodes: Default::default(),
541 barrier_start: Default::default(),
542 barrier_checkpoint: Default::default(),
543 },
544 );
545 tx.send(Ok(uuid)).await.ok();
546 }
547 TraceMessage::GetTrace(msg) => {
548 let WithChannels { inner, tx, .. } = msg;
549 let GetTrace {
550 session_id: _,
551 name,
552 } = inner;
553 let Some((trace_id, info)) = self
554 .traces_by_name
555 .get(&name)
556 .and_then(|trace_id| self.traces.get_key_value(trace_id))
557 else {
558 return send_err(tx, RemoteError::other("Trace not initialized")).await;
559 };
560 tx.send(Ok(GetTraceResponse {
561 trace_id: *trace_id,
562 info: info.init.info.clone(),
563 setup_data: info.init.setup_data.clone(),
564 }))
565 .await
566 .ok();
567 }
568 TraceMessage::StartNode(msg) => {
569 let WithChannels { inner, tx, .. } = msg;
570 if self.traces.contains_key(&inner.trace_id) {
571 tx.send(Ok(StartNodeResponse {})).await.ok();
572 } else {
573 send_err(tx, RemoteError::other("Trace not initialized")).await;
574 }
575 }
576 TraceMessage::EndNode(msg) => {
577 msg.tx.send(Ok(())).await.ok();
579 }
580 TraceMessage::CloseTrace(msg) => {
581 msg.tx.send(Ok(())).await.ok();
583 }
584 TraceMessage::PutCheckpoint(msg) => {
585 msg.tx.send(Ok(())).await.ok();
587 }
588 TraceMessage::PutLogs(msg) => {
589 msg.tx.send(Ok(())).await.ok();
591 }
592 TraceMessage::PutMetrics(msg) => {
593 msg.tx.send(Ok(())).await.ok();
595 }
596 TraceMessage::WaitCheckpoint(msg) => {
597 let WithChannels { inner, tx, .. } = msg;
598 let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
599 return send_err(tx, RemoteError::TraceNotFound).await;
600 };
601 let node_count = trace.node_count();
602 let barrier = trace
603 .barrier_checkpoint
604 .entry(inner.checkpoint_id)
605 .or_default();
606 barrier.push(tx);
607 debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, count=barrier.len(), total=node_count, "wait checkpoint");
608 if barrier.len() == node_count {
609 debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, "release");
610 barrier
611 .drain(..)
612 .map(|tx| tx.send(Ok(())))
613 .into_unordered_stream()
614 .count()
615 .await;
616 }
617 }
618 TraceMessage::WaitStart(msg) => {
619 let WithChannels { inner, tx, .. } = msg;
620 let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
621 return send_err(tx, RemoteError::TraceNotFound).await;
622 };
623 trace.nodes.push(inner.info);
624 trace.barrier_start.push(tx);
625 if trace.barrier_start.len() == trace.init.info.node_count as usize {
626 let data = WaitStartResponse {
627 infos: trace.nodes.clone(),
628 };
629 for tx in trace.barrier_start.drain(..) {
630 tx.send(Ok(data.clone())).await.ok();
631 }
632 }
633 }
634 }
635 }
636}
637
638async fn send_err<T>(tx: oneshot::Sender<RemoteResult<T>>, err: RemoteError) {
639 tx.send(Err(err)).await.ok();
640}