1use std::collections::{HashMap, HashSet};
13use std::sync::RwLock;
14
15use uni_common::Value;
16use uni_cypher::locy_ast::{ExplainRule, RuleCondition};
17use uni_locy::types::CompiledRule;
18use uni_locy::{CompiledProgram, DerivationNode, LocyConfig, LocyError, LocyStats, Row};
19
20use super::locy_delta::{
21 KeyTuple, RowStore, extract_cypher_conditions, extract_key, resolve_clause_with_is_refs,
22};
23
24use super::locy_eval::{eval_expr, record_batches_to_locy_rows};
25use super::locy_slg::SLGResolver;
26use super::locy_traits::DerivedFactSource;
27
28#[derive(Clone)]
31pub struct DerivationInput {
32 pub is_ref_rule: String,
33 pub fact_hash: Vec<u8>,
34}
35
36#[derive(Clone)]
38pub struct DerivationEntry {
39 pub rule_name: String,
41 pub clause_index: usize,
43 pub inputs: Vec<DerivationInput>,
45 pub along_values: HashMap<String, Value>,
47 pub iteration: usize,
49 pub fact_row: Row,
51}
52
53pub struct DerivationTracker {
59 entries: RwLock<HashMap<Vec<u8>, DerivationEntry>>,
60}
61
62impl DerivationTracker {
63 pub fn new() -> Self {
64 Self {
65 entries: RwLock::new(HashMap::new()),
66 }
67 }
68
69 pub fn record(&self, fact_hash: Vec<u8>, entry: DerivationEntry) {
72 if let Ok(mut guard) = self.entries.write() {
73 guard.entry(fact_hash).or_insert(entry);
74 }
75 }
76
77 pub fn lookup(&self, fact_hash: &[u8]) -> Option<DerivationEntry> {
79 self.entries.read().ok()?.get(fact_hash).cloned()
80 }
81
82 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<(Vec<u8>, DerivationEntry)> {
84 match self.entries.read() {
85 Ok(guard) => guard
86 .iter()
87 .filter(|(_, e)| e.rule_name == rule_name)
88 .map(|(k, v)| (k.clone(), v.clone()))
89 .collect(),
90 Err(_) => vec![],
91 }
92 }
93}
94
95impl Default for DerivationTracker {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101type VisitedSet = HashSet<(String, KeyTuple)>;
103
104pub async fn explain_rule(
110 query: &ExplainRule,
111 program: &CompiledProgram,
112 fact_source: &dyn DerivedFactSource,
113 config: &LocyConfig,
114 derived_store: &mut RowStore,
115 stats: &mut LocyStats,
116 tracker: Option<&DerivationTracker>,
117) -> Result<DerivationNode, LocyError> {
118 if let Some(Ok(node)) = tracker.map(|t| explain_rule_mode_a(query, program, t, derived_store)) {
121 return Ok(node);
122 }
123
124 explain_rule_mode_b(query, program, fact_source, config, derived_store, stats).await
126}
127
128fn explain_rule_mode_a(
132 query: &ExplainRule,
133 program: &CompiledProgram,
134 tracker: &DerivationTracker,
135 derived_store: &RowStore,
136) -> Result<DerivationNode, LocyError> {
137 let rule_name = query.rule_name.to_string();
138 let rule = program
139 .rule_catalog
140 .get(&rule_name)
141 .ok_or_else(|| LocyError::EvaluationError {
142 message: format!("rule '{}' not found for EXPLAIN RULE (Mode A)", rule_name),
143 })?;
144
145 let tracker_entries = tracker.entries_for_rule(&rule_name);
146 if tracker_entries.is_empty() {
147 return Err(LocyError::EvaluationError {
148 message: format!("no tracker entries for rule '{rule_name}' (falling back to Mode B)"),
149 });
150 }
151
152 let matching_entries: Vec<_> = tracker_entries
154 .into_iter()
155 .filter(|(_, entry)| {
156 eval_expr(&query.where_expr, &entry.fact_row)
157 .map(|v| v.as_bool().unwrap_or(false))
158 .unwrap_or(false)
159 })
160 .collect();
161
162 if matching_entries.is_empty() {
163 return Err(LocyError::EvaluationError {
164 message: format!("no tracker entries match WHERE clause for rule '{rule_name}'"),
165 });
166 }
167
168 let orch_facts = derived_store
171 .get(&rule_name)
172 .map(|r| r.rows.clone())
173 .unwrap_or_default();
174 let _ = orch_facts; let mut root = DerivationNode {
177 rule: rule_name.clone(),
178 clause_index: 0,
179 priority: rule.priority,
180 bindings: HashMap::new(),
181 along_values: HashMap::new(),
182 children: Vec::new(),
183 graph_fact: None,
184 };
185
186 for (_, entry) in matching_entries {
187 let along_values = extract_along_values(&entry.fact_row, rule);
188 let clause_priority = rule
189 .clauses
190 .get(entry.clause_index)
191 .and_then(|c| c.priority);
192 let node = DerivationNode {
193 rule: rule_name.clone(),
194 clause_index: entry.clause_index,
195 priority: clause_priority.or(rule.priority),
196 bindings: entry.fact_row.clone(),
197 along_values,
198 children: vec![],
200 graph_fact: Some(format!(
201 "[iter={}] {}",
202 entry.iteration,
203 format_graph_fact(&entry.fact_row)
204 )),
205 };
206 root.children.push(node);
207 }
208
209 Ok(root)
210}
211
212async fn explain_rule_mode_b(
215 query: &ExplainRule,
216 program: &CompiledProgram,
217 fact_source: &dyn DerivedFactSource,
218 config: &LocyConfig,
219 derived_store: &mut RowStore,
220 stats: &mut LocyStats,
221) -> Result<DerivationNode, LocyError> {
222 let rule_name = query.rule_name.to_string();
223 let rule = program
224 .rule_catalog
225 .get(&rule_name)
226 .ok_or_else(|| LocyError::EvaluationError {
227 message: format!("rule '{}' not found for EXPLAIN RULE", rule_name),
228 })?;
229
230 let key_columns: Vec<String> = rule
231 .yield_schema
232 .iter()
233 .filter(|c| c.is_key)
234 .map(|c| c.name.clone())
235 .collect();
236
237 {
241 let mut fresh_store = RowStore::new();
242 let slg_start = std::time::Instant::now();
243 let mut resolver =
244 SLGResolver::new(program, fact_source, config, &mut fresh_store, slg_start);
245 resolver.resolve_goal(&rule_name, &HashMap::new()).await?;
246 stats.queries_executed += resolver.stats.queries_executed;
247 for (name, relation) in fresh_store {
250 derived_store.insert(name, relation);
251 }
252 }
253
254 let facts = derived_store
256 .get(&rule_name)
257 .map(|r| r.rows.clone())
258 .unwrap_or_default();
259
260 let filtered: Vec<Row> = facts
262 .into_iter()
263 .filter(|row| {
264 eval_expr(&query.where_expr, row)
265 .map(|v| v.as_bool().unwrap_or(false))
266 .unwrap_or(false)
267 })
268 .collect();
269
270 let mut root = DerivationNode {
272 rule: rule_name.clone(),
273 clause_index: 0,
274 priority: rule.priority,
275 bindings: HashMap::new(),
276 along_values: HashMap::new(),
277 children: Vec::new(),
278 graph_fact: None,
279 };
280
281 for fact in &filtered {
283 let mut visited = VisitedSet::new();
284 let node = build_derivation_node(
285 &rule_name,
286 fact,
287 &key_columns,
288 program,
289 fact_source,
290 derived_store,
291 stats,
292 &mut visited,
293 config.max_explain_depth,
294 )
295 .await?;
296 root.children.push(node);
297 }
298
299 Ok(root)
300}
301
302#[allow(clippy::too_many_arguments)]
307fn build_derivation_node<'a>(
308 rule_name: &'a str,
309 fact: &'a Row,
310 key_columns: &'a [String],
311 program: &'a CompiledProgram,
312 fact_source: &'a dyn DerivedFactSource,
313 derived_store: &'a mut RowStore,
314 stats: &'a mut LocyStats,
315 visited: &'a mut VisitedSet,
316 max_depth: usize,
317) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<DerivationNode, LocyError>> + 'a>> {
318 Box::pin(async move {
319 let rule =
320 program
321 .rule_catalog
322 .get(rule_name)
323 .ok_or_else(|| LocyError::EvaluationError {
324 message: format!("rule '{}' not found during EXPLAIN", rule_name),
325 })?;
326
327 let key_tuple = extract_key(fact, key_columns);
328 let visit_key = (rule_name.to_string(), key_tuple);
329
330 if !visited.insert(visit_key.clone()) || max_depth == 0 {
332 return Ok(DerivationNode {
333 rule: rule_name.to_string(),
334 clause_index: 0,
335 priority: rule.priority,
336 bindings: fact.clone(),
337 along_values: extract_along_values(fact, rule),
338 children: Vec::new(),
339 graph_fact: Some("(cycle)".to_string()),
340 });
341 }
342
343 let yield_columns: Vec<String> = rule.yield_schema.iter().map(|c| c.name.clone()).collect();
345
346 for (clause_idx, clause) in rule.clauses.iter().enumerate() {
348 let has_is_refs = clause
349 .where_conditions
350 .iter()
351 .any(|c| matches!(c, RuleCondition::IsReference(_)));
352 let has_along = !clause.along.is_empty();
353
354 let resolved = if has_is_refs || has_along {
355 let rows = resolve_clause_with_is_refs(clause, fact_source, derived_store).await?;
356 stats.queries_executed += 1;
357 rows
358 } else {
359 let cypher_conditions = extract_cypher_conditions(&clause.where_conditions);
360 let raw_batches = fact_source
361 .execute_pattern(&clause.match_pattern, &cypher_conditions)
362 .await?;
363 stats.queries_executed += 1;
364 record_batches_to_locy_rows(&raw_batches)
365 };
366
367 let matching_row = resolved
369 .iter()
370 .find(|row| yield_columns.iter().all(|k| row.get(k) == fact.get(k)));
371
372 if let Some(evidence_row) = matching_row {
373 let along_values = extract_along_values(fact, rule);
374
375 let mut children = Vec::new();
377 for cond in &clause.where_conditions {
378 if let RuleCondition::IsReference(is_ref) = cond {
379 if is_ref.negated {
380 continue;
381 }
382 let ref_rule_name = is_ref.rule_name.to_string();
383 if let Some(ref_rule) = program.rule_catalog.get(&ref_rule_name) {
384 let ref_key_columns: Vec<String> = ref_rule
385 .yield_schema
386 .iter()
387 .filter(|c| c.is_key)
388 .map(|c| c.name.clone())
389 .collect();
390
391 let ref_facts: Vec<Row> = derived_store
392 .get(&ref_rule_name)
393 .map(|r| r.rows.clone())
394 .unwrap_or_default();
395
396 let matching_ref_facts: Vec<Row> = ref_facts
397 .into_iter()
398 .filter(|ref_fact| {
399 let subjects_match =
400 is_ref.subjects.iter().enumerate().all(|(i, subject)| {
401 if i < ref_key_columns.len() {
402 let subject_val = evidence_row
403 .get(subject)
404 .or_else(|| fact.get(subject));
405 match subject_val {
406 Some(val) => {
407 ref_fact.get(&ref_key_columns[i])
408 == Some(val)
409 }
410 None => true,
411 }
412 } else {
413 true
414 }
415 });
416 let target_matches = if let Some(target) = &is_ref.target {
417 let target_idx = is_ref.subjects.len();
418 if target_idx < ref_key_columns.len() {
419 let target_val = evidence_row
420 .get(target)
421 .or_else(|| fact.get(target));
422 match target_val {
423 Some(val) => {
424 ref_fact.get(&ref_key_columns[target_idx])
425 == Some(val)
426 }
427 None => true,
428 }
429 } else {
430 true
431 }
432 } else {
433 true
434 };
435 subjects_match && target_matches
436 })
437 .collect();
438
439 for ref_fact in matching_ref_facts {
440 let child = build_derivation_node(
441 &ref_rule_name,
442 &ref_fact,
443 &ref_key_columns,
444 program,
445 fact_source,
446 derived_store,
447 stats,
448 visited,
449 max_depth - 1,
450 )
451 .await?;
452 children.push(child);
453 }
454 }
455 }
456 }
457
458 visited.remove(&visit_key);
460
461 let mut merged_bindings = evidence_row.clone();
462 for (k, v) in fact {
463 merged_bindings.entry(k.clone()).or_insert(v.clone());
464 }
465
466 return Ok(DerivationNode {
467 rule: rule_name.to_string(),
468 clause_index: clause_idx,
469 priority: rule.clauses[clause_idx].priority,
470 bindings: merged_bindings,
471 along_values,
472 children,
473 graph_fact: Some(format_graph_fact(evidence_row)),
474 });
475 }
476 }
477
478 visited.remove(&visit_key);
480 Ok(DerivationNode {
481 rule: rule_name.to_string(),
482 clause_index: 0,
483 priority: rule.priority,
484 bindings: fact.clone(),
485 along_values: extract_along_values(fact, rule),
486 children: Vec::new(),
487 graph_fact: Some(format_graph_fact(fact)),
488 })
489 })
490}
491
492fn extract_along_values(fact: &Row, rule: &CompiledRule) -> HashMap<String, Value> {
493 let mut along_values = HashMap::new();
494 for clause in &rule.clauses {
495 for along in &clause.along {
496 if let Some(v) = fact.get(&along.name) {
497 along_values.insert(along.name.clone(), v.clone());
498 }
499 }
500 }
501 along_values
502}
503
504pub(crate) fn format_graph_fact(row: &Row) -> String {
505 let mut entries: Vec<String> = row
506 .iter()
507 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
508 .collect();
509 entries.sort();
510 format!("{{{}}}", entries.join(", "))
511}
512
513fn format_value(v: &Value) -> String {
514 match v {
515 Value::Null => "null".to_string(),
516 Value::Bool(b) => b.to_string(),
517 Value::Int(i) => i.to_string(),
518 Value::Float(f) => f.to_string(),
519 Value::String(s) => format!("\"{}\"", s),
520 Value::List(items) => {
521 let inner: Vec<String> = items.iter().map(format_value).collect();
522 format!("[{}]", inner.join(", "))
523 }
524 Value::Map(m) => {
525 let mut entries: Vec<String> = m
526 .iter()
527 .map(|(k, v)| format!("{}: {}", k, format_value(v)))
528 .collect();
529 entries.sort();
530 format!("{{{}}}", entries.join(", "))
531 }
532 Value::Node(n) => format!("Node({})", n.vid.as_u64()),
533 Value::Edge(e) => format!("Edge({})", e.eid.as_u64()),
534 _ => format!("{:?}", v),
535 }
536}