rust_rule_engine/rete/
working_memory.rs

1//! Working Memory for RETE-UL (Drools-style)
2//!
3//! This module implements a Working Memory system similar to Drools, providing:
4//! - FactHandle for tracking inserted objects
5//! - Insert, update, retract operations
6//! - Type indexing for fast lookups
7//! - Change tracking for incremental updates
8
9use std::collections::{HashMap, HashSet};
10use std::sync::atomic::{AtomicU64, Ordering};
11use super::facts::{FactValue, TypedFacts};
12
13/// Unique handle for a fact in working memory (similar to Drools FactHandle)
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct FactHandle(u64);
16
17impl FactHandle {
18    /// Create a new fact handle with a unique ID
19    pub fn new(id: u64) -> Self {
20        Self(id)
21    }
22
23    /// Get the handle ID
24    pub fn id(&self) -> u64 {
25        self.0
26    }
27}
28
29impl std::fmt::Display for FactHandle {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "FactHandle({})", self.0)
32    }
33}
34
35/// A fact stored in working memory
36#[derive(Debug, Clone)]
37pub struct WorkingMemoryFact {
38    /// The fact handle
39    pub handle: FactHandle,
40    /// Fact type (e.g., "Person", "Order")
41    pub fact_type: String,
42    /// The actual fact data
43    pub data: TypedFacts,
44    /// Metadata
45    pub metadata: FactMetadata,
46}
47
48/// Metadata for a fact
49#[derive(Debug, Clone)]
50pub struct FactMetadata {
51    /// When the fact was inserted
52    pub inserted_at: std::time::Instant,
53    /// When the fact was last updated
54    pub updated_at: std::time::Instant,
55    /// Number of updates
56    pub update_count: usize,
57    /// Is this fact retracted?
58    pub retracted: bool,
59}
60
61impl Default for FactMetadata {
62    fn default() -> Self {
63        let now = std::time::Instant::now();
64        Self {
65            inserted_at: now,
66            updated_at: now,
67            update_count: 0,
68            retracted: false,
69        }
70    }
71}
72
73/// Working Memory - stores and manages facts (Drools-style)
74pub struct WorkingMemory {
75    /// All facts by handle
76    facts: HashMap<FactHandle, WorkingMemoryFact>,
77    /// Type index: fact_type -> set of handles
78    type_index: HashMap<String, HashSet<FactHandle>>,
79    /// Next fact ID
80    next_id: AtomicU64,
81    /// Modified handles since last propagation
82    modified_handles: HashSet<FactHandle>,
83    /// Retracted handles since last propagation
84    retracted_handles: HashSet<FactHandle>,
85}
86
87impl WorkingMemory {
88    /// Create a new empty working memory
89    pub fn new() -> Self {
90        Self {
91            facts: HashMap::new(),
92            type_index: HashMap::new(),
93            next_id: AtomicU64::new(1),
94            modified_handles: HashSet::new(),
95            retracted_handles: HashSet::new(),
96        }
97    }
98
99    /// Insert a fact into working memory (returns FactHandle)
100    pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
101        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
102        let handle = FactHandle::new(id);
103
104        let fact = WorkingMemoryFact {
105            handle,
106            fact_type: fact_type.clone(),
107            data,
108            metadata: FactMetadata::default(),
109        };
110
111        self.facts.insert(handle, fact);
112        self.type_index
113            .entry(fact_type)
114            .or_insert_with(HashSet::new)
115            .insert(handle);
116        self.modified_handles.insert(handle);
117
118        handle
119    }
120
121    /// Update a fact in working memory
122    pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<(), String> {
123        let fact = self.facts.get_mut(&handle)
124            .ok_or_else(|| format!("FactHandle {} not found", handle))?;
125
126        if fact.metadata.retracted {
127            return Err(format!("FactHandle {} is retracted", handle));
128        }
129
130        fact.data = data;
131        fact.metadata.updated_at = std::time::Instant::now();
132        fact.metadata.update_count += 1;
133        self.modified_handles.insert(handle);
134
135        Ok(())
136    }
137
138    /// Retract (delete) a fact from working memory
139    pub fn retract(&mut self, handle: FactHandle) -> Result<(), String> {
140        let fact = self.facts.get_mut(&handle)
141            .ok_or_else(|| format!("FactHandle {} not found", handle))?;
142
143        if fact.metadata.retracted {
144            return Err(format!("FactHandle {} already retracted", handle));
145        }
146
147        fact.metadata.retracted = true;
148        self.retracted_handles.insert(handle);
149
150        // Remove from type index
151        if let Some(handles) = self.type_index.get_mut(&fact.fact_type) {
152            handles.remove(&handle);
153        }
154
155        Ok(())
156    }
157
158    /// Get a fact by handle
159    pub fn get(&self, handle: &FactHandle) -> Option<&WorkingMemoryFact> {
160        self.facts.get(handle).filter(|f| !f.metadata.retracted)
161    }
162
163    /// Get all facts of a specific type
164    pub fn get_by_type(&self, fact_type: &str) -> Vec<&WorkingMemoryFact> {
165        if let Some(handles) = self.type_index.get(fact_type) {
166            handles
167                .iter()
168                .filter_map(|h| self.facts.get(h))
169                .filter(|f| !f.metadata.retracted)
170                .collect()
171        } else {
172            Vec::new()
173        }
174    }
175
176    /// Get all facts
177    pub fn get_all_facts(&self) -> Vec<&WorkingMemoryFact> {
178        self.facts
179            .values()
180            .filter(|f| !f.metadata.retracted)
181            .collect()
182    }
183
184    /// Get all fact handles
185    pub fn get_all_handles(&self) -> Vec<FactHandle> {
186        self.facts
187            .values()
188            .filter(|f| !f.metadata.retracted)
189            .map(|f| f.handle)
190            .collect()
191    }
192
193    /// Get modified handles since last clear
194    pub fn get_modified_handles(&self) -> &HashSet<FactHandle> {
195        &self.modified_handles
196    }
197
198    /// Get retracted handles since last clear
199    pub fn get_retracted_handles(&self) -> &HashSet<FactHandle> {
200        &self.retracted_handles
201    }
202
203    /// Clear modification tracking (after propagation)
204    pub fn clear_modification_tracking(&mut self) {
205        self.modified_handles.clear();
206        self.retracted_handles.clear();
207    }
208
209    /// Get statistics
210    pub fn stats(&self) -> WorkingMemoryStats {
211        let active_facts = self.facts.values().filter(|f| !f.metadata.retracted).count();
212        let retracted_facts = self.facts.values().filter(|f| f.metadata.retracted).count();
213
214        WorkingMemoryStats {
215            total_facts: self.facts.len(),
216            active_facts,
217            retracted_facts,
218            types: self.type_index.len(),
219            modified_pending: self.modified_handles.len(),
220            retracted_pending: self.retracted_handles.len(),
221        }
222    }
223
224    /// Clear all facts
225    pub fn clear(&mut self) {
226        self.facts.clear();
227        self.type_index.clear();
228        self.modified_handles.clear();
229        self.retracted_handles.clear();
230    }
231
232    /// Flatten all facts into a single TypedFacts for evaluation
233    /// Each fact's fields are prefixed with "type.handle."
234    pub fn to_typed_facts(&self) -> TypedFacts {
235        let mut result = TypedFacts::new();
236
237        for fact in self.get_all_facts() {
238            let prefix = format!("{}.{}", fact.fact_type, fact.handle.id());
239            for (key, value) in fact.data.get_all() {
240                result.set(format!("{}.{}", prefix, key), value.clone());
241            }
242
243            // Also add without prefix for simple access (last fact of this type wins)
244            for (key, value) in fact.data.get_all() {
245                result.set(format!("{}.{}", fact.fact_type, key), value.clone());
246            }
247            
248            // Store handle for this fact type (last fact wins, but better than nothing)
249            // Actions can use Type._handle to get the handle
250            result.set_fact_handle(fact.fact_type.clone(), fact.handle);
251        }
252
253        result
254    }
255}
256
257impl Default for WorkingMemory {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263/// Working memory statistics
264#[derive(Debug, Clone)]
265pub struct WorkingMemoryStats {
266    pub total_facts: usize,
267    pub active_facts: usize,
268    pub retracted_facts: usize,
269    pub types: usize,
270    pub modified_pending: usize,
271    pub retracted_pending: usize,
272}
273
274impl std::fmt::Display for WorkingMemoryStats {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        write!(
277            f,
278            "WM Stats: {} active, {} retracted, {} types, {} modified, {} pending retraction",
279            self.active_facts, self.retracted_facts, self.types,
280            self.modified_pending, self.retracted_pending
281        )
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_insert_and_get() {
291        let mut wm = WorkingMemory::new();
292        let mut person_data = TypedFacts::new();
293        person_data.set("name", "John");
294        person_data.set("age", 25i64);
295
296        let handle = wm.insert("Person".to_string(), person_data);
297
298        let fact = wm.get(&handle).unwrap();
299        assert_eq!(fact.fact_type, "Person");
300        assert_eq!(fact.data.get("name").unwrap().as_string(), "John");
301    }
302
303    #[test]
304    fn test_update() {
305        let mut wm = WorkingMemory::new();
306        let mut data = TypedFacts::new();
307        data.set("age", 25i64);
308
309        let handle = wm.insert("Person".to_string(), data);
310
311        let mut updated_data = TypedFacts::new();
312        updated_data.set("age", 26i64);
313        wm.update(handle, updated_data).unwrap();
314
315        let fact = wm.get(&handle).unwrap();
316        assert_eq!(fact.data.get("age").unwrap().as_integer(), Some(26));
317        assert_eq!(fact.metadata.update_count, 1);
318    }
319
320    #[test]
321    fn test_retract() {
322        let mut wm = WorkingMemory::new();
323        let data = TypedFacts::new();
324        let handle = wm.insert("Person".to_string(), data);
325
326        wm.retract(handle).unwrap();
327
328        assert!(wm.get(&handle).is_none());
329        assert_eq!(wm.get_all_facts().len(), 0);
330    }
331
332    #[test]
333    fn test_type_index() {
334        let mut wm = WorkingMemory::new();
335
336        for i in 0..5 {
337            let mut data = TypedFacts::new();
338            data.set("id", i as i64);
339            wm.insert("Person".to_string(), data);
340        }
341
342        for i in 0..3 {
343            let mut data = TypedFacts::new();
344            data.set("id", i as i64);
345            wm.insert("Order".to_string(), data);
346        }
347
348        assert_eq!(wm.get_by_type("Person").len(), 5);
349        assert_eq!(wm.get_by_type("Order").len(), 3);
350        assert_eq!(wm.get_by_type("Unknown").len(), 0);
351    }
352
353    #[test]
354    fn test_modification_tracking() {
355        let mut wm = WorkingMemory::new();
356        let data = TypedFacts::new();
357        let h1 = wm.insert("Person".to_string(), data.clone());
358        let h2 = wm.insert("Person".to_string(), data.clone());
359
360        assert_eq!(wm.get_modified_handles().len(), 2);
361
362        wm.clear_modification_tracking();
363        assert_eq!(wm.get_modified_handles().len(), 0);
364
365        wm.update(h1, data.clone()).unwrap();
366        assert_eq!(wm.get_modified_handles().len(), 1);
367
368        wm.retract(h2).unwrap();
369        assert_eq!(wm.get_retracted_handles().len(), 1);
370    }
371}