1use zer_core::{error::ZerError, scoring::ModelParams};
2
3use crate::fingerprint::SchemaFingerprint;
4
5#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
9pub struct ModelArtifact {
10 pub fingerprint: SchemaFingerprint,
12 pub params: ModelParams,
14 pub tag: Option<String>,
16 pub trained_on: u64,
18 pub em_iterations: usize,
20}
21
22impl ModelArtifact {
23 pub fn to_bytes(&self) -> Result<Vec<u8>, ZerError> {
25 bincode::serialize(self).map_err(|e| ZerError::Serialization(e.to_string()))
26 }
27
28 pub fn from_bytes(bytes: &[u8]) -> Result<Self, ZerError> {
30 bincode::deserialize(bytes).map_err(|e| ZerError::Serialization(e.to_string()))
31 }
32}
33
34#[cfg(test)]
37mod tests {
38 use super::*;
39 use zer_core::schema::{FieldKind, SchemaBuilder};
40
41 fn dummy_artifact() -> ModelArtifact {
42 let schema = SchemaBuilder::new()
43 .field("voornamen", FieldKind::Name)
44 .field("achternaam", FieldKind::Name)
45 .field("geboortedatum", FieldKind::Date)
46 .build()
47 .unwrap();
48
49 let fingerprint = SchemaFingerprint::from_schema(&schema);
50
51 let params = ModelParams {
52 m: vec![
53 vec![0.02, 0.06, 0.12, 0.80],
54 vec![0.02, 0.06, 0.12, 0.80],
55 vec![0.01, 0.04, 0.10, 0.85],
56 ],
57 u: vec![
58 vec![0.70, 0.15, 0.10, 0.05],
59 vec![0.70, 0.15, 0.10, 0.05],
60 vec![0.80, 0.10, 0.07, 0.03],
61 ],
62 log_prior_odds: -2.0,
63 upper_threshold: 0.9,
64 lower_threshold: 0.1,
65 };
66
67 ModelArtifact {
68 fingerprint,
69 params,
70 tag: Some("test_artifact".into()),
71 trained_on: 0,
72 em_iterations: 25,
73 }
74 }
75
76 #[test]
77 fn roundtrip_preserves_all_fields() {
78 let original = dummy_artifact();
79 let bytes = original.to_bytes().expect("serialization must succeed");
80 let loaded = ModelArtifact::from_bytes(&bytes).expect("deserialization must succeed");
81
82 assert_eq!(original.tag, loaded.tag);
83 assert_eq!(original.em_iterations, loaded.em_iterations);
84 assert_eq!(original.params.upper_threshold, loaded.params.upper_threshold);
85 assert_eq!(original.params.lower_threshold, loaded.params.lower_threshold);
86 assert_eq!(original.params.log_prior_odds, loaded.params.log_prior_odds);
87 assert_eq!(original.fingerprint.schema_hash, loaded.fingerprint.schema_hash);
88 }
89
90 #[test]
91 fn roundtrip_preserves_m_u_tables() {
92 let original = dummy_artifact();
93 let bytes = original.to_bytes().unwrap();
94 let loaded = ModelArtifact::from_bytes(&bytes).unwrap();
95
96 assert_eq!(original.params.m.len(), loaded.params.m.len());
97 for (row_a, row_b) in original.params.m.iter().zip(loaded.params.m.iter()) {
98 for (va, vb) in row_a.iter().zip(row_b.iter()) {
99 assert!((va - vb).abs() < 1e-9, "m values must be bit-exact after roundtrip");
100 }
101 }
102 }
103
104 #[test]
105 fn serialized_size_under_10kb() {
106 let artifact = dummy_artifact();
107 let bytes = artifact.to_bytes().unwrap();
108 assert!(
109 bytes.len() < 10_240,
110 "serialized artifact for 3-field schema should be under 10 KB, got {} bytes",
111 bytes.len()
112 );
113 }
114
115 #[test]
116 fn from_bytes_rejects_garbage() {
117 let result = ModelArtifact::from_bytes(b"not valid bincode data");
118 assert!(result.is_err(), "garbage bytes must return an error");
119 }
120}