Skip to main content

rust_rule_engine/backward/
proof_graph.rs

1//! Proof Graph for Incremental Caching and TMS Integration
2//!
3//! This module provides a global cache of proven facts with dependency tracking,
4//! enabling reuse across multiple queries and incremental updates when facts change.
5//!
6//! Architecture:
7//! - ProofGraph: maintains mapping from FactKeys to proven facts with justifications
8//! - ProofGraphNode: represents a proven fact with its supporting premises and rules
9//! - Integration with IncrementalEngine for TMS-aware retraction propagation
10
11use crate::rete::FactHandle;
12use crate::types::Value;
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16/// Canonical key for identifying a fact (type + field values)
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub struct FactKey {
19    /// Fact type (e.g., "User", "Order")
20    pub fact_type: String,
21
22    /// Field name if specific field query (e.g., "User.Score")
23    pub field: Option<String>,
24
25    /// Expected value for field (e.g., "Score >= 80")
26    pub pattern: String,
27}
28
29impl FactKey {
30    /// Create a new fact key from a pattern string
31    pub fn from_pattern(pattern: &str) -> Self {
32        // Parse pattern like "User.Score >= 80" into components
33        if let Some(dot_pos) = pattern.find('.') {
34            let fact_type = pattern[..dot_pos].trim().to_string();
35            let rest = &pattern[dot_pos + 1..];
36
37            // Extract field name (before operator)
38            let field = if let Some(op_pos) = rest.find(|c: char| !c.is_alphanumeric() && c != '_')
39            {
40                Some(rest[..op_pos].trim().to_string())
41            } else {
42                Some(rest.trim().to_string())
43            };
44
45            Self {
46                fact_type,
47                field,
48                pattern: pattern.to_string(),
49            }
50        } else {
51            // Simple pattern without dot notation
52            Self {
53                fact_type: pattern.to_string(),
54                field: None,
55                pattern: pattern.to_string(),
56            }
57        }
58    }
59
60    /// Create from explicit components
61    pub fn new(fact_type: String, field: Option<String>, pattern: String) -> Self {
62        Self {
63            fact_type,
64            field,
65            pattern,
66        }
67    }
68}
69
70/// A justification for a proven fact (one way it was derived)
71#[derive(Debug, Clone)]
72pub struct Justification {
73    /// Rule that produced this fact
74    pub rule_name: String,
75
76    /// Premise fact handles that were used
77    pub premises: Vec<FactHandle>,
78
79    /// Premise keys (for human-readable tracing)
80    pub premise_keys: Vec<String>,
81
82    /// When this justification was created (generation/timestamp)
83    pub generation: u64,
84}
85
86/// A node in the proof graph representing a proven fact
87#[derive(Debug, Clone)]
88pub struct ProofGraphNode {
89    /// Unique key for this fact
90    pub key: FactKey,
91
92    /// Fact handle from IncrementalEngine (if inserted logically)
93    pub handle: Option<FactHandle>,
94
95    /// All justifications (ways this fact was proven)
96    pub justifications: Vec<Justification>,
97
98    /// Dependents (facts that depend on this fact as premise)
99    pub dependents: HashSet<FactHandle>,
100
101    /// Whether this fact is currently valid
102    pub valid: bool,
103
104    /// Generation when last validated
105    pub generation: u64,
106
107    /// Variable bindings (if any) associated with this proof
108    pub bindings: HashMap<String, Value>,
109}
110
111impl ProofGraphNode {
112    /// Create a new proof graph node
113    pub fn new(key: FactKey) -> Self {
114        Self {
115            key,
116            handle: None,
117            justifications: Vec::new(),
118            dependents: HashSet::new(),
119            valid: true,
120            generation: 0,
121            bindings: HashMap::new(),
122        }
123    }
124
125    /// Add a justification
126    pub fn add_justification(
127        &mut self,
128        rule_name: String,
129        premises: Vec<FactHandle>,
130        premise_keys: Vec<String>,
131        generation: u64,
132    ) {
133        self.justifications.push(Justification {
134            rule_name,
135            premises,
136            premise_keys,
137            generation,
138        });
139        self.valid = true;
140        self.generation = generation;
141    }
142
143    /// Check if this node has any valid justifications
144    pub fn has_valid_justifications(&self) -> bool {
145        !self.justifications.is_empty()
146    }
147
148    /// Remove a justification involving a retracted premise
149    pub fn remove_justifications_with_premise(&mut self, premise_handle: &FactHandle) -> bool {
150        let before = self.justifications.len();
151        self.justifications
152            .retain(|j| !j.premises.contains(premise_handle));
153        let after = self.justifications.len();
154
155        // If no justifications left, mark invalid
156        if self.justifications.is_empty() {
157            self.valid = false;
158        }
159
160        before != after
161    }
162}
163
164/// Global proof graph cache
165pub struct ProofGraph {
166    /// Nodes indexed by fact handle
167    nodes_by_handle: HashMap<FactHandle, ProofGraphNode>,
168
169    /// Index from fact key to handles (for pattern lookup)
170    index_by_key: HashMap<FactKey, Vec<FactHandle>>,
171
172    /// Reverse dependency index (premise -> dependents)
173    dependencies: HashMap<FactHandle, HashSet<FactHandle>>,
174
175    /// Generation counter for tracking updates
176    generation: u64,
177
178    /// Statistics
179    pub stats: ProofGraphStats,
180}
181
182/// Statistics about proof graph usage
183#[derive(Debug, Clone, Default)]
184pub struct ProofGraphStats {
185    pub total_nodes: usize,
186    pub cache_hits: usize,
187    pub cache_misses: usize,
188    pub invalidations: usize,
189    pub justifications_added: usize,
190}
191
192impl ProofGraph {
193    /// Create a new proof graph
194    pub fn new() -> Self {
195        Self {
196            nodes_by_handle: HashMap::new(),
197            index_by_key: HashMap::new(),
198            dependencies: HashMap::new(),
199            generation: 0,
200            stats: ProofGraphStats::default(),
201        }
202    }
203
204    /// Insert a proof into the graph
205    pub fn insert_proof(
206        &mut self,
207        handle: FactHandle,
208        key: FactKey,
209        rule_name: String,
210        premises: Vec<FactHandle>,
211        premise_keys: Vec<String>,
212    ) {
213        self.generation += 1;
214
215        // Get or create node
216        let node = self.nodes_by_handle.entry(handle).or_insert_with(|| {
217            let mut node = ProofGraphNode::new(key.clone());
218            node.handle = Some(handle);
219            self.stats.total_nodes += 1;
220            node
221        });
222
223        // Add justification
224        node.add_justification(rule_name, premises.clone(), premise_keys, self.generation);
225        self.stats.justifications_added += 1;
226
227        // Update key index
228        self.index_by_key
229            .entry(key.clone())
230            .or_default()
231            .push(handle);
232
233        // Update dependency edges
234        for premise in &premises {
235            self.dependencies
236                .entry(*premise)
237                .or_default()
238                .insert(handle);
239
240            // Also update the premise node's dependents
241            if let Some(premise_node) = self.nodes_by_handle.get_mut(premise) {
242                premise_node.dependents.insert(handle);
243            }
244        }
245    }
246
247    /// Lookup proven facts by key pattern
248    pub fn lookup_by_key(&mut self, key: &FactKey) -> Option<Vec<&ProofGraphNode>> {
249        if let Some(handles) = self.index_by_key.get(key) {
250            let nodes: Vec<&ProofGraphNode> = handles
251                .iter()
252                .filter_map(|h| self.nodes_by_handle.get(h))
253                .filter(|n| n.valid)
254                .collect();
255
256            if !nodes.is_empty() {
257                self.stats.cache_hits += 1;
258                Some(nodes)
259            } else {
260                self.stats.cache_misses += 1;
261                None
262            }
263        } else {
264            self.stats.cache_misses += 1;
265            None
266        }
267    }
268
269    /// Check if a fact key has been proven
270    pub fn is_proven(&mut self, key: &FactKey) -> bool {
271        self.lookup_by_key(key).is_some()
272    }
273
274    /// Invalidate a fact handle (e.g., when retracted by TMS)
275    pub fn invalidate_handle(&mut self, handle: &FactHandle) {
276        self.stats.invalidations += 1;
277
278        // Get dependents before removing
279        let dependents = self.dependencies.get(handle).cloned();
280
281        // Mark node invalid
282        if let Some(node) = self.nodes_by_handle.get_mut(handle) {
283            node.valid = false;
284        }
285
286        // Propagate to dependents
287        if let Some(deps) = dependents {
288            for dep_handle in deps {
289                self.propagate_invalidation(&dep_handle, handle);
290            }
291        }
292    }
293
294    /// Propagate invalidation to a dependent fact
295    fn propagate_invalidation(
296        &mut self,
297        dependent_handle: &FactHandle,
298        premise_handle: &FactHandle,
299    ) {
300        if let Some(node) = self.nodes_by_handle.get_mut(dependent_handle) {
301            // Remove justifications that depend on this premise
302            let changed = node.remove_justifications_with_premise(premise_handle);
303
304            // If node became invalid (no justifications left), propagate further
305            if changed && !node.valid {
306                self.stats.invalidations += 1;
307
308                // Get dependents and propagate recursively
309                let further_deps = node.dependents.clone();
310                for further_dep in further_deps {
311                    self.propagate_invalidation(&further_dep, dependent_handle);
312                }
313            }
314        }
315    }
316
317    /// Get a node by handle
318    pub fn get_node(&self, handle: &FactHandle) -> Option<&ProofGraphNode> {
319        self.nodes_by_handle.get(handle)
320    }
321
322    /// Clear all cached proofs (reset graph)
323    pub fn clear(&mut self) {
324        self.nodes_by_handle.clear();
325        self.index_by_key.clear();
326        self.dependencies.clear();
327        self.generation = 0;
328        self.stats = ProofGraphStats::default();
329    }
330
331    /// Get current generation counter
332    pub fn generation(&self) -> u64 {
333        self.generation
334    }
335
336    /// Print statistics
337    pub fn print_stats(&self) {
338        println!("ProofGraph Statistics:");
339        println!("  Total nodes: {}", self.stats.total_nodes);
340        println!("  Cache hits: {}", self.stats.cache_hits);
341        println!("  Cache misses: {}", self.stats.cache_misses);
342        println!("  Invalidations: {}", self.stats.invalidations);
343        println!(
344            "  Justifications added: {}",
345            self.stats.justifications_added
346        );
347
348        if self.stats.cache_hits + self.stats.cache_misses > 0 {
349            let hit_rate = (self.stats.cache_hits as f64)
350                / ((self.stats.cache_hits + self.stats.cache_misses) as f64)
351                * 100.0;
352            println!("  Cache hit rate: {:.1}%", hit_rate);
353        }
354    }
355}
356
357impl Default for ProofGraph {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363/// Thread-safe wrapper for ProofGraph
364pub type SharedProofGraph = Arc<std::sync::Mutex<ProofGraph>>;
365
366/// Create a new shared proof graph
367pub fn new_shared() -> SharedProofGraph {
368    Arc::new(std::sync::Mutex::new(ProofGraph::new()))
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_fact_key_from_pattern() {
377        let key = FactKey::from_pattern("User.Score >= 80");
378        assert_eq!(key.fact_type, "User");
379        assert_eq!(key.field, Some("Score".to_string()));
380        assert_eq!(key.pattern, "User.Score >= 80");
381    }
382
383    #[test]
384    fn test_proof_graph_insert_and_lookup() {
385        let mut graph = ProofGraph::new();
386        let handle = FactHandle::new(1);
387        let key = FactKey::from_pattern("User.Score >= 80");
388
389        graph.insert_proof(handle, key.clone(), "ScoreRule".to_string(), vec![], vec![]);
390
391        assert!(graph.is_proven(&key));
392        assert_eq!(graph.stats.total_nodes, 1);
393    }
394
395    #[test]
396    fn test_dependency_tracking() {
397        let mut graph = ProofGraph::new();
398        let premise_handle = FactHandle::new(1);
399        let conclusion_handle = FactHandle::new(2);
400
401        let premise_key = FactKey::from_pattern("User.Age >= 18");
402        let conclusion_key = FactKey::from_pattern("User.CanVote == true");
403
404        // Insert premise
405        graph.insert_proof(
406            premise_handle,
407            premise_key.clone(),
408            "AgeRule".to_string(),
409            vec![],
410            vec![],
411        );
412
413        // Insert conclusion depending on premise
414        graph.insert_proof(
415            conclusion_handle,
416            conclusion_key.clone(),
417            "VotingRule".to_string(),
418            vec![premise_handle],
419            vec!["User.Age >= 18".to_string()],
420        );
421
422        assert!(graph.is_proven(&premise_key));
423        assert!(graph.is_proven(&conclusion_key));
424
425        // Invalidate premise
426        graph.invalidate_handle(&premise_handle);
427
428        // Conclusion should now be invalid
429        let conclusion_node = graph.get_node(&conclusion_handle).unwrap();
430        assert!(!conclusion_node.valid);
431        assert_eq!(graph.stats.invalidations, 2); // premise + dependent
432    }
433
434    #[test]
435    fn test_multiple_justifications() {
436        let mut graph = ProofGraph::new();
437        let handle = FactHandle::new(1);
438        let key = FactKey::from_pattern("User.IsVIP == true");
439
440        // Add first justification
441        graph.insert_proof(
442            handle,
443            key.clone(),
444            "HighSpenderRule".to_string(),
445            vec![],
446            vec![],
447        );
448
449        // Add second justification for same fact
450        graph.insert_proof(
451            handle,
452            key.clone(),
453            "LoyaltyRule".to_string(),
454            vec![],
455            vec![],
456        );
457
458        let node = graph.get_node(&handle).unwrap();
459        assert_eq!(node.justifications.len(), 2);
460        assert!(node.valid);
461    }
462
463    #[test]
464    fn test_cache_statistics() {
465        let mut graph = ProofGraph::new();
466        let key = FactKey::from_pattern("User.Active == true");
467
468        // Miss
469        assert!(!graph.is_proven(&key));
470        assert_eq!(graph.stats.cache_misses, 1);
471
472        // Insert
473        let handle = FactHandle::new(1);
474        graph.insert_proof(
475            handle,
476            key.clone(),
477            "ActiveRule".to_string(),
478            vec![],
479            vec![],
480        );
481
482        // Hit
483        assert!(graph.is_proven(&key));
484        assert_eq!(graph.stats.cache_hits, 1);
485    }
486
487    #[test]
488    fn test_clear() {
489        let mut graph = ProofGraph::new();
490        let handle = FactHandle::new(1);
491        let key = FactKey::from_pattern("Test.Value == 42");
492
493        graph.insert_proof(handle, key.clone(), "TestRule".to_string(), vec![], vec![]);
494        assert!(graph.is_proven(&key));
495
496        graph.clear();
497        assert!(!graph.is_proven(&key));
498        assert_eq!(graph.stats.total_nodes, 0);
499    }
500}