rust_rule_engine/engine/
facts.rs

1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7/// Facts - represents the working memory of data objects
8/// Similar to Grule's DataContext concept
9#[derive(Debug, Clone)]
10pub struct Facts {
11    data: Arc<RwLock<HashMap<String, Value>>>,
12    fact_types: Arc<RwLock<HashMap<String, String>>>,
13}
14
15impl Facts {
16    /// Create a generic object from key-value pairs
17    pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
18        let mut map = HashMap::new();
19        for (key, value) in pairs {
20            map.insert(key, value);
21        }
22        Value::Object(map)
23    }
24
25    /// Create a user object
26    pub fn new() -> Self {
27        Self {
28            data: Arc::new(RwLock::new(HashMap::new())),
29            fact_types: Arc::new(RwLock::new(HashMap::new())),
30        }
31    }
32
33    /// Add a fact object to the working memory
34    pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
35    where
36        T: Serialize + std::fmt::Debug,
37    {
38        let value =
39            serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
40                message: e.to_string(),
41            })?;
42
43        let fact_value = Value::from(value);
44
45        let mut data = self.data.write().unwrap();
46        let mut types = self.fact_types.write().unwrap();
47
48        data.insert(name.to_string(), fact_value);
49        types.insert(name.to_string(), std::any::type_name::<T>().to_string());
50
51        Ok(())
52    }
53
54    /// Add a simple value fact
55    pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
56        let mut data = self.data.write().unwrap();
57        let mut types = self.fact_types.write().unwrap();
58
59        data.insert(name.to_string(), value);
60        types.insert(name.to_string(), "Value".to_string());
61
62        Ok(())
63    }
64
65    /// Get a fact by name
66    pub fn get(&self, name: &str) -> Option<Value> {
67        let data = self.data.read().unwrap();
68        data.get(name).cloned()
69    }
70
71    /// Get a nested fact property (e.g., "User.Profile.Age")
72    pub fn get_nested(&self, path: &str) -> Option<Value> {
73        let parts: Vec<&str> = path.split('.').collect();
74        if parts.is_empty() {
75            return None;
76        }
77
78        let data = self.data.read().unwrap();
79        let mut current = data.get(parts[0])?;
80
81        for part in parts.iter().skip(1) {
82            match current {
83                Value::Object(ref obj) => {
84                    current = obj.get(*part)?;
85                }
86                _ => return None,
87            }
88        }
89
90        Some(current.clone())
91    }
92
93    /// Set a fact value
94    pub fn set(&self, name: &str, value: Value) {
95        let mut data = self.data.write().unwrap();
96        data.insert(name.to_string(), value);
97    }
98
99    /// Set a nested fact property
100    pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
101        let parts: Vec<&str> = path.split('.').collect();
102        if parts.is_empty() {
103            return Err(RuleEngineError::FieldNotFound {
104                field: path.to_string(),
105            });
106        }
107
108        let mut data = self.data.write().unwrap();
109
110        if parts.len() == 1 {
111            data.insert(parts[0].to_string(), value);
112            return Ok(());
113        }
114
115        // Navigate to parent and set the nested value
116        let root_key = parts[0];
117        let root_value = data
118            .get_mut(root_key)
119            .ok_or_else(|| RuleEngineError::FieldNotFound {
120                field: root_key.to_string(),
121            })?;
122
123        self.set_nested_in_value(root_value, &parts[1..], value)?;
124        Ok(())
125    }
126
127    #[allow(clippy::only_used_in_recursion)]
128    fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
129        if path.is_empty() {
130            return Ok(());
131        }
132
133        if path.len() == 1 {
134            // We're at the target field
135            match current {
136                Value::Object(ref mut obj) => {
137                    obj.insert(path[0].to_string(), value);
138                    Ok(())
139                }
140                _ => Err(RuleEngineError::TypeMismatch {
141                    expected: "Object".to_string(),
142                    actual: format!("{:?}", current),
143                }),
144            }
145        } else {
146            // Continue navigating
147            match current {
148                Value::Object(ref mut obj) => {
149                    let next_value =
150                        obj.get_mut(path[0])
151                            .ok_or_else(|| RuleEngineError::FieldNotFound {
152                                field: path[0].to_string(),
153                            })?;
154                    self.set_nested_in_value(next_value, &path[1..], value)
155                }
156                _ => Err(RuleEngineError::TypeMismatch {
157                    expected: "Object".to_string(),
158                    actual: format!("{:?}", current),
159                }),
160            }
161        }
162    }
163
164    /// Remove a fact
165    pub fn remove(&self, name: &str) -> Option<Value> {
166        let mut data = self.data.write().unwrap();
167        let mut types = self.fact_types.write().unwrap();
168
169        types.remove(name);
170        data.remove(name)
171    }
172
173    /// Clear all facts
174    pub fn clear(&self) {
175        let mut data = self.data.write().unwrap();
176        let mut types = self.fact_types.write().unwrap();
177
178        data.clear();
179        types.clear();
180    }
181
182    /// Get all fact names
183    pub fn get_fact_names(&self) -> Vec<String> {
184        let data = self.data.read().unwrap();
185        data.keys().cloned().collect()
186    }
187
188    /// Get fact count
189    pub fn count(&self) -> usize {
190        let data = self.data.read().unwrap();
191        data.len()
192    }
193
194    /// Check if a fact exists
195    pub fn contains(&self, name: &str) -> bool {
196        let data = self.data.read().unwrap();
197        data.contains_key(name)
198    }
199
200    /// Get all facts as a HashMap (for pattern matching evaluation)
201    pub fn get_all_facts(&self) -> HashMap<String, Value> {
202        let data = self.data.read().unwrap();
203        data.clone()
204    }
205
206    /// Get the type name of a fact
207    pub fn get_fact_type(&self, name: &str) -> Option<String> {
208        let types = self.fact_types.read().unwrap();
209        types.get(name).cloned()
210    }
211
212    /// Convert to Context for rule evaluation
213    pub fn to_context(&self) -> Context {
214        let data = self.data.read().unwrap();
215        data.clone()
216    }
217
218    /// Create Facts from Context
219    pub fn from_context(context: Context) -> Self {
220        let facts = Facts::new();
221        {
222            let mut data = facts.data.write().unwrap();
223            *data = context;
224        }
225        facts
226    }
227
228    /// Merge another Facts instance into this one
229    pub fn merge(&self, other: &Facts) {
230        let other_data = other.data.read().unwrap();
231        let other_types = other.fact_types.read().unwrap();
232
233        let mut data = self.data.write().unwrap();
234        let mut types = self.fact_types.write().unwrap();
235
236        for (key, value) in other_data.iter() {
237            data.insert(key.clone(), value.clone());
238        }
239
240        for (key, type_name) in other_types.iter() {
241            types.insert(key.clone(), type_name.clone());
242        }
243    }
244
245    /// Get a snapshot of all facts
246    pub fn snapshot(&self) -> FactsSnapshot {
247        let data = self.data.read().unwrap();
248        let types = self.fact_types.read().unwrap();
249
250        FactsSnapshot {
251            data: data.clone(),
252            fact_types: types.clone(),
253        }
254    }
255
256    /// Restore from a snapshot
257    pub fn restore(&self, snapshot: FactsSnapshot) {
258        let mut data = self.data.write().unwrap();
259        let mut types = self.fact_types.write().unwrap();
260
261        *data = snapshot.data;
262        *types = snapshot.fact_types;
263    }
264}
265
266impl Default for Facts {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272/// A snapshot of Facts state
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct FactsSnapshot {
275    /// The fact data stored as key-value pairs
276    pub data: HashMap<String, Value>,
277    /// Type information for each fact
278    pub fact_types: HashMap<String, String>,
279}
280
281/// Trait for objects that can be used as facts
282pub trait Fact: Serialize + std::fmt::Debug {
283    /// Get the name of this fact type
284    fn fact_name() -> &'static str;
285}
286
287/// Macro to implement Fact trait easily
288#[macro_export]
289macro_rules! impl_fact {
290    ($type:ty, $name:expr) => {
291        impl Fact for $type {
292            fn fact_name() -> &'static str {
293                $name
294            }
295        }
296    };
297}
298
299/// Helper functions for working with fact objects
300pub struct FactHelper;
301
302impl FactHelper {
303    /// Create a generic object with key-value pairs
304    pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
305        let mut object = HashMap::new();
306        for (key, value) in pairs {
307            object.insert(key.to_string(), value);
308        }
309        Value::Object(object)
310    }
311
312    /// Create a User fact from common fields
313    pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
314        let mut user = HashMap::new();
315        user.insert("Name".to_string(), Value::String(name.to_string()));
316        user.insert("Age".to_string(), Value::Integer(age));
317        user.insert("Email".to_string(), Value::String(email.to_string()));
318        user.insert("Country".to_string(), Value::String(country.to_string()));
319        user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
320
321        Value::Object(user)
322    }
323
324    /// Create a Product fact
325    pub fn create_product(
326        name: &str,
327        price: f64,
328        category: &str,
329        in_stock: bool,
330        stock_count: i64,
331    ) -> Value {
332        let mut product = HashMap::new();
333        product.insert("Name".to_string(), Value::String(name.to_string()));
334        product.insert("Price".to_string(), Value::Number(price));
335        product.insert("Category".to_string(), Value::String(category.to_string()));
336        product.insert("InStock".to_string(), Value::Boolean(in_stock));
337        product.insert("StockCount".to_string(), Value::Integer(stock_count));
338
339        Value::Object(product)
340    }
341
342    /// Create an Order fact
343    pub fn create_order(
344        id: &str,
345        user_id: &str,
346        total: f64,
347        item_count: i64,
348        status: &str,
349    ) -> Value {
350        let mut order = HashMap::new();
351        order.insert("ID".to_string(), Value::String(id.to_string()));
352        order.insert("UserID".to_string(), Value::String(user_id.to_string()));
353        order.insert("Total".to_string(), Value::Number(total));
354        order.insert("ItemCount".to_string(), Value::Integer(item_count));
355        order.insert("Status".to_string(), Value::String(status.to_string()));
356
357        Value::Object(order)
358    }
359
360    /// Create a TestCar object for method call demo
361    pub fn create_test_car(
362        speed_up: bool,
363        speed: f64,
364        max_speed: f64,
365        speed_increment: f64,
366    ) -> Value {
367        let mut car = HashMap::new();
368        car.insert("speedUp".to_string(), Value::Boolean(speed_up));
369        car.insert("speed".to_string(), Value::Number(speed));
370        car.insert("maxSpeed".to_string(), Value::Number(max_speed));
371        car.insert("Speed".to_string(), Value::Number(speed));
372        car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
373        car.insert(
374            "_type".to_string(),
375            Value::String("TestCarClass".to_string()),
376        );
377
378        Value::Object(car)
379    }
380
381    /// Create a DistanceRecord object for method call demo  
382    pub fn create_distance_record(total_distance: f64) -> Value {
383        let mut record = HashMap::new();
384        record.insert("TotalDistance".to_string(), Value::Number(total_distance));
385        record.insert(
386            "_type".to_string(),
387            Value::String("DistanceRecordClass".to_string()),
388        );
389
390        Value::Object(record)
391    }
392
393    /// Create a Transaction fact for fraud detection
394    pub fn create_transaction(
395        id: &str,
396        amount: f64,
397        location: &str,
398        timestamp: i64,
399        user_id: &str,
400    ) -> Value {
401        let mut transaction = HashMap::new();
402        transaction.insert("ID".to_string(), Value::String(id.to_string()));
403        transaction.insert("Amount".to_string(), Value::Number(amount));
404        transaction.insert("Location".to_string(), Value::String(location.to_string()));
405        transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
406        transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
407
408        Value::Object(transaction)
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_facts_basic_operations() {
418        let facts = Facts::new();
419
420        // Add facts
421        facts.add_value("age", Value::Integer(25)).unwrap();
422        facts
423            .add_value("name", Value::String("John".to_string()))
424            .unwrap();
425
426        // Get facts
427        assert_eq!(facts.get("age"), Some(Value::Integer(25)));
428        assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
429
430        // Count
431        assert_eq!(facts.count(), 2);
432
433        // Contains
434        assert!(facts.contains("age"));
435        assert!(!facts.contains("email"));
436    }
437
438    #[test]
439    fn test_nested_facts() {
440        let facts = Facts::new();
441        let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
442
443        facts.add_value("User", user).unwrap();
444
445        // Get nested values
446        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
447        assert_eq!(
448            facts.get_nested("User.Name"),
449            Some(Value::String("John".to_string()))
450        );
451
452        // Set nested values
453        facts.set_nested("User.Age", Value::Integer(26)).unwrap();
454        assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
455    }
456
457    #[test]
458    fn test_facts_snapshot() {
459        let facts = Facts::new();
460        facts
461            .add_value("test", Value::String("value".to_string()))
462            .unwrap();
463
464        let snapshot = facts.snapshot();
465
466        facts.clear();
467        assert_eq!(facts.count(), 0);
468
469        facts.restore(snapshot);
470        assert_eq!(facts.count(), 1);
471        assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
472    }
473}