n0_watcher/
lib.rs

1//! Watchable values.
2//!
3//! A [`Watchable`] exists to keep track of a value which may change over time.  It allows
4//! observers to be notified of changes to the value.  The aim is to always be aware of the
5//! **last** value, not to observe *every* value change.
6//!
7//! The reason for this is ergonomics and predictable resource usage: Requiring every
8//! intermediate value to be observable would mean that either the side that sets new values
9//! using [`Watchable::set`] would need to wait for all "receivers" of these intermediate
10//! values to catch up and thus be an async operation, or it would require the receivers
11//! to buffer intermediate values until they've been "received" on the [`Watcher`]s with
12//! an unlimited buffer size and thus potentially unlimited memory growth.
13//!
14//! # Example
15//!
16//! ```
17//! use n0_future::StreamExt;
18//! use n0_watcher::{Watchable, Watcher as _};
19//!
20//! #[tokio::main(flavor = "current_thread", start_paused = true)]
21//! async fn main() {
22//!     let watchable = Watchable::new(None);
23//!
24//!     // A task that waits for the watcher to be initialized to Some(value) before printing it
25//!     let mut watcher = watchable.watch();
26//!     tokio::spawn(async move {
27//!         let initialized_value = watcher.initialized().await;
28//!         println!("initialized: {initialized_value}");
29//!     });
30//!
31//!     // A task that prints every update to the watcher since the initial one:
32//!     let mut updates = watchable.watch().stream_updates_only();
33//!     tokio::spawn(async move {
34//!         while let Some(update) = updates.next().await {
35//!             println!("update: {update:?}");
36//!         }
37//!     });
38//!
39//!     // A task that prints the current value and then every update it can catch,
40//!     // but it also does something else which makes it very slow to pick up new
41//!     // values, so it'll skip some:
42//!     let mut current_and_updates = watchable.watch().stream();
43//!     tokio::spawn(async move {
44//!         while let Some(update) = current_and_updates.next().await {
45//!             println!("update2: {update:?}");
46//!             tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
47//!         }
48//!     });
49//!
50//!     for i in 0..20 {
51//!         println!("Setting watchable to {i}");
52//!         watchable.set(Some(i)).ok();
53//!         tokio::time::sleep(tokio::time::Duration::from_millis(250)).await;
54//!     }
55//! }
56//! ```
57//!
58//! # Similar but different
59//!
60//! - `async_channel`: This is a multi-producer, multi-consumer channel implementation.
61//!   Only at most one consumer will receive each "produced" value.
62//!   What we want is to have every "produced" value to be "broadcast" to every receiver.
63//! - `tokio::broadcast`: Also a multi-producer, multi-consumer channel implementation.
64//!   This is very similar to this crate (`tokio::broadcast::Sender` is like [`Watchable`]
65//!   and `tokio::broadcast::Receiver` is like [`Watcher`]), but you can't get the latest
66//!   value without `.await`ing on the receiver, and it'll internally store a queue of
67//!   intermediate values.
68//! - `tokio::watch`: Also a MPSC channel, and unlike `tokio::broadcast` only retains the
69//!   latest value. That module has pretty much the same purpose as this crate, but doesn't
70//!   implement a poll-based method of getting updates and doesn't implement combinators.
71//! - [`std::sync::RwLock`]: (wrapped in an [`std::sync::Arc`]) This allows you access
72//!   to the latest values, but might block while it's being set (but that could be short
73//!   enough not to matter for async rust purposes).
74//!   This doesn't allow you to be notified whenever a new value is written.
75//! - The `watchable` crate: We used to use this crate at n0, but we wanted to experiment
76//!   with different APIs and needed Wasm support.
77#[cfg(not(watcher_loom))]
78use std::sync;
79use std::{
80    collections::VecDeque,
81    future::Future,
82    pin::Pin,
83    sync::{Arc, Weak},
84    task::{self, ready, Poll, Waker},
85};
86
87#[cfg(watcher_loom)]
88use loom::sync;
89use n0_error::StackError;
90use sync::{Mutex, RwLock};
91
92/// A wrapper around a value that notifies [`Watcher`]s when the value is modified.
93///
94/// Only the most recent value is available to any observer, but the observer is guaranteed
95/// to be notified of the most recent value.
96#[derive(Debug, Default)]
97pub struct Watchable<T> {
98    shared: Arc<Shared<T>>,
99}
100
101impl<T> Clone for Watchable<T> {
102    fn clone(&self) -> Self {
103        Self {
104            shared: self.shared.clone(),
105        }
106    }
107}
108
109/// Abstracts over `Option<T>` and `Vec<T>`
110pub trait Nullable<T> {
111    /// Converts this value into an `Option`.
112    fn into_option(self) -> Option<T>;
113}
114
115impl<T> Nullable<T> for Option<T> {
116    fn into_option(self) -> Option<T> {
117        self
118    }
119}
120
121impl<T> Nullable<T> for Vec<T> {
122    fn into_option(mut self) -> Option<T> {
123        self.pop()
124    }
125}
126
127impl<T: Clone + Eq> Watchable<T> {
128    /// Creates a [`Watchable`] initialized to given value.
129    pub fn new(value: T) -> Self {
130        Self {
131            shared: Arc::new(Shared {
132                state: RwLock::new(State {
133                    value,
134                    epoch: INITIAL_EPOCH,
135                }),
136                watchers: Default::default(),
137            }),
138        }
139    }
140
141    /// Sets a new value.
142    ///
143    /// Returns `Ok(previous_value)` if the value was different from the one set, or
144    /// returns the provided value back as `Err(value)` if the value didn't change.
145    ///
146    /// Watchers are only notified if the value changed.
147    pub fn set(&self, value: T) -> Result<T, T> {
148        // We don't actually write when the value didn't change, but there's unfortunately
149        // no way to upgrade a read guard to a write guard, and locking as read first, then
150        // dropping and locking as write introduces a possible race condition.
151        let mut state = self.shared.state.write().expect("poisoned");
152
153        // Find out if the value changed
154        let changed = state.value != value;
155
156        let ret = if changed {
157            let old = std::mem::replace(&mut state.value, value);
158            state.epoch += 1;
159            Ok(old)
160        } else {
161            Err(value)
162        };
163        drop(state); // No need to write anymore
164
165        // Notify watchers
166        if changed {
167            for watcher in self.shared.watchers.lock().expect("poisoned").drain(..) {
168                watcher.wake();
169            }
170        }
171        ret
172    }
173
174    /// Creates a [`Direct`] [`Watcher`], allowing the value to be observed, but not modified.
175    pub fn watch(&self) -> Direct<T> {
176        Direct {
177            state: self.shared.state(),
178            shared: Arc::downgrade(&self.shared),
179        }
180    }
181
182    /// Creates a [`LazyDirect`] [`Watcher`], allowing the value to be observed, but not modified.
183    ///
184    /// The [`LazyDirect`] watcher does not store the current value, making it smaller. If the watchable
185    /// is dropped, [`LazyDirect::get`] returns `T::default`.
186    pub fn watch_lazy(&self) -> LazyDirect<T>
187    where
188        T: Default,
189    {
190        LazyDirect {
191            epoch: self.shared.state().epoch,
192            shared: Arc::downgrade(&self.shared),
193        }
194    }
195
196    /// Creates a [`WeakWatcher`], which is a weak reference to the watchable's shared state.
197    ///
198    /// It has the size of a single pointer, and can be upgraded to a [`Direct`] or [`LazyDirect`].
199    pub fn weak_watcher(&self) -> WeakWatcher<T> {
200        WeakWatcher {
201            shared: Arc::downgrade(&self.shared),
202        }
203    }
204
205    /// Returns the currently stored value.
206    pub fn get(&self) -> T {
207        self.shared.get()
208    }
209
210    /// Returns true when there are any watchers actively listening on changes,
211    /// or false when all watchers have been dropped or none have been created yet.
212    pub fn has_watchers(&self) -> bool {
213        // `Watchable`s will increase the strong count
214        // `Direct`s watchers (which all watchers descend from) will increase the weak count
215        Arc::weak_count(&self.shared) != 0
216    }
217}
218
219impl<T> Drop for Watchable<T> {
220    fn drop(&mut self) {
221        let Ok(mut watchers) = self.shared.watchers.lock() else {
222            return; // Poisoned waking?
223        };
224        // Wake all watchers every time we drop.
225        // This allows us to notify `NextFut::poll`s that the underlying
226        // watchable might be dropped.
227        for watcher in watchers.drain(..) {
228            watcher.wake();
229        }
230    }
231}
232
233/// A handle to a value that's represented by one or more underlying [`Watchable`]s.
234///
235/// A [`Watcher`] can get the current value, and will be notified when the value changes.
236/// Only the most recent value is accessible, and if the threads with the underlying [`Watchable`]s
237/// change the value faster than the threads with the [`Watcher`] can keep up with, then
238/// it'll miss in-between values.
239/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always
240/// end up reporting the most recent state eventually.
241///
242/// Watchers can be modified via [`Watcher::map`] to observe a value derived from the original
243/// value via a function.
244///
245/// Watchers can be combined via [`Watcher::or`] to allow observing multiple values at once and
246/// getting an update in case any of the values updates.
247///
248/// One of the underlying [`Watchable`]s might already be dropped. In that case,
249/// the watcher will be "disconnected" and return [`Err(Disconnected)`](Disconnected)
250/// on some function calls or, when turned into a stream, that stream will end.
251/// This property can also be checked with [`Watcher::is_connected`].
252pub trait Watcher: Clone {
253    /// The type of value that can change.
254    ///
255    /// We require `Clone`, because we need to be able to make
256    /// the values have a lifetime that's detached from the original [`Watchable`]'s
257    /// lifetime.
258    ///
259    /// We require `Eq`, to be able to check whether the value actually changed or
260    /// not, so we can notify or not notify accordingly.
261    type Value: Clone + Eq;
262
263    /// Returns the current state of the underlying value.
264    ///
265    /// If any of the underlying [`Watchable`] values has been dropped, then this
266    /// might return an outdated value for that watchable, specifically, the latest
267    /// value that was fetched for that watchable, as opposed to the latest value
268    /// that was set on the watchable before it was dropped.
269    fn get(&mut self) -> Self::Value;
270
271    /// Whether this watcher is still connected to all of its underlying [`Watchable`]s.
272    ///
273    /// Returns false when any of the underlying watchables has been dropped.
274    fn is_connected(&self) -> bool;
275
276    /// Polls for the next value, or returns [`Disconnected`] if one of the underlying
277    /// [`Watchable`]s has been dropped.
278    fn poll_updated(
279        &mut self,
280        cx: &mut task::Context<'_>,
281    ) -> Poll<Result<Self::Value, Disconnected>>;
282
283    /// Returns a future completing with `Ok(value)` once a new value is set, or with
284    /// [`Err(Disconnected)`](Disconnected) if the connected [`Watchable`] was dropped.
285    ///
286    /// # Cancel Safety
287    ///
288    /// The returned future is cancel-safe.
289    fn updated(&mut self) -> NextFut<'_, Self> {
290        NextFut { watcher: self }
291    }
292
293    /// Returns a future completing once the value is set to [`Some`] value.
294    ///
295    /// If the current value is [`Some`] value, this future will resolve immediately.
296    ///
297    /// This is a utility for the common case of storing an [`Option`] inside a
298    /// [`Watchable`].
299    ///
300    /// # Cancel Safety
301    ///
302    /// The returned future is cancel-safe.
303    fn initialized<T, W>(&mut self) -> InitializedFut<'_, T, W, Self>
304    where
305        W: Nullable<T>,
306        Self: Watcher<Value = W>,
307    {
308        InitializedFut {
309            initial: self.get().into_option(),
310            watcher: self,
311        }
312    }
313
314    /// Returns a stream which will yield the most recent values as items.
315    ///
316    /// The first item of the stream is the current value, so that this stream can be easily
317    /// used to operate on the most recent value.
318    ///
319    /// Note however, that only the last item is stored.  If the stream is not polled when an
320    /// item is available it can be replaced with another item by the time it is polled.
321    ///
322    /// This stream ends once the original [`Watchable`] has been dropped.
323    ///
324    /// # Cancel Safety
325    ///
326    /// The returned stream is cancel-safe.
327    fn stream(mut self) -> Stream<Self>
328    where
329        Self: Unpin,
330    {
331        Stream {
332            initial: Some(self.get()),
333            watcher: self,
334        }
335    }
336
337    /// Returns a stream which will yield the most recent values as items, starting from
338    /// the next unobserved future value.
339    ///
340    /// This means this stream will only yield values when the watched value changes,
341    /// the value stored at the time the stream is created is not yielded.
342    ///
343    /// Note however, that only the last item is stored.  If the stream is not polled when an
344    /// item is available it can be replaced with another item by the time it is polled.
345    ///
346    /// This stream ends once the original [`Watchable`] has been dropped.
347    ///
348    /// # Cancel Safety
349    ///
350    /// The returned stream is cancel-safe.
351    fn stream_updates_only(self) -> Stream<Self>
352    where
353        Self: Unpin,
354    {
355        Stream {
356            initial: None,
357            watcher: self,
358        }
359    }
360
361    /// Maps this watcher with a function that transforms the observed values.
362    ///
363    /// The returned watcher will only register updates, when the *mapped* value
364    /// observably changes.
365    fn map<T: Clone + Eq>(
366        mut self,
367        map: impl Fn(Self::Value) -> T + Send + Sync + 'static,
368    ) -> Map<Self, T> {
369        Map {
370            current: (map)(self.get()),
371            map: Arc::new(map),
372            watcher: self,
373        }
374    }
375
376    /// Returns a watcher that updates every time this or the other watcher
377    /// updates, and yields both watcher's items together when that happens.
378    fn or<W: Watcher>(self, other: W) -> (Self, W) {
379        (self, other)
380    }
381}
382
383/// A weak reference to a watchable value that can be upgraded to a full watcher later.
384#[derive(Debug, Clone)]
385pub struct WeakWatcher<T> {
386    shared: Weak<Shared<T>>,
387}
388
389impl<T: Clone + Eq> WeakWatcher<T> {
390    /// Upgrade to a [`Direct`] watcher, allowing to observe the value.
391    ///
392    /// Returns `None` if the underlying [`Watchable`] has been dropped.
393    pub fn upgrade(&self) -> Option<Direct<T>> {
394        let shared = self.shared.upgrade()?;
395        let state = shared.state();
396        Some(Direct {
397            state,
398            shared: self.shared.clone(),
399        })
400    }
401}
402
403impl<T: Clone + Default + Eq> WeakWatcher<T> {
404    /// Upgrade to a [`LazyDirect`] watcher, allowing to observe the value.
405    ///
406    /// The [`LazyDirect`] fetches the value on demand, and thus `lazy_upgrade` succeeds
407    /// even if the underlying watchable has been dropped.
408    pub fn upgrade_lazy(&self) -> LazyDirect<T> {
409        LazyDirect {
410            epoch: 0,
411            shared: self.shared.clone(),
412        }
413    }
414}
415
416/// The immediate, direct observer of a [`Watchable`] value.
417///
418/// This type is mainly used via the [`Watcher`] interface.
419#[derive(Debug, Clone)]
420pub struct Direct<T> {
421    state: State<T>,
422    shared: Weak<Shared<T>>,
423}
424
425impl<T: Clone + Eq> Watcher for Direct<T> {
426    type Value = T;
427
428    fn get(&mut self) -> Self::Value {
429        if let Some(shared) = self.shared.upgrade() {
430            self.state = shared.state();
431        }
432        self.state.value.clone()
433    }
434
435    fn is_connected(&self) -> bool {
436        self.shared.upgrade().is_some()
437    }
438
439    fn poll_updated(
440        &mut self,
441        cx: &mut task::Context<'_>,
442    ) -> Poll<Result<Self::Value, Disconnected>> {
443        let Some(shared) = self.shared.upgrade() else {
444            return Poll::Ready(Err(Disconnected));
445        };
446        self.state = ready!(shared.poll_updated(cx, self.state.epoch));
447        Poll::Ready(Ok(self.state.value.clone()))
448    }
449}
450
451/// A lazy direct observer of a [`Watchable`] value.
452///
453/// Other than [`Direct`] it does not store the current value. It needs `T` to implement [`Default`].
454/// If the watchable is dropped, [`Self::get`] will return `T::default()`.
455///
456/// This type is mainly used via the [`Watcher`] interface.
457#[derive(Debug, Clone)]
458pub struct LazyDirect<T> {
459    epoch: u64,
460    shared: Weak<Shared<T>>,
461}
462
463impl<T: Clone + Default + Eq> Watcher for LazyDirect<T> {
464    type Value = T;
465
466    fn get(&mut self) -> Self::Value {
467        if let Some(shared) = self.shared.upgrade() {
468            let state = shared.state();
469            self.epoch = state.epoch;
470            state.value
471        } else {
472            T::default()
473        }
474    }
475
476    fn is_connected(&self) -> bool {
477        self.shared.upgrade().is_some()
478    }
479
480    fn poll_updated(
481        &mut self,
482        cx: &mut task::Context<'_>,
483    ) -> Poll<Result<Self::Value, Disconnected>> {
484        let Some(shared) = self.shared.upgrade() else {
485            return Poll::Ready(Err(Disconnected));
486        };
487        let state = ready!(shared.poll_updated(cx, self.epoch));
488        self.epoch = state.epoch;
489        Poll::Ready(Ok(state.value))
490    }
491}
492
493impl<S: Watcher, T: Watcher> Watcher for (S, T) {
494    type Value = (S::Value, T::Value);
495
496    fn get(&mut self) -> Self::Value {
497        (self.0.get(), self.1.get())
498    }
499
500    fn is_connected(&self) -> bool {
501        self.0.is_connected() && self.1.is_connected()
502    }
503
504    fn poll_updated(
505        &mut self,
506        cx: &mut task::Context<'_>,
507    ) -> Poll<Result<Self::Value, Disconnected>> {
508        let poll_0 = self.0.poll_updated(cx)?;
509        let poll_1 = self.1.poll_updated(cx)?;
510        match (poll_0, poll_1) {
511            (Poll::Ready(s), Poll::Ready(t)) => Poll::Ready(Ok((s, t))),
512            (Poll::Ready(s), Poll::Pending) => Poll::Ready(Ok((s, self.1.get()))),
513            (Poll::Pending, Poll::Ready(t)) => Poll::Ready(Ok((self.0.get(), t))),
514            (Poll::Pending, Poll::Pending) => Poll::Pending,
515        }
516    }
517}
518
519impl<S: Watcher, T: Watcher, U: Watcher> Watcher for (S, T, U) {
520    type Value = (S::Value, T::Value, U::Value);
521
522    fn get(&mut self) -> Self::Value {
523        (self.0.get(), self.1.get(), self.2.get())
524    }
525
526    fn is_connected(&self) -> bool {
527        self.0.is_connected() && self.1.is_connected() && self.2.is_connected()
528    }
529
530    fn poll_updated(
531        &mut self,
532        cx: &mut task::Context<'_>,
533    ) -> Poll<Result<Self::Value, Disconnected>> {
534        let poll_0 = self.0.poll_updated(cx)?;
535        let poll_1 = self.1.poll_updated(cx)?;
536        let poll_2 = self.2.poll_updated(cx)?;
537
538        if poll_0.is_pending() && poll_1.is_pending() && poll_2.is_pending() {
539            Poll::Pending
540        } else {
541            fn to_option<T>(poll: Poll<T>) -> Option<T> {
542                match poll {
543                    Poll::Ready(t) => Some(t),
544                    Poll::Pending => None,
545                }
546            }
547
548            let s = to_option(poll_0).unwrap_or_else(|| self.0.get());
549            let t = to_option(poll_1).unwrap_or_else(|| self.1.get());
550            let u = to_option(poll_2).unwrap_or_else(|| self.2.get());
551            Poll::Ready(Ok((s, t, u)))
552        }
553    }
554}
555
556/// Combinator to join two watchers
557#[derive(Debug, Clone)]
558pub struct Join<T: Clone + Eq, W: Watcher<Value = T>> {
559    watchers: Vec<W>,
560}
561impl<T: Clone + Eq, W: Watcher<Value = T>> Join<T, W> {
562    /// Joins a set of watchers into a single watcher
563    pub fn new(watchers: impl Iterator<Item = W>) -> Self {
564        let watchers: Vec<W> = watchers.into_iter().collect();
565
566        Self { watchers }
567    }
568}
569
570impl<T: Clone + Eq, W: Watcher<Value = T>> Watcher for Join<T, W> {
571    type Value = Vec<T>;
572
573    fn get(&mut self) -> Self::Value {
574        let mut out = Vec::with_capacity(self.watchers.len());
575        for watcher in &mut self.watchers {
576            out.push(watcher.get());
577        }
578
579        out
580    }
581
582    fn is_connected(&self) -> bool {
583        self.watchers.iter().all(|w| w.is_connected())
584    }
585
586    fn poll_updated(
587        &mut self,
588        cx: &mut task::Context<'_>,
589    ) -> Poll<Result<Self::Value, Disconnected>> {
590        let mut new_value = None;
591        for (i, watcher) in self.watchers.iter_mut().enumerate() {
592            match watcher.poll_updated(cx) {
593                Poll::Pending => {}
594                Poll::Ready(Ok(value)) => {
595                    new_value.replace((i, value));
596                    break;
597                }
598                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
599            }
600        }
601
602        if let Some((j, new_value)) = new_value {
603            let mut new = Vec::with_capacity(self.watchers.len());
604            for (i, watcher) in self.watchers.iter_mut().enumerate() {
605                if i != j {
606                    new.push(watcher.get());
607                } else {
608                    new.push(new_value.clone());
609                }
610            }
611            Poll::Ready(Ok(new))
612        } else {
613            Poll::Pending
614        }
615    }
616}
617
618/// Wraps a [`Watcher`] to allow observing a derived value.
619///
620/// See [`Watcher::map`].
621#[derive(derive_more::Debug, Clone)]
622pub struct Map<W: Watcher, T: Clone + Eq> {
623    #[debug("Arc<dyn Fn(W::Value) -> T + 'static>")]
624    map: Arc<dyn Fn(W::Value) -> T + Send + Sync + 'static>,
625    watcher: W,
626    current: T,
627}
628
629impl<W: Watcher, T: Clone + Eq> Watcher for Map<W, T> {
630    type Value = T;
631
632    fn get(&mut self) -> Self::Value {
633        (self.map)(self.watcher.get())
634    }
635
636    fn is_connected(&self) -> bool {
637        self.watcher.is_connected()
638    }
639
640    fn poll_updated(
641        &mut self,
642        cx: &mut task::Context<'_>,
643    ) -> Poll<Result<Self::Value, Disconnected>> {
644        loop {
645            let value = ready!(self.watcher.poll_updated(cx)?);
646            let mapped = (self.map)(value);
647            if mapped != self.current {
648                self.current = mapped.clone();
649                return Poll::Ready(Ok(mapped));
650            } else {
651                self.current = mapped;
652            }
653        }
654    }
655}
656
657/// Future returning the next item after the current one in a [`Watcher`].
658///
659/// See [`Watcher::updated`].
660///
661/// # Cancel Safety
662///
663/// This future is cancel-safe.
664#[derive(Debug)]
665pub struct NextFut<'a, W: Watcher> {
666    watcher: &'a mut W,
667}
668
669impl<W: Watcher> Future for NextFut<'_, W> {
670    type Output = Result<W::Value, Disconnected>;
671
672    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
673        self.watcher.poll_updated(cx)
674    }
675}
676
677/// Future returning the current or next value that's [`Some`] value.
678/// in a [`Watcher`].
679///
680/// See [`Watcher::initialized`].
681///
682/// # Cancel Safety
683///
684/// This Future is cancel-safe.
685#[derive(Debug)]
686pub struct InitializedFut<'a, T, V: Nullable<T>, W: Watcher<Value = V>> {
687    initial: Option<T>,
688    watcher: &'a mut W,
689}
690
691impl<T: Clone + Eq + Unpin, V: Nullable<T>, W: Watcher<Value = V> + Unpin> Future
692    for InitializedFut<'_, T, V, W>
693{
694    type Output = T;
695
696    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
697        if let Some(value) = self.as_mut().initial.take() {
698            return Poll::Ready(value);
699        }
700        loop {
701            let Ok(value) = ready!(self.as_mut().watcher.poll_updated(cx)) else {
702                // The value will never be initialized
703                return Poll::Pending;
704            };
705            if let Some(value) = value.into_option() {
706                return Poll::Ready(value);
707            }
708        }
709    }
710}
711
712/// A stream for a [`Watcher`]'s next values.
713///
714/// See [`Watcher::stream`] and [`Watcher::stream_updates_only`].
715///
716/// # Cancel Safety
717///
718/// This stream is cancel-safe.
719#[derive(Debug, Clone)]
720pub struct Stream<W: Watcher + Unpin> {
721    initial: Option<W::Value>,
722    watcher: W,
723}
724
725impl<W: Watcher + Unpin> n0_future::Stream for Stream<W>
726where
727    W::Value: Unpin,
728{
729    type Item = W::Value;
730
731    fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
732        if let Some(value) = self.as_mut().initial.take() {
733            return Poll::Ready(Some(value));
734        }
735        match self.as_mut().watcher.poll_updated(cx) {
736            Poll::Ready(Ok(value)) => Poll::Ready(Some(value)),
737            Poll::Ready(Err(Disconnected)) => Poll::Ready(None),
738            Poll::Pending => Poll::Pending,
739        }
740    }
741}
742
743/// The error for when a [`Watcher`] is disconnected from its underlying
744/// [`Watchable`] value, because of that watchable having been dropped.
745#[derive(StackError)]
746#[error("Watcher lost connection to underlying Watchable, it was dropped")]
747pub struct Disconnected;
748
749// Private:
750
751const INITIAL_EPOCH: u64 = 1;
752
753/// The shared state for a [`Watchable`].
754#[derive(Debug, Default)]
755struct Shared<T> {
756    /// The value to be watched and its current epoch.
757    state: RwLock<State<T>>,
758    watchers: Mutex<VecDeque<Waker>>,
759}
760
761#[derive(Debug, Clone)]
762struct State<T> {
763    value: T,
764    epoch: u64,
765}
766
767impl<T: Default> Default for State<T> {
768    fn default() -> Self {
769        Self {
770            value: Default::default(),
771            epoch: INITIAL_EPOCH,
772        }
773    }
774}
775
776impl<T: Clone> Shared<T> {
777    /// Returns the value, initialized or not.
778    fn get(&self) -> T {
779        self.state.read().expect("poisoned").value.clone()
780    }
781
782    fn state(&self) -> State<T> {
783        self.state.read().expect("poisoned").clone()
784    }
785
786    fn poll_updated(&self, cx: &mut task::Context<'_>, last_epoch: u64) -> Poll<State<T>> {
787        {
788            let state = self.state.read().expect("poisoned");
789
790            // We might get spurious wakeups due to e.g. a second-to-last Watchable being dropped.
791            // This makes sure we don't accidentally return an update that's not actually an update.
792            if last_epoch < state.epoch {
793                return Poll::Ready(state.clone());
794            }
795        }
796
797        self.watchers
798            .lock()
799            .expect("poisoned")
800            .push_back(cx.waker().to_owned());
801
802        #[cfg(watcher_loom)]
803        loom::thread::yield_now();
804
805        // We check for an update again to prevent races between putting in wakers and looking for updates.
806        {
807            let state = self.state.read().expect("poisoned");
808
809            if last_epoch < state.epoch {
810                return Poll::Ready(state.clone());
811            }
812        }
813
814        Poll::Pending
815    }
816}
817
818#[cfg(test)]
819mod tests {
820
821    use n0_future::{future::poll_once, StreamExt};
822    use rand::{rng, Rng};
823    use tokio::{
824        task::JoinSet,
825        time::{Duration, Instant},
826    };
827    use tokio_util::sync::CancellationToken;
828
829    use super::*;
830
831    #[tokio::test]
832    async fn test_watcher() {
833        let cancel = CancellationToken::new();
834        let watchable = Watchable::new(17);
835
836        assert_eq!(watchable.watch().stream().next().await.unwrap(), 17);
837
838        let start = Instant::now();
839        // spawn watchers
840        let mut tasks = JoinSet::new();
841        for i in 0..3 {
842            let mut watch = watchable.watch().stream();
843            let cancel = cancel.clone();
844            tasks.spawn(async move {
845                println!("[{i}] spawn");
846                let mut expected_value = 17;
847                loop {
848                    tokio::select! {
849                        biased;
850                        Some(value) = &mut watch.next() => {
851                            println!("{:?} [{i}] update: {value}", start.elapsed());
852                            assert_eq!(value, expected_value);
853                            if expected_value == 17 {
854                                expected_value = 0;
855                            } else {
856                                expected_value += 1;
857                            }
858                        },
859                        _ = cancel.cancelled() => {
860                            println!("{:?} [{i}] cancel", start.elapsed());
861                            assert_eq!(expected_value, 10);
862                            break;
863                        }
864                    }
865                }
866            });
867        }
868        for i in 0..3 {
869            let mut watch = watchable.watch().stream_updates_only();
870            let cancel = cancel.clone();
871            tasks.spawn(async move {
872                println!("[{i}] spawn");
873                let mut expected_value = 0;
874                loop {
875                    tokio::select! {
876                        biased;
877                        Some(value) = watch.next() => {
878                            println!("{:?} [{i}] stream update: {value}", start.elapsed());
879                            assert_eq!(value, expected_value);
880                            expected_value += 1;
881                        },
882                        _ = cancel.cancelled() => {
883                            println!("{:?} [{i}] cancel", start.elapsed());
884                            assert_eq!(expected_value, 10);
885                            break;
886                        }
887                        else => {
888                            panic!("stream died");
889                        }
890                    }
891                }
892            });
893        }
894
895        // set value
896        for next_value in 0..10 {
897            let sleep = Duration::from_nanos(rng().random_range(0..100_000_000));
898            println!("{:?} sleep {sleep:?}", start.elapsed());
899            tokio::time::sleep(sleep).await;
900
901            let changed = watchable.set(next_value);
902            println!("{:?} set {next_value} changed={changed:?}", start.elapsed());
903        }
904
905        println!("cancel");
906        cancel.cancel();
907        while let Some(res) = tasks.join_next().await {
908            res.expect("task failed");
909        }
910    }
911
912    #[test]
913    fn test_get() {
914        let watchable = Watchable::new(None);
915        assert!(watchable.get().is_none());
916
917        watchable.set(Some(1u8)).ok();
918        assert_eq!(watchable.get(), Some(1u8));
919    }
920
921    #[tokio::test]
922    async fn test_initialize() {
923        let watchable = Watchable::new(None);
924
925        let mut watcher = watchable.watch();
926        let mut initialized = watcher.initialized();
927
928        let poll = poll_once(&mut initialized).await;
929        assert!(poll.is_none());
930
931        watchable.set(Some(1u8)).ok();
932
933        let poll = poll_once(&mut initialized).await;
934        assert_eq!(poll.unwrap(), 1u8);
935    }
936
937    #[tokio::test]
938    async fn test_initialize_already_init() {
939        let watchable = Watchable::new(Some(1u8));
940
941        let mut watcher = watchable.watch();
942        let mut initialized = watcher.initialized();
943
944        let poll = poll_once(&mut initialized).await;
945        assert_eq!(poll.unwrap(), 1u8);
946    }
947
948    #[test]
949    fn test_initialized_always_resolves() {
950        #[cfg(not(watcher_loom))]
951        use std::thread;
952
953        #[cfg(watcher_loom)]
954        use loom::thread;
955
956        let test_case = || {
957            let watchable = Watchable::<Option<u8>>::new(None);
958
959            let mut watch = watchable.watch();
960            let thread = thread::spawn(move || n0_future::future::block_on(watch.initialized()));
961
962            watchable.set(Some(42)).ok();
963
964            thread::yield_now();
965
966            let value: u8 = thread.join().unwrap();
967
968            assert_eq!(value, 42);
969        };
970
971        #[cfg(watcher_loom)]
972        loom::model(test_case);
973        #[cfg(not(watcher_loom))]
974        test_case();
975    }
976
977    #[tokio::test(flavor = "multi_thread")]
978    async fn test_update_cancel_safety() {
979        let watchable = Watchable::new(0);
980        let mut watch = watchable.watch();
981        const MAX: usize = 100_000;
982
983        let handle = tokio::spawn(async move {
984            let mut last_observed = 0;
985
986            while last_observed != MAX {
987                tokio::select! {
988                    val = watch.updated() => {
989                        let Ok(val) = val else {
990                            return;
991                        };
992
993                        assert_ne!(val, last_observed, "never observe the same value twice, even with cancellation");
994                        last_observed = val;
995                    }
996                    _ = tokio::time::sleep(Duration::from_micros(rng().random_range(0..10_000))) => {
997                        // We cancel the other future and start over again
998                        continue;
999                    }
1000                }
1001            }
1002        });
1003
1004        for i in 1..=MAX {
1005            watchable.set(i).ok();
1006            if rng().random_bool(0.2) {
1007                tokio::task::yield_now().await;
1008            }
1009        }
1010
1011        tokio::time::timeout(Duration::from_secs(10), handle)
1012            .await
1013            .unwrap()
1014            .unwrap()
1015    }
1016
1017    #[tokio::test]
1018    async fn test_join_simple() {
1019        let a = Watchable::new(1u8);
1020        let b = Watchable::new(1u8);
1021
1022        let mut ab = Join::new([a.watch(), b.watch()].into_iter());
1023
1024        let stream = ab.clone().stream();
1025        let handle = tokio::task::spawn(async move { stream.take(5).collect::<Vec<_>>().await });
1026
1027        // get
1028        assert_eq!(ab.get(), vec![1, 1]);
1029        // set a
1030        a.set(2u8).unwrap();
1031        tokio::task::yield_now().await;
1032        assert_eq!(ab.get(), vec![2, 1]);
1033        // set b
1034        b.set(3u8).unwrap();
1035        tokio::task::yield_now().await;
1036        assert_eq!(ab.get(), vec![2, 3]);
1037
1038        a.set(3u8).unwrap();
1039        tokio::task::yield_now().await;
1040        b.set(4u8).unwrap();
1041        tokio::task::yield_now().await;
1042
1043        let values = tokio::time::timeout(Duration::from_secs(5), handle)
1044            .await
1045            .unwrap()
1046            .unwrap();
1047        assert_eq!(
1048            values,
1049            vec![vec![1, 1], vec![2, 1], vec![2, 3], vec![3, 3], vec![3, 4]]
1050        );
1051    }
1052
1053    #[tokio::test]
1054    async fn test_updated_then_disconnect_then_get() {
1055        let watchable = Watchable::new(10);
1056        let mut watcher = watchable.watch();
1057        assert_eq!(watchable.get(), 10);
1058        watchable.set(42).ok();
1059        assert_eq!(watcher.updated().await.unwrap(), 42);
1060        drop(watchable);
1061        assert_eq!(watcher.get(), 42);
1062    }
1063
1064    #[tokio::test(start_paused = true)]
1065    async fn test_update_wakeup_on_watchable_drop() {
1066        let watchable = Watchable::new(10);
1067        let mut watcher = watchable.watch();
1068
1069        let start = Instant::now();
1070        let (_, result) = tokio::time::timeout(Duration::from_secs(2), async move {
1071            tokio::join!(
1072                async move {
1073                    tokio::time::sleep(Duration::from_secs(1)).await;
1074                    drop(watchable);
1075                },
1076                async move { watcher.updated().await }
1077            )
1078        })
1079        .await
1080        .expect("watcher never updated");
1081        // We should've updated 1s after start, since that's when the watchable was dropped.
1082        // If this is 2s, then the watchable dropping didn't wake up the `Watcher::updated` future.
1083        assert_eq!(start.elapsed(), Duration::from_secs(1));
1084        assert!(result.is_err());
1085    }
1086
1087    #[tokio::test(start_paused = true)]
1088    async fn test_update_wakeup_always_a_change() {
1089        let watchable = Watchable::new(10);
1090        let mut watcher = watchable.watch();
1091
1092        let task = tokio::spawn(async move {
1093            let mut last_value = watcher.get();
1094            let mut values = Vec::new();
1095            while let Ok(value) = watcher.updated().await {
1096                values.push(value);
1097                if last_value == value {
1098                    return Err("value duplicated");
1099                }
1100                last_value = value;
1101            }
1102            Ok(values)
1103        });
1104
1105        // wait for the task to get set up and polled till pending for once
1106        tokio::time::sleep(Duration::from_millis(100)).await;
1107
1108        watchable.set(11).ok();
1109        tokio::time::sleep(Duration::from_millis(100)).await;
1110        let clone = watchable.clone();
1111        drop(clone); // this shouldn't trigger an update
1112        tokio::time::sleep(Duration::from_millis(100)).await;
1113        for i in 1..=10 {
1114            watchable.set(i + 11).ok();
1115            tokio::time::sleep(Duration::from_millis(100)).await;
1116        }
1117        drop(watchable);
1118
1119        let values = task
1120            .await
1121            .expect("task panicked")
1122            .expect("value duplicated");
1123        assert_eq!(values, vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]);
1124    }
1125
1126    #[test]
1127    fn test_has_watchers() {
1128        let a = Watchable::new(1u8);
1129        assert!(!a.has_watchers());
1130        let b = a.clone();
1131        assert!(!a.has_watchers());
1132        assert!(!b.has_watchers());
1133
1134        let watcher = a.watch();
1135        assert!(a.has_watchers());
1136        assert!(b.has_watchers());
1137
1138        drop(watcher);
1139
1140        assert!(!a.has_watchers());
1141        assert!(!b.has_watchers());
1142    }
1143
1144    #[tokio::test]
1145    async fn test_three_watchers_basic() {
1146        let watchable = Watchable::new(1u8);
1147
1148        let mut w1 = watchable.watch();
1149        let mut w2 = watchable.watch();
1150        let mut w3 = watchable.watch();
1151
1152        // All see the initial value
1153
1154        assert_eq!(w1.get(), 1);
1155        assert_eq!(w2.get(), 1);
1156        assert_eq!(w3.get(), 1);
1157
1158        // Change  value
1159        watchable.set(42).unwrap();
1160
1161        // All watchers get notified
1162        assert_eq!(w1.updated().await.unwrap(), 42);
1163        assert_eq!(w2.updated().await.unwrap(), 42);
1164        assert_eq!(w3.updated().await.unwrap(), 42);
1165    }
1166
1167    #[tokio::test]
1168    async fn test_three_watchers_skip_intermediate() {
1169        let watchable = Watchable::new(0u8);
1170        let mut watcher = watchable.watch();
1171
1172        watchable.set(1).ok();
1173        watchable.set(2).ok();
1174        watchable.set(3).ok();
1175        watchable.set(4).ok();
1176
1177        let value = watcher.updated().await.unwrap();
1178
1179        assert_eq!(value, 4);
1180    }
1181
1182    #[tokio::test]
1183    async fn test_three_watchers_with_streams() {
1184        let watchable = Watchable::new(10u8);
1185
1186        let mut stream1 = watchable.watch().stream();
1187        let mut stream2 = watchable.watch().stream();
1188        let mut stream3 = watchable.watch().stream_updates_only();
1189
1190        assert_eq!(stream1.next().await.unwrap(), 10);
1191        assert_eq!(stream2.next().await.unwrap(), 10);
1192
1193        // Update the value
1194        watchable.set(20).ok();
1195
1196        // All streams see the update
1197        assert_eq!(stream1.next().await.unwrap(), 20);
1198        assert_eq!(stream2.next().await.unwrap(), 20);
1199        assert_eq!(stream3.next().await.unwrap(), 20);
1200    }
1201
1202    #[tokio::test]
1203    async fn test_three_watchers_independent() {
1204        let watchable = Watchable::new(0u8);
1205
1206        let mut fast_watcher = watchable.watch();
1207        let mut slow_watcher = watchable.watch();
1208        let mut lazy_watcher = watchable.watch();
1209
1210        watchable.set(1).ok();
1211        assert_eq!(fast_watcher.updated().await.unwrap(), 1);
1212
1213        // More updates happen
1214        watchable.set(2).ok();
1215        watchable.set(3).ok();
1216
1217        assert_eq!(slow_watcher.updated().await.unwrap(), 3);
1218        assert_eq!(lazy_watcher.get(), 3);
1219    }
1220
1221    #[tokio::test]
1222    async fn test_combine_three_watchers() {
1223        let a = Watchable::new(1u8);
1224        let b = Watchable::new(2u8);
1225        let c = Watchable::new(3u8);
1226
1227        let mut combined = (a.watch(), b.watch(), c.watch());
1228
1229        assert_eq!(combined.get(), (1, 2, 3));
1230
1231        // Update one
1232        b.set(20).ok();
1233
1234        assert_eq!(combined.updated().await.unwrap(), (1, 20, 3));
1235
1236        c.set(30).ok();
1237        assert_eq!(combined.updated().await.unwrap(), (1, 20, 30));
1238    }
1239
1240    #[tokio::test]
1241    async fn test_three_watchers_disconnection() {
1242        let watchable = Watchable::new(5u8);
1243
1244        // All connected
1245        let mut w1 = watchable.watch();
1246        let mut w2 = watchable.watch();
1247        let mut w3 = watchable.watch();
1248
1249        // Drop the watchable
1250        drop(watchable);
1251
1252        // All become disconnected
1253        assert!(!w1.is_connected());
1254        assert!(!w2.is_connected());
1255        assert!(!w3.is_connected());
1256
1257        // Can still get last known value
1258        assert_eq!(w1.get(), 5);
1259        assert_eq!(w2.get(), 5);
1260
1261        // But updates fail
1262        assert!(w3.updated().await.is_err());
1263    }
1264
1265    #[tokio::test]
1266    async fn test_three_watchers_truly_concurrent() {
1267        use tokio::time::sleep;
1268        let watchable = Watchable::new(0u8);
1269
1270        // Spawn three READER tasks
1271        let mut reader_handles = vec![];
1272        for i in 0..3 {
1273            let mut watcher = watchable.watch();
1274            let handle = tokio::spawn(async move {
1275                let mut values = vec![];
1276                // Collect up to 5 updates
1277                for _ in 0..5 {
1278                    if let Ok(value) = watcher.updated().await {
1279                        values.push(value);
1280                    } else {
1281                        break;
1282                    }
1283                }
1284                (i, values)
1285            });
1286            reader_handles.push(handle);
1287        }
1288
1289        // Spawn three WRITER tasks that update concurrently
1290        let mut writer_handles = vec![];
1291        for i in 0..3 {
1292            let watchable_clone = watchable.clone();
1293            let handle = tokio::spawn(async move {
1294                for j in 0..5 {
1295                    let value = (i * 10) + j;
1296                    watchable_clone.set(value).ok();
1297                    sleep(Duration::from_millis(5)).await;
1298                }
1299            });
1300            writer_handles.push(handle);
1301        }
1302
1303        // Wait for writers to finish
1304        for handle in writer_handles {
1305            handle.await.unwrap();
1306        }
1307
1308        // Wait for readers and check results
1309        for handle in reader_handles {
1310            let (task_id, values) = handle.await.unwrap();
1311            println!("Reader {}: saw values {:?}", task_id, values);
1312            assert!(!values.is_empty());
1313        }
1314    }
1315
1316    #[test]
1317    fn test_lazy_direct() {
1318        let a = Watchable::new(1u8);
1319        let mut w1 = a.watch_lazy();
1320        let mut w2 = a.watch_lazy();
1321        assert_eq!(w1.get(), 1u8);
1322        assert_eq!(w2.get(), 1u8);
1323        a.set(2u8).unwrap();
1324        assert_eq!(w1.get(), 2u8);
1325        assert_eq!(w2.get(), 2u8);
1326        let mut s1 = w1.stream_updates_only();
1327        a.set(3u8).unwrap();
1328        assert_eq!(n0_future::future::now_or_never(s1.next()), Some(Some(3u8)));
1329        assert_eq!(w2.get(), 3u8);
1330        drop(a);
1331        assert_eq!(n0_future::future::now_or_never(s1.next()), Some(None));
1332        assert_eq!(w2.get(), 0u8);
1333    }
1334}