Skip to main content

uselesskey_core_cache/
lib.rs

1//! Per-process artifact cache keyed by [`ArtifactId`].
2//!
3//! Stores generated fixtures behind `Arc<dyn Any>` so expensive key generation
4//! (especially RSA) only happens once per unique identity tuple. Thread-safe:
5//! uses `DashMap` with `std`, `spin::Mutex` without.
6//!
7//! The primary type is [`ArtifactCache`].
8
9#![forbid(unsafe_code)]
10#![warn(missing_docs)]
11#![cfg_attr(not(feature = "std"), no_std)]
12
13extern crate alloc;
14
15#[cfg(not(feature = "std"))]
16use alloc::collections::BTreeMap;
17use alloc::sync::Arc;
18use core::any::Any;
19use core::fmt;
20#[cfg(feature = "std")]
21use dashmap::DashMap;
22#[cfg(not(feature = "std"))]
23use spin::Mutex;
24
25use uselesskey_core_id::ArtifactId;
26
27type CacheValue = Arc<dyn Any + Send + Sync>;
28
29#[cfg(feature = "std")]
30type Cache = DashMap<ArtifactId, CacheValue>;
31
32#[cfg(not(feature = "std"))]
33type Cache = Mutex<BTreeMap<ArtifactId, CacheValue>>;
34
35/// Cache keyed by [`ArtifactId`] that stores typed values behind `Arc<dyn Any>`.
36///
37/// # Examples
38///
39/// ```
40/// use std::sync::Arc;
41/// use uselesskey_core_cache::ArtifactCache;
42/// use uselesskey_core_id::{ArtifactId, DerivationVersion};
43///
44/// let cache = ArtifactCache::new();
45/// let id = ArtifactId::new("domain:rsa", "issuer", b"RS256", "good", DerivationVersion::V1);
46///
47/// // Insert once, retrieve many times
48/// cache.insert_if_absent_typed(id.clone(), Arc::new(42u32));
49/// let value = cache.get_typed::<u32>(&id).unwrap();
50/// assert_eq!(*value, 42);
51/// ```
52pub struct ArtifactCache {
53    inner: Cache,
54}
55
56impl fmt::Debug for ArtifactCache {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        f.debug_struct("ArtifactCache")
59            .field("len", &self.len())
60            .finish()
61    }
62}
63
64impl ArtifactCache {
65    /// Create an empty artifact cache.
66    pub fn new() -> Self {
67        Self { inner: new_cache() }
68    }
69
70    /// Number of cache entries.
71    pub fn len(&self) -> usize {
72        cache_len(&self.inner)
73    }
74
75    /// Returns `true` when there are no entries.
76    pub fn is_empty(&self) -> bool {
77        self.len() == 0
78    }
79
80    /// Remove all entries from the cache.
81    pub fn clear(&self) {
82        cache_clear(&self.inner);
83    }
84
85    /// Retrieve a typed value by id.
86    ///
87    /// Panics if the id exists with a different concrete type.
88    pub fn get_typed<T>(&self, id: &ArtifactId) -> Option<Arc<T>>
89    where
90        T: Any + Send + Sync + 'static,
91    {
92        cache_get(&self.inner, id).map(|entry| downcast_or_panic::<T>(entry, id))
93    }
94
95    /// Insert a typed value if the id is vacant and return the winning cached value.
96    ///
97    /// Panics if an existing value for the same id has a different concrete type.
98    pub fn insert_if_absent_typed<T>(&self, id: ArtifactId, value: Arc<T>) -> Arc<T>
99    where
100        T: Any + Send + Sync + 'static,
101    {
102        let value_any: CacheValue = value;
103        let winner = cache_insert_if_absent(&self.inner, id.clone(), value_any);
104        downcast_or_panic::<T>(winner, &id)
105    }
106}
107
108impl Default for ArtifactCache {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[cfg(feature = "std")]
115fn new_cache() -> Cache {
116    DashMap::new()
117}
118
119#[cfg(not(feature = "std"))]
120fn new_cache() -> Cache {
121    Mutex::new(BTreeMap::new())
122}
123
124#[cfg(feature = "std")]
125fn cache_len(cache: &Cache) -> usize {
126    cache.len()
127}
128
129#[cfg(not(feature = "std"))]
130fn cache_len(cache: &Cache) -> usize {
131    cache.lock().len()
132}
133
134#[cfg(feature = "std")]
135fn cache_clear(cache: &Cache) {
136    cache.clear();
137}
138
139#[cfg(not(feature = "std"))]
140fn cache_clear(cache: &Cache) {
141    cache.lock().clear();
142}
143
144#[cfg(feature = "std")]
145fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
146    cache.get(id).map(|entry| entry.value().clone())
147}
148
149#[cfg(not(feature = "std"))]
150fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
151    cache.lock().get(id).cloned()
152}
153
154#[cfg(feature = "std")]
155fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
156    cache.entry(id).or_insert(value).value().clone()
157}
158
159#[cfg(not(feature = "std"))]
160fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
161    use alloc::collections::btree_map::Entry;
162
163    let mut guard = cache.lock();
164    match guard.entry(id) {
165        Entry::Vacant(slot) => slot.insert(value).clone(),
166        Entry::Occupied(slot) => slot.get().clone(),
167    }
168}
169
170/// Downcast a cached `Any` value to the expected fixture type.
171///
172/// Panics when the cache key maps to a different concrete type.
173pub fn downcast_or_panic<T>(arc_any: CacheValue, id: &ArtifactId) -> Arc<T>
174where
175    T: Any + Send + Sync + 'static,
176{
177    match arc_any.downcast::<T>() {
178        Ok(v) => v,
179        Err(_) => {
180            panic!(
181                "uselesskey-core-cache: artifact type mismatch for domain={} label={} variant={}",
182                id.domain, id.label, id.variant
183            );
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::{ArtifactCache, downcast_or_panic};
191    use core::any::Any;
192    use std::panic::{AssertUnwindSafe, catch_unwind};
193    use std::sync::Arc;
194    use uselesskey_core_id::{ArtifactId, DerivationVersion};
195
196    fn sample_id() -> ArtifactId {
197        ArtifactId::new(
198            "domain:test",
199            "label",
200            b"spec",
201            "good",
202            DerivationVersion::V1,
203        )
204    }
205
206    #[test]
207    fn typed_round_trip() {
208        let cache = ArtifactCache::new();
209        let id = sample_id();
210
211        let inserted = cache.insert_if_absent_typed(id.clone(), Arc::new(7u32));
212        let fetched = cache
213            .get_typed::<u32>(&id)
214            .expect("value should be retrievable");
215
216        assert_eq!(*inserted, 7);
217        assert_eq!(*fetched, 7);
218    }
219
220    #[test]
221    fn insert_if_absent_keeps_first_value() {
222        let cache = ArtifactCache::new();
223        let id = sample_id();
224
225        let first = cache.insert_if_absent_typed(id.clone(), Arc::new(11u32));
226        let second = cache.insert_if_absent_typed(id, Arc::new(22u32));
227
228        assert!(Arc::ptr_eq(&first, &second));
229        assert_eq!(*second, 11u32);
230    }
231
232    #[test]
233    fn clear_empties_cache() {
234        let cache = ArtifactCache::new();
235        let id = sample_id();
236
237        cache.insert_if_absent_typed(id, Arc::new(1u8));
238        assert_eq!(cache.len(), 1);
239        assert!(!cache.is_empty());
240
241        cache.clear();
242        assert_eq!(cache.len(), 0);
243        assert!(cache.is_empty());
244    }
245
246    #[test]
247    fn debug_includes_type_name_and_len() {
248        let cache = ArtifactCache::new();
249        cache.insert_if_absent_typed(sample_id(), Arc::new(1u8));
250
251        let dbg = format!("{cache:?}");
252        assert!(
253            dbg.contains("ArtifactCache"),
254            "debug output should include struct name"
255        );
256        assert!(dbg.contains("len: 1"), "debug output should include len");
257    }
258
259    #[test]
260    fn get_typed_type_mismatch_panics() {
261        let cache = ArtifactCache::new();
262        let id = sample_id();
263        let _ = cache.insert_if_absent_typed(id.clone(), Arc::new(123u32));
264
265        let result = catch_unwind(AssertUnwindSafe(|| {
266            let _ = cache.get_typed::<String>(&id);
267        }));
268
269        assert!(result.is_err(), "expected panic on type mismatch");
270    }
271
272    #[test]
273    fn downcast_or_panic_type_mismatch_panics() {
274        let id = sample_id();
275        let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
276        let result = catch_unwind(AssertUnwindSafe(|| {
277            let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
278        }));
279
280        assert!(result.is_err(), "expected panic on type mismatch");
281    }
282
283    #[test]
284    fn downcast_or_panic_ok_returns_value() {
285        let id = sample_id();
286        let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
287        let arc = downcast_or_panic::<u32>(arc_any, &id);
288        assert_eq!(*arc, 123u32);
289    }
290
291    #[test]
292    fn default_creates_empty_cache() {
293        let cache = ArtifactCache::default();
294        assert!(cache.is_empty());
295        assert_eq!(cache.len(), 0);
296    }
297
298    #[test]
299    fn get_typed_missing_key_returns_none() {
300        let cache = ArtifactCache::new();
301        let id = sample_id();
302        assert!(cache.get_typed::<u32>(&id).is_none());
303    }
304
305    #[test]
306    fn distinct_ids_are_stored_independently() {
307        let cache = ArtifactCache::new();
308        let id_a = ArtifactId::new("domain:a", "label", b"spec", "good", DerivationVersion::V1);
309        let id_b = ArtifactId::new("domain:b", "label", b"spec", "good", DerivationVersion::V1);
310
311        cache.insert_if_absent_typed(id_a.clone(), Arc::new(1u32));
312        cache.insert_if_absent_typed(id_b.clone(), Arc::new(2u32));
313
314        assert_eq!(cache.len(), 2);
315        assert_eq!(*cache.get_typed::<u32>(&id_a).unwrap(), 1);
316        assert_eq!(*cache.get_typed::<u32>(&id_b).unwrap(), 2);
317    }
318
319    #[test]
320    fn concurrent_inserts_converge() {
321        use std::thread;
322
323        let cache = Arc::new(ArtifactCache::new());
324        let id = sample_id();
325
326        let handles: Vec<_> = (0..8)
327            .map(|i| {
328                let cache = Arc::clone(&cache);
329                let id = id.clone();
330                thread::spawn(move || cache.insert_if_absent_typed(id, Arc::new(i as u32)))
331            })
332            .collect();
333
334        let results: Vec<u32> = handles.into_iter().map(|h| *h.join().unwrap()).collect();
335
336        // All threads must see the same winning value.
337        let first = results[0];
338        assert!(results.iter().all(|v| *v == first));
339        assert_eq!(cache.len(), 1);
340    }
341
342    #[test]
343    fn downcast_or_panic_message_contains_id_fields() {
344        let id = ArtifactId::new(
345            "domain:msg",
346            "my-label",
347            b"spec",
348            "my-variant",
349            DerivationVersion::V1,
350        );
351        let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(42u32);
352        let result = catch_unwind(AssertUnwindSafe(|| {
353            let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
354        }));
355        let err = result.unwrap_err();
356        let msg = err.downcast_ref::<String>().unwrap();
357        assert!(msg.contains("domain:msg"), "panic should mention domain");
358        assert!(msg.contains("my-label"), "panic should mention label");
359        assert!(msg.contains("my-variant"), "panic should mention variant");
360    }
361}