1#![forbid(unsafe_code)]
2#![cfg_attr(not(feature = "std"), no_std)]
3extern crate alloc;
10
11use alloc::string::ToString;
12use alloc::sync::Arc;
13use core::fmt;
14
15#[cfg(feature = "std")]
16use rand10::TryRng;
17#[cfg(feature = "std")]
18use rand10::rngs::SysRng;
19use uselesskey_core_cache::ArtifactCache;
20use uselesskey_core_id::{ArtifactDomain, ArtifactId, DerivationVersion, Seed, derive_seed};
21
22#[derive(Clone, Debug)]
24pub enum Mode {
25 Random,
27
28 Deterministic { master: Seed },
30}
31
32struct Inner {
33 mode: Mode,
34 cache: ArtifactCache,
35}
36
37#[derive(Clone)]
41pub struct Factory {
42 inner: Arc<Inner>,
43}
44
45impl fmt::Debug for Factory {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 f.debug_struct("Factory")
48 .field("mode", &self.inner.mode)
49 .field("cache_size", &self.inner.cache.len())
50 .finish()
51 }
52}
53
54impl Factory {
55 pub fn new(mode: Mode) -> Self {
57 Self {
58 inner: Arc::new(Inner {
59 mode,
60 cache: ArtifactCache::new(),
61 }),
62 }
63 }
64
65 pub fn random() -> Self {
67 Self::new(Mode::Random)
68 }
69
70 pub fn deterministic(master: Seed) -> Self {
72 Self::new(Mode::Deterministic { master })
73 }
74
75 pub fn mode(&self) -> &Mode {
77 &self.inner.mode
78 }
79
80 pub fn clear_cache(&self) {
82 self.inner.cache.clear();
83 }
84
85 pub fn get_or_init<T, F>(
90 &self,
91 domain: ArtifactDomain,
92 label: &str,
93 spec_bytes: &[u8],
94 variant: &str,
95 init: F,
96 ) -> Arc<T>
97 where
98 T: core::any::Any + Send + Sync + 'static,
99 F: FnOnce(Seed) -> T,
100 {
101 let id = ArtifactId::new(
102 domain,
103 label.to_string(),
104 spec_bytes,
105 variant.to_string(),
106 DerivationVersion::V1,
107 );
108
109 if let Some(entry) = self.inner.cache.get_typed::<T>(&id) {
110 return entry;
111 }
112
113 let seed = self.seed_for(&id);
114 let value = init(seed);
115 let arc: Arc<T> = Arc::new(value);
116
117 self.inner.cache.insert_if_absent_typed(id, arc)
118 }
119
120 fn seed_for(&self, id: &ArtifactId) -> Seed {
121 match &self.inner.mode {
122 Mode::Random => random_seed(),
123 Mode::Deterministic { master } => derive_seed(master, id),
124 }
125 }
126}
127
128#[cfg(feature = "std")]
129pub(crate) fn random_seed() -> Seed {
130 let mut bytes = [0u8; 32];
131 SysRng
132 .try_fill_bytes(&mut bytes)
133 .expect("failed to read operating-system randomness");
134 Seed::new(bytes)
135}
136
137#[cfg(not(feature = "std"))]
138pub(crate) fn random_seed() -> Seed {
139 panic!("uselesskey-core-factory: Mode::Random requires the `std` feature")
140}
141
142#[cfg(all(test, feature = "std"))]
143mod tests {
144 use super::{Factory, Mode, random_seed};
145 use std::panic::{AssertUnwindSafe, catch_unwind};
146 use std::sync::Arc;
147 use std::sync::atomic::{AtomicUsize, Ordering};
148 use uselesskey_core_id::Seed;
149
150 fn draw_u64(seed: Seed) -> u64 {
151 let mut bytes = [0u8; 8];
152 seed.fill_bytes(&mut bytes);
153 u64::from_le_bytes(bytes)
154 }
155
156 #[test]
157 fn clear_cache_forces_reinit() {
158 let fx = Factory::random();
159 let hits = AtomicUsize::new(0);
160
161 let first = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
162 hits.fetch_add(1, Ordering::SeqCst);
163 42u8
164 });
165
166 assert_eq!(hits.load(Ordering::SeqCst), 1);
167 let second = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
168 hits.fetch_add(1, Ordering::SeqCst);
169 99u8
170 });
171 assert!(Arc::ptr_eq(&first, &second));
172
173 fx.clear_cache();
174 let third = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
175 hits.fetch_add(1, Ordering::SeqCst);
176 44u8
177 });
178
179 assert_eq!(hits.load(Ordering::SeqCst), 2);
180 assert!(!Arc::ptr_eq(&first, &third));
181 }
182
183 #[test]
184 fn get_or_init_type_mismatch_panics() {
185 let fx = Factory::random();
186 let _ = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| 123u32);
187 let result = catch_unwind(AssertUnwindSafe(|| {
188 let _ = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
189 "oops".to_string()
190 });
191 }));
192
193 assert!(result.is_err(), "expected panic on type mismatch");
194 }
195
196 #[test]
197 fn random_seed_has_expected_length() {
198 let seed = random_seed();
199 assert_eq!(seed.bytes().len(), 32);
200 }
201
202 #[test]
203 fn get_or_init_reentrant_does_not_deadlock() {
204 let fx = Factory::deterministic(Seed::new([42u8; 32]));
205
206 let outer: Arc<String> = fx.get_or_init("test:outer", "label", b"spec", "good", |_rng| {
207 let inner: Arc<u64> =
208 fx.get_or_init("test:inner", "label", b"spec", "good", |_rng| 42u64);
209 format!("outer-{}", *inner)
210 });
211
212 assert_eq!(*outer, "outer-42");
213 }
214
215 #[test]
216 fn debug_includes_cache_size() {
217 let fx = Factory::random();
218 let dbg = format!("{:?}", fx);
219 assert!(dbg.contains("cache_size: 0"), "empty factory: {dbg}");
220
221 let _ = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| 7u8);
222 let dbg = format!("{:?}", fx);
223 assert!(dbg.contains("cache_size: 1"), "after insert: {dbg}");
224 }
225
226 #[test]
227 fn mode_pattern_matches_deterministic() {
228 let seed = Seed::new([1u8; 32]);
229 let fx = Factory::deterministic(seed);
230 match fx.mode() {
231 Mode::Deterministic { master } => assert_eq!(master.bytes(), seed.bytes()),
232 Mode::Random => panic!("wrong mode"),
233 }
234 }
235
236 #[test]
237 fn mode_pattern_matches_random() {
238 let fx = Factory::random();
239 assert!(matches!(fx.mode(), Mode::Random));
240 }
241
242 #[test]
243 fn deterministic_same_inputs_yield_same_output() {
244 let fx = Factory::deterministic(Seed::new([7u8; 32]));
245 let a: Arc<u64> = fx.get_or_init("domain:det", "lbl", b"sp", "good", draw_u64);
246 fx.clear_cache();
248 let b: Arc<u64> = fx.get_or_init("domain:det", "lbl", b"sp", "good", draw_u64);
249 assert_eq!(*a, *b, "deterministic mode must reproduce the same value");
250 }
251
252 #[test]
253 fn clone_shares_cache() {
254 let fx = Factory::random();
255 let _ = fx.get_or_init("domain:clone", "lbl", b"sp", "good", |_| 99u32);
256 let fx2 = fx.clone();
257 let val = fx2.get_or_init("domain:clone", "lbl", b"sp", "good", |_| 0u32);
258 assert_eq!(*val, 99, "clone must share the same cache");
259 }
260
261 #[test]
262 fn different_domains_produce_distinct_entries() {
263 let fx = Factory::deterministic(Seed::new([1u8; 32]));
264 let a: Arc<u64> = fx.get_or_init("domain:a", "lbl", b"sp", "good", draw_u64);
265 let b: Arc<u64> = fx.get_or_init("domain:b", "lbl", b"sp", "good", draw_u64);
266 assert_ne!(*a, *b);
267 }
268
269 #[test]
270 fn different_variants_produce_distinct_entries() {
271 let fx = Factory::deterministic(Seed::new([2u8; 32]));
272 let a: Arc<u64> = fx.get_or_init("domain:v", "lbl", b"sp", "good", draw_u64);
273 let b: Arc<u64> = fx.get_or_init("domain:v", "lbl", b"sp", "bad", draw_u64);
274 assert_ne!(*a, *b);
275 }
276
277 #[test]
278 fn different_specs_produce_distinct_entries() {
279 let fx = Factory::deterministic(Seed::new([3u8; 32]));
280 let a: Arc<u64> = fx.get_or_init("domain:s", "lbl", b"RS256", "good", draw_u64);
281 let b: Arc<u64> = fx.get_or_init("domain:s", "lbl", b"RS384", "good", draw_u64);
282 assert_ne!(*a, *b);
283 }
284
285 #[test]
286 fn debug_mode_random() {
287 let fx = Factory::random();
288 let dbg = format!("{:?}", fx);
289 assert!(
290 dbg.contains("Random"),
291 "debug should show Random mode: {dbg}"
292 );
293 }
294
295 #[test]
296 fn debug_mode_deterministic() {
297 let fx = Factory::deterministic(Seed::new([0u8; 32]));
298 let dbg = format!("{:?}", fx);
299 assert!(
300 dbg.contains("Deterministic"),
301 "debug should show Deterministic mode: {dbg}"
302 );
303 assert!(
304 dbg.contains("redacted"),
305 "seed must be redacted in debug output: {dbg}"
306 );
307 }
308}