Skip to main content

uni_locy/
result.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use uni_common::{Properties, Value};
5
6use crate::types::{RuntimeWarning, RuntimeWarningCode};
7
8/// A single row of bindings from a Locy evaluation result.
9pub type FactRow = HashMap<String, Value>;
10
11/// The result of evaluating a compiled Locy program.
12#[derive(Debug, Clone)]
13pub struct LocyResult {
14    /// Derived facts per rule name.
15    pub derived: HashMap<String, Vec<FactRow>>,
16    /// Execution statistics.
17    pub stats: LocyStats,
18    /// Results from Phase 4 commands.
19    pub command_results: Vec<CommandResult>,
20    /// Runtime warnings collected during evaluation.
21    pub warnings: Vec<RuntimeWarning>,
22    /// Compile-time warnings carried over from `CompiledProgram.warnings`.
23    /// Phase C C4: surfaces `UncalibratedNeuralPredicate` /
24    /// `FoldInRecursivePath` / `UncalibratedLLMLogprobs` /
25    /// `MsumNonNegativity` / `ProbabilityDomainViolation` so test
26    /// harnesses and downstream tooling can inspect them on the
27    /// returned `LocyResult` rather than re-running the compiler.
28    pub compile_warnings: Vec<crate::types::CompilerWarning>,
29    /// Groups where BDD computation fell back to independence mode.
30    /// Maps rule name → list of human-readable key group descriptions.
31    pub approximate_groups: HashMap<String, Vec<String>>,
32    /// When present, contains the derived facts from a session-level DERIVE
33    /// that have not yet been applied. Use `tx.apply(derived)` to materialize.
34    pub derived_fact_set: Option<DerivedFactSet>,
35    /// Diagnostics for an evaluation that stopped before completing, present
36    /// only on the `allow_partial` path. Names which rules were left
37    /// incomplete or skipped (so a zero-row count can be distinguished from a
38    /// genuinely empty rule) and which complement rules are consequently
39    /// unsound. `None` for a normal, complete evaluation.
40    pub incomplete: Option<uni_common::LocyIncomplete>,
41}
42
43/// Result of executing a single Phase 4 command.
44#[derive(Debug, Clone)]
45pub enum CommandResult {
46    Query(Vec<FactRow>),
47    Assume(Vec<FactRow>),
48    Explain(DerivationNode),
49    Abduce(AbductionResult),
50    Derive {
51        affected: usize,
52    },
53    Cypher(Vec<FactRow>),
54    /// Phase C C2: result of a `CALIBRATE` statement — the fitted
55    /// calibrator plus pre- and post-calibration holdout metrics.
56    Calibrate(CalibrationResult),
57    /// Phase C C3: result of a `VALIDATE` statement — the metric
58    /// values computed over `(rule_output, ground_truth)` pairs.
59    Validate(ValidationResult),
60}
61
62/// Outcome of a single `VALIDATE` invocation. Phase C C3.
63///
64/// `metrics` maps each requested metric to its scalar value. The
65/// `n_samples` field reports how many `(prediction, label)` pairs
66/// were retained after joining the rule's PROB column with the
67/// TARGET expression. Bare `ECE` produces a `EceBinningBias`
68/// compile-time warning (surfaced via `LocyResult.compile_warnings`).
69#[derive(Debug, Clone)]
70pub struct ValidationResult {
71    pub rule_name: String,
72    pub prob_column: String,
73    pub n_samples: usize,
74    pub metrics: Vec<(uni_cypher::locy_ast::ValidationMetric, f64)>,
75}
76
77impl ValidationResult {
78    pub fn metric(&self, m: uni_cypher::locy_ast::ValidationMetric) -> Option<f64> {
79        self.metrics
80            .iter()
81            .find(|(name, _)| *name == m)
82            .map(|(_, v)| *v)
83    }
84}
85
86/// Phase C C1a: per-prediction confidence interval surfaced by
87/// uncertainty-aware calibrators. For split-conformal, the band is
88/// `[p - q, p + q]` clipped to `[0, 1]` where `q` is the
89/// `(1 - alpha)`-quantile of holdout nonconformity scores.
90#[derive(Debug, Clone, Copy)]
91pub struct ConfidenceBand {
92    pub lower: f64,
93    pub upper: f64,
94    pub source: ConfidenceSource,
95}
96
97/// Phase C C1a: provenance tag for a [`ConfidenceBand`] — identifies
98/// which uncertainty-quantification machinery produced the bounds.
99/// `Conformal` shipped in C1a; ensemble and credal variants follow in
100/// D-C1e as extensibility hooks for future calibrators.
101#[derive(Debug, Clone, Copy)]
102pub enum ConfidenceSource {
103    /// Split-conformal predictor: `alpha` is the miscoverage rate
104    /// (e.g. `0.1` → 90% coverage). Band is centered on the point
105    /// estimate, ± the `(1 - alpha)`-quantile of holdout
106    /// nonconformity scores. Shipped in C1a.
107    Conformal { alpha: f64 },
108    /// Phase D D-C1e: bootstrap or N-of-K ensemble calibrator. The
109    /// band is derived from cross-estimator variance: `[p - σ, p + σ]`
110    /// (clipped to `[0, 1]`) where `σ` is the standard deviation of
111    /// per-estimator predictions on the holdout. `n_estimators` is the
112    /// number of base learners that voted, surfaced so consumers can
113    /// reason about the noise floor of the band.
114    EnsembleVariance { n_estimators: usize },
115    /// Phase D D-C1e: credal (imprecise-probability) calibrator. The
116    /// band is an explicit interval `[lower, upper]` derived from a
117    /// credal prior rather than a point estimate ± halo. The two
118    /// `_prior` fields surface the calibrator's lower / upper prior
119    /// hyperparameters so consumers can map the band back to its
120    /// belief-revision shape.
121    Credal { lower_prior: f64, upper_prior: f64 },
122}
123
124#[cfg(test)]
125mod confidence_source_tests {
126    use super::ConfidenceSource;
127
128    #[test]
129    fn conformal_debug_format() {
130        let s = ConfidenceSource::Conformal { alpha: 0.1 };
131        let dbg = format!("{:?}", s);
132        assert!(dbg.contains("Conformal"));
133        assert!(dbg.contains("0.1"));
134    }
135
136    #[test]
137    fn ensemble_variance_debug_format() {
138        let s = ConfidenceSource::EnsembleVariance { n_estimators: 50 };
139        let dbg = format!("{:?}", s);
140        assert!(dbg.contains("EnsembleVariance"));
141        assert!(dbg.contains("50"));
142    }
143
144    #[test]
145    fn credal_debug_format() {
146        let s = ConfidenceSource::Credal {
147            lower_prior: 0.1,
148            upper_prior: 0.9,
149        };
150        let dbg = format!("{:?}", s);
151        assert!(dbg.contains("Credal"));
152        assert!(dbg.contains("0.1"));
153        assert!(dbg.contains("0.9"));
154    }
155}
156
157/// Outcome of a single `CALIBRATE` invocation. Phase C C2.
158///
159/// `calibrator` is the fitted transform; user code typically wraps it
160/// over the base classifier via `CalibratedClassifier` and re-registers
161/// the wrapped classifier in `LocyConfig::classifier_registry` for
162/// subsequent evaluations.
163#[derive(Debug, Clone)]
164pub struct CalibrationResult {
165    pub model_name: String,
166    pub method: crate::calibration::CalibrationMethodKind,
167    pub n_samples: usize,
168    pub holdout_size: usize,
169    pub calibrator: std::sync::Arc<dyn crate::calibration::Calibrator>,
170    pub raw_brier: f64,
171    pub raw_ece: f64,
172    pub calibrated_brier: f64,
173    pub calibrated_ece: f64,
174    /// Phase C C1a: for conformal calibrators, the
175    /// `(1 - alpha)`-quantile of holdout nonconformity scores —
176    /// the half-width of every confidence band the calibrator will
177    /// emit at inference. `None` for non-conformal methods.
178    pub confidence_band_quantile: Option<f64>,
179}
180
181/// Phase C B1–B3: per neural-model invocation provenance, attached
182/// to a [`DerivationNode`] when the derivation's body invoked one
183/// or more classifiers. `raw_probability` is the classifier's
184/// direct output; `calibrated_probability` is the post-Calibrator
185/// value (when any calibrator other than `Identity` is registered).
186/// `confidence_band` is populated when the active calibrator is
187/// conformal (or any future band-emitting calibrator).
188#[derive(Debug, Clone)]
189pub struct NeuralProvenance {
190    pub model_name: String,
191    pub raw_probability: f64,
192    pub calibrated_probability: Option<f64>,
193    pub confidence_band: Option<ConfidenceBand>,
194}
195
196/// A node in a derivation tree, produced by EXPLAIN RULE.
197#[derive(Debug, Clone)]
198pub struct DerivationNode {
199    pub rule: String,
200    pub clause_index: usize,
201    pub priority: Option<i64>,
202    pub bindings: HashMap<String, Value>,
203    pub along_values: HashMap<String, Value>,
204    pub children: Vec<DerivationNode>,
205    pub graph_fact: Option<String>,
206    /// True when this node's probability was computed via BDD fallback
207    /// (independence mode) because the group exceeded `max_bdd_variables`.
208    pub approximate: bool,
209    /// Probability of this specific proof path, populated when top-k proof
210    /// filtering is active (Scallop, Huang et al. 2021).
211    pub proof_probability: Option<f64>,
212    /// Phase C B1–B3: neural-model invocations that contributed to
213    /// this fact's derivation. Empty for purely-symbolic
214    /// derivations.
215    pub neural_calls: Vec<NeuralProvenance>,
216}
217
218/// Result of an ABDUCE query.
219#[derive(Debug, Clone, serde::Serialize)]
220pub struct AbductionResult {
221    pub modifications: Vec<ValidatedModification>,
222}
223
224/// A modification with validation status and cost.
225#[derive(Debug, Clone, serde::Serialize)]
226pub struct ValidatedModification {
227    pub modification: Modification,
228    /// Whether this modification satisfies the ABDUCE goal when applied via savepoint.
229    pub validated: bool,
230    /// Cost metric for ranking modifications: RemoveEdge=1.0, ChangeProperty=0.5, AddEdge=1.5.
231    pub cost: f64,
232}
233
234/// A proposed graph modification from ABDUCE.
235#[derive(Debug, Clone, serde::Serialize)]
236pub enum Modification {
237    RemoveEdge {
238        source_var: String,
239        target_var: String,
240        edge_var: String,
241        edge_type: String,
242        /// Property constraints used to identify the specific edge to remove.
243        match_properties: HashMap<String, Value>,
244    },
245    ChangeProperty {
246        element_var: String,
247        property: String,
248        old_value: Box<Value>,
249        new_value: Box<Value>,
250    },
251    AddEdge {
252        source_var: String,
253        target_var: String,
254        edge_type: String,
255        properties: HashMap<String, Value>,
256    },
257}
258
259/// A derived edge to be materialized.
260#[derive(Debug, Clone)]
261pub struct DerivedEdge {
262    pub edge_type: String,
263    pub source_label: String,
264    pub source_properties: Properties,
265    pub target_label: String,
266    pub target_properties: Properties,
267    pub edge_properties: Properties,
268}
269
270/// Pure-data representation of facts derived by a session-level DERIVE.
271///
272/// Apply to a transaction via `tx.apply(derived)` or `tx.apply_with(derived)`.
273#[derive(Debug, Clone)]
274pub struct DerivedFactSet {
275    /// New vertices grouped by label.
276    pub vertices: HashMap<String, Vec<Properties>>,
277    /// Derived edges connecting source/target vertices.
278    pub edges: Vec<DerivedEdge>,
279    /// Evaluation statistics from the DERIVE run.
280    pub stats: LocyStats,
281    /// Database version at evaluation time (for staleness detection).
282    pub evaluated_at_version: u64,
283    /// Internal: Cypher ASTs for faithful replay during `tx.apply()`.
284    #[doc(hidden)]
285    pub mutation_queries: Vec<uni_cypher::ast::Query>,
286}
287
288impl DerivedFactSet {
289    /// Total number of derived facts (vertices + edges).
290    pub fn fact_count(&self) -> usize {
291        self.vertices.values().map(|v| v.len()).sum::<usize>() + self.edges.len()
292    }
293
294    /// True when no facts were derived.
295    pub fn is_empty(&self) -> bool {
296        self.vertices.is_empty() && self.edges.is_empty()
297    }
298}
299
300/// Statistics collected during Locy program evaluation.
301#[derive(Debug, Clone, Default)]
302pub struct LocyStats {
303    pub strata_evaluated: usize,
304    pub total_iterations: usize,
305    pub derived_nodes: usize,
306    pub derived_edges: usize,
307    pub evaluation_time: Duration,
308    pub queries_executed: usize,
309    pub mutations_executed: usize,
310    /// Peak memory used by derived relations (in bytes).
311    pub peak_memory_bytes: usize,
312}
313
314impl LocyResult {
315    /// Get derived facts for a specific rule.
316    pub fn derived_facts(&self, rule: &str) -> Option<&Vec<FactRow>> {
317        self.derived.get(rule)
318    }
319
320    /// Get rows from the first Query command result.
321    pub fn rows(&self) -> Option<&Vec<FactRow>> {
322        self.command_results.iter().find_map(|cr| cr.as_query())
323    }
324
325    /// Get column names from the first Query command result's first row.
326    ///
327    /// Column names are returned in deterministic (sorted) order. The
328    /// underlying [`FactRow`] is a `HashMap`, whose iteration order is
329    /// randomized per-run; callers (snapshot tests, golden outputs,
330    /// downstream display) rely on a stable ordering.
331    pub fn columns(&self) -> Option<Vec<String>> {
332        self.rows().and_then(|rows| {
333            rows.first().map(|row| {
334                let mut cols: Vec<String> = row.keys().cloned().collect();
335                cols.sort();
336                cols
337            })
338        })
339    }
340
341    /// Get execution statistics.
342    pub fn stats(&self) -> &LocyStats {
343        &self.stats
344    }
345
346    /// Get the total number of fixpoint iterations.
347    pub fn iterations(&self) -> usize {
348        self.stats.total_iterations
349    }
350
351    /// Get runtime warnings collected during evaluation.
352    pub fn compile_warnings(&self) -> &[crate::types::CompilerWarning] {
353        &self.compile_warnings
354    }
355
356    pub fn command_results(&self) -> &[CommandResult] {
357        &self.command_results
358    }
359
360    pub fn warnings(&self) -> &[RuntimeWarning] {
361        &self.warnings
362    }
363
364    /// Check whether a specific warning code was emitted.
365    pub fn has_warning(&self, code: &RuntimeWarningCode) -> bool {
366        self.warnings.iter().any(|w| w.code == *code)
367    }
368
369    /// True when the evaluation was cut short by a timeout or iteration
370    /// limit. The `derived` map then contains whatever facts were
371    /// accumulated before the cutoff; partial results may not satisfy
372    /// the fixpoint invariant.
373    ///
374    /// This is exactly `self.incomplete.is_some()`. Inspect
375    /// [`incomplete`](LocyResult::incomplete) for the reason and the
376    /// skipped/unsound rule lists. Note it is only ever `true` on the
377    /// opt-in `allow_partial` path — by default an incomplete evaluation
378    /// returns [`UniError::LocyIncomplete`] instead of a result.
379    ///
380    /// [`UniError::LocyIncomplete`]: uni_common::UniError::LocyIncomplete
381    pub fn timed_out(&self) -> bool {
382        self.incomplete.is_some()
383    }
384}
385
386impl CommandResult {
387    /// If this is an Explain result, return the derivation node.
388    pub fn as_explain(&self) -> Option<&DerivationNode> {
389        match self {
390            CommandResult::Explain(node) => Some(node),
391            _ => None,
392        }
393    }
394
395    /// If this is a Query result, return the rows.
396    pub fn as_query(&self) -> Option<&Vec<FactRow>> {
397        match self {
398            CommandResult::Query(rows) => Some(rows),
399            _ => None,
400        }
401    }
402
403    /// If this is an Abduce result, return it.
404    pub fn as_abduce(&self) -> Option<&AbductionResult> {
405        match self {
406            CommandResult::Abduce(result) => Some(result),
407            _ => None,
408        }
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    /// Regression: previously, `columns()` returned `HashMap::keys()` order,
417    /// which is randomized per-run. Snapshot tests and downstream consumers
418    /// rely on deterministic column ordering.
419    #[test]
420    fn columns_returned_in_sorted_order() {
421        let mut row = FactRow::new();
422        // Insert in deliberately non-alphabetic order. With a HashMap this
423        // is enough to surface nondeterminism on many runs; sorting makes
424        // the test deterministic regardless of hasher state.
425        row.insert("zeta".into(), Value::Int(1));
426        row.insert("alpha".into(), Value::Int(2));
427        row.insert("mu".into(), Value::Int(3));
428
429        let result = LocyResult {
430            derived: HashMap::new(),
431            stats: LocyStats::default(),
432            command_results: vec![CommandResult::Query(vec![row])],
433            warnings: Vec::new(),
434            compile_warnings: Vec::new(),
435            approximate_groups: HashMap::new(),
436            derived_fact_set: None,
437            incomplete: None,
438        };
439
440        let cols = result
441            .columns()
442            .expect("expected columns for non-empty result");
443        assert_eq!(
444            cols,
445            vec!["alpha".to_owned(), "mu".to_owned(), "zeta".to_owned()]
446        );
447    }
448
449    #[test]
450    fn abduce_result_serializes_to_json() {
451        let result = AbductionResult {
452            modifications: vec![
453                ValidatedModification {
454                    modification: Modification::ChangeProperty {
455                        element_var: "a".into(),
456                        property: "flagged".into(),
457                        old_value: Box::new(Value::String("false".into())),
458                        new_value: Box::new(Value::String("true".into())),
459                    },
460                    validated: true,
461                    cost: 0.5,
462                },
463                ValidatedModification {
464                    modification: Modification::RemoveEdge {
465                        source_var: "a".into(),
466                        target_var: "b".into(),
467                        edge_var: "e".into(),
468                        edge_type: "TRANSFERS_TO".into(),
469                        match_properties: HashMap::from([("amount".into(), Value::Float(1000.0))]),
470                    },
471                    validated: false,
472                    cost: 1.0,
473                },
474                ValidatedModification {
475                    modification: Modification::AddEdge {
476                        source_var: "a".into(),
477                        target_var: "b".into(),
478                        edge_type: "FLAGGED_BY".into(),
479                        properties: HashMap::new(),
480                    },
481                    validated: true,
482                    cost: 1.5,
483                },
484            ],
485        };
486
487        let json = serde_json::to_value(&result).expect("serialization failed");
488        let mods = json["modifications"].as_array().unwrap();
489        assert_eq!(mods.len(), 3);
490        assert_eq!(mods[0]["validated"], true);
491        assert_eq!(mods[0]["cost"], 0.5);
492        assert!(mods[0]["modification"]["ChangeProperty"].is_object());
493        assert!(mods[1]["modification"]["RemoveEdge"].is_object());
494        assert!(mods[2]["modification"]["AddEdge"].is_object());
495    }
496}