Skip to main content

zer_schema/
registry.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4    sync::Mutex,
5};
6
7use zer_core::error::ZerError;
8
9use crate::{
10    artifact::ModelArtifact,
11    fingerprint::SchemaFingerprint,
12    similarity::{fingerprint_distance, WARM_START_THRESHOLD},
13};
14
15const MAGIC: &[u8] = b"ZSM\x01";
16
17/// Decides how the pipeline should initialize when a new dataset arrives.
18#[derive(Debug)]
19pub enum StartupMode {
20    /// Schema hash matches exactly, skip EM and use the saved params directly.
21    WarmLoad(ModelArtifact),
22    /// Schema is similar (distance ≤ threshold), use saved params as the EM
23    /// warm-start initializer and run 2–3 iterations to fine-tune.
24    WarmStart {
25        artifact: ModelArtifact,
26        distance: f32,
27    },
28    /// Schema is new or too different, initialize from priors and run full EM.
29    ColdStart,
30}
31
32struct RegistryInner {
33    path: Option<PathBuf>,
34    artifacts: HashMap<[u8; 32], ModelArtifact>,
35}
36
37/// Persistent store for trained [`ModelArtifact`]s.
38///
39/// Backed by a single portable `.zsm` binary file (`b"ZSM\x01"` magic +
40/// bincode-serialized `HashMap`). The file is written atomically on every
41/// mutation, a `.zsm.tmp` file is written first then renamed into place, so a
42/// crash during flush can never leave a partially-written registry.
43///
44/// The registry is small in practice (< 1 000 entries), so nearest-neighbor
45/// lookup performs a full linear scan without an index.
46pub struct SchemaRegistry {
47    inner: Mutex<RegistryInner>,
48}
49
50impl SchemaRegistry {
51    /// Open (or create) a registry at the given `.zsm` file path.
52    ///
53    /// If the file does not exist yet it is created on the first [`Self::save`] call.
54    pub fn open(path: &Path) -> Result<Self, ZerError> {
55        let artifacts = load(path)?;
56        Ok(Self {
57            inner: Mutex::new(RegistryInner {
58                path: Some(path.to_path_buf()),
59                artifacts,
60            }),
61        })
62    }
63
64    /// Create an in-memory registry. No file I/O; data is lost on drop.
65    #[cfg(test)]
66    pub(crate) fn open_temporary() -> Result<Self, ZerError> {
67        Ok(Self {
68            inner: Mutex::new(RegistryInner {
69                path: None,
70                artifacts: HashMap::new(),
71            }),
72        })
73    }
74
75    // ── Write ────────────────────────────────────────────────────────────────
76
77    /// Persist a trained model artifact. Overwrites any existing artifact with
78    /// the same schema hash and atomically flushes to disk.
79    pub fn save(&self, artifact: &ModelArtifact) -> Result<(), ZerError> {
80        let mut inner = self.inner.lock().unwrap();
81        inner
82            .artifacts
83            .insert(artifact.fingerprint.schema_hash, artifact.clone());
84        flush(&inner)?;
85        tracing::debug!(tag = artifact.tag.as_deref(), "saved model artifact");
86        Ok(())
87    }
88
89    // ── Read ─────────────────────────────────────────────────────────────────
90
91    /// Exact lookup by schema hash. Returns `None` if no matching artifact exists.
92    pub fn get_exact(
93        &self,
94        fingerprint: &SchemaFingerprint,
95    ) -> Result<Option<ModelArtifact>, ZerError> {
96        let inner = self.inner.lock().unwrap();
97        Ok(inner.artifacts.get(&fingerprint.schema_hash).cloned())
98    }
99
100    /// Nearest-neighbor lookup: returns the closest artifact and its distance.
101    ///
102    /// Performs a full linear scan, acceptable because the registry is expected
103    /// to hold far fewer than 1 000 entries.
104    ///
105    /// Returns `None` when the registry is empty.
106    pub fn get_nearest(
107        &self,
108        fingerprint: &SchemaFingerprint,
109    ) -> Result<Option<(ModelArtifact, f32)>, ZerError> {
110        let inner = self.inner.lock().unwrap();
111        let best = inner
112            .artifacts
113            .values()
114            .map(|a| {
115                let dist = fingerprint_distance(fingerprint, &a.fingerprint);
116                (a.clone(), dist)
117            })
118            .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap_or(std::cmp::Ordering::Equal));
119        Ok(best)
120    }
121
122    /// Determine the startup mode for an incoming dataset given its fingerprint.
123    ///
124    /// ```text
125    /// exact hash match         → WarmLoad   (skip EM entirely)
126    /// distance ≤ 0.25          → WarmStart  (2–3 EM iterations from saved init)
127    /// distance  > 0.25 / empty → ColdStart  (full EM from priors)
128    /// ```
129    pub fn lookup_startup_mode(
130        &self,
131        fingerprint: &SchemaFingerprint,
132    ) -> Result<StartupMode, ZerError> {
133        if let Some(exact) = self.get_exact(fingerprint)? {
134            tracing::info!("exact schema match, warm load");
135            return Ok(StartupMode::WarmLoad(exact));
136        }
137
138        match self.get_nearest(fingerprint)? {
139            Some((artifact, dist)) if dist <= WARM_START_THRESHOLD => {
140                tracing::info!(dist, "similar schema, warm start");
141                Ok(StartupMode::WarmStart {
142                    artifact,
143                    distance: dist,
144                })
145            }
146            _ => {
147                tracing::info!("no suitable prior, cold start");
148                Ok(StartupMode::ColdStart)
149            }
150        }
151    }
152
153    // ── Enumeration / deletion ────────────────────────────────────────────────
154
155    /// Return all stored artifacts in arbitrary order.
156    pub fn list_all(&self) -> Result<Vec<ModelArtifact>, ZerError> {
157        let inner = self.inner.lock().unwrap();
158        Ok(inner.artifacts.values().cloned().collect())
159    }
160
161    /// Delete the artifact for the given schema hash.
162    ///
163    /// Returns `true` if an artifact was found and removed, `false` otherwise.
164    pub fn delete(&self, schema_hash: &[u8; 32]) -> Result<bool, ZerError> {
165        let mut inner = self.inner.lock().unwrap();
166        let removed = inner.artifacts.remove(schema_hash).is_some();
167        if removed {
168            flush(&inner)?;
169        }
170        Ok(removed)
171    }
172}
173
174// ── File I/O ──────────────────────────────────────────────────────────────────
175
176fn flush(inner: &RegistryInner) -> Result<(), ZerError> {
177    let Some(path) = &inner.path else {
178        return Ok(());
179    };
180    let payload =
181        bincode::serialize(&inner.artifacts).map_err(|e| ZerError::Serialization(e.to_string()))?;
182    let mut buf = Vec::with_capacity(4 + payload.len());
183    buf.extend_from_slice(MAGIC);
184    buf.extend(payload);
185    let tmp = path.with_extension("zsm.tmp");
186    std::fs::write(&tmp, &buf).map_err(|e| ZerError::Store(e.to_string()))?;
187    std::fs::rename(&tmp, path).map_err(|e| ZerError::Store(e.to_string()))?;
188    Ok(())
189}
190
191fn load(path: &Path) -> Result<HashMap<[u8; 32], ModelArtifact>, ZerError> {
192    if !path.exists() {
193        return Ok(HashMap::new());
194    }
195    let bytes = std::fs::read(path).map_err(|e| ZerError::Store(e.to_string()))?;
196    if bytes.get(..4) != Some(MAGIC) {
197        return Err(ZerError::Store("invalid .zsm magic".into()));
198    }
199    bincode::deserialize(&bytes[4..]).map_err(|e| ZerError::Serialization(e.to_string()))
200}
201
202// ── Unit tests ────────────────────────────────────────────────────────────────
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use zer_core::{
208        schema::{FieldKind, SchemaBuilder},
209        scoring::ModelParams,
210    };
211
212    use crate::{artifact::ModelArtifact, fingerprint::SchemaFingerprint};
213
214    fn dummy_params(n_fields: usize) -> ModelParams {
215        ModelParams {
216            m: vec![vec![0.02, 0.06, 0.12, 0.80]; n_fields],
217            u: vec![vec![0.70, 0.15, 0.10, 0.05]; n_fields],
218            log_prior_odds: -2.0,
219            upper_threshold: 0.9,
220            lower_threshold: 0.1,
221        }
222    }
223
224    fn make_artifact(schema: &zer_core::schema::Schema, tag: &str) -> ModelArtifact {
225        ModelArtifact {
226            fingerprint: SchemaFingerprint::from_schema(schema),
227            params: dummy_params(schema.len()),
228            tag: Some(tag.into()),
229            trained_on: 0,
230            em_iterations: 25,
231        }
232    }
233
234    fn brp_schema() -> zer_core::schema::Schema {
235        SchemaBuilder::new()
236            .field("voornamen", FieldKind::Name)
237            .field("achternaam", FieldKind::Name)
238            .field("geboortedatum", FieldKind::Date)
239            .field("nationaliteit", FieldKind::Categorical)
240            .field("postcode", FieldKind::Id)
241            .build()
242            .unwrap()
243    }
244
245    fn sim_schema() -> zer_core::schema::Schema {
246        SchemaBuilder::new()
247            .field("sim_id", FieldKind::Id)
248            .field("msisdn", FieldKind::Phone)
249            .field("imsi", FieldKind::Id)
250            .field("voornamen", FieldKind::Name)
251            .field("achternaam", FieldKind::Name)
252            .field("geboortedatum", FieldKind::Date)
253            .field("nationaliteit", FieldKind::Categorical)
254            .build()
255            .unwrap()
256    }
257
258    #[test]
259    fn roundtrip_save_and_get_exact() {
260        let registry = SchemaRegistry::open_temporary().unwrap();
261        let schema = brp_schema();
262        let artifact = make_artifact(&schema, "brp_test");
263
264        registry.save(&artifact).unwrap();
265
266        let fp = SchemaFingerprint::from_schema(&schema);
267        let loaded = registry.get_exact(&fp).unwrap().unwrap();
268
269        assert_eq!(loaded.tag.as_deref(), Some("brp_test"));
270        assert_eq!(
271            loaded.fingerprint.schema_hash,
272            artifact.fingerprint.schema_hash
273        );
274        assert_eq!(
275            loaded.params.upper_threshold,
276            artifact.params.upper_threshold
277        );
278    }
279
280    #[test]
281    fn get_exact_returns_none_for_unknown_schema() {
282        let registry = SchemaRegistry::open_temporary().unwrap();
283        let fp = SchemaFingerprint::from_schema(&brp_schema());
284        let result = registry.get_exact(&fp).unwrap();
285        assert!(result.is_none());
286    }
287
288    #[test]
289    fn list_all_returns_all_artifacts() {
290        let registry = SchemaRegistry::open_temporary().unwrap();
291        let brp = brp_schema();
292        let sim = sim_schema();
293
294        registry.save(&make_artifact(&brp, "brp")).unwrap();
295        registry.save(&make_artifact(&sim, "sim")).unwrap();
296
297        let all = registry.list_all().unwrap();
298        assert_eq!(all.len(), 2);
299    }
300
301    #[test]
302    fn delete_removes_artifact_and_returns_true() {
303        let registry = SchemaRegistry::open_temporary().unwrap();
304        let schema = brp_schema();
305        let artifact = make_artifact(&schema, "brp");
306        registry.save(&artifact).unwrap();
307
308        let removed = registry.delete(&artifact.fingerprint.schema_hash).unwrap();
309        assert!(removed, "delete should return true when the key existed");
310
311        let fp = SchemaFingerprint::from_schema(&schema);
312        assert!(registry.get_exact(&fp).unwrap().is_none());
313    }
314
315    #[test]
316    fn delete_returns_false_for_missing_key() {
317        let registry = SchemaRegistry::open_temporary().unwrap();
318        let hash = [0u8; 32];
319        assert!(!registry.delete(&hash).unwrap());
320    }
321
322    #[test]
323    fn startup_mode_exact_match_is_warm_load() {
324        let registry = SchemaRegistry::open_temporary().unwrap();
325        let schema = brp_schema();
326        registry.save(&make_artifact(&schema, "brp")).unwrap();
327
328        let fp = SchemaFingerprint::from_schema(&schema);
329        let mode = registry.lookup_startup_mode(&fp).unwrap();
330
331        assert!(
332            matches!(mode, StartupMode::WarmLoad(_)),
333            "exact schema match must return WarmLoad"
334        );
335    }
336
337    #[test]
338    fn startup_mode_added_field_is_warm_start() {
339        let registry = SchemaRegistry::open_temporary().unwrap();
340        registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
341
342        let extended = SchemaBuilder::new()
343            .field("voornamen", FieldKind::Name)
344            .field("achternaam", FieldKind::Name)
345            .field("geboortedatum", FieldKind::Date)
346            .field("nationaliteit", FieldKind::Categorical)
347            .field("postcode", FieldKind::Id)
348            .field("verblijfstitel", FieldKind::Categorical)
349            .build()
350            .unwrap();
351
352        let fp = SchemaFingerprint::from_schema(&extended);
353        let mode = registry.lookup_startup_mode(&fp).unwrap();
354
355        assert!(
356            matches!(mode, StartupMode::WarmStart { .. }),
357            "one added field should return WarmStart"
358        );
359    }
360
361    #[test]
362    fn startup_mode_incompatible_schema_is_cold_start() {
363        let registry = SchemaRegistry::open_temporary().unwrap();
364        registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
365
366        let fp = SchemaFingerprint::from_schema(&sim_schema());
367        let mode = registry.lookup_startup_mode(&fp).unwrap();
368
369        assert!(
370            matches!(mode, StartupMode::ColdStart),
371            "BRP artifact vs SIM schema should return ColdStart"
372        );
373    }
374
375    #[test]
376    fn startup_mode_empty_registry_is_cold_start() {
377        let registry = SchemaRegistry::open_temporary().unwrap();
378        let fp = SchemaFingerprint::from_schema(&brp_schema());
379        assert!(matches!(
380            registry.lookup_startup_mode(&fp).unwrap(),
381            StartupMode::ColdStart
382        ));
383    }
384
385    #[test]
386    fn nearest_prefers_closer_artifact() {
387        let registry = SchemaRegistry::open_temporary().unwrap();
388        registry.save(&make_artifact(&brp_schema(), "brp")).unwrap();
389        registry.save(&make_artifact(&sim_schema(), "sim")).unwrap();
390
391        let brp_like = SchemaBuilder::new()
392            .field("voornamen", FieldKind::Name)
393            .field("achternaam", FieldKind::Name)
394            .field("geboortedatum", FieldKind::Date)
395            .field("nationaliteit", FieldKind::Categorical)
396            .field("postcode", FieldKind::Id)
397            .field("verblijfstitel", FieldKind::Categorical)
398            .build()
399            .unwrap();
400
401        let (nearest, _dist) = registry
402            .get_nearest(&SchemaFingerprint::from_schema(&brp_like))
403            .unwrap()
404            .expect("registry is not empty");
405
406        assert_eq!(
407            nearest.tag.as_deref(),
408            Some("brp"),
409            "BRP-like schema should match the BRP artifact, not SIM"
410        );
411    }
412}