rust_rule_engine/engine/
facts.rs

1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7/// Facts - represents the working memory of data objects
8/// Similar to Grule's DataContext concept
9#[derive(Debug, Clone)]
10pub struct Facts {
11    data: Arc<RwLock<HashMap<String, Value>>>,
12    fact_types: Arc<RwLock<HashMap<String, String>>>,
13    /// Undo log frames for lightweight snapshots (stack of frames)
14    /// Each frame records per-key previous values so rollback can restore only
15    /// changed keys instead of cloning the whole facts map.
16    undo_frames: Arc<RwLock<Vec<Vec<UndoEntry>>>>,
17}
18
19impl Facts {
20    /// Create a generic object from key-value pairs
21    pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
22        let mut map = HashMap::new();
23        for (key, value) in pairs {
24            map.insert(key, value);
25        }
26        Value::Object(map)
27    }
28
29    /// Create a user object
30    pub fn new() -> Self {
31        Self {
32            data: Arc::new(RwLock::new(HashMap::new())),
33            fact_types: Arc::new(RwLock::new(HashMap::new())),
34            undo_frames: Arc::new(RwLock::new(Vec::new())),
35        }
36    }
37
38    /// Add a fact object to the working memory
39    pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
40    where
41        T: Serialize + std::fmt::Debug,
42    {
43        let value =
44            serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
45                message: e.to_string(),
46            })?;
47
48        let fact_value = Value::from(value);
49
50        let mut data = self.data.write().unwrap();
51        let mut types = self.fact_types.write().unwrap();
52
53        data.insert(name.to_string(), fact_value);
54        types.insert(name.to_string(), std::any::type_name::<T>().to_string());
55
56        Ok(())
57    }
58
59    /// Add a simple value fact
60    pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
61        let mut data = self.data.write().unwrap();
62        let mut types = self.fact_types.write().unwrap();
63
64        data.insert(name.to_string(), value);
65        types.insert(name.to_string(), "Value".to_string());
66
67        Ok(())
68    }
69
70    /// Get a fact by name
71    pub fn get(&self, name: &str) -> Option<Value> {
72        let data = self.data.read().unwrap();
73        data.get(name).cloned()
74    }
75
76    /// Get a nested fact property (e.g., "User.Profile.Age")
77    pub fn get_nested(&self, path: &str) -> Option<Value> {
78        let parts: Vec<&str> = path.split('.').collect();
79        if parts.is_empty() {
80            return None;
81        }
82
83        let data = self.data.read().unwrap();
84        let mut current = data.get(parts[0])?;
85
86        for part in parts.iter().skip(1) {
87            match current {
88                Value::Object(ref obj) => {
89                    current = obj.get(*part)?;
90                }
91                _ => return None,
92            }
93        }
94
95        Some(current.clone())
96    }
97
98    /// Set a fact value
99    pub fn set(&self, name: &str, value: Value) {
100        // Record previous value for undo if an undo frame is active
101        self.record_undo_for_key(name);
102
103        let mut data = self.data.write().unwrap();
104        data.insert(name.to_string(), value);
105    }
106
107    /// Set a nested fact property
108    pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
109        let parts: Vec<&str> = path.split('.').collect();
110        if parts.is_empty() {
111            return Err(RuleEngineError::FieldNotFound {
112                field: path.to_string(),
113            });
114        }
115
116        // Record previous top-level key for undo semantics
117        self.record_undo_for_key(parts[0]);
118
119        let mut data = self.data.write().unwrap();
120
121        if parts.len() == 1 {
122            data.insert(parts[0].to_string(), value);
123            return Ok(());
124        }
125
126        // Navigate to parent and set the nested value
127        let root_key = parts[0];
128        let root_value = data
129            .get_mut(root_key)
130            .ok_or_else(|| RuleEngineError::FieldNotFound {
131                field: root_key.to_string(),
132            })?;
133
134        self.set_nested_in_value(root_value, &parts[1..], value)?;
135        Ok(())
136    }
137
138    #[allow(clippy::only_used_in_recursion)]
139    fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
140        if path.is_empty() {
141            return Ok(());
142        }
143
144        if path.len() == 1 {
145            // We're at the target field
146            match current {
147                Value::Object(ref mut obj) => {
148                    obj.insert(path[0].to_string(), value);
149                    Ok(())
150                }
151                _ => Err(RuleEngineError::TypeMismatch {
152                    expected: "Object".to_string(),
153                    actual: format!("{:?}", current),
154                }),
155            }
156        } else {
157            // Continue navigating
158            match current {
159                Value::Object(ref mut obj) => {
160                    let next_value =
161                        obj.get_mut(path[0])
162                            .ok_or_else(|| RuleEngineError::FieldNotFound {
163                                field: path[0].to_string(),
164                            })?;
165                    self.set_nested_in_value(next_value, &path[1..], value)
166                }
167                _ => Err(RuleEngineError::TypeMismatch {
168                    expected: "Object".to_string(),
169                    actual: format!("{:?}", current),
170                }),
171            }
172        }
173    }
174
175    /// Remove a fact
176    pub fn remove(&self, name: &str) -> Option<Value> {
177        // Record undo before removing
178        self.record_undo_for_key(name);
179
180        let mut data = self.data.write().unwrap();
181        let mut types = self.fact_types.write().unwrap();
182
183        types.remove(name);
184        data.remove(name)
185    }
186
187    /// Clear all facts
188    pub fn clear(&self) {
189        let mut data = self.data.write().unwrap();
190        let mut types = self.fact_types.write().unwrap();
191
192        data.clear();
193        types.clear();
194    }
195
196    /// Get all fact names
197    pub fn get_fact_names(&self) -> Vec<String> {
198        let data = self.data.read().unwrap();
199        data.keys().cloned().collect()
200    }
201
202    /// Get fact count
203    pub fn count(&self) -> usize {
204        let data = self.data.read().unwrap();
205        data.len()
206    }
207
208    /// Check if a fact exists
209    pub fn contains(&self, name: &str) -> bool {
210        let data = self.data.read().unwrap();
211        data.contains_key(name)
212    }
213
214    /// Get all facts as a HashMap (for pattern matching evaluation)
215    pub fn get_all_facts(&self) -> HashMap<String, Value> {
216        let data = self.data.read().unwrap();
217        data.clone()
218    }
219
220    /// Get the type name of a fact
221    pub fn get_fact_type(&self, name: &str) -> Option<String> {
222        let types = self.fact_types.read().unwrap();
223        types.get(name).cloned()
224    }
225
226    /// Convert to Context for rule evaluation
227    pub fn to_context(&self) -> Context {
228        let data = self.data.read().unwrap();
229        data.clone()
230    }
231
232    /// Create Facts from Context
233    pub fn from_context(context: Context) -> Self {
234        let facts = Facts::new();
235        {
236            let mut data = facts.data.write().unwrap();
237            *data = context;
238        }
239        facts
240    }
241
242    /// Merge another Facts instance into this one
243    pub fn merge(&self, other: &Facts) {
244        let other_data = other.data.read().unwrap();
245        let other_types = other.fact_types.read().unwrap();
246
247        let mut data = self.data.write().unwrap();
248        let mut types = self.fact_types.write().unwrap();
249
250        for (key, value) in other_data.iter() {
251            data.insert(key.clone(), value.clone());
252        }
253
254        for (key, type_name) in other_types.iter() {
255            types.insert(key.clone(), type_name.clone());
256        }
257    }
258
259    /// Get a snapshot of all facts
260    pub fn snapshot(&self) -> FactsSnapshot {
261        let data = self.data.read().unwrap();
262        let types = self.fact_types.read().unwrap();
263
264        FactsSnapshot {
265            data: data.clone(),
266            fact_types: types.clone(),
267        }
268    }
269
270    /// Restore from a snapshot
271    pub fn restore(&self, snapshot: FactsSnapshot) {
272        let mut data = self.data.write().unwrap();
273        let mut types = self.fact_types.write().unwrap();
274
275        *data = snapshot.data;
276        *types = snapshot.fact_types;
277    }
278}
279
280impl Default for Facts {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286/// A snapshot of Facts state
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct FactsSnapshot {
289    /// The fact data stored as key-value pairs
290    pub data: HashMap<String, Value>,
291    /// Type information for each fact
292    pub fact_types: HashMap<String, String>,
293}
294
295/// Undo entry for a single key
296#[derive(Debug, Clone)]
297struct UndoEntry {
298    key: String,
299    prev_value: Option<Value>,
300    prev_type: Option<String>,
301}
302
303impl Facts {
304    /// Start a new undo frame. Call `rollback_undo_frame` to revert or
305    /// `commit_undo_frame` to discard recorded changes.
306    pub fn begin_undo_frame(&self) {
307        let mut frames = self.undo_frames.write().unwrap();
308        frames.push(Vec::new());
309    }
310
311    /// Commit (discard) the top-most undo frame
312    pub fn commit_undo_frame(&self) {
313        let mut frames = self.undo_frames.write().unwrap();
314        frames.pop();
315    }
316
317    /// Rollback the top-most undo frame, restoring prior values
318    pub fn rollback_undo_frame(&self) {
319        let mut frames = self.undo_frames.write().unwrap();
320        if let Some(frame) = frames.pop() {
321            // Restore in reverse order
322            let mut data = self.data.write().unwrap();
323            let mut types = self.fact_types.write().unwrap();
324
325            for entry in frame.into_iter().rev() {
326                match entry.prev_value {
327                    Some(v) => {
328                        data.insert(entry.key.clone(), v);
329                    }
330                    None => {
331                        data.remove(&entry.key);
332                    }
333                }
334
335                match entry.prev_type {
336                    Some(t) => {
337                        types.insert(entry.key.clone(), t);
338                    }
339                    None => {
340                        types.remove(&entry.key);
341                    }
342                }
343            }
344        }
345    }
346
347    /// Record prior state for a top-level key if an undo frame is active
348    fn record_undo_for_key(&self, key: &str) {
349        let mut frames = self.undo_frames.write().unwrap();
350        if let Some(frame) = frames.last_mut() {
351            // capture previous value & type
352            let data = self.data.read().unwrap();
353            let types = self.fact_types.read().unwrap();
354
355            // Only record once per key in this frame
356            if frame.iter().any(|e: &UndoEntry| e.key == key) {
357                return;
358            }
359
360            let prev_value = data.get(key).cloned();
361            let prev_type = types.get(key).cloned();
362
363            frame.push(UndoEntry {
364                key: key.to_string(),
365                prev_value,
366                prev_type,
367            });
368        }
369    }
370}
371
372/// Trait for objects that can be used as facts
373pub trait Fact: Serialize + std::fmt::Debug {
374    /// Get the name of this fact type
375    fn fact_name() -> &'static str;
376}
377
378/// Macro to implement Fact trait easily
379#[macro_export]
380macro_rules! impl_fact {
381    ($type:ty, $name:expr) => {
382        impl Fact for $type {
383            fn fact_name() -> &'static str {
384                $name
385            }
386        }
387    };
388}
389
390/// Helper functions for working with fact objects
391pub struct FactHelper;
392
393impl FactHelper {
394    /// Create a generic object with key-value pairs
395    pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
396        let mut object = HashMap::new();
397        for (key, value) in pairs {
398            object.insert(key.to_string(), value);
399        }
400        Value::Object(object)
401    }
402
403    /// Create a User fact from common fields
404    pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
405        let mut user = HashMap::new();
406        user.insert("Name".to_string(), Value::String(name.to_string()));
407        user.insert("Age".to_string(), Value::Integer(age));
408        user.insert("Email".to_string(), Value::String(email.to_string()));
409        user.insert("Country".to_string(), Value::String(country.to_string()));
410        user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
411
412        Value::Object(user)
413    }
414
415    /// Create a Product fact
416    pub fn create_product(
417        name: &str,
418        price: f64,
419        category: &str,
420        in_stock: bool,
421        stock_count: i64,
422    ) -> Value {
423        let mut product = HashMap::new();
424        product.insert("Name".to_string(), Value::String(name.to_string()));
425        product.insert("Price".to_string(), Value::Number(price));
426        product.insert("Category".to_string(), Value::String(category.to_string()));
427        product.insert("InStock".to_string(), Value::Boolean(in_stock));
428        product.insert("StockCount".to_string(), Value::Integer(stock_count));
429
430        Value::Object(product)
431    }
432
433    /// Create an Order fact
434    pub fn create_order(
435        id: &str,
436        user_id: &str,
437        total: f64,
438        item_count: i64,
439        status: &str,
440    ) -> Value {
441        let mut order = HashMap::new();
442        order.insert("ID".to_string(), Value::String(id.to_string()));
443        order.insert("UserID".to_string(), Value::String(user_id.to_string()));
444        order.insert("Total".to_string(), Value::Number(total));
445        order.insert("ItemCount".to_string(), Value::Integer(item_count));
446        order.insert("Status".to_string(), Value::String(status.to_string()));
447
448        Value::Object(order)
449    }
450
451    /// Create a TestCar object for method call demo
452    pub fn create_test_car(
453        speed_up: bool,
454        speed: f64,
455        max_speed: f64,
456        speed_increment: f64,
457    ) -> Value {
458        let mut car = HashMap::new();
459        car.insert("speedUp".to_string(), Value::Boolean(speed_up));
460        car.insert("speed".to_string(), Value::Number(speed));
461        car.insert("maxSpeed".to_string(), Value::Number(max_speed));
462        car.insert("Speed".to_string(), Value::Number(speed));
463        car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
464        car.insert(
465            "_type".to_string(),
466            Value::String("TestCarClass".to_string()),
467        );
468
469        Value::Object(car)
470    }
471
472    /// Create a DistanceRecord object for method call demo  
473    pub fn create_distance_record(total_distance: f64) -> Value {
474        let mut record = HashMap::new();
475        record.insert("TotalDistance".to_string(), Value::Number(total_distance));
476        record.insert(
477            "_type".to_string(),
478            Value::String("DistanceRecordClass".to_string()),
479        );
480
481        Value::Object(record)
482    }
483
484    /// Create a Transaction fact for fraud detection
485    pub fn create_transaction(
486        id: &str,
487        amount: f64,
488        location: &str,
489        timestamp: i64,
490        user_id: &str,
491    ) -> Value {
492        let mut transaction = HashMap::new();
493        transaction.insert("ID".to_string(), Value::String(id.to_string()));
494        transaction.insert("Amount".to_string(), Value::Number(amount));
495        transaction.insert("Location".to_string(), Value::String(location.to_string()));
496        transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
497        transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
498
499        Value::Object(transaction)
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_facts_basic_operations() {
509        let facts = Facts::new();
510
511        // Add facts
512        facts.add_value("age", Value::Integer(25)).unwrap();
513        facts
514            .add_value("name", Value::String("John".to_string()))
515            .unwrap();
516
517        // Get facts
518        assert_eq!(facts.get("age"), Some(Value::Integer(25)));
519        assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
520
521        // Count
522        assert_eq!(facts.count(), 2);
523
524        // Contains
525        assert!(facts.contains("age"));
526        assert!(!facts.contains("email"));
527    }
528
529    #[test]
530    fn test_nested_facts() {
531        let facts = Facts::new();
532        let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
533
534        facts.add_value("User", user).unwrap();
535
536        // Get nested values
537        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
538        assert_eq!(
539            facts.get_nested("User.Name"),
540            Some(Value::String("John".to_string()))
541        );
542
543        // Set nested values
544        facts.set_nested("User.Age", Value::Integer(26)).unwrap();
545        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
546    }
547
548    #[test]
549    fn test_facts_snapshot() {
550        let facts = Facts::new();
551        facts
552            .add_value("test", Value::String("value".to_string()))
553            .unwrap();
554
555        let snapshot = facts.snapshot();
556
557        facts.clear();
558        assert_eq!(facts.count(), 0);
559
560        facts.restore(snapshot);
561        assert_eq!(facts.count(), 1);
562        assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
563    }
564}