Skip to main content

uni_query/query/df_graph/
locy_validate.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Phase C C3: `VALIDATE` statement runtime.
5//!
6//! For a `CompiledValidate { rule_name, prob_column, pattern, ... }`:
7//!
8//! 1. Build a Cypher `MATCH pattern [WHERE ...] RETURN <KEY vars>, target`
9//!    query — this is the ground-truth source.
10//! 2. Execute, pull `(key_tuple, label)` rows.
11//! 3. Look up the rule's derived facts in `DerivedStore`, indexed by
12//!    KEY column tuple → PROB column value.
13//! 4. Join the two on the key tuple to produce `(prediction, label)`
14//!    pairs (rows in either side without a match are dropped — this is
15//!    intentional, matches sklearn semantics).
16//! 5. Compute each requested metric via the `uni_locy::calibration`
17//!    library functions.
18//!
19//! Unlike `CALIBRATE`, this never invokes a classifier or fits
20//! anything — the rule has already been evaluated by the fixpoint
21//! loop and the metric pass just *measures*.
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use uni_common::Value;
27use uni_cypher::ast::{Clause, Expr, MatchClause, ReturnClause, ReturnItem, Statement};
28use uni_cypher::locy_ast::ValidationMetric;
29use uni_locy::{
30    CompiledValidate, FactRow, ValidationResult, accuracy, auc, brier_score, debiased_ece,
31    expected_calibration_error, log_loss,
32};
33
34/// Number of bins for ECE / debiased_ECE in the VALIDATE pass.
35const ECE_BINS: usize = 10;
36
37#[derive(Debug)]
38pub enum ValidateRuntimeError {
39    RuleNotDerived { rule_name: String },
40    EmptyDataset { rule_name: String },
41    JoinKeysMissing { rule_name: String, key: String },
42}
43
44impl std::fmt::Display for ValidateRuntimeError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            Self::RuleNotDerived { rule_name } => write!(
48                f,
49                "VALIDATE: rule '{rule_name}' has no derived facts; \
50                 ensure it appears in a stratum before VALIDATE"
51            ),
52            Self::EmptyDataset { rule_name } => write!(
53                f,
54                "VALIDATE: rule '{rule_name}' produced no \
55                 (prediction, label) pairs (empty join)"
56            ),
57            Self::JoinKeysMissing { rule_name, key } => write!(
58                f,
59                "VALIDATE: rule '{rule_name}' KEY column '{key}' missing \
60                 from either the rule's derived facts or the TARGET query rows"
61            ),
62        }
63    }
64}
65
66impl std::error::Error for ValidateRuntimeError {}
67
68/// Build the ground-truth collection query: the MATCH pattern + WHERE
69/// + RETURN of the TARGET expression plus all KEY-shaped projections
70///   needed to join with the rule's derived facts.
71///
72/// We use the YIELD-key variable names as both the pattern-bound
73/// variables and the RETURN aliases — this mirrors the rule's own
74/// KEY shape so the join can match by name.
75pub fn validate_collection_query(
76    cmd: &CompiledValidate,
77    key_columns: &[String],
78) -> uni_cypher::ast::Query {
79    let mut items: Vec<ReturnItem> = Vec::with_capacity(key_columns.len() + 1);
80    for col in key_columns {
81        items.push(ReturnItem::Expr {
82            expr: Expr::Variable(col.clone()),
83            alias: Some(col.clone()),
84            source_text: None,
85        });
86    }
87    items.push(ReturnItem::Expr {
88        expr: cmd.target_expr.clone(),
89        alias: Some("__validate_target".to_string()),
90        source_text: None,
91    });
92    let stmt = Statement {
93        clauses: vec![
94            Clause::Match(MatchClause {
95                optional: false,
96                pattern: cmd.pattern.clone(),
97                where_clause: cmd.where_expr.clone(),
98                for_update: false,
99            }),
100            Clause::Return(ReturnClause {
101                distinct: false,
102                items,
103                order_by: None,
104                skip: None,
105                limit: None,
106            }),
107        ],
108    };
109    uni_cypher::ast::Query::Single(stmt)
110}
111
112/// Convert a target Value to a bool label. Same rules as CALIBRATE.
113fn target_to_label(v: Option<&Value>) -> bool {
114    match v {
115        Some(Value::Bool(b)) => *b,
116        Some(Value::Int(i)) => *i != 0,
117        Some(Value::Float(f)) => *f != 0.0,
118        Some(Value::String(s)) => !s.is_empty(),
119        Some(Value::Null) | None => false,
120        Some(_) => false,
121    }
122}
123
124/// Canonicalize a Value for join-key comparison.
125///
126/// The rule's KEY column (typically a `Node` bound by the rule's
127/// MATCH) can be stored either as a `Value::Node` (full node carrier)
128/// or as `Value::Int(vid)` (just the integer vid) depending on which
129/// runtime path produced it — DerivedStore-to-FactRow conversion
130/// keeps vids as `Int`, while a fresh Cypher MATCH returns `Node`.
131/// To make the join work in both directions we extract the vid where
132/// applicable and stringify other primitive types directly.
133fn canonical_key(v: &Value) -> String {
134    // Critical: a rule's KEY column carrying a node-bound variable
135    // may show up as either `Value::Node` (after a fresh Cypher MATCH)
136    // or `Value::Int(vid)` (after the DerivedStore record-batch
137    // round-trip strips the rich Node value down to its vid). Both
138    // refer to the same node — canonicalize to the bare integer.
139    match v {
140        Value::Node(n) => format!("v:{}", n.vid),
141        Value::Edge(e) => format!("e:{}", e.eid),
142        // Treat any non-node integer as a potential vid for the same
143        // join. False positives here are tolerable because the rule's
144        // KEY column ALWAYS produces semantically-equivalent values
145        // under both encoding paths.
146        Value::Int(i) => format!("v:{i}"),
147        Value::Float(f) => format!("f:{f}"),
148        Value::Bool(b) => format!("b:{b}"),
149        Value::String(s) => format!("s:{s}"),
150        Value::Null => "null".into(),
151        other => format!("{other:?}"),
152    }
153}
154
155/// Build a stable join-key string from a row's KEY column values.
156/// `canonical_key` normalizes Node vs. Int(vid) on each side.
157fn join_key(row: &FactRow, key_columns: &[String]) -> Option<String> {
158    let mut parts = Vec::with_capacity(key_columns.len());
159    for col in key_columns {
160        let v = row.get(col)?;
161        parts.push(canonical_key(v));
162    }
163    Some(parts.join("|"))
164}
165
166/// Execute the validation pass. `rule_facts` is the rule's derived
167/// fact set (read from `LocyResult.derived[rule_name]`); `target_rows`
168/// is the Cypher MATCH+TARGET query result.
169pub fn run_validate(
170    cmd: &CompiledValidate,
171    rule_key_columns: &[String],
172    rule_facts: &[FactRow],
173    target_rows: Vec<FactRow>,
174) -> Result<ValidationResult, ValidateRuntimeError> {
175    if rule_facts.is_empty() {
176        return Err(ValidateRuntimeError::RuleNotDerived {
177            rule_name: cmd.rule_name.clone(),
178        });
179    }
180    // Index rule facts by KEY tuple → PROB value.
181    let mut by_key: HashMap<String, f64> = HashMap::with_capacity(rule_facts.len());
182    for row in rule_facts {
183        let key = join_key(row, rule_key_columns).ok_or_else(|| {
184            ValidateRuntimeError::JoinKeysMissing {
185                rule_name: cmd.rule_name.clone(),
186                key: rule_key_columns.join(","),
187            }
188        })?;
189        let prob = match row.get(&cmd.prob_column) {
190            Some(Value::Float(f)) => *f,
191            Some(Value::Int(i)) => *i as f64,
192            _ => continue,
193        };
194        by_key.insert(key, prob.clamp(0.0, 1.0));
195    }
196
197    // Join target rows onto the by_key index.
198    let mut preds: Vec<f64> = Vec::new();
199    let mut labels: Vec<bool> = Vec::new();
200    for row in &target_rows {
201        let key = join_key(row, rule_key_columns).ok_or_else(|| {
202            ValidateRuntimeError::JoinKeysMissing {
203                rule_name: cmd.rule_name.clone(),
204                key: rule_key_columns.join(","),
205            }
206        })?;
207        if let Some(&pred) = by_key.get(&key) {
208            preds.push(pred);
209            labels.push(target_to_label(row.get("__validate_target")));
210        }
211    }
212    if preds.is_empty() {
213        // Diagnostic: surface what we saw on each side of the join so
214        // mismatches are debuggable from the error message.
215        let rule_sample = rule_facts
216            .first()
217            .map(|r| r.keys().cloned().collect::<Vec<_>>().join(","));
218        let target_sample = target_rows
219            .first()
220            .map(|r| r.keys().cloned().collect::<Vec<_>>().join(","));
221        tracing::warn!(
222            "VALIDATE empty join for rule '{}'. rule_facts={}, target_rows={}, \
223             rule_cols={:?}, target_cols={:?}, key_columns={:?}, \
224             rule_key_sample={:?}, target_key_sample={:?}",
225            cmd.rule_name,
226            rule_facts.len(),
227            target_rows.len(),
228            rule_sample,
229            target_sample,
230            rule_key_columns,
231            rule_facts
232                .first()
233                .and_then(|r| r.get(&rule_key_columns[0]).cloned()),
234            target_rows
235                .first()
236                .and_then(|r| r.get(&rule_key_columns[0]).cloned()),
237        );
238        return Err(ValidateRuntimeError::EmptyDataset {
239            rule_name: cmd.rule_name.clone(),
240        });
241    }
242    let mut metrics_out: Vec<(ValidationMetric, f64)> = Vec::with_capacity(cmd.metrics.len());
243    for m in &cmd.metrics {
244        let v = match m {
245            ValidationMetric::BrierScore => brier_score(&preds, &labels),
246            ValidationMetric::LogLoss => log_loss(&preds, &labels),
247            ValidationMetric::Ece => expected_calibration_error(&preds, &labels, ECE_BINS),
248            ValidationMetric::DebiasedEce => debiased_ece(&preds, &labels, ECE_BINS),
249            ValidationMetric::Accuracy => accuracy(&preds, &labels),
250            ValidationMetric::Auc => auc(&preds, &labels),
251        };
252        metrics_out.push((*m, v));
253    }
254    Ok(ValidationResult {
255        rule_name: cmd.rule_name.clone(),
256        prob_column: cmd.prob_column.clone(),
257        n_samples: preds.len(),
258        metrics: metrics_out,
259    })
260}
261
262/// Re-export an `Arc<dyn ...>` wrapper for symmetry with locy_calibrate.
263/// (The runtime doesn't actually need this — included so dispatch
264/// callers can `?` on a unified `Result<ValidationResult, _>` shape.)
265pub fn into_arc_error(e: ValidateRuntimeError) -> Arc<dyn std::error::Error + Send + Sync> {
266    Arc::new(e)
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use uni_cypher::ast::Pattern;
273
274    fn fact_row(pairs: &[(&str, Value)]) -> FactRow {
275        pairs
276            .iter()
277            .map(|(k, v)| (k.to_string(), v.clone()))
278            .collect()
279    }
280
281    fn dummy_cmd() -> CompiledValidate {
282        CompiledValidate {
283            rule_name: "risky".into(),
284            pattern: Pattern { paths: vec![] },
285            where_expr: None,
286            target_expr: Expr::Variable("label".into()),
287            metrics: vec![
288                ValidationMetric::BrierScore,
289                ValidationMetric::Accuracy,
290                ValidationMetric::Auc,
291            ],
292            prob_column: "risk".into(),
293        }
294    }
295
296    #[test]
297    fn validate_joins_facts_with_target_rows() {
298        let cmd = dummy_cmd();
299        let rule_facts = vec![
300            fact_row(&[("s", Value::Int(1)), ("risk", Value::Float(0.9))]),
301            fact_row(&[("s", Value::Int(2)), ("risk", Value::Float(0.1))]),
302            fact_row(&[("s", Value::Int(3)), ("risk", Value::Float(0.8))]),
303            fact_row(&[("s", Value::Int(4)), ("risk", Value::Float(0.2))]),
304        ];
305        let target_rows = vec![
306            fact_row(&[
307                ("s", Value::Int(1)),
308                ("__validate_target", Value::Bool(true)),
309            ]),
310            fact_row(&[
311                ("s", Value::Int(2)),
312                ("__validate_target", Value::Bool(false)),
313            ]),
314            fact_row(&[
315                ("s", Value::Int(3)),
316                ("__validate_target", Value::Bool(true)),
317            ]),
318            fact_row(&[
319                ("s", Value::Int(4)),
320                ("__validate_target", Value::Bool(false)),
321            ]),
322        ];
323        let res = run_validate(&cmd, &["s".to_string()], &rule_facts, target_rows).unwrap();
324        assert_eq!(res.n_samples, 4);
325        // Perfect alignment: high probs on True, low on False.
326        // Brier should be very small.
327        let brier = res.metric(ValidationMetric::BrierScore).unwrap();
328        assert!(brier < 0.05, "expected small Brier, got {brier}");
329        // Accuracy = 1 (all predictions correct at threshold 0.5).
330        let acc = res.metric(ValidationMetric::Accuracy).unwrap();
331        assert_eq!(acc, 1.0);
332        // AUC = 1 for perfect separation.
333        let a = res.metric(ValidationMetric::Auc).unwrap();
334        assert!((a - 1.0).abs() < 1e-12);
335    }
336
337    #[test]
338    fn validate_drops_unjoinable_rows() {
339        let cmd = dummy_cmd();
340        let rule_facts = vec![fact_row(&[
341            ("s", Value::Int(1)),
342            ("risk", Value::Float(0.9)),
343        ])];
344        let target_rows = vec![fact_row(&[
345            ("s", Value::Int(99)),
346            ("__validate_target", Value::Bool(true)),
347        ])];
348        let err = run_validate(&cmd, &["s".to_string()], &rule_facts, target_rows).unwrap_err();
349        assert!(matches!(err, ValidateRuntimeError::EmptyDataset { .. }));
350    }
351
352    #[test]
353    fn validate_errors_on_no_rule_facts() {
354        let cmd = dummy_cmd();
355        let err = run_validate(&cmd, &["s".to_string()], &[], vec![]).unwrap_err();
356        assert!(matches!(err, ValidateRuntimeError::RuleNotDerived { .. }));
357    }
358}