Skip to main content

uselesskey_core/srp/
cache.rs

1//! Per-factory artifact cache keyed by [`ArtifactId`].
2//!
3//! Stores generated fixtures behind `Arc<dyn Any>` so expensive key generation
4//! (especially RSA) only happens once per unique identity tuple. Thread-safe:
5//! uses `DashMap` with `std`, `spin::Mutex` without.
6//!
7//! The primary type is [`ArtifactCache`].
8
9#[cfg(not(feature = "std"))]
10use alloc::collections::BTreeMap;
11use alloc::sync::Arc;
12use core::any::Any;
13use core::fmt;
14#[cfg(feature = "std")]
15use dashmap::DashMap;
16#[cfg(not(feature = "std"))]
17use spin::Mutex;
18
19use crate::srp::identity::ArtifactId;
20
21type CacheValue = Arc<dyn Any + Send + Sync>;
22
23#[cfg(feature = "std")]
24type Cache = DashMap<ArtifactId, CacheValue>;
25
26#[cfg(not(feature = "std"))]
27type Cache = Mutex<BTreeMap<ArtifactId, CacheValue>>;
28
29/// Cache keyed by [`ArtifactId`] that stores typed values behind `Arc<dyn Any>`.
30///
31/// # Examples
32///
33/// ```
34/// use std::sync::Arc;
35/// use uselesskey_core::srp::cache::ArtifactCache;
36/// use uselesskey_core::srp::identity::{ArtifactId, DerivationVersion};
37///
38/// let cache = ArtifactCache::new();
39/// let id = ArtifactId::new("domain:rsa", "issuer", b"RS256", "good", DerivationVersion::V1);
40///
41/// // Insert once, retrieve many times
42/// cache.insert_if_absent_typed(id.clone(), Arc::new(42u32));
43/// let value = cache.get_typed::<u32>(&id).unwrap();
44/// assert_eq!(*value, 42);
45/// ```
46pub struct ArtifactCache {
47    inner: Cache,
48}
49
50impl fmt::Debug for ArtifactCache {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("ArtifactCache")
53            .field("len", &self.len())
54            .finish()
55    }
56}
57
58impl ArtifactCache {
59    /// Create an empty artifact cache.
60    pub fn new() -> Self {
61        Self { inner: new_cache() }
62    }
63
64    /// Number of cache entries.
65    pub fn len(&self) -> usize {
66        cache_len(&self.inner)
67    }
68
69    /// Returns `true` when there are no entries.
70    pub fn is_empty(&self) -> bool {
71        self.len() == 0
72    }
73
74    /// Remove all entries from the cache.
75    pub fn clear(&self) {
76        cache_clear(&self.inner);
77    }
78
79    /// Retrieve a typed value by id.
80    ///
81    /// Panics if the id exists with a different concrete type.
82    pub fn get_typed<T>(&self, id: &ArtifactId) -> Option<Arc<T>>
83    where
84        T: Any + Send + Sync + 'static,
85    {
86        cache_get(&self.inner, id).map(|entry| downcast_or_panic::<T>(entry, id))
87    }
88
89    /// Insert a typed value if the id is vacant and return the winning cached value.
90    ///
91    /// Panics if an existing value for the same id has a different concrete type.
92    pub fn insert_if_absent_typed<T>(&self, id: ArtifactId, value: Arc<T>) -> Arc<T>
93    where
94        T: Any + Send + Sync + 'static,
95    {
96        let value_any: CacheValue = value;
97        let winner = cache_insert_if_absent(&self.inner, id.clone(), value_any);
98        downcast_or_panic::<T>(winner, &id)
99    }
100}
101
102impl Default for ArtifactCache {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108#[cfg(feature = "std")]
109fn new_cache() -> Cache {
110    DashMap::new()
111}
112
113#[cfg(not(feature = "std"))]
114fn new_cache() -> Cache {
115    Mutex::new(BTreeMap::new())
116}
117
118#[cfg(feature = "std")]
119fn cache_len(cache: &Cache) -> usize {
120    cache.len()
121}
122
123#[cfg(not(feature = "std"))]
124fn cache_len(cache: &Cache) -> usize {
125    cache.lock().len()
126}
127
128#[cfg(feature = "std")]
129fn cache_clear(cache: &Cache) {
130    cache.clear();
131}
132
133#[cfg(not(feature = "std"))]
134fn cache_clear(cache: &Cache) {
135    cache.lock().clear();
136}
137
138#[cfg(feature = "std")]
139fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
140    cache.get(id).map(|entry| entry.value().clone())
141}
142
143#[cfg(not(feature = "std"))]
144fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
145    cache.lock().get(id).cloned()
146}
147
148#[cfg(feature = "std")]
149fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
150    cache.entry(id).or_insert(value).value().clone()
151}
152
153#[cfg(not(feature = "std"))]
154fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
155    use alloc::collections::btree_map::Entry;
156
157    let mut guard = cache.lock();
158    match guard.entry(id) {
159        Entry::Vacant(slot) => slot.insert(value).clone(),
160        Entry::Occupied(slot) => slot.get().clone(),
161    }
162}
163
164/// Downcast a cached `Any` value to the expected fixture type.
165///
166/// Panics when the cache key maps to a different concrete type.
167pub fn downcast_or_panic<T>(arc_any: CacheValue, id: &ArtifactId) -> Arc<T>
168where
169    T: Any + Send + Sync + 'static,
170{
171    match arc_any.downcast::<T>() {
172        Ok(v) => v,
173        Err(_) => {
174            panic!(
175                "uselesskey-core-cache: artifact type mismatch for domain={} label={} variant={}",
176                id.domain, id.label, id.variant
177            );
178        }
179    }
180}
181
182#[cfg(all(test, feature = "std"))]
183mod tests {
184    use super::{ArtifactCache, downcast_or_panic};
185    use crate::srp::identity::{ArtifactId, DerivationVersion};
186    use core::any::Any;
187    use std::panic::{AssertUnwindSafe, catch_unwind};
188    use std::sync::Arc;
189
190    fn sample_id() -> ArtifactId {
191        ArtifactId::new(
192            "domain:test",
193            "label",
194            b"spec",
195            "good",
196            DerivationVersion::V1,
197        )
198    }
199
200    #[test]
201    fn typed_round_trip() {
202        let cache = ArtifactCache::new();
203        let id = sample_id();
204
205        let inserted = cache.insert_if_absent_typed(id.clone(), Arc::new(7u32));
206        let fetched = cache
207            .get_typed::<u32>(&id)
208            .expect("value should be retrievable");
209
210        assert_eq!(*inserted, 7);
211        assert_eq!(*fetched, 7);
212    }
213
214    #[test]
215    fn insert_if_absent_keeps_first_value() {
216        let cache = ArtifactCache::new();
217        let id = sample_id();
218
219        let first = cache.insert_if_absent_typed(id.clone(), Arc::new(11u32));
220        let second = cache.insert_if_absent_typed(id, Arc::new(22u32));
221
222        assert!(Arc::ptr_eq(&first, &second));
223        assert_eq!(*second, 11u32);
224    }
225
226    #[test]
227    fn clear_empties_cache() {
228        let cache = ArtifactCache::new();
229        let id = sample_id();
230
231        cache.insert_if_absent_typed(id, Arc::new(1u8));
232        assert_eq!(cache.len(), 1);
233        assert!(!cache.is_empty());
234
235        cache.clear();
236        assert_eq!(cache.len(), 0);
237        assert!(cache.is_empty());
238    }
239
240    #[test]
241    fn debug_includes_type_name_and_len() {
242        let cache = ArtifactCache::new();
243        cache.insert_if_absent_typed(sample_id(), Arc::new(1u8));
244
245        let dbg = format!("{cache:?}");
246        assert!(
247            dbg.contains("ArtifactCache"),
248            "debug output should include struct name"
249        );
250        assert!(dbg.contains("len: 1"), "debug output should include len");
251    }
252
253    #[test]
254    fn get_typed_type_mismatch_panics() {
255        let cache = ArtifactCache::new();
256        let id = sample_id();
257        let _ = cache.insert_if_absent_typed(id.clone(), Arc::new(123u32));
258
259        let result = catch_unwind(AssertUnwindSafe(|| {
260            let _ = cache.get_typed::<String>(&id);
261        }));
262
263        assert!(result.is_err(), "expected panic on type mismatch");
264    }
265
266    #[test]
267    fn downcast_or_panic_type_mismatch_panics() {
268        let id = sample_id();
269        let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
270        let result = catch_unwind(AssertUnwindSafe(|| {
271            let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
272        }));
273
274        assert!(result.is_err(), "expected panic on type mismatch");
275    }
276
277    #[test]
278    fn downcast_or_panic_ok_returns_value() {
279        let id = sample_id();
280        let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
281        let arc = downcast_or_panic::<u32>(arc_any, &id);
282        assert_eq!(*arc, 123u32);
283    }
284
285    #[test]
286    fn default_creates_empty_cache() {
287        let cache = ArtifactCache::default();
288        assert!(cache.is_empty());
289        assert_eq!(cache.len(), 0);
290    }
291
292    #[test]
293    fn get_typed_missing_key_returns_none() {
294        let cache = ArtifactCache::new();
295        let id = sample_id();
296        assert!(cache.get_typed::<u32>(&id).is_none());
297    }
298
299    #[test]
300    fn distinct_ids_are_stored_independently() {
301        let cache = ArtifactCache::new();
302        let id_a = ArtifactId::new("domain:a", "label", b"spec", "good", DerivationVersion::V1);
303        let id_b = ArtifactId::new("domain:b", "label", b"spec", "good", DerivationVersion::V1);
304
305        cache.insert_if_absent_typed(id_a.clone(), Arc::new(1u32));
306        cache.insert_if_absent_typed(id_b.clone(), Arc::new(2u32));
307
308        assert_eq!(cache.len(), 2);
309        assert_eq!(*cache.get_typed::<u32>(&id_a).unwrap(), 1);
310        assert_eq!(*cache.get_typed::<u32>(&id_b).unwrap(), 2);
311    }
312
313    #[test]
314    fn concurrent_inserts_converge() {
315        use std::thread;
316
317        let cache = Arc::new(ArtifactCache::new());
318        let id = sample_id();
319
320        let handles: Vec<_> = (0..8)
321            .map(|i| {
322                let cache = Arc::clone(&cache);
323                let id = id.clone();
324                thread::spawn(move || cache.insert_if_absent_typed(id, Arc::new(i as u32)))
325            })
326            .collect();
327
328        let results: Vec<u32> = handles.into_iter().map(|h| *h.join().unwrap()).collect();
329
330        // All threads must see the same winning value.
331        let first = results[0];
332        assert!(results.iter().all(|v| *v == first));
333        assert_eq!(cache.len(), 1);
334    }
335
336    #[test]
337    fn downcast_or_panic_message_contains_id_fields() {
338        let id = ArtifactId::new(
339            "domain:msg",
340            "my-label",
341            b"spec",
342            "my-variant",
343            DerivationVersion::V1,
344        );
345        let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(42u32);
346        let result = catch_unwind(AssertUnwindSafe(|| {
347            let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
348        }));
349        let err = result.unwrap_err();
350        let msg = err.downcast_ref::<String>().unwrap();
351        assert!(msg.contains("domain:msg"), "panic should mention domain");
352        assert!(msg.contains("my-label"), "panic should mention label");
353        assert!(msg.contains("my-variant"), "panic should mention variant");
354    }
355}