1use std::{
2 any::Any,
3 fmt::Debug,
4 marker::PhantomData,
5 pin::Pin,
6 sync::{Arc, RwLock},
7 time::Duration,
8};
9
10use anyhow::{Context, Result};
11use bytes::Bytes;
12use iroh::{Endpoint, EndpointAddr, EndpointId, SecretKey};
13use iroh_metrics::encoding::Encoder;
14use iroh_n0des::{
15 Registry,
16 simulation::proto::{ActiveTrace, EndpointInfo, TraceClient, TraceInfo},
17};
18use n0_future::IterExt;
19use proto::{EndpointInfoWithAddr, GetTraceResponse, Scope};
20use serde::{Serialize, de::DeserializeOwned};
21use tokio::sync::Semaphore;
22use tokio_util::sync::CancellationToken;
23use tracing::{Instrument, debug, error_span, info, warn};
24use uuid::Uuid;
25
26pub mod events;
27pub mod proto;
28pub mod trace;
29
30pub const ENV_TRACE_ISOLATED: &str = "N0DES_TRACE_ISOLATED";
32pub const ENV_TRACE_INIT_ONLY: &str = "N0DES_TRACE_INIT_ONLY";
34pub const ENV_TRACE_SERVER: &str = "N0DES_TRACE_SERVER";
36pub const ENV_TRACE_SESSION_ID: &str = "N0DES_SESSION_ID";
38
39type BoxedSetupFn<D> = Box<dyn 'static + Send + Sync + FnOnce() -> BoxFuture<'static, Result<D>>>;
40
41type BoxedSpawnFn<D> = Arc<
42 dyn 'static
43 + Send
44 + Sync
45 + for<'a> Fn(&'a mut SpawnContext<'a, D>) -> BoxFuture<'a, Result<BoxNode>>,
46>;
47type BoxedRoundFn<D> = Arc<
48 dyn 'static
49 + Send
50 + Sync
51 + for<'a> Fn(&'a mut BoxNode, &'a RoundContext<'a, D>) -> BoxFuture<'a, Result<bool>>,
52>;
53
54type BoxedCheckFn<D> = Arc<dyn Fn(&BoxNode, &RoundContext<'_, D>) -> Result<()>>;
55
56pub trait AsyncCallback<'a, A1: 'a, A2: 'a, T: 'a>:
63 'static + Send + Sync + Fn(&'a mut A1, &'a A2) -> Self::Fut
64{
65 type Fut: Future<Output = T> + Send;
66}
67
68impl<'a, A1: 'a, A2: 'a, T: 'a, Out, F> AsyncCallback<'a, A1, A2, T> for F
69where
70 Out: Send + Future<Output = T>,
71 F: 'static + Sync + Send + Fn(&'a mut A1, &'a A2) -> Out,
72{
73 type Fut = Out;
74}
75
76pub trait SetupData: Serialize + DeserializeOwned + Send + Sync + Clone + Debug + 'static {}
81impl<T> SetupData for T where T: Serialize + DeserializeOwned + Send + Sync + Clone + Debug + 'static
82{}
83
84pub struct SpawnContext<'a, D = ()> {
89 secret_key: SecretKey,
90 idx: u32,
91 setup_data: &'a D,
92 registry: &'a mut Registry,
93}
94
95impl<'a, D: SetupData> SpawnContext<'a, D> {
96 pub fn node_index(&self) -> u32 {
98 self.idx
99 }
100
101 pub fn setup_data(&self) -> &D {
103 self.setup_data
104 }
105
106 pub fn metrics_registry(&mut self) -> &mut Registry {
110 self.registry
111 }
112
113 pub fn secret_key(&self) -> SecretKey {
115 self.secret_key.clone()
116 }
117
118 pub fn id(&self) -> EndpointId {
120 self.secret_key.public()
121 }
122
123 pub async fn bind_endpoint(&self) -> Result<Endpoint> {
129 let ep = Endpoint::builder()
130 .secret_key(self.secret_key())
131 .bind()
132 .await?;
133 Ok(ep)
134 }
135}
136
137pub struct RoundContext<'a, D = ()> {
142 round: u32,
143 node_index: u32,
144 setup_data: &'a D,
145 all_nodes: &'a Vec<EndpointInfoWithAddr>,
146}
147
148impl<'a, D> RoundContext<'a, D> {
149 pub fn round(&self) -> u32 {
151 self.round
152 }
153
154 pub fn node_index(&self) -> u32 {
156 self.node_index
157 }
158
159 pub fn setup_data(&self) -> &D {
161 self.setup_data
162 }
163
164 pub fn all_other_nodes(&self, me: EndpointId) -> impl Iterator<Item = &EndpointAddr> + '_ {
166 self.all_nodes
167 .iter()
168 .filter(move |n| n.info.endpoint_id != Some(me))
169 .flat_map(|n| &n.addr)
170 }
171
172 pub fn addr(&self, idx: u32) -> Result<EndpointAddr> {
178 self.all_nodes
179 .iter()
180 .find(|n| n.info.idx == idx)
181 .cloned()
182 .context("node not found")?
183 .addr
184 .context("node has no address")
185 }
186
187 pub fn self_addr(&self) -> Option<&EndpointAddr> {
189 self.all_nodes
190 .iter()
191 .find(|n| n.info.idx == self.node_index)
192 .and_then(|info| info.addr.as_ref())
193 }
194
195 pub fn try_self_addr(&self) -> Result<&EndpointAddr> {
196 self.self_addr().context("missing node address")
197 }
198
199 pub fn node_count(&self) -> usize {
201 self.all_nodes.len()
202 }
203}
204
205pub trait Spawn<D: SetupData = ()>: Node + 'static {
214 fn spawn(context: &mut SpawnContext<'_, D>) -> impl Future<Output = Result<Self>> + Send
220 where
221 Self: Sized;
222
223 fn spawn_dyn<'a>(context: &'a mut SpawnContext<'a, D>) -> BoxFuture<'a, Result<BoxNode>>
231 where
232 Self: Sized,
233 {
234 Box::pin(async move {
235 let node = Self::spawn(context).await?;
236 let node: Box<dyn DynNode> = Box::new(node);
237 anyhow::Ok(node)
238 })
239 }
240
241 fn builder(
246 round_fn: impl for<'a> AsyncCallback<'a, Self, RoundContext<'a, D>, Result<bool>>,
247 ) -> NodeBuilder<Self, D>
248 where
249 Self: Sized,
250 {
251 NodeBuilder::new(round_fn)
252 }
253}
254
255pub trait Node: Send + 'static {
263 fn endpoint(&self) -> Option<&Endpoint> {
267 None
268 }
269
270 fn shutdown(&mut self) -> impl Future<Output = Result<()>> + Send + '_ {
278 async { Ok(()) }
279 }
280}
281
282type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
283
284pub type BoxNode = Box<dyn DynNode>;
286
287pub trait DynNode: Send + Any + 'static {
292 fn shutdown(&mut self) -> BoxFuture<'_, Result<()>> {
300 Box::pin(async { Ok(()) })
301 }
302
303 fn endpoint(&self) -> Option<&Endpoint> {
307 None
308 }
309
310 fn as_any(&self) -> &dyn Any;
312
313 fn as_any_mut(&mut self) -> &mut dyn Any;
315}
316
317impl<T: Node + Sized> DynNode for T {
318 fn shutdown(&mut self) -> BoxFuture<'_, Result<()>> {
319 Box::pin(<Self as Node>::shutdown(self))
320 }
321
322 fn endpoint(&self) -> Option<&Endpoint> {
323 <Self as Node>::endpoint(self)
324 }
325
326 fn as_any(&self) -> &dyn Any {
327 self
328 }
329
330 fn as_any_mut(&mut self) -> &mut dyn Any {
331 self
332 }
333}
334
335#[derive()]
336pub struct Builder<D = ()> {
341 setup_fn: BoxedSetupFn<D>,
342 node_builders: Vec<NodeBuilderWithCount<D>>,
343 rounds: u32,
344}
345
346#[derive(Clone)]
347pub struct NodeBuilder<N, D> {
352 phantom: PhantomData<N>,
353 spawn_fn: BoxedSpawnFn<D>,
354 round_fn: BoxedRoundFn<D>,
355 check_fn: Option<BoxedCheckFn<D>>,
356}
357
358#[derive(Clone)]
359struct ErasedNodeBuilder<D> {
360 spawn_fn: BoxedSpawnFn<D>,
361 round_fn: BoxedRoundFn<D>,
362 check_fn: Option<BoxedCheckFn<D>>,
363}
364
365impl<T, N: Spawn<D>, D: SetupData> From<T> for NodeBuilder<N, D>
366where
367 T: for<'a> AsyncCallback<'a, N, RoundContext<'a, D>, Result<bool>>,
368{
369 fn from(value: T) -> Self {
370 Self::new(value)
371 }
372}
373
374impl<N: Spawn<D>, D: SetupData> NodeBuilder<N, D> {
375 pub fn new(
380 round_fn: impl for<'a> AsyncCallback<'a, N, RoundContext<'a, D>, Result<bool>>,
381 ) -> Self {
382 let spawn_fn: BoxedSpawnFn<D> = Arc::new(N::spawn_dyn);
383 let round_fn: BoxedRoundFn<D> = Arc::new(move |node, context| {
384 let node = node
385 .as_any_mut()
386 .downcast_mut::<N>()
387 .expect("unreachable: type is statically guaranteed");
388 Box::pin(round_fn(node, context))
389 });
390 Self {
391 phantom: PhantomData,
392 spawn_fn,
393 round_fn,
394 check_fn: None,
395 }
396 }
397
398 pub fn check(
407 mut self,
408 check_fn: impl 'static + for<'a> Fn(&'a N, &RoundContext<'a, D>) -> Result<()>,
409 ) -> Self {
410 let check_fn: BoxedCheckFn<D> = Arc::new(move |node, context| {
411 let node = node
412 .as_any()
413 .downcast_ref::<N>()
414 .expect("unreachable: type is statically guaranteed");
415 check_fn(node, context)
416 });
417 self.check_fn = Some(check_fn);
418 self
419 }
420
421 fn erase(self) -> ErasedNodeBuilder<D> {
422 ErasedNodeBuilder {
423 spawn_fn: self.spawn_fn,
424 round_fn: self.round_fn,
425 check_fn: self.check_fn,
426 }
427 }
428}
429
430struct SimNode<D> {
431 node: BoxNode,
432 trace_id: Uuid,
433 idx: u32,
434 round_fn: BoxedRoundFn<D>,
435 check_fn: Option<BoxedCheckFn<D>>,
436 round: u32,
437 info: EndpointInfo,
438 metrics_encoder: Encoder,
439 all_nodes: Vec<EndpointInfoWithAddr>,
440}
441
442impl<D: SetupData> SimNode<D> {
443 async fn spawn_and_run(
444 builder: NodeBuilderWithIdx<D>,
445 client: TraceClient,
446 trace_id: Uuid,
447 setup_data: &D,
448 rounds: u32,
449 ) -> Result<()> {
450 let secret_key = SecretKey::generate(&mut rand::rng());
451 let NodeBuilderWithIdx { idx, builder } = builder;
452 let info = EndpointInfo {
453 idx,
455 endpoint_id: Some(secret_key.public()),
456 label: None,
457 };
458 let mut registry = Registry::default();
459 let mut context = SpawnContext {
460 setup_data,
461 idx,
462 secret_key,
463 registry: &mut registry,
464 };
465 let node = (builder.spawn_fn)(&mut context).await?;
466
467 if let Some(endpoint) = node.endpoint() {
468 registry.register_all(endpoint.metrics());
469 }
470
471 let mut node = Self {
472 node,
473 trace_id,
474 idx,
475 info,
476 round: 0,
477 round_fn: builder.round_fn,
478 check_fn: builder.check_fn,
479 metrics_encoder: Encoder::new(Arc::new(RwLock::new(registry))),
480 all_nodes: Default::default(),
481 };
482
483 let res = node
484 .run(&client, setup_data, rounds)
485 .await
486 .with_context(|| format!("node {} failed", node.idx));
487 if let Err(err) = &res {
488 warn!("node failed: {err:#}");
489 }
490 res
491 }
492
493 async fn run(&mut self, client: &TraceClient, setup_data: &D, rounds: u32) -> Result<()> {
494 let client = client.start_node(self.trace_id, self.info.clone()).await?;
495
496 info!(idx = self.idx, "start");
497
498 let info = EndpointInfoWithAddr {
499 addr: self.my_addr().await,
500 info: self.info.clone(),
501 };
502 let all_nodes = client.wait_start(info).await?;
503 self.all_nodes = all_nodes;
504
505 let result = self.run_rounds(&client, setup_data, rounds).await;
506
507 if let Err(err) = self.node.shutdown().await {
508 warn!("failure during node shutdown: {err:#}");
509 }
510
511 client.end(to_str_err(&result)).await?;
512
513 result
514 }
515
516 async fn run_rounds(
517 &mut self,
518 client: &ActiveTrace,
519 setup_data: &D,
520 rounds: u32,
521 ) -> Result<()> {
522 while self.round < rounds {
523 if !self
524 .run_round(client, setup_data)
525 .await
526 .with_context(|| format!("failed at round {}", self.round))?
527 {
528 return Ok(());
529 }
530 self.round += 1;
531 }
532 Ok(())
533 }
534
535 #[tracing::instrument(name="round", skip_all, fields(round=self.round))]
536 async fn run_round(&mut self, client: &ActiveTrace, setup_data: &D) -> Result<bool> {
537 info!("start round");
538 let context = RoundContext {
539 round: self.round,
540 node_index: self.idx,
541 setup_data,
542 all_nodes: &self.all_nodes,
543 };
544
545 let result = (self.round_fn)(&mut self.node, &context)
546 .await
547 .context("round function failed");
548
549 let checkpoint = (context.round + 1) as u64;
550 let label = format!("Round {} end", context.round);
551 client
552 .put_checkpoint(checkpoint, Some(label), to_str_err(&result))
553 .await?;
554
555 if let Some(id) = self.id() {
557 client
558 .put_metrics(id, Some(checkpoint), self.metrics_encoder.export())
559 .await?;
560 }
561
562 client.wait_checkpoint(checkpoint).await?;
563
564 match result {
565 Ok(out) => {
566 if let Some(check_fn) = self.check_fn.as_ref() {
567 (check_fn)(&self.node, &context).context("check function failed")?;
568 }
569 Ok(out)
570 }
571 Err(err) => Err(err),
572 }
573 }
574
575 fn id(&self) -> Option<EndpointId> {
576 self.info.endpoint_id
577 }
578
579 async fn my_addr(&self) -> Option<EndpointAddr> {
580 if let Some(endpoint) = self.node.endpoint() {
581 Some(addr(endpoint).await)
582 } else {
583 None
584 }
585 }
586}
587
588async fn addr(endpoint: &Endpoint) -> EndpointAddr {
589 endpoint.online().await;
590 endpoint.addr()
591}
592
593impl Default for Builder<()> {
594 fn default() -> Self {
595 Self::new()
596 }
597}
598
599impl Builder<()> {
600 pub fn new() -> Builder<()> {
602 let setup_fn: BoxedSetupFn<()> = Box::new(move || Box::pin(async move { Ok(()) }));
603 Builder {
604 node_builders: Vec::new(),
605 setup_fn,
606 rounds: 0,
607 }
608 }
609}
610impl<D: SetupData> Builder<D> {
611 pub fn with_setup<F, Fut>(setup_fn: F) -> Builder<D>
624 where
625 F: 'static + Send + Sync + FnOnce() -> Fut,
626 Fut: 'static + Send + Future<Output = Result<D>>,
627 {
628 let setup_fn: BoxedSetupFn<D> = Box::new(move || Box::pin(setup_fn()));
629 Builder {
630 node_builders: Vec::new(),
631 setup_fn,
632 rounds: 0,
633 }
634 }
635
636 pub fn rounds(mut self, rounds: u32) -> Self {
638 self.rounds = rounds;
639 self
640 }
641
642 pub fn spawn<N: Spawn<D>>(
650 mut self,
651 node_count: u32,
652 node_builder: impl Into<NodeBuilder<N, D>>,
653 ) -> Self {
654 let node_builder = node_builder.into();
655 self.node_builders.push(NodeBuilderWithCount {
656 count: node_count,
657 builder: node_builder.erase(),
658 });
659 self
660 }
661
662 pub async fn build(self, name: &str) -> Result<Simulation<D>> {
672 let client = TraceClient::from_env_or_local()?;
673 let run_mode = RunMode::from_env()?;
674
675 debug!(%name, ?run_mode, "build simulation run");
676
677 let (trace_id, setup_data) = if matches!(run_mode, RunMode::InitOnly | RunMode::Integrated)
678 {
679 let setup_data = (self.setup_fn)().await?;
680 let encoded_setup_data = Bytes::from(postcard::to_stdvec(&setup_data)?);
681 let node_count = self.node_builders.iter().map(|builder| builder.count).sum();
682 let trace_id = client
683 .init_trace(
684 TraceInfo {
685 name: name.to_string(),
686 node_count,
687 expected_checkpoints: Some(self.rounds as u64),
688 },
689 Some(encoded_setup_data),
690 )
691 .await?;
692 info!(%name, node_count, %trace_id, "init simulation");
693
694 (trace_id, setup_data)
695 } else {
696 let info = client.get_trace(name.to_string()).await?;
697 let GetTraceResponse {
698 trace_id,
699 info,
700 setup_data,
701 } = info;
702 info!(%name, node_count=info.node_count, %trace_id, "get simulation");
703 let setup_data = setup_data.context("expected setup data to be set")?;
704 let setup_data: D =
705 postcard::from_bytes(&setup_data).context("failed to decode setup data")?;
706 (trace_id, setup_data)
707 };
708
709 let mut node_builders = self
710 .node_builders
711 .into_iter()
712 .flat_map(|builder| (0..builder.count).map(move |_| builder.builder.clone()))
713 .enumerate()
714 .map(|(idx, builder)| NodeBuilderWithIdx {
715 idx: idx as u32,
716 builder,
717 });
718
719 let node_builders: Vec<_> = match run_mode {
720 RunMode::InitOnly => vec![],
721 RunMode::Integrated => node_builders.collect(),
722 RunMode::Isolated(idx) => vec![
723 node_builders
724 .nth(idx as usize)
725 .context("invalid isolated index")?,
726 ],
727 };
728
729 Ok(Simulation {
730 run_mode,
731 max_rounds: self.rounds,
732 node_builders,
733 client,
734 trace_id,
735 setup_data,
736 })
737 }
738}
739
740struct NodeBuilderWithCount<D> {
741 count: u32,
742 builder: ErasedNodeBuilder<D>,
743}
744
745struct NodeBuilderWithIdx<D> {
746 idx: u32,
747 builder: ErasedNodeBuilder<D>,
748}
749
750pub struct Simulation<D> {
755 trace_id: Uuid,
756 run_mode: RunMode,
757 client: TraceClient,
758 setup_data: D,
759 max_rounds: u32,
760 node_builders: Vec<NodeBuilderWithIdx<D>>,
761}
762
763impl<D: SetupData> Simulation<D> {
764 pub async fn run(self) -> Result<()> {
773 let cancel_token = CancellationToken::new();
774
775 let logs_scope = match self.run_mode {
777 RunMode::Isolated(idx) => Some(Scope::Isolated(idx)),
778 RunMode::Integrated => Some(Scope::Integrated),
779 RunMode::InitOnly => None,
781 };
782 let logs_task = if let Some(scope) = logs_scope {
783 Some(spawn_logs_task(
784 self.client.clone(),
785 self.trace_id,
786 scope,
787 cancel_token.clone(),
788 ))
789 } else {
790 None
791 };
792
793 let result = self
795 .node_builders
796 .into_iter()
797 .map(async |builder| {
798 let span = error_span!("sim-node", idx = builder.idx);
799 SimNode::spawn_and_run(
800 builder,
801 self.client.clone(),
802 self.trace_id,
803 &self.setup_data,
804 self.max_rounds,
805 )
806 .instrument(span)
807 .await
808 })
809 .try_join_all()
810 .await
811 .map(|_list| ());
812
813 cancel_token.cancel();
814 if let Some(join_handle) = logs_task {
815 join_handle.await?;
816 }
817
818 if matches!(self.run_mode, RunMode::Integrated) {
819 self.client
820 .close_trace(self.trace_id, to_str_err(&result))
821 .await?;
822 }
823
824 result
825 }
826}
827
828#[derive(Debug, Copy, Clone)]
829enum RunMode {
830 InitOnly,
831 Integrated,
832 Isolated(u32),
833}
834
835impl RunMode {
836 fn from_env() -> Result<Self> {
837 if std::env::var(ENV_TRACE_INIT_ONLY).is_ok() {
838 Ok(Self::InitOnly)
839 } else {
840 match std::env::var(ENV_TRACE_ISOLATED) {
841 Err(_) => Ok(Self::Integrated),
842 Ok(s) => {
843 let idx = s.parse().with_context(|| {
844 format!("Failed to parse env var `{ENV_TRACE_ISOLATED}` as number")
845 })?;
846 Ok(Self::Isolated(idx))
847 }
848 }
849 }
850 }
851}
852
853fn spawn_logs_task(
856 client: TraceClient,
857 trace_id: Uuid,
858 scope: Scope,
859 cancel_token: CancellationToken,
860) -> tokio::task::JoinHandle<()> {
861 tokio::task::spawn(async move {
862 loop {
863 if cancel_token
864 .run_until_cancelled(tokio::time::sleep(Duration::from_secs(1)))
865 .await
866 .is_none()
867 {
868 break;
869 }
870 let lines = self::trace::get_logs();
871 if lines.is_empty() {
872 continue;
873 }
874 for lines_chunk in lines.chunks(500) {
877 if let Err(e) = client.put_logs(trace_id, scope, lines_chunk.to_vec()).await {
878 eprintln!(
879 "warning: failed to submit logs due to error, stopping log submission now: {e:?}"
880 );
881 break;
882 }
883 }
884 if cancel_token.is_cancelled() {
885 break;
886 }
887 }
888 })
889}
890
891static PERMIT: Semaphore = Semaphore::const_new(1);
892
893#[doc(hidden)]
903pub async fn run_sim_fn<F, Fut, D, E>(name: &str, sim_fn: F) -> anyhow::Result<()>
904where
905 F: Fn() -> Fut,
906 Fut: Future<Output = Result<Builder<D>, E>>,
907 D: SetupData,
908 anyhow::Error: From<E>,
909{
910 let permit = PERMIT.acquire().await.expect("semaphore closed");
912
913 self::trace::init();
915 self::trace::global_writer().clear();
917
918 eprintln!("running simulation: {name}");
919 let result = sim_fn()
920 .await
921 .map_err(anyhow::Error::from)
922 .with_context(|| format!("simulation builder function `{name}` failed"))?
923 .build(name)
924 .await
925 .with_context(|| format!("simulation `{name}` failed to start"))?
926 .run()
927 .await
928 .with_context(|| format!("simulation `{name}` failed to complete"));
929
930 match &result {
931 Ok(()) => eprintln!("simulation `{name}` passed"),
932 Err(err) => eprintln!("simulation `{name}` failed: {err:#}"),
933 };
934
935 drop(permit);
936
937 result
938}
939
940fn to_str_err<T>(res: &Result<T, anyhow::Error>) -> Result<(), String> {
941 if let Some(err) = res.as_ref().err() {
942 Err(format!("{err:?}"))
943 } else {
944 Ok(())
945 }
946}