Skip to main content

uni_query/query/df_graph/
locy_explain.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! EXPLAIN RULE derivation tree construction.
5//!
6//! Ported from `uni-locy/src/orchestrator/explain.rs`. Uses `DerivedFactSource`
7//! instead of `CypherExecutor`. Uses `RowStore` for row-based fact storage.
8//!
9//! Implements Mode A (provenance-based, uses ProvenanceStore recorded during fixpoint)
10//! with fallback to Mode B (re-execution) when tracker has no entries for the rule.
11
12use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, FactRow, LocyConfig, LocyError, LocyStats};
19
20use super::locy_delta::{
21    KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{
25    eval_expr, normalize_graph_row, record_batches_to_locy_rows, values_equal_for_join,
26};
27use super::locy_slg::SLGResolver;
28use super::locy_traits::DerivedFactSource;
29
30/// A single term in a derivation proof: identifies one IS-ref dependency.
31///
32/// Each `ProofTerm` records a dependency edge in the proof DAG — the source
33/// rule that was referenced and the opaque hash of the specific base fact
34/// consumed (Cui & Widom 2000, lineage).
35#[derive(Clone, Debug)]
36pub struct ProofTerm {
37    /// Name of the IS-ref rule that produced this dependency.
38    pub source_rule: String,
39    /// Opaque hash identifying the base fact consumed from `source_rule`.
40    pub base_fact_id: Vec<u8>,
41}
42
43/// Provenance annotation for a derived fact (Green et al. 2007, Definition 3.1).
44///
45/// Captures a single derivation witness: the rule and clause that produced the
46/// fact, plus the `support` set of proof terms that contributed to it.
47#[derive(Clone, Debug)]
48pub struct ProvenanceAnnotation {
49    /// Name of the rule that derived this fact.
50    pub rule_name: String,
51    /// Index of the clause within the rule that produced this fact.
52    pub clause_index: usize,
53    /// Proof terms: IS-ref input facts (populated when IS-ref tracking is available).
54    pub support: Vec<ProofTerm>,
55    /// ALONG column values captured at derivation time.
56    pub along_values: HashMap<String, Value>,
57    /// Fixpoint iteration number when the fact was first derived.
58    pub iteration: usize,
59    /// Full fact row stored for Mode A filtering/display.
60    pub fact_row: FactRow,
61    /// Probability of this specific proof path (populated by top-k filtering).
62    pub proof_probability: Option<f64>,
63    /// Phase C B1–B3: per neural-model invocation that contributed
64    /// to this fact's derivation. Populated by
65    /// `LocyModelInvokeExec` when it runs a classifier; consumed
66    /// by Mode A EXPLAIN to surface model_name + raw + calibrated
67    /// + optional confidence band per call.
68    pub neural_calls: Vec<uni_locy::NeuralProvenance>,
69}
70
71/// Provenance store for derived facts (Green et al. 2007, §3).
72///
73/// Stores provenance annotations keyed by opaque fact hashes. Enables
74/// Mode A (provenance-based) EXPLAIN without re-execution.
75/// First-derivation-wins: once a fact hash is recorded, later iterations
76/// do not overwrite it.
77#[derive(Debug)]
78pub struct ProvenanceStore {
79    entries: RwLock<HashMap<Vec<u8>, Vec<ProvenanceAnnotation>>>,
80}
81
82impl ProvenanceStore {
83    pub fn new() -> Self {
84        Self {
85            entries: RwLock::new(HashMap::new()),
86        }
87    }
88
89    /// Record a provenance annotation. First-derivation-wins: if the hash is already
90    /// present, the existing annotation is kept (unlimited mode).
91    pub fn record(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation) {
92        if let Ok(mut guard) = self.entries.write() {
93            guard.entry(fact_hash).or_insert_with(|| vec![entry]);
94        }
95    }
96
97    /// Record a provenance annotation with top-k filtering.
98    ///
99    /// Retains at most `k` annotations per fact, ordered by `proof_probability`
100    /// (highest first). Annotations without a proof probability are treated as
101    /// having probability 0.0 for ordering purposes.
102    pub fn record_top_k(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation, k: usize) {
103        if let Ok(mut guard) = self.entries.write() {
104            let vec = guard.entry(fact_hash).or_default();
105            vec.push(entry);
106            // Sort descending by proof_probability.
107            vec.sort_by(|a, b| {
108                b.proof_probability
109                    .unwrap_or(0.0)
110                    .partial_cmp(&a.proof_probability.unwrap_or(0.0))
111                    .unwrap_or(std::cmp::Ordering::Equal)
112            });
113            vec.truncate(k);
114        }
115    }
116
117    /// Look up the first (highest-priority) provenance annotation for a fact hash.
118    pub fn lookup(&self, fact_hash: &[u8]) -> Option<ProvenanceAnnotation> {
119        self.entries.read().ok()?.get(fact_hash)?.first().cloned()
120    }
121
122    /// Look up all provenance annotations for a fact hash.
123    pub fn lookup_all(&self, fact_hash: &[u8]) -> Option<Vec<ProvenanceAnnotation>> {
124        let guard = self.entries.read().ok()?;
125        guard.get(fact_hash).cloned()
126    }
127
128    /// Collect base fact probabilities from stored annotations.
129    ///
130    /// Scans all annotations for base facts (those with empty support, i.e. leaf
131    /// nodes in the proof tree) and extracts the PROB column value from their
132    /// `fact_row`. Used by top-k proof filtering to compute `proof_probability`.
133    pub fn base_fact_probs(&self) -> HashMap<Vec<u8>, f64> {
134        let mut probs = HashMap::new();
135        if let Ok(guard) = self.entries.read() {
136            for (fact_hash, annotations) in guard.iter() {
137                if let Some(ann) = annotations.first()
138                    && ann.support.is_empty()
139                    && let Some(uni_common::Value::Float(p)) = ann.fact_row.get("PROB")
140                {
141                    probs.insert(fact_hash.clone(), *p);
142                }
143            }
144        }
145        probs
146    }
147
148    /// Get all entries for a specific rule name (returns first annotation per fact).
149    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, ProvenanceAnnotation)> {
150        match self.entries.read() {
151            Ok(guard) => guard
152                .iter()
153                .filter_map(|(k, annotations)| {
154                    annotations
155                        .first()
156                        .filter(|e| e.rule_name == rule_name)
157                        .map(|e| (k.clone(), e.clone()))
158                })
159                .collect(),
160            Err(_) => vec![],
161        }
162    }
163}
164
165/// Compute the probability of a single proof path from its support terms.
166///
167/// Returns `None` if any base fact's probability is unknown.
168pub fn compute_proof_probability(
169    support: &[ProofTerm],
170    base_probs: &HashMap<Vec<u8>, f64>,
171) -> Option<f64> {
172    if support.is_empty() {
173        return None;
174    }
175    let mut product = 1.0;
176    for term in support {
177        let p = base_probs.get(&term.base_fact_id)?;
178        product *= p;
179    }
180    Some(product)
181}
182
183impl Default for ProvenanceStore {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189/// Set of (rule_name, key_tuple) to detect cycles during recursive derivation (Mode B).
190type VisitedSet = HashSet<(String, KeyTuple)>;
191
192/// Build a derivation tree for a rule, showing how each fact was derived.
193///
194/// Tries Mode A (provenance-based, uses ProvenanceStore) first when a store is
195/// provided and has entries for the rule.  Falls through to Mode B (re-execution)
196/// when Mode A cannot produce a result.
197#[expect(
198    clippy::too_many_arguments,
199    reason = "explain requires full program context and tracker state"
200)]
201pub async fn explain_rule(
202    query: &ExplainRule,
203    program: &CompiledProgram,
204    fact_source: &dyn DerivedFactSource,
205    config: &LocyConfig,
206    derived_store: &mut RowStore,
207    stats: &mut LocyStats,
208    tracker: Option<&ProvenanceStore>,
209    approximate_groups: Option<&HashMap<String, Vec<String>>>,
210) -> Result<DerivationNode, LocyError> {
211    // Mode A: provenance-based (no re-execution required).
212    // Falls through to Mode B when tracker is absent or has no matching entries.
213    if let Some(t) = tracker
214        && let Ok(node) = explain_rule_mode_a(
215            query,
216            program,
217            t,
218            fact_source,
219            derived_store,
220            approximate_groups,
221        )
222        .await
223    {
224        return Ok(node);
225    }
226
227    // Mode B: re-execution fallback
228    explain_rule_mode_b(
229        query,
230        program,
231        fact_source,
232        config,
233        derived_store,
234        stats,
235        approximate_groups,
236    )
237    .await
238}
239
240/// Mode A: build derivation tree using recorded provenance from the fixpoint loop.
241///
242/// Returns `Err` when no tracker entries exist for the rule (signals Mode B fallback).
243async fn explain_rule_mode_a(
244    query: &ExplainRule,
245    program: &CompiledProgram,
246    tracker: &ProvenanceStore,
247    fact_source: &dyn DerivedFactSource,
248    _derived_store: &RowStore,
249    approximate_groups: Option<&HashMap<String, Vec<String>>>,
250) -> Result<DerivationNode, LocyError> {
251    let rule_name = query.rule_name.to_string();
252    let rule = program
253        .rule_catalog
254        .get(&rule_name)
255        .ok_or_else(|| LocyError::EvaluationError {
256            message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
257        })?;
258
259    let tracker_entries = tracker.entries_for_rule(&rule_name);
260    if tracker_entries.is_empty() {
261        return Err(LocyError::EvaluationError {
262            message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
263        });
264    }
265
266    // Enrich each tracker entry's fact_row so WHERE predicates like
267    // `n.prop = X` or `e.prop = X` resolve against full Node / Edge
268    // values. Two encodings to handle:
269    //   - Node bindings: tracker stores `Value::Int(vid)` (Arrow UInt64).
270    //     Resolved via the side-channel `lookup_nodes_by_vids` Cypher.
271    //   - Edge bindings: tracker stores `Value::Map({_eid,_type,...})`
272    //     (Arrow struct). Resolved by `normalize_graph_row`, the same
273    //     helper Mode B uses, which converts Map → Value::Edge.
274    //
275    // Enrichment is scoped to KEY columns from the rule's yield_schema
276    // (`is_key && !is_prob`) so non-KEY scalar Ints (e.g. count columns)
277    // never get coerced into Nodes.
278    let key_cols: HashSet<String> = rule
279        .yield_schema
280        .iter()
281        .filter(|c| c.is_key && !c.is_prob)
282        .map(|c| c.name.clone())
283        .collect();
284
285    let mut tracker_entries: Vec<(Vec<u8>, ProvenanceAnnotation)> = tracker_entries
286        .into_iter()
287        .map(|(h, mut ann)| {
288            // Map → Edge / Node conversion for bare graph-entity vars.
289            normalize_graph_row(&mut ann.fact_row);
290            (h, ann)
291        })
292        .collect();
293
294    // Collect residual Int-vid candidates from KEY columns only.
295    let mut candidate_vids: Vec<u64> = Vec::new();
296    for (_, entry) in &tracker_entries {
297        for (k, v) in &entry.fact_row {
298            if key_cols.contains(k)
299                && let Value::Int(i) = v
300                && *i >= 0
301            {
302                candidate_vids.push(*i as u64);
303            }
304        }
305    }
306    candidate_vids.sort_unstable();
307    candidate_vids.dedup();
308    let vid_to_node: HashMap<u64, Value> = if candidate_vids.is_empty() {
309        HashMap::new()
310    } else {
311        fact_source
312            .lookup_nodes_by_vids(&candidate_vids)
313            .await
314            .unwrap_or_default()
315    };
316    if !vid_to_node.is_empty() {
317        for (_, ann) in &mut tracker_entries {
318            for (k, v) in ann.fact_row.iter_mut() {
319                if key_cols.contains(k)
320                    && let Value::Int(i) = v
321                    && *i >= 0
322                    && let Some(node) = vid_to_node.get(&(*i as u64))
323                {
324                    *v = node.clone();
325                }
326            }
327        }
328    }
329
330    // Filter tracker entries by WHERE expression
331    let matching_entries: Vec<_> = if let Some(where_expr) = &query.where_expr {
332        tracker_entries
333            .into_iter()
334            .filter(|(_, entry)| {
335                eval_expr(where_expr, &entry.fact_row)
336                    .map(|v| v.as_bool().unwrap_or(false))
337                    .unwrap_or(false)
338            })
339            .collect()
340    } else {
341        tracker_entries
342    };
343
344    if matching_entries.is_empty() {
345        return Err(LocyError::EvaluationError {
346            message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
347        });
348    }
349
350    let is_approximate = approximate_groups
351        .map(|ag| ag.contains_key(&rule_name))
352        .unwrap_or(false);
353
354    let mut root = DerivationNode {
355        rule: rule_name.clone(),
356        clause_index: 0,
357        priority: rule.priority,
358        bindings: HashMap::new(),
359        along_values: HashMap::new(),
360        children: Vec::new(),
361        graph_fact: None,
362        approximate: is_approximate,
363        proof_probability: None,
364        neural_calls: Vec::new(),
365    };
366
367    for (_, entry) in matching_entries {
368        let along_values = extract_along_values(&entry.fact_row, rule);
369        let clause_priority = rule
370            .clauses
371            .get(entry.clause_index)
372            .and_then(|c| c.priority);
373        let base_fact = format!(
374            "[iter={}] {}",
375            entry.iteration,
376            format_graph_fact(&entry.fact_row)
377        );
378        let graph_fact = if is_approximate {
379            format!("[APPROXIMATE] {}", base_fact)
380        } else {
381            base_fact
382        };
383        let node = DerivationNode {
384            rule: rule_name.clone(),
385            clause_index: entry.clause_index,
386            priority: clause_priority.or(rule.priority),
387            bindings: entry.fact_row.clone(),
388            along_values,
389            // Mode A: children not tracked (inputs list is reserved for future recursion)
390            children: vec![],
391            graph_fact: Some(graph_fact),
392            approximate: is_approximate,
393            proof_probability: entry.proof_probability,
394            // Phase C B1–B3: surface the per-call neural metadata
395            // captured by `collect_neural_calls_for_row` during
396            // fixpoint into the EXPLAIN derivation tree.
397            neural_calls: entry.neural_calls.clone(),
398        };
399        root.children.push(node);
400    }
401
402    Ok(root)
403}
404
405/// Mode B: re-execution fallback — re-executes clause queries to find which
406/// clause produced each matching fact, then recurses into IS references.
407async fn explain_rule_mode_b(
408    query: &ExplainRule,
409    program: &CompiledProgram,
410    fact_source: &dyn DerivedFactSource,
411    config: &LocyConfig,
412    derived_store: &mut RowStore,
413    stats: &mut LocyStats,
414    approximate_groups: Option<&HashMap<String, Vec<String>>>,
415) -> Result<DerivationNode, LocyError> {
416    let rule_name = query.rule_name.to_string();
417    let rule = program
418        .rule_catalog
419        .get(&rule_name)
420        .ok_or_else(|| LocyError::EvaluationError {
421            message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
422        })?;
423
424    let key_columns: Vec<String> = rule
425        .yield_schema
426        .iter()
427        .filter(|c| c.is_key)
428        .map(|c| c.name.clone())
429        .collect();
430
431    // Re-evaluate the rule via SLG to obtain rows with full node objects and properties.
432    // The native fixpoint's orch_store has VID-only integers that fail property-based
433    // WHERE filters (e.g. a.name = 'A') — we need actual Value::Node values here.
434    {
435        let mut fresh_store = RowStore::new();
436        let slg_start = std::time::Instant::now();
437        let mut resolver =
438            SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
439        resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
440        stats.queries_executed += resolver.stats.queries_executed;
441        // Merge full-node facts into derived_store so IS-ref lookups in
442        // build_derivation_node also get proper node objects (not VIDs).
443        for (name, relation) in fresh_store {
444            derived_store.insert(name, relation);
445        }
446    }
447
448    // Get all derived facts for this rule (now populated with full node objects)
449    let facts = derived_store
450        .get(&rule_name)
451        .map(|r| r.rows.clone())
452        .unwrap_or_default();
453
454    // Filter facts by WHERE expression
455    let filtered: Vec<FactRow> = if let Some(where_expr) = &query.where_expr {
456        facts
457            .into_iter()
458            .filter(|row| {
459                eval_expr(where_expr, row)
460                    .map(|v| v.as_bool().unwrap_or(false))
461                    .unwrap_or(false)
462            })
463            .collect()
464    } else {
465        facts
466    };
467
468    let is_approximate = approximate_groups
469        .map(|ag| ag.contains_key(&rule_name))
470        .unwrap_or(false);
471
472    // Build derivation tree root
473    let mut root = DerivationNode {
474        rule: rule_name.clone(),
475        clause_index: 0,
476        priority: rule.priority,
477        bindings: HashMap::new(),
478        along_values: HashMap::new(),
479        children: Vec::new(),
480        graph_fact: None,
481        approximate: is_approximate,
482        proof_probability: None,
483        neural_calls: Vec::new(),
484    };
485
486    // For each matching fact, recursively build a derivation node
487    for fact in &filtered {
488        let mut visited = VisitedSet::new();
489        let mut node = build_derivation_node(
490            &rule_name,
491            fact,
492            &key_columns,
493            program,
494            fact_source,
495            derived_store,
496            stats,
497            &mut visited,
498            config.max_explain_depth,
499        )
500        .await?;
501        if is_approximate {
502            node.approximate = true;
503            if let Some(ref gf) = node.graph_fact {
504                node.graph_fact = Some(format!("[APPROXIMATE] {}", gf));
505            }
506        }
507        root.children.push(node);
508    }
509
510    Ok(root)
511}
512
513/// Recursively build a derivation node for a single fact of a rule.
514///
515/// Finds which clause produced this fact, extracts ALONG values,
516/// and recurses into IS reference dependencies.
517#[expect(
518    clippy::too_many_arguments,
519    reason = "recursive derivation node builder requires full fact context"
520)]
521fn build_derivation_node<'a>(
522    rule_name: &'a str,
523    fact: &'a FactRow,
524    key_columns: &'a [String],
525    program: &'a CompiledProgram,
526    fact_source: &'a dyn DerivedFactSource,
527    derived_store: &'a mut RowStore,
528    stats: &'a mut LocyStats,
529    visited: &'a mut VisitedSet,
530    max_depth: usize,
531) -> std::pin::Pin<
532    Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + Send + 'a>,
533> {
534    Box::pin(async move {
535        let rule =
536            program
537                .rule_catalog
538                .get(rule_name)
539                .ok_or_else(|| LocyError::EvaluationError {
540                    message: format!("rule '{}' not found during EXPLAIN", rule_name),
541                })?;
542
543        let key_tuple = extract_key(fact, key_columns);
544        let visit_key = (rule_name.to_string(), key_tuple);
545
546        // Cycle detection
547        if !visited.insert(visit_key.clone()) || max_depth == 0 {
548            return Ok(DerivationNode {
549                rule: rule_name.to_string(),
550                clause_index: 0,
551                priority: rule.priority,
552                bindings: fact.clone(),
553                along_values: extract_along_values(fact, rule),
554                children: Vec::new(),
555                graph_fact: Some("(cycle)".to_string()),
556                approximate: false,
557                proof_probability: None,
558                neural_calls: Vec::new(),
559            });
560        }
561
562        // Match on KEY columns only.  Clause-level resolution returns only
563        // base graph bindings (vertex/edge identifiers); non-KEY yield columns
564        // (FOLD-aggregated, similar_to, etc.) are absent from those rows.
565        // KEY columns uniquely identify a derived fact, so this is sufficient.
566
567        // Try each clause to find the one that produced this fact
568        for (clause_idx, clause) in rule.clauses.iter().enumerate() {
569            let has_is_refs = clause
570                .where_conditions
571                .iter()
572                .any(|c| matches!(c, RuleCondition::IsReference(_)));
573            let has_along = !clause.along.is_empty();
574
575            let resolved = if has_is_refs || has_along {
576                let rows = resolve_clause_with_is_refs(
577                    clause,
578                    fact_source,
579                    derived_store,
580                    &program.rule_catalog,
581                    None, // EXPLAIN traces proofs, doesn't compute probabilities
582                )
583                .await?;
584                stats.queries_executed += 1;
585                rows
586            } else {
587                let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
588                let raw_batches = fact_source
589                    .execute_pattern(&clause.match_pattern, &cypher_conditions)
590                    .await?;
591                stats.queries_executed += 1;
592                record_batches_to_locy_rows(&raw_batches)
593            };
594
595            // Use values_equal_for_join for VID/EID-based comparison: sidecar
596            // schema mode can add `overflow_json: Null` to nodes in some query
597            // paths, making structural equality unreliable.
598            let matching_row = resolved.iter().find(|row| {
599                key_columns.iter().all(|k| match (row.get(k), fact.get(k)) {
600                    (Some(v1), Some(v2)) => values_equal_for_join(v1, v2),
601                    (None, None) => true,
602                    _ => false,
603                })
604            });
605
606            if let Some(evidence_row) = matching_row {
607                let along_values = extract_along_values(fact, rule);
608
609                // Build children by recursing into IS references
610                let mut children = Vec::new();
611                for cond in &clause.where_conditions {
612                    if let RuleCondition::IsReference(is_ref) = cond {
613                        if is_ref.negated {
614                            continue;
615                        }
616                        let ref_rule_name = is_ref.rule_name.to_string();
617                        if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
618                            let ref_key_columns: Vec<String> = ref_rule
619                                .yield_schema
620                                .iter()
621                                .filter(|c| c.is_key)
622                                .map(|c| c.name.clone())
623                                .collect();
624
625                            let ref_facts: Vec<FactRow> = derived_store
626                                .get(&ref_rule_name)
627                                .map(|r| r.rows.clone())
628                                .unwrap_or_default();
629
630                            let matching_ref_facts: Vec<FactRow> = ref_facts
631                                .into_iter()
632                                .filter(|ref_fact| {
633                                    let subjects_match =
634                                        is_ref.subjects.iter().enumerate().all(|(i, subject)| {
635                                            binding_matches_key(
636                                                evidence_row,
637                                                fact,
638                                                subject,
639                                                ref_fact,
640                                                ref_key_columns.get(i),
641                                            )
642                                        });
643                                    let target_matches =
644                                        is_ref.target.as_ref().is_none_or(|target| {
645                                            binding_matches_key(
646                                                evidence_row,
647                                                fact,
648                                                target,
649                                                ref_fact,
650                                                ref_key_columns.get(is_ref.subjects.len()),
651                                            )
652                                        });
653                                    subjects_match && target_matches
654                                })
655                                .collect();
656
657                            for ref_fact in matching_ref_facts {
658                                let child = build_derivation_node(
659                                    &ref_rule_name,
660                                    &ref_fact,
661                                    &ref_key_columns,
662                                    program,
663                                    fact_source,
664                                    derived_store,
665                                    stats,
666                                    visited,
667                                    max_depth - 1,
668                                )
669                                .await?;
670                                children.push(child);
671                            }
672                        }
673                    }
674                }
675
676                // Backtrack visited set
677                visited.remove(&visit_key);
678
679                let mut merged_bindings = evidence_row.clone();
680                for (k, v) in fact {
681                    merged_bindings.entry(k.clone()).or_insert(v.clone());
682                }
683
684                return Ok(DerivationNode {
685                    rule: rule_name.to_string(),
686                    clause_index: clause_idx,
687                    priority: rule.clauses[clause_idx].priority,
688                    bindings: merged_bindings,
689                    along_values,
690                    children,
691                    graph_fact: Some(format_graph_fact(evidence_row)),
692                    approximate: false,
693                    proof_probability: None,
694                    neural_calls: Vec::new(),
695                });
696            }
697        }
698
699        // No clause matched — leaf node
700        visited.remove(&visit_key);
701        Ok(DerivationNode {
702            rule: rule_name.to_string(),
703            clause_index: 0,
704            priority: rule.priority,
705            bindings: fact.clone(),
706            along_values: extract_along_values(fact, rule),
707            children: Vec::new(),
708            graph_fact: Some(format_graph_fact(fact)),
709            approximate: false,
710            proof_probability: None,
711            neural_calls: Vec::new(),
712        })
713    })
714}
715
716/// Check if a binding variable matches a ref-fact key column via VID-based join.
717///
718/// Looks up `var_name` in `primary` (falling back to `fallback`), then compares
719/// it against `ref_key_col` in `ref_fact` using `values_equal_for_join`.
720/// Returns `true` when the key column is out of range or the binding is absent.
721fn binding_matches_key(
722    primary: &FactRow,
723    fallback: &FactRow,
724    var_name: &str,
725    ref_fact: &FactRow,
726    ref_key_col: Option<&String>,
727) -> bool {
728    let Some(key_col) = ref_key_col else {
729        return true;
730    };
731    let Some(val) = primary.get(var_name).or_else(|| fallback.get(var_name)) else {
732        return true;
733    };
734    ref_fact
735        .get(key_col)
736        .is_some_and(|rv| values_equal_for_join(rv, val))
737}
738
739fn extract_along_values(fact: &FactRow, rule: &CompiledRule) -> HashMap<String, Value> {
740    let mut along_values = HashMap::new();
741    for clause in &rule.clauses {
742        for along in &clause.along {
743            if let Some(v) = fact.get(&along.name) {
744                along_values.insert(along.name.clone(), v.clone());
745            }
746        }
747    }
748    along_values
749}
750
751pub(crate) fn format_graph_fact(row: &FactRow) -> String {
752    let mut entries: Vec<String> = row
753        .iter()
754        .map(|(k, v)| format!("{}: {}", k, format_value(v)))
755        .collect();
756    entries.sort();
757    format!("{{{}}}", entries.join(", "))
758}
759
760fn format_value(v: &Value) -> String {
761    match v {
762        Value::Null => "null".to_string(),
763        Value::Bool(b) => b.to_string(),
764        Value::Int(i) => i.to_string(),
765        Value::Float(f) => f.to_string(),
766        Value::String(s) => format!("\"{}\"", s),
767        Value::List(items) => {
768            let inner: Vec<String> = items.iter().map(format_value).collect();
769            format!("[{}]", inner.join(", "))
770        }
771        Value::Map(m) => {
772            let mut entries: Vec<String> = m
773                .iter()
774                .map(|(k, v)| format!("{}: {}", k, format_value(v)))
775                .collect();
776            entries.sort();
777            format!("{{{}}}", entries.join(", "))
778        }
779        Value::Node(n) => format!("Node({})", n.vid.as_u64()),
780        Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
781        _ => format!("{v:?}"),
782    }
783}
784
785#[cfg(test)]
786mod tests {
787    use super::*;
788
789    fn make_annotation(rule: &str, prob: Option<f64>) -> ProvenanceAnnotation {
790        ProvenanceAnnotation {
791            rule_name: rule.to_string(),
792            clause_index: 0,
793            support: vec![],
794            along_values: HashMap::new(),
795            iteration: 0,
796            fact_row: HashMap::new(),
797            proof_probability: prob,
798            neural_calls: Vec::new(),
799        }
800    }
801
802    #[test]
803    fn record_first_derivation_wins() {
804        let store = ProvenanceStore::new();
805        let hash = b"fact1".to_vec();
806        store.record(hash.clone(), make_annotation("rule_a", None));
807        store.record(hash.clone(), make_annotation("rule_b", None));
808        let entry = store.lookup(&hash).unwrap();
809        assert_eq!(entry.rule_name, "rule_a");
810    }
811
812    #[test]
813    fn lookup_returns_first_annotation() {
814        let store = ProvenanceStore::new();
815        let hash = b"fact1".to_vec();
816        store.record(hash.clone(), make_annotation("rule_a", None));
817        assert_eq!(store.lookup(&hash).unwrap().rule_name, "rule_a");
818        assert!(store.lookup(b"nonexistent").is_none());
819    }
820
821    #[test]
822    fn lookup_all_returns_all_annotations() {
823        let store = ProvenanceStore::new();
824        let hash = b"fact1".to_vec();
825        // record() is first-wins, so only one entry per hash via record()
826        store.record(hash.clone(), make_annotation("rule_a", None));
827        let all = store.lookup_all(&hash).unwrap();
828        assert_eq!(all.len(), 1);
829    }
830
831    #[test]
832    fn record_top_k_retains_highest() {
833        let store = ProvenanceStore::new();
834        let hash = b"fact1".to_vec();
835        store.record_top_k(hash.clone(), make_annotation("low", Some(0.1)), 2);
836        store.record_top_k(hash.clone(), make_annotation("high", Some(0.9)), 2);
837        store.record_top_k(hash.clone(), make_annotation("mid", Some(0.5)), 2);
838        store.record_top_k(hash.clone(), make_annotation("highest", Some(0.95)), 2);
839        store.record_top_k(hash.clone(), make_annotation("lowest", Some(0.01)), 2);
840
841        let all = store.lookup_all(&hash).unwrap();
842        assert_eq!(all.len(), 2);
843        assert_eq!(all[0].rule_name, "highest");
844        assert_eq!(all[1].rule_name, "high");
845    }
846
847    #[test]
848    fn compute_proof_probability_basic() {
849        let support = vec![
850            ProofTerm {
851                source_rule: "r1".to_string(),
852                base_fact_id: b"f1".to_vec(),
853            },
854            ProofTerm {
855                source_rule: "r2".to_string(),
856                base_fact_id: b"f2".to_vec(),
857            },
858        ];
859        let base_probs = HashMap::from([(b"f1".to_vec(), 0.3), (b"f2".to_vec(), 0.5)]);
860        let prob = compute_proof_probability(&support, &base_probs);
861        assert!((prob.unwrap() - 0.15).abs() < 1e-10);
862    }
863
864    #[test]
865    fn compute_proof_probability_empty_support() {
866        let prob = compute_proof_probability(&[], &HashMap::new());
867        assert!(prob.is_none());
868    }
869
870    #[test]
871    fn compute_proof_probability_missing_fact() {
872        let support = vec![ProofTerm {
873            source_rule: "r1".to_string(),
874            base_fact_id: b"unknown".to_vec(),
875        }];
876        let prob = compute_proof_probability(&support, &HashMap::new());
877        assert!(prob.is_none());
878    }
879
880    #[test]
881    fn entries_for_rule_filters_correctly() {
882        let store = ProvenanceStore::new();
883        store.record(b"f1".to_vec(), make_annotation("rule_a", None));
884        store.record(b"f2".to_vec(), make_annotation("rule_b", None));
885        store.record(b"f3".to_vec(), make_annotation("rule_a", None));
886
887        let entries = store.entries_for_rule("rule_a");
888        assert_eq!(entries.len(), 2);
889        let entries_b = store.entries_for_rule("rule_b");
890        assert_eq!(entries_b.len(), 1);
891    }
892}