uselesskey_core/srp/
cache.rs1#[cfg(not(feature = "std"))]
10use alloc::collections::BTreeMap;
11use alloc::sync::Arc;
12use core::any::Any;
13use core::fmt;
14#[cfg(feature = "std")]
15use dashmap::DashMap;
16#[cfg(not(feature = "std"))]
17use spin::Mutex;
18
19use crate::srp::identity::ArtifactId;
20
21type CacheValue = Arc<dyn Any + Send + Sync>;
22
23#[cfg(feature = "std")]
24type Cache = DashMap<ArtifactId, CacheValue>;
25
26#[cfg(not(feature = "std"))]
27type Cache = Mutex<BTreeMap<ArtifactId, CacheValue>>;
28
29pub struct ArtifactCache {
47 inner: Cache,
48}
49
50impl fmt::Debug for ArtifactCache {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("ArtifactCache")
53 .field("len", &self.len())
54 .finish()
55 }
56}
57
58impl ArtifactCache {
59 pub fn new() -> Self {
61 Self { inner: new_cache() }
62 }
63
64 pub fn len(&self) -> usize {
66 cache_len(&self.inner)
67 }
68
69 pub fn is_empty(&self) -> bool {
71 self.len() == 0
72 }
73
74 pub fn clear(&self) {
76 cache_clear(&self.inner);
77 }
78
79 pub fn get_typed<T>(&self, id: &ArtifactId) -> Option<Arc<T>>
83 where
84 T: Any + Send + Sync + 'static,
85 {
86 cache_get(&self.inner, id).map(|entry| downcast_or_panic::<T>(entry, id))
87 }
88
89 pub fn insert_if_absent_typed<T>(&self, id: ArtifactId, value: Arc<T>) -> Arc<T>
93 where
94 T: Any + Send + Sync + 'static,
95 {
96 let value_any: CacheValue = value;
97 let winner = cache_insert_if_absent(&self.inner, id.clone(), value_any);
98 downcast_or_panic::<T>(winner, &id)
99 }
100}
101
102impl Default for ArtifactCache {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108#[cfg(feature = "std")]
109fn new_cache() -> Cache {
110 DashMap::new()
111}
112
113#[cfg(not(feature = "std"))]
114fn new_cache() -> Cache {
115 Mutex::new(BTreeMap::new())
116}
117
118#[cfg(feature = "std")]
119fn cache_len(cache: &Cache) -> usize {
120 cache.len()
121}
122
123#[cfg(not(feature = "std"))]
124fn cache_len(cache: &Cache) -> usize {
125 cache.lock().len()
126}
127
128#[cfg(feature = "std")]
129fn cache_clear(cache: &Cache) {
130 cache.clear();
131}
132
133#[cfg(not(feature = "std"))]
134fn cache_clear(cache: &Cache) {
135 cache.lock().clear();
136}
137
138#[cfg(feature = "std")]
139fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
140 cache.get(id).map(|entry| entry.value().clone())
141}
142
143#[cfg(not(feature = "std"))]
144fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
145 cache.lock().get(id).cloned()
146}
147
148#[cfg(feature = "std")]
149fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
150 cache.entry(id).or_insert(value).value().clone()
151}
152
153#[cfg(not(feature = "std"))]
154fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
155 use alloc::collections::btree_map::Entry;
156
157 let mut guard = cache.lock();
158 match guard.entry(id) {
159 Entry::Vacant(slot) => slot.insert(value).clone(),
160 Entry::Occupied(slot) => slot.get().clone(),
161 }
162}
163
164pub fn downcast_or_panic<T>(arc_any: CacheValue, id: &ArtifactId) -> Arc<T>
168where
169 T: Any + Send + Sync + 'static,
170{
171 match arc_any.downcast::<T>() {
172 Ok(v) => v,
173 Err(_) => {
174 panic!(
175 "uselesskey-core-cache: artifact type mismatch for domain={} label={} variant={}",
176 id.domain, id.label, id.variant
177 );
178 }
179 }
180}
181
182#[cfg(all(test, feature = "std"))]
183mod tests {
184 use super::{ArtifactCache, downcast_or_panic};
185 use crate::srp::identity::{ArtifactId, DerivationVersion};
186 use core::any::Any;
187 use std::panic::{AssertUnwindSafe, catch_unwind};
188 use std::sync::Arc;
189
190 fn sample_id() -> ArtifactId {
191 ArtifactId::new(
192 "domain:test",
193 "label",
194 b"spec",
195 "good",
196 DerivationVersion::V1,
197 )
198 }
199
200 #[test]
201 fn typed_round_trip() {
202 let cache = ArtifactCache::new();
203 let id = sample_id();
204
205 let inserted = cache.insert_if_absent_typed(id.clone(), Arc::new(7u32));
206 let fetched = cache
207 .get_typed::<u32>(&id)
208 .expect("value should be retrievable");
209
210 assert_eq!(*inserted, 7);
211 assert_eq!(*fetched, 7);
212 }
213
214 #[test]
215 fn insert_if_absent_keeps_first_value() {
216 let cache = ArtifactCache::new();
217 let id = sample_id();
218
219 let first = cache.insert_if_absent_typed(id.clone(), Arc::new(11u32));
220 let second = cache.insert_if_absent_typed(id, Arc::new(22u32));
221
222 assert!(Arc::ptr_eq(&first, &second));
223 assert_eq!(*second, 11u32);
224 }
225
226 #[test]
227 fn clear_empties_cache() {
228 let cache = ArtifactCache::new();
229 let id = sample_id();
230
231 cache.insert_if_absent_typed(id, Arc::new(1u8));
232 assert_eq!(cache.len(), 1);
233 assert!(!cache.is_empty());
234
235 cache.clear();
236 assert_eq!(cache.len(), 0);
237 assert!(cache.is_empty());
238 }
239
240 #[test]
241 fn debug_includes_type_name_and_len() {
242 let cache = ArtifactCache::new();
243 cache.insert_if_absent_typed(sample_id(), Arc::new(1u8));
244
245 let dbg = format!("{cache:?}");
246 assert!(
247 dbg.contains("ArtifactCache"),
248 "debug output should include struct name"
249 );
250 assert!(dbg.contains("len: 1"), "debug output should include len");
251 }
252
253 #[test]
254 fn get_typed_type_mismatch_panics() {
255 let cache = ArtifactCache::new();
256 let id = sample_id();
257 let _ = cache.insert_if_absent_typed(id.clone(), Arc::new(123u32));
258
259 let result = catch_unwind(AssertUnwindSafe(|| {
260 let _ = cache.get_typed::<String>(&id);
261 }));
262
263 assert!(result.is_err(), "expected panic on type mismatch");
264 }
265
266 #[test]
267 fn downcast_or_panic_type_mismatch_panics() {
268 let id = sample_id();
269 let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
270 let result = catch_unwind(AssertUnwindSafe(|| {
271 let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
272 }));
273
274 assert!(result.is_err(), "expected panic on type mismatch");
275 }
276
277 #[test]
278 fn downcast_or_panic_ok_returns_value() {
279 let id = sample_id();
280 let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
281 let arc = downcast_or_panic::<u32>(arc_any, &id);
282 assert_eq!(*arc, 123u32);
283 }
284
285 #[test]
286 fn default_creates_empty_cache() {
287 let cache = ArtifactCache::default();
288 assert!(cache.is_empty());
289 assert_eq!(cache.len(), 0);
290 }
291
292 #[test]
293 fn get_typed_missing_key_returns_none() {
294 let cache = ArtifactCache::new();
295 let id = sample_id();
296 assert!(cache.get_typed::<u32>(&id).is_none());
297 }
298
299 #[test]
300 fn distinct_ids_are_stored_independently() {
301 let cache = ArtifactCache::new();
302 let id_a = ArtifactId::new("domain:a", "label", b"spec", "good", DerivationVersion::V1);
303 let id_b = ArtifactId::new("domain:b", "label", b"spec", "good", DerivationVersion::V1);
304
305 cache.insert_if_absent_typed(id_a.clone(), Arc::new(1u32));
306 cache.insert_if_absent_typed(id_b.clone(), Arc::new(2u32));
307
308 assert_eq!(cache.len(), 2);
309 assert_eq!(*cache.get_typed::<u32>(&id_a).unwrap(), 1);
310 assert_eq!(*cache.get_typed::<u32>(&id_b).unwrap(), 2);
311 }
312
313 #[test]
314 fn concurrent_inserts_converge() {
315 use std::thread;
316
317 let cache = Arc::new(ArtifactCache::new());
318 let id = sample_id();
319
320 let handles: Vec<_> = (0..8)
321 .map(|i| {
322 let cache = Arc::clone(&cache);
323 let id = id.clone();
324 thread::spawn(move || cache.insert_if_absent_typed(id, Arc::new(i as u32)))
325 })
326 .collect();
327
328 let results: Vec<u32> = handles.into_iter().map(|h| *h.join().unwrap()).collect();
329
330 let first = results[0];
332 assert!(results.iter().all(|v| *v == first));
333 assert_eq!(cache.len(), 1);
334 }
335
336 #[test]
337 fn downcast_or_panic_message_contains_id_fields() {
338 let id = ArtifactId::new(
339 "domain:msg",
340 "my-label",
341 b"spec",
342 "my-variant",
343 DerivationVersion::V1,
344 );
345 let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(42u32);
346 let result = catch_unwind(AssertUnwindSafe(|| {
347 let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
348 }));
349 let err = result.unwrap_err();
350 let msg = err.downcast_ref::<String>().unwrap();
351 assert!(msg.contains("domain:msg"), "panic should mention domain");
352 assert!(msg.contains("my-label"), "panic should mention label");
353 assert!(msg.contains("my-variant"), "panic should mention variant");
354 }
355}