iroh_gossip/proto/
util.rs

1//! Utilities used in the protocol implementation
2
3use std::{
4    collections::{hash_map, BinaryHeap, HashMap},
5    hash::Hash,
6};
7
8use n0_future::time::Instant;
9use rand::{
10    seq::{IteratorRandom, SliceRandom},
11    Rng,
12};
13
14/// Implement methods, display, debug and conversion traits for 32 byte identifiers.
15macro_rules! idbytes_impls {
16    ($ty:ty, $name:expr) => {
17        impl $ty {
18            /// Create from a byte array.
19            pub const fn from_bytes(bytes: [u8; 32]) -> Self {
20                Self(bytes)
21            }
22
23            /// Get as byte slice.
24            pub fn as_bytes(&self) -> &[u8; 32] {
25                &self.0
26            }
27        }
28
29        impl<T: ::std::convert::Into<[u8; 32]>> ::std::convert::From<T> for $ty {
30            fn from(value: T) -> Self {
31                Self::from_bytes(value.into())
32            }
33        }
34
35        impl ::std::fmt::Display for $ty {
36            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
37                write!(f, "{}", ::hex::encode(&self.0))
38            }
39        }
40
41        impl ::std::fmt::Debug for $ty {
42            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
43                write!(f, "{}({})", $name, ::hex::encode(&self.0))
44            }
45        }
46
47        impl ::std::str::FromStr for $ty {
48            type Err = ::hex::FromHexError;
49            fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
50                let mut bytes = [0u8; 32];
51                ::hex::decode_to_slice(s, &mut bytes)?;
52                Ok(Self::from_bytes(bytes))
53            }
54        }
55
56        impl ::std::convert::AsRef<[u8]> for $ty {
57            fn as_ref(&self) -> &[u8] {
58                &self.0
59            }
60        }
61
62        impl ::std::convert::AsRef<[u8; 32]> for $ty {
63            fn as_ref(&self) -> &[u8; 32] {
64                &self.0
65            }
66        }
67    };
68}
69
70pub(crate) use idbytes_impls;
71
72/// A hash set where the iteration order of the values is independent of their
73/// hash values.
74///
75/// This is wrapper around [indexmap::IndexSet] which couple of utility methods
76/// to randomly select elements from the set.
77#[derive(Default, Debug, Clone, derive_more::Deref)]
78pub(crate) struct IndexSet<T> {
79    inner: indexmap::IndexSet<T>,
80}
81
82impl<T: Hash + Eq> PartialEq for IndexSet<T> {
83    fn eq(&self, other: &Self) -> bool {
84        self.inner == other.inner
85    }
86}
87
88impl<T: Hash + Eq + PartialEq> IndexSet<T> {
89    pub fn new() -> Self {
90        Self {
91            inner: indexmap::IndexSet::new(),
92        }
93    }
94
95    pub fn insert(&mut self, value: T) -> bool {
96        self.inner.insert(value)
97    }
98
99    /// Remove a random element from the set.
100    pub fn remove_random<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Option<T> {
101        self.pick_random_index(rng)
102            .and_then(|idx| self.inner.shift_remove_index(idx))
103    }
104
105    /// Pick a random element from the set.
106    pub fn pick_random<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<&T> {
107        self.pick_random_index(rng)
108            .and_then(|idx| self.inner.get_index(idx))
109    }
110
111    /// Pick a random element from the set, but not any of the elements in `without`.
112    pub fn pick_random_without<R: Rng + ?Sized>(&self, without: &[&T], rng: &mut R) -> Option<&T> {
113        self.iter().filter(|x| !without.contains(x)).choose(rng)
114    }
115
116    /// Pick a random index for an element in the set.
117    pub fn pick_random_index<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<usize> {
118        if self.is_empty() {
119            None
120        } else {
121            Some(rng.random_range(0..self.inner.len()))
122        }
123    }
124
125    /// Remove an element from the set.
126    ///
127    /// NOTE: the value is removed by swapping it with the last element of the set and popping it off.
128    /// **This modifies the order of element by moving the last element**
129    pub fn remove(&mut self, value: &T) -> Option<T> {
130        self.inner.swap_remove_full(value).map(|(_i, v)| v)
131    }
132
133    /// Remove an element from the set by its index.
134    ///
135    /// NOTE: the value is removed by swapping it with the last element of the set and popping it off.
136    /// **This modifies the order of element by moving the last element**
137    pub fn remove_index(&mut self, index: usize) -> Option<T> {
138        self.inner.swap_remove_index(index)
139    }
140
141    /// Create an iterator over the set in the order of insertion, while skipping the element in
142    /// `without`.
143    pub fn iter_without<'a>(&'a self, value: &'a T) -> impl Iterator<Item = &'a T> {
144        self.iter().filter(move |x| *x != value)
145    }
146}
147
148impl<T> IndexSet<T>
149where
150    T: Hash + Eq + Clone,
151{
152    /// Create a vector of all elements in the set in random order.
153    pub fn shuffled<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T> {
154        let mut items: Vec<_> = self.inner.iter().cloned().collect();
155        items.shuffle(rng);
156        items
157    }
158
159    /// Create a vector of all elements in the set in random order, and shorten to
160    /// the first `len` elements after shuffling.
161    pub fn shuffled_and_capped<R: Rng + ?Sized>(&self, len: usize, rng: &mut R) -> Vec<T> {
162        let mut items = self.shuffled(rng);
163        items.truncate(len);
164        items
165    }
166
167    /// Create a vector of the elements in the set in random order while omitting
168    /// the elements in `without`.
169    pub fn shuffled_without<R: Rng + ?Sized>(&self, without: &[&T], rng: &mut R) -> Vec<T> {
170        let mut items = self
171            .inner
172            .iter()
173            .filter(|x| !without.contains(x))
174            .cloned()
175            .collect::<Vec<_>>();
176        items.shuffle(rng);
177        items
178    }
179
180    /// Create a vector of the elements in the set in random order while omitting
181    /// the elements in `without`, and shorten to the first `len` elements.
182    pub fn shuffled_without_and_capped<R: Rng + ?Sized>(
183        &self,
184        without: &[&T],
185        len: usize,
186        rng: &mut R,
187    ) -> Vec<T> {
188        let mut items = self.shuffled_without(without, rng);
189        items.truncate(len);
190        items
191    }
192}
193
194impl<T> IntoIterator for IndexSet<T> {
195    type Item = T;
196    type IntoIter = <indexmap::IndexSet<T> as IntoIterator>::IntoIter;
197    fn into_iter(self) -> Self::IntoIter {
198        self.inner.into_iter()
199    }
200}
201
202impl<T> FromIterator<T> for IndexSet<T>
203where
204    T: Hash + Eq,
205{
206    fn from_iter<I: IntoIterator<Item = T>>(iterable: I) -> Self {
207        IndexSet {
208            inner: indexmap::IndexSet::from_iter(iterable),
209        }
210    }
211}
212
213/// A [`BinaryHeap`] with entries sorted by [`Instant`]. Allows to process expired items.
214#[derive(Debug)]
215pub struct TimerMap<T> {
216    heap: BinaryHeap<TimerMapEntry<T>>,
217    seq: u64,
218}
219
220// Can't derive default because we don't want a `T: Default` bound.
221impl<T> Default for TimerMap<T> {
222    fn default() -> Self {
223        Self {
224            heap: Default::default(),
225            seq: 0,
226        }
227    }
228}
229
230impl<T> TimerMap<T> {
231    /// Create a new, empty TimerMap.
232    pub fn new() -> Self {
233        Self::default()
234    }
235
236    /// Insert a new entry at the specified instant.
237    pub fn insert(&mut self, instant: Instant, item: T) {
238        let seq = self.seq;
239        self.seq += 1;
240        let entry = TimerMapEntry {
241            seq,
242            time: instant,
243            item,
244        };
245        self.heap.push(entry);
246    }
247
248    /// Remove and return all entries before and equal to `from`.
249    pub fn drain_until(
250        &mut self,
251        from: &Instant,
252    ) -> impl Iterator<Item = (Instant, T)> + '_ + use<'_, T> {
253        let from = *from;
254        std::iter::from_fn(move || self.pop_before(from))
255    }
256
257    /// Pop the first entry, if equal or before `limit`.
258    pub fn pop_before(&mut self, limit: Instant) -> Option<(Instant, T)> {
259        match self.heap.peek() {
260            Some(item) if item.time <= limit => self.heap.pop().map(|item| (item.time, item.item)),
261            _ => None,
262        }
263    }
264
265    /// Get a reference to the earliest entry in the `TimerMap`.
266    pub fn first(&self) -> Option<&Instant> {
267        self.heap.peek().map(|x| &x.time)
268    }
269
270    #[cfg(test)]
271    fn to_vec(&self) -> Vec<(Instant, T)>
272    where
273        T: Clone,
274    {
275        self.heap
276            .clone()
277            .into_sorted_vec()
278            .into_iter()
279            .rev()
280            .map(|x| (x.time, x.item))
281            .collect()
282    }
283}
284
285#[derive(Debug, Clone)]
286struct TimerMapEntry<T> {
287    time: Instant,
288    seq: u64,
289    item: T,
290}
291
292impl<T> PartialEq for TimerMapEntry<T> {
293    fn eq(&self, other: &Self) -> bool {
294        self.time == other.time && self.seq == other.seq
295    }
296}
297
298impl<T> Eq for TimerMapEntry<T> {}
299
300impl<T> PartialOrd for TimerMapEntry<T> {
301    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
302        Some(self.cmp(other))
303    }
304}
305
306impl<T> Ord for TimerMapEntry<T> {
307    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
308        self.time
309            .cmp(&other.time)
310            .reverse()
311            .then_with(|| self.seq.cmp(&other.seq).reverse())
312    }
313}
314
315/// A hash map where entries expire after a time
316#[derive(Debug)]
317pub struct TimeBoundCache<K, V> {
318    map: HashMap<K, (Instant, V)>,
319    expiry: TimerMap<K>,
320}
321
322impl<K, V> Default for TimeBoundCache<K, V> {
323    fn default() -> Self {
324        Self {
325            map: Default::default(),
326            expiry: Default::default(),
327        }
328    }
329}
330
331impl<K: Hash + Eq + Clone, V> TimeBoundCache<K, V> {
332    /// Insert an item into the cache, marked with an expiration time.
333    pub fn insert(&mut self, key: K, value: V, expires: Instant) {
334        self.map.insert(key.clone(), (expires, value));
335        self.expiry.insert(expires, key);
336    }
337
338    /// Returns `true` if the map contains a value for the specified key.
339    pub fn contains_key(&self, key: &K) -> bool {
340        self.map.contains_key(key)
341    }
342
343    /// Get the number of entries in the cache.
344    pub fn len(&self) -> usize {
345        self.map.len()
346    }
347
348    /// Returns `true` if the map contains no elements.
349    pub fn is_empty(&self) -> bool {
350        self.map.is_empty()
351    }
352
353    /// Get an item from the cache.
354    pub fn get(&self, key: &K) -> Option<&V> {
355        self.map.get(key).map(|(_expires, value)| value)
356    }
357
358    /// Get the expiration time for an item.
359    pub fn expires(&self, key: &K) -> Option<&Instant> {
360        self.map.get(key).map(|(expires, _value)| expires)
361    }
362
363    /// Iterate over all items in the cache.
364    pub fn iter(&self) -> impl Iterator<Item = (&K, &V, &Instant)> {
365        self.map.iter().map(|(k, (expires, v))| (k, v, expires))
366    }
367
368    /// Remove all entries with an expiry instant lower or equal to `instant`.
369    ///
370    /// Returns the number of items that were removed.
371    pub fn expire_until(&mut self, instant: Instant) -> usize {
372        let drain = self.expiry.drain_until(&instant);
373        let mut count = 0;
374        for (time, key) in drain {
375            match self.map.entry(key) {
376                hash_map::Entry::Occupied(entry) if entry.get().0 == time => {
377                    // If the entry's time matches that of the item we are draining from the expiry list,
378                    // remove the entry from the map and increase the count of items we removed.
379                    entry.remove();
380                    count += 1;
381                }
382                hash_map::Entry::Occupied(_entry) => {
383                    // If the entry's time does not match the time of the item we are draining,
384                    // do not remove the entry: It means that it was re-added with a later time.
385                }
386                hash_map::Entry::Vacant(_) => {
387                    // If the entry is not in the map, it means that it was already removed,
388                    // which can happen if it was inserted multiple times.
389                }
390            }
391        }
392        count
393    }
394}
395
396#[cfg(test)]
397mod test {
398    use std::str::FromStr;
399
400    use n0_future::time::{Duration, Instant};
401    use rand::SeedableRng;
402
403    use super::{IndexSet, TimeBoundCache, TimerMap};
404
405    fn test_rng() -> rand_chacha::ChaCha12Rng {
406        rand_chacha::ChaCha12Rng::seed_from_u64(42)
407    }
408
409    #[test]
410    fn indexset() {
411        let elems = [1, 2, 3, 4];
412        let set = IndexSet::from_iter(elems);
413        let x = set.shuffled(&mut test_rng());
414        assert_eq!(x, vec![2, 1, 4, 3]);
415        let x = set.shuffled_and_capped(2, &mut test_rng());
416        assert_eq!(x, vec![2, 1]);
417        let x = set.shuffled_without(&[&1], &mut test_rng());
418        assert_eq!(x, vec![3, 2, 4]);
419        let x = set.shuffled_without_and_capped(&[&1], 2, &mut test_rng());
420        assert_eq!(x, vec![3, 2]);
421
422        // recreate the rng - otherwise we get failures on some architectures when cross-compiling,
423        // likely due to usize differences pulling different amounts of randomness.
424        let x = set.pick_random(&mut test_rng());
425        assert_eq!(x, Some(&1));
426        let x = set.pick_random_without(&[&3], &mut test_rng());
427        assert_eq!(x, Some(&4));
428
429        let mut set = set;
430        set.remove_random(&mut test_rng());
431        assert_eq!(set, IndexSet::from_iter([2, 3, 4]));
432    }
433
434    #[test]
435    fn timer_map() {
436        let mut map = TimerMap::new();
437        let now = Instant::now();
438
439        let times = [
440            now - Duration::from_secs(1),
441            now,
442            now + Duration::from_secs(1),
443            now + Duration::from_secs(2),
444        ];
445        map.insert(times[0], -1);
446        map.insert(times[0], -2);
447        map.insert(times[1], 0);
448        map.insert(times[2], 1);
449        map.insert(times[3], 2);
450        map.insert(times[3], 3);
451
452        assert_eq!(
453            map.to_vec(),
454            vec![
455                (times[0], -1),
456                (times[0], -2),
457                (times[1], 0),
458                (times[2], 1),
459                (times[3], 2),
460                (times[3], 3)
461            ]
462        );
463
464        assert_eq!(map.first(), Some(&times[0]));
465
466        let drain = map.drain_until(&now);
467        assert_eq!(
468            drain.collect::<Vec<_>>(),
469            vec![(times[0], -1), (times[0], -2), (times[1], 0),]
470        );
471        assert_eq!(
472            map.to_vec(),
473            vec![(times[2], 1), (times[3], 2), (times[3], 3)]
474        );
475        let drain = map.drain_until(&now);
476        assert_eq!(drain.collect::<Vec<_>>(), vec![]);
477        let drain = map.drain_until(&(now + Duration::from_secs(10)));
478        assert_eq!(
479            drain.collect::<Vec<_>>(),
480            vec![(times[2], 1), (times[3], 2), (times[3], 3)]
481        );
482    }
483
484    #[test]
485    fn hex() {
486        #[derive(Eq, PartialEq)]
487        struct Id([u8; 32]);
488        idbytes_impls!(Id, "Id");
489        let id: Id = [1u8; 32].into();
490        assert_eq!(id, Id::from_str(&format!("{id}")).unwrap());
491        assert_eq!(
492            &format!("{id}"),
493            "0101010101010101010101010101010101010101010101010101010101010101"
494        );
495        assert_eq!(
496            &format!("{id:?}"),
497            "Id(0101010101010101010101010101010101010101010101010101010101010101)"
498        );
499        assert_eq!(id.as_bytes(), &[1u8; 32]);
500    }
501
502    #[test]
503    fn time_bound_cache() {
504        let mut cache = TimeBoundCache::default();
505
506        let t0 = Instant::now();
507        let t1 = t0 + Duration::from_secs(1);
508        let t2 = t0 + Duration::from_secs(2);
509
510        cache.insert(1, 10, t0);
511        cache.insert(2, 20, t1);
512        cache.insert(3, 30, t1);
513        cache.insert(4, 40, t2);
514
515        assert_eq!(cache.get(&2), Some(&20));
516        assert_eq!(cache.len(), 4);
517        let removed = cache.expire_until(t1);
518        assert_eq!(removed, 3);
519        assert_eq!(cache.len(), 1);
520        assert_eq!(cache.get(&2), None);
521        assert_eq!(cache.get(&4), Some(&40));
522
523        let t3 = t2 + Duration::from_secs(1);
524        cache.insert(5, 50, t2);
525        assert_eq!(cache.expires(&5), Some(&t2));
526        cache.insert(5, 50, t3);
527        assert_eq!(cache.expires(&5), Some(&t3));
528        cache.expire_until(t2);
529        assert_eq!(cache.get(&4), None);
530        assert_eq!(cache.get(&5), Some(&50));
531    }
532}