1use std::collections::HashMap;
22use super::facts::FactValue;
23
24pub trait AccumulateFunction: Send + Sync {
26 fn init(&self) -> Box<dyn AccumulateState>;
28
29 fn name(&self) -> &str;
31
32 fn clone_box(&self) -> Box<dyn AccumulateFunction>;
34}
35
36pub trait AccumulateState: Send {
38 fn accumulate(&mut self, value: &FactValue);
40
41 fn get_result(&self) -> FactValue;
43
44 fn reset(&mut self);
46
47 fn clone_box(&self) -> Box<dyn AccumulateState>;
49}
50
51#[derive(Debug, Clone)]
57pub struct SumFunction;
58
59impl AccumulateFunction for SumFunction {
60 fn init(&self) -> Box<dyn AccumulateState> {
61 Box::new(SumState { total: 0.0 })
62 }
63
64 fn name(&self) -> &str {
65 "sum"
66 }
67
68 fn clone_box(&self) -> Box<dyn AccumulateFunction> {
69 Box::new(self.clone())
70 }
71}
72
73#[derive(Debug, Clone)]
74struct SumState {
75 total: f64,
76}
77
78impl AccumulateState for SumState {
79 fn accumulate(&mut self, value: &FactValue) {
80 match value {
81 FactValue::Integer(i) => self.total += *i as f64,
82 FactValue::Float(f) => self.total += f,
83 _ => {} }
85 }
86
87 fn get_result(&self) -> FactValue {
88 FactValue::Float(self.total)
89 }
90
91 fn reset(&mut self) {
92 self.total = 0.0;
93 }
94
95 fn clone_box(&self) -> Box<dyn AccumulateState> {
96 Box::new(self.clone())
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct CountFunction;
103
104impl AccumulateFunction for CountFunction {
105 fn init(&self) -> Box<dyn AccumulateState> {
106 Box::new(CountState { count: 0 })
107 }
108
109 fn name(&self) -> &str {
110 "count"
111 }
112
113 fn clone_box(&self) -> Box<dyn AccumulateFunction> {
114 Box::new(self.clone())
115 }
116}
117
118#[derive(Debug, Clone)]
119struct CountState {
120 count: i64,
121}
122
123impl AccumulateState for CountState {
124 fn accumulate(&mut self, _value: &FactValue) {
125 self.count += 1;
126 }
127
128 fn get_result(&self) -> FactValue {
129 FactValue::Integer(self.count)
130 }
131
132 fn reset(&mut self) {
133 self.count = 0;
134 }
135
136 fn clone_box(&self) -> Box<dyn AccumulateState> {
137 Box::new(self.clone())
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct AverageFunction;
144
145impl AccumulateFunction for AverageFunction {
146 fn init(&self) -> Box<dyn AccumulateState> {
147 Box::new(AverageState { sum: 0.0, count: 0 })
148 }
149
150 fn name(&self) -> &str {
151 "average"
152 }
153
154 fn clone_box(&self) -> Box<dyn AccumulateFunction> {
155 Box::new(self.clone())
156 }
157}
158
159#[derive(Debug, Clone)]
160struct AverageState {
161 sum: f64,
162 count: usize,
163}
164
165impl AccumulateState for AverageState {
166 fn accumulate(&mut self, value: &FactValue) {
167 match value {
168 FactValue::Integer(i) => {
169 self.sum += *i as f64;
170 self.count += 1;
171 }
172 FactValue::Float(f) => {
173 self.sum += f;
174 self.count += 1;
175 }
176 _ => {} }
178 }
179
180 fn get_result(&self) -> FactValue {
181 if self.count == 0 {
182 FactValue::Float(0.0)
183 } else {
184 FactValue::Float(self.sum / self.count as f64)
185 }
186 }
187
188 fn reset(&mut self) {
189 self.sum = 0.0;
190 self.count = 0;
191 }
192
193 fn clone_box(&self) -> Box<dyn AccumulateState> {
194 Box::new(self.clone())
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct MinFunction;
201
202impl AccumulateFunction for MinFunction {
203 fn init(&self) -> Box<dyn AccumulateState> {
204 Box::new(MinState { min: None })
205 }
206
207 fn name(&self) -> &str {
208 "min"
209 }
210
211 fn clone_box(&self) -> Box<dyn AccumulateFunction> {
212 Box::new(self.clone())
213 }
214}
215
216#[derive(Debug, Clone)]
217struct MinState {
218 min: Option<f64>,
219}
220
221impl AccumulateState for MinState {
222 fn accumulate(&mut self, value: &FactValue) {
223 let num = match value {
224 FactValue::Integer(i) => Some(*i as f64),
225 FactValue::Float(f) => Some(*f),
226 _ => None,
227 };
228
229 if let Some(n) = num {
230 self.min = Some(match self.min {
231 Some(current) => current.min(n),
232 None => n,
233 });
234 }
235 }
236
237 fn get_result(&self) -> FactValue {
238 match self.min {
239 Some(m) => FactValue::Float(m),
240 None => FactValue::Float(0.0),
241 }
242 }
243
244 fn reset(&mut self) {
245 self.min = None;
246 }
247
248 fn clone_box(&self) -> Box<dyn AccumulateState> {
249 Box::new(self.clone())
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct MaxFunction;
256
257impl AccumulateFunction for MaxFunction {
258 fn init(&self) -> Box<dyn AccumulateState> {
259 Box::new(MaxState { max: None })
260 }
261
262 fn name(&self) -> &str {
263 "max"
264 }
265
266 fn clone_box(&self) -> Box<dyn AccumulateFunction> {
267 Box::new(self.clone())
268 }
269}
270
271#[derive(Debug, Clone)]
272struct MaxState {
273 max: Option<f64>,
274}
275
276impl AccumulateState for MaxState {
277 fn accumulate(&mut self, value: &FactValue) {
278 let num = match value {
279 FactValue::Integer(i) => Some(*i as f64),
280 FactValue::Float(f) => Some(*f),
281 _ => None,
282 };
283
284 if let Some(n) = num {
285 self.max = Some(match self.max {
286 Some(current) => current.max(n),
287 None => n,
288 });
289 }
290 }
291
292 fn get_result(&self) -> FactValue {
293 match self.max {
294 Some(m) => FactValue::Float(m),
295 None => FactValue::Float(0.0),
296 }
297 }
298
299 fn reset(&mut self) {
300 self.max = None;
301 }
302
303 fn clone_box(&self) -> Box<dyn AccumulateState> {
304 Box::new(self.clone())
305 }
306}
307
308pub struct AccumulatePattern {
314 pub result_var: String,
316
317 pub source_pattern: String,
319
320 pub extract_field: String,
322
323 pub source_conditions: Vec<String>,
325
326 pub function: Box<dyn AccumulateFunction>,
328}
329
330impl Clone for AccumulatePattern {
331 fn clone(&self) -> Self {
332 Self {
333 result_var: self.result_var.clone(),
334 source_pattern: self.source_pattern.clone(),
335 extract_field: self.extract_field.clone(),
336 source_conditions: self.source_conditions.clone(),
337 function: self.function.clone_box(),
338 }
339 }
340}
341
342impl std::fmt::Debug for AccumulatePattern {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 f.debug_struct("AccumulatePattern")
345 .field("result_var", &self.result_var)
346 .field("source_pattern", &self.source_pattern)
347 .field("extract_field", &self.extract_field)
348 .field("source_conditions", &self.source_conditions)
349 .field("function", &self.function.name())
350 .finish()
351 }
352}
353
354impl AccumulatePattern {
355 pub fn new(
357 result_var: String,
358 source_pattern: String,
359 extract_field: String,
360 function: Box<dyn AccumulateFunction>,
361 ) -> Self {
362 Self {
363 result_var,
364 source_pattern,
365 extract_field,
366 source_conditions: Vec::new(),
367 function,
368 }
369 }
370
371 pub fn with_condition(mut self, condition: String) -> Self {
373 self.source_conditions.push(condition);
374 self
375 }
376}
377
378pub struct AccumulateFunctionRegistry {
384 functions: HashMap<String, Box<dyn AccumulateFunction>>,
385}
386
387impl AccumulateFunctionRegistry {
388 pub fn new() -> Self {
390 let mut registry = Self {
391 functions: HashMap::new(),
392 };
393
394 registry.register(Box::new(SumFunction));
396 registry.register(Box::new(CountFunction));
397 registry.register(Box::new(AverageFunction));
398 registry.register(Box::new(MinFunction));
399 registry.register(Box::new(MaxFunction));
400
401 registry
402 }
403
404 pub fn register(&mut self, function: Box<dyn AccumulateFunction>) {
406 self.functions.insert(function.name().to_string(), function);
407 }
408
409 pub fn get(&self, name: &str) -> Option<Box<dyn AccumulateFunction>> {
411 self.functions.get(name).map(|f| f.clone_box())
412 }
413
414 pub fn available_functions(&self) -> Vec<String> {
416 self.functions.keys().cloned().collect()
417 }
418}
419
420impl Default for AccumulateFunctionRegistry {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_sum_function() {
436 let sum = SumFunction;
437 let mut state = sum.init();
438
439 state.accumulate(&FactValue::Integer(10));
440 state.accumulate(&FactValue::Integer(20));
441 state.accumulate(&FactValue::Float(15.5));
442
443 match state.get_result() {
444 FactValue::Float(f) => assert_eq!(f, 45.5),
445 _ => panic!("Expected Float"),
446 }
447 }
448
449 #[test]
450 fn test_count_function() {
451 let count = CountFunction;
452 let mut state = count.init();
453
454 state.accumulate(&FactValue::Integer(10));
455 state.accumulate(&FactValue::String("test".to_string()));
456 state.accumulate(&FactValue::Boolean(true));
457
458 match state.get_result() {
459 FactValue::Integer(i) => assert_eq!(i, 3),
460 _ => panic!("Expected Integer"),
461 }
462 }
463
464 #[test]
465 fn test_average_function() {
466 let avg = AverageFunction;
467 let mut state = avg.init();
468
469 state.accumulate(&FactValue::Integer(10));
470 state.accumulate(&FactValue::Integer(20));
471 state.accumulate(&FactValue::Integer(30));
472
473 match state.get_result() {
474 FactValue::Float(f) => assert_eq!(f, 20.0),
475 _ => panic!("Expected Float"),
476 }
477 }
478
479 #[test]
480 fn test_min_max_functions() {
481 let min = MinFunction;
482 let max = MaxFunction;
483
484 let mut min_state = min.init();
485 let mut max_state = max.init();
486
487 for value in &[FactValue::Integer(15), FactValue::Integer(5), FactValue::Integer(25)] {
488 min_state.accumulate(value);
489 max_state.accumulate(value);
490 }
491
492 match min_state.get_result() {
493 FactValue::Float(f) => assert_eq!(f, 5.0),
494 _ => panic!("Expected Float"),
495 }
496
497 match max_state.get_result() {
498 FactValue::Float(f) => assert_eq!(f, 25.0),
499 _ => panic!("Expected Float"),
500 }
501 }
502
503 #[test]
504 fn test_registry() {
505 let registry = AccumulateFunctionRegistry::new();
506
507 assert!(registry.get("sum").is_some());
508 assert!(registry.get("count").is_some());
509 assert!(registry.get("average").is_some());
510 assert!(registry.get("min").is_some());
511 assert!(registry.get("max").is_some());
512 assert!(registry.get("unknown").is_none());
513
514 let functions = registry.available_functions();
515 assert_eq!(functions.len(), 5);
516 }
517}