Skip to main content

zer_schema/
artifact.rs

1use zer_core::{error::ZerError, scoring::ModelParams};
2
3use crate::fingerprint::SchemaFingerprint;
4
5/// Everything that must be persisted after a successful EM training run.
6///
7/// Serializes to roughly 2–10 KB per artifact (bincode).
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct ModelArtifact {
10    /// Fingerprint of the schema and data distribution this model was trained on.
11    pub fingerprint: SchemaFingerprint,
12    /// Learned Fellegi-Sunter m/u parameters and decision thresholds.
13    pub params: ModelParams,
14    /// Optional human-readable label, e.g. `"brp_2024_q1"`.
15    pub tag: Option<String>,
16    /// Unix timestamp (seconds) when EM training completed.
17    pub trained_on: u64,
18    /// Number of EM iterations performed.
19    pub em_iterations: usize,
20}
21
22impl ModelArtifact {
23    /// Serialize this artifact to bytes using bincode.
24    pub fn to_bytes(&self) -> Result<Vec<u8>, ZerError> {
25        bincode::serialize(self).map_err(|e| ZerError::Serialization(e.to_string()))
26    }
27
28    /// Deserialize an artifact from bincode bytes.
29    pub fn from_bytes(bytes: &[u8]) -> Result<Self, ZerError> {
30        bincode::deserialize(bytes).map_err(|e| ZerError::Serialization(e.to_string()))
31    }
32}
33
34// ── Unit tests ────────────────────────────────────────────────────────────────
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use zer_core::schema::{FieldKind, SchemaBuilder};
40
41    fn dummy_artifact() -> ModelArtifact {
42        let schema = SchemaBuilder::new()
43            .field("voornamen", FieldKind::Name)
44            .field("achternaam", FieldKind::Name)
45            .field("geboortedatum", FieldKind::Date)
46            .build()
47            .unwrap();
48
49        let fingerprint = SchemaFingerprint::from_schema(&schema);
50
51        let params = ModelParams {
52            m: vec![
53                vec![0.02, 0.06, 0.12, 0.80],
54                vec![0.02, 0.06, 0.12, 0.80],
55                vec![0.01, 0.04, 0.10, 0.85],
56            ],
57            u: vec![
58                vec![0.70, 0.15, 0.10, 0.05],
59                vec![0.70, 0.15, 0.10, 0.05],
60                vec![0.80, 0.10, 0.07, 0.03],
61            ],
62            log_prior_odds: -2.0,
63            upper_threshold: 0.9,
64            lower_threshold: 0.1,
65        };
66
67        ModelArtifact {
68            fingerprint,
69            params,
70            tag: Some("test_artifact".into()),
71            trained_on: 0,
72            em_iterations: 25,
73        }
74    }
75
76    #[test]
77    fn roundtrip_preserves_all_fields() {
78        let original = dummy_artifact();
79        let bytes = original.to_bytes().expect("serialization must succeed");
80        let loaded = ModelArtifact::from_bytes(&bytes).expect("deserialization must succeed");
81
82        assert_eq!(original.tag, loaded.tag);
83        assert_eq!(original.em_iterations, loaded.em_iterations);
84        assert_eq!(original.params.upper_threshold, loaded.params.upper_threshold);
85        assert_eq!(original.params.lower_threshold, loaded.params.lower_threshold);
86        assert_eq!(original.params.log_prior_odds, loaded.params.log_prior_odds);
87        assert_eq!(original.fingerprint.schema_hash, loaded.fingerprint.schema_hash);
88    }
89
90    #[test]
91    fn roundtrip_preserves_m_u_tables() {
92        let original = dummy_artifact();
93        let bytes = original.to_bytes().unwrap();
94        let loaded = ModelArtifact::from_bytes(&bytes).unwrap();
95
96        assert_eq!(original.params.m.len(), loaded.params.m.len());
97        for (row_a, row_b) in original.params.m.iter().zip(loaded.params.m.iter()) {
98            for (va, vb) in row_a.iter().zip(row_b.iter()) {
99                assert!((va - vb).abs() < 1e-9, "m values must be bit-exact after roundtrip");
100            }
101        }
102    }
103
104    #[test]
105    fn serialized_size_under_10kb() {
106        let artifact = dummy_artifact();
107        let bytes = artifact.to_bytes().unwrap();
108        assert!(
109            bytes.len() < 10_240,
110            "serialized artifact for 3-field schema should be under 10 KB, got {} bytes",
111            bytes.len()
112        );
113    }
114
115    #[test]
116    fn from_bytes_rejects_garbage() {
117        let result = ModelArtifact::from_bytes(b"not valid bincode data");
118        assert!(result.is_err(), "garbage bytes must return an error");
119    }
120}