Skip to main content

rust_rule_engine/engine/
facts.rs

1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7/// Facts - represents the working memory of data objects
8/// Similar to Grule's DataContext concept
9#[derive(Debug, Clone)]
10pub struct Facts {
11    data: Arc<RwLock<HashMap<String, Value>>>,
12    fact_types: Arc<RwLock<HashMap<String, String>>>,
13    /// Undo log frames for lightweight snapshots (stack of frames)
14    /// Each frame records per-key previous values so rollback can restore only
15    /// changed keys instead of cloning the whole facts map.
16    undo_frames: Arc<RwLock<Vec<Vec<UndoEntry>>>>,
17}
18
19impl Facts {
20    /// Create a generic object from key-value pairs
21    pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
22        let mut map = HashMap::new();
23        for (key, value) in pairs {
24            map.insert(key, value);
25        }
26        Value::Object(map)
27    }
28
29    /// Create a user object
30    pub fn new() -> Self {
31        Self {
32            data: Arc::new(RwLock::new(HashMap::new())),
33            fact_types: Arc::new(RwLock::new(HashMap::new())),
34            undo_frames: Arc::new(RwLock::new(Vec::new())),
35        }
36    }
37
38    /// Add a fact object to the working memory
39    pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
40    where
41        T: Serialize + std::fmt::Debug,
42    {
43        let value =
44            serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
45                message: e.to_string(),
46            })?;
47
48        let fact_value = Value::from(value);
49
50        let mut data = self.data.write().unwrap();
51        let mut types = self.fact_types.write().unwrap();
52
53        data.insert(name.to_string(), fact_value);
54        types.insert(name.to_string(), std::any::type_name::<T>().to_string());
55
56        Ok(())
57    }
58
59    /// Add a simple value fact
60    pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
61        let mut data = self.data.write().unwrap();
62        let mut types = self.fact_types.write().unwrap();
63
64        data.insert(name.to_string(), value);
65        types.insert(name.to_string(), "Value".to_string());
66
67        Ok(())
68    }
69
70    /// Get a fact by name
71    pub fn get(&self, name: &str) -> Option<Value> {
72        let data = self.data.read().unwrap();
73        data.get(name).cloned()
74    }
75
76    /// Access a fact value by reference via a callback, avoiding clone
77    pub fn with_value<F, R>(&self, name: &str, f: F) -> Option<R>
78    where
79        F: FnOnce(&Value) -> R,
80    {
81        let data = self.data.read().unwrap();
82        data.get(name).map(f)
83    }
84
85    /// Get a nested fact property (e.g., "User.Profile.Age")
86    pub fn get_nested(&self, path: &str) -> Option<Value> {
87        let parts: Vec<&str> = path.split('.').collect();
88        if parts.is_empty() {
89            return None;
90        }
91
92        let data = self.data.read().unwrap();
93        let mut current = data.get(parts[0])?;
94
95        for part in parts.iter().skip(1) {
96            match current {
97                Value::Object(ref obj) => {
98                    current = obj.get(*part)?;
99                }
100                _ => return None,
101            }
102        }
103
104        Some(current.clone())
105    }
106
107    /// Set a fact value
108    pub fn set(&self, name: &str, value: Value) {
109        // Record previous value for undo if an undo frame is active
110        self.record_undo_for_key(name);
111
112        let mut data = self.data.write().unwrap();
113        data.insert(name.to_string(), value);
114    }
115
116    /// Set a nested fact property
117    pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
118        let parts: Vec<&str> = path.split('.').collect();
119        if parts.is_empty() {
120            return Err(RuleEngineError::FieldNotFound {
121                field: path.to_string(),
122            });
123        }
124
125        // Record previous top-level key for undo semantics
126        self.record_undo_for_key(parts[0]);
127
128        let mut data = self.data.write().unwrap();
129
130        if parts.len() == 1 {
131            data.insert(parts[0].to_string(), value);
132            return Ok(());
133        }
134
135        // Navigate to parent and set the nested value
136        let root_key = parts[0];
137        let root_value = data
138            .get_mut(root_key)
139            .ok_or_else(|| RuleEngineError::FieldNotFound {
140                field: root_key.to_string(),
141            })?;
142
143        self.set_nested_in_value(root_value, &parts[1..], value)?;
144        Ok(())
145    }
146
147    #[allow(clippy::only_used_in_recursion)]
148    fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
149        if path.is_empty() {
150            return Ok(());
151        }
152
153        if path.len() == 1 {
154            // We're at the target field
155            match current {
156                Value::Object(ref mut obj) => {
157                    obj.insert(path[0].to_string(), value);
158                    Ok(())
159                }
160                _ => Err(RuleEngineError::TypeMismatch {
161                    expected: "Object".to_string(),
162                    actual: format!("{:?}", current),
163                }),
164            }
165        } else {
166            // Continue navigating
167            match current {
168                Value::Object(ref mut obj) => {
169                    let next_value =
170                        obj.get_mut(path[0])
171                            .ok_or_else(|| RuleEngineError::FieldNotFound {
172                                field: path[0].to_string(),
173                            })?;
174                    self.set_nested_in_value(next_value, &path[1..], value)
175                }
176                _ => Err(RuleEngineError::TypeMismatch {
177                    expected: "Object".to_string(),
178                    actual: format!("{:?}", current),
179                }),
180            }
181        }
182    }
183
184    /// Remove a fact
185    pub fn remove(&self, name: &str) -> Option<Value> {
186        // Record undo before removing
187        self.record_undo_for_key(name);
188
189        let mut data = self.data.write().unwrap();
190        let mut types = self.fact_types.write().unwrap();
191
192        types.remove(name);
193        data.remove(name)
194    }
195
196    /// Clear all facts
197    pub fn clear(&self) {
198        let mut data = self.data.write().unwrap();
199        let mut types = self.fact_types.write().unwrap();
200
201        data.clear();
202        types.clear();
203    }
204
205    /// Get all fact names
206    pub fn get_fact_names(&self) -> Vec<String> {
207        let data = self.data.read().unwrap();
208        data.keys().cloned().collect()
209    }
210
211    /// Get fact count
212    pub fn count(&self) -> usize {
213        let data = self.data.read().unwrap();
214        data.len()
215    }
216
217    /// Check if a fact exists
218    pub fn contains(&self, name: &str) -> bool {
219        let data = self.data.read().unwrap();
220        data.contains_key(name)
221    }
222
223    /// Get all facts as a HashMap (for pattern matching evaluation)
224    pub fn get_all_facts(&self) -> HashMap<String, Value> {
225        let data = self.data.read().unwrap();
226        data.clone()
227    }
228
229    /// Get the type name of a fact
230    pub fn get_fact_type(&self, name: &str) -> Option<String> {
231        let types = self.fact_types.read().unwrap();
232        types.get(name).cloned()
233    }
234
235    /// Convert to Context for rule evaluation
236    pub fn to_context(&self) -> Context {
237        let data = self.data.read().unwrap();
238        data.clone()
239    }
240
241    /// Create Facts from Context
242    pub fn from_context(context: Context) -> Self {
243        let facts = Facts::new();
244        {
245            let mut data = facts.data.write().unwrap();
246            *data = context;
247        }
248        facts
249    }
250
251    /// Merge another Facts instance into this one
252    pub fn merge(&self, other: &Facts) {
253        let other_data = other.data.read().unwrap();
254        let other_types = other.fact_types.read().unwrap();
255
256        let mut data = self.data.write().unwrap();
257        let mut types = self.fact_types.write().unwrap();
258
259        for (key, value) in other_data.iter() {
260            data.insert(key.clone(), value.clone());
261        }
262
263        for (key, type_name) in other_types.iter() {
264            types.insert(key.clone(), type_name.clone());
265        }
266    }
267
268    /// Get a snapshot of all facts
269    pub fn snapshot(&self) -> FactsSnapshot {
270        let data = self.data.read().unwrap();
271        let types = self.fact_types.read().unwrap();
272
273        FactsSnapshot {
274            data: data.clone(),
275            fact_types: types.clone(),
276        }
277    }
278
279    /// Restore from a snapshot
280    pub fn restore(&self, snapshot: FactsSnapshot) {
281        let mut data = self.data.write().unwrap();
282        let mut types = self.fact_types.write().unwrap();
283
284        *data = snapshot.data;
285        *types = snapshot.fact_types;
286    }
287}
288
289impl Default for Facts {
290    fn default() -> Self {
291        Self::new()
292    }
293}
294
295/// A snapshot of Facts state
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct FactsSnapshot {
298    /// The fact data stored as key-value pairs
299    pub data: HashMap<String, Value>,
300    /// Type information for each fact
301    pub fact_types: HashMap<String, String>,
302}
303
304/// Undo entry for a single key
305#[derive(Debug, Clone)]
306struct UndoEntry {
307    key: String,
308    prev_value: Option<Value>,
309    prev_type: Option<String>,
310}
311
312impl Facts {
313    /// Start a new undo frame. Call `rollback_undo_frame` to revert or
314    /// `commit_undo_frame` to discard recorded changes.
315    pub fn begin_undo_frame(&self) {
316        let mut frames = self.undo_frames.write().unwrap();
317        frames.push(Vec::new());
318    }
319
320    /// Commit (discard) the top-most undo frame
321    pub fn commit_undo_frame(&self) {
322        let mut frames = self.undo_frames.write().unwrap();
323        frames.pop();
324    }
325
326    /// Rollback the top-most undo frame, restoring prior values
327    pub fn rollback_undo_frame(&self) {
328        let mut frames = self.undo_frames.write().unwrap();
329        if let Some(frame) = frames.pop() {
330            // Restore in reverse order
331            let mut data = self.data.write().unwrap();
332            let mut types = self.fact_types.write().unwrap();
333
334            for entry in frame.into_iter().rev() {
335                match entry.prev_value {
336                    Some(v) => {
337                        data.insert(entry.key.clone(), v);
338                    }
339                    None => {
340                        data.remove(&entry.key);
341                    }
342                }
343
344                match entry.prev_type {
345                    Some(t) => {
346                        types.insert(entry.key.clone(), t);
347                    }
348                    None => {
349                        types.remove(&entry.key);
350                    }
351                }
352            }
353        }
354    }
355
356    /// Record prior state for a top-level key if an undo frame is active
357    fn record_undo_for_key(&self, key: &str) {
358        let mut frames = self.undo_frames.write().unwrap();
359        if let Some(frame) = frames.last_mut() {
360            // capture previous value & type
361            let data = self.data.read().unwrap();
362            let types = self.fact_types.read().unwrap();
363
364            // Only record once per key in this frame
365            if frame.iter().any(|e: &UndoEntry| e.key == key) {
366                return;
367            }
368
369            let prev_value = data.get(key).cloned();
370            let prev_type = types.get(key).cloned();
371
372            frame.push(UndoEntry {
373                key: key.to_string(),
374                prev_value,
375                prev_type,
376            });
377        }
378    }
379}
380
381/// Trait for objects that can be used as facts
382pub trait Fact: Serialize + std::fmt::Debug {
383    /// Get the name of this fact type
384    fn fact_name() -> &'static str;
385}
386
387/// Macro to implement Fact trait easily
388#[macro_export]
389macro_rules! impl_fact {
390    ($type:ty, $name:expr) => {
391        impl Fact for $type {
392            fn fact_name() -> &'static str {
393                $name
394            }
395        }
396    };
397}
398
399/// Helper functions for working with fact objects
400pub struct FactHelper;
401
402impl FactHelper {
403    /// Create a generic object with key-value pairs
404    pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
405        let mut object = HashMap::new();
406        for (key, value) in pairs {
407            object.insert(key.to_string(), value);
408        }
409        Value::Object(object)
410    }
411
412    /// Create a User fact from common fields
413    pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
414        let mut user = HashMap::new();
415        user.insert("Name".to_string(), Value::String(name.to_string()));
416        user.insert("Age".to_string(), Value::Integer(age));
417        user.insert("Email".to_string(), Value::String(email.to_string()));
418        user.insert("Country".to_string(), Value::String(country.to_string()));
419        user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
420
421        Value::Object(user)
422    }
423
424    /// Create a Product fact
425    pub fn create_product(
426        name: &str,
427        price: f64,
428        category: &str,
429        in_stock: bool,
430        stock_count: i64,
431    ) -> Value {
432        let mut product = HashMap::new();
433        product.insert("Name".to_string(), Value::String(name.to_string()));
434        product.insert("Price".to_string(), Value::Number(price));
435        product.insert("Category".to_string(), Value::String(category.to_string()));
436        product.insert("InStock".to_string(), Value::Boolean(in_stock));
437        product.insert("StockCount".to_string(), Value::Integer(stock_count));
438
439        Value::Object(product)
440    }
441
442    /// Create an Order fact
443    pub fn create_order(
444        id: &str,
445        user_id: &str,
446        total: f64,
447        item_count: i64,
448        status: &str,
449    ) -> Value {
450        let mut order = HashMap::new();
451        order.insert("ID".to_string(), Value::String(id.to_string()));
452        order.insert("UserID".to_string(), Value::String(user_id.to_string()));
453        order.insert("Total".to_string(), Value::Number(total));
454        order.insert("ItemCount".to_string(), Value::Integer(item_count));
455        order.insert("Status".to_string(), Value::String(status.to_string()));
456
457        Value::Object(order)
458    }
459
460    /// Create a TestCar object for method call demo
461    pub fn create_test_car(
462        speed_up: bool,
463        speed: f64,
464        max_speed: f64,
465        speed_increment: f64,
466    ) -> Value {
467        let mut car = HashMap::new();
468        car.insert("speedUp".to_string(), Value::Boolean(speed_up));
469        car.insert("speed".to_string(), Value::Number(speed));
470        car.insert("maxSpeed".to_string(), Value::Number(max_speed));
471        car.insert("Speed".to_string(), Value::Number(speed));
472        car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
473        car.insert(
474            "_type".to_string(),
475            Value::String("TestCarClass".to_string()),
476        );
477
478        Value::Object(car)
479    }
480
481    /// Create a DistanceRecord object for method call demo  
482    pub fn create_distance_record(total_distance: f64) -> Value {
483        let mut record = HashMap::new();
484        record.insert("TotalDistance".to_string(), Value::Number(total_distance));
485        record.insert(
486            "_type".to_string(),
487            Value::String("DistanceRecordClass".to_string()),
488        );
489
490        Value::Object(record)
491    }
492
493    /// Create a Transaction fact for fraud detection
494    pub fn create_transaction(
495        id: &str,
496        amount: f64,
497        location: &str,
498        timestamp: i64,
499        user_id: &str,
500    ) -> Value {
501        let mut transaction = HashMap::new();
502        transaction.insert("ID".to_string(), Value::String(id.to_string()));
503        transaction.insert("Amount".to_string(), Value::Number(amount));
504        transaction.insert("Location".to_string(), Value::String(location.to_string()));
505        transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
506        transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
507
508        Value::Object(transaction)
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_facts_basic_operations() {
518        let facts = Facts::new();
519
520        // Add facts
521        facts.add_value("age", Value::Integer(25)).unwrap();
522        facts
523            .add_value("name", Value::String("John".to_string()))
524            .unwrap();
525
526        // Get facts
527        assert_eq!(facts.get("age"), Some(Value::Integer(25)));
528        assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
529
530        // Count
531        assert_eq!(facts.count(), 2);
532
533        // Contains
534        assert!(facts.contains("age"));
535        assert!(!facts.contains("email"));
536    }
537
538    #[test]
539    fn test_nested_facts() {
540        let facts = Facts::new();
541        let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
542
543        facts.add_value("User", user).unwrap();
544
545        // Get nested values
546        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
547        assert_eq!(
548            facts.get_nested("User.Name"),
549            Some(Value::String("John".to_string()))
550        );
551
552        // Set nested values
553        facts.set_nested("User.Age", Value::Integer(26)).unwrap();
554        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
555    }
556
557    #[test]
558    fn test_facts_snapshot() {
559        let facts = Facts::new();
560        facts
561            .add_value("test", Value::String("value".to_string()))
562            .unwrap();
563
564        let snapshot = facts.snapshot();
565
566        facts.clear();
567        assert_eq!(facts.count(), 0);
568
569        facts.restore(snapshot);
570        assert_eq!(facts.count(), 1);
571        assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
572    }
573}