Skip to main content

zer_schema/
artifact.rs

1use zer_core::{error::ZerError, scoring::ModelParams};
2
3use crate::fingerprint::SchemaFingerprint;
4
5/// Everything that must be persisted after a successful EM training run.
6///
7/// Serializes to roughly 2–10 KB per artifact (bincode).
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct ModelArtifact {
10    /// Fingerprint of the schema and data distribution this model was trained on.
11    pub fingerprint: SchemaFingerprint,
12    /// Learned Fellegi-Sunter m/u parameters and decision thresholds.
13    pub params: ModelParams,
14    /// Optional human-readable label, e.g. `"brp_2024_q1"`.
15    pub tag: Option<String>,
16    /// Unix timestamp (seconds) when EM training completed.
17    pub trained_on: u64,
18    /// Number of EM iterations performed.
19    pub em_iterations: usize,
20}
21
22impl ModelArtifact {
23    /// Serialize this artifact to bytes using bincode.
24    pub fn to_bytes(&self) -> Result<Vec<u8>, ZerError> {
25        bincode::serialize(self).map_err(|e| ZerError::Serialization(e.to_string()))
26    }
27
28    /// Deserialize an artifact from bincode bytes.
29    pub fn from_bytes(bytes: &[u8]) -> Result<Self, ZerError> {
30        bincode::deserialize(bytes).map_err(|e| ZerError::Serialization(e.to_string()))
31    }
32}
33
34// ── Unit tests ────────────────────────────────────────────────────────────────
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use zer_core::schema::{FieldKind, SchemaBuilder};
40
41    fn dummy_artifact() -> ModelArtifact {
42        let schema = SchemaBuilder::new()
43            .field("voornamen", FieldKind::Name)
44            .field("achternaam", FieldKind::Name)
45            .field("geboortedatum", FieldKind::Date)
46            .build()
47            .unwrap();
48
49        let fingerprint = SchemaFingerprint::from_schema(&schema);
50
51        let params = ModelParams {
52            m: vec![
53                vec![0.02, 0.06, 0.12, 0.80],
54                vec![0.02, 0.06, 0.12, 0.80],
55                vec![0.01, 0.04, 0.10, 0.85],
56            ],
57            u: vec![
58                vec![0.70, 0.15, 0.10, 0.05],
59                vec![0.70, 0.15, 0.10, 0.05],
60                vec![0.80, 0.10, 0.07, 0.03],
61            ],
62            log_prior_odds: -2.0,
63            upper_threshold: 0.9,
64            lower_threshold: 0.1,
65        };
66
67        ModelArtifact {
68            fingerprint,
69            params,
70            tag: Some("test_artifact".into()),
71            trained_on: 0,
72            em_iterations: 25,
73        }
74    }
75
76    #[test]
77    fn roundtrip_preserves_all_fields() {
78        let original = dummy_artifact();
79        let bytes = original.to_bytes().expect("serialization must succeed");
80        let loaded = ModelArtifact::from_bytes(&bytes).expect("deserialization must succeed");
81
82        assert_eq!(original.tag, loaded.tag);
83        assert_eq!(original.em_iterations, loaded.em_iterations);
84        assert_eq!(
85            original.params.upper_threshold,
86            loaded.params.upper_threshold
87        );
88        assert_eq!(
89            original.params.lower_threshold,
90            loaded.params.lower_threshold
91        );
92        assert_eq!(original.params.log_prior_odds, loaded.params.log_prior_odds);
93        assert_eq!(
94            original.fingerprint.schema_hash,
95            loaded.fingerprint.schema_hash
96        );
97    }
98
99    #[test]
100    fn roundtrip_preserves_m_u_tables() {
101        let original = dummy_artifact();
102        let bytes = original.to_bytes().unwrap();
103        let loaded = ModelArtifact::from_bytes(&bytes).unwrap();
104
105        assert_eq!(original.params.m.len(), loaded.params.m.len());
106        for (row_a, row_b) in original.params.m.iter().zip(loaded.params.m.iter()) {
107            for (va, vb) in row_a.iter().zip(row_b.iter()) {
108                assert!(
109                    (va - vb).abs() < 1e-9,
110                    "m values must be bit-exact after roundtrip"
111                );
112            }
113        }
114    }
115
116    #[test]
117    fn serialized_size_under_10kb() {
118        let artifact = dummy_artifact();
119        let bytes = artifact.to_bytes().unwrap();
120        assert!(
121            bytes.len() < 10_240,
122            "serialized artifact for 3-field schema should be under 10 KB, got {} bytes",
123            bytes.len()
124        );
125    }
126
127    #[test]
128    fn from_bytes_rejects_garbage() {
129        let result = ModelArtifact::from_bytes(b"not valid bincode data");
130        assert!(result.is_err(), "garbage bytes must return an error");
131    }
132}