rust_rule_engine/rete/
working_memory.rs

1//! Working Memory for RETE-UL (Drools-style)
2//!
3//! This module implements a Working Memory system similar to Drools, providing:
4//! - FactHandle for tracking inserted objects
5//! - Insert, update, retract operations
6//! - Type indexing for fast lookups
7//! - Change tracking for incremental updates
8
9use std::collections::{HashMap, HashSet};
10use std::sync::atomic::{AtomicU64, Ordering};
11use super::facts::{FactValue, TypedFacts};
12
13/// Unique handle for a fact in working memory (similar to Drools FactHandle)
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct FactHandle(u64);
16
17impl FactHandle {
18    /// Create a new fact handle with a unique ID
19    fn new(id: u64) -> Self {
20        Self(id)
21    }
22
23    /// Get the handle ID
24    pub fn id(&self) -> u64 {
25        self.0
26    }
27}
28
29impl std::fmt::Display for FactHandle {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "FactHandle({})", self.0)
32    }
33}
34
35/// A fact stored in working memory
36#[derive(Debug, Clone)]
37pub struct WorkingMemoryFact {
38    /// The fact handle
39    pub handle: FactHandle,
40    /// Fact type (e.g., "Person", "Order")
41    pub fact_type: String,
42    /// The actual fact data
43    pub data: TypedFacts,
44    /// Metadata
45    pub metadata: FactMetadata,
46}
47
48/// Metadata for a fact
49#[derive(Debug, Clone)]
50pub struct FactMetadata {
51    /// When the fact was inserted
52    pub inserted_at: std::time::Instant,
53    /// When the fact was last updated
54    pub updated_at: std::time::Instant,
55    /// Number of updates
56    pub update_count: usize,
57    /// Is this fact retracted?
58    pub retracted: bool,
59}
60
61impl Default for FactMetadata {
62    fn default() -> Self {
63        let now = std::time::Instant::now();
64        Self {
65            inserted_at: now,
66            updated_at: now,
67            update_count: 0,
68            retracted: false,
69        }
70    }
71}
72
73/// Working Memory - stores and manages facts (Drools-style)
74pub struct WorkingMemory {
75    /// All facts by handle
76    facts: HashMap<FactHandle, WorkingMemoryFact>,
77    /// Type index: fact_type -> set of handles
78    type_index: HashMap<String, HashSet<FactHandle>>,
79    /// Next fact ID
80    next_id: AtomicU64,
81    /// Modified handles since last propagation
82    modified_handles: HashSet<FactHandle>,
83    /// Retracted handles since last propagation
84    retracted_handles: HashSet<FactHandle>,
85}
86
87impl WorkingMemory {
88    /// Create a new empty working memory
89    pub fn new() -> Self {
90        Self {
91            facts: HashMap::new(),
92            type_index: HashMap::new(),
93            next_id: AtomicU64::new(1),
94            modified_handles: HashSet::new(),
95            retracted_handles: HashSet::new(),
96        }
97    }
98
99    /// Insert a fact into working memory (returns FactHandle)
100    pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
101        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
102        let handle = FactHandle::new(id);
103
104        let fact = WorkingMemoryFact {
105            handle,
106            fact_type: fact_type.clone(),
107            data,
108            metadata: FactMetadata::default(),
109        };
110
111        self.facts.insert(handle, fact);
112        self.type_index
113            .entry(fact_type)
114            .or_insert_with(HashSet::new)
115            .insert(handle);
116        self.modified_handles.insert(handle);
117
118        handle
119    }
120
121    /// Update a fact in working memory
122    pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<(), String> {
123        let fact = self.facts.get_mut(&handle)
124            .ok_or_else(|| format!("FactHandle {} not found", handle))?;
125
126        if fact.metadata.retracted {
127            return Err(format!("FactHandle {} is retracted", handle));
128        }
129
130        fact.data = data;
131        fact.metadata.updated_at = std::time::Instant::now();
132        fact.metadata.update_count += 1;
133        self.modified_handles.insert(handle);
134
135        Ok(())
136    }
137
138    /// Retract (delete) a fact from working memory
139    pub fn retract(&mut self, handle: FactHandle) -> Result<(), String> {
140        let fact = self.facts.get_mut(&handle)
141            .ok_or_else(|| format!("FactHandle {} not found", handle))?;
142
143        if fact.metadata.retracted {
144            return Err(format!("FactHandle {} already retracted", handle));
145        }
146
147        fact.metadata.retracted = true;
148        self.retracted_handles.insert(handle);
149
150        // Remove from type index
151        if let Some(handles) = self.type_index.get_mut(&fact.fact_type) {
152            handles.remove(&handle);
153        }
154
155        Ok(())
156    }
157
158    /// Get a fact by handle
159    pub fn get(&self, handle: &FactHandle) -> Option<&WorkingMemoryFact> {
160        self.facts.get(handle).filter(|f| !f.metadata.retracted)
161    }
162
163    /// Get all facts of a specific type
164    pub fn get_by_type(&self, fact_type: &str) -> Vec<&WorkingMemoryFact> {
165        if let Some(handles) = self.type_index.get(fact_type) {
166            handles
167                .iter()
168                .filter_map(|h| self.facts.get(h))
169                .filter(|f| !f.metadata.retracted)
170                .collect()
171        } else {
172            Vec::new()
173        }
174    }
175
176    /// Get all facts
177    pub fn get_all_facts(&self) -> Vec<&WorkingMemoryFact> {
178        self.facts
179            .values()
180            .filter(|f| !f.metadata.retracted)
181            .collect()
182    }
183
184    /// Get all fact handles
185    pub fn get_all_handles(&self) -> Vec<FactHandle> {
186        self.facts
187            .values()
188            .filter(|f| !f.metadata.retracted)
189            .map(|f| f.handle)
190            .collect()
191    }
192
193    /// Get modified handles since last clear
194    pub fn get_modified_handles(&self) -> &HashSet<FactHandle> {
195        &self.modified_handles
196    }
197
198    /// Get retracted handles since last clear
199    pub fn get_retracted_handles(&self) -> &HashSet<FactHandle> {
200        &self.retracted_handles
201    }
202
203    /// Clear modification tracking (after propagation)
204    pub fn clear_modification_tracking(&mut self) {
205        self.modified_handles.clear();
206        self.retracted_handles.clear();
207    }
208
209    /// Get statistics
210    pub fn stats(&self) -> WorkingMemoryStats {
211        let active_facts = self.facts.values().filter(|f| !f.metadata.retracted).count();
212        let retracted_facts = self.facts.values().filter(|f| f.metadata.retracted).count();
213
214        WorkingMemoryStats {
215            total_facts: self.facts.len(),
216            active_facts,
217            retracted_facts,
218            types: self.type_index.len(),
219            modified_pending: self.modified_handles.len(),
220            retracted_pending: self.retracted_handles.len(),
221        }
222    }
223
224    /// Clear all facts
225    pub fn clear(&mut self) {
226        self.facts.clear();
227        self.type_index.clear();
228        self.modified_handles.clear();
229        self.retracted_handles.clear();
230    }
231
232    /// Flatten all facts into a single TypedFacts for evaluation
233    /// Each fact's fields are prefixed with "type.handle."
234    pub fn to_typed_facts(&self) -> TypedFacts {
235        let mut result = TypedFacts::new();
236
237        for fact in self.get_all_facts() {
238            let prefix = format!("{}.{}", fact.fact_type, fact.handle.id());
239            for (key, value) in fact.data.get_all() {
240                result.set(format!("{}.{}", prefix, key), value.clone());
241            }
242
243            // Also add without prefix for simple access
244            for (key, value) in fact.data.get_all() {
245                result.set(format!("{}.{}", fact.fact_type, key), value.clone());
246            }
247        }
248
249        result
250    }
251}
252
253impl Default for WorkingMemory {
254    fn default() -> Self {
255        Self::new()
256    }
257}
258
259/// Working memory statistics
260#[derive(Debug, Clone)]
261pub struct WorkingMemoryStats {
262    pub total_facts: usize,
263    pub active_facts: usize,
264    pub retracted_facts: usize,
265    pub types: usize,
266    pub modified_pending: usize,
267    pub retracted_pending: usize,
268}
269
270impl std::fmt::Display for WorkingMemoryStats {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        write!(
273            f,
274            "WM Stats: {} active, {} retracted, {} types, {} modified, {} pending retraction",
275            self.active_facts, self.retracted_facts, self.types,
276            self.modified_pending, self.retracted_pending
277        )
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_insert_and_get() {
287        let mut wm = WorkingMemory::new();
288        let mut person_data = TypedFacts::new();
289        person_data.set("name", "John");
290        person_data.set("age", 25i64);
291
292        let handle = wm.insert("Person".to_string(), person_data);
293
294        let fact = wm.get(&handle).unwrap();
295        assert_eq!(fact.fact_type, "Person");
296        assert_eq!(fact.data.get("name").unwrap().as_string(), "John");
297    }
298
299    #[test]
300    fn test_update() {
301        let mut wm = WorkingMemory::new();
302        let mut data = TypedFacts::new();
303        data.set("age", 25i64);
304
305        let handle = wm.insert("Person".to_string(), data);
306
307        let mut updated_data = TypedFacts::new();
308        updated_data.set("age", 26i64);
309        wm.update(handle, updated_data).unwrap();
310
311        let fact = wm.get(&handle).unwrap();
312        assert_eq!(fact.data.get("age").unwrap().as_integer(), Some(26));
313        assert_eq!(fact.metadata.update_count, 1);
314    }
315
316    #[test]
317    fn test_retract() {
318        let mut wm = WorkingMemory::new();
319        let data = TypedFacts::new();
320        let handle = wm.insert("Person".to_string(), data);
321
322        wm.retract(handle).unwrap();
323
324        assert!(wm.get(&handle).is_none());
325        assert_eq!(wm.get_all_facts().len(), 0);
326    }
327
328    #[test]
329    fn test_type_index() {
330        let mut wm = WorkingMemory::new();
331
332        for i in 0..5 {
333            let mut data = TypedFacts::new();
334            data.set("id", i as i64);
335            wm.insert("Person".to_string(), data);
336        }
337
338        for i in 0..3 {
339            let mut data = TypedFacts::new();
340            data.set("id", i as i64);
341            wm.insert("Order".to_string(), data);
342        }
343
344        assert_eq!(wm.get_by_type("Person").len(), 5);
345        assert_eq!(wm.get_by_type("Order").len(), 3);
346        assert_eq!(wm.get_by_type("Unknown").len(), 0);
347    }
348
349    #[test]
350    fn test_modification_tracking() {
351        let mut wm = WorkingMemory::new();
352        let data = TypedFacts::new();
353        let h1 = wm.insert("Person".to_string(), data.clone());
354        let h2 = wm.insert("Person".to_string(), data.clone());
355
356        assert_eq!(wm.get_modified_handles().len(), 2);
357
358        wm.clear_modification_tracking();
359        assert_eq!(wm.get_modified_handles().len(), 0);
360
361        wm.update(h1, data.clone()).unwrap();
362        assert_eq!(wm.get_modified_handles().len(), 1);
363
364        wm.retract(h2).unwrap();
365        assert_eq!(wm.get_retracted_handles().len(), 1);
366    }
367}