1use std::collections::HashMap;
22use std::sync::Arc;
23
24use uni_common::Value;
25use uni_cypher::ast::{Clause, Expr, MatchClause, ReturnClause, ReturnItem, Statement};
26use uni_cypher::locy_ast::CalibrationMethod;
27use uni_locy::{
28 BetaFitter, CalibrationMethodKind, CalibrationResult, CalibratorFitter, ClassifierRegistry,
29 ClassifyInput, CompiledCalibrate, CompiledModel, FactRow, FeatureValue, IdentityCalibrator,
30 IsotonicFitter, NeuralClassifier, PlattFitter, TemperatureFitter, brier_score,
31 expected_calibration_error,
32};
33
34const ECE_BINS: usize = 10;
38
39#[derive(Debug)]
42pub enum CalibrateRuntimeError {
43 ClassifierMissing {
44 model_name: String,
45 },
46 UnknownModelInCatalog {
47 model_name: String,
48 },
49 EmptyDataset {
50 model_name: String,
51 },
52 InsufficientData {
53 model_name: String,
54 train: usize,
55 holdout: usize,
56 },
57 FitFailure {
58 model_name: String,
59 message: String,
60 },
61}
62
63impl std::fmt::Display for CalibrateRuntimeError {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 Self::ClassifierMissing { model_name } => write!(
67 f,
68 "CALIBRATE: classifier '{}' not registered; \
69 add it to LocyConfig::classifier_registry before evaluating",
70 model_name
71 ),
72 Self::UnknownModelInCatalog { model_name } => write!(
73 f,
74 "CALIBRATE: model '{}' not in CompiledProgram.model_catalog \
75 (compiler should have rejected this earlier)",
76 model_name
77 ),
78 Self::EmptyDataset { model_name } => write!(
79 f,
80 "CALIBRATE: model '{}' MATCH pattern produced zero rows",
81 model_name
82 ),
83 Self::InsufficientData {
84 model_name,
85 train,
86 holdout,
87 } => write!(
88 f,
89 "CALIBRATE: model '{model_name}' needs at least 1 sample in each \
90 split (got train={train}, holdout={holdout}); increase the data \
91 set or pick a different HOLDOUT fraction"
92 ),
93 Self::FitFailure {
94 model_name,
95 message,
96 } => {
97 write!(f, "CALIBRATE: model '{model_name}' fitter error: {message}")
98 }
99 }
100 }
101}
102
103impl std::error::Error for CalibrateRuntimeError {}
104
105fn build_collection_query(
109 cmd: &CompiledCalibrate,
110 model: &CompiledModel,
111) -> uni_cypher::ast::Query {
112 let mut items: Vec<ReturnItem> = Vec::with_capacity(model.inputs.len() + 1);
113 for binding in &model.inputs {
114 items.push(ReturnItem::Expr {
115 expr: Expr::Variable(binding.variable.clone()),
116 alias: Some(binding.variable.clone()),
117 source_text: None,
118 });
119 }
120 items.push(ReturnItem::Expr {
121 expr: cmd.target_expr.clone(),
122 alias: Some("__calibrate_target".to_string()),
123 source_text: None,
124 });
125 let stmt = Statement {
126 clauses: vec![
127 Clause::Match(MatchClause {
128 optional: false,
129 pattern: cmd.pattern.clone(),
130 where_clause: cmd.where_expr.clone(),
131 for_update: false,
132 }),
133 Clause::Return(ReturnClause {
134 distinct: false,
135 items,
136 order_by: None,
137 skip: None,
138 limit: None,
139 }),
140 ],
141 };
142 uni_cypher::ast::Query::Single(stmt)
143}
144
145fn row_to_feature(row: &FactRow, name: &str) -> FeatureValue {
147 match row.get(name) {
148 Some(Value::Float(f)) => FeatureValue::Float(*f),
149 Some(Value::Int(i)) => FeatureValue::Int(*i),
150 Some(Value::String(s)) => FeatureValue::String(s.clone()),
151 Some(Value::Bool(b)) => FeatureValue::Bool(*b),
152 Some(Value::Null) | None => FeatureValue::Null,
153 Some(_) => FeatureValue::Null,
156 }
157}
158
159fn target_to_label(v: Option<&Value>) -> bool {
163 match v {
164 Some(Value::Bool(b)) => *b,
165 Some(Value::Int(i)) => *i != 0,
166 Some(Value::Float(f)) => *f != 0.0,
167 Some(Value::String(s)) => !s.is_empty(),
168 Some(Value::Null) | None => false,
169 Some(_) => false,
170 }
171}
172
173fn fit_method(
177 method: CalibrationMethod,
178 preds: &[f64],
179 labels: &[bool],
180 model_name: &str,
181) -> Result<Arc<dyn uni_locy::Calibrator>, CalibrateRuntimeError> {
182 let result = match method {
183 CalibrationMethod::PlattScaling => PlattFitter.fit(preds, labels),
184 CalibrationMethod::IsotonicRegression => IsotonicFitter.fit(preds, labels),
185 CalibrationMethod::TemperatureScaling => TemperatureFitter.fit(preds, labels),
186 CalibrationMethod::BetaCalibration => BetaFitter.fit(preds, labels),
187 CalibrationMethod::Conformal { alpha } => {
188 uni_locy::calibration::ConformalFitter { alpha }.fit(preds, labels)
189 }
190 CalibrationMethod::None => {
191 Ok(Arc::new(IdentityCalibrator) as Arc<dyn uni_locy::Calibrator>)
194 }
195 CalibrationMethod::Dirichlet => {
196 Err(uni_locy::calibration::CalibrationError::NumericIssue(
205 "Dirichlet is multi-class; the binary CALIBRATE statement \
206 cannot fit it. Use `uni_locy::calibration::DirichletFitter` \
207 directly until the multi-class CALIBRATE surface form ships.",
208 ))
209 }
210 };
211 result.map_err(|e| CalibrateRuntimeError::FitFailure {
212 model_name: model_name.to_string(),
213 message: e.to_string(),
214 })
215}
216
217fn method_kind(method: CalibrationMethod) -> CalibrationMethodKind {
220 match method {
221 CalibrationMethod::PlattScaling => CalibrationMethodKind::Platt,
222 CalibrationMethod::IsotonicRegression => CalibrationMethodKind::Isotonic,
223 CalibrationMethod::TemperatureScaling => CalibrationMethodKind::Temperature,
224 CalibrationMethod::BetaCalibration => CalibrationMethodKind::Beta,
225 CalibrationMethod::Conformal { .. } => CalibrationMethodKind::Conformal,
226 CalibrationMethod::Dirichlet => CalibrationMethodKind::Dirichlet,
227 CalibrationMethod::None => CalibrationMethodKind::Identity,
228 }
229}
230
231pub async fn run_calibrate(
239 cmd: &CompiledCalibrate,
240 model_catalog: &HashMap<String, CompiledModel>,
241 classifier_registry: &Arc<ClassifierRegistry>,
242 rows: Vec<FactRow>,
243) -> Result<CalibrationResult, CalibrateRuntimeError> {
244 let model = model_catalog.get(&cmd.model_name).ok_or_else(|| {
245 CalibrateRuntimeError::UnknownModelInCatalog {
246 model_name: cmd.model_name.clone(),
247 }
248 })?;
249 let classifier: Arc<dyn NeuralClassifier> =
250 classifier_registry
251 .get(&cmd.model_name)
252 .cloned()
253 .ok_or_else(|| CalibrateRuntimeError::ClassifierMissing {
254 model_name: cmd.model_name.clone(),
255 })?;
256 if rows.is_empty() {
257 return Err(CalibrateRuntimeError::EmptyDataset {
258 model_name: cmd.model_name.clone(),
259 });
260 }
261 let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(rows.len());
263 let mut labels: Vec<bool> = Vec::with_capacity(rows.len());
264 for row in &rows {
265 let mut features = HashMap::with_capacity(model.inputs.len());
266 for binding in &model.inputs {
267 features.insert(
268 binding.variable.clone(),
269 row_to_feature(row, &binding.variable),
270 );
271 }
272 inputs.push(ClassifyInput { features });
273 labels.push(target_to_label(row.get("__calibrate_target")));
274 }
275 let predictions =
277 classifier
278 .classify(&inputs)
279 .await
280 .map_err(|e| CalibrateRuntimeError::FitFailure {
281 model_name: cmd.model_name.clone(),
282 message: e.to_string(),
283 })?;
284 if predictions.len() != labels.len() {
285 return Err(CalibrateRuntimeError::FitFailure {
286 model_name: cmd.model_name.clone(),
287 message: format!(
288 "classifier returned {} predictions for {} inputs",
289 predictions.len(),
290 labels.len()
291 ),
292 });
293 }
294
295 let n = predictions.len();
303 let holdout_size = ((n as f64) * cmd.holdout).ceil().max(1.0) as usize;
304 let holdout_size = holdout_size.min(n);
305 let mut train_preds: Vec<f64> = Vec::new();
306 let mut train_labels: Vec<bool> = Vec::new();
307 let mut holdout_preds: Vec<f64> = Vec::new();
308 let mut holdout_labels: Vec<bool> = Vec::new();
309 for (i, (p, y)) in predictions.iter().zip(labels.iter()).enumerate() {
310 if i < holdout_size {
311 holdout_preds.push(*p);
312 holdout_labels.push(*y);
313 } else {
314 train_preds.push(*p);
315 train_labels.push(*y);
316 }
317 }
318 if train_preds.is_empty() || holdout_preds.is_empty() {
319 return Err(CalibrateRuntimeError::InsufficientData {
320 model_name: cmd.model_name.clone(),
321 train: train_preds.len(),
322 holdout: holdout_preds.len(),
323 });
324 }
325
326 let calibrator = fit_method(cmd.method, &train_preds, &train_labels, &cmd.model_name)?;
327 let raw_brier = brier_score(&holdout_preds, &holdout_labels);
328 let raw_ece = expected_calibration_error(&holdout_preds, &holdout_labels, ECE_BINS);
329 let calibrated: Vec<f64> = calibrator.apply_batch(&holdout_preds);
330 let calibrated_brier = brier_score(&calibrated, &holdout_labels);
331 let calibrated_ece = expected_calibration_error(&calibrated, &holdout_labels, ECE_BINS);
332
333 let confidence_band_quantile = calibrator
339 .confidence_band(0.5)
340 .map(|band| (band.upper - band.lower) / 2.0);
341
342 Ok(CalibrationResult {
343 model_name: cmd.model_name.clone(),
344 method: method_kind(cmd.method),
345 n_samples: predictions.len(),
346 holdout_size: holdout_preds.len(),
347 calibrator,
348 raw_brier,
349 raw_ece,
350 calibrated_brier,
351 calibrated_ece,
352 confidence_band_quantile,
353 })
354}
355
356pub fn calibrate_collection_query(
359 cmd: &CompiledCalibrate,
360 model: &CompiledModel,
361) -> uni_cypher::ast::Query {
362 build_collection_query(cmd, model)
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use uni_cypher::locy_ast::{CalibrationMethod as AstCalibration, OutputType};
369 use uni_locy::{CompiledInputBinding, MockClassifier};
370
371 fn fact_row(pairs: &[(&str, Value)]) -> FactRow {
372 pairs
373 .iter()
374 .map(|(k, v)| (k.to_string(), v.clone()))
375 .collect()
376 }
377
378 fn model_with_one_input() -> CompiledModel {
379 CompiledModel {
380 name: "scorer".into(),
381 inputs: vec![CompiledInputBinding {
382 variable: "s".into(),
383 label: Some("Supplier".into()),
384 }],
385 embedder_alias: None,
386 features: vec![],
387 path_context: None,
388 output_type: OutputType::Prob,
389 output_name: "risk".into(),
390 xervo_alias: "classify/test".into(),
391 calibration: None,
392 version: None,
393 annotations: Default::default(),
394 }
395 }
396
397 fn dummy_pattern() -> uni_cypher::ast::Pattern {
398 uni_cypher::ast::Pattern { paths: vec![] }
401 }
402
403 fn cmd(method: AstCalibration) -> CompiledCalibrate {
404 CompiledCalibrate {
405 model_name: "scorer".into(),
406 pattern: dummy_pattern(),
407 where_expr: None,
408 target_expr: Expr::Variable("label".into()),
409 method,
410 holdout: 0.25,
411 }
412 }
413
414 #[tokio::test]
415 async fn calibrate_constant_classifier_improves_ece() {
416 let mut catalog = HashMap::new();
419 catalog.insert("scorer".to_string(), model_with_one_input());
420 let mut registry = ClassifierRegistry::new();
421 let c: Arc<dyn NeuralClassifier> =
422 Arc::new(MockClassifier::constant("classify/test", 0.95));
423 registry.insert("scorer".into(), c);
424 let registry = Arc::new(registry);
425 let rows: Vec<FactRow> = (0..100)
426 .map(|i| {
427 fact_row(&[
428 ("s", Value::Int(i as i64)),
429 ("__calibrate_target", Value::Bool(i % 2 == 0)),
430 ])
431 })
432 .collect();
433 let result = run_calibrate(
434 &cmd(AstCalibration::PlattScaling),
435 &catalog,
436 ®istry,
437 rows,
438 )
439 .await
440 .unwrap();
441 assert_eq!(result.model_name, "scorer");
442 assert_eq!(result.method, CalibrationMethodKind::Platt);
443 assert!(
445 result.calibrated_ece < result.raw_ece * 0.5,
446 "Platt should reduce ECE by ≥50%: raw={} cal={}",
447 result.raw_ece,
448 result.calibrated_ece
449 );
450 }
451
452 #[tokio::test]
453 async fn calibrate_missing_classifier_errors() {
454 let mut catalog = HashMap::new();
455 catalog.insert("scorer".to_string(), model_with_one_input());
456 let registry = Arc::new(ClassifierRegistry::new());
457 let rows = vec![fact_row(&[
458 ("s", Value::Int(1)),
459 ("__calibrate_target", Value::Bool(true)),
460 ])];
461 let err = run_calibrate(
462 &cmd(AstCalibration::PlattScaling),
463 &catalog,
464 ®istry,
465 rows,
466 )
467 .await
468 .unwrap_err();
469 assert!(matches!(
470 err,
471 CalibrateRuntimeError::ClassifierMissing { .. }
472 ));
473 }
474
475 #[tokio::test]
476 async fn calibrate_empty_dataset_errors() {
477 let mut catalog = HashMap::new();
478 catalog.insert("scorer".to_string(), model_with_one_input());
479 let mut registry = ClassifierRegistry::new();
480 let c: Arc<dyn NeuralClassifier> = Arc::new(MockClassifier::constant("classify/test", 0.5));
481 registry.insert("scorer".into(), c);
482 let registry = Arc::new(registry);
483 let err = run_calibrate(
484 &cmd(AstCalibration::PlattScaling),
485 &catalog,
486 ®istry,
487 vec![],
488 )
489 .await
490 .unwrap_err();
491 assert!(matches!(err, CalibrateRuntimeError::EmptyDataset { .. }));
492 }
493}