1use std::{collections::BTreeMap, net::SocketAddr};
2
3use anyhow::Result;
4use bytes::Bytes;
5use iroh::{EndpointAddr, EndpointId};
6use iroh_metrics::encoding::Update;
7use irpc::{WithChannels, channel::oneshot, rpc_requests, util::make_insecure_client_endpoint};
8use irpc_iroh::IrohRemoteConnection;
9use n0_future::{IterExt, StreamExt};
10use serde::{Deserialize, Serialize};
11use time::OffsetDateTime as DateTime;
12use tracing::debug;
13use uuid::Uuid;
14
15use super::{ENV_TRACE_SERVER, ENV_TRACE_SESSION_ID};
16
17pub const ALPN: &[u8] = b"/iroh/n0des-sim/1";
18
19pub type RemoteResult<T> = Result<T, RemoteError>;
20
21#[derive(Serialize, Deserialize, thiserror::Error, Debug)]
22pub enum RemoteError {
23 #[error("{0}")]
24 Other(String),
25 #[error("Trace not found")]
26 TraceNotFound,
27}
28
29impl RemoteError {
30 pub fn other(s: impl ToString) -> Self {
31 Self::Other(s.to_string())
32 }
33}
34
35impl From<anyhow::Error> for RemoteError {
36 fn from(value: anyhow::Error) -> Self {
37 Self::other(value)
38 }
39}
40
41#[derive(Debug, Serialize, Deserialize)]
42#[rpc_requests(message = TraceMessage)]
43#[allow(clippy::large_enum_variant)]
44pub enum TraceProtocol {
45 #[rpc(tx=oneshot::Sender<RemoteResult<Option<GetSessionResponse>>>)]
46 GetSession(GetSession),
47 #[rpc(tx=oneshot::Sender<RemoteResult<Uuid>>)]
48 InitTrace(InitTrace),
49 #[rpc(tx=oneshot::Sender<RemoteResult<GetTraceResponse>>)]
50 GetTrace(GetTrace),
51 #[rpc(tx=oneshot::Sender<RemoteResult<StartNodeResponse>>)]
52 StartNode(StartNode),
53 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
54 EndNode(EndNode),
55 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
56 CloseTrace(CloseTrace),
57 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
58 PutCheckpoint(PutCheckpoint),
59 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
60 PutLogs(PutLogs),
61 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
62 PutMetrics(PutMetrics),
63 #[rpc(tx=oneshot::Sender<RemoteResult<()>>)]
64 WaitCheckpoint(WaitCheckpoint),
65 #[rpc(tx=oneshot::Sender<RemoteResult<WaitStartResponse>>)]
66 WaitStart(WaitStart),
67}
68
69#[derive(Debug, Serialize, Deserialize, Clone)]
70pub struct GetSession {
71 pub session_id: Uuid,
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct GetSessionResponse {
76 pub traces: Vec<TraceDetails>,
77}
78
79#[derive(Debug, Serialize, Deserialize, Clone)]
80pub struct TraceDetails {
81 pub trace_id: Uuid,
82 pub info: TraceInfo,
83 pub finished: bool,
84}
85
86#[derive(Debug, Serialize, Deserialize, Clone)]
87pub struct InitTrace {
88 pub session_id: Uuid,
89 pub info: TraceInfo,
90 #[serde(with = "time::serde::rfc3339")]
91 pub start_time: DateTime,
92 pub setup_data: Option<Bytes>,
93}
94
95#[derive(Debug, Serialize, Deserialize, Clone)]
96pub struct CloseTrace {
97 pub trace_id: Uuid,
98 #[serde(with = "time::serde::rfc3339")]
99 pub end_time: DateTime,
100 pub result: Result<(), String>,
101}
102
103#[derive(Debug, Serialize, Deserialize, Clone)]
104pub struct TraceInfo {
105 pub name: String,
106 pub node_count: u32,
107 pub expected_checkpoints: Option<u64>,
108}
109
110impl TraceInfo {
111 pub fn new(name: impl ToString, node_count: u32) -> Self {
112 Self {
113 name: name.to_string(),
114 node_count,
115 expected_checkpoints: None,
116 }
117 }
118}
119
120#[derive(Debug, Serialize, Deserialize, Clone)]
121pub struct GetTrace {
122 pub session_id: Uuid,
123 pub name: String,
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127pub struct GetTraceResponse {
128 pub trace_id: Uuid,
129 pub info: TraceInfo,
130 pub setup_data: Option<Bytes>,
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone)]
134pub struct StartNode {
135 pub trace_id: Uuid,
136 pub node_info: EndpointInfo,
137 #[serde(with = "time::serde::rfc3339")]
138 pub start_time: DateTime,
139}
140
141#[derive(Debug, Serialize, Deserialize, Clone)]
142pub struct StartNodeResponse {
143 }
146
147pub type EndpointIdx = u32;
148
149#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
150pub enum Scope {
151 Integrated,
152 Isolated(EndpointIdx),
153}
154
155#[derive(Debug, Serialize, Deserialize, Clone)]
156pub struct EndNode {
157 pub trace_id: Uuid,
158 pub idx: EndpointIdx,
159 #[serde(with = "time::serde::rfc3339")]
160 pub end_time: DateTime,
161 pub result: Result<(), String>,
162}
163
164#[derive(Debug, Serialize, Deserialize, Clone)]
165pub struct EndpointInfo {
166 pub idx: EndpointIdx,
167 pub endpoint_id: Option<EndpointId>,
168 pub label: Option<String>,
169}
170
171impl EndpointInfo {
172 pub fn new_empty(idx: EndpointIdx) -> Self {
173 Self {
174 idx,
175 endpoint_id: None,
176 label: None,
177 }
178 }
179}
180
181#[derive(Debug, Serialize, Deserialize, Clone)]
182pub struct EndpointInfoWithAddr {
183 pub info: EndpointInfo,
184 pub addr: Option<EndpointAddr>,
185}
186
187#[derive(Debug, Serialize, Deserialize, Clone)]
188pub struct PutLogs {
189 pub trace_id: Uuid,
190 pub scope: Scope,
191 pub json_lines: Vec<String>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195pub struct PutMetrics {
196 pub trace_id: Uuid,
197 pub id: EndpointId,
198 pub checkpoint_id: Option<CheckpointId>,
199 #[serde(with = "time::serde::rfc3339")]
200 pub time: DateTime,
201 pub metrics: iroh_metrics::encoding::Update,
202}
203
204pub type CheckpointId = u64;
205
206#[derive(Debug, Serialize, Deserialize, Clone)]
207pub struct PutCheckpoint {
208 pub trace_id: Uuid,
209 pub checkpoint_id: CheckpointId,
210 pub idx: EndpointIdx,
211 pub label: Option<String>,
212 #[serde(with = "time::serde::rfc3339")]
213 pub time: DateTime,
214 pub result: Result<(), String>,
215}
216
217#[derive(Debug, Serialize, Deserialize, Clone)]
218pub struct WaitCheckpoint {
219 pub trace_id: Uuid,
220 pub checkpoint_id: CheckpointId,
221}
222
223#[derive(Debug, Serialize, Deserialize, Clone)]
224pub struct WaitStart {
225 pub trace_id: Uuid,
226 pub info: EndpointInfoWithAddr,
227}
228
229#[derive(Debug, Serialize, Deserialize, Clone)]
230pub struct WaitStartResponse {
231 pub infos: Vec<EndpointInfoWithAddr>,
232}
233
234#[derive(Debug, Clone)]
235pub struct TraceClient {
236 client: irpc::Client<TraceProtocol>,
237 session_id: Uuid,
238}
239
240impl TraceClient {
241 pub fn from_env() -> Result<Option<Self>> {
242 if let Ok(addr) = std::env::var(ENV_TRACE_SERVER) {
243 let addr: SocketAddr = addr.parse()?;
244 let session_id: Uuid = match std::env::var(ENV_TRACE_SESSION_ID) {
245 Ok(id) => id.parse()?,
246 Err(_) => Uuid::now_v7(),
247 };
248 Ok(Some(Self::connect_quinn_insecure(addr, session_id)?))
249 } else {
250 Ok(None)
251 }
252 }
253
254 pub fn from_env_or_local() -> Result<Self> {
255 Ok(Self::from_env()?.unwrap_or_else(Self::local))
256 }
257
258 pub fn local() -> Self {
259 let (tx, rx) = tokio::sync::mpsc::channel(8);
260 LocalActor::spawn(rx);
261 let session_id = Uuid::now_v7();
262 Self {
263 client: irpc::Client::from(tx),
264 session_id,
265 }
266 }
267
268 pub fn new(client: irpc::Client<TraceProtocol>, session_id: Uuid) -> Self {
269 Self { client, session_id }
270 }
271
272 pub fn connect_quinn_insecure(remote: SocketAddr, session_id: Uuid) -> Result<Self> {
273 let addr_localhost = "127.0.0.1:0".parse().unwrap();
274 let endpoint = make_insecure_client_endpoint(addr_localhost)?;
275 Ok(Self::connect_quinn_endpoint(endpoint, remote, session_id))
276 }
277
278 pub fn connect_quinn_endpoint(
279 endpoint: quinn::Endpoint,
280 remote: SocketAddr,
281 session_id: Uuid,
282 ) -> Self {
283 let client = irpc::Client::quinn(endpoint, remote);
284 Self { client, session_id }
285 }
286
287 pub async fn connect_iroh(remote: iroh::EndpointId, session_id: Uuid) -> Result<Self> {
288 let endpoint = iroh::Endpoint::builder().bind().await?;
289 Self::connect_iroh_endpoint(endpoint, remote, session_id).await
290 }
291
292 pub async fn connect_iroh_endpoint(
293 endpoint: iroh::Endpoint,
294 remote: impl Into<iroh::EndpointAddr>,
295 session_id: Uuid,
296 ) -> Result<Self> {
297 let conn = endpoint.connect(remote.into(), ALPN).await?;
298 let conn = IrohRemoteConnection::new(conn);
299 let client = irpc::Client::boxed(conn);
300 Ok(Self { client, session_id })
301 }
302
303 pub async fn init_and_start_trace(&self, name: &str) -> Result<ActiveTrace> {
304 let trace_info = TraceInfo::new(name, 1);
305 let trace_id = self.init_trace(trace_info, None).await?;
306 let endpoint_info = EndpointInfo::new_empty(0);
307 let client = self.start_node(trace_id, endpoint_info).await?;
308 Ok(client)
309 }
310
311 pub async fn init_trace(&self, info: TraceInfo, setup_data: Option<Bytes>) -> Result<Uuid> {
312 debug!("init trace {info:?}");
313 let trace_id = self
314 .client
315 .rpc(InitTrace {
316 info,
317 session_id: self.session_id,
318 start_time: DateTime::now_utc(),
319 setup_data,
320 })
321 .await??;
322 debug!("init trace {trace_id}: OK");
323 Ok(trace_id)
324 }
325
326 pub async fn get_trace(&self, name: String) -> Result<GetTraceResponse> {
327 debug!("get trace {name}");
328 let data = self
329 .client
330 .rpc(GetTrace {
331 session_id: self.session_id,
332 name,
333 })
334 .await??;
335 debug!(?data, "get trace: OK");
336 Ok(data)
337 }
338
339 pub async fn close_trace(&self, trace_id: Uuid, result: Result<(), String>) -> Result<()> {
340 let end_time = DateTime::now_utc();
341 debug!(%trace_id, ?result, "close trace");
342 self.client
343 .rpc(CloseTrace {
344 trace_id,
345 end_time,
346 result,
347 })
348 .await??;
349 debug!(%trace_id, "close trace: OK");
350 Ok(())
351 }
352
353 pub async fn get_session(&self, session_id: Uuid) -> Result<Option<GetSessionResponse>> {
354 let res = self.client.rpc(GetSession { session_id }).await??;
355 Ok(res)
356 }
357
358 pub async fn put_logs(
359 &self,
360 trace_id: Uuid,
361 scope: Scope,
362 json_lines: Vec<String>,
363 ) -> Result<()> {
364 self.client
365 .rpc(PutLogs {
366 scope,
367 trace_id,
368 json_lines,
369 })
370 .await??;
371 Ok(())
372 }
373
374 pub async fn start_node(
375 &self,
376 trace_id: Uuid,
377 endpoint_info: EndpointInfo,
378 ) -> Result<ActiveTrace> {
379 let start_time = DateTime::now_utc();
380 let idx = endpoint_info.idx;
381 debug!(%trace_id, idx, "start node");
382 let res = self
383 .client
384 .rpc(StartNode {
385 trace_id,
386 node_info: endpoint_info,
387 start_time,
388 })
389 .await??;
390 let StartNodeResponse {} = res;
391 debug!(%trace_id, idx, "start node: OK");
392 Ok(ActiveTrace {
393 client: self.client.clone(),
394 trace_id,
395 idx,
396 })
397 }
398}
399
400#[derive(Debug, Clone)]
401pub struct ActiveTrace {
402 client: irpc::Client<TraceProtocol>,
403 trace_id: Uuid,
404 idx: u32,
405}
406
407impl ActiveTrace {
408 pub async fn put_checkpoint(
409 &self,
410 id: CheckpointId,
411 label: Option<String>,
412 result: Result<(), String>,
413 ) -> Result<()> {
414 debug!(id, "put checkpoint");
415 let time = DateTime::now_utc();
416 self.client
417 .rpc(PutCheckpoint {
418 trace_id: self.trace_id,
419 checkpoint_id: id,
420 idx: self.idx,
421 time,
422 label,
423 result,
424 })
425 .await??;
426 Ok(())
427 }
428
429 pub async fn put_metrics(
430 &self,
431 id: EndpointId,
432 checkpoint_id: Option<CheckpointId>,
433 metrics: Update,
434 ) -> Result<()> {
435 let time = DateTime::now_utc();
436 debug!(count = metrics.values.items.len(), "put metrics");
437 self.client
438 .rpc(PutMetrics {
439 trace_id: self.trace_id,
440 id,
441 checkpoint_id,
442 time,
443 metrics,
444 })
445 .await??;
446 Ok(())
447 }
448
449 pub async fn put_logs(&self, json_lines: Vec<String>) -> Result<()> {
450 self.client
451 .rpc(PutLogs {
452 scope: Scope::Isolated(self.idx),
453 trace_id: self.trace_id,
454 json_lines,
455 })
456 .await??;
457 Ok(())
458 }
459
460 pub async fn wait_start(
461 &self,
462 info: EndpointInfoWithAddr,
463 ) -> Result<Vec<EndpointInfoWithAddr>> {
464 debug!("waiting for start...");
465 let res = self
466 .client
467 .rpc(WaitStart {
468 info,
469 trace_id: self.trace_id,
470 })
471 .await??;
472 debug!("start confirmed");
473 Ok(res.infos)
474 }
475
476 pub async fn wait_checkpoint(&self, checkpoint_id: CheckpointId) -> Result<()> {
477 debug!(?checkpoint_id, "waiting for checkpoint...");
478 let res = self
479 .client
480 .rpc(WaitCheckpoint {
481 checkpoint_id,
482 trace_id: self.trace_id,
483 })
484 .await;
485 res??;
486 debug!(?checkpoint_id, "checkpoint confirmed");
487 Ok(())
488 }
489
490 pub async fn end(&self, result: Result<(), String>) -> Result<()> {
491 debug!("end node");
492 let end_time = DateTime::now_utc();
493 self.client
494 .rpc(EndNode {
495 trace_id: self.trace_id,
496 idx: self.idx,
497 end_time,
498 result,
499 })
500 .await??;
501 Ok(())
502 }
503}
504
505#[derive(Default)]
506struct LocalActor {
507 traces_by_name: BTreeMap<String, Uuid>,
508 traces: BTreeMap<Uuid, TraceState>,
509}
510
511struct TraceState {
512 init: InitTrace,
513 nodes: Vec<EndpointInfoWithAddr>,
514 barrier_start: Vec<oneshot::Sender<RemoteResult<WaitStartResponse>>>,
515 barrier_checkpoint: BTreeMap<CheckpointId, Vec<oneshot::Sender<RemoteResult<()>>>>,
516}
517
518impl TraceState {
519 fn node_count(&self) -> usize {
520 self.init.info.node_count as usize
521 }
522}
523impl LocalActor {
524 pub fn spawn(rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
525 let actor = Self::default();
526 tokio::task::spawn(actor.run(rx));
527 }
528
529 pub async fn run(mut self, mut rx: tokio::sync::mpsc::Receiver<TraceMessage>) {
530 while let Some(message) = rx.recv().await {
531 self.handle_message(message).await;
532 }
533 }
534 async fn handle_message(&mut self, message: TraceMessage) {
535 match message {
536 TraceMessage::GetSession(_msg) => {}
537 TraceMessage::InitTrace(msg) => {
538 let WithChannels { inner, tx, .. } = msg;
539 if self.traces_by_name.contains_key(&inner.info.name) {
540 return send_err(tx, RemoteError::other("Trace already initialized")).await;
541 }
542 let uuid = Uuid::now_v7();
543 self.traces_by_name.insert(inner.info.name.clone(), uuid);
544 self.traces.insert(
545 uuid,
546 TraceState {
547 init: inner,
548 nodes: Default::default(),
549 barrier_start: Default::default(),
550 barrier_checkpoint: Default::default(),
551 },
552 );
553 tx.send(Ok(uuid)).await.ok();
554 }
555 TraceMessage::GetTrace(msg) => {
556 let WithChannels { inner, tx, .. } = msg;
557 let GetTrace {
558 session_id: _,
559 name,
560 } = inner;
561 let Some((trace_id, info)) = self
562 .traces_by_name
563 .get(&name)
564 .and_then(|trace_id| self.traces.get_key_value(trace_id))
565 else {
566 return send_err(tx, RemoteError::other("Trace not initialized")).await;
567 };
568 tx.send(Ok(GetTraceResponse {
569 trace_id: *trace_id,
570 info: info.init.info.clone(),
571 setup_data: info.init.setup_data.clone(),
572 }))
573 .await
574 .ok();
575 }
576 TraceMessage::StartNode(msg) => {
577 let WithChannels { inner, tx, .. } = msg;
578 if self.traces.contains_key(&inner.trace_id) {
579 tx.send(Ok(StartNodeResponse {})).await.ok();
580 } else {
581 send_err(tx, RemoteError::other("Trace not initialized")).await;
582 }
583 }
584 TraceMessage::EndNode(msg) => {
585 msg.tx.send(Ok(())).await.ok();
587 }
588 TraceMessage::CloseTrace(msg) => {
589 msg.tx.send(Ok(())).await.ok();
591 }
592 TraceMessage::PutCheckpoint(msg) => {
593 msg.tx.send(Ok(())).await.ok();
595 }
596 TraceMessage::PutLogs(msg) => {
597 msg.tx.send(Ok(())).await.ok();
599 }
600 TraceMessage::PutMetrics(msg) => {
601 msg.tx.send(Ok(())).await.ok();
603 }
604 TraceMessage::WaitCheckpoint(msg) => {
605 let WithChannels { inner, tx, .. } = msg;
606 let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
607 return send_err(tx, RemoteError::TraceNotFound).await;
608 };
609 let node_count = trace.node_count();
610 let barrier = trace
611 .barrier_checkpoint
612 .entry(inner.checkpoint_id)
613 .or_default();
614 barrier.push(tx);
615 debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, count=barrier.len(), total=node_count, "wait checkpoint");
616 if barrier.len() == node_count {
617 debug!(trace_id=%inner.trace_id, checkpoint=inner.checkpoint_id, "release");
618 barrier
619 .drain(..)
620 .map(|tx| tx.send(Ok(())))
621 .into_unordered_stream()
622 .count()
623 .await;
624 }
625 }
626 TraceMessage::WaitStart(msg) => {
627 let WithChannels { inner, tx, .. } = msg;
628 let Some(trace) = self.traces.get_mut(&inner.trace_id) else {
629 return send_err(tx, RemoteError::TraceNotFound).await;
630 };
631 trace.nodes.push(inner.info);
632 trace.barrier_start.push(tx);
633 if trace.barrier_start.len() == trace.init.info.node_count as usize {
634 let data = WaitStartResponse {
635 infos: trace.nodes.clone(),
636 };
637 for tx in trace.barrier_start.drain(..) {
638 tx.send(Ok(data.clone())).await.ok();
639 }
640 }
641 }
642 }
643 }
644}
645
646async fn send_err<T>(tx: oneshot::Sender<RemoteResult<T>>, err: RemoteError) {
647 tx.send(Err(err)).await.ok();
648}