1use std::{
4 collections::{hash_map, BinaryHeap, HashMap},
5 hash::Hash,
6};
7
8use n0_future::time::Instant;
9use rand::{
10 seq::{IteratorRandom, SliceRandom},
11 Rng,
12};
13
14macro_rules! idbytes_impls {
16 ($ty:ty, $name:expr) => {
17 impl $ty {
18 pub const fn from_bytes(bytes: [u8; 32]) -> Self {
20 Self(bytes)
21 }
22
23 pub fn as_bytes(&self) -> &[u8; 32] {
25 &self.0
26 }
27 }
28
29 impl<T: ::std::convert::Into<[u8; 32]>> ::std::convert::From<T> for $ty {
30 fn from(value: T) -> Self {
31 Self::from_bytes(value.into())
32 }
33 }
34
35 impl ::std::fmt::Display for $ty {
36 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
37 write!(f, "{}", ::hex::encode(&self.0))
38 }
39 }
40
41 impl ::std::fmt::Debug for $ty {
42 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
43 write!(f, "{}({})", $name, ::hex::encode(&self.0))
44 }
45 }
46
47 impl ::std::str::FromStr for $ty {
48 type Err = ::hex::FromHexError;
49 fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
50 let mut bytes = [0u8; 32];
51 ::hex::decode_to_slice(s, &mut bytes)?;
52 Ok(Self::from_bytes(bytes))
53 }
54 }
55
56 impl ::std::convert::AsRef<[u8]> for $ty {
57 fn as_ref(&self) -> &[u8] {
58 &self.0
59 }
60 }
61
62 impl ::std::convert::AsRef<[u8; 32]> for $ty {
63 fn as_ref(&self) -> &[u8; 32] {
64 &self.0
65 }
66 }
67 };
68}
69
70pub(crate) use idbytes_impls;
71
72#[derive(Default, Debug, Clone, derive_more::Deref)]
78pub(crate) struct IndexSet<T> {
79 inner: indexmap::IndexSet<T>,
80}
81
82impl<T: Hash + Eq> PartialEq for IndexSet<T> {
83 fn eq(&self, other: &Self) -> bool {
84 self.inner == other.inner
85 }
86}
87
88impl<T: Hash + Eq + PartialEq> IndexSet<T> {
89 pub fn new() -> Self {
90 Self {
91 inner: indexmap::IndexSet::new(),
92 }
93 }
94
95 pub fn insert(&mut self, value: T) -> bool {
96 self.inner.insert(value)
97 }
98
99 pub fn remove_random<R: Rng + ?Sized>(&mut self, rng: &mut R) -> Option<T> {
101 self.pick_random_index(rng)
102 .and_then(|idx| self.inner.shift_remove_index(idx))
103 }
104
105 pub fn pick_random<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<&T> {
107 self.pick_random_index(rng)
108 .and_then(|idx| self.inner.get_index(idx))
109 }
110
111 pub fn pick_random_without<R: Rng + ?Sized>(&self, without: &[&T], rng: &mut R) -> Option<&T> {
113 self.iter().filter(|x| !without.contains(x)).choose(rng)
114 }
115
116 pub fn pick_random_index<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<usize> {
118 if self.is_empty() {
119 None
120 } else {
121 Some(rng.random_range(0..self.inner.len()))
122 }
123 }
124
125 pub fn remove(&mut self, value: &T) -> Option<T> {
130 self.inner.swap_remove_full(value).map(|(_i, v)| v)
131 }
132
133 pub fn remove_index(&mut self, index: usize) -> Option<T> {
138 self.inner.swap_remove_index(index)
139 }
140
141 pub fn iter_without<'a>(&'a self, value: &'a T) -> impl Iterator<Item = &'a T> {
144 self.iter().filter(move |x| *x != value)
145 }
146}
147
148impl<T> IndexSet<T>
149where
150 T: Hash + Eq + Clone,
151{
152 pub fn shuffled<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T> {
154 let mut items: Vec<_> = self.inner.iter().cloned().collect();
155 items.shuffle(rng);
156 items
157 }
158
159 pub fn shuffled_and_capped<R: Rng + ?Sized>(&self, len: usize, rng: &mut R) -> Vec<T> {
162 let mut items = self.shuffled(rng);
163 items.truncate(len);
164 items
165 }
166
167 pub fn shuffled_without<R: Rng + ?Sized>(&self, without: &[&T], rng: &mut R) -> Vec<T> {
170 let mut items = self
171 .inner
172 .iter()
173 .filter(|x| !without.contains(x))
174 .cloned()
175 .collect::<Vec<_>>();
176 items.shuffle(rng);
177 items
178 }
179
180 pub fn shuffled_without_and_capped<R: Rng + ?Sized>(
183 &self,
184 without: &[&T],
185 len: usize,
186 rng: &mut R,
187 ) -> Vec<T> {
188 let mut items = self.shuffled_without(without, rng);
189 items.truncate(len);
190 items
191 }
192}
193
194impl<T> IntoIterator for IndexSet<T> {
195 type Item = T;
196 type IntoIter = <indexmap::IndexSet<T> as IntoIterator>::IntoIter;
197 fn into_iter(self) -> Self::IntoIter {
198 self.inner.into_iter()
199 }
200}
201
202impl<T> FromIterator<T> for IndexSet<T>
203where
204 T: Hash + Eq,
205{
206 fn from_iter<I: IntoIterator<Item = T>>(iterable: I) -> Self {
207 IndexSet {
208 inner: indexmap::IndexSet::from_iter(iterable),
209 }
210 }
211}
212
213#[derive(Debug)]
215pub struct TimerMap<T> {
216 heap: BinaryHeap<TimerMapEntry<T>>,
217 seq: u64,
218}
219
220impl<T> Default for TimerMap<T> {
222 fn default() -> Self {
223 Self {
224 heap: Default::default(),
225 seq: 0,
226 }
227 }
228}
229
230impl<T> TimerMap<T> {
231 pub fn new() -> Self {
233 Self::default()
234 }
235
236 pub fn insert(&mut self, instant: Instant, item: T) {
238 let seq = self.seq;
239 self.seq += 1;
240 let entry = TimerMapEntry {
241 seq,
242 time: instant,
243 item,
244 };
245 self.heap.push(entry);
246 }
247
248 pub fn drain_until(
250 &mut self,
251 from: &Instant,
252 ) -> impl Iterator<Item = (Instant, T)> + '_ + use<'_, T> {
253 let from = *from;
254 std::iter::from_fn(move || self.pop_before(from))
255 }
256
257 pub fn pop_before(&mut self, limit: Instant) -> Option<(Instant, T)> {
259 match self.heap.peek() {
260 Some(item) if item.time <= limit => self.heap.pop().map(|item| (item.time, item.item)),
261 _ => None,
262 }
263 }
264
265 pub fn first(&self) -> Option<&Instant> {
267 self.heap.peek().map(|x| &x.time)
268 }
269
270 #[cfg(test)]
271 fn to_vec(&self) -> Vec<(Instant, T)>
272 where
273 T: Clone,
274 {
275 self.heap
276 .clone()
277 .into_sorted_vec()
278 .into_iter()
279 .rev()
280 .map(|x| (x.time, x.item))
281 .collect()
282 }
283}
284
285#[derive(Debug, Clone)]
286struct TimerMapEntry<T> {
287 time: Instant,
288 seq: u64,
289 item: T,
290}
291
292impl<T> PartialEq for TimerMapEntry<T> {
293 fn eq(&self, other: &Self) -> bool {
294 self.time == other.time && self.seq == other.seq
295 }
296}
297
298impl<T> Eq for TimerMapEntry<T> {}
299
300impl<T> PartialOrd for TimerMapEntry<T> {
301 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
302 Some(self.cmp(other))
303 }
304}
305
306impl<T> Ord for TimerMapEntry<T> {
307 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
308 self.time
309 .cmp(&other.time)
310 .reverse()
311 .then_with(|| self.seq.cmp(&other.seq).reverse())
312 }
313}
314
315#[derive(Debug)]
317pub struct TimeBoundCache<K, V> {
318 map: HashMap<K, (Instant, V)>,
319 expiry: TimerMap<K>,
320}
321
322impl<K, V> Default for TimeBoundCache<K, V> {
323 fn default() -> Self {
324 Self {
325 map: Default::default(),
326 expiry: Default::default(),
327 }
328 }
329}
330
331impl<K: Hash + Eq + Clone, V> TimeBoundCache<K, V> {
332 pub fn insert(&mut self, key: K, value: V, expires: Instant) {
334 self.map.insert(key.clone(), (expires, value));
335 self.expiry.insert(expires, key);
336 }
337
338 pub fn contains_key(&self, key: &K) -> bool {
340 self.map.contains_key(key)
341 }
342
343 pub fn len(&self) -> usize {
345 self.map.len()
346 }
347
348 pub fn is_empty(&self) -> bool {
350 self.map.is_empty()
351 }
352
353 pub fn get(&self, key: &K) -> Option<&V> {
355 self.map.get(key).map(|(_expires, value)| value)
356 }
357
358 pub fn expires(&self, key: &K) -> Option<&Instant> {
360 self.map.get(key).map(|(expires, _value)| expires)
361 }
362
363 pub fn iter(&self) -> impl Iterator<Item = (&K, &V, &Instant)> {
365 self.map.iter().map(|(k, (expires, v))| (k, v, expires))
366 }
367
368 pub fn expire_until(&mut self, instant: Instant) -> usize {
372 let drain = self.expiry.drain_until(&instant);
373 let mut count = 0;
374 for (time, key) in drain {
375 match self.map.entry(key) {
376 hash_map::Entry::Occupied(entry) if entry.get().0 == time => {
377 entry.remove();
380 count += 1;
381 }
382 hash_map::Entry::Occupied(_entry) => {
383 }
386 hash_map::Entry::Vacant(_) => {
387 }
390 }
391 }
392 count
393 }
394}
395
396#[cfg(test)]
397mod test {
398 use std::str::FromStr;
399
400 use n0_future::time::{Duration, Instant};
401 use rand::SeedableRng;
402
403 use super::{IndexSet, TimeBoundCache, TimerMap};
404
405 fn test_rng() -> rand_chacha::ChaCha12Rng {
406 rand_chacha::ChaCha12Rng::seed_from_u64(42)
407 }
408
409 #[test]
410 fn indexset() {
411 let elems = [1, 2, 3, 4];
412 let set = IndexSet::from_iter(elems);
413 let x = set.shuffled(&mut test_rng());
414 assert_eq!(x, vec![2, 1, 4, 3]);
415 let x = set.shuffled_and_capped(2, &mut test_rng());
416 assert_eq!(x, vec![2, 1]);
417 let x = set.shuffled_without(&[&1], &mut test_rng());
418 assert_eq!(x, vec![3, 2, 4]);
419 let x = set.shuffled_without_and_capped(&[&1], 2, &mut test_rng());
420 assert_eq!(x, vec![3, 2]);
421
422 let x = set.pick_random(&mut test_rng());
425 assert_eq!(x, Some(&1));
426 let x = set.pick_random_without(&[&3], &mut test_rng());
427 assert_eq!(x, Some(&4));
428
429 let mut set = set;
430 set.remove_random(&mut test_rng());
431 assert_eq!(set, IndexSet::from_iter([2, 3, 4]));
432 }
433
434 #[test]
435 fn timer_map() {
436 let mut map = TimerMap::new();
437 let now = Instant::now();
438
439 let times = [
440 now - Duration::from_secs(1),
441 now,
442 now + Duration::from_secs(1),
443 now + Duration::from_secs(2),
444 ];
445 map.insert(times[0], -1);
446 map.insert(times[0], -2);
447 map.insert(times[1], 0);
448 map.insert(times[2], 1);
449 map.insert(times[3], 2);
450 map.insert(times[3], 3);
451
452 assert_eq!(
453 map.to_vec(),
454 vec![
455 (times[0], -1),
456 (times[0], -2),
457 (times[1], 0),
458 (times[2], 1),
459 (times[3], 2),
460 (times[3], 3)
461 ]
462 );
463
464 assert_eq!(map.first(), Some(×[0]));
465
466 let drain = map.drain_until(&now);
467 assert_eq!(
468 drain.collect::<Vec<_>>(),
469 vec![(times[0], -1), (times[0], -2), (times[1], 0),]
470 );
471 assert_eq!(
472 map.to_vec(),
473 vec![(times[2], 1), (times[3], 2), (times[3], 3)]
474 );
475 let drain = map.drain_until(&now);
476 assert_eq!(drain.collect::<Vec<_>>(), vec![]);
477 let drain = map.drain_until(&(now + Duration::from_secs(10)));
478 assert_eq!(
479 drain.collect::<Vec<_>>(),
480 vec![(times[2], 1), (times[3], 2), (times[3], 3)]
481 );
482 }
483
484 #[test]
485 fn hex() {
486 #[derive(Eq, PartialEq)]
487 struct Id([u8; 32]);
488 idbytes_impls!(Id, "Id");
489 let id: Id = [1u8; 32].into();
490 assert_eq!(id, Id::from_str(&format!("{id}")).unwrap());
491 assert_eq!(
492 &format!("{id}"),
493 "0101010101010101010101010101010101010101010101010101010101010101"
494 );
495 assert_eq!(
496 &format!("{id:?}"),
497 "Id(0101010101010101010101010101010101010101010101010101010101010101)"
498 );
499 assert_eq!(id.as_bytes(), &[1u8; 32]);
500 }
501
502 #[test]
503 fn time_bound_cache() {
504 let mut cache = TimeBoundCache::default();
505
506 let t0 = Instant::now();
507 let t1 = t0 + Duration::from_secs(1);
508 let t2 = t0 + Duration::from_secs(2);
509
510 cache.insert(1, 10, t0);
511 cache.insert(2, 20, t1);
512 cache.insert(3, 30, t1);
513 cache.insert(4, 40, t2);
514
515 assert_eq!(cache.get(&2), Some(&20));
516 assert_eq!(cache.len(), 4);
517 let removed = cache.expire_until(t1);
518 assert_eq!(removed, 3);
519 assert_eq!(cache.len(), 1);
520 assert_eq!(cache.get(&2), None);
521 assert_eq!(cache.get(&4), Some(&40));
522
523 let t3 = t2 + Duration::from_secs(1);
524 cache.insert(5, 50, t2);
525 assert_eq!(cache.expires(&5), Some(&t2));
526 cache.insert(5, 50, t3);
527 assert_eq!(cache.expires(&5), Some(&t3));
528 cache.expire_until(t2);
529 assert_eq!(cache.get(&4), None);
530 assert_eq!(cache.get(&5), Some(&50));
531 }
532}