use std::{cmp::Ordering, fmt::Debug};
use serde::{Deserialize, Serialize};
use crate::ContentStatus;
pub trait RangeEntry: Debug + Clone {
type Key: RangeKey;
type Value: RangeValue;
fn key(&self) -> &Self::Key;
fn value(&self) -> &Self::Value;
fn as_fingerprint(&self) -> Fingerprint;
}
pub trait RangeKey: Sized + Debug + Ord + PartialEq + Clone + 'static {
#[cfg(test)]
fn is_prefix_of(&self, other: &Self) -> bool;
#[cfg(test)]
fn is_prefixed_by(&self, other: &Self) -> bool {
other.is_prefix_of(self)
}
}
pub trait RangeValue: Sized + Debug + Ord + PartialEq + Clone + 'static {}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
pub struct Range<K> {
x: K,
y: K,
}
impl<K> Range<K> {
pub fn x(&self) -> &K {
&self.x
}
pub fn y(&self) -> &K {
&self.y
}
pub fn new(x: K, y: K) -> Self {
Range { x, y }
}
pub fn map<X>(self, f: impl FnOnce(K, K) -> (X, X)) -> Range<X> {
let (x, y) = f(self.x, self.y);
Range { x, y }
}
}
impl<K: Ord> Range<K> {
pub fn is_all(&self) -> bool {
self.x() == self.y()
}
pub fn contains(&self, t: &K) -> bool {
match self.x().cmp(self.y()) {
Ordering::Equal => true,
Ordering::Less => self.x() <= t && t < self.y(),
Ordering::Greater => self.x() <= t || t < self.y(),
}
}
}
impl<K> From<(K, K)> for Range<K> {
fn from((x, y): (K, K)) -> Self {
Range { x, y }
}
}
#[derive(Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct Fingerprint(pub [u8; 32]);
impl Debug for Fingerprint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Fp({})", blake3::Hash::from(self.0).to_hex())
}
}
impl Fingerprint {
pub fn empty() -> Self {
Fingerprint(*blake3::hash(&[]).as_bytes())
}
pub fn new<T: RangeEntry>(val: T) -> Self {
val.as_fingerprint()
}
}
impl std::ops::BitXorAssign for Fingerprint {
fn bitxor_assign(&mut self, rhs: Self) {
for (a, b) in self.0.iter_mut().zip(rhs.0.iter()) {
*a ^= b;
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RangeFingerprint<K> {
#[serde(bound(
serialize = "Range<K>: Serialize",
deserialize = "Range<K>: Deserialize<'de>"
))]
pub range: Range<K>,
pub fingerprint: Fingerprint,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RangeItem<E: RangeEntry> {
#[serde(bound(
serialize = "Range<E::Key>: Serialize",
deserialize = "Range<E::Key>: Deserialize<'de>"
))]
pub range: Range<E::Key>,
#[serde(bound(serialize = "E: Serialize", deserialize = "E: Deserialize<'de>"))]
pub values: Vec<(E, ContentStatus)>,
pub have_local: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MessagePart<E: RangeEntry> {
#[serde(bound(
serialize = "RangeFingerprint<E::Key>: Serialize",
deserialize = "RangeFingerprint<E::Key>: Deserialize<'de>"
))]
RangeFingerprint(RangeFingerprint<E::Key>),
#[serde(bound(
serialize = "RangeItem<E>: Serialize",
deserialize = "RangeItem<E>: Deserialize<'de>"
))]
RangeItem(RangeItem<E>),
}
impl<E: RangeEntry> MessagePart<E> {
pub fn is_range_fingerprint(&self) -> bool {
matches!(self, MessagePart::RangeFingerprint(_))
}
pub fn is_range_item(&self) -> bool {
matches!(self, MessagePart::RangeItem(_))
}
pub fn values(&self) -> Option<&[(E, ContentStatus)]> {
match self {
MessagePart::RangeFingerprint(_) => None,
MessagePart::RangeItem(RangeItem { values, .. }) => Some(values),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Message<E: RangeEntry> {
#[serde(bound(
serialize = "MessagePart<E>: Serialize",
deserialize = "MessagePart<E>: Deserialize<'de>"
))]
parts: Vec<MessagePart<E>>,
}
impl<E: RangeEntry> Message<E> {
fn init<S: Store<E>>(store: &mut S) -> Result<Self, S::Error> {
let x = store.get_first()?;
let range = Range::new(x.clone(), x);
let fingerprint = store.get_fingerprint(&range)?;
let part = MessagePart::RangeFingerprint(RangeFingerprint { range, fingerprint });
Ok(Message { parts: vec![part] })
}
pub fn parts(&self) -> &[MessagePart<E>] {
&self.parts
}
pub fn values(&self) -> impl Iterator<Item = &(E, ContentStatus)> {
self.parts().iter().filter_map(|p| p.values()).flatten()
}
pub fn value_count(&self) -> usize {
self.values().count()
}
}
pub trait Store<E: RangeEntry>: Sized {
type Error: Debug + Send + Sync + Into<anyhow::Error> + 'static;
type RangeIterator<'a>: Iterator<Item = Result<E, Self::Error>>
where
Self: 'a,
E: 'a;
type ParentIterator<'a>: Iterator<Item = Result<E, Self::Error>>
where
Self: 'a,
E: 'a;
fn get_first(&mut self) -> Result<E::Key, Self::Error>;
fn get(&mut self, key: &E::Key) -> Result<Option<E>, Self::Error>;
fn len(&mut self) -> Result<usize, Self::Error>;
fn is_empty(&mut self) -> Result<bool, Self::Error>;
fn get_fingerprint(&mut self, range: &Range<E::Key>) -> Result<Fingerprint, Self::Error>;
fn entry_put(&mut self, entry: E) -> Result<(), Self::Error>;
fn get_range(&mut self, range: Range<E::Key>) -> Result<Self::RangeIterator<'_>, Self::Error>;
fn get_range_len(&mut self, range: Range<E::Key>) -> Result<usize, Self::Error> {
let mut count = 0;
for el in self.get_range(range)? {
let _el = el?;
count += 1;
}
Ok(count)
}
fn prefixed_by(&mut self, prefix: &E::Key) -> Result<Self::RangeIterator<'_>, Self::Error>;
fn prefixes_of(&mut self, key: &E::Key) -> Result<Self::ParentIterator<'_>, Self::Error>;
fn all(&mut self) -> Result<Self::RangeIterator<'_>, Self::Error>;
fn entry_remove(&mut self, key: &E::Key) -> Result<Option<E>, Self::Error>;
fn remove_prefix_filtered(
&mut self,
prefix: &E::Key,
predicate: impl Fn(&E::Value) -> bool,
) -> Result<usize, Self::Error>;
fn initial_message(&mut self) -> Result<Message<E>, Self::Error> {
Message::init(self)
}
fn process_message<F, F2, F3>(
&mut self,
config: &SyncConfig,
message: Message<E>,
validate_cb: F,
mut on_insert_cb: F2,
content_status_cb: F3,
) -> Result<Option<Message<E>>, Self::Error>
where
F: Fn(&Self, &E, ContentStatus) -> bool,
F2: FnMut(&Self, E, ContentStatus),
F3: Fn(&Self, &E) -> ContentStatus,
{
let mut out = Vec::new();
let mut items = Vec::new();
let mut fingerprints = Vec::new();
for part in message.parts {
match part {
MessagePart::RangeItem(item) => {
items.push(item);
}
MessagePart::RangeFingerprint(fp) => {
fingerprints.push(fp);
}
}
}
for RangeItem {
range,
values,
have_local,
} in items
{
let diff: Option<Vec<_>> = if have_local {
None
} else {
Some({
let items = self
.get_range(range.clone())?
.filter_map(|our_entry| match our_entry {
Ok(our_entry) => {
if !values.iter().any(|(their_entry, _)| {
our_entry.key() == their_entry.key()
&& their_entry.value() >= our_entry.value()
}) {
Some(Ok(our_entry))
} else {
None
}
}
Err(err) => Some(Err(err)),
})
.collect::<Result<Vec<_>, _>>()?;
items
.into_iter()
.map(|entry| {
let content_status = content_status_cb(self, &entry);
(entry, content_status)
})
.collect()
})
};
for (entry, content_status) in values {
if validate_cb(self, &entry, content_status) {
let outcome = self.put(entry.clone())?;
if let InsertOutcome::Inserted { .. } = outcome {
on_insert_cb(self, entry, content_status);
}
}
}
if let Some(diff) = diff {
if !diff.is_empty() {
out.push(MessagePart::RangeItem(RangeItem {
range,
values: diff,
have_local: true,
}));
}
}
}
for RangeFingerprint { range, fingerprint } in fingerprints {
let local_fingerprint = self.get_fingerprint(&range)?;
if local_fingerprint == fingerprint {
continue;
}
let num_local_values = self.get_range_len(range.clone())?;
if num_local_values <= 1 || fingerprint == Fingerprint::empty() {
let values = self
.get_range(range.clone())?
.collect::<Result<Vec<_>, _>>()?;
let values = values
.into_iter()
.map(|entry| {
let content_status = content_status_cb(self, &entry);
(entry, content_status)
})
.collect();
out.push(MessagePart::RangeItem(RangeItem {
range,
values,
have_local: false,
}));
} else {
let mut ranges = Vec::with_capacity(config.split_factor);
let mut start_index = 0;
for el in self.get_range(range.clone())? {
let el = el?;
if el.key() >= range.x() {
break;
}
start_index += 1;
}
let mut pivot = |i: usize| {
let i = i % config.split_factor;
let offset = (num_local_values * (i + 1)) / config.split_factor;
let offset = (start_index + offset) % num_local_values;
self.get_range(range.clone())
.map(|mut i| i.nth(offset))
.and_then(|e| e.expect("missing entry"))
.map(|e| e.key().clone())
};
if range.is_all() {
for i in 0..config.split_factor {
let (x, y) = (pivot(i)?, pivot(i + 1)?);
if x != y {
ranges.push(Range { x, y })
}
}
} else {
ranges.push(Range {
x: range.x().clone(),
y: pivot(0)?,
});
for i in 0..config.split_factor - 2 {
let (x, y) = (pivot(i)?, pivot(i + 1)?);
if x != y {
ranges.push(Range { x, y })
}
}
ranges.push(Range {
x: pivot(config.split_factor - 2)?,
y: range.y().clone(),
});
}
let mut non_empty = 0;
for range in ranges {
let chunk: Vec<_> = self.get_range(range.clone())?.collect();
if !chunk.is_empty() {
non_empty += 1;
}
let fingerprint = self.get_fingerprint(&range)?;
if chunk.len() > config.max_set_size {
out.push(MessagePart::RangeFingerprint(RangeFingerprint {
range: range.clone(),
fingerprint,
}));
} else {
let values = chunk
.into_iter()
.map(|entry| {
entry.map(|entry| {
let content_status = content_status_cb(self, &entry);
(entry, content_status)
})
})
.collect::<Result<_, _>>()?;
out.push(MessagePart::RangeItem(RangeItem {
range,
values,
have_local: false,
}));
}
}
debug_assert!(non_empty > 1);
}
}
if !out.is_empty() {
Ok(Some(Message { parts: out }))
} else {
Ok(None)
}
}
fn put(&mut self, entry: E) -> Result<InsertOutcome, Self::Error> {
let prefix_entry = self.prefixes_of(entry.key())?;
for prefix_entry in prefix_entry {
let prefix_entry = prefix_entry?;
if entry.value() <= prefix_entry.value() {
return Ok(InsertOutcome::NotInserted);
}
}
let removed = self.remove_prefix_filtered(entry.key(), |value| entry.value() >= value)?;
self.entry_put(entry)?;
Ok(InsertOutcome::Inserted { removed })
}
}
impl<E: RangeEntry, S: Store<E>> Store<E> for &mut S {
type Error = S::Error;
type RangeIterator<'a>
= S::RangeIterator<'a>
where
Self: 'a,
E: 'a;
type ParentIterator<'a>
= S::ParentIterator<'a>
where
Self: 'a,
E: 'a;
fn get_first(&mut self) -> Result<<E as RangeEntry>::Key, Self::Error> {
(**self).get_first()
}
fn get(&mut self, key: &<E as RangeEntry>::Key) -> Result<Option<E>, Self::Error> {
(**self).get(key)
}
fn len(&mut self) -> Result<usize, Self::Error> {
(**self).len()
}
fn is_empty(&mut self) -> Result<bool, Self::Error> {
(**self).is_empty()
}
fn get_fingerprint(
&mut self,
range: &Range<<E as RangeEntry>::Key>,
) -> Result<Fingerprint, Self::Error> {
(**self).get_fingerprint(range)
}
fn entry_put(&mut self, entry: E) -> Result<(), Self::Error> {
(**self).entry_put(entry)
}
fn get_range(
&mut self,
range: Range<<E as RangeEntry>::Key>,
) -> Result<Self::RangeIterator<'_>, Self::Error> {
(**self).get_range(range)
}
fn prefixed_by(
&mut self,
prefix: &<E as RangeEntry>::Key,
) -> Result<Self::RangeIterator<'_>, Self::Error> {
(**self).prefixed_by(prefix)
}
fn prefixes_of(
&mut self,
key: &<E as RangeEntry>::Key,
) -> Result<Self::ParentIterator<'_>, Self::Error> {
(**self).prefixes_of(key)
}
fn all(&mut self) -> Result<Self::RangeIterator<'_>, Self::Error> {
(**self).all()
}
fn entry_remove(&mut self, key: &<E as RangeEntry>::Key) -> Result<Option<E>, Self::Error> {
(**self).entry_remove(key)
}
fn remove_prefix_filtered(
&mut self,
prefix: &<E as RangeEntry>::Key,
predicate: impl Fn(&<E as RangeEntry>::Value) -> bool,
) -> Result<usize, Self::Error> {
(**self).remove_prefix_filtered(prefix, predicate)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SyncConfig {
max_set_size: usize,
split_factor: usize,
}
impl Default for SyncConfig {
fn default() -> Self {
SyncConfig {
max_set_size: 1,
split_factor: 2,
}
}
}
#[derive(Debug)]
pub enum InsertOutcome {
NotInserted,
Inserted {
removed: usize,
},
}
#[cfg(test)]
mod tests {
use std::{cell::RefCell, collections::BTreeMap, convert::Infallible, fmt::Debug, rc::Rc};
use proptest::prelude::*;
use test_strategy::proptest;
use super::*;
#[derive(Debug)]
struct SimpleStore<K, V> {
data: BTreeMap<K, V>,
}
impl<K, V> Default for SimpleStore<K, V> {
fn default() -> Self {
SimpleStore {
data: BTreeMap::default(),
}
}
}
impl<K, V> RangeEntry for (K, V)
where
K: RangeKey,
V: RangeValue,
{
type Key = K;
type Value = V;
fn key(&self) -> &Self::Key {
&self.0
}
fn value(&self) -> &Self::Value {
&self.1
}
fn as_fingerprint(&self) -> Fingerprint {
let mut hasher = blake3::Hasher::new();
hasher.update(format!("{:?}", self.0).as_bytes());
hasher.update(format!("{:?}", self.1).as_bytes());
Fingerprint(hasher.finalize().into())
}
}
impl RangeKey for &'static str {
fn is_prefix_of(&self, other: &Self) -> bool {
other.starts_with(self)
}
}
impl RangeKey for String {
fn is_prefix_of(&self, other: &Self) -> bool {
other.starts_with(self)
}
}
impl RangeValue for &'static [u8] {}
impl RangeValue for i32 {}
impl RangeValue for u8 {}
impl RangeValue for () {}
impl<K, V> Store<(K, V)> for SimpleStore<K, V>
where
K: RangeKey + Default,
V: RangeValue,
{
type Error = Infallible;
type ParentIterator<'a> = std::vec::IntoIter<Result<(K, V), Infallible>>;
fn get_first(&mut self) -> Result<K, Self::Error> {
if let Some((k, _)) = self.data.first_key_value() {
Ok(k.clone())
} else {
Ok(Default::default())
}
}
fn get(&mut self, key: &K) -> Result<Option<(K, V)>, Self::Error> {
Ok(self.data.get(key).cloned().map(|v| (key.clone(), v)))
}
fn len(&mut self) -> Result<usize, Self::Error> {
Ok(self.data.len())
}
fn is_empty(&mut self) -> Result<bool, Self::Error> {
Ok(self.data.is_empty())
}
fn get_fingerprint(&mut self, range: &Range<K>) -> Result<Fingerprint, Self::Error> {
let elements = self.get_range(range.clone())?;
let mut fp = Fingerprint::empty();
for el in elements {
let el = el?;
fp ^= el.as_fingerprint();
}
Ok(fp)
}
fn entry_put(&mut self, e: (K, V)) -> Result<(), Self::Error> {
self.data.insert(e.0, e.1);
Ok(())
}
type RangeIterator<'a>
= SimpleRangeIterator<'a, K, V>
where
K: 'a,
V: 'a;
fn get_range(&mut self, range: Range<K>) -> Result<Self::RangeIterator<'_>, Self::Error> {
let iter = self.data.iter();
Ok(SimpleRangeIterator {
iter,
filter: SimpleFilter::Range(range),
})
}
fn entry_remove(&mut self, key: &K) -> Result<Option<(K, V)>, Self::Error> {
let res = self.data.remove(key).map(|v| (key.clone(), v));
Ok(res)
}
fn all(&mut self) -> Result<Self::RangeIterator<'_>, Self::Error> {
let iter = self.data.iter();
Ok(SimpleRangeIterator {
iter,
filter: SimpleFilter::None,
})
}
fn prefixes_of(&mut self, key: &K) -> Result<Self::ParentIterator<'_>, Self::Error> {
let mut res = vec![];
for (k, v) in self.data.iter() {
if k.is_prefix_of(key) {
res.push(Ok((k.clone(), v.clone())));
}
}
Ok(res.into_iter())
}
fn prefixed_by(&mut self, prefix: &K) -> Result<Self::RangeIterator<'_>, Self::Error> {
let iter = self.data.iter();
Ok(SimpleRangeIterator {
iter,
filter: SimpleFilter::Prefix(prefix.clone()),
})
}
fn remove_prefix_filtered(
&mut self,
prefix: &K,
predicate: impl Fn(&V) -> bool,
) -> Result<usize, Self::Error> {
let old_len = self.data.len();
self.data.retain(|k, v| {
let remove = prefix.is_prefix_of(k) && predicate(v);
!remove
});
Ok(old_len - self.data.len())
}
}
#[derive(Debug)]
pub struct SimpleRangeIterator<'a, K, V> {
iter: std::collections::btree_map::Iter<'a, K, V>,
filter: SimpleFilter<K>,
}
#[derive(Debug)]
enum SimpleFilter<K> {
None,
Range(Range<K>),
Prefix(K),
}
impl<'a, K, V> Iterator for SimpleRangeIterator<'a, K, V>
where
K: RangeKey + Default,
V: Clone,
{
type Item = Result<(K, V), Infallible>;
fn next(&mut self) -> Option<Self::Item> {
let mut next = self.iter.next()?;
let filter = |x: &K| match &self.filter {
SimpleFilter::None => true,
SimpleFilter::Range(range) => range.contains(x),
SimpleFilter::Prefix(prefix) => prefix.is_prefix_of(x),
};
loop {
if filter(next.0) {
return Some(Ok((next.0.clone(), next.1.clone())));
}
next = self.iter.next()?;
}
}
}
#[test]
fn test_paper_1() {
let alice_set = [("ape", 1), ("eel", 1), ("fox", 1), ("gnu", 1)];
let bob_set = [
("bee", 1),
("cat", 1),
("doe", 1),
("eel", 1),
("fox", 1),
("hog", 1),
];
let res = sync(&alice_set, &bob_set);
res.print_messages();
assert_eq!(res.alice_to_bob.len(), 3, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 2, "B -> A message count");
assert_eq!(res.alice_to_bob[0].parts.len(), 1);
assert!(res.alice_to_bob[0].parts[0].is_range_fingerprint());
assert_eq!(res.bob_to_alice[0].parts.len(), 2);
assert!(res.bob_to_alice[0].parts[0].is_range_fingerprint());
assert!(res.bob_to_alice[0].parts[1].is_range_fingerprint());
assert_eq!(res.alice_to_bob[1].parts.len(), 3);
assert!(res.alice_to_bob[1].parts[0].is_range_fingerprint());
assert!(res.alice_to_bob[1].parts[1].is_range_fingerprint());
assert!(res.alice_to_bob[1].parts[2].is_range_item());
assert_eq!(res.bob_to_alice[1].parts.len(), 2);
assert!(res.bob_to_alice[1].parts[0].is_range_item());
assert!(res.bob_to_alice[1].parts[1].is_range_item());
}
#[test]
fn test_paper_2() {
let alice_set = [
("ape", 1),
("bee", 1),
("cat", 1),
("doe", 1),
("eel", 1),
("fox", 1), ("gnu", 1),
("hog", 1),
];
let bob_set = [
("ape", 1),
("bee", 1),
("cat", 1),
("doe", 1),
("eel", 1),
("gnu", 1),
("hog", 1),
];
let res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 3, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 2, "B -> A message count");
}
#[test]
fn test_paper_3() {
let alice_set = [
("ape", 1),
("bee", 1),
("cat", 1),
("doe", 1),
("eel", 1),
("fox", 1),
("gnu", 1),
("hog", 1),
];
let bob_set = [("ape", 1), ("cat", 1), ("eel", 1), ("gnu", 1)];
let res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 3, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 2, "B -> A message count");
}
#[test]
fn test_limits() {
let alice_set = [("ape", 1), ("bee", 1), ("cat", 1)];
let bob_set = [("ape", 1), ("cat", 1), ("doe", 1)];
let res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 2, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 2, "B -> A message count");
}
#[test]
fn test_prefixes_simple() {
let alice_set = [("/foo/bar", 1), ("/foo/baz", 1), ("/foo/cat", 1)];
let bob_set = [("/foo/bar", 1), ("/alice/bar", 1), ("/alice/baz", 1)];
let res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 2, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 2, "B -> A message count");
}
#[test]
fn test_prefixes_empty_alice() {
let alice_set = [];
let bob_set = [("/foo/bar", 1), ("/alice/bar", 1), ("/alice/baz", 1)];
let res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 1, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 1, "B -> A message count");
}
#[test]
fn test_prefixes_empty_bob() {
let alice_set = [("/foo/bar", 1), ("/foo/baz", 1), ("/foo/cat", 1)];
let bob_set = [];
let res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 2, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 1, "B -> A message count");
}
#[test]
fn test_equal_key_higher_value() {
let alice_set = [("foo", 2)];
let bob_set = [("foo", 1)];
let res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 2, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 1, "B -> A message count");
}
#[test]
fn test_multikey() {
#[derive(Default, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct Multikey {
author: [u8; 4],
key: Vec<u8>,
}
impl RangeKey for Multikey {
fn is_prefix_of(&self, other: &Self) -> bool {
self.author == other.author && self.key.starts_with(&other.key)
}
}
impl Debug for Multikey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let key = if let Ok(key) = std::str::from_utf8(&self.key) {
key.to_string()
} else {
hex::encode(&self.key)
};
f.debug_struct("Multikey")
.field("author", &hex::encode(self.author))
.field("key", &key)
.finish()
}
}
impl Multikey {
fn new(author: [u8; 4], key: impl AsRef<[u8]>) -> Self {
Multikey {
author,
key: key.as_ref().to_vec(),
}
}
}
let author_a = [1u8; 4];
let author_b = [2u8; 4];
let alice_set = [
(Multikey::new(author_a, "ape"), 1),
(Multikey::new(author_a, "bee"), 1),
(Multikey::new(author_b, "bee"), 1),
(Multikey::new(author_a, "doe"), 1),
];
let bob_set = [
(Multikey::new(author_a, "ape"), 1),
(Multikey::new(author_a, "bee"), 1),
(Multikey::new(author_a, "cat"), 1),
(Multikey::new(author_b, "cat"), 1),
];
let mut res = sync(&alice_set, &bob_set);
assert_eq!(res.alice_to_bob.len(), 2, "A -> B message count");
assert_eq!(res.bob_to_alice.len(), 2, "B -> A message count");
res.assert_alice_set(
"no limit",
&[
(Multikey::new(author_a, "ape"), 1),
(Multikey::new(author_a, "bee"), 1),
(Multikey::new(author_b, "bee"), 1),
(Multikey::new(author_a, "doe"), 1),
(Multikey::new(author_a, "cat"), 1),
(Multikey::new(author_b, "cat"), 1),
],
);
res.assert_bob_set(
"no limit",
&[
(Multikey::new(author_a, "ape"), 1),
(Multikey::new(author_a, "bee"), 1),
(Multikey::new(author_b, "bee"), 1),
(Multikey::new(author_a, "doe"), 1),
(Multikey::new(author_a, "cat"), 1),
(Multikey::new(author_b, "cat"), 1),
],
);
}
#[test]
fn test_validate_cb() {
let alice_set = [("alice1", 1), ("alice2", 2)];
let bob_set = [("bob1", 3), ("bob2", 4), ("bob3", 5)];
let alice_validate_set = Rc::new(RefCell::new(vec![]));
let bob_validate_set = Rc::new(RefCell::new(vec![]));
let validate_alice: ValidateCb<&str, i32> = Box::new({
let alice_validate_set = alice_validate_set.clone();
move |_, e, _| {
alice_validate_set.borrow_mut().push(*e);
false
}
});
let validate_bob: ValidateCb<&str, i32> = Box::new({
let bob_validate_set = bob_validate_set.clone();
move |_, e, _| {
bob_validate_set.borrow_mut().push(*e);
false
}
});
let mut alice = SimpleStore::default();
for (k, v) in alice_set {
alice.put((k, v)).unwrap();
}
let mut bob = SimpleStore::default();
for (k, v) in bob_set {
bob.put((k, v)).unwrap();
}
let mut res = sync_exchange_messages(alice, bob, &validate_alice, &validate_bob, 100);
res.assert_alice_set("unchanged", &alice_set);
res.assert_bob_set("unchanged", &bob_set);
assert_eq!(alice_validate_set.take(), bob_set);
assert_eq!(bob_validate_set.take(), alice_set);
}
struct SyncResult<K, V>
where
K: RangeKey + Default,
V: RangeValue,
{
alice: SimpleStore<K, V>,
bob: SimpleStore<K, V>,
alice_to_bob: Vec<Message<(K, V)>>,
bob_to_alice: Vec<Message<(K, V)>>,
}
impl<K, V> SyncResult<K, V>
where
K: RangeKey + Default,
V: RangeValue,
{
fn print_messages(&self) {
let len = std::cmp::max(self.alice_to_bob.len(), self.bob_to_alice.len());
for i in 0..len {
if let Some(msg) = self.alice_to_bob.get(i) {
println!("A -> B:");
print_message(msg);
}
if let Some(msg) = self.bob_to_alice.get(i) {
println!("B -> A:");
print_message(msg);
}
}
}
fn assert_alice_set(&mut self, ctx: &str, expected: &[(K, V)]) {
dbg!(self.alice.all().unwrap().collect::<Vec<_>>());
for e in expected {
assert_eq!(
self.alice.get(e.key()).unwrap().as_ref(),
Some(e),
"{}: (alice) missing key {:?}",
ctx,
e.key()
);
}
assert_eq!(
expected.len(),
self.alice.len().unwrap(),
"{}: (alice)",
ctx
);
}
fn assert_bob_set(&mut self, ctx: &str, expected: &[(K, V)]) {
dbg!(self.bob.all().unwrap().collect::<Vec<_>>());
for e in expected {
assert_eq!(
self.bob.get(e.key()).unwrap().as_ref(),
Some(e),
"{}: (bob) missing key {:?}",
ctx,
e
);
}
assert_eq!(expected.len(), self.bob.len().unwrap(), "{}: (bob)", ctx);
}
}
fn print_message<E: RangeEntry>(msg: &Message<E>) {
for part in &msg.parts {
match part {
MessagePart::RangeFingerprint(RangeFingerprint { range, fingerprint }) => {
println!(
" RangeFingerprint({:?}, {:?}, {:?})",
range.x(),
range.y(),
fingerprint
);
}
MessagePart::RangeItem(RangeItem {
range,
values,
have_local,
}) => {
println!(
" RangeItem({:?} | {:?}) (local?: {})\n {:?}",
range.x(),
range.y(),
have_local,
values,
);
}
}
}
}
type ValidateCb<K, V> = Box<dyn Fn(&SimpleStore<K, V>, &(K, V), ContentStatus) -> bool>;
fn sync<K, V>(alice_set: &[(K, V)], bob_set: &[(K, V)]) -> SyncResult<K, V>
where
K: RangeKey + Default,
V: RangeValue,
{
let alice_validate_cb: ValidateCb<K, V> = Box::new(|_, _, _| true);
let bob_validate_cb: ValidateCb<K, V> = Box::new(|_, _, _| true);
sync_with_validate_cb_and_assert(alice_set, bob_set, &alice_validate_cb, &bob_validate_cb)
}
fn insert_if_larger<K: RangeKey, V: RangeValue>(map: &mut BTreeMap<K, V>, key: K, value: V) {
let mut insert = true;
for (k, v) in map.iter() {
if k.is_prefix_of(&key) && v >= &value {
insert = false;
}
}
if insert {
#[allow(clippy::needless_bool)]
map.retain(|k, v| {
if key.is_prefix_of(k) && value >= *v {
false
} else {
true
}
});
map.insert(key, value);
}
}
fn sync_with_validate_cb_and_assert<K, V, F1, F2>(
alice_set: &[(K, V)],
bob_set: &[(K, V)],
alice_validate_cb: F1,
bob_validate_cb: F2,
) -> SyncResult<K, V>
where
K: RangeKey + Default,
V: RangeValue,
F1: Fn(&SimpleStore<K, V>, &(K, V), ContentStatus) -> bool,
F2: Fn(&SimpleStore<K, V>, &(K, V), ContentStatus) -> bool,
{
let mut alice = SimpleStore::<K, V>::default();
let mut bob = SimpleStore::<K, V>::default();
let expected_set = {
let mut expected_set = BTreeMap::new();
let mut alice_expected = BTreeMap::new();
for e in alice_set {
alice.put(e.clone()).unwrap();
insert_if_larger(&mut expected_set, e.0.clone(), e.1.clone());
insert_if_larger(&mut alice_expected, e.0.clone(), e.1.clone());
}
let alice_expected = alice_expected.into_iter().collect::<Vec<_>>();
let alice_now: Vec<_> = alice.all().unwrap().collect::<Result<_, _>>().unwrap();
assert_eq!(
alice_expected, alice_now,
"alice initial set does not match"
);
let mut bob_expected = BTreeMap::new();
for e in bob_set {
bob.put(e.clone()).unwrap();
insert_if_larger(&mut expected_set, e.0.clone(), e.1.clone());
insert_if_larger(&mut bob_expected, e.0.clone(), e.1.clone());
}
let bob_expected = bob_expected.into_iter().collect::<Vec<_>>();
let bob_now: Vec<_> = bob.all().unwrap().collect::<Result<_, _>>().unwrap();
assert_eq!(bob_expected, bob_now, "bob initial set does not match");
expected_set.into_iter().collect::<Vec<_>>()
};
let mut res = sync_exchange_messages(alice, bob, alice_validate_cb, bob_validate_cb, 100);
let alice_now: Vec<_> = res.alice.all().unwrap().collect::<Result<_, _>>().unwrap();
if alice_now != expected_set {
res.print_messages();
println!("alice_init: {alice_set:?}");
println!("bob_init: {bob_set:?}");
println!("expected: {expected_set:?}");
println!("alice_now: {alice_now:?}");
panic!("alice_now does not match expected");
}
let bob_now: Vec<_> = res.bob.all().unwrap().collect::<Result<_, _>>().unwrap();
if bob_now != expected_set {
res.print_messages();
println!("alice_init: {alice_set:?}");
println!("bob_init: {bob_set:?}");
println!("expected: {expected_set:?}");
println!("bob_now: {bob_now:?}");
panic!("bob_now does not match expected");
}
let mut alice_sent = BTreeMap::new();
for msg in &res.alice_to_bob {
for part in &msg.parts {
if let Some(values) = part.values() {
for (e, _) in values {
assert!(
alice_sent.insert(e.key(), e).is_none(),
"alice: duplicate {:?}",
e
);
}
}
}
}
let mut bob_sent = BTreeMap::new();
for msg in &res.bob_to_alice {
for part in &msg.parts {
if let Some(values) = part.values() {
for (e, _) in values {
assert!(
bob_sent.insert(e.key(), e).is_none(),
"bob: duplicate {:?}",
e
);
}
}
}
}
res
}
fn sync_exchange_messages<K, V, F1, F2>(
mut alice: SimpleStore<K, V>,
mut bob: SimpleStore<K, V>,
alice_validate_cb: F1,
bob_validate_cb: F2,
max_rounds: usize,
) -> SyncResult<K, V>
where
K: RangeKey + Default,
V: RangeValue,
F1: Fn(&SimpleStore<K, V>, &(K, V), ContentStatus) -> bool,
F2: Fn(&SimpleStore<K, V>, &(K, V), ContentStatus) -> bool,
{
let mut alice_to_bob = Vec::new();
let mut bob_to_alice = Vec::new();
let initial_message = alice.initial_message().unwrap();
let mut next_to_bob = Some(initial_message);
let mut rounds = 0;
while let Some(msg) = next_to_bob.take() {
assert!(rounds < max_rounds, "too many rounds");
rounds += 1;
alice_to_bob.push(msg.clone());
if let Some(msg) = bob
.process_message(
&Default::default(),
msg,
&bob_validate_cb,
|_, _, _| (),
|_, _| ContentStatus::Complete,
)
.unwrap()
{
bob_to_alice.push(msg.clone());
next_to_bob = alice
.process_message(
&Default::default(),
msg,
&alice_validate_cb,
|_, _, _| (),
|_, _| ContentStatus::Complete,
)
.unwrap();
}
}
SyncResult {
alice,
bob,
alice_to_bob,
bob_to_alice,
}
}
#[test]
fn store_get_range() {
let mut store = SimpleStore::<&'static str, i32>::default();
let set = [
("bee", 1),
("cat", 1),
("doe", 1),
("eel", 1),
("fox", 1),
("hog", 1),
];
for (k, v) in &set {
store.entry_put((*k, *v)).unwrap();
}
let all: Vec<_> = store
.get_range(Range::new("", ""))
.unwrap()
.collect::<Result<_, Infallible>>()
.unwrap();
assert_eq!(&all, &set[..]);
let regular: Vec<_> = store
.get_range(("bee", "eel").into())
.unwrap()
.collect::<Result<_, Infallible>>()
.unwrap();
assert_eq!(®ular, &set[..3]);
let regular: Vec<_> = store
.get_range(("", "eel").into())
.unwrap()
.collect::<Result<_, Infallible>>()
.unwrap();
assert_eq!(®ular, &set[..3]);
let regular: Vec<_> = store
.get_range(("cat", "hog").into())
.unwrap()
.collect::<Result<_, Infallible>>()
.unwrap();
assert_eq!(®ular, &set[1..5]);
let excluded: Vec<_> = store
.get_range(("fox", "bee").into())
.unwrap()
.collect::<Result<_, Infallible>>()
.unwrap();
assert_eq!(excluded[0].0, "fox");
assert_eq!(excluded[1].0, "hog");
assert_eq!(excluded.len(), 2);
let excluded: Vec<_> = store
.get_range(("fox", "doe").into())
.unwrap()
.collect::<Result<_, Infallible>>()
.unwrap();
assert_eq!(excluded.len(), 4);
assert_eq!(excluded[0].0, "bee");
assert_eq!(excluded[1].0, "cat");
assert_eq!(excluded[2].0, "fox");
assert_eq!(excluded[3].0, "hog");
}
type TestSetStringUnit = BTreeMap<String, ()>;
type TestSetStringU8 = BTreeMap<String, u8>;
fn test_key() -> impl Strategy<Value = String> {
"[a-z0-9]{0,5}"
}
fn test_set_string_unit() -> impl Strategy<Value = TestSetStringUnit> {
prop::collection::btree_map(test_key(), Just(()), 0..10)
}
fn test_set_string_u8() -> impl Strategy<Value = TestSetStringU8> {
prop::collection::btree_map(test_key(), test_value_u8(), 0..10)
}
fn test_value_u8() -> impl Strategy<Value = u8> {
0u8..u8::MAX
}
fn test_vec_string_unit() -> impl Strategy<Value = Vec<(String, ())>> {
test_set_string_unit().prop_map(|m| m.into_iter().collect::<Vec<_>>())
}
fn test_vec_string_u8() -> impl Strategy<Value = Vec<(String, u8)>> {
test_set_string_u8().prop_map(|m| m.into_iter().collect::<Vec<_>>())
}
fn test_range() -> impl Strategy<Value = Range<String>> {
(test_key(), test_key()).prop_map(|(x, y)| Range::new(x, y))
}
fn mk_test_set(values: impl IntoIterator<Item = impl AsRef<str>>) -> TestSetStringUnit {
values
.into_iter()
.map(|v| v.as_ref().to_string())
.map(|k| (k, ()))
.collect()
}
fn mk_test_vec(values: impl IntoIterator<Item = impl AsRef<str>>) -> Vec<(String, ())> {
mk_test_set(values).into_iter().collect()
}
#[test]
fn simple_store_sync_1() {
let alice = mk_test_vec(["3"]);
let bob = mk_test_vec(["2", "3", "4", "5", "6", "7", "8"]);
let _res = sync(&alice, &bob);
}
#[test]
fn simple_store_sync_x() {
let alice = mk_test_vec(["1", "3"]);
let bob = mk_test_vec(["2"]);
let _res = sync(&alice, &bob);
}
#[test]
fn simple_store_sync_2() {
let alice = mk_test_vec(["1", "3"]);
let bob = mk_test_vec(["0", "2", "3"]);
let _res = sync(&alice, &bob);
}
#[test]
fn simple_store_sync_3() {
let alice = mk_test_vec(["8", "9"]);
let bob = mk_test_vec(["1", "2", "3"]);
let _res = sync(&alice, &bob);
}
#[proptest]
fn simple_store_sync(
#[strategy(test_vec_string_unit())] alice: Vec<(String, ())>,
#[strategy(test_vec_string_unit())] bob: Vec<(String, ())>,
) {
let _res = sync(&alice, &bob);
}
#[proptest]
fn simple_store_sync_u8(
#[strategy(test_vec_string_u8())] alice: Vec<(String, u8)>,
#[strategy(test_vec_string_u8())] bob: Vec<(String, u8)>,
) {
let _res = sync(&alice, &bob);
}
#[allow(clippy::type_complexity)]
fn store_get_ranges_test<S, E>(
elems: impl IntoIterator<Item = E>,
range: Range<E::Key>,
) -> (Vec<E>, Vec<E>)
where
S: Store<E> + Default,
E: RangeEntry,
{
let mut store = S::default();
let elems = elems.into_iter().collect::<Vec<_>>();
for e in elems.iter().cloned() {
store.entry_put(e).unwrap();
}
let mut actual = store
.get_range(range.clone())
.unwrap()
.collect::<std::result::Result<Vec<_>, S::Error>>()
.unwrap();
let mut expected = elems
.into_iter()
.filter(|e| range.contains(e.key()))
.collect::<Vec<_>>();
actual.sort_by(|a, b| a.key().cmp(b.key()));
expected.sort_by(|a, b| a.key().cmp(b.key()));
(expected, actual)
}
#[proptest]
fn simple_store_get_ranges(
#[strategy(test_set_string_unit())] contents: BTreeMap<String, ()>,
#[strategy(test_range())] range: Range<String>,
) {
let (expected, actual) = store_get_ranges_test::<SimpleStore<_, _>, _>(contents, range);
prop_assert_eq!(expected, actual);
}
}