rust_rule_engine/engine/
facts.rs

1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use std::collections::HashSet;
7
8/// Facts - represents the working memory of data objects
9/// Similar to Grule's DataContext concept
10#[derive(Debug, Clone)]
11pub struct Facts {
12    data: Arc<RwLock<HashMap<String, Value>>>,
13    fact_types: Arc<RwLock<HashMap<String, String>>>,
14    /// Undo log frames for lightweight snapshots (stack of frames)
15    /// Each frame records per-key previous values so rollback can restore only
16    /// changed keys instead of cloning the whole facts map.
17    undo_frames: Arc<RwLock<Vec<Vec<UndoEntry>>>>,
18}
19
20impl Facts {
21    /// Create a generic object from key-value pairs
22    pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
23        let mut map = HashMap::new();
24        for (key, value) in pairs {
25            map.insert(key, value);
26        }
27        Value::Object(map)
28    }
29
30    /// Create a user object
31    pub fn new() -> Self {
32        Self {
33            data: Arc::new(RwLock::new(HashMap::new())),
34            fact_types: Arc::new(RwLock::new(HashMap::new())),
35            undo_frames: Arc::new(RwLock::new(Vec::new())),
36        }
37    }
38
39    /// Add a fact object to the working memory
40    pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
41    where
42        T: Serialize + std::fmt::Debug,
43    {
44        let value =
45            serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
46                message: e.to_string(),
47            })?;
48
49        let fact_value = Value::from(value);
50
51        let mut data = self.data.write().unwrap();
52        let mut types = self.fact_types.write().unwrap();
53
54        data.insert(name.to_string(), fact_value);
55        types.insert(name.to_string(), std::any::type_name::<T>().to_string());
56
57        Ok(())
58    }
59
60    /// Add a simple value fact
61    pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
62        let mut data = self.data.write().unwrap();
63        let mut types = self.fact_types.write().unwrap();
64
65        data.insert(name.to_string(), value);
66        types.insert(name.to_string(), "Value".to_string());
67
68        Ok(())
69    }
70
71    /// Get a fact by name
72    pub fn get(&self, name: &str) -> Option<Value> {
73        let data = self.data.read().unwrap();
74        data.get(name).cloned()
75    }
76
77    /// Get a nested fact property (e.g., "User.Profile.Age")
78    pub fn get_nested(&self, path: &str) -> Option<Value> {
79        let parts: Vec<&str> = path.split('.').collect();
80        if parts.is_empty() {
81            return None;
82        }
83
84        let data = self.data.read().unwrap();
85        let mut current = data.get(parts[0])?;
86
87        for part in parts.iter().skip(1) {
88            match current {
89                Value::Object(ref obj) => {
90                    current = obj.get(*part)?;
91                }
92                _ => return None,
93            }
94        }
95
96        Some(current.clone())
97    }
98
99    /// Set a fact value
100    pub fn set(&self, name: &str, value: Value) {
101        // Record previous value for undo if an undo frame is active
102        self.record_undo_for_key(name);
103
104        let mut data = self.data.write().unwrap();
105        data.insert(name.to_string(), value);
106    }
107
108    /// Set a nested fact property
109    pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
110        let parts: Vec<&str> = path.split('.').collect();
111        if parts.is_empty() {
112            return Err(RuleEngineError::FieldNotFound {
113                field: path.to_string(),
114            });
115        }
116
117    // Record previous top-level key for undo semantics
118    self.record_undo_for_key(parts[0]);
119
120    let mut data = self.data.write().unwrap();
121
122        if parts.len() == 1 {
123            data.insert(parts[0].to_string(), value);
124            return Ok(());
125        }
126
127        // Navigate to parent and set the nested value
128        let root_key = parts[0];
129        let root_value = data
130            .get_mut(root_key)
131            .ok_or_else(|| RuleEngineError::FieldNotFound {
132                field: root_key.to_string(),
133            })?;
134
135        self.set_nested_in_value(root_value, &parts[1..], value)?;
136        Ok(())
137    }
138
139    #[allow(clippy::only_used_in_recursion)]
140    fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
141        if path.is_empty() {
142            return Ok(());
143        }
144
145        if path.len() == 1 {
146            // We're at the target field
147            match current {
148                Value::Object(ref mut obj) => {
149                    obj.insert(path[0].to_string(), value);
150                    Ok(())
151                }
152                _ => Err(RuleEngineError::TypeMismatch {
153                    expected: "Object".to_string(),
154                    actual: format!("{:?}", current),
155                }),
156            }
157        } else {
158            // Continue navigating
159            match current {
160                Value::Object(ref mut obj) => {
161                    let next_value =
162                        obj.get_mut(path[0])
163                            .ok_or_else(|| RuleEngineError::FieldNotFound {
164                                field: path[0].to_string(),
165                            })?;
166                    self.set_nested_in_value(next_value, &path[1..], value)
167                }
168                _ => Err(RuleEngineError::TypeMismatch {
169                    expected: "Object".to_string(),
170                    actual: format!("{:?}", current),
171                }),
172            }
173        }
174    }
175
176    /// Remove a fact
177    pub fn remove(&self, name: &str) -> Option<Value> {
178        // Record undo before removing
179        self.record_undo_for_key(name);
180
181        let mut data = self.data.write().unwrap();
182        let mut types = self.fact_types.write().unwrap();
183
184        types.remove(name);
185        data.remove(name)
186    }
187
188    /// Clear all facts
189    pub fn clear(&self) {
190        let mut data = self.data.write().unwrap();
191        let mut types = self.fact_types.write().unwrap();
192
193        data.clear();
194        types.clear();
195    }
196
197    /// Get all fact names
198    pub fn get_fact_names(&self) -> Vec<String> {
199        let data = self.data.read().unwrap();
200        data.keys().cloned().collect()
201    }
202
203    /// Get fact count
204    pub fn count(&self) -> usize {
205        let data = self.data.read().unwrap();
206        data.len()
207    }
208
209    /// Check if a fact exists
210    pub fn contains(&self, name: &str) -> bool {
211        let data = self.data.read().unwrap();
212        data.contains_key(name)
213    }
214
215    /// Get all facts as a HashMap (for pattern matching evaluation)
216    pub fn get_all_facts(&self) -> HashMap<String, Value> {
217        let data = self.data.read().unwrap();
218        data.clone()
219    }
220
221    /// Get the type name of a fact
222    pub fn get_fact_type(&self, name: &str) -> Option<String> {
223        let types = self.fact_types.read().unwrap();
224        types.get(name).cloned()
225    }
226
227    /// Convert to Context for rule evaluation
228    pub fn to_context(&self) -> Context {
229        let data = self.data.read().unwrap();
230        data.clone()
231    }
232
233    /// Create Facts from Context
234    pub fn from_context(context: Context) -> Self {
235        let facts = Facts::new();
236        {
237            let mut data = facts.data.write().unwrap();
238            *data = context;
239        }
240        facts
241    }
242
243    /// Merge another Facts instance into this one
244    pub fn merge(&self, other: &Facts) {
245        let other_data = other.data.read().unwrap();
246        let other_types = other.fact_types.read().unwrap();
247
248        let mut data = self.data.write().unwrap();
249        let mut types = self.fact_types.write().unwrap();
250
251        for (key, value) in other_data.iter() {
252            data.insert(key.clone(), value.clone());
253        }
254
255        for (key, type_name) in other_types.iter() {
256            types.insert(key.clone(), type_name.clone());
257        }
258    }
259
260    /// Get a snapshot of all facts
261    pub fn snapshot(&self) -> FactsSnapshot {
262        let data = self.data.read().unwrap();
263        let types = self.fact_types.read().unwrap();
264
265        FactsSnapshot {
266            data: data.clone(),
267            fact_types: types.clone(),
268        }
269    }
270
271    /// Restore from a snapshot
272    pub fn restore(&self, snapshot: FactsSnapshot) {
273        let mut data = self.data.write().unwrap();
274        let mut types = self.fact_types.write().unwrap();
275
276        *data = snapshot.data;
277        *types = snapshot.fact_types;
278    }
279}
280
281impl Default for Facts {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287/// A snapshot of Facts state
288#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct FactsSnapshot {
290    /// The fact data stored as key-value pairs
291    pub data: HashMap<String, Value>,
292    /// Type information for each fact
293    pub fact_types: HashMap<String, String>,
294}
295
296/// Undo entry for a single key
297#[derive(Debug, Clone)]
298struct UndoEntry {
299    key: String,
300    prev_value: Option<Value>,
301    prev_type: Option<String>,
302}
303
304impl Facts {
305    /// Start a new undo frame. Call `rollback_undo_frame` to revert or
306    /// `commit_undo_frame` to discard recorded changes.
307    pub fn begin_undo_frame(&self) {
308        let mut frames = self.undo_frames.write().unwrap();
309        frames.push(Vec::new());
310    }
311
312    /// Commit (discard) the top-most undo frame
313    pub fn commit_undo_frame(&self) {
314        let mut frames = self.undo_frames.write().unwrap();
315        frames.pop();
316    }
317
318    /// Rollback the top-most undo frame, restoring prior values
319    pub fn rollback_undo_frame(&self) {
320        let mut frames = self.undo_frames.write().unwrap();
321        if let Some(frame) = frames.pop() {
322            // Restore in reverse order
323            let mut data = self.data.write().unwrap();
324            let mut types = self.fact_types.write().unwrap();
325
326            for entry in frame.into_iter().rev() {
327                match entry.prev_value {
328                    Some(v) => { data.insert(entry.key.clone(), v); }
329                    None => { data.remove(&entry.key); }
330                }
331
332                match entry.prev_type {
333                    Some(t) => { types.insert(entry.key.clone(), t); }
334                    None => { types.remove(&entry.key); }
335                }
336            }
337        }
338    }
339
340    /// Record prior state for a top-level key if an undo frame is active
341    fn record_undo_for_key(&self, key: &str) {
342        let mut frames = self.undo_frames.write().unwrap();
343        if let Some(frame) = frames.last_mut() {
344            // capture previous value & type
345            let data = self.data.read().unwrap();
346            let types = self.fact_types.read().unwrap();
347
348            // Only record once per key in this frame
349            if frame.iter().any(|e: &UndoEntry| e.key == key) {
350                return;
351            }
352
353            let prev_value = data.get(key).cloned();
354            let prev_type = types.get(key).cloned();
355
356            frame.push(UndoEntry {
357                key: key.to_string(),
358                prev_value,
359                prev_type,
360            });
361        }
362    }
363}
364
365/// Trait for objects that can be used as facts
366pub trait Fact: Serialize + std::fmt::Debug {
367    /// Get the name of this fact type
368    fn fact_name() -> &'static str;
369}
370
371/// Macro to implement Fact trait easily
372#[macro_export]
373macro_rules! impl_fact {
374    ($type:ty, $name:expr) => {
375        impl Fact for $type {
376            fn fact_name() -> &'static str {
377                $name
378            }
379        }
380    };
381}
382
383/// Helper functions for working with fact objects
384pub struct FactHelper;
385
386impl FactHelper {
387    /// Create a generic object with key-value pairs
388    pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
389        let mut object = HashMap::new();
390        for (key, value) in pairs {
391            object.insert(key.to_string(), value);
392        }
393        Value::Object(object)
394    }
395
396    /// Create a User fact from common fields
397    pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
398        let mut user = HashMap::new();
399        user.insert("Name".to_string(), Value::String(name.to_string()));
400        user.insert("Age".to_string(), Value::Integer(age));
401        user.insert("Email".to_string(), Value::String(email.to_string()));
402        user.insert("Country".to_string(), Value::String(country.to_string()));
403        user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
404
405        Value::Object(user)
406    }
407
408    /// Create a Product fact
409    pub fn create_product(
410        name: &str,
411        price: f64,
412        category: &str,
413        in_stock: bool,
414        stock_count: i64,
415    ) -> Value {
416        let mut product = HashMap::new();
417        product.insert("Name".to_string(), Value::String(name.to_string()));
418        product.insert("Price".to_string(), Value::Number(price));
419        product.insert("Category".to_string(), Value::String(category.to_string()));
420        product.insert("InStock".to_string(), Value::Boolean(in_stock));
421        product.insert("StockCount".to_string(), Value::Integer(stock_count));
422
423        Value::Object(product)
424    }
425
426    /// Create an Order fact
427    pub fn create_order(
428        id: &str,
429        user_id: &str,
430        total: f64,
431        item_count: i64,
432        status: &str,
433    ) -> Value {
434        let mut order = HashMap::new();
435        order.insert("ID".to_string(), Value::String(id.to_string()));
436        order.insert("UserID".to_string(), Value::String(user_id.to_string()));
437        order.insert("Total".to_string(), Value::Number(total));
438        order.insert("ItemCount".to_string(), Value::Integer(item_count));
439        order.insert("Status".to_string(), Value::String(status.to_string()));
440
441        Value::Object(order)
442    }
443
444    /// Create a TestCar object for method call demo
445    pub fn create_test_car(
446        speed_up: bool,
447        speed: f64,
448        max_speed: f64,
449        speed_increment: f64,
450    ) -> Value {
451        let mut car = HashMap::new();
452        car.insert("speedUp".to_string(), Value::Boolean(speed_up));
453        car.insert("speed".to_string(), Value::Number(speed));
454        car.insert("maxSpeed".to_string(), Value::Number(max_speed));
455        car.insert("Speed".to_string(), Value::Number(speed));
456        car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
457        car.insert(
458            "_type".to_string(),
459            Value::String("TestCarClass".to_string()),
460        );
461
462        Value::Object(car)
463    }
464
465    /// Create a DistanceRecord object for method call demo  
466    pub fn create_distance_record(total_distance: f64) -> Value {
467        let mut record = HashMap::new();
468        record.insert("TotalDistance".to_string(), Value::Number(total_distance));
469        record.insert(
470            "_type".to_string(),
471            Value::String("DistanceRecordClass".to_string()),
472        );
473
474        Value::Object(record)
475    }
476
477    /// Create a Transaction fact for fraud detection
478    pub fn create_transaction(
479        id: &str,
480        amount: f64,
481        location: &str,
482        timestamp: i64,
483        user_id: &str,
484    ) -> Value {
485        let mut transaction = HashMap::new();
486        transaction.insert("ID".to_string(), Value::String(id.to_string()));
487        transaction.insert("Amount".to_string(), Value::Number(amount));
488        transaction.insert("Location".to_string(), Value::String(location.to_string()));
489        transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
490        transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
491
492        Value::Object(transaction)
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_facts_basic_operations() {
502        let facts = Facts::new();
503
504        // Add facts
505        facts.add_value("age", Value::Integer(25)).unwrap();
506        facts
507            .add_value("name", Value::String("John".to_string()))
508            .unwrap();
509
510        // Get facts
511        assert_eq!(facts.get("age"), Some(Value::Integer(25)));
512        assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
513
514        // Count
515        assert_eq!(facts.count(), 2);
516
517        // Contains
518        assert!(facts.contains("age"));
519        assert!(!facts.contains("email"));
520    }
521
522    #[test]
523    fn test_nested_facts() {
524        let facts = Facts::new();
525        let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
526
527        facts.add_value("User", user).unwrap();
528
529        // Get nested values
530        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
531        assert_eq!(
532            facts.get_nested("User.Name"),
533            Some(Value::String("John".to_string()))
534        );
535
536        // Set nested values
537        facts.set_nested("User.Age", Value::Integer(26)).unwrap();
538        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
539    }
540
541    #[test]
542    fn test_facts_snapshot() {
543        let facts = Facts::new();
544        facts
545            .add_value("test", Value::String("value".to_string()))
546            .unwrap();
547
548        let snapshot = facts.snapshot();
549
550        facts.clear();
551        assert_eq!(facts.count(), 0);
552
553        facts.restore(snapshot);
554        assert_eq!(facts.count(), 1);
555        assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
556    }
557}