1use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, FactRow, LocyConfig, LocyError, LocyStats};
19
20use super::locy_delta::{
21 KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{eval_expr, record_batches_to_locy_rows, values_equal_for_join};
25use super::locy_slg::SLGResolver;
26use super::locy_traits::DerivedFactSource;
27
28#[derive(Clone, Debug)]
34pub struct ProofTerm {
35 pub source_rule: String,
37 pub base_fact_id: Vec<u8>,
39}
40
41#[derive(Clone, Debug)]
46pub struct ProvenanceAnnotation {
47 pub rule_name: String,
49 pub clause_index: usize,
51 pub support: Vec<ProofTerm>,
53 pub along_values: HashMap<String, Value>,
55 pub iteration: usize,
57 pub fact_row: FactRow,
59 pub proof_probability: Option<f64>,
61}
62
63#[derive(Debug)]
70pub struct ProvenanceStore {
71 entries: RwLock<HashMap<Vec<u8>, Vec<ProvenanceAnnotation>>>,
72}
73
74impl ProvenanceStore {
75 pub fn new() -> Self {
76 Self {
77 entries: RwLock::new(HashMap::new()),
78 }
79 }
80
81 pub fn record(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation) {
84 if let Ok(mut guard) = self.entries.write() {
85 guard.entry(fact_hash).or_insert_with(|| vec![entry]);
86 }
87 }
88
89 pub fn record_top_k(&self, fact_hash: Vec<u8>, entry: ProvenanceAnnotation, k: usize) {
95 if let Ok(mut guard) = self.entries.write() {
96 let vec = guard.entry(fact_hash).or_default();
97 vec.push(entry);
98 vec.sort_by(|a, b| {
100 b.proof_probability
101 .unwrap_or(0.0)
102 .partial_cmp(&a.proof_probability.unwrap_or(0.0))
103 .unwrap_or(std::cmp::Ordering::Equal)
104 });
105 vec.truncate(k);
106 }
107 }
108
109 pub fn lookup(&self, fact_hash: &[u8]) -> Option<ProvenanceAnnotation> {
111 self.entries.read().ok()?.get(fact_hash)?.first().cloned()
112 }
113
114 pub fn lookup_all(&self, fact_hash: &[u8]) -> Option<Vec<ProvenanceAnnotation>> {
116 let guard = self.entries.read().ok()?;
117 guard.get(fact_hash).cloned()
118 }
119
120 pub fn base_fact_probs(&self) -> HashMap<Vec<u8>, f64> {
126 let mut probs = HashMap::new();
127 if let Ok(guard) = self.entries.read() {
128 for (fact_hash, annotations) in guard.iter() {
129 if let Some(ann) = annotations.first()
130 && ann.support.is_empty()
131 && let Some(uni_common::Value::Float(p)) = ann.fact_row.get("PROB")
132 {
133 probs.insert(fact_hash.clone(), *p);
134 }
135 }
136 }
137 probs
138 }
139
140 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, ProvenanceAnnotation)> {
142 match self.entries.read() {
143 Ok(guard) => guard
144 .iter()
145 .filter_map(|(k, annotations)| {
146 annotations
147 .first()
148 .filter(|e| e.rule_name == rule_name)
149 .map(|e| (k.clone(), e.clone()))
150 })
151 .collect(),
152 Err(_) => vec![],
153 }
154 }
155}
156
157pub fn compute_proof_probability(
161 support: &[ProofTerm],
162 base_probs: &HashMap<Vec<u8>, f64>,
163) -> Option<f64> {
164 if support.is_empty() {
165 return None;
166 }
167 let mut product = 1.0;
168 for term in support {
169 let p = base_probs.get(&term.base_fact_id)?;
170 product *= p;
171 }
172 Some(product)
173}
174
175impl Default for ProvenanceStore {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181type VisitedSet = HashSet<(String, KeyTuple)>;
183
184#[expect(
190 clippy::too_many_arguments,
191 reason = "explain requires full program context and tracker state"
192)]
193pub async fn explain_rule(
194 query: &ExplainRule,
195 program: &CompiledProgram,
196 fact_source: &dyn DerivedFactSource,
197 config: &LocyConfig,
198 derived_store: &mut RowStore,
199 stats: &mut LocyStats,
200 tracker: Option<&ProvenanceStore>,
201 approximate_groups: Option<&HashMap<String, Vec<String>>>,
202) -> Result<DerivationNode, LocyError> {
203 if let Some(Ok(node)) =
206 tracker.map(|t| explain_rule_mode_a(query, program, t, derived_store, approximate_groups))
207 {
208 return Ok(node);
209 }
210
211 explain_rule_mode_b(
213 query,
214 program,
215 fact_source,
216 config,
217 derived_store,
218 stats,
219 approximate_groups,
220 )
221 .await
222}
223
224fn explain_rule_mode_a(
228 query: &ExplainRule,
229 program: &CompiledProgram,
230 tracker: &ProvenanceStore,
231 _derived_store: &RowStore,
232 approximate_groups: Option<&HashMap<String, Vec<String>>>,
233) -> Result<DerivationNode, LocyError> {
234 let rule_name = query.rule_name.to_string();
235 let rule = program
236 .rule_catalog
237 .get(&rule_name)
238 .ok_or_else(|| LocyError::EvaluationError {
239 message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
240 })?;
241
242 let tracker_entries = tracker.entries_for_rule(&rule_name);
243 if tracker_entries.is_empty() {
244 return Err(LocyError::EvaluationError {
245 message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
246 });
247 }
248
249 let matching_entries: Vec<_> = if let Some(where_expr) = &query.where_expr {
251 tracker_entries
252 .into_iter()
253 .filter(|(_, entry)| {
254 eval_expr(where_expr, &entry.fact_row)
255 .map(|v| v.as_bool().unwrap_or(false))
256 .unwrap_or(false)
257 })
258 .collect()
259 } else {
260 tracker_entries
261 };
262
263 if matching_entries.is_empty() {
264 return Err(LocyError::EvaluationError {
265 message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
266 });
267 }
268
269 let is_approximate = approximate_groups
270 .map(|ag| ag.contains_key(&rule_name))
271 .unwrap_or(false);
272
273 let mut root = DerivationNode {
274 rule: rule_name.clone(),
275 clause_index: 0,
276 priority: rule.priority,
277 bindings: HashMap::new(),
278 along_values: HashMap::new(),
279 children: Vec::new(),
280 graph_fact: None,
281 approximate: is_approximate,
282 proof_probability: None,
283 };
284
285 for (_, entry) in matching_entries {
286 let along_values = extract_along_values(&entry.fact_row, rule);
287 let clause_priority = rule
288 .clauses
289 .get(entry.clause_index)
290 .and_then(|c| c.priority);
291 let base_fact = format!(
292 "[iter={}] {}",
293 entry.iteration,
294 format_graph_fact(&entry.fact_row)
295 );
296 let graph_fact = if is_approximate {
297 format!("[APPROXIMATE] {}", base_fact)
298 } else {
299 base_fact
300 };
301 let node = DerivationNode {
302 rule: rule_name.clone(),
303 clause_index: entry.clause_index,
304 priority: clause_priority.or(rule.priority),
305 bindings: entry.fact_row.clone(),
306 along_values,
307 children: vec![],
309 graph_fact: Some(graph_fact),
310 approximate: is_approximate,
311 proof_probability: entry.proof_probability,
312 };
313 root.children.push(node);
314 }
315
316 Ok(root)
317}
318
319async fn explain_rule_mode_b(
322 query: &ExplainRule,
323 program: &CompiledProgram,
324 fact_source: &dyn DerivedFactSource,
325 config: &LocyConfig,
326 derived_store: &mut RowStore,
327 stats: &mut LocyStats,
328 approximate_groups: Option<&HashMap<String, Vec<String>>>,
329) -> Result<DerivationNode, LocyError> {
330 let rule_name = query.rule_name.to_string();
331 let rule = program
332 .rule_catalog
333 .get(&rule_name)
334 .ok_or_else(|| LocyError::EvaluationError {
335 message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
336 })?;
337
338 let key_columns: Vec<String> = rule
339 .yield_schema
340 .iter()
341 .filter(|c| c.is_key)
342 .map(|c| c.name.clone())
343 .collect();
344
345 {
349 let mut fresh_store = RowStore::new();
350 let slg_start = std::time::Instant::now();
351 let mut resolver =
352 SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
353 resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
354 stats.queries_executed += resolver.stats.queries_executed;
355 for (name, relation) in fresh_store {
358 derived_store.insert(name, relation);
359 }
360 }
361
362 let facts = derived_store
364 .get(&rule_name)
365 .map(|r| r.rows.clone())
366 .unwrap_or_default();
367
368 let filtered: Vec<FactRow> = if let Some(where_expr) = &query.where_expr {
370 facts
371 .into_iter()
372 .filter(|row| {
373 eval_expr(where_expr, row)
374 .map(|v| v.as_bool().unwrap_or(false))
375 .unwrap_or(false)
376 })
377 .collect()
378 } else {
379 facts
380 };
381
382 let is_approximate = approximate_groups
383 .map(|ag| ag.contains_key(&rule_name))
384 .unwrap_or(false);
385
386 let mut root = DerivationNode {
388 rule: rule_name.clone(),
389 clause_index: 0,
390 priority: rule.priority,
391 bindings: HashMap::new(),
392 along_values: HashMap::new(),
393 children: Vec::new(),
394 graph_fact: None,
395 approximate: is_approximate,
396 proof_probability: None,
397 };
398
399 for fact in &filtered {
401 let mut visited = VisitedSet::new();
402 let mut node = build_derivation_node(
403 &rule_name,
404 fact,
405 &key_columns,
406 program,
407 fact_source,
408 derived_store,
409 stats,
410 &mut visited,
411 config.max_explain_depth,
412 )
413 .await?;
414 if is_approximate {
415 node.approximate = true;
416 if let Some(ref gf) = node.graph_fact {
417 node.graph_fact = Some(format!("[APPROXIMATE] {}", gf));
418 }
419 }
420 root.children.push(node);
421 }
422
423 Ok(root)
424}
425
426#[expect(
431 clippy::too_many_arguments,
432 reason = "recursive derivation node builder requires full fact context"
433)]
434fn build_derivation_node<'a>(
435 rule_name: &'a str,
436 fact: &'a FactRow,
437 key_columns: &'a [String],
438 program: &'a CompiledProgram,
439 fact_source: &'a dyn DerivedFactSource,
440 derived_store: &'a mut RowStore,
441 stats: &'a mut LocyStats,
442 visited: &'a mut VisitedSet,
443 max_depth: usize,
444) -> std::pin::Pin<
445 Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + Send + 'a>,
446> {
447 Box::pin(async move {
448 let rule =
449 program
450 .rule_catalog
451 .get(rule_name)
452 .ok_or_else(|| LocyError::EvaluationError {
453 message: format!("rule '{}' not found during EXPLAIN", rule_name),
454 })?;
455
456 let key_tuple = extract_key(fact, key_columns);
457 let visit_key = (rule_name.to_string(), key_tuple);
458
459 if !visited.insert(visit_key.clone()) || max_depth == 0 {
461 return Ok(DerivationNode {
462 rule: rule_name.to_string(),
463 clause_index: 0,
464 priority: rule.priority,
465 bindings: fact.clone(),
466 along_values: extract_along_values(fact, rule),
467 children: Vec::new(),
468 graph_fact: Some("(cycle)".to_string()),
469 approximate: false,
470 proof_probability: None,
471 });
472 }
473
474 for (clause_idx, clause) in rule.clauses.iter().enumerate() {
481 let has_is_refs = clause
482 .where_conditions
483 .iter()
484 .any(|c| matches!(c, RuleCondition::IsReference(_)));
485 let has_along = !clause.along.is_empty();
486
487 let resolved = if has_is_refs || has_along {
488 let rows = resolve_clause_with_is_refs(
489 clause,
490 fact_source,
491 derived_store,
492 &program.rule_catalog,
493 None, )
495 .await?;
496 stats.queries_executed += 1;
497 rows
498 } else {
499 let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
500 let raw_batches = fact_source
501 .execute_pattern(&clause.match_pattern, &cypher_conditions)
502 .await?;
503 stats.queries_executed += 1;
504 record_batches_to_locy_rows(&raw_batches)
505 };
506
507 let matching_row = resolved.iter().find(|row| {
511 key_columns.iter().all(|k| match (row.get(k), fact.get(k)) {
512 (Some(v1), Some(v2)) => values_equal_for_join(v1, v2),
513 (None, None) => true,
514 _ => false,
515 })
516 });
517
518 if let Some(evidence_row) = matching_row {
519 let along_values = extract_along_values(fact, rule);
520
521 let mut children = Vec::new();
523 for cond in &clause.where_conditions {
524 if let RuleCondition::IsReference(is_ref) = cond {
525 if is_ref.negated {
526 continue;
527 }
528 let ref_rule_name = is_ref.rule_name.to_string();
529 if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
530 let ref_key_columns: Vec<String> = ref_rule
531 .yield_schema
532 .iter()
533 .filter(|c| c.is_key)
534 .map(|c| c.name.clone())
535 .collect();
536
537 let ref_facts: Vec<FactRow> = derived_store
538 .get(&ref_rule_name)
539 .map(|r| r.rows.clone())
540 .unwrap_or_default();
541
542 let matching_ref_facts: Vec<FactRow> = ref_facts
543 .into_iter()
544 .filter(|ref_fact| {
545 let subjects_match =
546 is_ref.subjects.iter().enumerate().all(|(i, subject)| {
547 binding_matches_key(
548 evidence_row,
549 fact,
550 subject,
551 ref_fact,
552 ref_key_columns.get(i),
553 )
554 });
555 let target_matches =
556 is_ref.target.as_ref().is_none_or(|target| {
557 binding_matches_key(
558 evidence_row,
559 fact,
560 target,
561 ref_fact,
562 ref_key_columns.get(is_ref.subjects.len()),
563 )
564 });
565 subjects_match && target_matches
566 })
567 .collect();
568
569 for ref_fact in matching_ref_facts {
570 let child = build_derivation_node(
571 &ref_rule_name,
572 &ref_fact,
573 &ref_key_columns,
574 program,
575 fact_source,
576 derived_store,
577 stats,
578 visited,
579 max_depth - 1,
580 )
581 .await?;
582 children.push(child);
583 }
584 }
585 }
586 }
587
588 visited.remove(&visit_key);
590
591 let mut merged_bindings = evidence_row.clone();
592 for (k, v) in fact {
593 merged_bindings.entry(k.clone()).or_insert(v.clone());
594 }
595
596 return Ok(DerivationNode {
597 rule: rule_name.to_string(),
598 clause_index: clause_idx,
599 priority: rule.clauses[clause_idx].priority,
600 bindings: merged_bindings,
601 along_values,
602 children,
603 graph_fact: Some(format_graph_fact(evidence_row)),
604 approximate: false,
605 proof_probability: None,
606 });
607 }
608 }
609
610 visited.remove(&visit_key);
612 Ok(DerivationNode {
613 rule: rule_name.to_string(),
614 clause_index: 0,
615 priority: rule.priority,
616 bindings: fact.clone(),
617 along_values: extract_along_values(fact, rule),
618 children: Vec::new(),
619 graph_fact: Some(format_graph_fact(fact)),
620 approximate: false,
621 proof_probability: None,
622 })
623 })
624}
625
626fn binding_matches_key(
632 primary: &FactRow,
633 fallback: &FactRow,
634 var_name: &str,
635 ref_fact: &FactRow,
636 ref_key_col: Option<&String>,
637) -> bool {
638 let Some(key_col) = ref_key_col else {
639 return true;
640 };
641 let Some(val) = primary.get(var_name).or_else(|| fallback.get(var_name)) else {
642 return true;
643 };
644 ref_fact
645 .get(key_col)
646 .is_some_and(|rv| values_equal_for_join(rv, val))
647}
648
649fn extract_along_values(fact: &FactRow, rule: &CompiledRule) -> HashMap<String, Value> {
650 let mut along_values = HashMap::new();
651 for clause in &rule.clauses {
652 for along in &clause.along {
653 if let Some(v) = fact.get(&along.name) {
654 along_values.insert(along.name.clone(), v.clone());
655 }
656 }
657 }
658 along_values
659}
660
661pub(crate) fn format_graph_fact(row: &FactRow) -> String {
662 let mut entries: Vec<String> = row
663 .iter()
664 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
665 .collect();
666 entries.sort();
667 format!("{{{}}}", entries.join(", "))
668}
669
670fn format_value(v: &Value) -> String {
671 match v {
672 Value::Null => "null".to_string(),
673 Value::Bool(b) => b.to_string(),
674 Value::Int(i) => i.to_string(),
675 Value::Float(f) => f.to_string(),
676 Value::String(s) => format!("\"{}\"", s),
677 Value::List(items) => {
678 let inner: Vec<String> = items.iter().map(format_value).collect();
679 format!("[{}]", inner.join(", "))
680 }
681 Value::Map(m) => {
682 let mut entries: Vec<String> = m
683 .iter()
684 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
685 .collect();
686 entries.sort();
687 format!("{{{}}}", entries.join(", "))
688 }
689 Value::Node(n) => format!("Node({})", n.vid.as_u64()),
690 Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
691 _ => format!("{v:?}"),
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698
699 fn make_annotation(rule: &str, prob: Option<f64>) -> ProvenanceAnnotation {
700 ProvenanceAnnotation {
701 rule_name: rule.to_string(),
702 clause_index: 0,
703 support: vec![],
704 along_values: HashMap::new(),
705 iteration: 0,
706 fact_row: HashMap::new(),
707 proof_probability: prob,
708 }
709 }
710
711 #[test]
712 fn record_first_derivation_wins() {
713 let store = ProvenanceStore::new();
714 let hash = b"fact1".to_vec();
715 store.record(hash.clone(), make_annotation("rule_a", None));
716 store.record(hash.clone(), make_annotation("rule_b", None));
717 let entry = store.lookup(&hash).unwrap();
718 assert_eq!(entry.rule_name, "rule_a");
719 }
720
721 #[test]
722 fn lookup_returns_first_annotation() {
723 let store = ProvenanceStore::new();
724 let hash = b"fact1".to_vec();
725 store.record(hash.clone(), make_annotation("rule_a", None));
726 assert_eq!(store.lookup(&hash).unwrap().rule_name, "rule_a");
727 assert!(store.lookup(b"nonexistent").is_none());
728 }
729
730 #[test]
731 fn lookup_all_returns_all_annotations() {
732 let store = ProvenanceStore::new();
733 let hash = b"fact1".to_vec();
734 store.record(hash.clone(), make_annotation("rule_a", None));
736 let all = store.lookup_all(&hash).unwrap();
737 assert_eq!(all.len(), 1);
738 }
739
740 #[test]
741 fn record_top_k_retains_highest() {
742 let store = ProvenanceStore::new();
743 let hash = b"fact1".to_vec();
744 store.record_top_k(hash.clone(), make_annotation("low", Some(0.1)), 2);
745 store.record_top_k(hash.clone(), make_annotation("high", Some(0.9)), 2);
746 store.record_top_k(hash.clone(), make_annotation("mid", Some(0.5)), 2);
747 store.record_top_k(hash.clone(), make_annotation("highest", Some(0.95)), 2);
748 store.record_top_k(hash.clone(), make_annotation("lowest", Some(0.01)), 2);
749
750 let all = store.lookup_all(&hash).unwrap();
751 assert_eq!(all.len(), 2);
752 assert_eq!(all[0].rule_name, "highest");
753 assert_eq!(all[1].rule_name, "high");
754 }
755
756 #[test]
757 fn compute_proof_probability_basic() {
758 let support = vec![
759 ProofTerm {
760 source_rule: "r1".to_string(),
761 base_fact_id: b"f1".to_vec(),
762 },
763 ProofTerm {
764 source_rule: "r2".to_string(),
765 base_fact_id: b"f2".to_vec(),
766 },
767 ];
768 let base_probs = HashMap::from([(b"f1".to_vec(), 0.3), (b"f2".to_vec(), 0.5)]);
769 let prob = compute_proof_probability(&support, &base_probs);
770 assert!((prob.unwrap() - 0.15).abs() < 1e-10);
771 }
772
773 #[test]
774 fn compute_proof_probability_empty_support() {
775 let prob = compute_proof_probability(&[], &HashMap::new());
776 assert!(prob.is_none());
777 }
778
779 #[test]
780 fn compute_proof_probability_missing_fact() {
781 let support = vec![ProofTerm {
782 source_rule: "r1".to_string(),
783 base_fact_id: b"unknown".to_vec(),
784 }];
785 let prob = compute_proof_probability(&support, &HashMap::new());
786 assert!(prob.is_none());
787 }
788
789 #[test]
790 fn entries_for_rule_filters_correctly() {
791 let store = ProvenanceStore::new();
792 store.record(b"f1".to_vec(), make_annotation("rule_a", None));
793 store.record(b"f2".to_vec(), make_annotation("rule_b", None));
794 store.record(b"f3".to_vec(), make_annotation("rule_a", None));
795
796 let entries = store.entries_for_rule("rule_a");
797 assert_eq!(entries.len(), 2);
798 let entries_b = store.entries_for_rule("rule_b");
799 assert_eq!(entries_b.len(), 1);
800 }
801}