1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6use std::collections::HashSet;
7
8#[derive(Debug, Clone)]
11pub struct Facts {
12 data: Arc<RwLock<HashMap<String, Value>>>,
13 fact_types: Arc<RwLock<HashMap<String, String>>>,
14 undo_frames: Arc<RwLock<Vec<Vec<UndoEntry>>>>,
18}
19
20impl Facts {
21 pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
23 let mut map = HashMap::new();
24 for (key, value) in pairs {
25 map.insert(key, value);
26 }
27 Value::Object(map)
28 }
29
30 pub fn new() -> Self {
32 Self {
33 data: Arc::new(RwLock::new(HashMap::new())),
34 fact_types: Arc::new(RwLock::new(HashMap::new())),
35 undo_frames: Arc::new(RwLock::new(Vec::new())),
36 }
37 }
38
39 pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
41 where
42 T: Serialize + std::fmt::Debug,
43 {
44 let value =
45 serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
46 message: e.to_string(),
47 })?;
48
49 let fact_value = Value::from(value);
50
51 let mut data = self.data.write().unwrap();
52 let mut types = self.fact_types.write().unwrap();
53
54 data.insert(name.to_string(), fact_value);
55 types.insert(name.to_string(), std::any::type_name::<T>().to_string());
56
57 Ok(())
58 }
59
60 pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
62 let mut data = self.data.write().unwrap();
63 let mut types = self.fact_types.write().unwrap();
64
65 data.insert(name.to_string(), value);
66 types.insert(name.to_string(), "Value".to_string());
67
68 Ok(())
69 }
70
71 pub fn get(&self, name: &str) -> Option<Value> {
73 let data = self.data.read().unwrap();
74 data.get(name).cloned()
75 }
76
77 pub fn get_nested(&self, path: &str) -> Option<Value> {
79 let parts: Vec<&str> = path.split('.').collect();
80 if parts.is_empty() {
81 return None;
82 }
83
84 let data = self.data.read().unwrap();
85 let mut current = data.get(parts[0])?;
86
87 for part in parts.iter().skip(1) {
88 match current {
89 Value::Object(ref obj) => {
90 current = obj.get(*part)?;
91 }
92 _ => return None,
93 }
94 }
95
96 Some(current.clone())
97 }
98
99 pub fn set(&self, name: &str, value: Value) {
101 self.record_undo_for_key(name);
103
104 let mut data = self.data.write().unwrap();
105 data.insert(name.to_string(), value);
106 }
107
108 pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
110 let parts: Vec<&str> = path.split('.').collect();
111 if parts.is_empty() {
112 return Err(RuleEngineError::FieldNotFound {
113 field: path.to_string(),
114 });
115 }
116
117 self.record_undo_for_key(parts[0]);
119
120 let mut data = self.data.write().unwrap();
121
122 if parts.len() == 1 {
123 data.insert(parts[0].to_string(), value);
124 return Ok(());
125 }
126
127 let root_key = parts[0];
129 let root_value = data
130 .get_mut(root_key)
131 .ok_or_else(|| RuleEngineError::FieldNotFound {
132 field: root_key.to_string(),
133 })?;
134
135 self.set_nested_in_value(root_value, &parts[1..], value)?;
136 Ok(())
137 }
138
139 #[allow(clippy::only_used_in_recursion)]
140 fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
141 if path.is_empty() {
142 return Ok(());
143 }
144
145 if path.len() == 1 {
146 match current {
148 Value::Object(ref mut obj) => {
149 obj.insert(path[0].to_string(), value);
150 Ok(())
151 }
152 _ => Err(RuleEngineError::TypeMismatch {
153 expected: "Object".to_string(),
154 actual: format!("{:?}", current),
155 }),
156 }
157 } else {
158 match current {
160 Value::Object(ref mut obj) => {
161 let next_value =
162 obj.get_mut(path[0])
163 .ok_or_else(|| RuleEngineError::FieldNotFound {
164 field: path[0].to_string(),
165 })?;
166 self.set_nested_in_value(next_value, &path[1..], value)
167 }
168 _ => Err(RuleEngineError::TypeMismatch {
169 expected: "Object".to_string(),
170 actual: format!("{:?}", current),
171 }),
172 }
173 }
174 }
175
176 pub fn remove(&self, name: &str) -> Option<Value> {
178 self.record_undo_for_key(name);
180
181 let mut data = self.data.write().unwrap();
182 let mut types = self.fact_types.write().unwrap();
183
184 types.remove(name);
185 data.remove(name)
186 }
187
188 pub fn clear(&self) {
190 let mut data = self.data.write().unwrap();
191 let mut types = self.fact_types.write().unwrap();
192
193 data.clear();
194 types.clear();
195 }
196
197 pub fn get_fact_names(&self) -> Vec<String> {
199 let data = self.data.read().unwrap();
200 data.keys().cloned().collect()
201 }
202
203 pub fn count(&self) -> usize {
205 let data = self.data.read().unwrap();
206 data.len()
207 }
208
209 pub fn contains(&self, name: &str) -> bool {
211 let data = self.data.read().unwrap();
212 data.contains_key(name)
213 }
214
215 pub fn get_all_facts(&self) -> HashMap<String, Value> {
217 let data = self.data.read().unwrap();
218 data.clone()
219 }
220
221 pub fn get_fact_type(&self, name: &str) -> Option<String> {
223 let types = self.fact_types.read().unwrap();
224 types.get(name).cloned()
225 }
226
227 pub fn to_context(&self) -> Context {
229 let data = self.data.read().unwrap();
230 data.clone()
231 }
232
233 pub fn from_context(context: Context) -> Self {
235 let facts = Facts::new();
236 {
237 let mut data = facts.data.write().unwrap();
238 *data = context;
239 }
240 facts
241 }
242
243 pub fn merge(&self, other: &Facts) {
245 let other_data = other.data.read().unwrap();
246 let other_types = other.fact_types.read().unwrap();
247
248 let mut data = self.data.write().unwrap();
249 let mut types = self.fact_types.write().unwrap();
250
251 for (key, value) in other_data.iter() {
252 data.insert(key.clone(), value.clone());
253 }
254
255 for (key, type_name) in other_types.iter() {
256 types.insert(key.clone(), type_name.clone());
257 }
258 }
259
260 pub fn snapshot(&self) -> FactsSnapshot {
262 let data = self.data.read().unwrap();
263 let types = self.fact_types.read().unwrap();
264
265 FactsSnapshot {
266 data: data.clone(),
267 fact_types: types.clone(),
268 }
269 }
270
271 pub fn restore(&self, snapshot: FactsSnapshot) {
273 let mut data = self.data.write().unwrap();
274 let mut types = self.fact_types.write().unwrap();
275
276 *data = snapshot.data;
277 *types = snapshot.fact_types;
278 }
279}
280
281impl Default for Facts {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct FactsSnapshot {
290 pub data: HashMap<String, Value>,
292 pub fact_types: HashMap<String, String>,
294}
295
296#[derive(Debug, Clone)]
298struct UndoEntry {
299 key: String,
300 prev_value: Option<Value>,
301 prev_type: Option<String>,
302}
303
304impl Facts {
305 pub fn begin_undo_frame(&self) {
308 let mut frames = self.undo_frames.write().unwrap();
309 frames.push(Vec::new());
310 }
311
312 pub fn commit_undo_frame(&self) {
314 let mut frames = self.undo_frames.write().unwrap();
315 frames.pop();
316 }
317
318 pub fn rollback_undo_frame(&self) {
320 let mut frames = self.undo_frames.write().unwrap();
321 if let Some(frame) = frames.pop() {
322 let mut data = self.data.write().unwrap();
324 let mut types = self.fact_types.write().unwrap();
325
326 for entry in frame.into_iter().rev() {
327 match entry.prev_value {
328 Some(v) => { data.insert(entry.key.clone(), v); }
329 None => { data.remove(&entry.key); }
330 }
331
332 match entry.prev_type {
333 Some(t) => { types.insert(entry.key.clone(), t); }
334 None => { types.remove(&entry.key); }
335 }
336 }
337 }
338 }
339
340 fn record_undo_for_key(&self, key: &str) {
342 let mut frames = self.undo_frames.write().unwrap();
343 if let Some(frame) = frames.last_mut() {
344 let data = self.data.read().unwrap();
346 let types = self.fact_types.read().unwrap();
347
348 if frame.iter().any(|e: &UndoEntry| e.key == key) {
350 return;
351 }
352
353 let prev_value = data.get(key).cloned();
354 let prev_type = types.get(key).cloned();
355
356 frame.push(UndoEntry {
357 key: key.to_string(),
358 prev_value,
359 prev_type,
360 });
361 }
362 }
363}
364
365pub trait Fact: Serialize + std::fmt::Debug {
367 fn fact_name() -> &'static str;
369}
370
371#[macro_export]
373macro_rules! impl_fact {
374 ($type:ty, $name:expr) => {
375 impl Fact for $type {
376 fn fact_name() -> &'static str {
377 $name
378 }
379 }
380 };
381}
382
383pub struct FactHelper;
385
386impl FactHelper {
387 pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
389 let mut object = HashMap::new();
390 for (key, value) in pairs {
391 object.insert(key.to_string(), value);
392 }
393 Value::Object(object)
394 }
395
396 pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
398 let mut user = HashMap::new();
399 user.insert("Name".to_string(), Value::String(name.to_string()));
400 user.insert("Age".to_string(), Value::Integer(age));
401 user.insert("Email".to_string(), Value::String(email.to_string()));
402 user.insert("Country".to_string(), Value::String(country.to_string()));
403 user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
404
405 Value::Object(user)
406 }
407
408 pub fn create_product(
410 name: &str,
411 price: f64,
412 category: &str,
413 in_stock: bool,
414 stock_count: i64,
415 ) -> Value {
416 let mut product = HashMap::new();
417 product.insert("Name".to_string(), Value::String(name.to_string()));
418 product.insert("Price".to_string(), Value::Number(price));
419 product.insert("Category".to_string(), Value::String(category.to_string()));
420 product.insert("InStock".to_string(), Value::Boolean(in_stock));
421 product.insert("StockCount".to_string(), Value::Integer(stock_count));
422
423 Value::Object(product)
424 }
425
426 pub fn create_order(
428 id: &str,
429 user_id: &str,
430 total: f64,
431 item_count: i64,
432 status: &str,
433 ) -> Value {
434 let mut order = HashMap::new();
435 order.insert("ID".to_string(), Value::String(id.to_string()));
436 order.insert("UserID".to_string(), Value::String(user_id.to_string()));
437 order.insert("Total".to_string(), Value::Number(total));
438 order.insert("ItemCount".to_string(), Value::Integer(item_count));
439 order.insert("Status".to_string(), Value::String(status.to_string()));
440
441 Value::Object(order)
442 }
443
444 pub fn create_test_car(
446 speed_up: bool,
447 speed: f64,
448 max_speed: f64,
449 speed_increment: f64,
450 ) -> Value {
451 let mut car = HashMap::new();
452 car.insert("speedUp".to_string(), Value::Boolean(speed_up));
453 car.insert("speed".to_string(), Value::Number(speed));
454 car.insert("maxSpeed".to_string(), Value::Number(max_speed));
455 car.insert("Speed".to_string(), Value::Number(speed));
456 car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
457 car.insert(
458 "_type".to_string(),
459 Value::String("TestCarClass".to_string()),
460 );
461
462 Value::Object(car)
463 }
464
465 pub fn create_distance_record(total_distance: f64) -> Value {
467 let mut record = HashMap::new();
468 record.insert("TotalDistance".to_string(), Value::Number(total_distance));
469 record.insert(
470 "_type".to_string(),
471 Value::String("DistanceRecordClass".to_string()),
472 );
473
474 Value::Object(record)
475 }
476
477 pub fn create_transaction(
479 id: &str,
480 amount: f64,
481 location: &str,
482 timestamp: i64,
483 user_id: &str,
484 ) -> Value {
485 let mut transaction = HashMap::new();
486 transaction.insert("ID".to_string(), Value::String(id.to_string()));
487 transaction.insert("Amount".to_string(), Value::Number(amount));
488 transaction.insert("Location".to_string(), Value::String(location.to_string()));
489 transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
490 transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
491
492 Value::Object(transaction)
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_facts_basic_operations() {
502 let facts = Facts::new();
503
504 facts.add_value("age", Value::Integer(25)).unwrap();
506 facts
507 .add_value("name", Value::String("John".to_string()))
508 .unwrap();
509
510 assert_eq!(facts.get("age"), Some(Value::Integer(25)));
512 assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
513
514 assert_eq!(facts.count(), 2);
516
517 assert!(facts.contains("age"));
519 assert!(!facts.contains("email"));
520 }
521
522 #[test]
523 fn test_nested_facts() {
524 let facts = Facts::new();
525 let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
526
527 facts.add_value("User", user).unwrap();
528
529 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
531 assert_eq!(
532 facts.get_nested("User.Name"),
533 Some(Value::String("John".to_string()))
534 );
535
536 facts.set_nested("User.Age", Value::Integer(26)).unwrap();
538 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
539 }
540
541 #[test]
542 fn test_facts_snapshot() {
543 let facts = Facts::new();
544 facts
545 .add_value("test", Value::String("value".to_string()))
546 .unwrap();
547
548 let snapshot = facts.snapshot();
549
550 facts.clear();
551 assert_eq!(facts.count(), 0);
552
553 facts.restore(snapshot);
554 assert_eq!(facts.count(), 1);
555 assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
556 }
557}