1use crate::errors::{Result, RuleEngineError};
2use crate::types::{Context, Value};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7#[derive(Debug, Clone)]
10pub struct Facts {
11 data: Arc<RwLock<HashMap<String, Value>>>,
12 fact_types: Arc<RwLock<HashMap<String, String>>>,
13 undo_frames: Arc<RwLock<Vec<Vec<UndoEntry>>>>,
17}
18
19impl Facts {
20 pub fn create_object(pairs: Vec<(String, Value)>) -> Value {
22 let mut map = HashMap::new();
23 for (key, value) in pairs {
24 map.insert(key, value);
25 }
26 Value::Object(map)
27 }
28
29 pub fn new() -> Self {
31 Self {
32 data: Arc::new(RwLock::new(HashMap::new())),
33 fact_types: Arc::new(RwLock::new(HashMap::new())),
34 undo_frames: Arc::new(RwLock::new(Vec::new())),
35 }
36 }
37
38 pub fn add<T>(&self, name: &str, fact: T) -> Result<()>
40 where
41 T: Serialize + std::fmt::Debug,
42 {
43 let value =
44 serde_json::to_value(&fact).map_err(|e| RuleEngineError::SerializationError {
45 message: e.to_string(),
46 })?;
47
48 let fact_value = Value::from(value);
49
50 let mut data = self.data.write().unwrap();
51 let mut types = self.fact_types.write().unwrap();
52
53 data.insert(name.to_string(), fact_value);
54 types.insert(name.to_string(), std::any::type_name::<T>().to_string());
55
56 Ok(())
57 }
58
59 pub fn add_value(&self, name: &str, value: Value) -> Result<()> {
61 let mut data = self.data.write().unwrap();
62 let mut types = self.fact_types.write().unwrap();
63
64 data.insert(name.to_string(), value);
65 types.insert(name.to_string(), "Value".to_string());
66
67 Ok(())
68 }
69
70 pub fn get(&self, name: &str) -> Option<Value> {
72 let data = self.data.read().unwrap();
73 data.get(name).cloned()
74 }
75
76 pub fn with_value<F, R>(&self, name: &str, f: F) -> Option<R>
78 where
79 F: FnOnce(&Value) -> R,
80 {
81 let data = self.data.read().unwrap();
82 data.get(name).map(f)
83 }
84
85 pub fn get_nested(&self, path: &str) -> Option<Value> {
87 let parts: Vec<&str> = path.split('.').collect();
88 if parts.is_empty() {
89 return None;
90 }
91
92 let data = self.data.read().unwrap();
93 let mut current = data.get(parts[0])?;
94
95 for part in parts.iter().skip(1) {
96 match current {
97 Value::Object(ref obj) => {
98 current = obj.get(*part)?;
99 }
100 _ => return None,
101 }
102 }
103
104 Some(current.clone())
105 }
106
107 pub fn set(&self, name: &str, value: Value) {
109 self.record_undo_for_key(name);
111
112 let mut data = self.data.write().unwrap();
113 data.insert(name.to_string(), value);
114 }
115
116 pub fn set_nested(&self, path: &str, value: Value) -> Result<()> {
118 let parts: Vec<&str> = path.split('.').collect();
119 if parts.is_empty() {
120 return Err(RuleEngineError::FieldNotFound {
121 field: path.to_string(),
122 });
123 }
124
125 self.record_undo_for_key(parts[0]);
127
128 let mut data = self.data.write().unwrap();
129
130 if parts.len() == 1 {
131 data.insert(parts[0].to_string(), value);
132 return Ok(());
133 }
134
135 let root_key = parts[0];
137 let root_value = data
138 .get_mut(root_key)
139 .ok_or_else(|| RuleEngineError::FieldNotFound {
140 field: root_key.to_string(),
141 })?;
142
143 self.set_nested_in_value(root_value, &parts[1..], value)?;
144 Ok(())
145 }
146
147 #[allow(clippy::only_used_in_recursion)]
148 fn set_nested_in_value(&self, current: &mut Value, path: &[&str], value: Value) -> Result<()> {
149 if path.is_empty() {
150 return Ok(());
151 }
152
153 if path.len() == 1 {
154 match current {
156 Value::Object(ref mut obj) => {
157 obj.insert(path[0].to_string(), value);
158 Ok(())
159 }
160 _ => Err(RuleEngineError::TypeMismatch {
161 expected: "Object".to_string(),
162 actual: format!("{:?}", current),
163 }),
164 }
165 } else {
166 match current {
168 Value::Object(ref mut obj) => {
169 let next_value =
170 obj.get_mut(path[0])
171 .ok_or_else(|| RuleEngineError::FieldNotFound {
172 field: path[0].to_string(),
173 })?;
174 self.set_nested_in_value(next_value, &path[1..], value)
175 }
176 _ => Err(RuleEngineError::TypeMismatch {
177 expected: "Object".to_string(),
178 actual: format!("{:?}", current),
179 }),
180 }
181 }
182 }
183
184 pub fn remove(&self, name: &str) -> Option<Value> {
186 self.record_undo_for_key(name);
188
189 let mut data = self.data.write().unwrap();
190 let mut types = self.fact_types.write().unwrap();
191
192 types.remove(name);
193 data.remove(name)
194 }
195
196 pub fn clear(&self) {
198 let mut data = self.data.write().unwrap();
199 let mut types = self.fact_types.write().unwrap();
200
201 data.clear();
202 types.clear();
203 }
204
205 pub fn get_fact_names(&self) -> Vec<String> {
207 let data = self.data.read().unwrap();
208 data.keys().cloned().collect()
209 }
210
211 pub fn count(&self) -> usize {
213 let data = self.data.read().unwrap();
214 data.len()
215 }
216
217 pub fn contains(&self, name: &str) -> bool {
219 let data = self.data.read().unwrap();
220 data.contains_key(name)
221 }
222
223 pub fn get_all_facts(&self) -> HashMap<String, Value> {
225 let data = self.data.read().unwrap();
226 data.clone()
227 }
228
229 pub fn get_fact_type(&self, name: &str) -> Option<String> {
231 let types = self.fact_types.read().unwrap();
232 types.get(name).cloned()
233 }
234
235 pub fn to_context(&self) -> Context {
237 let data = self.data.read().unwrap();
238 data.clone()
239 }
240
241 pub fn from_context(context: Context) -> Self {
243 let facts = Facts::new();
244 {
245 let mut data = facts.data.write().unwrap();
246 *data = context;
247 }
248 facts
249 }
250
251 pub fn merge(&self, other: &Facts) {
253 let other_data = other.data.read().unwrap();
254 let other_types = other.fact_types.read().unwrap();
255
256 let mut data = self.data.write().unwrap();
257 let mut types = self.fact_types.write().unwrap();
258
259 for (key, value) in other_data.iter() {
260 data.insert(key.clone(), value.clone());
261 }
262
263 for (key, type_name) in other_types.iter() {
264 types.insert(key.clone(), type_name.clone());
265 }
266 }
267
268 pub fn snapshot(&self) -> FactsSnapshot {
270 let data = self.data.read().unwrap();
271 let types = self.fact_types.read().unwrap();
272
273 FactsSnapshot {
274 data: data.clone(),
275 fact_types: types.clone(),
276 }
277 }
278
279 pub fn restore(&self, snapshot: FactsSnapshot) {
281 let mut data = self.data.write().unwrap();
282 let mut types = self.fact_types.write().unwrap();
283
284 *data = snapshot.data;
285 *types = snapshot.fact_types;
286 }
287}
288
289impl Default for Facts {
290 fn default() -> Self {
291 Self::new()
292 }
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct FactsSnapshot {
298 pub data: HashMap<String, Value>,
300 pub fact_types: HashMap<String, String>,
302}
303
304#[derive(Debug, Clone)]
306struct UndoEntry {
307 key: String,
308 prev_value: Option<Value>,
309 prev_type: Option<String>,
310}
311
312impl Facts {
313 pub fn begin_undo_frame(&self) {
316 let mut frames = self.undo_frames.write().unwrap();
317 frames.push(Vec::new());
318 }
319
320 pub fn commit_undo_frame(&self) {
322 let mut frames = self.undo_frames.write().unwrap();
323 frames.pop();
324 }
325
326 pub fn rollback_undo_frame(&self) {
328 let mut frames = self.undo_frames.write().unwrap();
329 if let Some(frame) = frames.pop() {
330 let mut data = self.data.write().unwrap();
332 let mut types = self.fact_types.write().unwrap();
333
334 for entry in frame.into_iter().rev() {
335 match entry.prev_value {
336 Some(v) => {
337 data.insert(entry.key.clone(), v);
338 }
339 None => {
340 data.remove(&entry.key);
341 }
342 }
343
344 match entry.prev_type {
345 Some(t) => {
346 types.insert(entry.key.clone(), t);
347 }
348 None => {
349 types.remove(&entry.key);
350 }
351 }
352 }
353 }
354 }
355
356 fn record_undo_for_key(&self, key: &str) {
358 let mut frames = self.undo_frames.write().unwrap();
359 if let Some(frame) = frames.last_mut() {
360 let data = self.data.read().unwrap();
362 let types = self.fact_types.read().unwrap();
363
364 if frame.iter().any(|e: &UndoEntry| e.key == key) {
366 return;
367 }
368
369 let prev_value = data.get(key).cloned();
370 let prev_type = types.get(key).cloned();
371
372 frame.push(UndoEntry {
373 key: key.to_string(),
374 prev_value,
375 prev_type,
376 });
377 }
378 }
379}
380
381pub trait Fact: Serialize + std::fmt::Debug {
383 fn fact_name() -> &'static str;
385}
386
387#[macro_export]
389macro_rules! impl_fact {
390 ($type:ty, $name:expr) => {
391 impl Fact for $type {
392 fn fact_name() -> &'static str {
393 $name
394 }
395 }
396 };
397}
398
399pub struct FactHelper;
401
402impl FactHelper {
403 pub fn create_object(pairs: Vec<(&str, Value)>) -> Value {
405 let mut object = HashMap::new();
406 for (key, value) in pairs {
407 object.insert(key.to_string(), value);
408 }
409 Value::Object(object)
410 }
411
412 pub fn create_user(name: &str, age: i64, email: &str, country: &str, is_vip: bool) -> Value {
414 let mut user = HashMap::new();
415 user.insert("Name".to_string(), Value::String(name.to_string()));
416 user.insert("Age".to_string(), Value::Integer(age));
417 user.insert("Email".to_string(), Value::String(email.to_string()));
418 user.insert("Country".to_string(), Value::String(country.to_string()));
419 user.insert("IsVIP".to_string(), Value::Boolean(is_vip));
420
421 Value::Object(user)
422 }
423
424 pub fn create_product(
426 name: &str,
427 price: f64,
428 category: &str,
429 in_stock: bool,
430 stock_count: i64,
431 ) -> Value {
432 let mut product = HashMap::new();
433 product.insert("Name".to_string(), Value::String(name.to_string()));
434 product.insert("Price".to_string(), Value::Number(price));
435 product.insert("Category".to_string(), Value::String(category.to_string()));
436 product.insert("InStock".to_string(), Value::Boolean(in_stock));
437 product.insert("StockCount".to_string(), Value::Integer(stock_count));
438
439 Value::Object(product)
440 }
441
442 pub fn create_order(
444 id: &str,
445 user_id: &str,
446 total: f64,
447 item_count: i64,
448 status: &str,
449 ) -> Value {
450 let mut order = HashMap::new();
451 order.insert("ID".to_string(), Value::String(id.to_string()));
452 order.insert("UserID".to_string(), Value::String(user_id.to_string()));
453 order.insert("Total".to_string(), Value::Number(total));
454 order.insert("ItemCount".to_string(), Value::Integer(item_count));
455 order.insert("Status".to_string(), Value::String(status.to_string()));
456
457 Value::Object(order)
458 }
459
460 pub fn create_test_car(
462 speed_up: bool,
463 speed: f64,
464 max_speed: f64,
465 speed_increment: f64,
466 ) -> Value {
467 let mut car = HashMap::new();
468 car.insert("speedUp".to_string(), Value::Boolean(speed_up));
469 car.insert("speed".to_string(), Value::Number(speed));
470 car.insert("maxSpeed".to_string(), Value::Number(max_speed));
471 car.insert("Speed".to_string(), Value::Number(speed));
472 car.insert("SpeedIncrement".to_string(), Value::Number(speed_increment));
473 car.insert(
474 "_type".to_string(),
475 Value::String("TestCarClass".to_string()),
476 );
477
478 Value::Object(car)
479 }
480
481 pub fn create_distance_record(total_distance: f64) -> Value {
483 let mut record = HashMap::new();
484 record.insert("TotalDistance".to_string(), Value::Number(total_distance));
485 record.insert(
486 "_type".to_string(),
487 Value::String("DistanceRecordClass".to_string()),
488 );
489
490 Value::Object(record)
491 }
492
493 pub fn create_transaction(
495 id: &str,
496 amount: f64,
497 location: &str,
498 timestamp: i64,
499 user_id: &str,
500 ) -> Value {
501 let mut transaction = HashMap::new();
502 transaction.insert("ID".to_string(), Value::String(id.to_string()));
503 transaction.insert("Amount".to_string(), Value::Number(amount));
504 transaction.insert("Location".to_string(), Value::String(location.to_string()));
505 transaction.insert("Timestamp".to_string(), Value::Integer(timestamp));
506 transaction.insert("UserID".to_string(), Value::String(user_id.to_string()));
507
508 Value::Object(transaction)
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_facts_basic_operations() {
518 let facts = Facts::new();
519
520 facts.add_value("age", Value::Integer(25)).unwrap();
522 facts
523 .add_value("name", Value::String("John".to_string()))
524 .unwrap();
525
526 assert_eq!(facts.get("age"), Some(Value::Integer(25)));
528 assert_eq!(facts.get("name"), Some(Value::String("John".to_string())));
529
530 assert_eq!(facts.count(), 2);
532
533 assert!(facts.contains("age"));
535 assert!(!facts.contains("email"));
536 }
537
538 #[test]
539 fn test_nested_facts() {
540 let facts = Facts::new();
541 let user = FactHelper::create_user("John", 25, "john@example.com", "US", true);
542
543 facts.add_value("User", user).unwrap();
544
545 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(25)));
547 assert_eq!(
548 facts.get_nested("User.Name"),
549 Some(Value::String("John".to_string()))
550 );
551
552 facts.set_nested("User.Age", Value::Integer(26)).unwrap();
554 assert_eq!(facts.get_nested("User.Age"), Some(Value::Integer(26)));
555 }
556
557 #[test]
558 fn test_facts_snapshot() {
559 let facts = Facts::new();
560 facts
561 .add_value("test", Value::String("value".to_string()))
562 .unwrap();
563
564 let snapshot = facts.snapshot();
565
566 facts.clear();
567 assert_eq!(facts.count(), 0);
568
569 facts.restore(snapshot);
570 assert_eq!(facts.count(), 1);
571 assert_eq!(facts.get("test"), Some(Value::String("value".to_string())));
572 }
573}