iroh_quinn_proto/
token_memory_cache.rs1use std::{
4 collections::{HashMap, VecDeque, hash_map},
5 sync::{Arc, Mutex},
6};
7
8use bytes::Bytes;
9use lru_slab::LruSlab;
10use tracing::trace;
11
12use crate::token::TokenStore;
13
14#[derive(Debug)]
17pub struct TokenMemoryCache(Mutex<State>);
18
19impl TokenMemoryCache {
20 pub fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
22 Self(Mutex::new(State::new(
23 max_server_names,
24 max_tokens_per_server,
25 )))
26 }
27}
28
29impl TokenStore for TokenMemoryCache {
30 fn insert(&self, server_name: &str, token: Bytes) {
31 trace!(%server_name, "storing token");
32 self.0.lock().unwrap().store(server_name, token)
33 }
34
35 fn take(&self, server_name: &str) -> Option<Bytes> {
36 let token = self.0.lock().unwrap().take(server_name);
37 trace!(%server_name, found=%token.is_some(), "taking token");
38 token
39 }
40}
41
42impl Default for TokenMemoryCache {
44 fn default() -> Self {
45 Self::new(256, 2)
46 }
47}
48
49#[derive(Debug)]
51struct State {
52 max_server_names: u32,
53 max_tokens_per_server: usize,
54 lookup: HashMap<Arc<str>, u32>,
56 lru: LruSlab<CacheEntry>,
57}
58
59impl State {
60 fn new(max_server_names: u32, max_tokens_per_server: usize) -> Self {
61 Self {
62 max_server_names,
63 max_tokens_per_server,
64 lookup: HashMap::new(),
65 lru: LruSlab::default(),
66 }
67 }
68
69 fn store(&mut self, server_name: &str, token: Bytes) {
70 if self.max_server_names == 0 {
71 return;
75 }
76 if self.max_tokens_per_server == 0 {
77 return;
81 }
82
83 let server_name = Arc::<str>::from(server_name);
84 match self.lookup.entry(server_name.clone()) {
85 hash_map::Entry::Occupied(hmap_entry) => {
86 let tokens = &mut self.lru.get_mut(*hmap_entry.get()).tokens;
88 if tokens.len() >= self.max_tokens_per_server {
89 debug_assert!(tokens.len() == self.max_tokens_per_server);
90 tokens.pop_front().unwrap();
91 }
92 tokens.push_back(token);
93 }
94 hash_map::Entry::Vacant(hmap_entry) => {
95 let removed_key = if self.lru.len() >= self.max_server_names {
97 Some(self.lru.remove(self.lru.lru().unwrap()).server_name)
100 } else {
101 None
102 };
103
104 hmap_entry.insert(self.lru.insert(CacheEntry::new(server_name, token)));
105
106 if let Some(removed_slot) = removed_key {
108 let removed = self.lookup.remove(&removed_slot);
109 debug_assert!(removed.is_some());
110 }
111 }
112 };
113 }
114
115 fn take(&mut self, server_name: &str) -> Option<Bytes> {
116 let slab_key = *self.lookup.get(server_name)?;
117
118 let entry = self.lru.get_mut(slab_key);
120 let token = entry.tokens.pop_front().unwrap();
122
123 if entry.tokens.is_empty() {
124 self.lru.remove(slab_key);
126 self.lookup.remove(server_name);
127 }
128
129 Some(token)
130 }
131}
132
133#[derive(Debug)]
135struct CacheEntry {
136 server_name: Arc<str>,
137 tokens: VecDeque<Bytes>,
139}
140
141impl CacheEntry {
142 fn new(server_name: Arc<str>, token: Bytes) -> Self {
144 let mut tokens = VecDeque::new();
145 tokens.push_back(token);
146 Self {
147 server_name,
148 tokens,
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use std::collections::VecDeque;
156
157 use super::*;
158 use rand::prelude::*;
159 use rand_pcg::Pcg32;
160 use tracing::info;
161
162 fn new_rng() -> impl Rng {
163 Pcg32::from_seed(0xdeadbeefdeadbeefdeadbeefdeadbeefu128.to_le_bytes())
164 }
165
166 #[test]
167 fn cache_test() {
168 let mut rng = new_rng();
169 const N: usize = 2;
170
171 for _ in 0..10 {
172 let mut cache_1: Vec<(u32, VecDeque<Bytes>)> = Vec::new(); let cache_2 = TokenMemoryCache::new(20, 2);
174
175 for i in 0..200 {
176 let server_name = rng.random::<u32>() % 10;
177 if rng.random_bool(0.666) {
178 let token = Bytes::from(vec![i]);
180 info!("STORE {server_name} {token:?}");
181 if let Some((j, _)) = cache_1
182 .iter()
183 .enumerate()
184 .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
185 {
186 let (_, mut queue) = cache_1.remove(j);
187 queue.push_back(token.clone());
188 if queue.len() > N {
189 queue.pop_front();
190 }
191 cache_1.push((server_name, queue));
192 } else {
193 let mut queue = VecDeque::new();
194 queue.push_back(token.clone());
195 cache_1.push((server_name, queue));
196 if cache_1.len() > 20 {
197 cache_1.remove(0);
198 }
199 }
200 cache_2.insert(&server_name.to_string(), token);
201 } else {
202 info!("TAKE {server_name}");
204 let expecting = cache_1
205 .iter()
206 .enumerate()
207 .find(|&(_, &(server_name_2, _))| server_name_2 == server_name)
208 .map(|(j, _)| j)
209 .map(|j| {
210 let (_, mut queue) = cache_1.remove(j);
211 let token = queue.pop_front().unwrap();
212 if !queue.is_empty() {
213 cache_1.push((server_name, queue));
214 }
215 token
216 });
217 info!("EXPECTING {expecting:?}");
218 assert_eq!(cache_2.take(&server_name.to_string()), expecting);
219 }
220 }
221 }
222 }
223
224 #[test]
225 fn zero_max_server_names() {
226 let cache = TokenMemoryCache::new(0, 2);
228 for i in 0..10 {
229 cache.insert(&i.to_string(), Bytes::from(vec![i]));
230 for j in 0..10 {
231 assert!(cache.take(&j.to_string()).is_none());
232 }
233 }
234 }
235
236 #[test]
237 fn zero_queue_length() {
238 let cache = TokenMemoryCache::new(256, 0);
240 for i in 0..10 {
241 cache.insert(&i.to_string(), Bytes::from(vec![i]));
242 for j in 0..10 {
243 assert!(cache.take(&j.to_string()).is_none());
244 }
245 }
246 }
247}