Skip to main content

uni_query/query/df_graph/
locy_explain.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! EXPLAIN RULE derivation tree construction.
5//!
6//! Ported from `uni-locy/src/orchestrator/explain.rs`. Uses `DerivedFactSource`
7//! instead of `CypherExecutor`. Uses `RowStore` for row-based fact storage.
8//!
9//! Implements Mode A (provenance-based, uses ProvenanceStore recorded during fixpoint)
10//! with fallback to Mode B (re-execution) when tracker has no entries for the rule.
11
12use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, LocyConfig, LocyError, LocyStats, Row};
19
20use super::locy_delta::{
21    KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{eval_expr, record_batches_to_locy_rows, values_equal_for_join};
25use super::locy_slg::SLGResolver;
26use super::locy_traits::DerivedFactSource;
27
28/// A single term in a derivation proof: identifies one IS-ref dependency.
29///
30/// Each `ProofTerm` records a dependency edge in the proof DAG — the source
31/// rule that was referenced and the opaque hash of the specific base fact
32/// consumed (Cui & Widom 2000, lineage).
33#[derive(Clone, Debug)]
34pub struct ProofTerm {
35    /// Name of the IS-ref rule that produced this dependency.
36    pub source_rule: String,
37    /// Opaque hash identifying the base fact consumed from `source_rule`.
38    pub base_fact_id: Vec<u8>,
39}
40
41/// Provenance annotation for a derived fact (Green et al. 2007, Definition 3.1).
42///
43/// Captures a single derivation witness: the rule and clause that produced the
44/// fact, plus the `support` set of proof terms that contributed to it.
45#[derive(Clone, Debug)]
46pub struct ProvenanceAnnotation {
47    /// Name of the rule that derived this fact.
48    pub rule_name: String,
49    /// Index of the clause within the rule that produced this fact.
50    pub clause_index: usize,
51    /// Proof terms: IS-ref input facts (populated when IS-ref tracking is available).
52    pub support: Vec<ProofTerm>,
53    /// ALONG column values captured at derivation time.
54    pub along_values: HashMap<String, Value>,
55    /// Fixpoint iteration number when the fact was first derived.
56    pub iteration: usize,
57    /// Full fact row stored for Mode A filtering/display.
58    pub fact_row: Row,
59    /// Probability of this specific proof path (populated by top-k filtering).
60    pub proof_probability: Option<f64>,
61}
62
63/// Provenance store for derived facts (Green et al. 2007, §3).
64///
65/// Stores provenance annotations keyed by opaque fact hashes. Enables
66/// Mode A (provenance-based) EXPLAIN without re-execution.
67/// First-derivation-wins: once a fact hash is recorded, later iterations
68/// do not overwrite it.
69#[derive(Debug)]
70pub struct ProvenanceStore {
71    entries: RwLock<HashMap<Vec<u8>, Vec<ProvenanceAnnotation>>>,
72}
73
74impl ProvenanceStore {
75    pub fn new() -> Self {
76        Self {
77            entries: RwLock::new(HashMap::new()),
78        }
79    }
80
81    /// Record a provenance annotation. First-derivation-wins: if the hash is already
82    /// present, the existing annotation is kept (unlimited mode).
83    pub fn record(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation) {
84        if let Ok(mut guard) = self.entries.write() {
85            guard.entry(fact_hash).or_insert_with(|| vec![entry]);
86        }
87    }
88
89    /// Record a provenance annotation with top-k filtering.
90    ///
91    /// Retains at most `k` annotations per fact, ordered by `proof_probability`
92    /// (highest first). Annotations without a proof probability are treated as
93    /// having probability 0.0 for ordering purposes.
94    pub fn record_top_k(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation, k: usize) {
95        if let Ok(mut guard) = self.entries.write() {
96            let vec = guard.entry(fact_hash).or_default();
97            vec.push(entry);
98            // Sort descending by proof_probability.
99            vec.sort_by(|a, b| {
100                b.proof_probability
101                    .unwrap_or(0.0)
102                    .partial_cmp(&a.proof_probability.unwrap_or(0.0))
103                    .unwrap_or(std::cmp::Ordering::Equal)
104            });
105            vec.truncate(k);
106        }
107    }
108
109    /// Look up the first (highest-priority) provenance annotation for a fact hash.
110    pub fn lookup(&self, fact_hash: &[u8]) -> Option<ProvenanceAnnotation> {
111        self.entries.read().ok()?.get(fact_hash)?.first().cloned()
112    }
113
114    /// Look up all provenance annotations for a fact hash.
115    pub fn lookup_all(&self, fact_hash: &[u8]) -> Option<Vec<ProvenanceAnnotation>> {
116        let guard = self.entries.read().ok()?;
117        guard.get(fact_hash).cloned()
118    }
119
120    /// Collect base fact probabilities from stored annotations.
121    ///
122    /// Scans all annotations for base facts (those with empty support, i.e. leaf
123    /// nodes in the proof tree) and extracts the PROB column value from their
124    /// `fact_row`. Used by top-k proof filtering to compute `proof_probability`.
125    pub fn base_fact_probs(&self) -> HashMap<Vec<u8>, f64> {
126        let mut probs = HashMap::new();
127        if let Ok(guard) = self.entries.read() {
128            for (fact_hash, annotations) in guard.iter() {
129                if let Some(ann) = annotations.first()
130                    && ann.support.is_empty()
131                    && let Some(uni_common::Value::Float(p)) = ann.fact_row.get("PROB")
132                {
133                    probs.insert(fact_hash.clone(), *p);
134                }
135            }
136        }
137        probs
138    }
139
140    /// Get all entries for a specific rule name (returns first annotation per fact).
141    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, ProvenanceAnnotation)> {
142        match self.entries.read() {
143            Ok(guard) => guard
144                .iter()
145                .filter_map(|(k, annotations)| {
146                    annotations
147                        .first()
148                        .filter(|e| e.rule_name == rule_name)
149                        .map(|e| (k.clone(), e.clone()))
150                })
151                .collect(),
152            Err(_) => vec![],
153        }
154    }
155}
156
157/// Compute the probability of a single proof path from its support terms.
158///
159/// Returns `None` if any base fact's probability is unknown.
160pub fn compute_proof_probability(
161    support: &[ProofTerm],
162    base_probs: &HashMap<Vec<u8>, f64>,
163) -> Option<f64> {
164    if support.is_empty() {
165        return None;
166    }
167    let mut product = 1.0;
168    for term in support {
169        let p = base_probs.get(&term.base_fact_id)?;
170        product *= p;
171    }
172    Some(product)
173}
174
175impl Default for ProvenanceStore {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Set of (rule_name, key_tuple) to detect cycles during recursive derivation (Mode B).
182type VisitedSet = HashSet<(String, KeyTuple)>;
183
184/// Build a derivation tree for a rule, showing how each fact was derived.
185///
186/// Tries Mode A (provenance-based, uses ProvenanceStore) first when a store is
187/// provided and has entries for the rule.  Falls through to Mode B (re-execution)
188/// when Mode A cannot produce a result.
189#[expect(
190    clippy::too_many_arguments,
191    reason = "explain requires full program context and tracker state"
192)]
193pub async fn explain_rule(
194    query: &ExplainRule,
195    program: &CompiledProgram,
196    fact_source: &dyn DerivedFactSource,
197    config: &LocyConfig,
198    derived_store: &mut RowStore,
199    stats: &mut LocyStats,
200    tracker: Option<&ProvenanceStore>,
201    approximate_groups: Option<&HashMap<String, Vec<String>>>,
202) -> Result<DerivationNode, LocyError> {
203    // Mode A: provenance-based (no re-execution required).
204    // Falls through to Mode B when tracker is absent or has no matching entries.
205    if let Some(Ok(node)) =
206        tracker.map(|t| explain_rule_mode_a(query, program, t, derived_store, approximate_groups))
207    {
208        return Ok(node);
209    }
210
211    // Mode B: re-execution fallback
212    explain_rule_mode_b(
213        query,
214        program,
215        fact_source,
216        config,
217        derived_store,
218        stats,
219        approximate_groups,
220    )
221    .await
222}
223
224/// Mode A: build derivation tree using recorded provenance from the fixpoint loop.
225///
226/// Returns `Err` when no tracker entries exist for the rule (signals Mode B fallback).
227fn explain_rule_mode_a(
228    query: &ExplainRule,
229    program: &CompiledProgram,
230    tracker: &ProvenanceStore,
231    _derived_store: &RowStore,
232    approximate_groups: Option<&HashMap<String, Vec<String>>>,
233) -> Result<DerivationNode, LocyError> {
234    let rule_name = query.rule_name.to_string();
235    let rule = program
236        .rule_catalog
237        .get(&rule_name)
238        .ok_or_else(|| LocyError::EvaluationError {
239            message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
240        })?;
241
242    let tracker_entries = tracker.entries_for_rule(&rule_name);
243    if tracker_entries.is_empty() {
244        return Err(LocyError::EvaluationError {
245            message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
246        });
247    }
248
249    // Filter tracker entries by WHERE expression
250    let matching_entries: Vec<_> = if let Some(where_expr) = &query.where_expr {
251        tracker_entries
252            .into_iter()
253            .filter(|(_, entry)| {
254                eval_expr(where_expr, &entry.fact_row)
255                    .map(|v| v.as_bool().unwrap_or(false))
256                    .unwrap_or(false)
257            })
258            .collect()
259    } else {
260        tracker_entries
261    };
262
263    if matching_entries.is_empty() {
264        return Err(LocyError::EvaluationError {
265            message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
266        });
267    }
268
269    let is_approximate = approximate_groups
270        .map(|ag| ag.contains_key(&rule_name))
271        .unwrap_or(false);
272
273    let mut root = DerivationNode {
274        rule: rule_name.clone(),
275        clause_index: 0,
276        priority: rule.priority,
277        bindings: HashMap::new(),
278        along_values: HashMap::new(),
279        children: Vec::new(),
280        graph_fact: None,
281        approximate: is_approximate,
282        proof_probability: None,
283    };
284
285    for (_, entry) in matching_entries {
286        let along_values = extract_along_values(&entry.fact_row, rule);
287        let clause_priority = rule
288            .clauses
289            .get(entry.clause_index)
290            .and_then(|c| c.priority);
291        let base_fact = format!(
292            "[iter={}] {}",
293            entry.iteration,
294            format_graph_fact(&entry.fact_row)
295        );
296        let graph_fact = if is_approximate {
297            format!("[APPROXIMATE] {}", base_fact)
298        } else {
299            base_fact
300        };
301        let node = DerivationNode {
302            rule: rule_name.clone(),
303            clause_index: entry.clause_index,
304            priority: clause_priority.or(rule.priority),
305            bindings: entry.fact_row.clone(),
306            along_values,
307            // Mode A: children not tracked (inputs list is reserved for future recursion)
308            children: vec![],
309            graph_fact: Some(graph_fact),
310            approximate: is_approximate,
311            proof_probability: entry.proof_probability,
312        };
313        root.children.push(node);
314    }
315
316    Ok(root)
317}
318
319/// Mode B: re-execution fallback — re-executes clause queries to find which
320/// clause produced each matching fact, then recurses into IS references.
321async fn explain_rule_mode_b(
322    query: &ExplainRule,
323    program: &CompiledProgram,
324    fact_source: &dyn DerivedFactSource,
325    config: &LocyConfig,
326    derived_store: &mut RowStore,
327    stats: &mut LocyStats,
328    approximate_groups: Option<&HashMap<String, Vec<String>>>,
329) -> Result<DerivationNode, LocyError> {
330    let rule_name = query.rule_name.to_string();
331    let rule = program
332        .rule_catalog
333        .get(&rule_name)
334        .ok_or_else(|| LocyError::EvaluationError {
335            message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
336        })?;
337
338    let key_columns: Vec<String> = rule
339        .yield_schema
340        .iter()
341        .filter(|c| c.is_key)
342        .map(|c| c.name.clone())
343        .collect();
344
345    // Re-evaluate the rule via SLG to obtain rows with full node objects and properties.
346    // The native fixpoint's orch_store has VID-only integers that fail property-based
347    // WHERE filters (e.g. a.name = 'A') — we need actual Value::Node values here.
348    {
349        let mut fresh_store = RowStore::new();
350        let slg_start = std::time::Instant::now();
351        let mut resolver =
352            SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
353        resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
354        stats.queries_executed += resolver.stats.queries_executed;
355        // Merge full-node facts into derived_store so IS-ref lookups in
356        // build_derivation_node also get proper node objects (not VIDs).
357        for (name, relation) in fresh_store {
358            derived_store.insert(name, relation);
359        }
360    }
361
362    // Get all derived facts for this rule (now populated with full node objects)
363    let facts = derived_store
364        .get(&rule_name)
365        .map(|r| r.rows.clone())
366        .unwrap_or_default();
367
368    // Filter facts by WHERE expression
369    let filtered: Vec<Row> = if let Some(where_expr) = &query.where_expr {
370        facts
371            .into_iter()
372            .filter(|row| {
373                eval_expr(where_expr, row)
374                    .map(|v| v.as_bool().unwrap_or(false))
375                    .unwrap_or(false)
376            })
377            .collect()
378    } else {
379        facts
380    };
381
382    let is_approximate = approximate_groups
383        .map(|ag| ag.contains_key(&rule_name))
384        .unwrap_or(false);
385
386    // Build derivation tree root
387    let mut root = DerivationNode {
388        rule: rule_name.clone(),
389        clause_index: 0,
390        priority: rule.priority,
391        bindings: HashMap::new(),
392        along_values: HashMap::new(),
393        children: Vec::new(),
394        graph_fact: None,
395        approximate: is_approximate,
396        proof_probability: None,
397    };
398
399    // For each matching fact, recursively build a derivation node
400    for fact in &filtered {
401        let mut visited = VisitedSet::new();
402        let mut node = build_derivation_node(
403            &rule_name,
404            fact,
405            &key_columns,
406            program,
407            fact_source,
408            derived_store,
409            stats,
410            &mut visited,
411            config.max_explain_depth,
412        )
413        .await?;
414        if is_approximate {
415            node.approximate = true;
416            if let Some(ref gf) = node.graph_fact {
417                node.graph_fact = Some(format!("[APPROXIMATE] {}", gf));
418            }
419        }
420        root.children.push(node);
421    }
422
423    Ok(root)
424}
425
426/// Recursively build a derivation node for a single fact of a rule.
427///
428/// Finds which clause produced this fact, extracts ALONG values,
429/// and recurses into IS reference dependencies.
430#[expect(
431    clippy::too_many_arguments,
432    reason = "recursive derivation node builder requires full fact context"
433)]
434fn build_derivation_node<'a>(
435    rule_name: &'a str,
436    fact: &'a Row,
437    key_columns: &'a [String],
438    program: &'a CompiledProgram,
439    fact_source: &'a dyn DerivedFactSource,
440    derived_store: &'a mut RowStore,
441    stats: &'a mut LocyStats,
442    visited: &'a mut VisitedSet,
443    max_depth: usize,
444) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + 'a>> {
445    Box::pin(async move {
446        let rule =
447            program
448                .rule_catalog
449                .get(rule_name)
450                .ok_or_else(|| LocyError::EvaluationError {
451                    message: format!("rule '{}' not found during EXPLAIN", rule_name),
452                })?;
453
454        let key_tuple = extract_key(fact, key_columns);
455        let visit_key = (rule_name.to_string(), key_tuple);
456
457        // Cycle detection
458        if !visited.insert(visit_key.clone()) || max_depth == 0 {
459            return Ok(DerivationNode {
460                rule: rule_name.to_string(),
461                clause_index: 0,
462                priority: rule.priority,
463                bindings: fact.clone(),
464                along_values: extract_along_values(fact, rule),
465                children: Vec::new(),
466                graph_fact: Some("(cycle)".to_string()),
467                approximate: false,
468                proof_probability: None,
469            });
470        }
471
472        // Match on KEY columns only.  Clause-level resolution returns only
473        // base graph bindings (vertex/edge identifiers); non-KEY yield columns
474        // (FOLD-aggregated, similar_to, etc.) are absent from those rows.
475        // KEY columns uniquely identify a derived fact, so this is sufficient.
476
477        // Try each clause to find the one that produced this fact
478        for (clause_idx, clause) in rule.clauses.iter().enumerate() {
479            let has_is_refs = clause
480                .where_conditions
481                .iter()
482                .any(|c| matches!(c, RuleCondition::IsReference(_)));
483            let has_along = !clause.along.is_empty();
484
485            let resolved = if has_is_refs || has_along {
486                let rows = resolve_clause_with_is_refs(clause, fact_source, derived_store).await?;
487                stats.queries_executed += 1;
488                rows
489            } else {
490                let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
491                let raw_batches = fact_source
492                    .execute_pattern(&clause.match_pattern, &cypher_conditions)
493                    .await?;
494                stats.queries_executed += 1;
495                record_batches_to_locy_rows(&raw_batches)
496            };
497
498            // Use values_equal_for_join for VID/EID-based comparison: sidecar
499            // schema mode can add `overflow_json: Null` to nodes in some query
500            // paths, making structural equality unreliable.
501            let matching_row = resolved.iter().find(|row| {
502                key_columns.iter().all(|k| match (row.get(k), fact.get(k)) {
503                    (Some(v1), Some(v2)) => values_equal_for_join(v1, v2),
504                    (None, None) => true,
505                    _ => false,
506                })
507            });
508
509            if let Some(evidence_row) = matching_row {
510                let along_values = extract_along_values(fact, rule);
511
512                // Build children by recursing into IS references
513                let mut children = Vec::new();
514                for cond in &clause.where_conditions {
515                    if let RuleCondition::IsReference(is_ref) = cond {
516                        if is_ref.negated {
517                            continue;
518                        }
519                        let ref_rule_name = is_ref.rule_name.to_string();
520                        if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
521                            let ref_key_columns: Vec<String> = ref_rule
522                                .yield_schema
523                                .iter()
524                                .filter(|c| c.is_key)
525                                .map(|c| c.name.clone())
526                                .collect();
527
528                            let ref_facts: Vec<Row> = derived_store
529                                .get(&ref_rule_name)
530                                .map(|r| r.rows.clone())
531                                .unwrap_or_default();
532
533                            let matching_ref_facts: Vec<Row> = ref_facts
534                                .into_iter()
535                                .filter(|ref_fact| {
536                                    let subjects_match =
537                                        is_ref.subjects.iter().enumerate().all(|(i, subject)| {
538                                            binding_matches_key(
539                                                evidence_row,
540                                                fact,
541                                                subject,
542                                                ref_fact,
543                                                ref_key_columns.get(i),
544                                            )
545                                        });
546                                    let target_matches =
547                                        is_ref.target.as_ref().is_none_or(|target| {
548                                            binding_matches_key(
549                                                evidence_row,
550                                                fact,
551                                                target,
552                                                ref_fact,
553                                                ref_key_columns.get(is_ref.subjects.len()),
554                                            )
555                                        });
556                                    subjects_match && target_matches
557                                })
558                                .collect();
559
560                            for ref_fact in matching_ref_facts {
561                                let child = build_derivation_node(
562                                    &ref_rule_name,
563                                    &ref_fact,
564                                    &ref_key_columns,
565                                    program,
566                                    fact_source,
567                                    derived_store,
568                                    stats,
569                                    visited,
570                                    max_depth - 1,
571                                )
572                                .await?;
573                                children.push(child);
574                            }
575                        }
576                    }
577                }
578
579                // Backtrack visited set
580                visited.remove(&visit_key);
581
582                let mut merged_bindings = evidence_row.clone();
583                for (k, v) in fact {
584                    merged_bindings.entry(k.clone()).or_insert(v.clone());
585                }
586
587                return Ok(DerivationNode {
588                    rule: rule_name.to_string(),
589                    clause_index: clause_idx,
590                    priority: rule.clauses[clause_idx].priority,
591                    bindings: merged_bindings,
592                    along_values,
593                    children,
594                    graph_fact: Some(format_graph_fact(evidence_row)),
595                    approximate: false,
596                    proof_probability: None,
597                });
598            }
599        }
600
601        // No clause matched — leaf node
602        visited.remove(&visit_key);
603        Ok(DerivationNode {
604            rule: rule_name.to_string(),
605            clause_index: 0,
606            priority: rule.priority,
607            bindings: fact.clone(),
608            along_values: extract_along_values(fact, rule),
609            children: Vec::new(),
610            graph_fact: Some(format_graph_fact(fact)),
611            approximate: false,
612            proof_probability: None,
613        })
614    })
615}
616
617/// Check if a binding variable matches a ref-fact key column via VID-based join.
618///
619/// Looks up `var_name` in `primary` (falling back to `fallback`), then compares
620/// it against `ref_key_col` in `ref_fact` using `values_equal_for_join`.
621/// Returns `true` when the key column is out of range or the binding is absent.
622fn binding_matches_key(
623    primary: &Row,
624    fallback: &Row,
625    var_name: &str,
626    ref_fact: &Row,
627    ref_key_col: Option<&String>,
628) -> bool {
629    let Some(key_col) = ref_key_col else {
630        return true;
631    };
632    let Some(val) = primary.get(var_name).or_else(|| fallback.get(var_name)) else {
633        return true;
634    };
635    ref_fact
636        .get(key_col)
637        .is_some_and(|rv| values_equal_for_join(rv, val))
638}
639
640fn extract_along_values(fact: &Row, rule: &CompiledRule) -> HashMap<String, Value> {
641    let mut along_values = HashMap::new();
642    for clause in &rule.clauses {
643        for along in &clause.along {
644            if let Some(v) = fact.get(&along.name) {
645                along_values.insert(along.name.clone(), v.clone());
646            }
647        }
648    }
649    along_values
650}
651
652pub(crate) fn format_graph_fact(row: &Row) -> String {
653    let mut entries: Vec<String> = row
654        .iter()
655        .map(|(k, v)| format!("{}: {}", k, format_value(v)))
656        .collect();
657    entries.sort();
658    format!("{{{}}}", entries.join(", "))
659}
660
661fn format_value(v: &Value) -> String {
662    match v {
663        Value::Null => "null".to_string(),
664        Value::Bool(b) => b.to_string(),
665        Value::Int(i) => i.to_string(),
666        Value::Float(f) => f.to_string(),
667        Value::String(s) => format!("\"{}\"", s),
668        Value::List(items) => {
669            let inner: Vec<String> = items.iter().map(format_value).collect();
670            format!("[{}]", inner.join(", "))
671        }
672        Value::Map(m) => {
673            let mut entries: Vec<String> = m
674                .iter()
675                .map(|(k, v)| format!("{}: {}", k, format_value(v)))
676                .collect();
677            entries.sort();
678            format!("{{{}}}", entries.join(", "))
679        }
680        Value::Node(n) => format!("Node({})", n.vid.as_u64()),
681        Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
682        _ => format!("{v:?}"),
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    fn make_annotation(rule: &str, prob: Option<f64>) -> ProvenanceAnnotation {
691        ProvenanceAnnotation {
692            rule_name: rule.to_string(),
693            clause_index: 0,
694            support: vec![],
695            along_values: HashMap::new(),
696            iteration: 0,
697            fact_row: HashMap::new(),
698            proof_probability: prob,
699        }
700    }
701
702    #[test]
703    fn record_first_derivation_wins() {
704        let store = ProvenanceStore::new();
705        let hash = b"fact1".to_vec();
706        store.record(hash.clone(), make_annotation("rule_a", None));
707        store.record(hash.clone(), make_annotation("rule_b", None));
708        let entry = store.lookup(&hash).unwrap();
709        assert_eq!(entry.rule_name, "rule_a");
710    }
711
712    #[test]
713    fn lookup_returns_first_annotation() {
714        let store = ProvenanceStore::new();
715        let hash = b"fact1".to_vec();
716        store.record(hash.clone(), make_annotation("rule_a", None));
717        assert_eq!(store.lookup(&hash).unwrap().rule_name, "rule_a");
718        assert!(store.lookup(b"nonexistent").is_none());
719    }
720
721    #[test]
722    fn lookup_all_returns_all_annotations() {
723        let store = ProvenanceStore::new();
724        let hash = b"fact1".to_vec();
725        // record() is first-wins, so only one entry per hash via record()
726        store.record(hash.clone(), make_annotation("rule_a", None));
727        let all = store.lookup_all(&hash).unwrap();
728        assert_eq!(all.len(), 1);
729    }
730
731    #[test]
732    fn record_top_k_retains_highest() {
733        let store = ProvenanceStore::new();
734        let hash = b"fact1".to_vec();
735        store.record_top_k(hash.clone(), make_annotation("low", Some(0.1)), 2);
736        store.record_top_k(hash.clone(), make_annotation("high", Some(0.9)), 2);
737        store.record_top_k(hash.clone(), make_annotation("mid", Some(0.5)), 2);
738        store.record_top_k(hash.clone(), make_annotation("highest", Some(0.95)), 2);
739        store.record_top_k(hash.clone(), make_annotation("lowest", Some(0.01)), 2);
740
741        let all = store.lookup_all(&hash).unwrap();
742        assert_eq!(all.len(), 2);
743        assert_eq!(all[0].rule_name, "highest");
744        assert_eq!(all[1].rule_name, "high");
745    }
746
747    #[test]
748    fn compute_proof_probability_basic() {
749        let support = vec![
750            ProofTerm {
751                source_rule: "r1".to_string(),
752                base_fact_id: b"f1".to_vec(),
753            },
754            ProofTerm {
755                source_rule: "r2".to_string(),
756                base_fact_id: b"f2".to_vec(),
757            },
758        ];
759        let base_probs = HashMap::from([(b"f1".to_vec(), 0.3), (b"f2".to_vec(), 0.5)]);
760        let prob = compute_proof_probability(&support, &base_probs);
761        assert!((prob.unwrap() - 0.15).abs() < 1e-10);
762    }
763
764    #[test]
765    fn compute_proof_probability_empty_support() {
766        let prob = compute_proof_probability(&[], &HashMap::new());
767        assert!(prob.is_none());
768    }
769
770    #[test]
771    fn compute_proof_probability_missing_fact() {
772        let support = vec![ProofTerm {
773            source_rule: "r1".to_string(),
774            base_fact_id: b"unknown".to_vec(),
775        }];
776        let prob = compute_proof_probability(&support, &HashMap::new());
777        assert!(prob.is_none());
778    }
779
780    #[test]
781    fn entries_for_rule_filters_correctly() {
782        let store = ProvenanceStore::new();
783        store.record(b"f1".to_vec(), make_annotation("rule_a", None));
784        store.record(b"f2".to_vec(), make_annotation("rule_b", None));
785        store.record(b"f3".to_vec(), make_annotation("rule_a", None));
786
787        let entries = store.entries_for_rule("rule_a");
788        assert_eq!(entries.len(), 2);
789        let entries_b = store.entries_for_rule("rule_b");
790        assert_eq!(entries_b.len(), 1);
791    }
792}