1use std::collections::HashMap;
24use std::sync::Arc;
25
26use uni_common::Value;
27use uni_cypher::ast::{Clause, Expr, MatchClause, ReturnClause, ReturnItem, Statement};
28use uni_cypher::locy_ast::ValidationMetric;
29use uni_locy::{
30 CompiledValidate, FactRow, ValidationResult, accuracy, auc, brier_score, debiased_ece,
31 expected_calibration_error, log_loss,
32};
33
34const ECE_BINS: usize = 10;
36
37#[derive(Debug)]
38pub enum ValidateRuntimeError {
39 RuleNotDerived { rule_name: String },
40 EmptyDataset { rule_name: String },
41 JoinKeysMissing { rule_name: String, key: String },
42}
43
44impl std::fmt::Display for ValidateRuntimeError {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Self::RuleNotDerived { rule_name } => write!(
48 f,
49 "VALIDATE: rule '{rule_name}' has no derived facts; \
50 ensure it appears in a stratum before VALIDATE"
51 ),
52 Self::EmptyDataset { rule_name } => write!(
53 f,
54 "VALIDATE: rule '{rule_name}' produced no \
55 (prediction, label) pairs (empty join)"
56 ),
57 Self::JoinKeysMissing { rule_name, key } => write!(
58 f,
59 "VALIDATE: rule '{rule_name}' KEY column '{key}' missing \
60 from either the rule's derived facts or the TARGET query rows"
61 ),
62 }
63 }
64}
65
66impl std::error::Error for ValidateRuntimeError {}
67
68pub fn validate_collection_query(
76 cmd: &CompiledValidate,
77 key_columns: &[String],
78) -> uni_cypher::ast::Query {
79 let mut items: Vec<ReturnItem> = Vec::with_capacity(key_columns.len() + 1);
80 for col in key_columns {
81 items.push(ReturnItem::Expr {
82 expr: Expr::Variable(col.clone()),
83 alias: Some(col.clone()),
84 source_text: None,
85 });
86 }
87 items.push(ReturnItem::Expr {
88 expr: cmd.target_expr.clone(),
89 alias: Some("__validate_target".to_string()),
90 source_text: None,
91 });
92 let stmt = Statement {
93 clauses: vec![
94 Clause::Match(MatchClause {
95 optional: false,
96 pattern: cmd.pattern.clone(),
97 where_clause: cmd.where_expr.clone(),
98 for_update: false,
99 }),
100 Clause::Return(ReturnClause {
101 distinct: false,
102 items,
103 order_by: None,
104 skip: None,
105 limit: None,
106 }),
107 ],
108 };
109 uni_cypher::ast::Query::Single(stmt)
110}
111
112fn target_to_label(v: Option<&Value>) -> bool {
114 match v {
115 Some(Value::Bool(b)) => *b,
116 Some(Value::Int(i)) => *i != 0,
117 Some(Value::Float(f)) => *f != 0.0,
118 Some(Value::String(s)) => !s.is_empty(),
119 Some(Value::Null) | None => false,
120 Some(_) => false,
121 }
122}
123
124fn canonical_key(v: &Value) -> String {
134 match v {
140 Value::Node(n) => format!("v:{}", n.vid),
141 Value::Edge(e) => format!("e:{}", e.eid),
142 Value::Int(i) => format!("v:{i}"),
147 Value::Float(f) => format!("f:{f}"),
148 Value::Bool(b) => format!("b:{b}"),
149 Value::String(s) => format!("s:{s}"),
150 Value::Null => "null".into(),
151 other => format!("{other:?}"),
152 }
153}
154
155fn join_key(row: &FactRow, key_columns: &[String]) -> Option<String> {
158 let mut parts = Vec::with_capacity(key_columns.len());
159 for col in key_columns {
160 let v = row.get(col)?;
161 parts.push(canonical_key(v));
162 }
163 Some(parts.join("|"))
164}
165
166pub fn run_validate(
170 cmd: &CompiledValidate,
171 rule_key_columns: &[String],
172 rule_facts: &[FactRow],
173 target_rows: Vec<FactRow>,
174) -> Result<ValidationResult, ValidateRuntimeError> {
175 if rule_facts.is_empty() {
176 return Err(ValidateRuntimeError::RuleNotDerived {
177 rule_name: cmd.rule_name.clone(),
178 });
179 }
180 let mut by_key: HashMap<String, f64> = HashMap::with_capacity(rule_facts.len());
182 for row in rule_facts {
183 let key = join_key(row, rule_key_columns).ok_or_else(|| {
184 ValidateRuntimeError::JoinKeysMissing {
185 rule_name: cmd.rule_name.clone(),
186 key: rule_key_columns.join(","),
187 }
188 })?;
189 let prob = match row.get(&cmd.prob_column) {
190 Some(Value::Float(f)) => *f,
191 Some(Value::Int(i)) => *i as f64,
192 _ => continue,
193 };
194 by_key.insert(key, prob.clamp(0.0, 1.0));
195 }
196
197 let mut preds: Vec<f64> = Vec::new();
199 let mut labels: Vec<bool> = Vec::new();
200 for row in &target_rows {
201 let key = join_key(row, rule_key_columns).ok_or_else(|| {
202 ValidateRuntimeError::JoinKeysMissing {
203 rule_name: cmd.rule_name.clone(),
204 key: rule_key_columns.join(","),
205 }
206 })?;
207 if let Some(&pred) = by_key.get(&key) {
208 preds.push(pred);
209 labels.push(target_to_label(row.get("__validate_target")));
210 }
211 }
212 if preds.is_empty() {
213 let rule_sample = rule_facts
216 .first()
217 .map(|r| r.keys().cloned().collect::<Vec<_>>().join(","));
218 let target_sample = target_rows
219 .first()
220 .map(|r| r.keys().cloned().collect::<Vec<_>>().join(","));
221 tracing::warn!(
222 "VALIDATE empty join for rule '{}'. rule_facts={}, target_rows={}, \
223 rule_cols={:?}, target_cols={:?}, key_columns={:?}, \
224 rule_key_sample={:?}, target_key_sample={:?}",
225 cmd.rule_name,
226 rule_facts.len(),
227 target_rows.len(),
228 rule_sample,
229 target_sample,
230 rule_key_columns,
231 rule_facts
232 .first()
233 .and_then(|r| r.get(&rule_key_columns[0]).cloned()),
234 target_rows
235 .first()
236 .and_then(|r| r.get(&rule_key_columns[0]).cloned()),
237 );
238 return Err(ValidateRuntimeError::EmptyDataset {
239 rule_name: cmd.rule_name.clone(),
240 });
241 }
242 let mut metrics_out: Vec<(ValidationMetric, f64)> = Vec::with_capacity(cmd.metrics.len());
243 for m in &cmd.metrics {
244 let v = match m {
245 ValidationMetric::BrierScore => brier_score(&preds, &labels),
246 ValidationMetric::LogLoss => log_loss(&preds, &labels),
247 ValidationMetric::Ece => expected_calibration_error(&preds, &labels, ECE_BINS),
248 ValidationMetric::DebiasedEce => debiased_ece(&preds, &labels, ECE_BINS),
249 ValidationMetric::Accuracy => accuracy(&preds, &labels),
250 ValidationMetric::Auc => auc(&preds, &labels),
251 };
252 metrics_out.push((*m, v));
253 }
254 Ok(ValidationResult {
255 rule_name: cmd.rule_name.clone(),
256 prob_column: cmd.prob_column.clone(),
257 n_samples: preds.len(),
258 metrics: metrics_out,
259 })
260}
261
262pub fn into_arc_error(e: ValidateRuntimeError) -> Arc<dyn std::error::Error + Send + Sync> {
266 Arc::new(e)
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use uni_cypher::ast::Pattern;
273
274 fn fact_row(pairs: &[(&str, Value)]) -> FactRow {
275 pairs
276 .iter()
277 .map(|(k, v)| (k.to_string(), v.clone()))
278 .collect()
279 }
280
281 fn dummy_cmd() -> CompiledValidate {
282 CompiledValidate {
283 rule_name: "risky".into(),
284 pattern: Pattern { paths: vec![] },
285 where_expr: None,
286 target_expr: Expr::Variable("label".into()),
287 metrics: vec![
288 ValidationMetric::BrierScore,
289 ValidationMetric::Accuracy,
290 ValidationMetric::Auc,
291 ],
292 prob_column: "risk".into(),
293 }
294 }
295
296 #[test]
297 fn validate_joins_facts_with_target_rows() {
298 let cmd = dummy_cmd();
299 let rule_facts = vec![
300 fact_row(&[("s", Value::Int(1)), ("risk", Value::Float(0.9))]),
301 fact_row(&[("s", Value::Int(2)), ("risk", Value::Float(0.1))]),
302 fact_row(&[("s", Value::Int(3)), ("risk", Value::Float(0.8))]),
303 fact_row(&[("s", Value::Int(4)), ("risk", Value::Float(0.2))]),
304 ];
305 let target_rows = vec![
306 fact_row(&[
307 ("s", Value::Int(1)),
308 ("__validate_target", Value::Bool(true)),
309 ]),
310 fact_row(&[
311 ("s", Value::Int(2)),
312 ("__validate_target", Value::Bool(false)),
313 ]),
314 fact_row(&[
315 ("s", Value::Int(3)),
316 ("__validate_target", Value::Bool(true)),
317 ]),
318 fact_row(&[
319 ("s", Value::Int(4)),
320 ("__validate_target", Value::Bool(false)),
321 ]),
322 ];
323 let res = run_validate(&cmd, &["s".to_string()], &rule_facts, target_rows).unwrap();
324 assert_eq!(res.n_samples, 4);
325 let brier = res.metric(ValidationMetric::BrierScore).unwrap();
328 assert!(brier < 0.05, "expected small Brier, got {brier}");
329 let acc = res.metric(ValidationMetric::Accuracy).unwrap();
331 assert_eq!(acc, 1.0);
332 let a = res.metric(ValidationMetric::Auc).unwrap();
334 assert!((a - 1.0).abs() < 1e-12);
335 }
336
337 #[test]
338 fn validate_drops_unjoinable_rows() {
339 let cmd = dummy_cmd();
340 let rule_facts = vec![fact_row(&[
341 ("s", Value::Int(1)),
342 ("risk", Value::Float(0.9)),
343 ])];
344 let target_rows = vec![fact_row(&[
345 ("s", Value::Int(99)),
346 ("__validate_target", Value::Bool(true)),
347 ])];
348 let err = run_validate(&cmd, &["s".to_string()], &rule_facts, target_rows).unwrap_err();
349 assert!(matches!(err, ValidateRuntimeError::EmptyDataset { .. }));
350 }
351
352 #[test]
353 fn validate_errors_on_no_rule_facts() {
354 let cmd = dummy_cmd();
355 let err = run_validate(&cmd, &["s".to_string()], &[], vec![]).unwrap_err();
356 assert!(matches!(err, ValidateRuntimeError::RuleNotDerived { .. }));
357 }
358}