1use std::{
2 any::Any,
3 fmt::Debug,
4 marker::PhantomData,
5 pin::Pin,
6 sync::{Arc, RwLock},
7 time::Duration,
8};
9
10use anyhow::{Context, Result};
11use bytes::Bytes;
12use iroh::{Endpoint, NodeAddr, NodeId, SecretKey};
13use iroh_metrics::encoding::Encoder;
14use iroh_n0des::{
15 Registry,
16 simulation::proto::{ActiveTrace, NodeInfo, TraceClient, TraceInfo},
17};
18use n0_future::IterExt;
19use proto::{GetTraceResponse, NodeInfoWithAddr, Scope};
20use serde::{Serialize, de::DeserializeOwned};
21use tokio::sync::Semaphore;
22use tokio_util::sync::CancellationToken;
23use tracing::{Instrument, debug, error_span, info, warn};
24use uuid::Uuid;
25
26pub mod events;
27pub mod proto;
28pub mod trace;
29
30pub const ENV_TRACE_ISOLATED: &str = "N0DES_TRACE_ISOLATED";
32pub const ENV_TRACE_INIT_ONLY: &str = "N0DES_TRACE_INIT_ONLY";
34pub const ENV_TRACE_SERVER: &str = "N0DES_TRACE_SERVER";
36pub const ENV_TRACE_SESSION_ID: &str = "N0DES_SESSION_ID";
38
39type BoxedSetupFn<D> = Box<dyn 'static + Send + Sync + FnOnce() -> BoxFuture<'static, Result<D>>>;
40
41type BoxedSpawnFn<D> = Arc<
42 dyn 'static
43 + Send
44 + Sync
45 + for<'a> Fn(&'a mut SpawnContext<'a, D>) -> BoxFuture<'a, Result<BoxNode>>,
46>;
47type BoxedRoundFn<D> = Arc<
48 dyn 'static
49 + Send
50 + Sync
51 + for<'a> Fn(&'a mut BoxNode, &'a RoundContext<'a, D>) -> BoxFuture<'a, Result<bool>>,
52>;
53
54type BoxedCheckFn<D> = Arc<dyn Fn(&BoxNode, &RoundContext<'_, D>) -> Result<()>>;
55
56pub trait AsyncCallback<'a, A1: 'a, A2: 'a, T: 'a>:
63 'static + Send + Sync + Fn(&'a mut A1, &'a A2) -> Self::Fut
64{
65 type Fut: Future<Output = T> + Send;
66}
67
68impl<'a, A1: 'a, A2: 'a, T: 'a, Out, F> AsyncCallback<'a, A1, A2, T> for F
69where
70 Out: Send + Future<Output = T>,
71 F: 'static + Sync + Send + Fn(&'a mut A1, &'a A2) -> Out,
72{
73 type Fut = Out;
74}
75
76pub trait SetupData: Serialize + DeserializeOwned + Send + Sync + Clone + Debug + 'static {}
81impl<T> SetupData for T where T: Serialize + DeserializeOwned + Send + Sync + Clone + Debug + 'static
82{}
83
84pub struct SpawnContext<'a, D = ()> {
89 secret_key: SecretKey,
90 node_idx: u32,
91 setup_data: &'a D,
92 registry: &'a mut Registry,
93}
94
95impl<'a, D: SetupData> SpawnContext<'a, D> {
96 pub fn node_index(&self) -> u32 {
98 self.node_idx
99 }
100
101 pub fn setup_data(&self) -> &D {
103 self.setup_data
104 }
105
106 pub fn metrics_registry(&mut self) -> &mut Registry {
110 self.registry
111 }
112
113 pub fn secret_key(&self) -> SecretKey {
115 self.secret_key.clone()
116 }
117
118 pub fn node_id(&self) -> NodeId {
120 self.secret_key.public()
121 }
122
123 pub async fn bind_endpoint(&self) -> Result<Endpoint> {
129 let ep = Endpoint::builder()
130 .discovery_n0()
131 .secret_key(self.secret_key())
132 .bind()
133 .await?;
134 Ok(ep)
135 }
136}
137
138pub struct RoundContext<'a, D = ()> {
143 round: u32,
144 node_index: u32,
145 setup_data: &'a D,
146 all_nodes: &'a Vec<NodeInfoWithAddr>,
147}
148
149impl<'a, D> RoundContext<'a, D> {
150 pub fn round(&self) -> u32 {
152 self.round
153 }
154
155 pub fn node_index(&self) -> u32 {
157 self.node_index
158 }
159
160 pub fn setup_data(&self) -> &D {
162 self.setup_data
163 }
164
165 pub fn all_other_nodes(&self, me: NodeId) -> impl Iterator<Item = &NodeAddr> + '_ {
167 self.all_nodes
168 .iter()
169 .filter(move |n| n.info.node_id != Some(me))
170 .flat_map(|n| &n.addr)
171 }
172
173 pub fn addr(&self, idx: u32) -> Result<NodeAddr> {
179 self.all_nodes
180 .iter()
181 .find(|n| n.info.idx == idx)
182 .cloned()
183 .context("node not found")?
184 .addr
185 .context("node has no address")
186 }
187
188 pub fn self_addr(&self) -> Option<&NodeAddr> {
190 self.all_nodes
191 .iter()
192 .find(|n| n.info.idx == self.node_index)
193 .and_then(|info| info.addr.as_ref())
194 }
195
196 pub fn try_self_addr(&self) -> Result<&NodeAddr> {
197 self.self_addr().context("missing node address")
198 }
199
200 pub fn node_count(&self) -> usize {
202 self.all_nodes.len()
203 }
204}
205
206pub trait Spawn<D: SetupData = ()>: Node + 'static {
215 fn spawn(context: &mut SpawnContext<'_, D>) -> impl Future<Output = Result<Self>> + Send
221 where
222 Self: Sized;
223
224 fn spawn_dyn<'a>(context: &'a mut SpawnContext<'a, D>) -> BoxFuture<'a, Result<BoxNode>>
232 where
233 Self: Sized,
234 {
235 Box::pin(async move {
236 let node = Self::spawn(context).await?;
237 let node: Box<dyn DynNode> = Box::new(node);
238 anyhow::Ok(node)
239 })
240 }
241
242 fn builder(
247 round_fn: impl for<'a> AsyncCallback<'a, Self, RoundContext<'a, D>, Result<bool>>,
248 ) -> NodeBuilder<Self, D>
249 where
250 Self: Sized,
251 {
252 NodeBuilder::new(round_fn)
253 }
254}
255
256pub trait Node: Send + 'static {
264 fn endpoint(&self) -> Option<&Endpoint> {
268 None
269 }
270
271 fn shutdown(&mut self) -> impl Future<Output = Result<()>> + Send + '_ {
279 async { Ok(()) }
280 }
281}
282
283type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
284
285pub type BoxNode = Box<dyn DynNode>;
287
288pub trait DynNode: Send + Any + 'static {
293 fn shutdown(&mut self) -> BoxFuture<'_, Result<()>> {
301 Box::pin(async { Ok(()) })
302 }
303
304 fn endpoint(&self) -> Option<&Endpoint> {
308 None
309 }
310
311 fn as_any(&self) -> &dyn Any;
313
314 fn as_any_mut(&mut self) -> &mut dyn Any;
316}
317
318impl<T: Node + Sized> DynNode for T {
319 fn shutdown(&mut self) -> BoxFuture<'_, Result<()>> {
320 Box::pin(<Self as Node>::shutdown(self))
321 }
322
323 fn endpoint(&self) -> Option<&Endpoint> {
324 <Self as Node>::endpoint(self)
325 }
326
327 fn as_any(&self) -> &dyn Any {
328 self
329 }
330
331 fn as_any_mut(&mut self) -> &mut dyn Any {
332 self
333 }
334}
335
336#[derive()]
337pub struct Builder<D = ()> {
342 setup_fn: BoxedSetupFn<D>,
343 node_builders: Vec<NodeBuilderWithCount<D>>,
344 rounds: u32,
345}
346
347#[derive(Clone)]
348pub struct NodeBuilder<N, D> {
353 phantom: PhantomData<N>,
354 spawn_fn: BoxedSpawnFn<D>,
355 round_fn: BoxedRoundFn<D>,
356 check_fn: Option<BoxedCheckFn<D>>,
357}
358
359#[derive(Clone)]
360struct ErasedNodeBuilder<D> {
361 spawn_fn: BoxedSpawnFn<D>,
362 round_fn: BoxedRoundFn<D>,
363 check_fn: Option<BoxedCheckFn<D>>,
364}
365
366impl<T, N: Spawn<D>, D: SetupData> From<T> for NodeBuilder<N, D>
367where
368 T: for<'a> AsyncCallback<'a, N, RoundContext<'a, D>, Result<bool>>,
369{
370 fn from(value: T) -> Self {
371 Self::new(value)
372 }
373}
374
375impl<N: Spawn<D>, D: SetupData> NodeBuilder<N, D> {
376 pub fn new(
381 round_fn: impl for<'a> AsyncCallback<'a, N, RoundContext<'a, D>, Result<bool>>,
382 ) -> Self {
383 let spawn_fn: BoxedSpawnFn<D> = Arc::new(N::spawn_dyn);
384 let round_fn: BoxedRoundFn<D> = Arc::new(move |node, context| {
385 let node = node
386 .as_any_mut()
387 .downcast_mut::<N>()
388 .expect("unreachable: type is statically guaranteed");
389 Box::pin(round_fn(node, context))
390 });
391 Self {
392 phantom: PhantomData,
393 spawn_fn,
394 round_fn,
395 check_fn: None,
396 }
397 }
398
399 pub fn check(
408 mut self,
409 check_fn: impl 'static + for<'a> Fn(&'a N, &RoundContext<'a, D>) -> Result<()>,
410 ) -> Self {
411 let check_fn: BoxedCheckFn<D> = Arc::new(move |node, context| {
412 let node = node
413 .as_any()
414 .downcast_ref::<N>()
415 .expect("unreachable: type is statically guaranteed");
416 check_fn(node, context)
417 });
418 self.check_fn = Some(check_fn);
419 self
420 }
421
422 fn erase(self) -> ErasedNodeBuilder<D> {
423 ErasedNodeBuilder {
424 spawn_fn: self.spawn_fn,
425 round_fn: self.round_fn,
426 check_fn: self.check_fn,
427 }
428 }
429}
430
431struct SimNode<D> {
432 node: BoxNode,
433 trace_id: Uuid,
434 idx: u32,
435 round_fn: BoxedRoundFn<D>,
436 check_fn: Option<BoxedCheckFn<D>>,
437 round: u32,
438 info: NodeInfo,
439 metrics_encoder: Encoder,
440 all_nodes: Vec<NodeInfoWithAddr>,
441}
442
443impl<D: SetupData> SimNode<D> {
444 async fn spawn_and_run(
445 builder: NodeBuilderWithIdx<D>,
446 client: TraceClient,
447 trace_id: Uuid,
448 setup_data: &D,
449 rounds: u32,
450 ) -> Result<()> {
451 let secret_key = SecretKey::generate(&mut rand::rng());
452 let NodeBuilderWithIdx { node_idx, builder } = builder;
453 let info = NodeInfo {
454 node_id: Some(secret_key.public()),
456 idx: node_idx,
457 label: None,
458 };
459 let mut registry = Registry::default();
460 let mut context = SpawnContext {
461 setup_data,
462 node_idx,
463 secret_key,
464 registry: &mut registry,
465 };
466 let node = (builder.spawn_fn)(&mut context).await?;
467
468 if let Some(endpoint) = node.endpoint() {
469 registry.register_all(endpoint.metrics());
470 }
471
472 let mut node = Self {
473 node,
474 trace_id,
475 idx: node_idx,
476 info,
477 round: 0,
478 round_fn: builder.round_fn,
479 check_fn: builder.check_fn,
480 metrics_encoder: Encoder::new(Arc::new(RwLock::new(registry))),
481 all_nodes: Default::default(),
482 };
483
484 let res = node
485 .run(&client, setup_data, rounds)
486 .await
487 .with_context(|| format!("node {} failed", node.idx));
488 if let Err(err) = &res {
489 warn!("node failed: {err:#}");
490 }
491 res
492 }
493
494 async fn run(&mut self, client: &TraceClient, setup_data: &D, rounds: u32) -> Result<()> {
495 let client = client.start_node(self.trace_id, self.info.clone()).await?;
496
497 info!(idx = self.idx, "start");
498
499 let info = NodeInfoWithAddr {
500 addr: self.my_addr().await,
501 info: self.info.clone(),
502 };
503 let all_nodes = client.wait_start(info).await?;
504 self.all_nodes = all_nodes;
505
506 let result = self.run_rounds(&client, setup_data, rounds).await;
507
508 if let Err(err) = self.node.shutdown().await {
509 warn!("failure during node shutdown: {err:#}");
510 }
511
512 client.end(to_str_err(&result)).await?;
513
514 result
515 }
516
517 async fn run_rounds(
518 &mut self,
519 client: &ActiveTrace,
520 setup_data: &D,
521 rounds: u32,
522 ) -> Result<()> {
523 while self.round < rounds {
524 if !self
525 .run_round(client, setup_data)
526 .await
527 .with_context(|| format!("failed at round {}", self.round))?
528 {
529 return Ok(());
530 }
531 self.round += 1;
532 }
533 Ok(())
534 }
535
536 #[tracing::instrument(name="round", skip_all, fields(round=self.round))]
537 async fn run_round(&mut self, client: &ActiveTrace, setup_data: &D) -> Result<bool> {
538 info!("start round");
539 let context = RoundContext {
540 round: self.round,
541 node_index: self.idx,
542 setup_data,
543 all_nodes: &self.all_nodes,
544 };
545
546 let result = (self.round_fn)(&mut self.node, &context)
547 .await
548 .context("round function failed");
549
550 let checkpoint = (context.round + 1) as u64;
551 let label = format!("Round {} end", context.round);
552 client
553 .put_checkpoint(checkpoint, Some(label), to_str_err(&result))
554 .await?;
555
556 if let Some(node_id) = self.node_id() {
558 client
559 .put_metrics(node_id, Some(checkpoint), self.metrics_encoder.export())
560 .await?;
561 }
562
563 client.wait_checkpoint(checkpoint).await?;
564
565 match result {
566 Ok(out) => {
567 if let Some(check_fn) = self.check_fn.as_ref() {
568 (check_fn)(&self.node, &context).context("check function failed")?;
569 }
570 Ok(out)
571 }
572 Err(err) => Err(err),
573 }
574 }
575
576 fn node_id(&self) -> Option<NodeId> {
577 self.info.node_id
578 }
579
580 async fn my_addr(&self) -> Option<NodeAddr> {
581 if let Some(endpoint) = self.node.endpoint() {
582 Some(node_addr(endpoint).await)
583 } else {
584 None
585 }
586 }
587}
588
589async fn node_addr(endpoint: &Endpoint) -> NodeAddr {
590 endpoint.online().await;
591 endpoint.node_addr()
592}
593
594impl Default for Builder<()> {
595 fn default() -> Self {
596 Self::new()
597 }
598}
599
600impl Builder<()> {
601 pub fn new() -> Builder<()> {
603 let setup_fn: BoxedSetupFn<()> = Box::new(move || Box::pin(async move { Ok(()) }));
604 Builder {
605 node_builders: Vec::new(),
606 setup_fn,
607 rounds: 0,
608 }
609 }
610}
611impl<D: SetupData> Builder<D> {
612 pub fn with_setup<F, Fut>(setup_fn: F) -> Builder<D>
625 where
626 F: 'static + Send + Sync + FnOnce() -> Fut,
627 Fut: 'static + Send + Future<Output = Result<D>>,
628 {
629 let setup_fn: BoxedSetupFn<D> = Box::new(move || Box::pin(setup_fn()));
630 Builder {
631 node_builders: Vec::new(),
632 setup_fn,
633 rounds: 0,
634 }
635 }
636
637 pub fn rounds(mut self, rounds: u32) -> Self {
639 self.rounds = rounds;
640 self
641 }
642
643 pub fn spawn<N: Spawn<D>>(
651 mut self,
652 node_count: u32,
653 node_builder: impl Into<NodeBuilder<N, D>>,
654 ) -> Self {
655 let node_builder = node_builder.into();
656 self.node_builders.push(NodeBuilderWithCount {
657 count: node_count,
658 builder: node_builder.erase(),
659 });
660 self
661 }
662
663 pub async fn build(self, name: &str) -> Result<Simulation<D>> {
673 let client = TraceClient::from_env_or_local()?;
674 let run_mode = RunMode::from_env()?;
675
676 debug!(%name, ?run_mode, "build simulation run");
677
678 let (trace_id, setup_data) = if matches!(run_mode, RunMode::InitOnly | RunMode::Integrated)
679 {
680 let setup_data = (self.setup_fn)().await?;
681 let encoded_setup_data = Bytes::from(postcard::to_stdvec(&setup_data)?);
682 let node_count = self.node_builders.iter().map(|builder| builder.count).sum();
683 let trace_id = client
684 .init_trace(
685 TraceInfo {
686 name: name.to_string(),
687 node_count,
688 expected_checkpoints: Some(self.rounds as u64),
689 },
690 Some(encoded_setup_data),
691 )
692 .await?;
693 info!(%name, node_count, %trace_id, "init simulation");
694
695 (trace_id, setup_data)
696 } else {
697 let info = client.get_trace(name.to_string()).await?;
698 let GetTraceResponse {
699 trace_id,
700 info,
701 setup_data,
702 } = info;
703 info!(%name, node_count=info.node_count, %trace_id, "get simulation");
704 let setup_data = setup_data.context("expected setup data to be set")?;
705 let setup_data: D =
706 postcard::from_bytes(&setup_data).context("failed to decode setup data")?;
707 (trace_id, setup_data)
708 };
709
710 let mut node_builders = self
711 .node_builders
712 .into_iter()
713 .flat_map(|builder| (0..builder.count).map(move |_| builder.builder.clone()))
714 .enumerate()
715 .map(|(node_idx, builder)| NodeBuilderWithIdx {
716 node_idx: node_idx as u32,
717 builder,
718 });
719
720 let node_builders: Vec<_> = match run_mode {
721 RunMode::InitOnly => vec![],
722 RunMode::Integrated => node_builders.collect(),
723 RunMode::Isolated(idx) => vec![
724 node_builders
725 .nth(idx as usize)
726 .context("invalid isolated index")?,
727 ],
728 };
729
730 Ok(Simulation {
731 run_mode,
732 max_rounds: self.rounds,
733 node_builders,
734 client,
735 trace_id,
736 setup_data,
737 })
738 }
739}
740
741struct NodeBuilderWithCount<D> {
742 count: u32,
743 builder: ErasedNodeBuilder<D>,
744}
745
746struct NodeBuilderWithIdx<D> {
747 node_idx: u32,
748 builder: ErasedNodeBuilder<D>,
749}
750
751pub struct Simulation<D> {
756 trace_id: Uuid,
757 run_mode: RunMode,
758 client: TraceClient,
759 setup_data: D,
760 max_rounds: u32,
761 node_builders: Vec<NodeBuilderWithIdx<D>>,
762}
763
764impl<D: SetupData> Simulation<D> {
765 pub async fn run(self) -> Result<()> {
774 let cancel_token = CancellationToken::new();
775
776 let logs_scope = match self.run_mode {
778 RunMode::Isolated(idx) => Some(Scope::Isolated(idx)),
779 RunMode::Integrated => Some(Scope::Integrated),
780 RunMode::InitOnly => None,
782 };
783 let logs_task = if let Some(scope) = logs_scope {
784 Some(spawn_logs_task(
785 self.client.clone(),
786 self.trace_id,
787 scope,
788 cancel_token.clone(),
789 ))
790 } else {
791 None
792 };
793
794 let result = self
796 .node_builders
797 .into_iter()
798 .map(async |builder| {
799 let span = error_span!("sim-node", idx = builder.node_idx);
800 SimNode::spawn_and_run(
801 builder,
802 self.client.clone(),
803 self.trace_id,
804 &self.setup_data,
805 self.max_rounds,
806 )
807 .instrument(span)
808 .await
809 })
810 .try_join_all()
811 .await
812 .map(|_list| ());
813
814 cancel_token.cancel();
815 if let Some(join_handle) = logs_task {
816 join_handle.await?;
817 }
818
819 if matches!(self.run_mode, RunMode::Integrated) {
820 self.client
821 .close_trace(self.trace_id, to_str_err(&result))
822 .await?;
823 }
824
825 result
826 }
827}
828
829#[derive(Debug, Copy, Clone)]
830enum RunMode {
831 InitOnly,
832 Integrated,
833 Isolated(u32),
834}
835
836impl RunMode {
837 fn from_env() -> Result<Self> {
838 if std::env::var(ENV_TRACE_INIT_ONLY).is_ok() {
839 Ok(Self::InitOnly)
840 } else {
841 match std::env::var(ENV_TRACE_ISOLATED) {
842 Err(_) => Ok(Self::Integrated),
843 Ok(s) => {
844 let idx = s.parse().with_context(|| {
845 format!("Failed to parse env var `{ENV_TRACE_ISOLATED}` as number")
846 })?;
847 Ok(Self::Isolated(idx))
848 }
849 }
850 }
851 }
852}
853
854fn spawn_logs_task(
857 client: TraceClient,
858 trace_id: Uuid,
859 scope: Scope,
860 cancel_token: CancellationToken,
861) -> tokio::task::JoinHandle<()> {
862 tokio::task::spawn(async move {
863 loop {
864 if cancel_token
865 .run_until_cancelled(tokio::time::sleep(Duration::from_secs(1)))
866 .await
867 .is_none()
868 {
869 break;
870 }
871 let lines = self::trace::get_logs();
872 if lines.is_empty() {
873 continue;
874 }
875 for lines_chunk in lines.chunks(500) {
878 if let Err(e) = client.put_logs(trace_id, scope, lines_chunk.to_vec()).await {
879 eprintln!(
880 "warning: failed to submit logs due to error, stopping log submission now: {e:?}"
881 );
882 break;
883 }
884 }
885 if cancel_token.is_cancelled() {
886 break;
887 }
888 }
889 })
890}
891
892static PERMIT: Semaphore = Semaphore::const_new(1);
893
894#[doc(hidden)]
904pub async fn run_sim_fn<F, Fut, D, E>(name: &str, sim_fn: F) -> anyhow::Result<()>
905where
906 F: Fn() -> Fut,
907 Fut: Future<Output = Result<Builder<D>, E>>,
908 D: SetupData,
909 anyhow::Error: From<E>,
910{
911 let permit = PERMIT.acquire().await.expect("semaphore closed");
913
914 self::trace::init();
916 self::trace::global_writer().clear();
918
919 eprintln!("running simulation: {name}");
920 let result = sim_fn()
921 .await
922 .map_err(anyhow::Error::from)
923 .with_context(|| format!("simulation builder function `{name}` failed"))?
924 .build(name)
925 .await
926 .with_context(|| format!("simulation `{name}` failed to start"))?
927 .run()
928 .await
929 .with_context(|| format!("simulation `{name}` failed to complete"));
930
931 match &result {
932 Ok(()) => eprintln!("simulation `{name}` passed"),
933 Err(err) => eprintln!("simulation `{name}` failed: {err:#}"),
934 };
935
936 drop(permit);
937
938 result
939}
940
941fn to_str_err<T>(res: &Result<T, anyhow::Error>) -> Result<(), String> {
942 if let Some(err) = res.as_ref().err() {
943 Err(format!("{err:?}"))
944 } else {
945 Ok(())
946 }
947}