1use crate::errors::{Result, RuleEngineError};
7use crate::rete::facts::FactValue;
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11#[derive(Debug, Clone)]
13pub struct GlobalVar {
14 pub name: String,
15 pub value: FactValue,
16 pub read_only: bool,
17}
18
19impl GlobalVar {
20 pub fn new(name: impl Into<String>, value: FactValue) -> Self {
22 Self {
23 name: name.into(),
24 value,
25 read_only: false,
26 }
27 }
28
29 pub fn read_only(name: impl Into<String>, value: FactValue) -> Self {
31 Self {
32 name: name.into(),
33 value,
34 read_only: true,
35 }
36 }
37
38 pub fn set(&mut self, value: FactValue) -> Result<()> {
40 if self.read_only {
41 return Err(RuleEngineError::EvaluationError {
42 message: format!("Cannot modify read-only global '{}'", self.name),
43 });
44 }
45 self.value = value;
46 Ok(())
47 }
48
49 pub fn get(&self) -> &FactValue {
51 &self.value
52 }
53}
54
55#[derive(Debug, Clone)]
58pub struct GlobalsRegistry {
59 globals: Arc<RwLock<HashMap<String, GlobalVar>>>,
60}
61
62impl GlobalsRegistry {
63 pub fn new() -> Self {
65 Self {
66 globals: Arc::new(RwLock::new(HashMap::new())),
67 }
68 }
69
70 pub fn define(&self, name: impl Into<String>, value: FactValue) -> Result<()> {
72 let var_name = name.into();
73 let mut globals = self.globals.write().map_err(|e| {
74 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
75 })?;
76
77 globals.insert(var_name.clone(), GlobalVar::new(var_name, value));
78 Ok(())
79 }
80
81 pub fn define_readonly(&self, name: impl Into<String>, value: FactValue) -> Result<()> {
83 let var_name = name.into();
84 let mut globals = self.globals.write().map_err(|e| {
85 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
86 })?;
87
88 globals.insert(var_name.clone(), GlobalVar::read_only(var_name, value));
89 Ok(())
90 }
91
92 pub fn get(&self, name: &str) -> Result<FactValue> {
94 let globals = self.globals.read().map_err(|e| {
95 RuleEngineError::ExecutionError(format!("Failed to acquire read lock: {}", e))
96 })?;
97
98 globals
99 .get(name)
100 .map(|var| var.value.clone())
101 .ok_or_else(|| RuleEngineError::EvaluationError {
102 message: format!("Global variable '{}' not found", name),
103 })
104 }
105
106 pub fn set(&self, name: &str, value: FactValue) -> Result<()> {
108 let mut globals = self.globals.write().map_err(|e| {
109 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
110 })?;
111
112 let var = globals
113 .get_mut(name)
114 .ok_or_else(|| RuleEngineError::EvaluationError {
115 message: format!("Global variable '{}' not found", name),
116 })?;
117
118 var.set(value)
119 }
120
121 pub fn exists(&self, name: &str) -> bool {
123 if let Ok(globals) = self.globals.read() {
124 globals.contains_key(name)
125 } else {
126 false
127 }
128 }
129
130 pub fn remove(&self, name: &str) -> Result<()> {
132 let mut globals = self.globals.write().map_err(|e| {
133 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
134 })?;
135
136 globals
137 .remove(name)
138 .ok_or_else(|| RuleEngineError::EvaluationError {
139 message: format!("Global variable '{}' not found", name),
140 })?;
141
142 Ok(())
143 }
144
145 pub fn list_globals(&self) -> Vec<String> {
147 if let Ok(globals) = self.globals.read() {
148 globals.keys().cloned().collect()
149 } else {
150 Vec::new()
151 }
152 }
153
154 pub fn get_all(&self) -> HashMap<String, FactValue> {
156 if let Ok(globals) = self.globals.read() {
157 globals
158 .iter()
159 .map(|(k, v)| (k.clone(), v.value.clone()))
160 .collect()
161 } else {
162 HashMap::new()
163 }
164 }
165
166 pub fn clear(&self) {
168 if let Ok(mut globals) = self.globals.write() {
169 globals.clear();
170 }
171 }
172
173 pub fn increment(&self, name: &str, delta: f64) -> Result<()> {
175 let mut globals = self.globals.write().map_err(|e| {
176 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
177 })?;
178
179 let var = globals
180 .get_mut(name)
181 .ok_or_else(|| RuleEngineError::EvaluationError {
182 message: format!("Global variable '{}' not found", name),
183 })?;
184
185 if var.read_only {
186 return Err(RuleEngineError::EvaluationError {
187 message: format!("Cannot modify read-only global '{}'", name),
188 });
189 }
190
191 let new_value = match &var.value {
192 FactValue::Integer(i) => FactValue::Integer(i + delta as i64),
193 FactValue::Float(f) => FactValue::Float(f + delta),
194 _ => {
195 return Err(RuleEngineError::EvaluationError {
196 message: format!("Global '{}' is not numeric", name),
197 })
198 }
199 };
200
201 var.value = new_value;
202 Ok(())
203 }
204}
205
206impl Default for GlobalsRegistry {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212pub struct GlobalsBuilder {
214 registry: GlobalsRegistry,
215}
216
217impl GlobalsBuilder {
218 pub fn new() -> Self {
220 Self {
221 registry: GlobalsRegistry::new(),
222 }
223 }
224
225 pub fn define(self, name: impl Into<String>, value: FactValue) -> Self {
227 let _ = self.registry.define(name, value);
228 self
229 }
230
231 pub fn define_readonly(self, name: impl Into<String>, value: FactValue) -> Self {
233 let _ = self.registry.define_readonly(name, value);
234 self
235 }
236
237 pub fn build(self) -> GlobalsRegistry {
239 self.registry
240 }
241}
242
243impl Default for GlobalsBuilder {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_define_and_get() {
255 let registry = GlobalsRegistry::new();
256 registry.define("counter", FactValue::Integer(0)).unwrap();
257
258 let value = registry.get("counter").unwrap();
259 assert_eq!(value, FactValue::Integer(0));
260 }
261
262 #[test]
263 fn test_set_global() {
264 let registry = GlobalsRegistry::new();
265 registry
266 .define("status", FactValue::String("active".to_string()))
267 .unwrap();
268
269 registry
270 .set("status", FactValue::String("inactive".to_string()))
271 .unwrap();
272
273 let value = registry.get("status").unwrap();
274 assert_eq!(value, FactValue::String("inactive".to_string()));
275 }
276
277 #[test]
278 fn test_readonly_global() {
279 let registry = GlobalsRegistry::new();
280 registry
281 .define_readonly("PI", FactValue::Float(std::f64::consts::PI))
282 .unwrap();
283
284 let result = registry.set("PI", FactValue::Float(3.0));
286 assert!(result.is_err());
287
288 let value = registry.get("PI").unwrap();
290 assert_eq!(value, FactValue::Float(std::f64::consts::PI));
291 }
292
293 #[test]
294 fn test_increment() {
295 let registry = GlobalsRegistry::new();
296 registry.define("counter", FactValue::Integer(10)).unwrap();
297
298 registry.increment("counter", 5.0).unwrap();
299
300 let value = registry.get("counter").unwrap();
301 assert_eq!(value, FactValue::Integer(15));
302 }
303
304 #[test]
305 fn test_list_globals() {
306 let registry = GlobalsRegistry::new();
307 registry.define("var1", FactValue::Integer(1)).unwrap();
308 registry.define("var2", FactValue::Integer(2)).unwrap();
309
310 let list = registry.list_globals();
311 assert_eq!(list.len(), 2);
312 assert!(list.contains(&"var1".to_string()));
313 assert!(list.contains(&"var2".to_string()));
314 }
315
316 #[test]
317 fn test_remove_global() {
318 let registry = GlobalsRegistry::new();
319 registry.define("temp", FactValue::Boolean(true)).unwrap();
320
321 assert!(registry.exists("temp"));
322
323 registry.remove("temp").unwrap();
324
325 assert!(!registry.exists("temp"));
326 }
327
328 #[test]
329 fn test_builder() {
330 let registry = GlobalsBuilder::new()
331 .define("max_retries", FactValue::Integer(3))
332 .define("timeout", FactValue::Float(30.0))
333 .define_readonly("VERSION", FactValue::String("1.0.0".to_string()))
334 .build();
335
336 assert_eq!(registry.get("max_retries").unwrap(), FactValue::Integer(3));
337 assert_eq!(registry.get("timeout").unwrap(), FactValue::Float(30.0));
338 assert_eq!(
339 registry.get("VERSION").unwrap(),
340 FactValue::String("1.0.0".to_string())
341 );
342 }
343
344 #[test]
345 fn test_get_all() {
346 let registry = GlobalsRegistry::new();
347 registry.define("a", FactValue::Integer(1)).unwrap();
348 registry.define("b", FactValue::Integer(2)).unwrap();
349
350 let all = registry.get_all();
351 assert_eq!(all.len(), 2);
352 assert_eq!(all.get("a"), Some(&FactValue::Integer(1)));
353 assert_eq!(all.get("b"), Some(&FactValue::Integer(2)));
354 }
355
356 #[test]
357 fn test_thread_safety() {
358 use std::thread;
359
360 let registry = GlobalsRegistry::new();
361 registry
362 .define("shared_counter", FactValue::Integer(0))
363 .unwrap();
364
365 let registry_clone = registry.clone();
366 let handle = thread::spawn(move || {
367 registry_clone.increment("shared_counter", 1.0).unwrap();
368 });
369
370 handle.join().unwrap();
371
372 let value = registry.get("shared_counter").unwrap();
373 assert_eq!(value, FactValue::Integer(1));
374 }
375}