1use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, FactRow, LocyConfig, LocyError, LocyStats};
19
20use super::locy_delta::{
21 KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{eval_expr, record_batches_to_locy_rows, values_equal_for_join};
25use super::locy_slg::SLGResolver;
26use super::locy_traits::DerivedFactSource;
27
28#[derive(Clone, Debug)]
34pub struct ProofTerm {
35 pub source_rule: String,
37 pub base_fact_id: Vec<u8>,
39}
40
41#[derive(Clone, Debug)]
46pub struct ProvenanceAnnotation {
47 pub rule_name: String,
49 pub clause_index: usize,
51 pub support: Vec<ProofTerm>,
53 pub along_values: HashMap<String, Value>,
55 pub iteration: usize,
57 pub fact_row: FactRow,
59 pub proof_probability: Option<f64>,
61}
62
63#[derive(Debug)]
70pub struct ProvenanceStore {
71 entries: RwLock<HashMap<Vec<u8>, Vec<ProvenanceAnnotation>>>,
72}
73
74impl ProvenanceStore {
75 pub fn new() -> Self {
76 Self {
77 entries: RwLock::new(HashMap::new()),
78 }
79 }
80
81 pub fn record(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation) {
84 if let Ok(mut guard) = self.entries.write() {
85 guard.entry(fact_hash).or_insert_with(|| vec![entry]);
86 }
87 }
88
89 pub fn record_top_k(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation, k: usize) {
95 if let Ok(mut guard) = self.entries.write() {
96 let vec = guard.entry(fact_hash).or_default();
97 vec.push(entry);
98 vec.sort_by(|a, b| {
100 b.proof_probability
101 .unwrap_or(0.0)
102 .partial_cmp(&a.proof_probability.unwrap_or(0.0))
103 .unwrap_or(std::cmp::Ordering::Equal)
104 });
105 vec.truncate(k);
106 }
107 }
108
109 pub fn lookup(&self, fact_hash: &[u8]) -> Option<ProvenanceAnnotation> {
111 self.entries.read().ok()?.get(fact_hash)?.first().cloned()
112 }
113
114 pub fn lookup_all(&self, fact_hash: &[u8]) -> Option<Vec<ProvenanceAnnotation>> {
116 let guard = self.entries.read().ok()?;
117 guard.get(fact_hash).cloned()
118 }
119
120 pub fn base_fact_probs(&self) -> HashMap<Vec<u8>, f64> {
126 let mut probs = HashMap::new();
127 if let Ok(guard) = self.entries.read() {
128 for (fact_hash, annotations) in guard.iter() {
129 if let Some(ann) = annotations.first()
130 && ann.support.is_empty()
131 && let Some(uni_common::Value::Float(p)) = ann.fact_row.get("PROB")
132 {
133 probs.insert(fact_hash.clone(), *p);
134 }
135 }
136 }
137 probs
138 }
139
140 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, ProvenanceAnnotation)> {
142 match self.entries.read() {
143 Ok(guard) => guard
144 .iter()
145 .filter_map(|(k, annotations)| {
146 annotations
147 .first()
148 .filter(|e| e.rule_name == rule_name)
149 .map(|e| (k.clone(), e.clone()))
150 })
151 .collect(),
152 Err(_) => vec![],
153 }
154 }
155}
156
157pub fn compute_proof_probability(
161 support: &[ProofTerm],
162 base_probs: &HashMap<Vec<u8>, f64>,
163) -> Option<f64> {
164 if support.is_empty() {
165 return None;
166 }
167 let mut product = 1.0;
168 for term in support {
169 let p = base_probs.get(&term.base_fact_id)?;
170 product *= p;
171 }
172 Some(product)
173}
174
175impl Default for ProvenanceStore {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181type VisitedSet = HashSet<(String, KeyTuple)>;
183
184#[expect(
190 clippy::too_many_arguments,
191 reason = "explain requires full program context and tracker state"
192)]
193pub async fn explain_rule(
194 query: &ExplainRule,
195 program: &CompiledProgram,
196 fact_source: &dyn DerivedFactSource,
197 config: &LocyConfig,
198 derived_store: &mut RowStore,
199 stats: &mut LocyStats,
200 tracker: Option<&ProvenanceStore>,
201 approximate_groups: Option<&HashMap<String, Vec<String>>>,
202) -> Result<DerivationNode, LocyError> {
203 if let Some(Ok(node)) =
206 tracker.map(|t| explain_rule_mode_a(query, program, t, derived_store, approximate_groups))
207 {
208 return Ok(node);
209 }
210
211 explain_rule_mode_b(
213 query,
214 program,
215 fact_source,
216 config,
217 derived_store,
218 stats,
219 approximate_groups,
220 )
221 .await
222}
223
224fn explain_rule_mode_a(
228 query: &ExplainRule,
229 program: &CompiledProgram,
230 tracker: &ProvenanceStore,
231 _derived_store: &RowStore,
232 approximate_groups: Option<&HashMap<String, Vec<String>>>,
233) -> Result<DerivationNode, LocyError> {
234 let rule_name = query.rule_name.to_string();
235 let rule = program
236 .rule_catalog
237 .get(&rule_name)
238 .ok_or_else(|| LocyError::EvaluationError {
239 message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
240 })?;
241
242 let tracker_entries = tracker.entries_for_rule(&rule_name);
243 if tracker_entries.is_empty() {
244 return Err(LocyError::EvaluationError {
245 message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
246 });
247 }
248
249 let matching_entries: Vec<_> = if let Some(where_expr) = &query.where_expr {
251 tracker_entries
252 .into_iter()
253 .filter(|(_, entry)| {
254 eval_expr(where_expr, &entry.fact_row)
255 .map(|v| v.as_bool().unwrap_or(false))
256 .unwrap_or(false)
257 })
258 .collect()
259 } else {
260 tracker_entries
261 };
262
263 if matching_entries.is_empty() {
264 return Err(LocyError::EvaluationError {
265 message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
266 });
267 }
268
269 let is_approximate = approximate_groups
270 .map(|ag| ag.contains_key(&rule_name))
271 .unwrap_or(false);
272
273 let mut root = DerivationNode {
274 rule: rule_name.clone(),
275 clause_index: 0,
276 priority: rule.priority,
277 bindings: HashMap::new(),
278 along_values: HashMap::new(),
279 children: Vec::new(),
280 graph_fact: None,
281 approximate: is_approximate,
282 proof_probability: None,
283 };
284
285 for (_, entry) in matching_entries {
286 let along_values = extract_along_values(&entry.fact_row, rule);
287 let clause_priority = rule
288 .clauses
289 .get(entry.clause_index)
290 .and_then(|c| c.priority);
291 let base_fact = format!(
292 "[iter={}] {}",
293 entry.iteration,
294 format_graph_fact(&entry.fact_row)
295 );
296 let graph_fact = if is_approximate {
297 format!("[APPROXIMATE] {}", base_fact)
298 } else {
299 base_fact
300 };
301 let node = DerivationNode {
302 rule: rule_name.clone(),
303 clause_index: entry.clause_index,
304 priority: clause_priority.or(rule.priority),
305 bindings: entry.fact_row.clone(),
306 along_values,
307 children: vec![],
309 graph_fact: Some(graph_fact),
310 approximate: is_approximate,
311 proof_probability: entry.proof_probability,
312 };
313 root.children.push(node);
314 }
315
316 Ok(root)
317}
318
319async fn explain_rule_mode_b(
322 query: &ExplainRule,
323 program: &CompiledProgram,
324 fact_source: &dyn DerivedFactSource,
325 config: &LocyConfig,
326 derived_store: &mut RowStore,
327 stats: &mut LocyStats,
328 approximate_groups: Option<&HashMap<String, Vec<String>>>,
329) -> Result<DerivationNode, LocyError> {
330 let rule_name = query.rule_name.to_string();
331 let rule = program
332 .rule_catalog
333 .get(&rule_name)
334 .ok_or_else(|| LocyError::EvaluationError {
335 message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
336 })?;
337
338 let key_columns: Vec<String> = rule
339 .yield_schema
340 .iter()
341 .filter(|c| c.is_key)
342 .map(|c| c.name.clone())
343 .collect();
344
345 {
349 let mut fresh_store = RowStore::new();
350 let slg_start = std::time::Instant::now();
351 let mut resolver =
352 SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
353 resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
354 stats.queries_executed += resolver.stats.queries_executed;
355 for (name, relation) in fresh_store {
358 derived_store.insert(name, relation);
359 }
360 }
361
362 let facts = derived_store
364 .get(&rule_name)
365 .map(|r| r.rows.clone())
366 .unwrap_or_default();
367
368 let filtered: Vec<FactRow> = if let Some(where_expr) = &query.where_expr {
370 facts
371 .into_iter()
372 .filter(|row| {
373 eval_expr(where_expr, row)
374 .map(|v| v.as_bool().unwrap_or(false))
375 .unwrap_or(false)
376 })
377 .collect()
378 } else {
379 facts
380 };
381
382 let is_approximate = approximate_groups
383 .map(|ag| ag.contains_key(&rule_name))
384 .unwrap_or(false);
385
386 let mut root = DerivationNode {
388 rule: rule_name.clone(),
389 clause_index: 0,
390 priority: rule.priority,
391 bindings: HashMap::new(),
392 along_values: HashMap::new(),
393 children: Vec::new(),
394 graph_fact: None,
395 approximate: is_approximate,
396 proof_probability: None,
397 };
398
399 for fact in &filtered {
401 let mut visited = VisitedSet::new();
402 let mut node = build_derivation_node(
403 &rule_name,
404 fact,
405 &key_columns,
406 program,
407 fact_source,
408 derived_store,
409 stats,
410 &mut visited,
411 config.max_explain_depth,
412 )
413 .await?;
414 if is_approximate {
415 node.approximate = true;
416 if let Some(ref gf) = node.graph_fact {
417 node.graph_fact = Some(format!("[APPROXIMATE] {}", gf));
418 }
419 }
420 root.children.push(node);
421 }
422
423 Ok(root)
424}
425
426#[expect(
431 clippy::too_many_arguments,
432 reason = "recursive derivation node builder requires full fact context"
433)]
434fn build_derivation_node<'a>(
435 rule_name: &'a str,
436 fact: &'a FactRow,
437 key_columns: &'a [String],
438 program: &'a CompiledProgram,
439 fact_source: &'a dyn DerivedFactSource,
440 derived_store: &'a mut RowStore,
441 stats: &'a mut LocyStats,
442 visited: &'a mut VisitedSet,
443 max_depth: usize,
444) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + 'a>> {
445 Box::pin(async move {
446 let rule =
447 program
448 .rule_catalog
449 .get(rule_name)
450 .ok_or_else(|| LocyError::EvaluationError {
451 message: format!("rule '{}' not found during EXPLAIN", rule_name),
452 })?;
453
454 let key_tuple = extract_key(fact, key_columns);
455 let visit_key = (rule_name.to_string(), key_tuple);
456
457 if !visited.insert(visit_key.clone()) || max_depth == 0 {
459 return Ok(DerivationNode {
460 rule: rule_name.to_string(),
461 clause_index: 0,
462 priority: rule.priority,
463 bindings: fact.clone(),
464 along_values: extract_along_values(fact, rule),
465 children: Vec::new(),
466 graph_fact: Some("(cycle)".to_string()),
467 approximate: false,
468 proof_probability: None,
469 });
470 }
471
472 for (clause_idx, clause) in rule.clauses.iter().enumerate() {
479 let has_is_refs = clause
480 .where_conditions
481 .iter()
482 .any(|c| matches!(c, RuleCondition::IsReference(_)));
483 let has_along = !clause.along.is_empty();
484
485 let resolved = if has_is_refs || has_along {
486 let rows = resolve_clause_with_is_refs(
487 clause,
488 fact_source,
489 derived_store,
490 &program.rule_catalog,
491 None, )
493 .await?;
494 stats.queries_executed += 1;
495 rows
496 } else {
497 let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
498 let raw_batches = fact_source
499 .execute_pattern(&clause.match_pattern, &cypher_conditions)
500 .await?;
501 stats.queries_executed += 1;
502 record_batches_to_locy_rows(&raw_batches)
503 };
504
505 let matching_row = resolved.iter().find(|row| {
509 key_columns.iter().all(|k| match (row.get(k), fact.get(k)) {
510 (Some(v1), Some(v2)) => values_equal_for_join(v1, v2),
511 (None, None) => true,
512 _ => false,
513 })
514 });
515
516 if let Some(evidence_row) = matching_row {
517 let along_values = extract_along_values(fact, rule);
518
519 let mut children = Vec::new();
521 for cond in &clause.where_conditions {
522 if let RuleCondition::IsReference(is_ref) = cond {
523 if is_ref.negated {
524 continue;
525 }
526 let ref_rule_name = is_ref.rule_name.to_string();
527 if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
528 let ref_key_columns: Vec<String> = ref_rule
529 .yield_schema
530 .iter()
531 .filter(|c| c.is_key)
532 .map(|c| c.name.clone())
533 .collect();
534
535 let ref_facts: Vec<FactRow> = derived_store
536 .get(&ref_rule_name)
537 .map(|r| r.rows.clone())
538 .unwrap_or_default();
539
540 let matching_ref_facts: Vec<FactRow> = ref_facts
541 .into_iter()
542 .filter(|ref_fact| {
543 let subjects_match =
544 is_ref.subjects.iter().enumerate().all(|(i, subject)| {
545 binding_matches_key(
546 evidence_row,
547 fact,
548 subject,
549 ref_fact,
550 ref_key_columns.get(i),
551 )
552 });
553 let target_matches =
554 is_ref.target.as_ref().is_none_or(|target| {
555 binding_matches_key(
556 evidence_row,
557 fact,
558 target,
559 ref_fact,
560 ref_key_columns.get(is_ref.subjects.len()),
561 )
562 });
563 subjects_match && target_matches
564 })
565 .collect();
566
567 for ref_fact in matching_ref_facts {
568 let child = build_derivation_node(
569 &ref_rule_name,
570 &ref_fact,
571 &ref_key_columns,
572 program,
573 fact_source,
574 derived_store,
575 stats,
576 visited,
577 max_depth - 1,
578 )
579 .await?;
580 children.push(child);
581 }
582 }
583 }
584 }
585
586 visited.remove(&visit_key);
588
589 let mut merged_bindings = evidence_row.clone();
590 for (k, v) in fact {
591 merged_bindings.entry(k.clone()).or_insert(v.clone());
592 }
593
594 return Ok(DerivationNode {
595 rule: rule_name.to_string(),
596 clause_index: clause_idx,
597 priority: rule.clauses[clause_idx].priority,
598 bindings: merged_bindings,
599 along_values,
600 children,
601 graph_fact: Some(format_graph_fact(evidence_row)),
602 approximate: false,
603 proof_probability: None,
604 });
605 }
606 }
607
608 visited.remove(&visit_key);
610 Ok(DerivationNode {
611 rule: rule_name.to_string(),
612 clause_index: 0,
613 priority: rule.priority,
614 bindings: fact.clone(),
615 along_values: extract_along_values(fact, rule),
616 children: Vec::new(),
617 graph_fact: Some(format_graph_fact(fact)),
618 approximate: false,
619 proof_probability: None,
620 })
621 })
622}
623
624fn binding_matches_key(
630 primary: &FactRow,
631 fallback: &FactRow,
632 var_name: &str,
633 ref_fact: &FactRow,
634 ref_key_col: Option<&String>,
635) -> bool {
636 let Some(key_col) = ref_key_col else {
637 return true;
638 };
639 let Some(val) = primary.get(var_name).or_else(|| fallback.get(var_name)) else {
640 return true;
641 };
642 ref_fact
643 .get(key_col)
644 .is_some_and(|rv| values_equal_for_join(rv, val))
645}
646
647fn extract_along_values(fact: &FactRow, rule: &CompiledRule) -> HashMap<String, Value> {
648 let mut along_values = HashMap::new();
649 for clause in &rule.clauses {
650 for along in &clause.along {
651 if let Some(v) = fact.get(&along.name) {
652 along_values.insert(along.name.clone(), v.clone());
653 }
654 }
655 }
656 along_values
657}
658
659pub(crate) fn format_graph_fact(row: &FactRow) -> String {
660 let mut entries: Vec<String> = row
661 .iter()
662 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
663 .collect();
664 entries.sort();
665 format!("{{{}}}", entries.join(", "))
666}
667
668fn format_value(v: &Value) -> String {
669 match v {
670 Value::Null => "null".to_string(),
671 Value::Bool(b) => b.to_string(),
672 Value::Int(i) => i.to_string(),
673 Value::Float(f) => f.to_string(),
674 Value::String(s) => format!("\"{}\"", s),
675 Value::List(items) => {
676 let inner: Vec<String> = items.iter().map(format_value).collect();
677 format!("[{}]", inner.join(", "))
678 }
679 Value::Map(m) => {
680 let mut entries: Vec<String> = m
681 .iter()
682 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
683 .collect();
684 entries.sort();
685 format!("{{{}}}", entries.join(", "))
686 }
687 Value::Node(n) => format!("Node({})", n.vid.as_u64()),
688 Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
689 _ => format!("{v:?}"),
690 }
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 fn make_annotation(rule: &str, prob: Option<f64>) -> ProvenanceAnnotation {
698 ProvenanceAnnotation {
699 rule_name: rule.to_string(),
700 clause_index: 0,
701 support: vec![],
702 along_values: HashMap::new(),
703 iteration: 0,
704 fact_row: HashMap::new(),
705 proof_probability: prob,
706 }
707 }
708
709 #[test]
710 fn record_first_derivation_wins() {
711 let store = ProvenanceStore::new();
712 let hash = b"fact1".to_vec();
713 store.record(hash.clone(), make_annotation("rule_a", None));
714 store.record(hash.clone(), make_annotation("rule_b", None));
715 let entry = store.lookup(&hash).unwrap();
716 assert_eq!(entry.rule_name, "rule_a");
717 }
718
719 #[test]
720 fn lookup_returns_first_annotation() {
721 let store = ProvenanceStore::new();
722 let hash = b"fact1".to_vec();
723 store.record(hash.clone(), make_annotation("rule_a", None));
724 assert_eq!(store.lookup(&hash).unwrap().rule_name, "rule_a");
725 assert!(store.lookup(b"nonexistent").is_none());
726 }
727
728 #[test]
729 fn lookup_all_returns_all_annotations() {
730 let store = ProvenanceStore::new();
731 let hash = b"fact1".to_vec();
732 store.record(hash.clone(), make_annotation("rule_a", None));
734 let all = store.lookup_all(&hash).unwrap();
735 assert_eq!(all.len(), 1);
736 }
737
738 #[test]
739 fn record_top_k_retains_highest() {
740 let store = ProvenanceStore::new();
741 let hash = b"fact1".to_vec();
742 store.record_top_k(hash.clone(), make_annotation("low", Some(0.1)), 2);
743 store.record_top_k(hash.clone(), make_annotation("high", Some(0.9)), 2);
744 store.record_top_k(hash.clone(), make_annotation("mid", Some(0.5)), 2);
745 store.record_top_k(hash.clone(), make_annotation("highest", Some(0.95)), 2);
746 store.record_top_k(hash.clone(), make_annotation("lowest", Some(0.01)), 2);
747
748 let all = store.lookup_all(&hash).unwrap();
749 assert_eq!(all.len(), 2);
750 assert_eq!(all[0].rule_name, "highest");
751 assert_eq!(all[1].rule_name, "high");
752 }
753
754 #[test]
755 fn compute_proof_probability_basic() {
756 let support = vec![
757 ProofTerm {
758 source_rule: "r1".to_string(),
759 base_fact_id: b"f1".to_vec(),
760 },
761 ProofTerm {
762 source_rule: "r2".to_string(),
763 base_fact_id: b"f2".to_vec(),
764 },
765 ];
766 let base_probs = HashMap::from([(b"f1".to_vec(), 0.3), (b"f2".to_vec(), 0.5)]);
767 let prob = compute_proof_probability(&support, &base_probs);
768 assert!((prob.unwrap() - 0.15).abs() < 1e-10);
769 }
770
771 #[test]
772 fn compute_proof_probability_empty_support() {
773 let prob = compute_proof_probability(&[], &HashMap::new());
774 assert!(prob.is_none());
775 }
776
777 #[test]
778 fn compute_proof_probability_missing_fact() {
779 let support = vec![ProofTerm {
780 source_rule: "r1".to_string(),
781 base_fact_id: b"unknown".to_vec(),
782 }];
783 let prob = compute_proof_probability(&support, &HashMap::new());
784 assert!(prob.is_none());
785 }
786
787 #[test]
788 fn entries_for_rule_filters_correctly() {
789 let store = ProvenanceStore::new();
790 store.record(b"f1".to_vec(), make_annotation("rule_a", None));
791 store.record(b"f2".to_vec(), make_annotation("rule_b", None));
792 store.record(b"f3".to_vec(), make_annotation("rule_a", None));
793
794 let entries = store.entries_for_rule("rule_a");
795 assert_eq!(entries.len(), 2);
796 let entries_b = store.entries_for_rule("rule_b");
797 assert_eq!(entries_b.len(), 1);
798 }
799}