1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7#[derive(Debug, Clone)]
10pub struct Facts {
11 data: Arc<RwLock<HashMap<String, Value>>>,
12 fact_types: Arc<RwLock<HashMap<String, String>>>,
13 undo_frames: Arc<RwLock<Vec<Vec<UndoEntry>>>>,
17}
18
19impl Facts {
20 pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
22 let mut map = HashMap::new();
23 for (key, value) in pairs {
24 map.insert(key, value);
25 }
26 Value::Object(map)
27 }
28
29 pub fn new() -> Self {
31 Self {
32 data: Arc::new(RwLock::new(HashMap::new())),
33 fact_types: Arc::new(RwLock::new(HashMap::new())),
34 undo_frames: Arc::new(RwLock::new(Vec::new())),
35 }
36 }
37
38 pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
40 where
41 T: Serialize + std::fmt::Debug,
42 {
43 let value =
44 serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
45 message: e.to_string(),
46 })?;
47
48 let fact_value = Value::from(value);
49
50 let mut data = self.data.write().unwrap();
51 let mut types = self.fact_types.write().unwrap();
52
53 data.insert(name.to_string(), fact_value);
54 types.insert(name.to_string(), std::any::type_name::<T>().to_string());
55
56 Ok(())
57 }
58
59 pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
61 let mut data = self.data.write().unwrap();
62 let mut types = self.fact_types.write().unwrap();
63
64 data.insert(name.to_string(), value);
65 types.insert(name.to_string(), "Value".to_string());
66
67 Ok(())
68 }
69
70 pub fn get(&self, name: &str) -> Option<Value> {
72 let data = self.data.read().unwrap();
73 data.get(name).cloned()
74 }
75
76 pub fn get_nested(&self, path: &str) -> Option<Value> {
78 let parts: Vec<&str> = path.split('.').collect();
79 if parts.is_empty() {
80 return None;
81 }
82
83 let data = self.data.read().unwrap();
84 let mut current = data.get(parts[0])?;
85
86 for part in parts.iter().skip(1) {
87 match current {
88 Value::Object(ref obj) => {
89 current = obj.get(*part)?;
90 }
91 _ => return None,
92 }
93 }
94
95 Some(current.clone())
96 }
97
98 pub fn set(&self, name: &str, value: Value) {
100 self.record_undo_for_key(name);
102
103 let mut data = self.data.write().unwrap();
104 data.insert(name.to_string(), value);
105 }
106
107 pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
109 let parts: Vec<&str> = path.split('.').collect();
110 if parts.is_empty() {
111 return Err(RuleEngineError::FieldNotFound {
112 field: path.to_string(),
113 });
114 }
115
116 self.record_undo_for_key(parts[0]);
118
119 let mut data = self.data.write().unwrap();
120
121 if parts.len() == 1 {
122 data.insert(parts[0].to_string(), value);
123 return Ok(());
124 }
125
126 let root_key = parts[0];
128 let root_value = data
129 .get_mut(root_key)
130 .ok_or_else(|| RuleEngineError::FieldNotFound {
131 field: root_key.to_string(),
132 })?;
133
134 self.set_nested_in_value(root_value, &parts[1..], value)?;
135 Ok(())
136 }
137
138 #[allow(clippy::only_used_in_recursion)]
139 fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
140 if path.is_empty() {
141 return Ok(());
142 }
143
144 if path.len() == 1 {
145 match current {
147 Value::Object(ref mut obj) => {
148 obj.insert(path[0].to_string(), value);
149 Ok(())
150 }
151 _ => Err(RuleEngineError::TypeMismatch {
152 expected: "Object".to_string(),
153 actual: format!("{:?}", current),
154 }),
155 }
156 } else {
157 match current {
159 Value::Object(ref mut obj) => {
160 let next_value =
161 obj.get_mut(path[0])
162 .ok_or_else(|| RuleEngineError::FieldNotFound {
163 field: path[0].to_string(),
164 })?;
165 self.set_nested_in_value(next_value, &path[1..], value)
166 }
167 _ => Err(RuleEngineError::TypeMismatch {
168 expected: "Object".to_string(),
169 actual: format!("{:?}", current),
170 }),
171 }
172 }
173 }
174
175 pub fn remove(&self, name: &str) -> Option<Value> {
177 self.record_undo_for_key(name);
179
180 let mut data = self.data.write().unwrap();
181 let mut types = self.fact_types.write().unwrap();
182
183 types.remove(name);
184 data.remove(name)
185 }
186
187 pub fn clear(&self) {
189 let mut data = self.data.write().unwrap();
190 let mut types = self.fact_types.write().unwrap();
191
192 data.clear();
193 types.clear();
194 }
195
196 pub fn get_fact_names(&self) -> Vec<String> {
198 let data = self.data.read().unwrap();
199 data.keys().cloned().collect()
200 }
201
202 pub fn count(&self) -> usize {
204 let data = self.data.read().unwrap();
205 data.len()
206 }
207
208 pub fn contains(&self, name: &str) -> bool {
210 let data = self.data.read().unwrap();
211 data.contains_key(name)
212 }
213
214 pub fn get_all_facts(&self) -> HashMap<String, Value> {
216 let data = self.data.read().unwrap();
217 data.clone()
218 }
219
220 pub fn get_fact_type(&self, name: &str) -> Option<String> {
222 let types = self.fact_types.read().unwrap();
223 types.get(name).cloned()
224 }
225
226 pub fn to_context(&self) -> Context {
228 let data = self.data.read().unwrap();
229 data.clone()
230 }
231
232 pub fn from_context(context: Context) -> Self {
234 let facts = Facts::new();
235 {
236 let mut data = facts.data.write().unwrap();
237 *data = context;
238 }
239 facts
240 }
241
242 pub fn merge(&self, other: &Facts) {
244 let other_data = other.data.read().unwrap();
245 let other_types = other.fact_types.read().unwrap();
246
247 let mut data = self.data.write().unwrap();
248 let mut types = self.fact_types.write().unwrap();
249
250 for (key, value) in other_data.iter() {
251 data.insert(key.clone(), value.clone());
252 }
253
254 for (key, type_name) in other_types.iter() {
255 types.insert(key.clone(), type_name.clone());
256 }
257 }
258
259 pub fn snapshot(&self) -> FactsSnapshot {
261 let data = self.data.read().unwrap();
262 let types = self.fact_types.read().unwrap();
263
264 FactsSnapshot {
265 data: data.clone(),
266 fact_types: types.clone(),
267 }
268 }
269
270 pub fn restore(&self, snapshot: FactsSnapshot) {
272 let mut data = self.data.write().unwrap();
273 let mut types = self.fact_types.write().unwrap();
274
275 *data = snapshot.data;
276 *types = snapshot.fact_types;
277 }
278}
279
280impl Default for Facts {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct FactsSnapshot {
289 pub data: HashMap<String, Value>,
291 pub fact_types: HashMap<String, String>,
293}
294
295#[derive(Debug, Clone)]
297struct UndoEntry {
298 key: String,
299 prev_value: Option<Value>,
300 prev_type: Option<String>,
301}
302
303impl Facts {
304 pub fn begin_undo_frame(&self) {
307 let mut frames = self.undo_frames.write().unwrap();
308 frames.push(Vec::new());
309 }
310
311 pub fn commit_undo_frame(&self) {
313 let mut frames = self.undo_frames.write().unwrap();
314 frames.pop();
315 }
316
317 pub fn rollback_undo_frame(&self) {
319 let mut frames = self.undo_frames.write().unwrap();
320 if let Some(frame) = frames.pop() {
321 let mut data = self.data.write().unwrap();
323 let mut types = self.fact_types.write().unwrap();
324
325 for entry in frame.into_iter().rev() {
326 match entry.prev_value {
327 Some(v) => {
328 data.insert(entry.key.clone(), v);
329 }
330 None => {
331 data.remove(&entry.key);
332 }
333 }
334
335 match entry.prev_type {
336 Some(t) => {
337 types.insert(entry.key.clone(), t);
338 }
339 None => {
340 types.remove(&entry.key);
341 }
342 }
343 }
344 }
345 }
346
347 fn record_undo_for_key(&self, key: &str) {
349 let mut frames = self.undo_frames.write().unwrap();
350 if let Some(frame) = frames.last_mut() {
351 let data = self.data.read().unwrap();
353 let types = self.fact_types.read().unwrap();
354
355 if frame.iter().any(|e: &UndoEntry| e.key == key) {
357 return;
358 }
359
360 let prev_value = data.get(key).cloned();
361 let prev_type = types.get(key).cloned();
362
363 frame.push(UndoEntry {
364 key: key.to_string(),
365 prev_value,
366 prev_type,
367 });
368 }
369 }
370}
371
372pub trait Fact: Serialize + std::fmt::Debug {
374 fn fact_name() -> &'static str;
376}
377
378#[macro_export]
380macro_rules! impl_fact {
381 ($type:ty, $name:expr) => {
382 impl Fact for $type {
383 fn fact_name() -> &'static str {
384 $name
385 }
386 }
387 };
388}
389
390pub struct FactHelper;
392
393impl FactHelper {
394 pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
396 let mut object = HashMap::new();
397 for (key, value) in pairs {
398 object.insert(key.to_string(), value);
399 }
400 Value::Object(object)
401 }
402
403 pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
405 let mut user = HashMap::new();
406 user.insert("Name".to_string(), Value::String(name.to_string()));
407 user.insert("Age".to_string(), Value::Integer(age));
408 user.insert("Email".to_string(), Value::String(email.to_string()));
409 user.insert("Country".to_string(), Value::String(country.to_string()));
410 user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
411
412 Value::Object(user)
413 }
414
415 pub fn create_product(
417 name: &str,
418 price: f64,
419 category: &str,
420 in_stock: bool,
421 stock_count: i64,
422 ) -> Value {
423 let mut product = HashMap::new();
424 product.insert("Name".to_string(), Value::String(name.to_string()));
425 product.insert("Price".to_string(), Value::Number(price));
426 product.insert("Category".to_string(), Value::String(category.to_string()));
427 product.insert("InStock".to_string(), Value::Boolean(in_stock));
428 product.insert("StockCount".to_string(), Value::Integer(stock_count));
429
430 Value::Object(product)
431 }
432
433 pub fn create_order(
435 id: &str,
436 user_id: &str,
437 total: f64,
438 item_count: i64,
439 status: &str,
440 ) -> Value {
441 let mut order = HashMap::new();
442 order.insert("ID".to_string(), Value::String(id.to_string()));
443 order.insert("UserID".to_string(), Value::String(user_id.to_string()));
444 order.insert("Total".to_string(), Value::Number(total));
445 order.insert("ItemCount".to_string(), Value::Integer(item_count));
446 order.insert("Status".to_string(), Value::String(status.to_string()));
447
448 Value::Object(order)
449 }
450
451 pub fn create_test_car(
453 speed_up: bool,
454 speed: f64,
455 max_speed: f64,
456 speed_increment: f64,
457 ) -> Value {
458 let mut car = HashMap::new();
459 car.insert("speedUp".to_string(), Value::Boolean(speed_up));
460 car.insert("speed".to_string(), Value::Number(speed));
461 car.insert("maxSpeed".to_string(), Value::Number(max_speed));
462 car.insert("Speed".to_string(), Value::Number(speed));
463 car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
464 car.insert(
465 "_type".to_string(),
466 Value::String("TestCarClass".to_string()),
467 );
468
469 Value::Object(car)
470 }
471
472 pub fn create_distance_record(total_distance: f64) -> Value {
474 let mut record = HashMap::new();
475 record.insert("TotalDistance".to_string(), Value::Number(total_distance));
476 record.insert(
477 "_type".to_string(),
478 Value::String("DistanceRecordClass".to_string()),
479 );
480
481 Value::Object(record)
482 }
483
484 pub fn create_transaction(
486 id: &str,
487 amount: f64,
488 location: &str,
489 timestamp: i64,
490 user_id: &str,
491 ) -> Value {
492 let mut transaction = HashMap::new();
493 transaction.insert("ID".to_string(), Value::String(id.to_string()));
494 transaction.insert("Amount".to_string(), Value::Number(amount));
495 transaction.insert("Location".to_string(), Value::String(location.to_string()));
496 transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
497 transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
498
499 Value::Object(transaction)
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_facts_basic_operations() {
509 let facts = Facts::new();
510
511 facts.add_value("age", Value::Integer(25)).unwrap();
513 facts
514 .add_value("name", Value::String("John".to_string()))
515 .unwrap();
516
517 assert_eq!(facts.get("age"), Some(Value::Integer(25)));
519 assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
520
521 assert_eq!(facts.count(), 2);
523
524 assert!(facts.contains("age"));
526 assert!(!facts.contains("email"));
527 }
528
529 #[test]
530 fn test_nested_facts() {
531 let facts = Facts::new();
532 let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
533
534 facts.add_value("User", user).unwrap();
535
536 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
538 assert_eq!(
539 facts.get_nested("User.Name"),
540 Some(Value::String("John".to_string()))
541 );
542
543 facts.set_nested("User.Age", Value::Integer(26)).unwrap();
545 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
546 }
547
548 #[test]
549 fn test_facts_snapshot() {
550 let facts = Facts::new();
551 facts
552 .add_value("test", Value::String("value".to_string()))
553 .unwrap();
554
555 let snapshot = facts.snapshot();
556
557 facts.clear();
558 assert_eq!(facts.count(), 0);
559
560 facts.restore(snapshot);
561 assert_eq!(facts.count(), 1);
562 assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
563 }
564}