rust_rule_engine/rete/
working_memory.rs

1//! Working Memory for RETE-UL (Drools-style)
2//!
3//! This module implements a Working Memory system similar to Drools, providing:
4//! - FactHandle for tracking inserted objects
5//! - Insert, update, retract operations
6//! - Type indexing for fast lookups
7//! - Change tracking for incremental updates
8
9use super::facts::TypedFacts;
10use std::collections::{HashMap, HashSet};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13/// Unique handle for a fact in working memory (similar to Drools FactHandle)
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub struct FactHandle(u64);
16
17impl FactHandle {
18    /// Create a new fact handle with a unique ID
19    pub fn new(id: u64) -> Self {
20        Self(id)
21    }
22
23    /// Get the handle ID
24    pub fn id(&self) -> u64 {
25        self.0
26    }
27}
28
29impl std::fmt::Display for FactHandle {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "FactHandle({})", self.0)
32    }
33}
34
35/// A fact stored in working memory
36#[derive(Debug, Clone)]
37pub struct WorkingMemoryFact {
38    /// The fact handle
39    pub handle: FactHandle,
40    /// Fact type (e.g., "Person", "Order")
41    pub fact_type: String,
42    /// The actual fact data
43    pub data: TypedFacts,
44    /// Metadata
45    pub metadata: FactMetadata,
46    /// Stream source (if this fact came from a stream)
47    #[cfg(feature = "streaming")]
48    pub stream_source: Option<String>,
49    /// Stream event (if this fact came from a stream)
50    #[cfg(feature = "streaming")]
51    pub stream_event: Option<crate::streaming::event::StreamEvent>,
52}
53
54/// Metadata for a fact
55#[derive(Debug, Clone)]
56pub struct FactMetadata {
57    /// When the fact was inserted
58    pub inserted_at: std::time::Instant,
59    /// When the fact was last updated
60    pub updated_at: std::time::Instant,
61    /// Number of updates
62    pub update_count: usize,
63    /// Is this fact retracted?
64    pub retracted: bool,
65}
66
67impl Default for FactMetadata {
68    fn default() -> Self {
69        let now = std::time::Instant::now();
70        Self {
71            inserted_at: now,
72            updated_at: now,
73            update_count: 0,
74            retracted: false,
75        }
76    }
77}
78
79/// Working Memory - stores and manages facts (Drools-style)
80pub struct WorkingMemory {
81    /// All facts by handle
82    facts: HashMap<FactHandle, WorkingMemoryFact>,
83    /// Type index: fact_type -> set of handles
84    type_index: HashMap<String, HashSet<FactHandle>>,
85    /// Next fact ID
86    next_id: AtomicU64,
87    /// Modified handles since last propagation
88    modified_handles: HashSet<FactHandle>,
89    /// Retracted handles since last propagation
90    retracted_handles: HashSet<FactHandle>,
91}
92
93impl WorkingMemory {
94    /// Create a new empty working memory
95    pub fn new() -> Self {
96        Self {
97            facts: HashMap::new(),
98            type_index: HashMap::new(),
99            next_id: AtomicU64::new(1),
100            modified_handles: HashSet::new(),
101            retracted_handles: HashSet::new(),
102        }
103    }
104
105    /// Insert a fact into working memory (returns FactHandle)
106    pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
107        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
108        let handle = FactHandle::new(id);
109
110        let fact = WorkingMemoryFact {
111            handle,
112            fact_type: fact_type.clone(),
113            data,
114            metadata: FactMetadata::default(),
115            #[cfg(feature = "streaming")]
116            stream_source: None,
117            #[cfg(feature = "streaming")]
118            stream_event: None,
119        };
120
121        self.facts.insert(handle, fact);
122        self.type_index.entry(fact_type).or_default().insert(handle);
123        self.modified_handles.insert(handle);
124
125        handle
126    }
127
128    /// Insert a fact from a stream event
129    #[cfg(feature = "streaming")]
130    pub fn insert_from_stream(
131        &mut self,
132        stream_name: String,
133        event: crate::streaming::event::StreamEvent,
134    ) -> FactHandle {
135        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
136        let handle = FactHandle::new(id);
137
138        // Convert event data to TypedFacts
139        let mut typed_facts = super::facts::TypedFacts::new();
140        for (key, value) in &event.data {
141            let fact_value: super::facts::FactValue = value.clone().into();
142            typed_facts.set(key.clone(), fact_value);
143        }
144
145        let fact = WorkingMemoryFact {
146            handle,
147            fact_type: event.event_type.clone(),
148            data: typed_facts,
149            metadata: FactMetadata::default(),
150            stream_source: Some(stream_name.clone()),
151            stream_event: Some(event),
152        };
153
154        self.facts.insert(handle, fact);
155        self.type_index
156            .entry(stream_name)
157            .or_default()
158            .insert(handle);
159        self.modified_handles.insert(handle);
160
161        handle
162    }
163
164    /// Update a fact in working memory
165    pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<(), String> {
166        let fact = self
167            .facts
168            .get_mut(&handle)
169            .ok_or_else(|| format!("FactHandle {} not found", handle))?;
170
171        if fact.metadata.retracted {
172            return Err(format!("FactHandle {} is retracted", handle));
173        }
174
175        fact.data = data;
176        fact.metadata.updated_at = std::time::Instant::now();
177        fact.metadata.update_count += 1;
178        self.modified_handles.insert(handle);
179
180        Ok(())
181    }
182
183    /// Retract (delete) a fact from working memory
184    pub fn retract(&mut self, handle: FactHandle) -> Result<(), String> {
185        let fact = self
186            .facts
187            .get_mut(&handle)
188            .ok_or_else(|| format!("FactHandle {} not found", handle))?;
189
190        if fact.metadata.retracted {
191            return Err(format!("FactHandle {} already retracted", handle));
192        }
193
194        fact.metadata.retracted = true;
195        self.retracted_handles.insert(handle);
196
197        // Remove from type index
198        if let Some(handles) = self.type_index.get_mut(&fact.fact_type) {
199            handles.remove(&handle);
200        }
201
202        Ok(())
203    }
204
205    /// Get a fact by handle
206    pub fn get(&self, handle: &FactHandle) -> Option<&WorkingMemoryFact> {
207        self.facts.get(handle).filter(|f| !f.metadata.retracted)
208    }
209
210    /// Get all facts of a specific type
211    pub fn get_by_type(&self, fact_type: &str) -> Vec<&WorkingMemoryFact> {
212        if let Some(handles) = self.type_index.get(fact_type) {
213            handles
214                .iter()
215                .filter_map(|h| self.facts.get(h))
216                .filter(|f| !f.metadata.retracted)
217                .collect()
218        } else {
219            Vec::new()
220        }
221    }
222
223    /// Get all facts
224    pub fn get_all_facts(&self) -> Vec<&WorkingMemoryFact> {
225        self.facts
226            .values()
227            .filter(|f| !f.metadata.retracted)
228            .collect()
229    }
230
231    /// Get all fact handles
232    pub fn get_all_handles(&self) -> Vec<FactHandle> {
233        self.facts
234            .values()
235            .filter(|f| !f.metadata.retracted)
236            .map(|f| f.handle)
237            .collect()
238    }
239
240    /// Get modified handles since last clear
241    pub fn get_modified_handles(&self) -> &HashSet<FactHandle> {
242        &self.modified_handles
243    }
244
245    /// Get retracted handles since last clear
246    pub fn get_retracted_handles(&self) -> &HashSet<FactHandle> {
247        &self.retracted_handles
248    }
249
250    /// Clear modification tracking (after propagation)
251    pub fn clear_modification_tracking(&mut self) {
252        self.modified_handles.clear();
253        self.retracted_handles.clear();
254    }
255
256    /// Get statistics
257    pub fn stats(&self) -> WorkingMemoryStats {
258        let active_facts = self
259            .facts
260            .values()
261            .filter(|f| !f.metadata.retracted)
262            .count();
263        let retracted_facts = self.facts.values().filter(|f| f.metadata.retracted).count();
264
265        WorkingMemoryStats {
266            total_facts: self.facts.len(),
267            active_facts,
268            retracted_facts,
269            types: self.type_index.len(),
270            modified_pending: self.modified_handles.len(),
271            retracted_pending: self.retracted_handles.len(),
272        }
273    }
274
275    /// Clear all facts
276    pub fn clear(&mut self) {
277        self.facts.clear();
278        self.type_index.clear();
279        self.modified_handles.clear();
280        self.retracted_handles.clear();
281    }
282
283    /// Flatten all facts into a single TypedFacts for evaluation
284    /// Each fact's fields are prefixed with "type.handle."
285    pub fn to_typed_facts(&self) -> TypedFacts {
286        let mut result = TypedFacts::new();
287
288        for fact in self.get_all_facts() {
289            let prefix = format!("{}.{}", fact.fact_type, fact.handle.id());
290            for (key, value) in fact.data.get_all() {
291                result.set(format!("{}.{}", prefix, key), value.clone());
292            }
293
294            // Also add without prefix for simple access (last fact of this type wins)
295            for (key, value) in fact.data.get_all() {
296                result.set(format!("{}.{}", fact.fact_type, key), value.clone());
297            }
298
299            // Store handle for this fact type (last fact wins, but better than nothing)
300            // Actions can use Type._handle to get the handle
301            result.set_fact_handle(fact.fact_type.clone(), fact.handle);
302        }
303
304        result
305    }
306}
307
308impl Default for WorkingMemory {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314/// Working memory statistics
315#[derive(Debug, Clone)]
316pub struct WorkingMemoryStats {
317    pub total_facts: usize,
318    pub active_facts: usize,
319    pub retracted_facts: usize,
320    pub types: usize,
321    pub modified_pending: usize,
322    pub retracted_pending: usize,
323}
324
325impl std::fmt::Display for WorkingMemoryStats {
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        write!(
328            f,
329            "WM Stats: {} active, {} retracted, {} types, {} modified, {} pending retraction",
330            self.active_facts,
331            self.retracted_facts,
332            self.types,
333            self.modified_pending,
334            self.retracted_pending
335        )
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342
343    #[test]
344    fn test_insert_and_get() {
345        let mut wm = WorkingMemory::new();
346        let mut person_data = TypedFacts::new();
347        person_data.set("name", "John");
348        person_data.set("age", 25i64);
349
350        let handle = wm.insert("Person".to_string(), person_data);
351
352        let fact = wm.get(&handle).unwrap();
353        assert_eq!(fact.fact_type, "Person");
354        assert_eq!(fact.data.get("name").unwrap().as_string(), "John");
355    }
356
357    #[test]
358    fn test_update() {
359        let mut wm = WorkingMemory::new();
360        let mut data = TypedFacts::new();
361        data.set("age", 25i64);
362
363        let handle = wm.insert("Person".to_string(), data);
364
365        let mut updated_data = TypedFacts::new();
366        updated_data.set("age", 26i64);
367        wm.update(handle, updated_data).unwrap();
368
369        let fact = wm.get(&handle).unwrap();
370        assert_eq!(fact.data.get("age").unwrap().as_integer(), Some(26));
371        assert_eq!(fact.metadata.update_count, 1);
372    }
373
374    #[test]
375    fn test_retract() {
376        let mut wm = WorkingMemory::new();
377        let data = TypedFacts::new();
378        let handle = wm.insert("Person".to_string(), data);
379
380        wm.retract(handle).unwrap();
381
382        assert!(wm.get(&handle).is_none());
383        assert_eq!(wm.get_all_facts().len(), 0);
384    }
385
386    #[test]
387    fn test_type_index() {
388        let mut wm = WorkingMemory::new();
389
390        for i in 0..5 {
391            let mut data = TypedFacts::new();
392            data.set("id", i as i64);
393            wm.insert("Person".to_string(), data);
394        }
395
396        for i in 0..3 {
397            let mut data = TypedFacts::new();
398            data.set("id", i as i64);
399            wm.insert("Order".to_string(), data);
400        }
401
402        assert_eq!(wm.get_by_type("Person").len(), 5);
403        assert_eq!(wm.get_by_type("Order").len(), 3);
404        assert_eq!(wm.get_by_type("Unknown").len(), 0);
405    }
406
407    #[test]
408    fn test_modification_tracking() {
409        let mut wm = WorkingMemory::new();
410        let data = TypedFacts::new();
411        let h1 = wm.insert("Person".to_string(), data.clone());
412        let h2 = wm.insert("Person".to_string(), data.clone());
413
414        assert_eq!(wm.get_modified_handles().len(), 2);
415
416        wm.clear_modification_tracking();
417        assert_eq!(wm.get_modified_handles().len(), 0);
418
419        wm.update(h1, data.clone()).unwrap();
420        assert_eq!(wm.get_modified_handles().len(), 1);
421
422        wm.retract(h2).unwrap();
423        assert_eq!(wm.get_retracted_handles().len(), 1);
424    }
425}