1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4 sync::Mutex,
5};
6
7use zer_core::error::ZerError;
8
9use crate::{
10 artifact::ModelArtifact,
11 fingerprint::SchemaFingerprint,
12 similarity::{fingerprint_distance, WARM_START_THRESHOLD},
13};
14
15const MAGIC: &[u8] = b"ZSM\x01";
16
17#[derive(Debug)]
19pub enum StartupMode {
20 WarmLoad(ModelArtifact),
22 WarmStart { artifact: ModelArtifact, distance: f32 },
25 ColdStart,
27}
28
29struct RegistryInner {
30 path: Option<PathBuf>,
31 artifacts: HashMap<[u8; 32], ModelArtifact>,
32}
33
34pub struct SchemaRegistry {
44 inner: Mutex<RegistryInner>,
45}
46
47impl SchemaRegistry {
48 pub fn open(path: &Path) -> Result<Self, ZerError> {
52 let artifacts = load(path)?;
53 Ok(Self {
54 inner: Mutex::new(RegistryInner {
55 path: Some(path.to_path_buf()),
56 artifacts,
57 }),
58 })
59 }
60
61 #[cfg(test)]
63 pub(crate) fn open_temporary() -> Result<Self, ZerError> {
64 Ok(Self {
65 inner: Mutex::new(RegistryInner {
66 path: None,
67 artifacts: HashMap::new(),
68 }),
69 })
70 }
71
72 pub fn save(&self, artifact: &ModelArtifact) -> Result<(), ZerError> {
77 let mut inner = self.inner.lock().unwrap();
78 inner.artifacts.insert(artifact.fingerprint.schema_hash, artifact.clone());
79 flush(&inner)?;
80 tracing::debug!(tag = artifact.tag.as_deref(), "saved model artifact");
81 Ok(())
82 }
83
84 pub fn get_exact(
88 &self,
89 fingerprint: &SchemaFingerprint,
90 ) -> Result<Option<ModelArtifact>, ZerError> {
91 let inner = self.inner.lock().unwrap();
92 Ok(inner.artifacts.get(&fingerprint.schema_hash).cloned())
93 }
94
95 pub fn get_nearest(
102 &self,
103 fingerprint: &SchemaFingerprint,
104 ) -> Result<Option<(ModelArtifact, f32)>, ZerError> {
105 let inner = self.inner.lock().unwrap();
106 let best = inner
107 .artifacts
108 .values()
109 .map(|a| {
110 let dist = fingerprint_distance(fingerprint, &a.fingerprint);
111 (a.clone(), dist)
112 })
113 .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Equal));
114 Ok(best)
115 }
116
117 pub fn lookup_startup_mode(
125 &self,
126 fingerprint: &SchemaFingerprint,
127 ) -> Result<StartupMode, ZerError> {
128 if let Some(exact) = self.get_exact(fingerprint)? {
129 tracing::info!("exact schema match, warm load");
130 return Ok(StartupMode::WarmLoad(exact));
131 }
132
133 match self.get_nearest(fingerprint)? {
134 Some((artifact, dist)) if dist <= WARM_START_THRESHOLD => {
135 tracing::info!(dist, "similar schema, warm start");
136 Ok(StartupMode::WarmStart { artifact, distance: dist })
137 }
138 _ => {
139 tracing::info!("no suitable prior, cold start");
140 Ok(StartupMode::ColdStart)
141 }
142 }
143 }
144
145 pub fn list_all(&self) -> Result<Vec<ModelArtifact>, ZerError> {
149 let inner = self.inner.lock().unwrap();
150 Ok(inner.artifacts.values().cloned().collect())
151 }
152
153 pub fn delete(&self, schema_hash: &[u8; 32]) -> Result<bool, ZerError> {
157 let mut inner = self.inner.lock().unwrap();
158 let removed = inner.artifacts.remove(schema_hash).is_some();
159 if removed {
160 flush(&inner)?;
161 }
162 Ok(removed)
163 }
164}
165
166fn flush(inner: &RegistryInner) -> Result<(), ZerError> {
169 let Some(path) = &inner.path else {
170 return Ok(());
171 };
172 let payload = bincode::serialize(&inner.artifacts)
173 .map_err(|e| ZerError::Serialization(e.to_string()))?;
174 let mut buf = Vec::with_capacity(4 + payload.len());
175 buf.extend_from_slice(MAGIC);
176 buf.extend(payload);
177 let tmp = path.with_extension("zsm.tmp");
178 std::fs::write(&tmp, &buf).map_err(|e| ZerError::Store(e.to_string()))?;
179 std::fs::rename(&tmp, path).map_err(|e| ZerError::Store(e.to_string()))?;
180 Ok(())
181}
182
183fn load(path: &Path) -> Result<HashMap<[u8; 32], ModelArtifact>, ZerError> {
184 if !path.exists() {
185 return Ok(HashMap::new());
186 }
187 let bytes = std::fs::read(path).map_err(|e| ZerError::Store(e.to_string()))?;
188 if bytes.get(..4) != Some(MAGIC) {
189 return Err(ZerError::Store("invalid .zsm magic".into()));
190 }
191 bincode::deserialize(&bytes[4..]).map_err(|e| ZerError::Serialization(e.to_string()))
192}
193
194#[cfg(test)]
197mod tests {
198 use super::*;
199 use zer_core::{
200 schema::{FieldKind, SchemaBuilder},
201 scoring::ModelParams,
202 };
203
204 use crate::{artifact::ModelArtifact, fingerprint::SchemaFingerprint};
205
206 fn dummy_params(n_fields: usize) -> ModelParams {
207 ModelParams {
208 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
209 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
210 log_prior_odds: -2.0,
211 upper_threshold: 0.9,
212 lower_threshold: 0.1,
213 }
214 }
215
216 fn make_artifact(schema: &zer_core::schema::Schema, tag: &str) -> ModelArtifact {
217 ModelArtifact {
218 fingerprint: SchemaFingerprint::from_schema(schema),
219 params: dummy_params(schema.len()),
220 tag: Some(tag.into()),
221 trained_on: 0,
222 em_iterations: 25,
223 }
224 }
225
226 fn brp_schema() -> zer_core::schema::Schema {
227 SchemaBuilder::new()
228 .field("voornamen", FieldKind::Name)
229 .field("achternaam", FieldKind::Name)
230 .field("geboortedatum", FieldKind::Date)
231 .field("nationaliteit", FieldKind::Categorical)
232 .field("postcode", FieldKind::Id)
233 .build()
234 .unwrap()
235 }
236
237 fn sim_schema() -> zer_core::schema::Schema {
238 SchemaBuilder::new()
239 .field("sim_id", FieldKind::Id)
240 .field("msisdn", FieldKind::Phone)
241 .field("imsi", FieldKind::Id)
242 .field("voornamen", FieldKind::Name)
243 .field("achternaam", FieldKind::Name)
244 .field("geboortedatum", FieldKind::Date)
245 .field("nationaliteit", FieldKind::Categorical)
246 .build()
247 .unwrap()
248 }
249
250 #[test]
251 fn roundtrip_save_and_get_exact() {
252 let registry = SchemaRegistry::open_temporary().unwrap();
253 let schema = brp_schema();
254 let artifact = make_artifact(&schema, "brp_test");
255
256 registry.save(&artifact).unwrap();
257
258 let fp = SchemaFingerprint::from_schema(&schema);
259 let loaded = registry.get_exact(&fp).unwrap().unwrap();
260
261 assert_eq!(loaded.tag.as_deref(), Some("brp_test"));
262 assert_eq!(
263 loaded.fingerprint.schema_hash,
264 artifact.fingerprint.schema_hash
265 );
266 assert_eq!(loaded.params.upper_threshold, artifact.params.upper_threshold);
267 }
268
269 #[test]
270 fn get_exact_returns_none_for_unknown_schema() {
271 let registry = SchemaRegistry::open_temporary().unwrap();
272 let fp = SchemaFingerprint::from_schema(&brp_schema());
273 let result = registry.get_exact(&fp).unwrap();
274 assert!(result.is_none());
275 }
276
277 #[test]
278 fn list_all_returns_all_artifacts() {
279 let registry = SchemaRegistry::open_temporary().unwrap();
280 let brp = brp_schema();
281 let sim = sim_schema();
282
283 registry.save(&make_artifact(&brp, "brp")).unwrap();
284 registry.save(&make_artifact(&sim, "sim")).unwrap();
285
286 let all = registry.list_all().unwrap();
287 assert_eq!(all.len(), 2);
288 }
289
290 #[test]
291 fn delete_removes_artifact_and_returns_true() {
292 let registry = SchemaRegistry::open_temporary().unwrap();
293 let schema = brp_schema();
294 let artifact = make_artifact(&schema, "brp");
295 registry.save(&artifact).unwrap();
296
297 let removed = registry.delete(&artifact.fingerprint.schema_hash).unwrap();
298 assert!(removed, "delete should return true when the key existed");
299
300 let fp = SchemaFingerprint::from_schema(&schema);
301 assert!(registry.get_exact(&fp).unwrap().is_none());
302 }
303
304 #[test]
305 fn delete_returns_false_for_missing_key() {
306 let registry = SchemaRegistry::open_temporary().unwrap();
307 let hash = [0u8; 32];
308 assert!(!registry.delete(&hash).unwrap());
309 }
310
311 #[test]
312 fn startup_mode_exact_match_is_warm_load() {
313 let registry = SchemaRegistry::open_temporary().unwrap();
314 let schema = brp_schema();
315 registry.save(&make_artifact(&schema, "brp")).unwrap();
316
317 let fp = SchemaFingerprint::from_schema(&schema);
318 let mode = registry.lookup_startup_mode(&fp).unwrap();
319
320 assert!(
321 matches!(mode, StartupMode::WarmLoad(_)),
322 "exact schema match must return WarmLoad"
323 );
324 }
325
326 #[test]
327 fn startup_mode_added_field_is_warm_start() {
328 let registry = SchemaRegistry::open_temporary().unwrap();
329 registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
330
331 let extended = SchemaBuilder::new()
332 .field("voornamen", FieldKind::Name)
333 .field("achternaam", FieldKind::Name)
334 .field("geboortedatum", FieldKind::Date)
335 .field("nationaliteit", FieldKind::Categorical)
336 .field("postcode", FieldKind::Id)
337 .field("verblijfstitel", FieldKind::Categorical)
338 .build()
339 .unwrap();
340
341 let fp = SchemaFingerprint::from_schema(&extended);
342 let mode = registry.lookup_startup_mode(&fp).unwrap();
343
344 assert!(
345 matches!(mode, StartupMode::WarmStart { .. }),
346 "one added field should return WarmStart"
347 );
348 }
349
350 #[test]
351 fn startup_mode_incompatible_schema_is_cold_start() {
352 let registry = SchemaRegistry::open_temporary().unwrap();
353 registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
354
355 let fp = SchemaFingerprint::from_schema(&sim_schema());
356 let mode = registry.lookup_startup_mode(&fp).unwrap();
357
358 assert!(
359 matches!(mode, StartupMode::ColdStart),
360 "BRP artifact vs SIM schema should return ColdStart"
361 );
362 }
363
364 #[test]
365 fn startup_mode_empty_registry_is_cold_start() {
366 let registry = SchemaRegistry::open_temporary().unwrap();
367 let fp = SchemaFingerprint::from_schema(&brp_schema());
368 assert!(matches!(
369 registry.lookup_startup_mode(&fp).unwrap(),
370 StartupMode::ColdStart
371 ));
372 }
373
374 #[test]
375 fn nearest_prefers_closer_artifact() {
376 let registry = SchemaRegistry::open_temporary().unwrap();
377 registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
378 registry.save(&make_artifact(&sim_schema(), "sim")).unwrap();
379
380 let brp_like = SchemaBuilder::new()
381 .field("voornamen", FieldKind::Name)
382 .field("achternaam", FieldKind::Name)
383 .field("geboortedatum", FieldKind::Date)
384 .field("nationaliteit", FieldKind::Categorical)
385 .field("postcode", FieldKind::Id)
386 .field("verblijfstitel", FieldKind::Categorical)
387 .build()
388 .unwrap();
389
390 let (nearest, _dist) = registry
391 .get_nearest(&SchemaFingerprint::from_schema(&brp_like))
392 .unwrap()
393 .expect("registry is not empty");
394
395 assert_eq!(
396 nearest.tag.as_deref(),
397 Some("brp"),
398 "BRP-like schema should match the BRP artifact, not SIM"
399 );
400 }
401}