1use std::path::PathBuf;
2
3use zer_cluster::ClusterConfig;
4use zer_core::field_mapping::FieldMapping;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
11#[serde(rename_all = "PascalCase")]
12pub enum BatchStartupMode {
13 ColdStart,
15 WarmLoad,
17 WarmStart,
20}
21
22impl std::fmt::Display for BatchStartupMode {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::ColdStart => write!(f, "ColdStart"),
26 Self::WarmLoad => write!(f, "WarmLoad"),
27 Self::WarmStart => write!(f, "WarmStart"),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
34pub enum LinkMode {
35 #[default]
37 Deduplicate,
38
39 LinkOnly,
46
47 LinkAndDedupe,
50}
51
52impl LinkMode {
53 pub fn as_str(self) -> &'static str {
54 match self {
55 Self::Deduplicate => "deduplicate",
56 Self::LinkOnly => "link-only",
57 Self::LinkAndDedupe => "link-and-dedupe",
58 }
59 }
60}
61
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
64pub struct RateConfig {
65 pub slow_threshold: f32,
67 pub fast_threshold: f32,
69 pub bulk_threshold_multiplier: f32,
71}
72
73impl Default for RateConfig {
74 fn default() -> Self {
75 Self {
76 slow_threshold: 1.0,
77 fast_threshold: 100.0,
78 bulk_threshold_multiplier: 1.05,
79 }
80 }
81}
82
83#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
85pub struct PipelineConfig {
86 pub registry_path: PathBuf,
88
89 pub em_max_iter_cold: usize,
91
92 pub em_max_iter_warm: usize,
94
95 pub cluster_config: ClusterConfig,
97
98 pub gpu_min_batch: usize,
100
101 pub rate_config: RateConfig,
103
104 #[serde(default)]
111 pub upper_threshold: Option<f32>,
112
113 #[serde(default)]
118 pub lower_threshold: Option<f32>,
119
120 #[serde(default)]
127 pub link_mode: LinkMode,
128
129 #[serde(default = "default_max_bucket_size")]
139 pub max_bucket_size: usize,
140
141 #[serde(default)]
147 pub field_mappings: Vec<FieldMapping>,
148}
149
150const DEFAULT_MAX_BUCKET_SIZE: usize = 300;
151
152fn default_max_bucket_size() -> usize {
153 DEFAULT_MAX_BUCKET_SIZE
154}
155
156impl Default for PipelineConfig {
157 fn default() -> Self {
158 Self {
159 registry_path: PathBuf::from("schema.zsm"),
160 em_max_iter_cold: 25,
161 em_max_iter_warm: 3,
162 cluster_config: ClusterConfig::default(),
163 gpu_min_batch: 1_000,
164 rate_config: RateConfig::default(),
165 upper_threshold: None,
166 lower_threshold: None,
167 link_mode: LinkMode::Deduplicate,
168 max_bucket_size: DEFAULT_MAX_BUCKET_SIZE,
169 field_mappings: Vec::new(),
170 }
171 }
172}
173
174#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn default_config_has_sensible_values() {
182 let cfg = PipelineConfig::default();
183 assert_eq!(cfg.em_max_iter_cold, 25);
184 assert_eq!(cfg.em_max_iter_warm, 3);
185 assert_eq!(cfg.gpu_min_batch, 1_000);
186 }
187
188 #[test]
189 fn default_threshold_overrides_are_none() {
190 let cfg = PipelineConfig::default();
191 assert!(
192 cfg.upper_threshold.is_none(),
193 "upper_threshold must default to None"
194 );
195 assert!(
196 cfg.lower_threshold.is_none(),
197 "lower_threshold must default to None"
198 );
199 }
200
201 #[test]
202 fn threshold_overrides_round_trip_json() {
203 let cfg = PipelineConfig {
204 upper_threshold: Some(0.92),
205 lower_threshold: Some(0.08),
206 ..Default::default()
207 };
208 let json = serde_json::to_string(&cfg).expect("serialize");
209 let back: PipelineConfig = serde_json::from_str(&json).expect("deserialize");
210 assert_eq!(back.upper_threshold, Some(0.92));
211 assert_eq!(back.lower_threshold, Some(0.08));
212 }
213
214 #[test]
215 fn threshold_override_none_round_trips_from_json_without_field() {
216 let json = r#"{"registry_path":"schema.zsm","em_max_iter_cold":25,"em_max_iter_warm":3,"cluster_config":{"max_cluster_size":50,"within_cluster_min":0.85},"gpu_min_batch":1000,"rate_config":{"slow_threshold":1.0,"fast_threshold":100.0,"bulk_threshold_multiplier":1.05}}"#;
218 let cfg: PipelineConfig = serde_json::from_str(json).expect("deserialize");
219 assert!(cfg.upper_threshold.is_none());
220 assert!(cfg.lower_threshold.is_none());
221 assert_eq!(cfg.link_mode, LinkMode::Deduplicate);
223 assert_eq!(cfg.max_bucket_size, 300);
225 assert!(cfg.field_mappings.is_empty());
227 }
228
229 #[test]
230 fn max_bucket_size_default_is_300() {
231 let cfg = PipelineConfig::default();
232 assert_eq!(cfg.max_bucket_size, 300);
233 }
234
235 #[test]
236 fn max_bucket_size_round_trips_json() {
237 let cfg = PipelineConfig {
238 max_bucket_size: 500,
239 ..Default::default()
240 };
241 let json = serde_json::to_string(&cfg).expect("serialize");
242 let back: PipelineConfig = serde_json::from_str(&json).expect("deserialize");
243 assert_eq!(back.max_bucket_size, 500);
244 }
245
246 #[test]
247 fn link_mode_default_is_deduplicate() {
248 let cfg = PipelineConfig::default();
249 assert_eq!(cfg.link_mode, LinkMode::Deduplicate);
250 }
251
252 #[test]
253 fn link_mode_round_trips_json() {
254 let cfg = PipelineConfig {
255 link_mode: LinkMode::LinkOnly,
256 ..Default::default()
257 };
258 let json = serde_json::to_string(&cfg).expect("serialize");
259 let back: PipelineConfig = serde_json::from_str(&json).expect("deserialize");
260 assert_eq!(back.link_mode, LinkMode::LinkOnly);
261 }
262
263 #[test]
264 fn link_mode_link_and_dedupe_round_trips_json() {
265 let cfg = PipelineConfig {
266 link_mode: LinkMode::LinkAndDedupe,
267 ..Default::default()
268 };
269 let json = serde_json::to_string(&cfg).expect("serialize");
270 let back: PipelineConfig = serde_json::from_str(&json).expect("deserialize");
271 assert_eq!(back.link_mode, LinkMode::LinkAndDedupe);
272 }
273
274 #[test]
275 fn default_rate_config_thresholds_ordered() {
276 let r = RateConfig::default();
277 assert!(r.slow_threshold < r.fast_threshold);
278 assert!(r.bulk_threshold_multiplier > 1.0);
279 }
280
281 #[test]
282 fn pipeline_config_roundtrip_json() {
283 let cfg = PipelineConfig::default();
284 let json = serde_json::to_string(&cfg).expect("serialize");
285 let back: PipelineConfig = serde_json::from_str(&json).expect("deserialize");
286 assert_eq!(cfg.em_max_iter_cold, back.em_max_iter_cold);
287 assert_eq!(cfg.em_max_iter_warm, back.em_max_iter_warm);
288 assert_eq!(
289 cfg.rate_config.fast_threshold,
290 back.rate_config.fast_threshold
291 );
292 }
293
294 #[test]
295 fn cluster_config_default_reasonable() {
296 let cfg = PipelineConfig::default();
297 assert!(cfg.cluster_config.max_cluster_size > 0);
298 assert!(cfg.cluster_config.within_cluster_min > 0.0);
299 assert!(cfg.cluster_config.within_cluster_min < 1.0);
300 }
301}