rust_rule_engine/backward/
proof_tree.rs

1// Proof Tree for Backward Chaining Explanations
2//
3// This module provides data structures for capturing and visualizing
4// the reasoning process in backward chaining queries.
5//
6// Version: 1.9.0
7
8use super::unification::Bindings;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Represents a single node in the proof tree
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ProofNode {
15    /// The goal that was proven at this node
16    pub goal: String,
17
18    /// Name of the rule that was used (if any)
19    pub rule_name: Option<String>,
20
21    /// Variable bindings at this node
22    #[serde(skip_serializing_if = "HashMap::is_empty")]
23    pub bindings: HashMap<String, String>,
24
25    /// Child nodes (sub-goals that were proven)
26    #[serde(skip_serializing_if = "Vec::is_empty")]
27    pub children: Vec<ProofNode>,
28
29    /// Depth in the proof tree
30    pub depth: usize,
31
32    /// Whether this goal was proven successfully
33    pub proven: bool,
34
35    /// Type of proof node
36    pub node_type: ProofNodeType,
37}
38
39/// Type of proof node
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
41pub enum ProofNodeType {
42    /// Goal proven by a fact
43    Fact,
44
45    /// Goal proven by a rule
46    Rule,
47
48    /// Negated goal (NOT)
49    Negation,
50
51    /// Goal failed to prove
52    Failed,
53}
54
55impl ProofNode {
56    /// Create a new proof node
57    pub fn new(goal: String, depth: usize) -> Self {
58        ProofNode {
59            goal,
60            rule_name: None,
61            bindings: HashMap::new(),
62            children: Vec::new(),
63            depth,
64            proven: false,
65            node_type: ProofNodeType::Failed,
66        }
67    }
68
69    /// Create a fact node
70    pub fn fact(goal: String, depth: usize) -> Self {
71        ProofNode {
72            goal,
73            rule_name: None,
74            bindings: HashMap::new(),
75            children: Vec::new(),
76            depth,
77            proven: true,
78            node_type: ProofNodeType::Fact,
79        }
80    }
81
82    /// Create a rule node
83    pub fn rule(goal: String, rule_name: String, depth: usize) -> Self {
84        ProofNode {
85            goal,
86            rule_name: Some(rule_name),
87            bindings: HashMap::new(),
88            children: Vec::new(),
89            depth,
90            proven: true,
91            node_type: ProofNodeType::Rule,
92        }
93    }
94
95    /// Create a negation node
96    pub fn negation(goal: String, depth: usize, proven: bool) -> Self {
97        ProofNode {
98            goal,
99            rule_name: None,
100            bindings: HashMap::new(),
101            children: Vec::new(),
102            depth,
103            proven,
104            node_type: ProofNodeType::Negation,
105        }
106    }
107
108    /// Add a child node
109    pub fn add_child(&mut self, child: ProofNode) {
110        self.children.push(child);
111    }
112
113    /// Set bindings from Bindings object
114    pub fn set_bindings(&mut self, bindings: &Bindings) {
115        // Convert Bindings to HashMap using to_map() method
116        let binding_map = bindings.to_map();
117        self.bindings = binding_map
118            .iter()
119            .map(|(k, v)| (k.clone(), format!("{:?}", v)))
120            .collect();
121    }
122
123    /// Set bindings from HashMap
124    pub fn set_bindings_map(&mut self, bindings: HashMap<String, String>) {
125        self.bindings = bindings;
126    }
127
128    /// Check if this is a leaf node
129    pub fn is_leaf(&self) -> bool {
130        self.children.is_empty()
131    }
132
133    /// Print the proof tree
134    pub fn print_tree(&self, indent: usize) {
135        let prefix = "  ".repeat(indent);
136        let status = if self.proven { "✓" } else { "✗" };
137
138        println!("{}{} {}", prefix, status, self.goal);
139
140        if let Some(rule) = &self.rule_name {
141            println!("{}  [Rule: {}]", prefix, rule);
142        }
143
144        match self.node_type {
145            ProofNodeType::Fact => println!("{}  [FACT]", prefix),
146            ProofNodeType::Negation => println!("{}  [NEGATION]", prefix),
147            _ => {}
148        }
149
150        if !self.bindings.is_empty() {
151            println!("{}  Bindings: {:?}", prefix, self.bindings);
152        }
153
154        for child in &self.children {
155            child.print_tree(indent + 1);
156        }
157    }
158
159    /// Get tree height
160    pub fn height(&self) -> usize {
161        if self.children.is_empty() {
162            1
163        } else {
164            1 + self.children.iter().map(|c| c.height()).max().unwrap_or(0)
165        }
166    }
167
168    /// Count total nodes
169    pub fn node_count(&self) -> usize {
170        1 + self.children.iter().map(|c| c.node_count()).sum::<usize>()
171    }
172}
173
174/// Complete proof tree with metadata
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct ProofTree {
177    /// Root node of the proof tree
178    pub root: ProofNode,
179
180    /// Whether the query was proven
181    pub success: bool,
182
183    /// Original query string
184    pub query: String,
185
186    /// Statistics
187    pub stats: ProofStats,
188}
189
190/// Statistics about the proof
191#[derive(Debug, Clone, Default, Serialize, Deserialize)]
192pub struct ProofStats {
193    /// Total number of goals explored
194    pub goals_explored: usize,
195
196    /// Total number of rules evaluated
197    pub rules_evaluated: usize,
198
199    /// Total number of facts checked
200    pub facts_checked: usize,
201
202    /// Maximum depth reached
203    pub max_depth: usize,
204
205    /// Total nodes in proof tree
206    pub total_nodes: usize,
207}
208
209impl ProofTree {
210    /// Create a new proof tree
211    pub fn new(root: ProofNode, query: String) -> Self {
212        let success = root.proven;
213        let total_nodes = root.node_count();
214        let max_depth = root.height();
215
216        ProofTree {
217            root,
218            success,
219            query,
220            stats: ProofStats {
221                goals_explored: 0,
222                rules_evaluated: 0,
223                facts_checked: 0,
224                max_depth,
225                total_nodes,
226            },
227        }
228    }
229
230    /// Set statistics
231    pub fn set_stats(&mut self, stats: ProofStats) {
232        self.stats = stats;
233    }
234
235    /// Print the entire tree
236    pub fn print(&self) {
237        println!("Query: {}", self.query);
238        println!("Result: {}", if self.success { "✓ Proven" } else { "✗ Unprovable" });
239        println!("\nProof Tree:");
240        println!("{}", "=".repeat(80));
241        self.root.print_tree(0);
242        println!("{}", "=".repeat(80));
243        self.print_stats();
244    }
245
246    /// Print statistics
247    pub fn print_stats(&self) {
248        println!("\nStatistics:");
249        println!("  Goals explored: {}", self.stats.goals_explored);
250        println!("  Rules evaluated: {}", self.stats.rules_evaluated);
251        println!("  Facts checked: {}", self.stats.facts_checked);
252        println!("  Max depth: {}", self.stats.max_depth);
253        println!("  Total nodes: {}", self.stats.total_nodes);
254    }
255
256    /// Convert to JSON string
257    pub fn to_json(&self) -> Result<String, serde_json::Error> {
258        serde_json::to_string_pretty(self)
259    }
260
261    /// Convert to Markdown
262    pub fn to_markdown(&self) -> String {
263        let mut md = String::new();
264
265        md.push_str("# Proof Explanation\n\n");
266        md.push_str(&format!("**Query:** `{}`\n\n", self.query));
267        md.push_str(&format!("**Result:** {}\n\n",
268            if self.success { "✓ Proven" } else { "✗ Unprovable" }
269        ));
270
271        md.push_str("## Proof Tree\n\n");
272        self.node_to_markdown(&self.root, &mut md, 0);
273
274        md.push_str("\n## Statistics\n\n");
275        md.push_str(&format!("- **Goals explored:** {}\n", self.stats.goals_explored));
276        md.push_str(&format!("- **Rules evaluated:** {}\n", self.stats.rules_evaluated));
277        md.push_str(&format!("- **Facts checked:** {}\n", self.stats.facts_checked));
278        md.push_str(&format!("- **Max depth:** {}\n", self.stats.max_depth));
279        md.push_str(&format!("- **Total nodes:** {}\n", self.stats.total_nodes));
280
281        md
282    }
283
284    /// Convert node to markdown recursively
285    fn node_to_markdown(&self, node: &ProofNode, md: &mut String, depth: usize) {
286        let prefix = "  ".repeat(depth);
287        let status = if node.proven { "✓" } else { "✗" };
288
289        md.push_str(&format!("{}* {} `{}`", prefix, status, node.goal));
290
291        if let Some(rule) = &node.rule_name {
292            md.push_str(&format!(" **[Rule: {}]**", rule));
293        }
294
295        match node.node_type {
296            ProofNodeType::Fact => md.push_str(" *[FACT]*"),
297            ProofNodeType::Negation => md.push_str(" *[NEGATION]*"),
298            _ => {}
299        }
300
301        md.push('\n');
302
303        if !node.bindings.is_empty() {
304            md.push_str(&format!("{}  * Bindings: `{:?}`\n", prefix, node.bindings));
305        }
306
307        for child in &node.children {
308            self.node_to_markdown(child, md, depth + 1);
309        }
310    }
311
312    /// Convert to HTML
313    pub fn to_html(&self) -> String {
314        let mut html = String::new();
315
316        html.push_str("<!DOCTYPE html>\n<html>\n<head>\n");
317        html.push_str("  <title>Proof Explanation</title>\n");
318        html.push_str("  <style>\n");
319        html.push_str("    body { font-family: 'Courier New', monospace; margin: 20px; }\n");
320        html.push_str("    .proven { color: green; }\n");
321        html.push_str("    .failed { color: red; }\n");
322        html.push_str("    .node { margin-left: 20px; }\n");
323        html.push_str("    .rule { color: blue; font-style: italic; }\n");
324        html.push_str("    .bindings { color: gray; font-size: 0.9em; }\n");
325        html.push_str("    .stats { margin-top: 20px; padding: 10px; background: #f0f0f0; }\n");
326        html.push_str("  </style>\n");
327        html.push_str("</head>\n<body>\n");
328
329        html.push_str(&format!("<h1>Proof Explanation</h1>\n"));
330        html.push_str(&format!("<p><strong>Query:</strong> <code>{}</code></p>\n", self.query));
331        html.push_str(&format!("<p><strong>Result:</strong> <span class=\"{}\">{}</span></p>\n",
332            if self.success { "proven" } else { "failed" },
333            if self.success { "✓ Proven" } else { "✗ Unprovable" }
334        ));
335
336        html.push_str("<h2>Proof Tree</h2>\n");
337        self.node_to_html(&self.root, &mut html);
338
339        html.push_str("<div class=\"stats\">\n");
340        html.push_str("<h2>Statistics</h2>\n");
341        html.push_str(&format!("<p>Goals explored: {}</p>\n", self.stats.goals_explored));
342        html.push_str(&format!("<p>Rules evaluated: {}</p>\n", self.stats.rules_evaluated));
343        html.push_str(&format!("<p>Facts checked: {}</p>\n", self.stats.facts_checked));
344        html.push_str(&format!("<p>Max depth: {}</p>\n", self.stats.max_depth));
345        html.push_str(&format!("<p>Total nodes: {}</p>\n", self.stats.total_nodes));
346        html.push_str("</div>\n");
347
348        html.push_str("</body>\n</html>");
349        html
350    }
351
352    /// Convert node to HTML recursively
353    fn node_to_html(&self, node: &ProofNode, html: &mut String) {
354        let status = if node.proven { "✓" } else { "✗" };
355        let class = if node.proven { "proven" } else { "failed" };
356
357        html.push_str("<div class=\"node\">\n");
358        html.push_str(&format!("  <span class=\"{}\">{} {}</span>",
359            class, status, node.goal));
360
361        if let Some(rule) = &node.rule_name {
362            html.push_str(&format!(" <span class=\"rule\">[Rule: {}]</span>", rule));
363        }
364
365        match node.node_type {
366            ProofNodeType::Fact => html.push_str(" <em>[FACT]</em>"),
367            ProofNodeType::Negation => html.push_str(" <em>[NEGATION]</em>"),
368            _ => {}
369        }
370
371        if !node.bindings.is_empty() {
372            html.push_str(&format!("<br><span class=\"bindings\">Bindings: {:?}</span>",
373                node.bindings));
374        }
375
376        html.push_str("\n");
377
378        for child in &node.children {
379            self.node_to_html(child, html);
380        }
381
382        html.push_str("</div>\n");
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_proof_node_creation() {
392        let node = ProofNode::new("test_goal".to_string(), 0);
393        assert_eq!(node.goal, "test_goal");
394        assert_eq!(node.depth, 0);
395        assert!(!node.proven);
396        assert_eq!(node.node_type, ProofNodeType::Failed);
397    }
398
399    #[test]
400    fn test_fact_node() {
401        let node = ProofNode::fact("fact_goal".to_string(), 1);
402        assert!(node.proven);
403        assert_eq!(node.node_type, ProofNodeType::Fact);
404        assert!(node.is_leaf());
405    }
406
407    #[test]
408    fn test_rule_node() {
409        let node = ProofNode::rule("rule_goal".to_string(), "test_rule".to_string(), 2);
410        assert!(node.proven);
411        assert_eq!(node.node_type, ProofNodeType::Rule);
412        assert_eq!(node.rule_name, Some("test_rule".to_string()));
413    }
414
415    #[test]
416    fn test_add_child() {
417        let mut parent = ProofNode::rule("parent".to_string(), "rule1".to_string(), 0);
418        let child = ProofNode::fact("child".to_string(), 1);
419
420        parent.add_child(child);
421        assert_eq!(parent.children.len(), 1);
422        assert!(!parent.is_leaf());
423    }
424
425    #[test]
426    fn test_tree_height() {
427        let mut root = ProofNode::rule("root".to_string(), "rule1".to_string(), 0);
428        let mut child1 = ProofNode::rule("child1".to_string(), "rule2".to_string(), 1);
429        let child2 = ProofNode::fact("child2".to_string(), 2);
430
431        child1.add_child(child2);
432        root.add_child(child1);
433
434        assert_eq!(root.height(), 3);
435    }
436
437    #[test]
438    fn test_node_count() {
439        let mut root = ProofNode::rule("root".to_string(), "rule1".to_string(), 0);
440        let child1 = ProofNode::fact("child1".to_string(), 1);
441        let child2 = ProofNode::fact("child2".to_string(), 1);
442
443        root.add_child(child1);
444        root.add_child(child2);
445
446        assert_eq!(root.node_count(), 3);
447    }
448
449    #[test]
450    fn test_proof_tree_creation() {
451        let root = ProofNode::fact("test".to_string(), 0);
452        let tree = ProofTree::new(root, "test query".to_string());
453
454        assert!(tree.success);
455        assert_eq!(tree.query, "test query");
456        assert_eq!(tree.stats.total_nodes, 1);
457    }
458
459    #[test]
460    fn test_json_serialization() {
461        let root = ProofNode::fact("test".to_string(), 0);
462        let tree = ProofTree::new(root, "test query".to_string());
463
464        let json = tree.to_json().unwrap();
465        assert!(json.contains("test query"));
466        assert!(json.contains("Fact"));
467    }
468
469    #[test]
470    fn test_markdown_generation() {
471        let root = ProofNode::fact("test".to_string(), 0);
472        let tree = ProofTree::new(root, "test query".to_string());
473
474        let md = tree.to_markdown();
475        assert!(md.contains("# Proof Explanation"));
476        assert!(md.contains("test query"));
477        assert!(md.contains("✓"));
478    }
479
480    #[test]
481    fn test_html_generation() {
482        let root = ProofNode::fact("test".to_string(), 0);
483        let tree = ProofTree::new(root, "test query".to_string());
484
485        let html = tree.to_html();
486        assert!(html.contains("<!DOCTYPE html>"));
487        assert!(html.contains("test query"));
488        assert!(html.contains("✓"));
489    }
490}