uselesskey_core_cache/
lib.rs1#![forbid(unsafe_code)]
10#![warn(missing_docs)]
11#![cfg_attr(not(feature = "std"), no_std)]
12
13extern crate alloc;
14
15#[cfg(not(feature = "std"))]
16use alloc::collections::BTreeMap;
17use alloc::sync::Arc;
18use core::any::Any;
19use core::fmt;
20#[cfg(feature = "std")]
21use dashmap::DashMap;
22#[cfg(not(feature = "std"))]
23use spin::Mutex;
24
25use uselesskey_core_id::ArtifactId;
26
27type CacheValue = Arc<dyn Any + Send + Sync>;
28
29#[cfg(feature = "std")]
30type Cache = DashMap<ArtifactId, CacheValue>;
31
32#[cfg(not(feature = "std"))]
33type Cache = Mutex<BTreeMap<ArtifactId, CacheValue>>;
34
35pub struct ArtifactCache {
53 inner: Cache,
54}
55
56impl fmt::Debug for ArtifactCache {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 f.debug_struct("ArtifactCache")
59 .field("len", &self.len())
60 .finish()
61 }
62}
63
64impl ArtifactCache {
65 pub fn new() -> Self {
67 Self { inner: new_cache() }
68 }
69
70 pub fn len(&self) -> usize {
72 cache_len(&self.inner)
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.len() == 0
78 }
79
80 pub fn clear(&self) {
82 cache_clear(&self.inner);
83 }
84
85 pub fn get_typed<T>(&self, id: &ArtifactId) -> Option<Arc<T>>
89 where
90 T: Any + Send + Sync + 'static,
91 {
92 cache_get(&self.inner, id).map(|entry| downcast_or_panic::<T>(entry, id))
93 }
94
95 pub fn insert_if_absent_typed<T>(&self, id: ArtifactId, value: Arc<T>) -> Arc<T>
99 where
100 T: Any + Send + Sync + 'static,
101 {
102 let value_any: CacheValue = value;
103 let winner = cache_insert_if_absent(&self.inner, id.clone(), value_any);
104 downcast_or_panic::<T>(winner, &id)
105 }
106}
107
108impl Default for ArtifactCache {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[cfg(feature = "std")]
115fn new_cache() -> Cache {
116 DashMap::new()
117}
118
119#[cfg(not(feature = "std"))]
120fn new_cache() -> Cache {
121 Mutex::new(BTreeMap::new())
122}
123
124#[cfg(feature = "std")]
125fn cache_len(cache: &Cache) -> usize {
126 cache.len()
127}
128
129#[cfg(not(feature = "std"))]
130fn cache_len(cache: &Cache) -> usize {
131 cache.lock().len()
132}
133
134#[cfg(feature = "std")]
135fn cache_clear(cache: &Cache) {
136 cache.clear();
137}
138
139#[cfg(not(feature = "std"))]
140fn cache_clear(cache: &Cache) {
141 cache.lock().clear();
142}
143
144#[cfg(feature = "std")]
145fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
146 cache.get(id).map(|entry| entry.value().clone())
147}
148
149#[cfg(not(feature = "std"))]
150fn cache_get(cache: &Cache, id: &ArtifactId) -> Option<CacheValue> {
151 cache.lock().get(id).cloned()
152}
153
154#[cfg(feature = "std")]
155fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
156 cache.entry(id).or_insert(value).value().clone()
157}
158
159#[cfg(not(feature = "std"))]
160fn cache_insert_if_absent(cache: &Cache, id: ArtifactId, value: CacheValue) -> CacheValue {
161 use alloc::collections::btree_map::Entry;
162
163 let mut guard = cache.lock();
164 match guard.entry(id) {
165 Entry::Vacant(slot) => slot.insert(value).clone(),
166 Entry::Occupied(slot) => slot.get().clone(),
167 }
168}
169
170pub fn downcast_or_panic<T>(arc_any: CacheValue, id: &ArtifactId) -> Arc<T>
174where
175 T: Any + Send + Sync + 'static,
176{
177 match arc_any.downcast::<T>() {
178 Ok(v) => v,
179 Err(_) => {
180 panic!(
181 "uselesskey-core-cache: artifact type mismatch for domain={} label={} variant={}",
182 id.domain, id.label, id.variant
183 );
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::{ArtifactCache, downcast_or_panic};
191 use core::any::Any;
192 use std::panic::{AssertUnwindSafe, catch_unwind};
193 use std::sync::Arc;
194 use uselesskey_core_id::{ArtifactId, DerivationVersion};
195
196 fn sample_id() -> ArtifactId {
197 ArtifactId::new(
198 "domain:test",
199 "label",
200 b"spec",
201 "good",
202 DerivationVersion::V1,
203 )
204 }
205
206 #[test]
207 fn typed_round_trip() {
208 let cache = ArtifactCache::new();
209 let id = sample_id();
210
211 let inserted = cache.insert_if_absent_typed(id.clone(), Arc::new(7u32));
212 let fetched = cache
213 .get_typed::<u32>(&id)
214 .expect("value should be retrievable");
215
216 assert_eq!(*inserted, 7);
217 assert_eq!(*fetched, 7);
218 }
219
220 #[test]
221 fn insert_if_absent_keeps_first_value() {
222 let cache = ArtifactCache::new();
223 let id = sample_id();
224
225 let first = cache.insert_if_absent_typed(id.clone(), Arc::new(11u32));
226 let second = cache.insert_if_absent_typed(id, Arc::new(22u32));
227
228 assert!(Arc::ptr_eq(&first, &second));
229 assert_eq!(*second, 11u32);
230 }
231
232 #[test]
233 fn clear_empties_cache() {
234 let cache = ArtifactCache::new();
235 let id = sample_id();
236
237 cache.insert_if_absent_typed(id, Arc::new(1u8));
238 assert_eq!(cache.len(), 1);
239 assert!(!cache.is_empty());
240
241 cache.clear();
242 assert_eq!(cache.len(), 0);
243 assert!(cache.is_empty());
244 }
245
246 #[test]
247 fn debug_includes_type_name_and_len() {
248 let cache = ArtifactCache::new();
249 cache.insert_if_absent_typed(sample_id(), Arc::new(1u8));
250
251 let dbg = format!("{cache:?}");
252 assert!(
253 dbg.contains("ArtifactCache"),
254 "debug output should include struct name"
255 );
256 assert!(dbg.contains("len: 1"), "debug output should include len");
257 }
258
259 #[test]
260 fn get_typed_type_mismatch_panics() {
261 let cache = ArtifactCache::new();
262 let id = sample_id();
263 let _ = cache.insert_if_absent_typed(id.clone(), Arc::new(123u32));
264
265 let result = catch_unwind(AssertUnwindSafe(|| {
266 let _ = cache.get_typed::<String>(&id);
267 }));
268
269 assert!(result.is_err(), "expected panic on type mismatch");
270 }
271
272 #[test]
273 fn downcast_or_panic_type_mismatch_panics() {
274 let id = sample_id();
275 let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
276 let result = catch_unwind(AssertUnwindSafe(|| {
277 let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
278 }));
279
280 assert!(result.is_err(), "expected panic on type mismatch");
281 }
282
283 #[test]
284 fn downcast_or_panic_ok_returns_value() {
285 let id = sample_id();
286 let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(123u32);
287 let arc = downcast_or_panic::<u32>(arc_any, &id);
288 assert_eq!(*arc, 123u32);
289 }
290
291 #[test]
292 fn default_creates_empty_cache() {
293 let cache = ArtifactCache::default();
294 assert!(cache.is_empty());
295 assert_eq!(cache.len(), 0);
296 }
297
298 #[test]
299 fn get_typed_missing_key_returns_none() {
300 let cache = ArtifactCache::new();
301 let id = sample_id();
302 assert!(cache.get_typed::<u32>(&id).is_none());
303 }
304
305 #[test]
306 fn distinct_ids_are_stored_independently() {
307 let cache = ArtifactCache::new();
308 let id_a = ArtifactId::new("domain:a", "label", b"spec", "good", DerivationVersion::V1);
309 let id_b = ArtifactId::new("domain:b", "label", b"spec", "good", DerivationVersion::V1);
310
311 cache.insert_if_absent_typed(id_a.clone(), Arc::new(1u32));
312 cache.insert_if_absent_typed(id_b.clone(), Arc::new(2u32));
313
314 assert_eq!(cache.len(), 2);
315 assert_eq!(*cache.get_typed::<u32>(&id_a).unwrap(), 1);
316 assert_eq!(*cache.get_typed::<u32>(&id_b).unwrap(), 2);
317 }
318
319 #[test]
320 fn concurrent_inserts_converge() {
321 use std::thread;
322
323 let cache = Arc::new(ArtifactCache::new());
324 let id = sample_id();
325
326 let handles: Vec<_> = (0..8)
327 .map(|i| {
328 let cache = Arc::clone(&cache);
329 let id = id.clone();
330 thread::spawn(move || cache.insert_if_absent_typed(id, Arc::new(i as u32)))
331 })
332 .collect();
333
334 let results: Vec<u32> = handles.into_iter().map(|h| *h.join().unwrap()).collect();
335
336 let first = results[0];
338 assert!(results.iter().all(|v| *v == first));
339 assert_eq!(cache.len(), 1);
340 }
341
342 #[test]
343 fn downcast_or_panic_message_contains_id_fields() {
344 let id = ArtifactId::new(
345 "domain:msg",
346 "my-label",
347 b"spec",
348 "my-variant",
349 DerivationVersion::V1,
350 );
351 let arc_any: Arc<dyn Any + Send + Sync> = Arc::new(42u32);
352 let result = catch_unwind(AssertUnwindSafe(|| {
353 let _ = downcast_or_panic::<String>(arc_any.clone(), &id);
354 }));
355 let err = result.unwrap_err();
356 let msg = err.downcast_ref::<String>().unwrap();
357 assert!(msg.contains("domain:msg"), "panic should mention domain");
358 assert!(msg.contains("my-label"), "panic should mention label");
359 assert!(msg.contains("my-variant"), "panic should mention variant");
360 }
361}