Skip to main content

uni_query/query/df_graph/
locy_explain.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! EXPLAIN RULE derivation tree construction.
5//!
6//! Ported from `uni-locy/src/orchestrator/explain.rs`. Uses `DerivedFactSource`
7//! instead of `CypherExecutor`. Uses `RowStore` for row-based fact storage.
8//!
9//! Implements Mode A (provenance-based, uses ProvenanceStore recorded during fixpoint)
10//! with fallback to Mode B (re-execution) when tracker has no entries for the rule.
11
12use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, FactRow, LocyConfig, LocyError, LocyStats};
19
20use super::locy_delta::{
21    KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{eval_expr, record_batches_to_locy_rows, values_equal_for_join};
25use super::locy_slg::SLGResolver;
26use super::locy_traits::DerivedFactSource;
27
28/// A single term in a derivation proof: identifies one IS-ref dependency.
29///
30/// Each `ProofTerm` records a dependency edge in the proof DAG — the source
31/// rule that was referenced and the opaque hash of the specific base fact
32/// consumed (Cui & Widom 2000, lineage).
33#[derive(Clone, Debug)]
34pub struct ProofTerm {
35    /// Name of the IS-ref rule that produced this dependency.
36    pub source_rule: String,
37    /// Opaque hash identifying the base fact consumed from `source_rule`.
38    pub base_fact_id: Vec<u8>,
39}
40
41/// Provenance annotation for a derived fact (Green et al. 2007, Definition 3.1).
42///
43/// Captures a single derivation witness: the rule and clause that produced the
44/// fact, plus the `support` set of proof terms that contributed to it.
45#[derive(Clone, Debug)]
46pub struct ProvenanceAnnotation {
47    /// Name of the rule that derived this fact.
48    pub rule_name: String,
49    /// Index of the clause within the rule that produced this fact.
50    pub clause_index: usize,
51    /// Proof terms: IS-ref input facts (populated when IS-ref tracking is available).
52    pub support: Vec<ProofTerm>,
53    /// ALONG column values captured at derivation time.
54    pub along_values: HashMap<String, Value>,
55    /// Fixpoint iteration number when the fact was first derived.
56    pub iteration: usize,
57    /// Full fact row stored for Mode A filtering/display.
58    pub fact_row: FactRow,
59    /// Probability of this specific proof path (populated by top-k filtering).
60    pub proof_probability: Option<f64>,
61}
62
63/// Provenance store for derived facts (Green et al. 2007, §3).
64///
65/// Stores provenance annotations keyed by opaque fact hashes. Enables
66/// Mode A (provenance-based) EXPLAIN without re-execution.
67/// First-derivation-wins: once a fact hash is recorded, later iterations
68/// do not overwrite it.
69#[derive(Debug)]
70pub struct ProvenanceStore {
71    entries: RwLock<HashMap<Vec<u8>, Vec<ProvenanceAnnotation>>>,
72}
73
74impl ProvenanceStore {
75    pub fn new() -> Self {
76        Self {
77            entries: RwLock::new(HashMap::new()),
78        }
79    }
80
81    /// Record a provenance annotation. First-derivation-wins: if the hash is already
82    /// present, the existing annotation is kept (unlimited mode).
83    pub fn record(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation) {
84        if let Ok(mut guard) = self.entries.write() {
85            guard.entry(fact_hash).or_insert_with(|| vec![entry]);
86        }
87    }
88
89    /// Record a provenance annotation with top-k filtering.
90    ///
91    /// Retains at most `k` annotations per fact, ordered by `proof_probability`
92    /// (highest first). Annotations without a proof probability are treated as
93    /// having probability 0.0 for ordering purposes.
94    pub fn record_top_k(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation, k: usize) {
95        if let Ok(mut guard) = self.entries.write() {
96            let vec = guard.entry(fact_hash).or_default();
97            vec.push(entry);
98            // Sort descending by proof_probability.
99            vec.sort_by(|a, b| {
100                b.proof_probability
101                    .unwrap_or(0.0)
102                    .partial_cmp(&a.proof_probability.unwrap_or(0.0))
103                    .unwrap_or(std::cmp::Ordering::Equal)
104            });
105            vec.truncate(k);
106        }
107    }
108
109    /// Look up the first (highest-priority) provenance annotation for a fact hash.
110    pub fn lookup(&self, fact_hash: &[u8]) -> Option<ProvenanceAnnotation> {
111        self.entries.read().ok()?.get(fact_hash)?.first().cloned()
112    }
113
114    /// Look up all provenance annotations for a fact hash.
115    pub fn lookup_all(&self, fact_hash: &[u8]) -> Option<Vec<ProvenanceAnnotation>> {
116        let guard = self.entries.read().ok()?;
117        guard.get(fact_hash).cloned()
118    }
119
120    /// Collect base fact probabilities from stored annotations.
121    ///
122    /// Scans all annotations for base facts (those with empty support, i.e. leaf
123    /// nodes in the proof tree) and extracts the PROB column value from their
124    /// `fact_row`. Used by top-k proof filtering to compute `proof_probability`.
125    pub fn base_fact_probs(&self) -> HashMap<Vec<u8>, f64> {
126        let mut probs = HashMap::new();
127        if let Ok(guard) = self.entries.read() {
128            for (fact_hash, annotations) in guard.iter() {
129                if let Some(ann) = annotations.first()
130                    && ann.support.is_empty()
131                    && let Some(uni_common::Value::Float(p)) = ann.fact_row.get("PROB")
132                {
133                    probs.insert(fact_hash.clone(), *p);
134                }
135            }
136        }
137        probs
138    }
139
140    /// Get all entries for a specific rule name (returns first annotation per fact).
141    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, ProvenanceAnnotation)> {
142        match self.entries.read() {
143            Ok(guard) => guard
144                .iter()
145                .filter_map(|(k, annotations)| {
146                    annotations
147                        .first()
148                        .filter(|e| e.rule_name == rule_name)
149                        .map(|e| (k.clone(), e.clone()))
150                })
151                .collect(),
152            Err(_) => vec![],
153        }
154    }
155}
156
157/// Compute the probability of a single proof path from its support terms.
158///
159/// Returns `None` if any base fact's probability is unknown.
160pub fn compute_proof_probability(
161    support: &[ProofTerm],
162    base_probs: &HashMap<Vec<u8>, f64>,
163) -> Option<f64> {
164    if support.is_empty() {
165        return None;
166    }
167    let mut product = 1.0;
168    for term in support {
169        let p = base_probs.get(&term.base_fact_id)?;
170        product *= p;
171    }
172    Some(product)
173}
174
175impl Default for ProvenanceStore {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Set of (rule_name, key_tuple) to detect cycles during recursive derivation (Mode B).
182type VisitedSet = HashSet<(String, KeyTuple)>;
183
184/// Build a derivation tree for a rule, showing how each fact was derived.
185///
186/// Tries Mode A (provenance-based, uses ProvenanceStore) first when a store is
187/// provided and has entries for the rule.  Falls through to Mode B (re-execution)
188/// when Mode A cannot produce a result.
189#[expect(
190    clippy::too_many_arguments,
191    reason = "explain requires full program context and tracker state"
192)]
193pub async fn explain_rule(
194    query: &ExplainRule,
195    program: &CompiledProgram,
196    fact_source: &dyn DerivedFactSource,
197    config: &LocyConfig,
198    derived_store: &mut RowStore,
199    stats: &mut LocyStats,
200    tracker: Option<&ProvenanceStore>,
201    approximate_groups: Option<&HashMap<String, Vec<String>>>,
202) -> Result<DerivationNode, LocyError> {
203    // Mode A: provenance-based (no re-execution required).
204    // Falls through to Mode B when tracker is absent or has no matching entries.
205    if let Some(Ok(node)) =
206        tracker.map(|t| explain_rule_mode_a(query, program, t, derived_store, approximate_groups))
207    {
208        return Ok(node);
209    }
210
211    // Mode B: re-execution fallback
212    explain_rule_mode_b(
213        query,
214        program,
215        fact_source,
216        config,
217        derived_store,
218        stats,
219        approximate_groups,
220    )
221    .await
222}
223
224/// Mode A: build derivation tree using recorded provenance from the fixpoint loop.
225///
226/// Returns `Err` when no tracker entries exist for the rule (signals Mode B fallback).
227fn explain_rule_mode_a(
228    query: &ExplainRule,
229    program: &CompiledProgram,
230    tracker: &ProvenanceStore,
231    _derived_store: &RowStore,
232    approximate_groups: Option<&HashMap<String, Vec<String>>>,
233) -> Result<DerivationNode, LocyError> {
234    let rule_name = query.rule_name.to_string();
235    let rule = program
236        .rule_catalog
237        .get(&rule_name)
238        .ok_or_else(|| LocyError::EvaluationError {
239            message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
240        })?;
241
242    let tracker_entries = tracker.entries_for_rule(&rule_name);
243    if tracker_entries.is_empty() {
244        return Err(LocyError::EvaluationError {
245            message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
246        });
247    }
248
249    // Filter tracker entries by WHERE expression
250    let matching_entries: Vec<_> = if let Some(where_expr) = &query.where_expr {
251        tracker_entries
252            .into_iter()
253            .filter(|(_, entry)| {
254                eval_expr(where_expr, &entry.fact_row)
255                    .map(|v| v.as_bool().unwrap_or(false))
256                    .unwrap_or(false)
257            })
258            .collect()
259    } else {
260        tracker_entries
261    };
262
263    if matching_entries.is_empty() {
264        return Err(LocyError::EvaluationError {
265            message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
266        });
267    }
268
269    let is_approximate = approximate_groups
270        .map(|ag| ag.contains_key(&rule_name))
271        .unwrap_or(false);
272
273    let mut root = DerivationNode {
274        rule: rule_name.clone(),
275        clause_index: 0,
276        priority: rule.priority,
277        bindings: HashMap::new(),
278        along_values: HashMap::new(),
279        children: Vec::new(),
280        graph_fact: None,
281        approximate: is_approximate,
282        proof_probability: None,
283    };
284
285    for (_, entry) in matching_entries {
286        let along_values = extract_along_values(&entry.fact_row, rule);
287        let clause_priority = rule
288            .clauses
289            .get(entry.clause_index)
290            .and_then(|c| c.priority);
291        let base_fact = format!(
292            "[iter={}] {}",
293            entry.iteration,
294            format_graph_fact(&entry.fact_row)
295        );
296        let graph_fact = if is_approximate {
297            format!("[APPROXIMATE] {}", base_fact)
298        } else {
299            base_fact
300        };
301        let node = DerivationNode {
302            rule: rule_name.clone(),
303            clause_index: entry.clause_index,
304            priority: clause_priority.or(rule.priority),
305            bindings: entry.fact_row.clone(),
306            along_values,
307            // Mode A: children not tracked (inputs list is reserved for future recursion)
308            children: vec![],
309            graph_fact: Some(graph_fact),
310            approximate: is_approximate,
311            proof_probability: entry.proof_probability,
312        };
313        root.children.push(node);
314    }
315
316    Ok(root)
317}
318
319/// Mode B: re-execution fallback — re-executes clause queries to find which
320/// clause produced each matching fact, then recurses into IS references.
321async fn explain_rule_mode_b(
322    query: &ExplainRule,
323    program: &CompiledProgram,
324    fact_source: &dyn DerivedFactSource,
325    config: &LocyConfig,
326    derived_store: &mut RowStore,
327    stats: &mut LocyStats,
328    approximate_groups: Option<&HashMap<String, Vec<String>>>,
329) -> Result<DerivationNode, LocyError> {
330    let rule_name = query.rule_name.to_string();
331    let rule = program
332        .rule_catalog
333        .get(&rule_name)
334        .ok_or_else(|| LocyError::EvaluationError {
335            message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
336        })?;
337
338    let key_columns: Vec<String> = rule
339        .yield_schema
340        .iter()
341        .filter(|c| c.is_key)
342        .map(|c| c.name.clone())
343        .collect();
344
345    // Re-evaluate the rule via SLG to obtain rows with full node objects and properties.
346    // The native fixpoint's orch_store has VID-only integers that fail property-based
347    // WHERE filters (e.g. a.name = 'A') — we need actual Value::Node values here.
348    {
349        let mut fresh_store = RowStore::new();
350        let slg_start = std::time::Instant::now();
351        let mut resolver =
352            SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
353        resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
354        stats.queries_executed += resolver.stats.queries_executed;
355        // Merge full-node facts into derived_store so IS-ref lookups in
356        // build_derivation_node also get proper node objects (not VIDs).
357        for (name, relation) in fresh_store {
358            derived_store.insert(name, relation);
359        }
360    }
361
362    // Get all derived facts for this rule (now populated with full node objects)
363    let facts = derived_store
364        .get(&rule_name)
365        .map(|r| r.rows.clone())
366        .unwrap_or_default();
367
368    // Filter facts by WHERE expression
369    let filtered: Vec<FactRow> = if let Some(where_expr) = &query.where_expr {
370        facts
371            .into_iter()
372            .filter(|row| {
373                eval_expr(where_expr, row)
374                    .map(|v| v.as_bool().unwrap_or(false))
375                    .unwrap_or(false)
376            })
377            .collect()
378    } else {
379        facts
380    };
381
382    let is_approximate = approximate_groups
383        .map(|ag| ag.contains_key(&rule_name))
384        .unwrap_or(false);
385
386    // Build derivation tree root
387    let mut root = DerivationNode {
388        rule: rule_name.clone(),
389        clause_index: 0,
390        priority: rule.priority,
391        bindings: HashMap::new(),
392        along_values: HashMap::new(),
393        children: Vec::new(),
394        graph_fact: None,
395        approximate: is_approximate,
396        proof_probability: None,
397    };
398
399    // For each matching fact, recursively build a derivation node
400    for fact in &filtered {
401        let mut visited = VisitedSet::new();
402        let mut node = build_derivation_node(
403            &rule_name,
404            fact,
405            &key_columns,
406            program,
407            fact_source,
408            derived_store,
409            stats,
410            &mut visited,
411            config.max_explain_depth,
412        )
413        .await?;
414        if is_approximate {
415            node.approximate = true;
416            if let Some(ref gf) = node.graph_fact {
417                node.graph_fact = Some(format!("[APPROXIMATE] {}", gf));
418            }
419        }
420        root.children.push(node);
421    }
422
423    Ok(root)
424}
425
426/// Recursively build a derivation node for a single fact of a rule.
427///
428/// Finds which clause produced this fact, extracts ALONG values,
429/// and recurses into IS reference dependencies.
430#[expect(
431    clippy::too_many_arguments,
432    reason = "recursive derivation node builder requires full fact context"
433)]
434fn build_derivation_node<'a>(
435    rule_name: &'a str,
436    fact: &'a FactRow,
437    key_columns: &'a [String],
438    program: &'a CompiledProgram,
439    fact_source: &'a dyn DerivedFactSource,
440    derived_store: &'a mut RowStore,
441    stats: &'a mut LocyStats,
442    visited: &'a mut VisitedSet,
443    max_depth: usize,
444) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + 'a>> {
445    Box::pin(async move {
446        let rule =
447            program
448                .rule_catalog
449                .get(rule_name)
450                .ok_or_else(|| LocyError::EvaluationError {
451                    message: format!("rule '{}' not found during EXPLAIN", rule_name),
452                })?;
453
454        let key_tuple = extract_key(fact, key_columns);
455        let visit_key = (rule_name.to_string(), key_tuple);
456
457        // Cycle detection
458        if !visited.insert(visit_key.clone()) || max_depth == 0 {
459            return Ok(DerivationNode {
460                rule: rule_name.to_string(),
461                clause_index: 0,
462                priority: rule.priority,
463                bindings: fact.clone(),
464                along_values: extract_along_values(fact, rule),
465                children: Vec::new(),
466                graph_fact: Some("(cycle)".to_string()),
467                approximate: false,
468                proof_probability: None,
469            });
470        }
471
472        // Match on KEY columns only.  Clause-level resolution returns only
473        // base graph bindings (vertex/edge identifiers); non-KEY yield columns
474        // (FOLD-aggregated, similar_to, etc.) are absent from those rows.
475        // KEY columns uniquely identify a derived fact, so this is sufficient.
476
477        // Try each clause to find the one that produced this fact
478        for (clause_idx, clause) in rule.clauses.iter().enumerate() {
479            let has_is_refs = clause
480                .where_conditions
481                .iter()
482                .any(|c| matches!(c, RuleCondition::IsReference(_)));
483            let has_along = !clause.along.is_empty();
484
485            let resolved = if has_is_refs || has_along {
486                let rows = resolve_clause_with_is_refs(
487                    clause,
488                    fact_source,
489                    derived_store,
490                    &program.rule_catalog,
491                    None, // EXPLAIN traces proofs, doesn't compute probabilities
492                )
493                .await?;
494                stats.queries_executed += 1;
495                rows
496            } else {
497                let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
498                let raw_batches = fact_source
499                    .execute_pattern(&clause.match_pattern, &cypher_conditions)
500                    .await?;
501                stats.queries_executed += 1;
502                record_batches_to_locy_rows(&raw_batches)
503            };
504
505            // Use values_equal_for_join for VID/EID-based comparison: sidecar
506            // schema mode can add `overflow_json: Null` to nodes in some query
507            // paths, making structural equality unreliable.
508            let matching_row = resolved.iter().find(|row| {
509                key_columns.iter().all(|k| match (row.get(k), fact.get(k)) {
510                    (Some(v1), Some(v2)) => values_equal_for_join(v1, v2),
511                    (None, None) => true,
512                    _ => false,
513                })
514            });
515
516            if let Some(evidence_row) = matching_row {
517                let along_values = extract_along_values(fact, rule);
518
519                // Build children by recursing into IS references
520                let mut children = Vec::new();
521                for cond in &clause.where_conditions {
522                    if let RuleCondition::IsReference(is_ref) = cond {
523                        if is_ref.negated {
524                            continue;
525                        }
526                        let ref_rule_name = is_ref.rule_name.to_string();
527                        if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
528                            let ref_key_columns: Vec<String> = ref_rule
529                                .yield_schema
530                                .iter()
531                                .filter(|c| c.is_key)
532                                .map(|c| c.name.clone())
533                                .collect();
534
535                            let ref_facts: Vec<FactRow> = derived_store
536                                .get(&ref_rule_name)
537                                .map(|r| r.rows.clone())
538                                .unwrap_or_default();
539
540                            let matching_ref_facts: Vec<FactRow> = ref_facts
541                                .into_iter()
542                                .filter(|ref_fact| {
543                                    let subjects_match =
544                                        is_ref.subjects.iter().enumerate().all(|(i, subject)| {
545                                            binding_matches_key(
546                                                evidence_row,
547                                                fact,
548                                                subject,
549                                                ref_fact,
550                                                ref_key_columns.get(i),
551                                            )
552                                        });
553                                    let target_matches =
554                                        is_ref.target.as_ref().is_none_or(|target| {
555                                            binding_matches_key(
556                                                evidence_row,
557                                                fact,
558                                                target,
559                                                ref_fact,
560                                                ref_key_columns.get(is_ref.subjects.len()),
561                                            )
562                                        });
563                                    subjects_match && target_matches
564                                })
565                                .collect();
566
567                            for ref_fact in matching_ref_facts {
568                                let child = build_derivation_node(
569                                    &ref_rule_name,
570                                    &ref_fact,
571                                    &ref_key_columns,
572                                    program,
573                                    fact_source,
574                                    derived_store,
575                                    stats,
576                                    visited,
577                                    max_depth - 1,
578                                )
579                                .await?;
580                                children.push(child);
581                            }
582                        }
583                    }
584                }
585
586                // Backtrack visited set
587                visited.remove(&visit_key);
588
589                let mut merged_bindings = evidence_row.clone();
590                for (k, v) in fact {
591                    merged_bindings.entry(k.clone()).or_insert(v.clone());
592                }
593
594                return Ok(DerivationNode {
595                    rule: rule_name.to_string(),
596                    clause_index: clause_idx,
597                    priority: rule.clauses[clause_idx].priority,
598                    bindings: merged_bindings,
599                    along_values,
600                    children,
601                    graph_fact: Some(format_graph_fact(evidence_row)),
602                    approximate: false,
603                    proof_probability: None,
604                });
605            }
606        }
607
608        // No clause matched — leaf node
609        visited.remove(&visit_key);
610        Ok(DerivationNode {
611            rule: rule_name.to_string(),
612            clause_index: 0,
613            priority: rule.priority,
614            bindings: fact.clone(),
615            along_values: extract_along_values(fact, rule),
616            children: Vec::new(),
617            graph_fact: Some(format_graph_fact(fact)),
618            approximate: false,
619            proof_probability: None,
620        })
621    })
622}
623
624/// Check if a binding variable matches a ref-fact key column via VID-based join.
625///
626/// Looks up `var_name` in `primary` (falling back to `fallback`), then compares
627/// it against `ref_key_col` in `ref_fact` using `values_equal_for_join`.
628/// Returns `true` when the key column is out of range or the binding is absent.
629fn binding_matches_key(
630    primary: &FactRow,
631    fallback: &FactRow,
632    var_name: &str,
633    ref_fact: &FactRow,
634    ref_key_col: Option<&String>,
635) -> bool {
636    let Some(key_col) = ref_key_col else {
637        return true;
638    };
639    let Some(val) = primary.get(var_name).or_else(|| fallback.get(var_name)) else {
640        return true;
641    };
642    ref_fact
643        .get(key_col)
644        .is_some_and(|rv| values_equal_for_join(rv, val))
645}
646
647fn extract_along_values(fact: &FactRow, rule: &CompiledRule) -> HashMap<String, Value> {
648    let mut along_values = HashMap::new();
649    for clause in &rule.clauses {
650        for along in &clause.along {
651            if let Some(v) = fact.get(&along.name) {
652                along_values.insert(along.name.clone(), v.clone());
653            }
654        }
655    }
656    along_values
657}
658
659pub(crate) fn format_graph_fact(row: &FactRow) -> String {
660    let mut entries: Vec<String> = row
661        .iter()
662        .map(|(k, v)| format!("{}: {}", k, format_value(v)))
663        .collect();
664    entries.sort();
665    format!("{{{}}}", entries.join(", "))
666}
667
668fn format_value(v: &Value) -> String {
669    match v {
670        Value::Null => "null".to_string(),
671        Value::Bool(b) => b.to_string(),
672        Value::Int(i) => i.to_string(),
673        Value::Float(f) => f.to_string(),
674        Value::String(s) => format!("\"{}\"", s),
675        Value::List(items) => {
676            let inner: Vec<String> = items.iter().map(format_value).collect();
677            format!("[{}]", inner.join(", "))
678        }
679        Value::Map(m) => {
680            let mut entries: Vec<String> = m
681                .iter()
682                .map(|(k, v)| format!("{}: {}", k, format_value(v)))
683                .collect();
684            entries.sort();
685            format!("{{{}}}", entries.join(", "))
686        }
687        Value::Node(n) => format!("Node({})", n.vid.as_u64()),
688        Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
689        _ => format!("{v:?}"),
690    }
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696
697    fn make_annotation(rule: &str, prob: Option<f64>) -> ProvenanceAnnotation {
698        ProvenanceAnnotation {
699            rule_name: rule.to_string(),
700            clause_index: 0,
701            support: vec![],
702            along_values: HashMap::new(),
703            iteration: 0,
704            fact_row: HashMap::new(),
705            proof_probability: prob,
706        }
707    }
708
709    #[test]
710    fn record_first_derivation_wins() {
711        let store = ProvenanceStore::new();
712        let hash = b"fact1".to_vec();
713        store.record(hash.clone(), make_annotation("rule_a", None));
714        store.record(hash.clone(), make_annotation("rule_b", None));
715        let entry = store.lookup(&hash).unwrap();
716        assert_eq!(entry.rule_name, "rule_a");
717    }
718
719    #[test]
720    fn lookup_returns_first_annotation() {
721        let store = ProvenanceStore::new();
722        let hash = b"fact1".to_vec();
723        store.record(hash.clone(), make_annotation("rule_a", None));
724        assert_eq!(store.lookup(&hash).unwrap().rule_name, "rule_a");
725        assert!(store.lookup(b"nonexistent").is_none());
726    }
727
728    #[test]
729    fn lookup_all_returns_all_annotations() {
730        let store = ProvenanceStore::new();
731        let hash = b"fact1".to_vec();
732        // record() is first-wins, so only one entry per hash via record()
733        store.record(hash.clone(), make_annotation("rule_a", None));
734        let all = store.lookup_all(&hash).unwrap();
735        assert_eq!(all.len(), 1);
736    }
737
738    #[test]
739    fn record_top_k_retains_highest() {
740        let store = ProvenanceStore::new();
741        let hash = b"fact1".to_vec();
742        store.record_top_k(hash.clone(), make_annotation("low", Some(0.1)), 2);
743        store.record_top_k(hash.clone(), make_annotation("high", Some(0.9)), 2);
744        store.record_top_k(hash.clone(), make_annotation("mid", Some(0.5)), 2);
745        store.record_top_k(hash.clone(), make_annotation("highest", Some(0.95)), 2);
746        store.record_top_k(hash.clone(), make_annotation("lowest", Some(0.01)), 2);
747
748        let all = store.lookup_all(&hash).unwrap();
749        assert_eq!(all.len(), 2);
750        assert_eq!(all[0].rule_name, "highest");
751        assert_eq!(all[1].rule_name, "high");
752    }
753
754    #[test]
755    fn compute_proof_probability_basic() {
756        let support = vec![
757            ProofTerm {
758                source_rule: "r1".to_string(),
759                base_fact_id: b"f1".to_vec(),
760            },
761            ProofTerm {
762                source_rule: "r2".to_string(),
763                base_fact_id: b"f2".to_vec(),
764            },
765        ];
766        let base_probs = HashMap::from([(b"f1".to_vec(), 0.3), (b"f2".to_vec(), 0.5)]);
767        let prob = compute_proof_probability(&support, &base_probs);
768        assert!((prob.unwrap() - 0.15).abs() < 1e-10);
769    }
770
771    #[test]
772    fn compute_proof_probability_empty_support() {
773        let prob = compute_proof_probability(&[], &HashMap::new());
774        assert!(prob.is_none());
775    }
776
777    #[test]
778    fn compute_proof_probability_missing_fact() {
779        let support = vec![ProofTerm {
780            source_rule: "r1".to_string(),
781            base_fact_id: b"unknown".to_vec(),
782        }];
783        let prob = compute_proof_probability(&support, &HashMap::new());
784        assert!(prob.is_none());
785    }
786
787    #[test]
788    fn entries_for_rule_filters_correctly() {
789        let store = ProvenanceStore::new();
790        store.record(b"f1".to_vec(), make_annotation("rule_a", None));
791        store.record(b"f2".to_vec(), make_annotation("rule_b", None));
792        store.record(b"f3".to_vec(), make_annotation("rule_a", None));
793
794        let entries = store.entries_for_rule("rule_a");
795        assert_eq!(entries.len(), 2);
796        let entries_b = store.entries_for_rule("rule_b");
797        assert_eq!(entries_b.len(), 1);
798    }
799}