1use alloc::string::ToString;
8use alloc::sync::Arc;
9use core::fmt;
10
11#[cfg(feature = "std")]
12use rand10::TryRng;
13#[cfg(feature = "std")]
14use rand10::rngs::SysRng;
15
16use crate::srp::cache::ArtifactCache;
17use crate::srp::identity::{ArtifactDomain, ArtifactId, DerivationVersion, Seed, derive_seed};
18
19#[derive(Clone, Debug)]
21pub enum Mode {
22 Random,
24
25 Deterministic { master: Seed },
27}
28
29struct Inner {
30 mode: Mode,
31 cache: ArtifactCache,
32}
33
34#[derive(Clone)]
38pub struct Factory {
39 inner: Arc<Inner>,
40}
41
42impl fmt::Debug for Factory {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.debug_struct("Factory")
45 .field("mode", &self.inner.mode)
46 .field("cache_size", &self.inner.cache.len())
47 .finish()
48 }
49}
50
51impl Factory {
52 pub fn new(mode: Mode) -> Self {
54 Self {
55 inner: Arc::new(Inner {
56 mode,
57 cache: ArtifactCache::new(),
58 }),
59 }
60 }
61
62 pub fn random() -> Self {
64 Self::new(Mode::Random)
65 }
66
67 pub fn deterministic(master: Seed) -> Self {
69 Self::new(Mode::Deterministic { master })
70 }
71
72 pub fn deterministic_from_str(text: &str) -> Self {
78 Self::deterministic(Seed::from_text(text))
79 }
80
81 #[cfg(feature = "std")]
91 pub fn deterministic_from_env(var: &str) -> Result<Self, crate::Error> {
92 let raw = std::env::var(var).map_err(|_| crate::Error::MissingEnvVar {
93 var: var.to_string(),
94 })?;
95
96 let seed = Seed::from_env_value(&raw).map_err(|message| crate::Error::InvalidSeed {
97 var: var.to_string(),
98 message,
99 })?;
100
101 Ok(Self::deterministic(seed))
102 }
103
104 pub fn mode(&self) -> &Mode {
106 &self.inner.mode
107 }
108
109 pub fn clear_cache(&self) {
111 self.inner.cache.clear();
112 }
113
114 pub fn get_or_init<T, F>(
119 &self,
120 domain: ArtifactDomain,
121 label: &str,
122 spec_bytes: &[u8],
123 variant: &str,
124 init: F,
125 ) -> Arc<T>
126 where
127 T: core::any::Any + Send + Sync + 'static,
128 F: FnOnce(Seed) -> T,
129 {
130 let id = ArtifactId::new(
131 domain,
132 label.to_string(),
133 spec_bytes,
134 variant.to_string(),
135 DerivationVersion::V1,
136 );
137
138 if let Some(entry) = self.inner.cache.get_typed::<T>(&id) {
139 return entry;
140 }
141
142 let seed = self.seed_for(&id);
143 let value = init(seed);
144 let arc: Arc<T> = Arc::new(value);
145
146 self.inner.cache.insert_if_absent_typed(id, arc)
147 }
148
149 fn seed_for(&self, id: &ArtifactId) -> Seed {
150 match &self.inner.mode {
151 Mode::Random => random_seed(),
152 Mode::Deterministic { master } => derive_seed(master, id),
153 }
154 }
155}
156
157#[cfg(feature = "std")]
158pub(crate) fn random_seed() -> Seed {
159 let mut bytes = [0u8; 32];
160 SysRng
161 .try_fill_bytes(&mut bytes)
162 .expect("failed to read operating-system randomness");
163 Seed::new(bytes)
164}
165
166#[cfg(not(feature = "std"))]
167pub(crate) fn random_seed() -> Seed {
168 panic!("uselesskey-core-factory: Mode::Random requires the `std` feature")
169}
170
171#[cfg(all(test, feature = "std"))]
172mod tests {
173 use super::{Factory, Mode, random_seed};
174 use crate::Seed;
175 use std::panic::{AssertUnwindSafe, catch_unwind};
176 use std::sync::Arc;
177 use std::sync::atomic::{AtomicUsize, Ordering};
178
179 fn draw_u64(seed: Seed) -> u64 {
180 let mut bytes = [0u8; 8];
181 seed.fill_bytes(&mut bytes);
182 u64::from_le_bytes(bytes)
183 }
184
185 #[test]
186 fn clear_cache_forces_reinit() {
187 let fx = Factory::random();
188 let hits = AtomicUsize::new(0);
189
190 let first = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
191 hits.fetch_add(1, Ordering::SeqCst);
192 42u8
193 });
194
195 assert_eq!(hits.load(Ordering::SeqCst), 1);
196 let second = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
197 hits.fetch_add(1, Ordering::SeqCst);
198 99u8
199 });
200 assert!(Arc::ptr_eq(&first, &second));
201
202 fx.clear_cache();
203 let third = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
204 hits.fetch_add(1, Ordering::SeqCst);
205 44u8
206 });
207
208 assert_eq!(hits.load(Ordering::SeqCst), 2);
209 assert!(!Arc::ptr_eq(&first, &third));
210 }
211
212 #[test]
213 fn get_or_init_type_mismatch_panics() {
214 let fx = Factory::random();
215 let _ = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| 123u32);
216 let result = catch_unwind(AssertUnwindSafe(|| {
217 let _ = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| {
218 "oops".to_string()
219 });
220 }));
221
222 assert!(result.is_err(), "expected panic on type mismatch");
223 }
224
225 #[test]
226 fn random_seed_has_expected_length() {
227 let seed = random_seed();
228 assert_eq!(seed.bytes().len(), 32);
229 }
230
231 #[test]
232 fn get_or_init_reentrant_does_not_deadlock() {
233 let fx = Factory::deterministic(Seed::new([42u8; 32]));
234
235 let outer: Arc<String> = fx.get_or_init("test:outer", "label", b"spec", "good", |_rng| {
236 let inner: Arc<u64> =
237 fx.get_or_init("test:inner", "label", b"spec", "good", |_rng| 42u64);
238 format!("outer-{}", *inner)
239 });
240
241 assert_eq!(*outer, "outer-42");
242 }
243
244 #[test]
245 fn debug_includes_cache_size() {
246 let fx = Factory::random();
247 let dbg = format!("{:?}", fx);
248 assert!(dbg.contains("cache_size: 0"), "empty factory: {dbg}");
249
250 let _ = fx.get_or_init("domain:test", "label", b"spec", "good", |_rng| 7u8);
251 let dbg = format!("{:?}", fx);
252 assert!(dbg.contains("cache_size: 1"), "after insert: {dbg}");
253 }
254
255 #[test]
256 fn mode_pattern_matches_deterministic() {
257 let seed = Seed::new([1u8; 32]);
258 let fx = Factory::deterministic(seed);
259 match fx.mode() {
260 Mode::Deterministic { master } => assert_eq!(master.bytes(), seed.bytes()),
261 Mode::Random => panic!("wrong mode"),
262 }
263 }
264
265 #[test]
266 fn mode_pattern_matches_random() {
267 let fx = Factory::random();
268 assert!(matches!(fx.mode(), Mode::Random));
269 }
270
271 #[test]
272 fn deterministic_same_inputs_yield_same_output() {
273 let fx = Factory::deterministic(Seed::new([7u8; 32]));
274 let a: Arc<u64> = fx.get_or_init("domain:det", "lbl", b"sp", "good", draw_u64);
275 fx.clear_cache();
277 let b: Arc<u64> = fx.get_or_init("domain:det", "lbl", b"sp", "good", draw_u64);
278 assert_eq!(*a, *b, "deterministic mode must reproduce the same value");
279 }
280
281 #[test]
282 fn clone_shares_cache() {
283 let fx = Factory::random();
284 let _ = fx.get_or_init("domain:clone", "lbl", b"sp", "good", |_| 99u32);
285 let fx2 = fx.clone();
286 let val = fx2.get_or_init("domain:clone", "lbl", b"sp", "good", |_| 0u32);
287 assert_eq!(*val, 99, "clone must share the same cache");
288 }
289
290 #[test]
291 fn different_domains_produce_distinct_entries() {
292 let fx = Factory::deterministic(Seed::new([1u8; 32]));
293 let a: Arc<u64> = fx.get_or_init("domain:a", "lbl", b"sp", "good", draw_u64);
294 let b: Arc<u64> = fx.get_or_init("domain:b", "lbl", b"sp", "good", draw_u64);
295 assert_ne!(*a, *b);
296 }
297
298 #[test]
299 fn different_variants_produce_distinct_entries() {
300 let fx = Factory::deterministic(Seed::new([2u8; 32]));
301 let a: Arc<u64> = fx.get_or_init("domain:v", "lbl", b"sp", "good", draw_u64);
302 let b: Arc<u64> = fx.get_or_init("domain:v", "lbl", b"sp", "bad", draw_u64);
303 assert_ne!(*a, *b);
304 }
305
306 #[test]
307 fn different_specs_produce_distinct_entries() {
308 let fx = Factory::deterministic(Seed::new([3u8; 32]));
309 let a: Arc<u64> = fx.get_or_init("domain:s", "lbl", b"RS256", "good", draw_u64);
310 let b: Arc<u64> = fx.get_or_init("domain:s", "lbl", b"RS384", "good", draw_u64);
311 assert_ne!(*a, *b);
312 }
313
314 #[test]
315 fn debug_mode_random() {
316 let fx = Factory::random();
317 let dbg = format!("{:?}", fx);
318 assert!(
319 dbg.contains("Random"),
320 "debug should show Random mode: {dbg}"
321 );
322 }
323
324 #[test]
325 fn debug_mode_deterministic() {
326 let fx = Factory::deterministic(Seed::new([0u8; 32]));
327 let dbg = format!("{:?}", fx);
328 assert!(
329 dbg.contains("Deterministic"),
330 "debug should show Deterministic mode: {dbg}"
331 );
332 assert!(
333 dbg.contains("redacted"),
334 "seed must be redacted in debug output: {dbg}"
335 );
336 }
337}