1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7#[derive(Debug, Clone)]
10pub struct Facts {
11 data: Arc<RwLock<HashMap<String, Value>>>,
12 fact_types: Arc<RwLock<HashMap<String, String>>>,
13}
14
15impl Facts {
16 pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
18 let mut map = HashMap::new();
19 for (key, value) in pairs {
20 map.insert(key, value);
21 }
22 Value::Object(map)
23 }
24
25 pub fn new() -> Self {
27 Self {
28 data: Arc::new(RwLock::new(HashMap::new())),
29 fact_types: Arc::new(RwLock::new(HashMap::new())),
30 }
31 }
32
33 pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
35 where
36 T: Serialize + std::fmt::Debug,
37 {
38 let value =
39 serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
40 message: e.to_string(),
41 })?;
42
43 let fact_value = Value::from(value);
44
45 let mut data = self.data.write().unwrap();
46 let mut types = self.fact_types.write().unwrap();
47
48 data.insert(name.to_string(), fact_value);
49 types.insert(name.to_string(), std::any::type_name::<T>().to_string());
50
51 Ok(())
52 }
53
54 pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
56 let mut data = self.data.write().unwrap();
57 let mut types = self.fact_types.write().unwrap();
58
59 data.insert(name.to_string(), value);
60 types.insert(name.to_string(), "Value".to_string());
61
62 Ok(())
63 }
64
65 pub fn get(&self, name: &str) -> Option<Value> {
67 let data = self.data.read().unwrap();
68 data.get(name).cloned()
69 }
70
71 pub fn get_nested(&self, path: &str) -> Option<Value> {
73 let parts: Vec<&str> = path.split('.').collect();
74 if parts.is_empty() {
75 return None;
76 }
77
78 let data = self.data.read().unwrap();
79 let mut current = data.get(parts[0])?;
80
81 for part in parts.iter().skip(1) {
82 match current {
83 Value::Object(ref obj) => {
84 current = obj.get(*part)?;
85 }
86 _ => return None,
87 }
88 }
89
90 Some(current.clone())
91 }
92
93 pub fn set(&self, name: &str, value: Value) {
95 let mut data = self.data.write().unwrap();
96 data.insert(name.to_string(), value);
97 }
98
99 pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
101 let parts: Vec<&str> = path.split('.').collect();
102 if parts.is_empty() {
103 return Err(RuleEngineError::FieldNotFound {
104 field: path.to_string(),
105 });
106 }
107
108 let mut data = self.data.write().unwrap();
109
110 if parts.len() == 1 {
111 data.insert(parts[0].to_string(), value);
112 return Ok(());
113 }
114
115 let root_key = parts[0];
117 let root_value = data
118 .get_mut(root_key)
119 .ok_or_else(|| RuleEngineError::FieldNotFound {
120 field: root_key.to_string(),
121 })?;
122
123 self.set_nested_in_value(root_value, &parts[1..], value)?;
124 Ok(())
125 }
126
127 #[allow(clippy::only_used_in_recursion)]
128 fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
129 if path.is_empty() {
130 return Ok(());
131 }
132
133 if path.len() == 1 {
134 match current {
136 Value::Object(ref mut obj) => {
137 obj.insert(path[0].to_string(), value);
138 Ok(())
139 }
140 _ => Err(RuleEngineError::TypeMismatch {
141 expected: "Object".to_string(),
142 actual: format!("{:?}", current),
143 }),
144 }
145 } else {
146 match current {
148 Value::Object(ref mut obj) => {
149 let next_value =
150 obj.get_mut(path[0])
151 .ok_or_else(|| RuleEngineError::FieldNotFound {
152 field: path[0].to_string(),
153 })?;
154 self.set_nested_in_value(next_value, &path[1..], value)
155 }
156 _ => Err(RuleEngineError::TypeMismatch {
157 expected: "Object".to_string(),
158 actual: format!("{:?}", current),
159 }),
160 }
161 }
162 }
163
164 pub fn remove(&self, name: &str) -> Option<Value> {
166 let mut data = self.data.write().unwrap();
167 let mut types = self.fact_types.write().unwrap();
168
169 types.remove(name);
170 data.remove(name)
171 }
172
173 pub fn clear(&self) {
175 let mut data = self.data.write().unwrap();
176 let mut types = self.fact_types.write().unwrap();
177
178 data.clear();
179 types.clear();
180 }
181
182 pub fn get_fact_names(&self) -> Vec<String> {
184 let data = self.data.read().unwrap();
185 data.keys().cloned().collect()
186 }
187
188 pub fn count(&self) -> usize {
190 let data = self.data.read().unwrap();
191 data.len()
192 }
193
194 pub fn contains(&self, name: &str) -> bool {
196 let data = self.data.read().unwrap();
197 data.contains_key(name)
198 }
199
200 pub fn get_all_facts(&self) -> HashMap<String, Value> {
202 let data = self.data.read().unwrap();
203 data.clone()
204 }
205
206 pub fn get_fact_type(&self, name: &str) -> Option<String> {
208 let types = self.fact_types.read().unwrap();
209 types.get(name).cloned()
210 }
211
212 pub fn to_context(&self) -> Context {
214 let data = self.data.read().unwrap();
215 data.clone()
216 }
217
218 pub fn from_context(context: Context) -> Self {
220 let facts = Facts::new();
221 {
222 let mut data = facts.data.write().unwrap();
223 *data = context;
224 }
225 facts
226 }
227
228 pub fn merge(&self, other: &Facts) {
230 let other_data = other.data.read().unwrap();
231 let other_types = other.fact_types.read().unwrap();
232
233 let mut data = self.data.write().unwrap();
234 let mut types = self.fact_types.write().unwrap();
235
236 for (key, value) in other_data.iter() {
237 data.insert(key.clone(), value.clone());
238 }
239
240 for (key, type_name) in other_types.iter() {
241 types.insert(key.clone(), type_name.clone());
242 }
243 }
244
245 pub fn snapshot(&self) -> FactsSnapshot {
247 let data = self.data.read().unwrap();
248 let types = self.fact_types.read().unwrap();
249
250 FactsSnapshot {
251 data: data.clone(),
252 fact_types: types.clone(),
253 }
254 }
255
256 pub fn restore(&self, snapshot: FactsSnapshot) {
258 let mut data = self.data.write().unwrap();
259 let mut types = self.fact_types.write().unwrap();
260
261 *data = snapshot.data;
262 *types = snapshot.fact_types;
263 }
264}
265
266impl Default for Facts {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct FactsSnapshot {
275 pub data: HashMap<String, Value>,
277 pub fact_types: HashMap<String, String>,
279}
280
281pub trait Fact: Serialize + std::fmt::Debug {
283 fn fact_name() -> &'static str;
285}
286
287#[macro_export]
289macro_rules! impl_fact {
290 ($type:ty, $name:expr) => {
291 impl Fact for $type {
292 fn fact_name() -> &'static str {
293 $name
294 }
295 }
296 };
297}
298
299pub struct FactHelper;
301
302impl FactHelper {
303 pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
305 let mut object = HashMap::new();
306 for (key, value) in pairs {
307 object.insert(key.to_string(), value);
308 }
309 Value::Object(object)
310 }
311
312 pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
314 let mut user = HashMap::new();
315 user.insert("Name".to_string(), Value::String(name.to_string()));
316 user.insert("Age".to_string(), Value::Integer(age));
317 user.insert("Email".to_string(), Value::String(email.to_string()));
318 user.insert("Country".to_string(), Value::String(country.to_string()));
319 user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
320
321 Value::Object(user)
322 }
323
324 pub fn create_product(
326 name: &str,
327 price: f64,
328 category: &str,
329 in_stock: bool,
330 stock_count: i64,
331 ) -> Value {
332 let mut product = HashMap::new();
333 product.insert("Name".to_string(), Value::String(name.to_string()));
334 product.insert("Price".to_string(), Value::Number(price));
335 product.insert("Category".to_string(), Value::String(category.to_string()));
336 product.insert("InStock".to_string(), Value::Boolean(in_stock));
337 product.insert("StockCount".to_string(), Value::Integer(stock_count));
338
339 Value::Object(product)
340 }
341
342 pub fn create_order(
344 id: &str,
345 user_id: &str,
346 total: f64,
347 item_count: i64,
348 status: &str,
349 ) -> Value {
350 let mut order = HashMap::new();
351 order.insert("ID".to_string(), Value::String(id.to_string()));
352 order.insert("UserID".to_string(), Value::String(user_id.to_string()));
353 order.insert("Total".to_string(), Value::Number(total));
354 order.insert("ItemCount".to_string(), Value::Integer(item_count));
355 order.insert("Status".to_string(), Value::String(status.to_string()));
356
357 Value::Object(order)
358 }
359
360 pub fn create_test_car(
362 speed_up: bool,
363 speed: f64,
364 max_speed: f64,
365 speed_increment: f64,
366 ) -> Value {
367 let mut car = HashMap::new();
368 car.insert("speedUp".to_string(), Value::Boolean(speed_up));
369 car.insert("speed".to_string(), Value::Number(speed));
370 car.insert("maxSpeed".to_string(), Value::Number(max_speed));
371 car.insert("Speed".to_string(), Value::Number(speed));
372 car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
373 car.insert(
374 "_type".to_string(),
375 Value::String("TestCarClass".to_string()),
376 );
377
378 Value::Object(car)
379 }
380
381 pub fn create_distance_record(total_distance: f64) -> Value {
383 let mut record = HashMap::new();
384 record.insert("TotalDistance".to_string(), Value::Number(total_distance));
385 record.insert(
386 "_type".to_string(),
387 Value::String("DistanceRecordClass".to_string()),
388 );
389
390 Value::Object(record)
391 }
392
393 pub fn create_transaction(
395 id: &str,
396 amount: f64,
397 location: &str,
398 timestamp: i64,
399 user_id: &str,
400 ) -> Value {
401 let mut transaction = HashMap::new();
402 transaction.insert("ID".to_string(), Value::String(id.to_string()));
403 transaction.insert("Amount".to_string(), Value::Number(amount));
404 transaction.insert("Location".to_string(), Value::String(location.to_string()));
405 transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
406 transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
407
408 Value::Object(transaction)
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_facts_basic_operations() {
418 let facts = Facts::new();
419
420 facts.add_value("age", Value::Integer(25)).unwrap();
422 facts
423 .add_value("name", Value::String("John".to_string()))
424 .unwrap();
425
426 assert_eq!(facts.get("age"), Some(Value::Integer(25)));
428 assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
429
430 assert_eq!(facts.count(), 2);
432
433 assert!(facts.contains("age"));
435 assert!(!facts.contains("email"));
436 }
437
438 #[test]
439 fn test_nested_facts() {
440 let facts = Facts::new();
441 let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
442
443 facts.add_value("User", user).unwrap();
444
445 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
447 assert_eq!(
448 facts.get_nested("User.Name"),
449 Some(Value::String("John".to_string()))
450 );
451
452 facts.set_nested("User.Age", Value::Integer(26)).unwrap();
454 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
455 }
456
457 #[test]
458 fn test_facts_snapshot() {
459 let facts = Facts::new();
460 facts
461 .add_value("test", Value::String("value".to_string()))
462 .unwrap();
463
464 let snapshot = facts.snapshot();
465
466 facts.clear();
467 assert_eq!(facts.count(), 0);
468
469 facts.restore(snapshot);
470 assert_eq!(facts.count(), 1);
471 assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
472 }
473}