Skip to main content

uni_query/query/df_graph/
locy_calibrate.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Phase C C2: `CALIBRATE` statement runtime.
5//!
6//! For each `CompiledCalibrate` command, this module:
7//!
8//! 1. Builds a Cypher `MATCH pattern [WHERE expr] RETURN <input vars>, <target>`
9//!    query from the compiled command's pieces.
10//! 2. Executes it through the same `execute_cypher_inline` path used for
11//!    Phase 4 inline Cypher commands — gets back a list of `FactRow`s.
12//! 3. Builds `ClassifyInput`s from each row using the model's INPUT
13//!    binding names, then batch-calls the registered classifier.
14//! 4. Converts target column values to bool labels.
15//! 5. Splits train / holdout deterministically (index-based modulo).
16//! 6. Fits the chosen `CalibratorFitter` on the training half.
17//! 7. Computes Brier + ECE on the holdout pre- and post-calibration.
18//! 8. Returns a [`uni_locy::CalibrationResult`] for surfacing in the
19//!    `LocyResult.command_results` slot.
20
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use uni_common::Value;
25use uni_cypher::ast::{Clause, Expr, MatchClause, ReturnClause, ReturnItem, Statement};
26use uni_cypher::locy_ast::CalibrationMethod;
27use uni_locy::{
28    BetaFitter, CalibrationMethodKind, CalibrationResult, CalibratorFitter, ClassifierRegistry,
29    ClassifyInput, CompiledCalibrate, CompiledModel, FactRow, FeatureValue, IdentityCalibrator,
30    IsotonicFitter, NeuralClassifier, PlattFitter, TemperatureFitter, brier_score,
31    expected_calibration_error,
32};
33
34/// Number of bins used for ECE reporting in the CALIBRATE holdout
35/// summary. The Phase C C2 result block is informational; C3
36/// `VALIDATE` will offer richer (debiased / classwise) variants.
37const ECE_BINS: usize = 10;
38
39/// Errors specific to `CALIBRATE` runtime. Wrapped into a
40/// `DataFusionError::Execution` at the dispatch site.
41#[derive(Debug)]
42pub enum CalibrateRuntimeError {
43    ClassifierMissing {
44        model_name: String,
45    },
46    UnknownModelInCatalog {
47        model_name: String,
48    },
49    EmptyDataset {
50        model_name: String,
51    },
52    InsufficientData {
53        model_name: String,
54        train: usize,
55        holdout: usize,
56    },
57    FitFailure {
58        model_name: String,
59        message: String,
60    },
61}
62
63impl std::fmt::Display for CalibrateRuntimeError {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Self::ClassifierMissing { model_name } => write!(
67                f,
68                "CALIBRATE: classifier '{}' not registered; \
69                 add it to LocyConfig::classifier_registry before evaluating",
70                model_name
71            ),
72            Self::UnknownModelInCatalog { model_name } => write!(
73                f,
74                "CALIBRATE: model '{}' not in CompiledProgram.model_catalog \
75                 (compiler should have rejected this earlier)",
76                model_name
77            ),
78            Self::EmptyDataset { model_name } => write!(
79                f,
80                "CALIBRATE: model '{}' MATCH pattern produced zero rows",
81                model_name
82            ),
83            Self::InsufficientData {
84                model_name,
85                train,
86                holdout,
87            } => write!(
88                f,
89                "CALIBRATE: model '{model_name}' needs at least 1 sample in each \
90                 split (got train={train}, holdout={holdout}); increase the data \
91                 set or pick a different HOLDOUT fraction"
92            ),
93            Self::FitFailure {
94                model_name,
95                message,
96            } => {
97                write!(f, "CALIBRATE: model '{model_name}' fitter error: {message}")
98            }
99        }
100    }
101}
102
103impl std::error::Error for CalibrateRuntimeError {}
104
105/// Build a Cypher `Query` from the CALIBRATE command's pattern + WHERE +
106/// projected variables. The projection returns one node-variable per
107/// model INPUT binding followed by the TARGET expression.
108fn build_collection_query(
109    cmd: &CompiledCalibrate,
110    model: &CompiledModel,
111) -> uni_cypher::ast::Query {
112    let mut items: Vec<ReturnItem> = Vec::with_capacity(model.inputs.len() + 1);
113    for binding in &model.inputs {
114        items.push(ReturnItem::Expr {
115            expr: Expr::Variable(binding.variable.clone()),
116            alias: Some(binding.variable.clone()),
117            source_text: None,
118        });
119    }
120    items.push(ReturnItem::Expr {
121        expr: cmd.target_expr.clone(),
122        alias: Some("__calibrate_target".to_string()),
123        source_text: None,
124    });
125    let stmt = Statement {
126        clauses: vec![
127            Clause::Match(MatchClause {
128                optional: false,
129                pattern: cmd.pattern.clone(),
130                where_clause: cmd.where_expr.clone(),
131                for_update: false,
132            }),
133            Clause::Return(ReturnClause {
134                distinct: false,
135                items,
136                order_by: None,
137                skip: None,
138                limit: None,
139            }),
140        ],
141    };
142    uni_cypher::ast::Query::Single(stmt)
143}
144
145/// Pull the FactRow value for column `name` and convert to FeatureValue.
146fn row_to_feature(row: &FactRow, name: &str) -> FeatureValue {
147    match row.get(name) {
148        Some(Value::Float(f)) => FeatureValue::Float(*f),
149        Some(Value::Int(i)) => FeatureValue::Int(*i),
150        Some(Value::String(s)) => FeatureValue::String(s.clone()),
151        Some(Value::Bool(b)) => FeatureValue::Bool(*b),
152        Some(Value::Null) | None => FeatureValue::Null,
153        // Other Value variants (List, Map, Node, Edge, …) fall back to
154        // Null in this slice — Slice 3+ may extend FeatureValue.
155        Some(_) => FeatureValue::Null,
156    }
157}
158
159/// Convert a target Value to a bool label. Non-null truthy values
160/// (true, non-zero numbers, non-empty strings) become 1; null /
161/// false / 0 become 0.
162fn target_to_label(v: Option<&Value>) -> bool {
163    match v {
164        Some(Value::Bool(b)) => *b,
165        Some(Value::Int(i)) => *i != 0,
166        Some(Value::Float(f)) => *f != 0.0,
167        Some(Value::String(s)) => !s.is_empty(),
168        Some(Value::Null) | None => false,
169        Some(_) => false,
170    }
171}
172
173/// Dispatch the chosen calibration method to its fitter, run the fit,
174/// wrap the resulting `Arc<dyn Calibrator>` in a `CalibrateRuntimeError`
175/// on failure.
176fn fit_method(
177    method: CalibrationMethod,
178    preds: &[f64],
179    labels: &[bool],
180    model_name: &str,
181) -> Result<Arc<dyn uni_locy::Calibrator>, CalibrateRuntimeError> {
182    let result = match method {
183        CalibrationMethod::PlattScaling => PlattFitter.fit(preds, labels),
184        CalibrationMethod::IsotonicRegression => IsotonicFitter.fit(preds, labels),
185        CalibrationMethod::TemperatureScaling => TemperatureFitter.fit(preds, labels),
186        CalibrationMethod::BetaCalibration => BetaFitter.fit(preds, labels),
187        CalibrationMethod::Conformal { alpha } => {
188            uni_locy::calibration::ConformalFitter { alpha }.fit(preds, labels)
189        }
190        CalibrationMethod::None => {
191            // Explicit "no-op" — caller asked for identity. Useful for
192            // exercising the CALIBRATE plumbing without modeling.
193            Ok(Arc::new(IdentityCalibrator) as Arc<dyn uni_locy::Calibrator>)
194        }
195        CalibrationMethod::Dirichlet => {
196            // Phase D D-C1d surface: the grammar accepts the keyword,
197            // but the binary CALIBRATE pipeline can't drive a
198            // multi-class fit — the trait expects `labels: &[bool]`
199            // and `preds: &[f64]`, whereas Dirichlet needs
200            // `labels: &[u32]` + `preds: &[Vec<f64>]`. Pending a
201            // surface form for multi-class CALIBRATE, callers should
202            // instantiate `DirichletFitter` directly via the Rust
203            // library API.
204            Err(uni_locy::calibration::CalibrationError::NumericIssue(
205                "Dirichlet is multi-class; the binary CALIBRATE statement \
206                 cannot fit it. Use `uni_locy::calibration::DirichletFitter` \
207                 directly until the multi-class CALIBRATE surface form ships.",
208            ))
209        }
210    };
211    result.map_err(|e| CalibrateRuntimeError::FitFailure {
212        model_name: model_name.to_string(),
213        message: e.to_string(),
214    })
215}
216
217/// Match the chosen method to its [`CalibrationMethodKind`] for the
218/// returned result block.
219fn method_kind(method: CalibrationMethod) -> CalibrationMethodKind {
220    match method {
221        CalibrationMethod::PlattScaling => CalibrationMethodKind::Platt,
222        CalibrationMethod::IsotonicRegression => CalibrationMethodKind::Isotonic,
223        CalibrationMethod::TemperatureScaling => CalibrationMethodKind::Temperature,
224        CalibrationMethod::BetaCalibration => CalibrationMethodKind::Beta,
225        CalibrationMethod::Conformal { .. } => CalibrationMethodKind::Conformal,
226        CalibrationMethod::Dirichlet => CalibrationMethodKind::Dirichlet,
227        CalibrationMethod::None => CalibrationMethodKind::Identity,
228    }
229}
230
231/// Run a `CALIBRATE` command end-to-end. The caller supplies an
232/// already-collected (input_value, label) row set — typically by
233/// driving the same `execute_cypher_inline` primitive used for Phase
234/// 4 inline Cypher.
235///
236/// This separation keeps the runtime testable without standing up a
237/// DataFusion session.
238pub async fn run_calibrate(
239    cmd: &CompiledCalibrate,
240    model_catalog: &HashMap<String, CompiledModel>,
241    classifier_registry: &Arc<ClassifierRegistry>,
242    rows: Vec<FactRow>,
243) -> Result<CalibrationResult, CalibrateRuntimeError> {
244    let model = model_catalog.get(&cmd.model_name).ok_or_else(|| {
245        CalibrateRuntimeError::UnknownModelInCatalog {
246            model_name: cmd.model_name.clone(),
247        }
248    })?;
249    let classifier: Arc<dyn NeuralClassifier> =
250        classifier_registry
251            .get(&cmd.model_name)
252            .cloned()
253            .ok_or_else(|| CalibrateRuntimeError::ClassifierMissing {
254                model_name: cmd.model_name.clone(),
255            })?;
256    if rows.is_empty() {
257        return Err(CalibrateRuntimeError::EmptyDataset {
258            model_name: cmd.model_name.clone(),
259        });
260    }
261    // Build ClassifyInputs and labels in row order.
262    let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(rows.len());
263    let mut labels: Vec<bool> = Vec::with_capacity(rows.len());
264    for row in &rows {
265        let mut features = HashMap::with_capacity(model.inputs.len());
266        for binding in &model.inputs {
267            features.insert(
268                binding.variable.clone(),
269                row_to_feature(row, &binding.variable),
270            );
271        }
272        inputs.push(ClassifyInput { features });
273        labels.push(target_to_label(row.get("__calibrate_target")));
274    }
275    // Classify everything once — same primitive Slice 3 uses for rule-body invocation.
276    let predictions =
277        classifier
278            .classify(&inputs)
279            .await
280            .map_err(|e| CalibrateRuntimeError::FitFailure {
281                model_name: cmd.model_name.clone(),
282                message: e.to_string(),
283            })?;
284    if predictions.len() != labels.len() {
285        return Err(CalibrateRuntimeError::FitFailure {
286            model_name: cmd.model_name.clone(),
287            message: format!(
288                "classifier returned {} predictions for {} inputs",
289                predictions.len(),
290                labels.len()
291            ),
292        });
293    }
294
295    // Deterministic holdout split: the holdout takes the FIRST
296    // ceil(n * holdout) rows in input order. A modulo-based stride
297    // would alias with label patterns that have the same period
298    // (e.g. label = `i % 2 == 0` aliases with stride 4), so prefix
299    // selection keeps the split label-distribution-independent.
300    // Tests rely on this exact behavior. Randomized splitting with
301    // a seedable RNG is a follow-up.
302    let n = predictions.len();
303    let holdout_size = ((n as f64) * cmd.holdout).ceil().max(1.0) as usize;
304    let holdout_size = holdout_size.min(n);
305    let mut train_preds: Vec<f64> = Vec::new();
306    let mut train_labels: Vec<bool> = Vec::new();
307    let mut holdout_preds: Vec<f64> = Vec::new();
308    let mut holdout_labels: Vec<bool> = Vec::new();
309    for (i, (p, y)) in predictions.iter().zip(labels.iter()).enumerate() {
310        if i < holdout_size {
311            holdout_preds.push(*p);
312            holdout_labels.push(*y);
313        } else {
314            train_preds.push(*p);
315            train_labels.push(*y);
316        }
317    }
318    if train_preds.is_empty() || holdout_preds.is_empty() {
319        return Err(CalibrateRuntimeError::InsufficientData {
320            model_name: cmd.model_name.clone(),
321            train: train_preds.len(),
322            holdout: holdout_preds.len(),
323        });
324    }
325
326    let calibrator = fit_method(cmd.method, &train_preds, &train_labels, &cmd.model_name)?;
327    let raw_brier = brier_score(&holdout_preds, &holdout_labels);
328    let raw_ece = expected_calibration_error(&holdout_preds, &holdout_labels, ECE_BINS);
329    let calibrated: Vec<f64> = calibrator.apply_batch(&holdout_preds);
330    let calibrated_brier = brier_score(&calibrated, &holdout_labels);
331    let calibrated_ece = expected_calibration_error(&calibrated, &holdout_labels, ECE_BINS);
332
333    // Phase C C1a: surface the conformal quantile in the result row
334    // for downstream EXPLAIN / band reporting. Only populated when
335    // the method is Conformal — extracted via the calibrator's
336    // confidence_band probe at p = 0.5 (the band half-width equals
337    // the quantile regardless of the probe point).
338    let confidence_band_quantile = calibrator
339        .confidence_band(0.5)
340        .map(|band| (band.upper - band.lower) / 2.0);
341
342    Ok(CalibrationResult {
343        model_name: cmd.model_name.clone(),
344        method: method_kind(cmd.method),
345        n_samples: predictions.len(),
346        holdout_size: holdout_preds.len(),
347        calibrator,
348        raw_brier,
349        raw_ece,
350        calibrated_brier,
351        calibrated_ece,
352        confidence_band_quantile,
353    })
354}
355
356/// Export the collection-query builder for the dispatch layer that
357/// wraps `execute_cypher_inline`.
358pub fn calibrate_collection_query(
359    cmd: &CompiledCalibrate,
360    model: &CompiledModel,
361) -> uni_cypher::ast::Query {
362    build_collection_query(cmd, model)
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use uni_cypher::locy_ast::{CalibrationMethod as AstCalibration, OutputType};
369    use uni_locy::{CompiledInputBinding, MockClassifier};
370
371    fn fact_row(pairs: &[(&str, Value)]) -> FactRow {
372        pairs
373            .iter()
374            .map(|(k, v)| (k.to_string(), v.clone()))
375            .collect()
376    }
377
378    fn model_with_one_input() -> CompiledModel {
379        CompiledModel {
380            name: "scorer".into(),
381            inputs: vec![CompiledInputBinding {
382                variable: "s".into(),
383                label: Some("Supplier".into()),
384            }],
385            embedder_alias: None,
386            features: vec![],
387            path_context: None,
388            output_type: OutputType::Prob,
389            output_name: "risk".into(),
390            xervo_alias: "classify/test".into(),
391            calibration: None,
392            version: None,
393            annotations: Default::default(),
394        }
395    }
396
397    fn dummy_pattern() -> uni_cypher::ast::Pattern {
398        // A minimal pattern; the actual MATCH wouldn't be executed in
399        // these tests since we feed `run_calibrate` rows directly.
400        uni_cypher::ast::Pattern { paths: vec![] }
401    }
402
403    fn cmd(method: AstCalibration) -> CompiledCalibrate {
404        CompiledCalibrate {
405            model_name: "scorer".into(),
406            pattern: dummy_pattern(),
407            where_expr: None,
408            target_expr: Expr::Variable("label".into()),
409            method,
410            holdout: 0.25,
411        }
412    }
413
414    #[tokio::test]
415    async fn calibrate_constant_classifier_improves_ece() {
416        // Build a dataset of 100 rows, alternating labels, with a
417        // mock classifier that always returns 0.95.
418        let mut catalog = HashMap::new();
419        catalog.insert("scorer".to_string(), model_with_one_input());
420        let mut registry = ClassifierRegistry::new();
421        let c: Arc<dyn NeuralClassifier> =
422            Arc::new(MockClassifier::constant("classify/test", 0.95));
423        registry.insert("scorer".into(), c);
424        let registry = Arc::new(registry);
425        let rows: Vec<FactRow> = (0..100)
426            .map(|i| {
427                fact_row(&[
428                    ("s", Value::Int(i as i64)),
429                    ("__calibrate_target", Value::Bool(i % 2 == 0)),
430                ])
431            })
432            .collect();
433        let result = run_calibrate(
434            &cmd(AstCalibration::PlattScaling),
435            &catalog,
436            &registry,
437            rows,
438        )
439        .await
440        .unwrap();
441        assert_eq!(result.model_name, "scorer");
442        assert_eq!(result.method, CalibrationMethodKind::Platt);
443        // Phase C gate: ECE should drop by at least 50% after Platt.
444        assert!(
445            result.calibrated_ece < result.raw_ece * 0.5,
446            "Platt should reduce ECE by ≥50%: raw={} cal={}",
447            result.raw_ece,
448            result.calibrated_ece
449        );
450    }
451
452    #[tokio::test]
453    async fn calibrate_missing_classifier_errors() {
454        let mut catalog = HashMap::new();
455        catalog.insert("scorer".to_string(), model_with_one_input());
456        let registry = Arc::new(ClassifierRegistry::new());
457        let rows = vec![fact_row(&[
458            ("s", Value::Int(1)),
459            ("__calibrate_target", Value::Bool(true)),
460        ])];
461        let err = run_calibrate(
462            &cmd(AstCalibration::PlattScaling),
463            &catalog,
464            &registry,
465            rows,
466        )
467        .await
468        .unwrap_err();
469        assert!(matches!(
470            err,
471            CalibrateRuntimeError::ClassifierMissing { .. }
472        ));
473    }
474
475    #[tokio::test]
476    async fn calibrate_empty_dataset_errors() {
477        let mut catalog = HashMap::new();
478        catalog.insert("scorer".to_string(), model_with_one_input());
479        let mut registry = ClassifierRegistry::new();
480        let c: Arc<dyn NeuralClassifier> = Arc::new(MockClassifier::constant("classify/test", 0.5));
481        registry.insert("scorer".into(), c);
482        let registry = Arc::new(registry);
483        let err = run_calibrate(
484            &cmd(AstCalibration::PlattScaling),
485            &catalog,
486            &registry,
487            vec![],
488        )
489        .await
490        .unwrap_err();
491        assert!(matches!(err, CalibrateRuntimeError::EmptyDataset { .. }));
492    }
493}