iroh_quinn_proto/
token_memory_cache.rs

1//! Storing tokens sent from servers in NEW_TOKEN frames and using them in subsequent connections
2
3use std::{
4    collections::{HashMap, VecDeque, hash_map},
5    sync::{Arc, Mutex},
6};
7
8use bytes::Bytes;
9use lru_slab::LruSlab;
10use tracing::trace;
11
12use crate::token::TokenStore;
13
14/// `TokenStore` implementation that stores up to `N` tokens per server name for up to a
15/// limited number of server names, in-memory
16#[derive(Debug)]
17pub struct TokenMemoryCache(Mutex<State>);
18
19impl TokenMemoryCache {
20    /// Construct empty
21    pub fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
22        Self(Mutex::new(State::new(
23            max_server_names,
24            max_tokens_per_server,
25        )))
26    }
27}
28
29impl TokenStore for TokenMemoryCache {
30    fn insert(&self, server_name: &str, token: Bytes) {
31        trace!(%server_name, "storing token");
32        self.0.lock().unwrap().store(server_name, token)
33    }
34
35    fn take(&self, server_name: &str) -> Option<Bytes> {
36        let token = self.0.lock().unwrap().take(server_name);
37        trace!(%server_name, found=%token.is_some(), "taking token");
38        token
39    }
40}
41
42/// Defaults to a maximum of 256 servers and 2 tokens per server
43impl Default for TokenMemoryCache {
44    fn default() -> Self {
45        Self::new(256, 2)
46    }
47}
48
49/// Lockable inner state of `TokenMemoryCache`
50#[derive(Debug)]
51struct State {
52    max_server_names: u32,
53    max_tokens_per_server: usize,
54    // map from server name to index in lru
55    lookup: HashMap<Arc<str>, u32>,
56    lru: LruSlab<CacheEntry>,
57}
58
59impl State {
60    fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
61        Self {
62            max_server_names,
63            max_tokens_per_server,
64            lookup: HashMap::new(),
65            lru: LruSlab::default(),
66        }
67    }
68
69    fn store(&mut self, server_name: &str, token: Bytes) {
70        if self.max_server_names == 0 {
71            // the rest of this method assumes that we can always insert a new entry so long as
72            // we're willing to evict a pre-existing entry. thus, an entry limit of 0 is an edge
73            // case we must short-circuit on now.
74            return;
75        }
76        if self.max_tokens_per_server == 0 {
77            // similarly to above, the rest of this method assumes that we can always push a new
78            // token to a queue so long as we're willing to evict a pre-existing token, so we
79            // short-circuit on the edge case of a token limit of 0.
80            return;
81        }
82
83        let server_name = Arc::<str>::from(server_name);
84        match self.lookup.entry(server_name.clone()) {
85            hash_map::Entry::Occupied(hmap_entry) => {
86                // key already exists, push the new token to its token queue
87                let tokens = &mut self.lru.get_mut(*hmap_entry.get()).tokens;
88                if tokens.len() >= self.max_tokens_per_server {
89                    debug_assert!(tokens.len() == self.max_tokens_per_server);
90                    tokens.pop_front().unwrap();
91                }
92                tokens.push_back(token);
93            }
94            hash_map::Entry::Vacant(hmap_entry) => {
95                // key does not yet exist, create a new one, evicting the oldest if necessary
96                let removed_key = if self.lru.len() >= self.max_server_names {
97                    // unwrap safety: max_server_names is > 0, so there's at least one entry, so
98                    //                lru() is some
99                    Some(self.lru.remove(self.lru.lru().unwrap()).server_name)
100                } else {
101                    None
102                };
103
104                hmap_entry.insert(self.lru.insert(CacheEntry::new(server_name, token)));
105
106                // for borrowing reasons, we must defer removing the evicted hmap entry to here
107                if let Some(removed_slot) = removed_key {
108                    let removed = self.lookup.remove(&removed_slot);
109                    debug_assert!(removed.is_some());
110                }
111            }
112        };
113    }
114
115    fn take(&mut self, server_name: &str) -> Option<Bytes> {
116        let slab_key = *self.lookup.get(server_name)?;
117
118        // pop from entry's token queue
119        let entry = self.lru.get_mut(slab_key);
120        // unwrap safety: we never leave tokens empty
121        let token = entry.tokens.pop_front().unwrap();
122
123        if entry.tokens.is_empty() {
124            // token stack emptied, remove entry
125            self.lru.remove(slab_key);
126            self.lookup.remove(server_name);
127        }
128
129        Some(token)
130    }
131}
132
133/// Cache entry within `TokenMemoryCache`'s LRU slab
134#[derive(Debug)]
135struct CacheEntry {
136    server_name: Arc<str>,
137    // invariant: tokens is never empty
138    tokens: VecDeque<Bytes>,
139}
140
141impl CacheEntry {
142    /// Construct with a single token
143    fn new(server_name: Arc<str>, token: Bytes) -> Self {
144        let mut tokens = VecDeque::new();
145        tokens.push_back(token);
146        Self {
147            server_name,
148            tokens,
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use std::collections::VecDeque;
156
157    use super::*;
158    use rand::prelude::*;
159    use rand_pcg::Pcg32;
160    use tracing::info;
161
162    fn new_rng() -> impl Rng {
163        Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeefu128.to_le_bytes())
164    }
165
166    #[test]
167    fn cache_test() {
168        let mut rng = new_rng();
169        const N: usize = 2;
170
171        for _ in 0..10 {
172            let mut cache_1: Vec<(u32, VecDeque<Bytes>)> = Vec::new(); // keep it sorted oldest to newest
173            let cache_2 = TokenMemoryCache::new(20, 2);
174
175            for i in 0..200 {
176                let server_name = rng.random::<u32>() % 10;
177                if rng.random_bool(0.666) {
178                    // store
179                    let token = Bytes::from(vec![i]);
180                    info!("STORE {server_name} {token:?}");
181                    if let Some((j, _)) = cache_1
182                        .iter()
183                        .enumerate()
184                        .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
185                    {
186                        let (_, mut queue) = cache_1.remove(j);
187                        queue.push_back(token.clone());
188                        if queue.len() > N {
189                            queue.pop_front();
190                        }
191                        cache_1.push((server_name, queue));
192                    } else {
193                        let mut queue = VecDeque::new();
194                        queue.push_back(token.clone());
195                        cache_1.push((server_name, queue));
196                        if cache_1.len() > 20 {
197                            cache_1.remove(0);
198                        }
199                    }
200                    cache_2.insert(&server_name.to_string(), token);
201                } else {
202                    // take
203                    info!("TAKE {server_name}");
204                    let expecting = cache_1
205                        .iter()
206                        .enumerate()
207                        .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
208                        .map(|(j, _)| j)
209                        .map(|j| {
210                            let (_, mut queue) = cache_1.remove(j);
211                            let token = queue.pop_front().unwrap();
212                            if !queue.is_empty() {
213                                cache_1.push((server_name, queue));
214                            }
215                            token
216                        });
217                    info!("EXPECTING {expecting:?}");
218                    assert_eq!(cache_2.take(&server_name.to_string()), expecting);
219                }
220            }
221        }
222    }
223
224    #[test]
225    fn zero_max_server_names() {
226        // test that this edge case doesn't panic
227        let cache = TokenMemoryCache::new(0, 2);
228        for i in 0..10 {
229            cache.insert(&i.to_string(), Bytes::from(vec![i]));
230            for j in 0..10 {
231                assert!(cache.take(&j.to_string()).is_none());
232            }
233        }
234    }
235
236    #[test]
237    fn zero_queue_length() {
238        // test that this edge case doesn't panic
239        let cache = TokenMemoryCache::new(256, 0);
240        for i in 0..10 {
241            cache.insert(&i.to_string(), Bytes::from(vec![i]));
242            for j in 0..10 {
243                assert!(cache.take(&j.to_string()).is_none());
244            }
245        }
246    }
247}