rust_rule_engine/backward/
proof_graph.rs1use crate::rete::FactHandle;
12use crate::types::Value;
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub struct FactKey {
19 pub fact_type: String,
21
22 pub field: Option<String>,
24
25 pub pattern: String,
27}
28
29impl FactKey {
30 pub fn from_pattern(pattern: &str) -> Self {
32 if let Some(dot_pos) = pattern.find('.') {
34 let fact_type = pattern[..dot_pos].trim().to_string();
35 let rest = &pattern[dot_pos + 1..];
36
37 let field = if let Some(op_pos) = rest.find(|c: char| !c.is_alphanumeric() && c != '_')
39 {
40 Some(rest[..op_pos].trim().to_string())
41 } else {
42 Some(rest.trim().to_string())
43 };
44
45 Self {
46 fact_type,
47 field,
48 pattern: pattern.to_string(),
49 }
50 } else {
51 Self {
53 fact_type: pattern.to_string(),
54 field: None,
55 pattern: pattern.to_string(),
56 }
57 }
58 }
59
60 pub fn new(fact_type: String, field: Option<String>, pattern: String) -> Self {
62 Self {
63 fact_type,
64 field,
65 pattern,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct Justification {
73 pub rule_name: String,
75
76 pub premises: Vec<FactHandle>,
78
79 pub premise_keys: Vec<String>,
81
82 pub generation: u64,
84}
85
86#[derive(Debug, Clone)]
88pub struct ProofGraphNode {
89 pub key: FactKey,
91
92 pub handle: Option<FactHandle>,
94
95 pub justifications: Vec<Justification>,
97
98 pub dependents: HashSet<FactHandle>,
100
101 pub valid: bool,
103
104 pub generation: u64,
106
107 pub bindings: HashMap<String, Value>,
109}
110
111impl ProofGraphNode {
112 pub fn new(key: FactKey) -> Self {
114 Self {
115 key,
116 handle: None,
117 justifications: Vec::new(),
118 dependents: HashSet::new(),
119 valid: true,
120 generation: 0,
121 bindings: HashMap::new(),
122 }
123 }
124
125 pub fn add_justification(
127 &mut self,
128 rule_name: String,
129 premises: Vec<FactHandle>,
130 premise_keys: Vec<String>,
131 generation: u64,
132 ) {
133 self.justifications.push(Justification {
134 rule_name,
135 premises,
136 premise_keys,
137 generation,
138 });
139 self.valid = true;
140 self.generation = generation;
141 }
142
143 pub fn has_valid_justifications(&self) -> bool {
145 !self.justifications.is_empty()
146 }
147
148 pub fn remove_justifications_with_premise(&mut self, premise_handle: &FactHandle) -> bool {
150 let before = self.justifications.len();
151 self.justifications
152 .retain(|j| !j.premises.contains(premise_handle));
153 let after = self.justifications.len();
154
155 if self.justifications.is_empty() {
157 self.valid = false;
158 }
159
160 before != after
161 }
162}
163
164pub struct ProofGraph {
166 nodes_by_handle: HashMap<FactHandle, ProofGraphNode>,
168
169 index_by_key: HashMap<FactKey, Vec<FactHandle>>,
171
172 dependencies: HashMap<FactHandle, HashSet<FactHandle>>,
174
175 generation: u64,
177
178 pub stats: ProofGraphStats,
180}
181
182#[derive(Debug, Clone, Default)]
184pub struct ProofGraphStats {
185 pub total_nodes: usize,
186 pub cache_hits: usize,
187 pub cache_misses: usize,
188 pub invalidations: usize,
189 pub justifications_added: usize,
190}
191
192impl ProofGraph {
193 pub fn new() -> Self {
195 Self {
196 nodes_by_handle: HashMap::new(),
197 index_by_key: HashMap::new(),
198 dependencies: HashMap::new(),
199 generation: 0,
200 stats: ProofGraphStats::default(),
201 }
202 }
203
204 pub fn insert_proof(
206 &mut self,
207 handle: FactHandle,
208 key: FactKey,
209 rule_name: String,
210 premises: Vec<FactHandle>,
211 premise_keys: Vec<String>,
212 ) {
213 self.generation += 1;
214
215 let node = self.nodes_by_handle.entry(handle).or_insert_with(|| {
217 let mut node = ProofGraphNode::new(key.clone());
218 node.handle = Some(handle);
219 self.stats.total_nodes += 1;
220 node
221 });
222
223 node.add_justification(rule_name, premises.clone(), premise_keys, self.generation);
225 self.stats.justifications_added += 1;
226
227 self.index_by_key
229 .entry(key.clone())
230 .or_default()
231 .push(handle);
232
233 for premise in &premises {
235 self.dependencies
236 .entry(*premise)
237 .or_default()
238 .insert(handle);
239
240 if let Some(premise_node) = self.nodes_by_handle.get_mut(premise) {
242 premise_node.dependents.insert(handle);
243 }
244 }
245 }
246
247 pub fn lookup_by_key(&mut self, key: &FactKey) -> Option<Vec<&ProofGraphNode>> {
249 if let Some(handles) = self.index_by_key.get(key) {
250 let nodes: Vec<&ProofGraphNode> = handles
251 .iter()
252 .filter_map(|h| self.nodes_by_handle.get(h))
253 .filter(|n| n.valid)
254 .collect();
255
256 if !nodes.is_empty() {
257 self.stats.cache_hits += 1;
258 Some(nodes)
259 } else {
260 self.stats.cache_misses += 1;
261 None
262 }
263 } else {
264 self.stats.cache_misses += 1;
265 None
266 }
267 }
268
269 pub fn is_proven(&mut self, key: &FactKey) -> bool {
271 self.lookup_by_key(key).is_some()
272 }
273
274 pub fn invalidate_handle(&mut self, handle: &FactHandle) {
276 self.stats.invalidations += 1;
277
278 let dependents = self.dependencies.get(handle).cloned();
280
281 if let Some(node) = self.nodes_by_handle.get_mut(handle) {
283 node.valid = false;
284 }
285
286 if let Some(deps) = dependents {
288 for dep_handle in deps {
289 self.propagate_invalidation(&dep_handle, handle);
290 }
291 }
292 }
293
294 fn propagate_invalidation(
296 &mut self,
297 dependent_handle: &FactHandle,
298 premise_handle: &FactHandle,
299 ) {
300 if let Some(node) = self.nodes_by_handle.get_mut(dependent_handle) {
301 let changed = node.remove_justifications_with_premise(premise_handle);
303
304 if changed && !node.valid {
306 self.stats.invalidations += 1;
307
308 let further_deps = node.dependents.clone();
310 for further_dep in further_deps {
311 self.propagate_invalidation(&further_dep, dependent_handle);
312 }
313 }
314 }
315 }
316
317 pub fn get_node(&self, handle: &FactHandle) -> Option<&ProofGraphNode> {
319 self.nodes_by_handle.get(handle)
320 }
321
322 pub fn clear(&mut self) {
324 self.nodes_by_handle.clear();
325 self.index_by_key.clear();
326 self.dependencies.clear();
327 self.generation = 0;
328 self.stats = ProofGraphStats::default();
329 }
330
331 pub fn generation(&self) -> u64 {
333 self.generation
334 }
335
336 pub fn print_stats(&self) {
338 println!("ProofGraph Statistics:");
339 println!(" Total nodes: {}", self.stats.total_nodes);
340 println!(" Cache hits: {}", self.stats.cache_hits);
341 println!(" Cache misses: {}", self.stats.cache_misses);
342 println!(" Invalidations: {}", self.stats.invalidations);
343 println!(
344 " Justifications added: {}",
345 self.stats.justifications_added
346 );
347
348 if self.stats.cache_hits + self.stats.cache_misses > 0 {
349 let hit_rate = (self.stats.cache_hits as f64)
350 / ((self.stats.cache_hits + self.stats.cache_misses) as f64)
351 * 100.0;
352 println!(" Cache hit rate: {:.1}%", hit_rate);
353 }
354 }
355}
356
357impl Default for ProofGraph {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363pub type SharedProofGraph = Arc<std::sync::Mutex<ProofGraph>>;
365
366pub fn new_shared() -> SharedProofGraph {
368 Arc::new(std::sync::Mutex::new(ProofGraph::new()))
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_fact_key_from_pattern() {
377 let key = FactKey::from_pattern("User.Score >= 80");
378 assert_eq!(key.fact_type, "User");
379 assert_eq!(key.field, Some("Score".to_string()));
380 assert_eq!(key.pattern, "User.Score >= 80");
381 }
382
383 #[test]
384 fn test_proof_graph_insert_and_lookup() {
385 let mut graph = ProofGraph::new();
386 let handle = FactHandle::new(1);
387 let key = FactKey::from_pattern("User.Score >= 80");
388
389 graph.insert_proof(handle, key.clone(), "ScoreRule".to_string(), vec![], vec![]);
390
391 assert!(graph.is_proven(&key));
392 assert_eq!(graph.stats.total_nodes, 1);
393 }
394
395 #[test]
396 fn test_dependency_tracking() {
397 let mut graph = ProofGraph::new();
398 let premise_handle = FactHandle::new(1);
399 let conclusion_handle = FactHandle::new(2);
400
401 let premise_key = FactKey::from_pattern("User.Age >= 18");
402 let conclusion_key = FactKey::from_pattern("User.CanVote == true");
403
404 graph.insert_proof(
406 premise_handle,
407 premise_key.clone(),
408 "AgeRule".to_string(),
409 vec![],
410 vec![],
411 );
412
413 graph.insert_proof(
415 conclusion_handle,
416 conclusion_key.clone(),
417 "VotingRule".to_string(),
418 vec![premise_handle],
419 vec!["User.Age >= 18".to_string()],
420 );
421
422 assert!(graph.is_proven(&premise_key));
423 assert!(graph.is_proven(&conclusion_key));
424
425 graph.invalidate_handle(&premise_handle);
427
428 let conclusion_node = graph.get_node(&conclusion_handle).unwrap();
430 assert!(!conclusion_node.valid);
431 assert_eq!(graph.stats.invalidations, 2); }
433
434 #[test]
435 fn test_multiple_justifications() {
436 let mut graph = ProofGraph::new();
437 let handle = FactHandle::new(1);
438 let key = FactKey::from_pattern("User.IsVIP == true");
439
440 graph.insert_proof(
442 handle,
443 key.clone(),
444 "HighSpenderRule".to_string(),
445 vec![],
446 vec![],
447 );
448
449 graph.insert_proof(
451 handle,
452 key.clone(),
453 "LoyaltyRule".to_string(),
454 vec![],
455 vec![],
456 );
457
458 let node = graph.get_node(&handle).unwrap();
459 assert_eq!(node.justifications.len(), 2);
460 assert!(node.valid);
461 }
462
463 #[test]
464 fn test_cache_statistics() {
465 let mut graph = ProofGraph::new();
466 let key = FactKey::from_pattern("User.Active == true");
467
468 assert!(!graph.is_proven(&key));
470 assert_eq!(graph.stats.cache_misses, 1);
471
472 let handle = FactHandle::new(1);
474 graph.insert_proof(
475 handle,
476 key.clone(),
477 "ActiveRule".to_string(),
478 vec![],
479 vec![],
480 );
481
482 assert!(graph.is_proven(&key));
484 assert_eq!(graph.stats.cache_hits, 1);
485 }
486
487 #[test]
488 fn test_clear() {
489 let mut graph = ProofGraph::new();
490 let handle = FactHandle::new(1);
491 let key = FactKey::from_pattern("Test.Value == 42");
492
493 graph.insert_proof(handle, key.clone(), "TestRule".to_string(), vec![], vec![]);
494 assert!(graph.is_proven(&key));
495
496 graph.clear();
497 assert!(!graph.is_proven(&key));
498 assert_eq!(graph.stats.total_nodes, 0);
499 }
500}