Skip to main content

uni_query/query/df_graph/
locy_explain.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! EXPLAIN RULE derivation tree construction.
5//!
6//! Ported from `uni-locy/src/orchestrator/explain.rs`. Uses `DerivedFactSource`
7//! instead of `CypherExecutor`. Uses `RowStore` for row-based fact storage.
8//!
9//! Implements Mode A (provenance-based, uses ProvenanceStore recorded during fixpoint)
10//! with fallback to Mode B (re-execution) when tracker has no entries for the rule.
11
12use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, FactRow, LocyConfig, LocyError, LocyStats};
19
20use super::locy_delta::{
21    KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{eval_expr, record_batches_to_locy_rows, values_equal_for_join};
25use super::locy_slg::SLGResolver;
26use super::locy_traits::DerivedFactSource;
27
28/// A single term in a derivation proof: identifies one IS-ref dependency.
29///
30/// Each `ProofTerm` records a dependency edge in the proof DAG — the source
31/// rule that was referenced and the opaque hash of the specific base fact
32/// consumed (Cui & Widom 2000, lineage).
33#[derive(Clone, Debug)]
34pub struct ProofTerm {
35    /// Name of the IS-ref rule that produced this dependency.
36    pub source_rule: String,
37    /// Opaque hash identifying the base fact consumed from `source_rule`.
38    pub base_fact_id: Vec<u8>,
39}
40
41/// Provenance annotation for a derived fact (Green et al. 2007, Definition 3.1).
42///
43/// Captures a single derivation witness: the rule and clause that produced the
44/// fact, plus the `support` set of proof terms that contributed to it.
45#[derive(Clone, Debug)]
46pub struct ProvenanceAnnotation {
47    /// Name of the rule that derived this fact.
48    pub rule_name: String,
49    /// Index of the clause within the rule that produced this fact.
50    pub clause_index: usize,
51    /// Proof terms: IS-ref input facts (populated when IS-ref tracking is available).
52    pub support: Vec<ProofTerm>,
53    /// ALONG column values captured at derivation time.
54    pub along_values: HashMap<String, Value>,
55    /// Fixpoint iteration number when the fact was first derived.
56    pub iteration: usize,
57    /// Full fact row stored for Mode A filtering/display.
58    pub fact_row: FactRow,
59    /// Probability of this specific proof path (populated by top-k filtering).
60    pub proof_probability: Option<f64>,
61}
62
63/// Provenance store for derived facts (Green et al. 2007, §3).
64///
65/// Stores provenance annotations keyed by opaque fact hashes. Enables
66/// Mode A (provenance-based) EXPLAIN without re-execution.
67/// First-derivation-wins: once a fact hash is recorded, later iterations
68/// do not overwrite it.
69#[derive(Debug)]
70pub struct ProvenanceStore {
71    entries: RwLock<HashMap<Vec<u8>, Vec<ProvenanceAnnotation>>>,
72}
73
74impl ProvenanceStore {
75    pub fn new() -> Self {
76        Self {
77            entries: RwLock::new(HashMap::new()),
78        }
79    }
80
81    /// Record a provenance annotation. First-derivation-wins: if the hash is already
82    /// present, the existing annotation is kept (unlimited mode).
83    pub fn record(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation) {
84        if let Ok(mut guard) = self.entries.write() {
85            guard.entry(fact_hash).or_insert_with(|| vec![entry]);
86        }
87    }
88
89    /// Record a provenance annotation with top-k filtering.
90    ///
91    /// Retains at most `k` annotations per fact, ordered by `proof_probability`
92    /// (highest first). Annotations without a proof probability are treated as
93    /// having probability 0.0 for ordering purposes.
94    pub fn record_top_k(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation, k: usize) {
95        if let Ok(mut guard) = self.entries.write() {
96            let vec = guard.entry(fact_hash).or_default();
97            vec.push(entry);
98            // Sort descending by proof_probability.
99            vec.sort_by(|a, b| {
100                b.proof_probability
101                    .unwrap_or(0.0)
102                    .partial_cmp(&a.proof_probability.unwrap_or(0.0))
103                    .unwrap_or(std::cmp::Ordering::Equal)
104            });
105            vec.truncate(k);
106        }
107    }
108
109    /// Look up the first (highest-priority) provenance annotation for a fact hash.
110    pub fn lookup(&self, fact_hash: &[u8]) -> Option<ProvenanceAnnotation> {
111        self.entries.read().ok()?.get(fact_hash)?.first().cloned()
112    }
113
114    /// Look up all provenance annotations for a fact hash.
115    pub fn lookup_all(&self, fact_hash: &[u8]) -> Option<Vec<ProvenanceAnnotation>> {
116        let guard = self.entries.read().ok()?;
117        guard.get(fact_hash).cloned()
118    }
119
120    /// Collect base fact probabilities from stored annotations.
121    ///
122    /// Scans all annotations for base facts (those with empty support, i.e. leaf
123    /// nodes in the proof tree) and extracts the PROB column value from their
124    /// `fact_row`. Used by top-k proof filtering to compute `proof_probability`.
125    pub fn base_fact_probs(&self) -> HashMap<Vec<u8>, f64> {
126        let mut probs = HashMap::new();
127        if let Ok(guard) = self.entries.read() {
128            for (fact_hash, annotations) in guard.iter() {
129                if let Some(ann) = annotations.first()
130                    && ann.support.is_empty()
131                    && let Some(uni_common::Value::Float(p)) = ann.fact_row.get("PROB")
132                {
133                    probs.insert(fact_hash.clone(), *p);
134                }
135            }
136        }
137        probs
138    }
139
140    /// Get all entries for a specific rule name (returns first annotation per fact).
141    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, ProvenanceAnnotation)> {
142        match self.entries.read() {
143            Ok(guard) => guard
144                .iter()
145                .filter_map(|(k, annotations)| {
146                    annotations
147                        .first()
148                        .filter(|e| e.rule_name == rule_name)
149                        .map(|e| (k.clone(), e.clone()))
150                })
151                .collect(),
152            Err(_) => vec![],
153        }
154    }
155}
156
157/// Compute the probability of a single proof path from its support terms.
158///
159/// Returns `None` if any base fact's probability is unknown.
160pub fn compute_proof_probability(
161    support: &[ProofTerm],
162    base_probs: &HashMap<Vec<u8>, f64>,
163) -> Option<f64> {
164    if support.is_empty() {
165        return None;
166    }
167    let mut product = 1.0;
168    for term in support {
169        let p = base_probs.get(&term.base_fact_id)?;
170        product *= p;
171    }
172    Some(product)
173}
174
175impl Default for ProvenanceStore {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Set of (rule_name, key_tuple) to detect cycles during recursive derivation (Mode B).
182type VisitedSet = HashSet<(String, KeyTuple)>;
183
184/// Build a derivation tree for a rule, showing how each fact was derived.
185///
186/// Tries Mode A (provenance-based, uses ProvenanceStore) first when a store is
187/// provided and has entries for the rule.  Falls through to Mode B (re-execution)
188/// when Mode A cannot produce a result.
189#[expect(
190    clippy::too_many_arguments,
191    reason = "explain requires full program context and tracker state"
192)]
193pub async fn explain_rule(
194    query: &ExplainRule,
195    program: &CompiledProgram,
196    fact_source: &dyn DerivedFactSource,
197    config: &LocyConfig,
198    derived_store: &mut RowStore,
199    stats: &mut LocyStats,
200    tracker: Option<&ProvenanceStore>,
201    approximate_groups: Option<&HashMap<String, Vec<String>>>,
202) -> Result<DerivationNode, LocyError> {
203    // Mode A: provenance-based (no re-execution required).
204    // Falls through to Mode B when tracker is absent or has no matching entries.
205    if let Some(Ok(node)) =
206        tracker.map(|t| explain_rule_mode_a(query, program, t, derived_store, approximate_groups))
207    {
208        return Ok(node);
209    }
210
211    // Mode B: re-execution fallback
212    explain_rule_mode_b(
213        query,
214        program,
215        fact_source,
216        config,
217        derived_store,
218        stats,
219        approximate_groups,
220    )
221    .await
222}
223
224/// Mode A: build derivation tree using recorded provenance from the fixpoint loop.
225///
226/// Returns `Err` when no tracker entries exist for the rule (signals Mode B fallback).
227fn explain_rule_mode_a(
228    query: &ExplainRule,
229    program: &CompiledProgram,
230    tracker: &ProvenanceStore,
231    _derived_store: &RowStore,
232    approximate_groups: Option<&HashMap<String, Vec<String>>>,
233) -> Result<DerivationNode, LocyError> {
234    let rule_name = query.rule_name.to_string();
235    let rule = program
236        .rule_catalog
237        .get(&rule_name)
238        .ok_or_else(|| LocyError::EvaluationError {
239            message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
240        })?;
241
242    let tracker_entries = tracker.entries_for_rule(&rule_name);
243    if tracker_entries.is_empty() {
244        return Err(LocyError::EvaluationError {
245            message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
246        });
247    }
248
249    // Filter tracker entries by WHERE expression
250    let matching_entries: Vec<_> = if let Some(where_expr) = &query.where_expr {
251        tracker_entries
252            .into_iter()
253            .filter(|(_, entry)| {
254                eval_expr(where_expr, &entry.fact_row)
255                    .map(|v| v.as_bool().unwrap_or(false))
256                    .unwrap_or(false)
257            })
258            .collect()
259    } else {
260        tracker_entries
261    };
262
263    if matching_entries.is_empty() {
264        return Err(LocyError::EvaluationError {
265            message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
266        });
267    }
268
269    let is_approximate = approximate_groups
270        .map(|ag| ag.contains_key(&rule_name))
271        .unwrap_or(false);
272
273    let mut root = DerivationNode {
274        rule: rule_name.clone(),
275        clause_index: 0,
276        priority: rule.priority,
277        bindings: HashMap::new(),
278        along_values: HashMap::new(),
279        children: Vec::new(),
280        graph_fact: None,
281        approximate: is_approximate,
282        proof_probability: None,
283    };
284
285    for (_, entry) in matching_entries {
286        let along_values = extract_along_values(&entry.fact_row, rule);
287        let clause_priority = rule
288            .clauses
289            .get(entry.clause_index)
290            .and_then(|c| c.priority);
291        let base_fact = format!(
292            "[iter={}] {}",
293            entry.iteration,
294            format_graph_fact(&entry.fact_row)
295        );
296        let graph_fact = if is_approximate {
297            format!("[APPROXIMATE] {}", base_fact)
298        } else {
299            base_fact
300        };
301        let node = DerivationNode {
302            rule: rule_name.clone(),
303            clause_index: entry.clause_index,
304            priority: clause_priority.or(rule.priority),
305            bindings: entry.fact_row.clone(),
306            along_values,
307            // Mode A: children not tracked (inputs list is reserved for future recursion)
308            children: vec![],
309            graph_fact: Some(graph_fact),
310            approximate: is_approximate,
311            proof_probability: entry.proof_probability,
312        };
313        root.children.push(node);
314    }
315
316    Ok(root)
317}
318
319/// Mode B: re-execution fallback — re-executes clause queries to find which
320/// clause produced each matching fact, then recurses into IS references.
321async fn explain_rule_mode_b(
322    query: &ExplainRule,
323    program: &CompiledProgram,
324    fact_source: &dyn DerivedFactSource,
325    config: &LocyConfig,
326    derived_store: &mut RowStore,
327    stats: &mut LocyStats,
328    approximate_groups: Option<&HashMap<String, Vec<String>>>,
329) -> Result<DerivationNode, LocyError> {
330    let rule_name = query.rule_name.to_string();
331    let rule = program
332        .rule_catalog
333        .get(&rule_name)
334        .ok_or_else(|| LocyError::EvaluationError {
335            message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
336        })?;
337
338    let key_columns: Vec<String> = rule
339        .yield_schema
340        .iter()
341        .filter(|c| c.is_key)
342        .map(|c| c.name.clone())
343        .collect();
344
345    // Re-evaluate the rule via SLG to obtain rows with full node objects and properties.
346    // The native fixpoint's orch_store has VID-only integers that fail property-based
347    // WHERE filters (e.g. a.name = 'A') — we need actual Value::Node values here.
348    {
349        let mut fresh_store = RowStore::new();
350        let slg_start = std::time::Instant::now();
351        let mut resolver =
352            SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
353        resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
354        stats.queries_executed += resolver.stats.queries_executed;
355        // Merge full-node facts into derived_store so IS-ref lookups in
356        // build_derivation_node also get proper node objects (not VIDs).
357        for (name, relation) in fresh_store {
358            derived_store.insert(name, relation);
359        }
360    }
361
362    // Get all derived facts for this rule (now populated with full node objects)
363    let facts = derived_store
364        .get(&rule_name)
365        .map(|r| r.rows.clone())
366        .unwrap_or_default();
367
368    // Filter facts by WHERE expression
369    let filtered: Vec<FactRow> = if let Some(where_expr) = &query.where_expr {
370        facts
371            .into_iter()
372            .filter(|row| {
373                eval_expr(where_expr, row)
374                    .map(|v| v.as_bool().unwrap_or(false))
375                    .unwrap_or(false)
376            })
377            .collect()
378    } else {
379        facts
380    };
381
382    let is_approximate = approximate_groups
383        .map(|ag| ag.contains_key(&rule_name))
384        .unwrap_or(false);
385
386    // Build derivation tree root
387    let mut root = DerivationNode {
388        rule: rule_name.clone(),
389        clause_index: 0,
390        priority: rule.priority,
391        bindings: HashMap::new(),
392        along_values: HashMap::new(),
393        children: Vec::new(),
394        graph_fact: None,
395        approximate: is_approximate,
396        proof_probability: None,
397    };
398
399    // For each matching fact, recursively build a derivation node
400    for fact in &filtered {
401        let mut visited = VisitedSet::new();
402        let mut node = build_derivation_node(
403            &rule_name,
404            fact,
405            &key_columns,
406            program,
407            fact_source,
408            derived_store,
409            stats,
410            &mut visited,
411            config.max_explain_depth,
412        )
413        .await?;
414        if is_approximate {
415            node.approximate = true;
416            if let Some(ref gf) = node.graph_fact {
417                node.graph_fact = Some(format!("[APPROXIMATE] {}", gf));
418            }
419        }
420        root.children.push(node);
421    }
422
423    Ok(root)
424}
425
426/// Recursively build a derivation node for a single fact of a rule.
427///
428/// Finds which clause produced this fact, extracts ALONG values,
429/// and recurses into IS reference dependencies.
430#[expect(
431    clippy::too_many_arguments,
432    reason = "recursive derivation node builder requires full fact context"
433)]
434fn build_derivation_node<'a>(
435    rule_name: &'a str,
436    fact: &'a FactRow,
437    key_columns: &'a [String],
438    program: &'a CompiledProgram,
439    fact_source: &'a dyn DerivedFactSource,
440    derived_store: &'a mut RowStore,
441    stats: &'a mut LocyStats,
442    visited: &'a mut VisitedSet,
443    max_depth: usize,
444) -> std::pin::Pin<
445    Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + Send + 'a>,
446> {
447    Box::pin(async move {
448        let rule =
449            program
450                .rule_catalog
451                .get(rule_name)
452                .ok_or_else(|| LocyError::EvaluationError {
453                    message: format!("rule '{}' not found during EXPLAIN", rule_name),
454                })?;
455
456        let key_tuple = extract_key(fact, key_columns);
457        let visit_key = (rule_name.to_string(), key_tuple);
458
459        // Cycle detection
460        if !visited.insert(visit_key.clone()) || max_depth == 0 {
461            return Ok(DerivationNode {
462                rule: rule_name.to_string(),
463                clause_index: 0,
464                priority: rule.priority,
465                bindings: fact.clone(),
466                along_values: extract_along_values(fact, rule),
467                children: Vec::new(),
468                graph_fact: Some("(cycle)".to_string()),
469                approximate: false,
470                proof_probability: None,
471            });
472        }
473
474        // Match on KEY columns only.  Clause-level resolution returns only
475        // base graph bindings (vertex/edge identifiers); non-KEY yield columns
476        // (FOLD-aggregated, similar_to, etc.) are absent from those rows.
477        // KEY columns uniquely identify a derived fact, so this is sufficient.
478
479        // Try each clause to find the one that produced this fact
480        for (clause_idx, clause) in rule.clauses.iter().enumerate() {
481            let has_is_refs = clause
482                .where_conditions
483                .iter()
484                .any(|c| matches!(c, RuleCondition::IsReference(_)));
485            let has_along = !clause.along.is_empty();
486
487            let resolved = if has_is_refs || has_along {
488                let rows = resolve_clause_with_is_refs(
489                    clause,
490                    fact_source,
491                    derived_store,
492                    &program.rule_catalog,
493                    None, // EXPLAIN traces proofs, doesn't compute probabilities
494                )
495                .await?;
496                stats.queries_executed += 1;
497                rows
498            } else {
499                let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
500                let raw_batches = fact_source
501                    .execute_pattern(&clause.match_pattern, &cypher_conditions)
502                    .await?;
503                stats.queries_executed += 1;
504                record_batches_to_locy_rows(&raw_batches)
505            };
506
507            // Use values_equal_for_join for VID/EID-based comparison: sidecar
508            // schema mode can add `overflow_json: Null` to nodes in some query
509            // paths, making structural equality unreliable.
510            let matching_row = resolved.iter().find(|row| {
511                key_columns.iter().all(|k| match (row.get(k), fact.get(k)) {
512                    (Some(v1), Some(v2)) => values_equal_for_join(v1, v2),
513                    (None, None) => true,
514                    _ => false,
515                })
516            });
517
518            if let Some(evidence_row) = matching_row {
519                let along_values = extract_along_values(fact, rule);
520
521                // Build children by recursing into IS references
522                let mut children = Vec::new();
523                for cond in &clause.where_conditions {
524                    if let RuleCondition::IsReference(is_ref) = cond {
525                        if is_ref.negated {
526                            continue;
527                        }
528                        let ref_rule_name = is_ref.rule_name.to_string();
529                        if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
530                            let ref_key_columns: Vec<String> = ref_rule
531                                .yield_schema
532                                .iter()
533                                .filter(|c| c.is_key)
534                                .map(|c| c.name.clone())
535                                .collect();
536
537                            let ref_facts: Vec<FactRow> = derived_store
538                                .get(&ref_rule_name)
539                                .map(|r| r.rows.clone())
540                                .unwrap_or_default();
541
542                            let matching_ref_facts: Vec<FactRow> = ref_facts
543                                .into_iter()
544                                .filter(|ref_fact| {
545                                    let subjects_match =
546                                        is_ref.subjects.iter().enumerate().all(|(i, subject)| {
547                                            binding_matches_key(
548                                                evidence_row,
549                                                fact,
550                                                subject,
551                                                ref_fact,
552                                                ref_key_columns.get(i),
553                                            )
554                                        });
555                                    let target_matches =
556                                        is_ref.target.as_ref().is_none_or(|target| {
557                                            binding_matches_key(
558                                                evidence_row,
559                                                fact,
560                                                target,
561                                                ref_fact,
562                                                ref_key_columns.get(is_ref.subjects.len()),
563                                            )
564                                        });
565                                    subjects_match && target_matches
566                                })
567                                .collect();
568
569                            for ref_fact in matching_ref_facts {
570                                let child = build_derivation_node(
571                                    &ref_rule_name,
572                                    &ref_fact,
573                                    &ref_key_columns,
574                                    program,
575                                    fact_source,
576                                    derived_store,
577                                    stats,
578                                    visited,
579                                    max_depth - 1,
580                                )
581                                .await?;
582                                children.push(child);
583                            }
584                        }
585                    }
586                }
587
588                // Backtrack visited set
589                visited.remove(&visit_key);
590
591                let mut merged_bindings = evidence_row.clone();
592                for (k, v) in fact {
593                    merged_bindings.entry(k.clone()).or_insert(v.clone());
594                }
595
596                return Ok(DerivationNode {
597                    rule: rule_name.to_string(),
598                    clause_index: clause_idx,
599                    priority: rule.clauses[clause_idx].priority,
600                    bindings: merged_bindings,
601                    along_values,
602                    children,
603                    graph_fact: Some(format_graph_fact(evidence_row)),
604                    approximate: false,
605                    proof_probability: None,
606                });
607            }
608        }
609
610        // No clause matched — leaf node
611        visited.remove(&visit_key);
612        Ok(DerivationNode {
613            rule: rule_name.to_string(),
614            clause_index: 0,
615            priority: rule.priority,
616            bindings: fact.clone(),
617            along_values: extract_along_values(fact, rule),
618            children: Vec::new(),
619            graph_fact: Some(format_graph_fact(fact)),
620            approximate: false,
621            proof_probability: None,
622        })
623    })
624}
625
626/// Check if a binding variable matches a ref-fact key column via VID-based join.
627///
628/// Looks up `var_name` in `primary` (falling back to `fallback`), then compares
629/// it against `ref_key_col` in `ref_fact` using `values_equal_for_join`.
630/// Returns `true` when the key column is out of range or the binding is absent.
631fn binding_matches_key(
632    primary: &FactRow,
633    fallback: &FactRow,
634    var_name: &str,
635    ref_fact: &FactRow,
636    ref_key_col: Option<&String>,
637) -> bool {
638    let Some(key_col) = ref_key_col else {
639        return true;
640    };
641    let Some(val) = primary.get(var_name).or_else(|| fallback.get(var_name)) else {
642        return true;
643    };
644    ref_fact
645        .get(key_col)
646        .is_some_and(|rv| values_equal_for_join(rv, val))
647}
648
649fn extract_along_values(fact: &FactRow, rule: &CompiledRule) -> HashMap<String, Value> {
650    let mut along_values = HashMap::new();
651    for clause in &rule.clauses {
652        for along in &clause.along {
653            if let Some(v) = fact.get(&along.name) {
654                along_values.insert(along.name.clone(), v.clone());
655            }
656        }
657    }
658    along_values
659}
660
661pub(crate) fn format_graph_fact(row: &FactRow) -> String {
662    let mut entries: Vec<String> = row
663        .iter()
664        .map(|(k, v)| format!("{}: {}", k, format_value(v)))
665        .collect();
666    entries.sort();
667    format!("{{{}}}", entries.join(", "))
668}
669
670fn format_value(v: &Value) -> String {
671    match v {
672        Value::Null => "null".to_string(),
673        Value::Bool(b) => b.to_string(),
674        Value::Int(i) => i.to_string(),
675        Value::Float(f) => f.to_string(),
676        Value::String(s) => format!("\"{}\"", s),
677        Value::List(items) => {
678            let inner: Vec<String> = items.iter().map(format_value).collect();
679            format!("[{}]", inner.join(", "))
680        }
681        Value::Map(m) => {
682            let mut entries: Vec<String> = m
683                .iter()
684                .map(|(k, v)| format!("{}: {}", k, format_value(v)))
685                .collect();
686            entries.sort();
687            format!("{{{}}}", entries.join(", "))
688        }
689        Value::Node(n) => format!("Node({})", n.vid.as_u64()),
690        Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
691        _ => format!("{v:?}"),
692    }
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698
699    fn make_annotation(rule: &str, prob: Option<f64>) -> ProvenanceAnnotation {
700        ProvenanceAnnotation {
701            rule_name: rule.to_string(),
702            clause_index: 0,
703            support: vec![],
704            along_values: HashMap::new(),
705            iteration: 0,
706            fact_row: HashMap::new(),
707            proof_probability: prob,
708        }
709    }
710
711    #[test]
712    fn record_first_derivation_wins() {
713        let store = ProvenanceStore::new();
714        let hash = b"fact1".to_vec();
715        store.record(hash.clone(), make_annotation("rule_a", None));
716        store.record(hash.clone(), make_annotation("rule_b", None));
717        let entry = store.lookup(&hash).unwrap();
718        assert_eq!(entry.rule_name, "rule_a");
719    }
720
721    #[test]
722    fn lookup_returns_first_annotation() {
723        let store = ProvenanceStore::new();
724        let hash = b"fact1".to_vec();
725        store.record(hash.clone(), make_annotation("rule_a", None));
726        assert_eq!(store.lookup(&hash).unwrap().rule_name, "rule_a");
727        assert!(store.lookup(b"nonexistent").is_none());
728    }
729
730    #[test]
731    fn lookup_all_returns_all_annotations() {
732        let store = ProvenanceStore::new();
733        let hash = b"fact1".to_vec();
734        // record() is first-wins, so only one entry per hash via record()
735        store.record(hash.clone(), make_annotation("rule_a", None));
736        let all = store.lookup_all(&hash).unwrap();
737        assert_eq!(all.len(), 1);
738    }
739
740    #[test]
741    fn record_top_k_retains_highest() {
742        let store = ProvenanceStore::new();
743        let hash = b"fact1".to_vec();
744        store.record_top_k(hash.clone(), make_annotation("low", Some(0.1)), 2);
745        store.record_top_k(hash.clone(), make_annotation("high", Some(0.9)), 2);
746        store.record_top_k(hash.clone(), make_annotation("mid", Some(0.5)), 2);
747        store.record_top_k(hash.clone(), make_annotation("highest", Some(0.95)), 2);
748        store.record_top_k(hash.clone(), make_annotation("lowest", Some(0.01)), 2);
749
750        let all = store.lookup_all(&hash).unwrap();
751        assert_eq!(all.len(), 2);
752        assert_eq!(all[0].rule_name, "highest");
753        assert_eq!(all[1].rule_name, "high");
754    }
755
756    #[test]
757    fn compute_proof_probability_basic() {
758        let support = vec![
759            ProofTerm {
760                source_rule: "r1".to_string(),
761                base_fact_id: b"f1".to_vec(),
762            },
763            ProofTerm {
764                source_rule: "r2".to_string(),
765                base_fact_id: b"f2".to_vec(),
766            },
767        ];
768        let base_probs = HashMap::from([(b"f1".to_vec(), 0.3), (b"f2".to_vec(), 0.5)]);
769        let prob = compute_proof_probability(&support, &base_probs);
770        assert!((prob.unwrap() - 0.15).abs() < 1e-10);
771    }
772
773    #[test]
774    fn compute_proof_probability_empty_support() {
775        let prob = compute_proof_probability(&[], &HashMap::new());
776        assert!(prob.is_none());
777    }
778
779    #[test]
780    fn compute_proof_probability_missing_fact() {
781        let support = vec![ProofTerm {
782            source_rule: "r1".to_string(),
783            base_fact_id: b"unknown".to_vec(),
784        }];
785        let prob = compute_proof_probability(&support, &HashMap::new());
786        assert!(prob.is_none());
787    }
788
789    #[test]
790    fn entries_for_rule_filters_correctly() {
791        let store = ProvenanceStore::new();
792        store.record(b"f1".to_vec(), make_annotation("rule_a", None));
793        store.record(b"f2".to_vec(), make_annotation("rule_b", None));
794        store.record(b"f3".to_vec(), make_annotation("rule_a", None));
795
796        let entries = store.entries_for_rule("rule_a");
797        assert_eq!(entries.len(), 2);
798        let entries_b = store.entries_for_rule("rule_b");
799        assert_eq!(entries_b.len(), 1);
800    }
801}