1use crate::rete::facts::FactValue;
7use crate::errors::{Result, RuleEngineError};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11#[derive(Debug, Clone)]
13pub struct GlobalVar {
14 pub name: String,
15 pub value: FactValue,
16 pub read_only: bool,
17}
18
19impl GlobalVar {
20 pub fn new(name: impl Into<String>, value: FactValue) -> Self {
22 Self {
23 name: name.into(),
24 value,
25 read_only: false,
26 }
27 }
28
29 pub fn read_only(name: impl Into<String>, value: FactValue) -> Self {
31 Self {
32 name: name.into(),
33 value,
34 read_only: true,
35 }
36 }
37
38 pub fn set(&mut self, value: FactValue) -> Result<()> {
40 if self.read_only {
41 return Err(RuleEngineError::EvaluationError {
42 message: format!("Cannot modify read-only global '{}'", self.name),
43 });
44 }
45 self.value = value;
46 Ok(())
47 }
48
49 pub fn get(&self) -> &FactValue {
51 &self.value
52 }
53}
54
55#[derive(Debug, Clone)]
58pub struct GlobalsRegistry {
59 globals: Arc<RwLock<HashMap<String, GlobalVar>>>,
60}
61
62impl GlobalsRegistry {
63 pub fn new() -> Self {
65 Self {
66 globals: Arc::new(RwLock::new(HashMap::new())),
67 }
68 }
69
70 pub fn define(&self, name: impl Into<String>, value: FactValue) -> Result<()> {
72 let var_name = name.into();
73 let mut globals = self.globals.write().map_err(|e| {
74 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
75 })?;
76
77 globals.insert(var_name.clone(), GlobalVar::new(var_name, value));
78 Ok(())
79 }
80
81 pub fn define_readonly(&self, name: impl Into<String>, value: FactValue) -> Result<()> {
83 let var_name = name.into();
84 let mut globals = self.globals.write().map_err(|e| {
85 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
86 })?;
87
88 globals.insert(var_name.clone(), GlobalVar::read_only(var_name, value));
89 Ok(())
90 }
91
92 pub fn get(&self, name: &str) -> Result<FactValue> {
94 let globals = self.globals.read().map_err(|e| {
95 RuleEngineError::ExecutionError(format!("Failed to acquire read lock: {}", e))
96 })?;
97
98 globals
99 .get(name)
100 .map(|var| var.value.clone())
101 .ok_or_else(|| RuleEngineError::EvaluationError {
102 message: format!("Global variable '{}' not found", name),
103 })
104 }
105
106 pub fn set(&self, name: &str, value: FactValue) -> Result<()> {
108 let mut globals = self.globals.write().map_err(|e| {
109 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
110 })?;
111
112 let var = globals.get_mut(name).ok_or_else(|| {
113 RuleEngineError::EvaluationError {
114 message: format!("Global variable '{}' not found", name),
115 }
116 })?;
117
118 var.set(value)
119 }
120
121 pub fn exists(&self, name: &str) -> bool {
123 if let Ok(globals) = self.globals.read() {
124 globals.contains_key(name)
125 } else {
126 false
127 }
128 }
129
130 pub fn remove(&self, name: &str) -> Result<()> {
132 let mut globals = self.globals.write().map_err(|e| {
133 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
134 })?;
135
136 globals.remove(name).ok_or_else(|| {
137 RuleEngineError::EvaluationError {
138 message: format!("Global variable '{}' not found", name),
139 }
140 })?;
141
142 Ok(())
143 }
144
145 pub fn list_globals(&self) -> Vec<String> {
147 if let Ok(globals) = self.globals.read() {
148 globals.keys().cloned().collect()
149 } else {
150 Vec::new()
151 }
152 }
153
154 pub fn get_all(&self) -> HashMap<String, FactValue> {
156 if let Ok(globals) = self.globals.read() {
157 globals
158 .iter()
159 .map(|(k, v)| (k.clone(), v.value.clone()))
160 .collect()
161 } else {
162 HashMap::new()
163 }
164 }
165
166 pub fn clear(&self) {
168 if let Ok(mut globals) = self.globals.write() {
169 globals.clear();
170 }
171 }
172
173 pub fn increment(&self, name: &str, delta: f64) -> Result<()> {
175 let mut globals = self.globals.write().map_err(|e| {
176 RuleEngineError::ExecutionError(format!("Failed to acquire write lock: {}", e))
177 })?;
178
179 let var = globals.get_mut(name).ok_or_else(|| {
180 RuleEngineError::EvaluationError {
181 message: format!("Global variable '{}' not found", name),
182 }
183 })?;
184
185 if var.read_only {
186 return Err(RuleEngineError::EvaluationError {
187 message: format!("Cannot modify read-only global '{}'", name),
188 });
189 }
190
191 let new_value = match &var.value {
192 FactValue::Integer(i) => FactValue::Integer(i + delta as i64),
193 FactValue::Float(f) => FactValue::Float(f + delta),
194 _ => {
195 return Err(RuleEngineError::EvaluationError {
196 message: format!("Global '{}' is not numeric", name),
197 })
198 }
199 };
200
201 var.value = new_value;
202 Ok(())
203 }
204}
205
206impl Default for GlobalsRegistry {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212pub struct GlobalsBuilder {
214 registry: GlobalsRegistry,
215}
216
217impl GlobalsBuilder {
218 pub fn new() -> Self {
220 Self {
221 registry: GlobalsRegistry::new(),
222 }
223 }
224
225 pub fn define(self, name: impl Into<String>, value: FactValue) -> Self {
227 let _ = self.registry.define(name, value);
228 self
229 }
230
231 pub fn define_readonly(self, name: impl Into<String>, value: FactValue) -> Self {
233 let _ = self.registry.define_readonly(name, value);
234 self
235 }
236
237 pub fn build(self) -> GlobalsRegistry {
239 self.registry
240 }
241}
242
243impl Default for GlobalsBuilder {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_define_and_get() {
255 let registry = GlobalsRegistry::new();
256 registry.define("counter", FactValue::Integer(0)).unwrap();
257
258 let value = registry.get("counter").unwrap();
259 assert_eq!(value, FactValue::Integer(0));
260 }
261
262 #[test]
263 fn test_set_global() {
264 let registry = GlobalsRegistry::new();
265 registry.define("status", FactValue::String("active".to_string())).unwrap();
266
267 registry.set("status", FactValue::String("inactive".to_string())).unwrap();
268
269 let value = registry.get("status").unwrap();
270 assert_eq!(value, FactValue::String("inactive".to_string()));
271 }
272
273 #[test]
274 fn test_readonly_global() {
275 let registry = GlobalsRegistry::new();
276 registry.define_readonly("PI", FactValue::Float(3.14159)).unwrap();
277
278 let result = registry.set("PI", FactValue::Float(3.0));
280 assert!(result.is_err());
281
282 let value = registry.get("PI").unwrap();
284 assert_eq!(value, FactValue::Float(3.14159));
285 }
286
287 #[test]
288 fn test_increment() {
289 let registry = GlobalsRegistry::new();
290 registry.define("counter", FactValue::Integer(10)).unwrap();
291
292 registry.increment("counter", 5.0).unwrap();
293
294 let value = registry.get("counter").unwrap();
295 assert_eq!(value, FactValue::Integer(15));
296 }
297
298 #[test]
299 fn test_list_globals() {
300 let registry = GlobalsRegistry::new();
301 registry.define("var1", FactValue::Integer(1)).unwrap();
302 registry.define("var2", FactValue::Integer(2)).unwrap();
303
304 let list = registry.list_globals();
305 assert_eq!(list.len(), 2);
306 assert!(list.contains(&"var1".to_string()));
307 assert!(list.contains(&"var2".to_string()));
308 }
309
310 #[test]
311 fn test_remove_global() {
312 let registry = GlobalsRegistry::new();
313 registry.define("temp", FactValue::Boolean(true)).unwrap();
314
315 assert!(registry.exists("temp"));
316
317 registry.remove("temp").unwrap();
318
319 assert!(!registry.exists("temp"));
320 }
321
322 #[test]
323 fn test_builder() {
324 let registry = GlobalsBuilder::new()
325 .define("max_retries", FactValue::Integer(3))
326 .define("timeout", FactValue::Float(30.0))
327 .define_readonly("VERSION", FactValue::String("1.0.0".to_string()))
328 .build();
329
330 assert_eq!(registry.get("max_retries").unwrap(), FactValue::Integer(3));
331 assert_eq!(registry.get("timeout").unwrap(), FactValue::Float(30.0));
332 assert_eq!(registry.get("VERSION").unwrap(), FactValue::String("1.0.0".to_string()));
333 }
334
335 #[test]
336 fn test_get_all() {
337 let registry = GlobalsRegistry::new();
338 registry.define("a", FactValue::Integer(1)).unwrap();
339 registry.define("b", FactValue::Integer(2)).unwrap();
340
341 let all = registry.get_all();
342 assert_eq!(all.len(), 2);
343 assert_eq!(all.get("a"), Some(&FactValue::Integer(1)));
344 assert_eq!(all.get("b"), Some(&FactValue::Integer(2)));
345 }
346
347 #[test]
348 fn test_thread_safety() {
349 use std::thread;
350
351 let registry = GlobalsRegistry::new();
352 registry.define("shared_counter", FactValue::Integer(0)).unwrap();
353
354 let registry_clone = registry.clone();
355 let handle = thread::spawn(move || {
356 registry_clone.increment("shared_counter", 1.0).unwrap();
357 });
358
359 handle.join().unwrap();
360
361 let value = registry.get("shared_counter").unwrap();
362 assert_eq!(value, FactValue::Integer(1));
363 }
364}