1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4 sync::Mutex,
5};
6
7use zer_core::error::ZerError;
8
9use crate::{
10 artifact::ModelArtifact,
11 fingerprint::SchemaFingerprint,
12 similarity::{fingerprint_distance, WARM_START_THRESHOLD},
13};
14
15const MAGIC: &[u8] = b"ZSM\x01";
16
17#[derive(Debug)]
19pub enum StartupMode {
20 WarmLoad(ModelArtifact),
22 WarmStart {
25 artifact: ModelArtifact,
26 distance: f32,
27 },
28 ColdStart,
30}
31
32struct RegistryInner {
33 path: Option<PathBuf>,
34 artifacts: HashMap<[u8; 32], ModelArtifact>,
35}
36
37pub struct SchemaRegistry {
47 inner: Mutex<RegistryInner>,
48}
49
50impl SchemaRegistry {
51 pub fn open(path: &Path) -> Result<Self, ZerError> {
55 let artifacts = load(path)?;
56 Ok(Self {
57 inner: Mutex::new(RegistryInner {
58 path: Some(path.to_path_buf()),
59 artifacts,
60 }),
61 })
62 }
63
64 #[cfg(test)]
66 pub(crate) fn open_temporary() -> Result<Self, ZerError> {
67 Ok(Self {
68 inner: Mutex::new(RegistryInner {
69 path: None,
70 artifacts: HashMap::new(),
71 }),
72 })
73 }
74
75 pub fn save(&self, artifact: &ModelArtifact) -> Result<(), ZerError> {
80 let mut inner = self.inner.lock().unwrap();
81 inner
82 .artifacts
83 .insert(artifact.fingerprint.schema_hash, artifact.clone());
84 flush(&inner)?;
85 tracing::debug!(tag = artifact.tag.as_deref(), "saved model artifact");
86 Ok(())
87 }
88
89 pub fn get_exact(
93 &self,
94 fingerprint: &SchemaFingerprint,
95 ) -> Result<Option<ModelArtifact>, ZerError> {
96 let inner = self.inner.lock().unwrap();
97 Ok(inner.artifacts.get(&fingerprint.schema_hash).cloned())
98 }
99
100 pub fn get_nearest(
107 &self,
108 fingerprint: &SchemaFingerprint,
109 ) -> Result<Option<(ModelArtifact, f32)>, ZerError> {
110 let inner = self.inner.lock().unwrap();
111 let best = inner
112 .artifacts
113 .values()
114 .map(|a| {
115 let dist = fingerprint_distance(fingerprint, &a.fingerprint);
116 (a.clone(), dist)
117 })
118 .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Equal));
119 Ok(best)
120 }
121
122 pub fn lookup_startup_mode(
130 &self,
131 fingerprint: &SchemaFingerprint,
132 ) -> Result<StartupMode, ZerError> {
133 if let Some(exact) = self.get_exact(fingerprint)? {
134 tracing::info!("exact schema match, warm load");
135 return Ok(StartupMode::WarmLoad(exact));
136 }
137
138 match self.get_nearest(fingerprint)? {
139 Some((artifact, dist)) if dist <= WARM_START_THRESHOLD => {
140 tracing::info!(dist, "similar schema, warm start");
141 Ok(StartupMode::WarmStart {
142 artifact,
143 distance: dist,
144 })
145 }
146 _ => {
147 tracing::info!("no suitable prior, cold start");
148 Ok(StartupMode::ColdStart)
149 }
150 }
151 }
152
153 pub fn list_all(&self) -> Result<Vec<ModelArtifact>, ZerError> {
157 let inner = self.inner.lock().unwrap();
158 Ok(inner.artifacts.values().cloned().collect())
159 }
160
161 pub fn delete(&self, schema_hash: &[u8; 32]) -> Result<bool, ZerError> {
165 let mut inner = self.inner.lock().unwrap();
166 let removed = inner.artifacts.remove(schema_hash).is_some();
167 if removed {
168 flush(&inner)?;
169 }
170 Ok(removed)
171 }
172}
173
174fn flush(inner: &RegistryInner) -> Result<(), ZerError> {
177 let Some(path) = &inner.path else {
178 return Ok(());
179 };
180 let payload =
181 bincode::serialize(&inner.artifacts).map_err(|e| ZerError::Serialization(e.to_string()))?;
182 let mut buf = Vec::with_capacity(4 + payload.len());
183 buf.extend_from_slice(MAGIC);
184 buf.extend(payload);
185 let tmp = path.with_extension("zsm.tmp");
186 std::fs::write(&tmp, &buf).map_err(|e| ZerError::Store(e.to_string()))?;
187 std::fs::rename(&tmp, path).map_err(|e| ZerError::Store(e.to_string()))?;
188 Ok(())
189}
190
191fn load(path: &Path) -> Result<HashMap<[u8; 32], ModelArtifact>, ZerError> {
192 if !path.exists() {
193 return Ok(HashMap::new());
194 }
195 let bytes = std::fs::read(path).map_err(|e| ZerError::Store(e.to_string()))?;
196 if bytes.get(..4) != Some(MAGIC) {
197 return Err(ZerError::Store("invalid .zsm magic".into()));
198 }
199 bincode::deserialize(&bytes[4..]).map_err(|e| ZerError::Serialization(e.to_string()))
200}
201
202#[cfg(test)]
205mod tests {
206 use super::*;
207 use zer_core::{
208 schema::{FieldKind, SchemaBuilder},
209 scoring::ModelParams,
210 };
211
212 use crate::{artifact::ModelArtifact, fingerprint::SchemaFingerprint};
213
214 fn dummy_params(n_fields: usize) -> ModelParams {
215 ModelParams {
216 m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
217 u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
218 log_prior_odds: -2.0,
219 upper_threshold: 0.9,
220 lower_threshold: 0.1,
221 }
222 }
223
224 fn make_artifact(schema: &zer_core::schema::Schema, tag: &str) -> ModelArtifact {
225 ModelArtifact {
226 fingerprint: SchemaFingerprint::from_schema(schema),
227 params: dummy_params(schema.len()),
228 tag: Some(tag.into()),
229 trained_on: 0,
230 em_iterations: 25,
231 }
232 }
233
234 fn brp_schema() -> zer_core::schema::Schema {
235 SchemaBuilder::new()
236 .field("voornamen", FieldKind::Name)
237 .field("achternaam", FieldKind::Name)
238 .field("geboortedatum", FieldKind::Date)
239 .field("nationaliteit", FieldKind::Categorical)
240 .field("postcode", FieldKind::Id)
241 .build()
242 .unwrap()
243 }
244
245 fn sim_schema() -> zer_core::schema::Schema {
246 SchemaBuilder::new()
247 .field("sim_id", FieldKind::Id)
248 .field("msisdn", FieldKind::Phone)
249 .field("imsi", FieldKind::Id)
250 .field("voornamen", FieldKind::Name)
251 .field("achternaam", FieldKind::Name)
252 .field("geboortedatum", FieldKind::Date)
253 .field("nationaliteit", FieldKind::Categorical)
254 .build()
255 .unwrap()
256 }
257
258 #[test]
259 fn roundtrip_save_and_get_exact() {
260 let registry = SchemaRegistry::open_temporary().unwrap();
261 let schema = brp_schema();
262 let artifact = make_artifact(&schema, "brp_test");
263
264 registry.save(&artifact).unwrap();
265
266 let fp = SchemaFingerprint::from_schema(&schema);
267 let loaded = registry.get_exact(&fp).unwrap().unwrap();
268
269 assert_eq!(loaded.tag.as_deref(), Some("brp_test"));
270 assert_eq!(
271 loaded.fingerprint.schema_hash,
272 artifact.fingerprint.schema_hash
273 );
274 assert_eq!(
275 loaded.params.upper_threshold,
276 artifact.params.upper_threshold
277 );
278 }
279
280 #[test]
281 fn get_exact_returns_none_for_unknown_schema() {
282 let registry = SchemaRegistry::open_temporary().unwrap();
283 let fp = SchemaFingerprint::from_schema(&brp_schema());
284 let result = registry.get_exact(&fp).unwrap();
285 assert!(result.is_none());
286 }
287
288 #[test]
289 fn list_all_returns_all_artifacts() {
290 let registry = SchemaRegistry::open_temporary().unwrap();
291 let brp = brp_schema();
292 let sim = sim_schema();
293
294 registry.save(&make_artifact(&brp, "brp")).unwrap();
295 registry.save(&make_artifact(&sim, "sim")).unwrap();
296
297 let all = registry.list_all().unwrap();
298 assert_eq!(all.len(), 2);
299 }
300
301 #[test]
302 fn delete_removes_artifact_and_returns_true() {
303 let registry = SchemaRegistry::open_temporary().unwrap();
304 let schema = brp_schema();
305 let artifact = make_artifact(&schema, "brp");
306 registry.save(&artifact).unwrap();
307
308 let removed = registry.delete(&artifact.fingerprint.schema_hash).unwrap();
309 assert!(removed, "delete should return true when the key existed");
310
311 let fp = SchemaFingerprint::from_schema(&schema);
312 assert!(registry.get_exact(&fp).unwrap().is_none());
313 }
314
315 #[test]
316 fn delete_returns_false_for_missing_key() {
317 let registry = SchemaRegistry::open_temporary().unwrap();
318 let hash = [0u8; 32];
319 assert!(!registry.delete(&hash).unwrap());
320 }
321
322 #[test]
323 fn startup_mode_exact_match_is_warm_load() {
324 let registry = SchemaRegistry::open_temporary().unwrap();
325 let schema = brp_schema();
326 registry.save(&make_artifact(&schema, "brp")).unwrap();
327
328 let fp = SchemaFingerprint::from_schema(&schema);
329 let mode = registry.lookup_startup_mode(&fp).unwrap();
330
331 assert!(
332 matches!(mode, StartupMode::WarmLoad(_)),
333 "exact schema match must return WarmLoad"
334 );
335 }
336
337 #[test]
338 fn startup_mode_added_field_is_warm_start() {
339 let registry = SchemaRegistry::open_temporary().unwrap();
340 registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
341
342 let extended = SchemaBuilder::new()
343 .field("voornamen", FieldKind::Name)
344 .field("achternaam", FieldKind::Name)
345 .field("geboortedatum", FieldKind::Date)
346 .field("nationaliteit", FieldKind::Categorical)
347 .field("postcode", FieldKind::Id)
348 .field("verblijfstitel", FieldKind::Categorical)
349 .build()
350 .unwrap();
351
352 let fp = SchemaFingerprint::from_schema(&extended);
353 let mode = registry.lookup_startup_mode(&fp).unwrap();
354
355 assert!(
356 matches!(mode, StartupMode::WarmStart { .. }),
357 "one added field should return WarmStart"
358 );
359 }
360
361 #[test]
362 fn startup_mode_incompatible_schema_is_cold_start() {
363 let registry = SchemaRegistry::open_temporary().unwrap();
364 registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
365
366 let fp = SchemaFingerprint::from_schema(&sim_schema());
367 let mode = registry.lookup_startup_mode(&fp).unwrap();
368
369 assert!(
370 matches!(mode, StartupMode::ColdStart),
371 "BRP artifact vs SIM schema should return ColdStart"
372 );
373 }
374
375 #[test]
376 fn startup_mode_empty_registry_is_cold_start() {
377 let registry = SchemaRegistry::open_temporary().unwrap();
378 let fp = SchemaFingerprint::from_schema(&brp_schema());
379 assert!(matches!(
380 registry.lookup_startup_mode(&fp).unwrap(),
381 StartupMode::ColdStart
382 ));
383 }
384
385 #[test]
386 fn nearest_prefers_closer_artifact() {
387 let registry = SchemaRegistry::open_temporary().unwrap();
388 registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
389 registry.save(&make_artifact(&sim_schema(), "sim")).unwrap();
390
391 let brp_like = SchemaBuilder::new()
392 .field("voornamen", FieldKind::Name)
393 .field("achternaam", FieldKind::Name)
394 .field("geboortedatum", FieldKind::Date)
395 .field("nationaliteit", FieldKind::Categorical)
396 .field("postcode", FieldKind::Id)
397 .field("verblijfstitel", FieldKind::Categorical)
398 .build()
399 .unwrap();
400
401 let (nearest, _dist) = registry
402 .get_nearest(&SchemaFingerprint::from_schema(&brp_like))
403 .unwrap()
404 .expect("registry is not empty");
405
406 assert_eq!(
407 nearest.tag.as_deref(),
408 Some("brp"),
409 "BRP-like schema should match the BRP artifact, not SIM"
410 );
411 }
412}