1use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, FactRow, LocyConfig, LocyError, LocyStats};
19
20use super::locy_delta::{
21 KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{
25 eval_expr, normalize_graph_row, record_batches_to_locy_rows, values_equal_for_join,
26};
27use super::locy_slg::SLGResolver;
28use super::locy_traits::DerivedFactSource;
29
30#[derive(Clone, Debug)]
36pub struct ProofTerm {
37 pub source_rule: String,
39 pub base_fact_id: Vec<u8>,
41}
42
43#[derive(Clone, Debug)]
48pub struct ProvenanceAnnotation {
49 pub rule_name: String,
51 pub clause_index: usize,
53 pub support: Vec<ProofTerm>,
55 pub along_values: HashMap<String, Value>,
57 pub iteration: usize,
59 pub fact_row: FactRow,
61 pub proof_probability: Option<f64>,
63 pub neural_calls: Vec<uni_locy::NeuralProvenance>,
69}
70
71#[derive(Debug)]
78pub struct ProvenanceStore {
79 entries: RwLock<HashMap<Vec<u8>, Vec<ProvenanceAnnotation>>>,
80}
81
82impl ProvenanceStore {
83 pub fn new() -> Self {
84 Self {
85 entries: RwLock::new(HashMap::new()),
86 }
87 }
88
89 pub fn record(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation) {
92 if let Ok(mut guard) = self.entries.write() {
93 guard.entry(fact_hash).or_insert_with(|| vec![entry]);
94 }
95 }
96
97 pub fn record_top_k(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation, k: usize) {
103 if let Ok(mut guard) = self.entries.write() {
104 let vec = guard.entry(fact_hash).or_default();
105 vec.push(entry);
106 vec.sort_by(|a, b| {
108 b.proof_probability
109 .unwrap_or(0.0)
110 .partial_cmp(&a.proof_probability.unwrap_or(0.0))
111 .unwrap_or(std::cmp::Ordering::Equal)
112 });
113 vec.truncate(k);
114 }
115 }
116
117 pub fn lookup(&self, fact_hash: &[u8]) -> Option<ProvenanceAnnotation> {
119 self.entries.read().ok()?.get(fact_hash)?.first().cloned()
120 }
121
122 pub fn lookup_all(&self, fact_hash: &[u8]) -> Option<Vec<ProvenanceAnnotation>> {
124 let guard = self.entries.read().ok()?;
125 guard.get(fact_hash).cloned()
126 }
127
128 pub fn base_fact_probs(&self) -> HashMap<Vec<u8>, f64> {
134 let mut probs = HashMap::new();
135 if let Ok(guard) = self.entries.read() {
136 for (fact_hash, annotations) in guard.iter() {
137 if let Some(ann) = annotations.first()
138 && ann.support.is_empty()
139 && let Some(uni_common::Value::Float(p)) = ann.fact_row.get("PROB")
140 {
141 probs.insert(fact_hash.clone(), *p);
142 }
143 }
144 }
145 probs
146 }
147
148 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, ProvenanceAnnotation)> {
150 match self.entries.read() {
151 Ok(guard) => guard
152 .iter()
153 .filter_map(|(k, annotations)| {
154 annotations
155 .first()
156 .filter(|e| e.rule_name == rule_name)
157 .map(|e| (k.clone(), e.clone()))
158 })
159 .collect(),
160 Err(_) => vec![],
161 }
162 }
163}
164
165pub fn compute_proof_probability(
169 support: &[ProofTerm],
170 base_probs: &HashMap<Vec<u8>, f64>,
171) -> Option<f64> {
172 if support.is_empty() {
173 return None;
174 }
175 let mut product = 1.0;
176 for term in support {
177 let p = base_probs.get(&term.base_fact_id)?;
178 product *= p;
179 }
180 Some(product)
181}
182
183impl Default for ProvenanceStore {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189type VisitedSet = HashSet<(String, KeyTuple)>;
191
192#[expect(
198 clippy::too_many_arguments,
199 reason = "explain requires full program context and tracker state"
200)]
201pub async fn explain_rule(
202 query: &ExplainRule,
203 program: &CompiledProgram,
204 fact_source: &dyn DerivedFactSource,
205 config: &LocyConfig,
206 derived_store: &mut RowStore,
207 stats: &mut LocyStats,
208 tracker: Option<&ProvenanceStore>,
209 approximate_groups: Option<&HashMap<String, Vec<String>>>,
210) -> Result<DerivationNode, LocyError> {
211 if let Some(t) = tracker
214 && let Ok(node) = explain_rule_mode_a(
215 query,
216 program,
217 t,
218 fact_source,
219 derived_store,
220 approximate_groups,
221 )
222 .await
223 {
224 return Ok(node);
225 }
226
227 explain_rule_mode_b(
229 query,
230 program,
231 fact_source,
232 config,
233 derived_store,
234 stats,
235 approximate_groups,
236 )
237 .await
238}
239
240async fn explain_rule_mode_a(
244 query: &ExplainRule,
245 program: &CompiledProgram,
246 tracker: &ProvenanceStore,
247 fact_source: &dyn DerivedFactSource,
248 _derived_store: &RowStore,
249 approximate_groups: Option<&HashMap<String, Vec<String>>>,
250) -> Result<DerivationNode, LocyError> {
251 let rule_name = query.rule_name.to_string();
252 let rule = program
253 .rule_catalog
254 .get(&rule_name)
255 .ok_or_else(|| LocyError::EvaluationError {
256 message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
257 })?;
258
259 let tracker_entries = tracker.entries_for_rule(&rule_name);
260 if tracker_entries.is_empty() {
261 return Err(LocyError::EvaluationError {
262 message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
263 });
264 }
265
266 let key_cols: HashSet<String> = rule
279 .yield_schema
280 .iter()
281 .filter(|c| c.is_key && !c.is_prob)
282 .map(|c| c.name.clone())
283 .collect();
284
285 let mut tracker_entries: Vec<(Vec<u8>, ProvenanceAnnotation)> = tracker_entries
286 .into_iter()
287 .map(|(h, mut ann)| {
288 normalize_graph_row(&mut ann.fact_row);
290 (h, ann)
291 })
292 .collect();
293
294 let mut candidate_vids: Vec<u64> = Vec::new();
296 for (_, entry) in &tracker_entries {
297 for (k, v) in &entry.fact_row {
298 if key_cols.contains(k)
299 && let Value::Int(i) = v
300 && *i >= 0
301 {
302 candidate_vids.push(*i as u64);
303 }
304 }
305 }
306 candidate_vids.sort_unstable();
307 candidate_vids.dedup();
308 let vid_to_node: HashMap<u64, Value> = if candidate_vids.is_empty() {
309 HashMap::new()
310 } else {
311 fact_source
312 .lookup_nodes_by_vids(&candidate_vids)
313 .await
314 .unwrap_or_default()
315 };
316 if !vid_to_node.is_empty() {
317 for (_, ann) in &mut tracker_entries {
318 for (k, v) in ann.fact_row.iter_mut() {
319 if key_cols.contains(k)
320 && let Value::Int(i) = v
321 && *i >= 0
322 && let Some(node) = vid_to_node.get(&(*i as u64))
323 {
324 *v = node.clone();
325 }
326 }
327 }
328 }
329
330 let matching_entries: Vec<_> = if let Some(where_expr) = &query.where_expr {
332 tracker_entries
333 .into_iter()
334 .filter(|(_, entry)| {
335 eval_expr(where_expr, &entry.fact_row)
336 .map(|v| v.as_bool().unwrap_or(false))
337 .unwrap_or(false)
338 })
339 .collect()
340 } else {
341 tracker_entries
342 };
343
344 if matching_entries.is_empty() {
345 return Err(LocyError::EvaluationError {
346 message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
347 });
348 }
349
350 let is_approximate = approximate_groups
351 .map(|ag| ag.contains_key(&rule_name))
352 .unwrap_or(false);
353
354 let mut root = DerivationNode {
355 rule: rule_name.clone(),
356 clause_index: 0,
357 priority: rule.priority,
358 bindings: HashMap::new(),
359 along_values: HashMap::new(),
360 children: Vec::new(),
361 graph_fact: None,
362 approximate: is_approximate,
363 proof_probability: None,
364 neural_calls: Vec::new(),
365 };
366
367 for (_, entry) in matching_entries {
368 let along_values = extract_along_values(&entry.fact_row, rule);
369 let clause_priority = rule
370 .clauses
371 .get(entry.clause_index)
372 .and_then(|c| c.priority);
373 let base_fact = format!(
374 "[iter={}] {}",
375 entry.iteration,
376 format_graph_fact(&entry.fact_row)
377 );
378 let graph_fact = if is_approximate {
379 format!("[APPROXIMATE] {}", base_fact)
380 } else {
381 base_fact
382 };
383 let node = DerivationNode {
384 rule: rule_name.clone(),
385 clause_index: entry.clause_index,
386 priority: clause_priority.or(rule.priority),
387 bindings: entry.fact_row.clone(),
388 along_values,
389 children: vec![],
391 graph_fact: Some(graph_fact),
392 approximate: is_approximate,
393 proof_probability: entry.proof_probability,
394 neural_calls: entry.neural_calls.clone(),
398 };
399 root.children.push(node);
400 }
401
402 Ok(root)
403}
404
405async fn explain_rule_mode_b(
408 query: &ExplainRule,
409 program: &CompiledProgram,
410 fact_source: &dyn DerivedFactSource,
411 config: &LocyConfig,
412 derived_store: &mut RowStore,
413 stats: &mut LocyStats,
414 approximate_groups: Option<&HashMap<String, Vec<String>>>,
415) -> Result<DerivationNode, LocyError> {
416 let rule_name = query.rule_name.to_string();
417 let rule = program
418 .rule_catalog
419 .get(&rule_name)
420 .ok_or_else(|| LocyError::EvaluationError {
421 message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
422 })?;
423
424 let key_columns: Vec<String> = rule
425 .yield_schema
426 .iter()
427 .filter(|c| c.is_key)
428 .map(|c| c.name.clone())
429 .collect();
430
431 {
435 let mut fresh_store = RowStore::new();
436 let slg_start = std::time::Instant::now();
437 let mut resolver =
438 SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
439 resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
440 stats.queries_executed += resolver.stats.queries_executed;
441 for (name, relation) in fresh_store {
444 derived_store.insert(name, relation);
445 }
446 }
447
448 let facts = derived_store
450 .get(&rule_name)
451 .map(|r| r.rows.clone())
452 .unwrap_or_default();
453
454 let filtered: Vec<FactRow> = if let Some(where_expr) = &query.where_expr {
456 facts
457 .into_iter()
458 .filter(|row| {
459 eval_expr(where_expr, row)
460 .map(|v| v.as_bool().unwrap_or(false))
461 .unwrap_or(false)
462 })
463 .collect()
464 } else {
465 facts
466 };
467
468 let is_approximate = approximate_groups
469 .map(|ag| ag.contains_key(&rule_name))
470 .unwrap_or(false);
471
472 let mut root = DerivationNode {
474 rule: rule_name.clone(),
475 clause_index: 0,
476 priority: rule.priority,
477 bindings: HashMap::new(),
478 along_values: HashMap::new(),
479 children: Vec::new(),
480 graph_fact: None,
481 approximate: is_approximate,
482 proof_probability: None,
483 neural_calls: Vec::new(),
484 };
485
486 for fact in &filtered {
488 let mut visited = VisitedSet::new();
489 let mut node = build_derivation_node(
490 &rule_name,
491 fact,
492 &key_columns,
493 program,
494 fact_source,
495 derived_store,
496 stats,
497 &mut visited,
498 config.max_explain_depth,
499 )
500 .await?;
501 if is_approximate {
502 node.approximate = true;
503 if let Some(ref gf) = node.graph_fact {
504 node.graph_fact = Some(format!("[APPROXIMATE] {}", gf));
505 }
506 }
507 root.children.push(node);
508 }
509
510 Ok(root)
511}
512
513#[expect(
518 clippy::too_many_arguments,
519 reason = "recursive derivation node builder requires full fact context"
520)]
521fn build_derivation_node<'a>(
522 rule_name: &'a str,
523 fact: &'a FactRow,
524 key_columns: &'a [String],
525 program: &'a CompiledProgram,
526 fact_source: &'a dyn DerivedFactSource,
527 derived_store: &'a mut RowStore,
528 stats: &'a mut LocyStats,
529 visited: &'a mut VisitedSet,
530 max_depth: usize,
531) -> std::pin::Pin<
532 Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + Send + 'a>,
533> {
534 Box::pin(async move {
535 let rule =
536 program
537 .rule_catalog
538 .get(rule_name)
539 .ok_or_else(|| LocyError::EvaluationError {
540 message: format!("rule '{}' not found during EXPLAIN", rule_name),
541 })?;
542
543 let key_tuple = extract_key(fact, key_columns);
544 let visit_key = (rule_name.to_string(), key_tuple);
545
546 if !visited.insert(visit_key.clone()) || max_depth == 0 {
548 return Ok(DerivationNode {
549 rule: rule_name.to_string(),
550 clause_index: 0,
551 priority: rule.priority,
552 bindings: fact.clone(),
553 along_values: extract_along_values(fact, rule),
554 children: Vec::new(),
555 graph_fact: Some("(cycle)".to_string()),
556 approximate: false,
557 proof_probability: None,
558 neural_calls: Vec::new(),
559 });
560 }
561
562 for (clause_idx, clause) in rule.clauses.iter().enumerate() {
569 let has_is_refs = clause
570 .where_conditions
571 .iter()
572 .any(|c| matches!(c, RuleCondition::IsReference(_)));
573 let has_along = !clause.along.is_empty();
574
575 let resolved = if has_is_refs || has_along {
576 let rows = resolve_clause_with_is_refs(
577 clause,
578 fact_source,
579 derived_store,
580 &program.rule_catalog,
581 None, )
583 .await?;
584 stats.queries_executed += 1;
585 rows
586 } else {
587 let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
588 let raw_batches = fact_source
589 .execute_pattern(&clause.match_pattern, &cypher_conditions)
590 .await?;
591 stats.queries_executed += 1;
592 record_batches_to_locy_rows(&raw_batches)
593 };
594
595 let matching_row = resolved.iter().find(|row| {
599 key_columns.iter().all(|k| match (row.get(k), fact.get(k)) {
600 (Some(v1), Some(v2)) => values_equal_for_join(v1, v2),
601 (None, None) => true,
602 _ => false,
603 })
604 });
605
606 if let Some(evidence_row) = matching_row {
607 let along_values = extract_along_values(fact, rule);
608
609 let mut children = Vec::new();
611 for cond in &clause.where_conditions {
612 if let RuleCondition::IsReference(is_ref) = cond {
613 if is_ref.negated {
614 continue;
615 }
616 let ref_rule_name = is_ref.rule_name.to_string();
617 if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
618 let ref_key_columns: Vec<String> = ref_rule
619 .yield_schema
620 .iter()
621 .filter(|c| c.is_key)
622 .map(|c| c.name.clone())
623 .collect();
624
625 let ref_facts: Vec<FactRow> = derived_store
626 .get(&ref_rule_name)
627 .map(|r| r.rows.clone())
628 .unwrap_or_default();
629
630 let matching_ref_facts: Vec<FactRow> = ref_facts
631 .into_iter()
632 .filter(|ref_fact| {
633 let subjects_match =
634 is_ref.subjects.iter().enumerate().all(|(i, subject)| {
635 binding_matches_key(
636 evidence_row,
637 fact,
638 subject,
639 ref_fact,
640 ref_key_columns.get(i),
641 )
642 });
643 let target_matches =
644 is_ref.target.as_ref().is_none_or(|target| {
645 binding_matches_key(
646 evidence_row,
647 fact,
648 target,
649 ref_fact,
650 ref_key_columns.get(is_ref.subjects.len()),
651 )
652 });
653 subjects_match && target_matches
654 })
655 .collect();
656
657 for ref_fact in matching_ref_facts {
658 let child = build_derivation_node(
659 &ref_rule_name,
660 &ref_fact,
661 &ref_key_columns,
662 program,
663 fact_source,
664 derived_store,
665 stats,
666 visited,
667 max_depth - 1,
668 )
669 .await?;
670 children.push(child);
671 }
672 }
673 }
674 }
675
676 visited.remove(&visit_key);
678
679 let mut merged_bindings = evidence_row.clone();
680 for (k, v) in fact {
681 merged_bindings.entry(k.clone()).or_insert(v.clone());
682 }
683
684 return Ok(DerivationNode {
685 rule: rule_name.to_string(),
686 clause_index: clause_idx,
687 priority: rule.clauses[clause_idx].priority,
688 bindings: merged_bindings,
689 along_values,
690 children,
691 graph_fact: Some(format_graph_fact(evidence_row)),
692 approximate: false,
693 proof_probability: None,
694 neural_calls: Vec::new(),
695 });
696 }
697 }
698
699 visited.remove(&visit_key);
701 Ok(DerivationNode {
702 rule: rule_name.to_string(),
703 clause_index: 0,
704 priority: rule.priority,
705 bindings: fact.clone(),
706 along_values: extract_along_values(fact, rule),
707 children: Vec::new(),
708 graph_fact: Some(format_graph_fact(fact)),
709 approximate: false,
710 proof_probability: None,
711 neural_calls: Vec::new(),
712 })
713 })
714}
715
716fn binding_matches_key(
722 primary: &FactRow,
723 fallback: &FactRow,
724 var_name: &str,
725 ref_fact: &FactRow,
726 ref_key_col: Option<&String>,
727) -> bool {
728 let Some(key_col) = ref_key_col else {
729 return true;
730 };
731 let Some(val) = primary.get(var_name).or_else(|| fallback.get(var_name)) else {
732 return true;
733 };
734 ref_fact
735 .get(key_col)
736 .is_some_and(|rv| values_equal_for_join(rv, val))
737}
738
739fn extract_along_values(fact: &FactRow, rule: &CompiledRule) -> HashMap<String, Value> {
740 let mut along_values = HashMap::new();
741 for clause in &rule.clauses {
742 for along in &clause.along {
743 if let Some(v) = fact.get(&along.name) {
744 along_values.insert(along.name.clone(), v.clone());
745 }
746 }
747 }
748 along_values
749}
750
751pub(crate) fn format_graph_fact(row: &FactRow) -> String {
752 let mut entries: Vec<String> = row
753 .iter()
754 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
755 .collect();
756 entries.sort();
757 format!("{{{}}}", entries.join(", "))
758}
759
760fn format_value(v: &Value) -> String {
761 match v {
762 Value::Null => "null".to_string(),
763 Value::Bool(b) => b.to_string(),
764 Value::Int(i) => i.to_string(),
765 Value::Float(f) => f.to_string(),
766 Value::String(s) => format!("\"{}\"", s),
767 Value::List(items) => {
768 let inner: Vec<String> = items.iter().map(format_value).collect();
769 format!("[{}]", inner.join(", "))
770 }
771 Value::Map(m) => {
772 let mut entries: Vec<String> = m
773 .iter()
774 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
775 .collect();
776 entries.sort();
777 format!("{{{}}}", entries.join(", "))
778 }
779 Value::Node(n) => format!("Node({})", n.vid.as_u64()),
780 Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
781 _ => format!("{v:?}"),
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 fn make_annotation(rule: &str, prob: Option<f64>) -> ProvenanceAnnotation {
790 ProvenanceAnnotation {
791 rule_name: rule.to_string(),
792 clause_index: 0,
793 support: vec![],
794 along_values: HashMap::new(),
795 iteration: 0,
796 fact_row: HashMap::new(),
797 proof_probability: prob,
798 neural_calls: Vec::new(),
799 }
800 }
801
802 #[test]
803 fn record_first_derivation_wins() {
804 let store = ProvenanceStore::new();
805 let hash = b"fact1".to_vec();
806 store.record(hash.clone(), make_annotation("rule_a", None));
807 store.record(hash.clone(), make_annotation("rule_b", None));
808 let entry = store.lookup(&hash).unwrap();
809 assert_eq!(entry.rule_name, "rule_a");
810 }
811
812 #[test]
813 fn lookup_returns_first_annotation() {
814 let store = ProvenanceStore::new();
815 let hash = b"fact1".to_vec();
816 store.record(hash.clone(), make_annotation("rule_a", None));
817 assert_eq!(store.lookup(&hash).unwrap().rule_name, "rule_a");
818 assert!(store.lookup(b"nonexistent").is_none());
819 }
820
821 #[test]
822 fn lookup_all_returns_all_annotations() {
823 let store = ProvenanceStore::new();
824 let hash = b"fact1".to_vec();
825 store.record(hash.clone(), make_annotation("rule_a", None));
827 let all = store.lookup_all(&hash).unwrap();
828 assert_eq!(all.len(), 1);
829 }
830
831 #[test]
832 fn record_top_k_retains_highest() {
833 let store = ProvenanceStore::new();
834 let hash = b"fact1".to_vec();
835 store.record_top_k(hash.clone(), make_annotation("low", Some(0.1)), 2);
836 store.record_top_k(hash.clone(), make_annotation("high", Some(0.9)), 2);
837 store.record_top_k(hash.clone(), make_annotation("mid", Some(0.5)), 2);
838 store.record_top_k(hash.clone(), make_annotation("highest", Some(0.95)), 2);
839 store.record_top_k(hash.clone(), make_annotation("lowest", Some(0.01)), 2);
840
841 let all = store.lookup_all(&hash).unwrap();
842 assert_eq!(all.len(), 2);
843 assert_eq!(all[0].rule_name, "highest");
844 assert_eq!(all[1].rule_name, "high");
845 }
846
847 #[test]
848 fn compute_proof_probability_basic() {
849 let support = vec![
850 ProofTerm {
851 source_rule: "r1".to_string(),
852 base_fact_id: b"f1".to_vec(),
853 },
854 ProofTerm {
855 source_rule: "r2".to_string(),
856 base_fact_id: b"f2".to_vec(),
857 },
858 ];
859 let base_probs = HashMap::from([(b"f1".to_vec(), 0.3), (b"f2".to_vec(), 0.5)]);
860 let prob = compute_proof_probability(&support, &base_probs);
861 assert!((prob.unwrap() - 0.15).abs() < 1e-10);
862 }
863
864 #[test]
865 fn compute_proof_probability_empty_support() {
866 let prob = compute_proof_probability(&[], &HashMap::new());
867 assert!(prob.is_none());
868 }
869
870 #[test]
871 fn compute_proof_probability_missing_fact() {
872 let support = vec![ProofTerm {
873 source_rule: "r1".to_string(),
874 base_fact_id: b"unknown".to_vec(),
875 }];
876 let prob = compute_proof_probability(&support, &HashMap::new());
877 assert!(prob.is_none());
878 }
879
880 #[test]
881 fn entries_for_rule_filters_correctly() {
882 let store = ProvenanceStore::new();
883 store.record(b"f1".to_vec(), make_annotation("rule_a", None));
884 store.record(b"f2".to_vec(), make_annotation("rule_b", None));
885 store.record(b"f3".to_vec(), make_annotation("rule_a", None));
886
887 let entries = store.entries_for_rule("rule_a");
888 assert_eq!(entries.len(), 2);
889 let entries_b = store.entries_for_rule("rule_b");
890 assert_eq!(entries_b.len(), 1);
891 }
892}