rust_rule_engine/rete/
accumulate.rs

1//! Accumulate Functions for RETE-UL Engine
2//!
3//! This module implements Drools/CLIPS-style accumulate functions for aggregating
4//! data across multiple facts in rule conditions.
5//!
6//! # Examples
7//!
8//! ```grl
9//! rule "TotalSales" {
10//!     when
11//!         $total: accumulate(
12//!             Order($amount: amount, status == "completed"),
13//!             sum($amount)
14//!         )
15//!         $total > 10000
16//!     then
17//!         Report.highValue = true;
18//! }
19//! ```
20
21use std::collections::HashMap;
22use super::facts::FactValue;
23
24/// Accumulate function trait - defines how to aggregate values
25pub trait AccumulateFunction: Send + Sync {
26    /// Initialize the accumulator
27    fn init(&self) -> Box<dyn AccumulateState>;
28
29    /// Get the function name
30    fn name(&self) -> &str;
31
32    /// Clone the function
33    fn clone_box(&self) -> Box<dyn AccumulateFunction>;
34}
35
36/// State maintained during accumulation
37pub trait AccumulateState: Send {
38    /// Accumulate a new value
39    fn accumulate(&mut self, value: &FactValue);
40
41    /// Get the final result
42    fn get_result(&self) -> FactValue;
43
44    /// Reset the state
45    fn reset(&mut self);
46
47    /// Clone the state
48    fn clone_box(&self) -> Box<dyn AccumulateState>;
49}
50
51// ============================================================================
52// Built-in Accumulate Functions
53// ============================================================================
54
55/// Sum accumulator - adds up numeric values
56#[derive(Debug, Clone)]
57pub struct SumFunction;
58
59impl AccumulateFunction for SumFunction {
60    fn init(&self) -> Box<dyn AccumulateState> {
61        Box::new(SumState { total: 0.0 })
62    }
63
64    fn name(&self) -> &str {
65        "sum"
66    }
67
68    fn clone_box(&self) -> Box<dyn AccumulateFunction> {
69        Box::new(self.clone())
70    }
71}
72
73#[derive(Debug, Clone)]
74struct SumState {
75    total: f64,
76}
77
78impl AccumulateState for SumState {
79    fn accumulate(&mut self, value: &FactValue) {
80        match value {
81            FactValue::Integer(i) => self.total += *i as f64,
82            FactValue::Float(f) => self.total += f,
83            _ => {} // Ignore non-numeric values
84        }
85    }
86
87    fn get_result(&self) -> FactValue {
88        FactValue::Float(self.total)
89    }
90
91    fn reset(&mut self) {
92        self.total = 0.0;
93    }
94
95    fn clone_box(&self) -> Box<dyn AccumulateState> {
96        Box::new(self.clone())
97    }
98}
99
100/// Count accumulator - counts number of matching facts
101#[derive(Debug, Clone)]
102pub struct CountFunction;
103
104impl AccumulateFunction for CountFunction {
105    fn init(&self) -> Box<dyn AccumulateState> {
106        Box::new(CountState { count: 0 })
107    }
108
109    fn name(&self) -> &str {
110        "count"
111    }
112
113    fn clone_box(&self) -> Box<dyn AccumulateFunction> {
114        Box::new(self.clone())
115    }
116}
117
118#[derive(Debug, Clone)]
119struct CountState {
120    count: i64,
121}
122
123impl AccumulateState for CountState {
124    fn accumulate(&mut self, _value: &FactValue) {
125        self.count += 1;
126    }
127
128    fn get_result(&self) -> FactValue {
129        FactValue::Integer(self.count)
130    }
131
132    fn reset(&mut self) {
133        self.count = 0;
134    }
135
136    fn clone_box(&self) -> Box<dyn AccumulateState> {
137        Box::new(self.clone())
138    }
139}
140
141/// Average accumulator - calculates mean of numeric values
142#[derive(Debug, Clone)]
143pub struct AverageFunction;
144
145impl AccumulateFunction for AverageFunction {
146    fn init(&self) -> Box<dyn AccumulateState> {
147        Box::new(AverageState { sum: 0.0, count: 0 })
148    }
149
150    fn name(&self) -> &str {
151        "average"
152    }
153
154    fn clone_box(&self) -> Box<dyn AccumulateFunction> {
155        Box::new(self.clone())
156    }
157}
158
159#[derive(Debug, Clone)]
160struct AverageState {
161    sum: f64,
162    count: usize,
163}
164
165impl AccumulateState for AverageState {
166    fn accumulate(&mut self, value: &FactValue) {
167        match value {
168            FactValue::Integer(i) => {
169                self.sum += *i as f64;
170                self.count += 1;
171            }
172            FactValue::Float(f) => {
173                self.sum += f;
174                self.count += 1;
175            }
176            _ => {} // Ignore non-numeric values
177        }
178    }
179
180    fn get_result(&self) -> FactValue {
181        if self.count == 0 {
182            FactValue::Float(0.0)
183        } else {
184            FactValue::Float(self.sum / self.count as f64)
185        }
186    }
187
188    fn reset(&mut self) {
189        self.sum = 0.0;
190        self.count = 0;
191    }
192
193    fn clone_box(&self) -> Box<dyn AccumulateState> {
194        Box::new(self.clone())
195    }
196}
197
198/// Minimum accumulator - finds minimum numeric value
199#[derive(Debug, Clone)]
200pub struct MinFunction;
201
202impl AccumulateFunction for MinFunction {
203    fn init(&self) -> Box<dyn AccumulateState> {
204        Box::new(MinState { min: None })
205    }
206
207    fn name(&self) -> &str {
208        "min"
209    }
210
211    fn clone_box(&self) -> Box<dyn AccumulateFunction> {
212        Box::new(self.clone())
213    }
214}
215
216#[derive(Debug, Clone)]
217struct MinState {
218    min: Option<f64>,
219}
220
221impl AccumulateState for MinState {
222    fn accumulate(&mut self, value: &FactValue) {
223        let num = match value {
224            FactValue::Integer(i) => Some(*i as f64),
225            FactValue::Float(f) => Some(*f),
226            _ => None,
227        };
228
229        if let Some(n) = num {
230            self.min = Some(match self.min {
231                Some(current) => current.min(n),
232                None => n,
233            });
234        }
235    }
236
237    fn get_result(&self) -> FactValue {
238        match self.min {
239            Some(m) => FactValue::Float(m),
240            None => FactValue::Float(0.0),
241        }
242    }
243
244    fn reset(&mut self) {
245        self.min = None;
246    }
247
248    fn clone_box(&self) -> Box<dyn AccumulateState> {
249        Box::new(self.clone())
250    }
251}
252
253/// Maximum accumulator - finds maximum numeric value
254#[derive(Debug, Clone)]
255pub struct MaxFunction;
256
257impl AccumulateFunction for MaxFunction {
258    fn init(&self) -> Box<dyn AccumulateState> {
259        Box::new(MaxState { max: None })
260    }
261
262    fn name(&self) -> &str {
263        "max"
264    }
265
266    fn clone_box(&self) -> Box<dyn AccumulateFunction> {
267        Box::new(self.clone())
268    }
269}
270
271#[derive(Debug, Clone)]
272struct MaxState {
273    max: Option<f64>,
274}
275
276impl AccumulateState for MaxState {
277    fn accumulate(&mut self, value: &FactValue) {
278        let num = match value {
279            FactValue::Integer(i) => Some(*i as f64),
280            FactValue::Float(f) => Some(*f),
281            _ => None,
282        };
283
284        if let Some(n) = num {
285            self.max = Some(match self.max {
286                Some(current) => current.max(n),
287                None => n,
288            });
289        }
290    }
291
292    fn get_result(&self) -> FactValue {
293        match self.max {
294            Some(m) => FactValue::Float(m),
295            None => FactValue::Float(0.0),
296        }
297    }
298
299    fn reset(&mut self) {
300        self.max = None;
301    }
302
303    fn clone_box(&self) -> Box<dyn AccumulateState> {
304        Box::new(self.clone())
305    }
306}
307
308// ============================================================================
309// Accumulate Pattern - for use in RETE conditions
310// ============================================================================
311
312/// Accumulate pattern in a rule condition
313pub struct AccumulatePattern {
314    /// Variable to bind the result to (e.g., "$total")
315    pub result_var: String,
316
317    /// Source pattern to match facts (e.g., "Order")
318    pub source_pattern: String,
319
320    /// Field to extract from matching facts (e.g., "amount")
321    pub extract_field: String,
322
323    /// Conditions on the source pattern (e.g., "status == 'completed'")
324    pub source_conditions: Vec<String>,
325
326    /// Accumulate function to apply (sum, avg, count, etc.)
327    pub function: Box<dyn AccumulateFunction>,
328}
329
330impl Clone for AccumulatePattern {
331    fn clone(&self) -> Self {
332        Self {
333            result_var: self.result_var.clone(),
334            source_pattern: self.source_pattern.clone(),
335            extract_field: self.extract_field.clone(),
336            source_conditions: self.source_conditions.clone(),
337            function: self.function.clone_box(),
338        }
339    }
340}
341
342impl std::fmt::Debug for AccumulatePattern {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        f.debug_struct("AccumulatePattern")
345            .field("result_var", &self.result_var)
346            .field("source_pattern", &self.source_pattern)
347            .field("extract_field", &self.extract_field)
348            .field("source_conditions", &self.source_conditions)
349            .field("function", &self.function.name())
350            .finish()
351    }
352}
353
354impl AccumulatePattern {
355    /// Create a new accumulate pattern
356    pub fn new(
357        result_var: String,
358        source_pattern: String,
359        extract_field: String,
360        function: Box<dyn AccumulateFunction>,
361    ) -> Self {
362        Self {
363            result_var,
364            source_pattern,
365            extract_field,
366            source_conditions: Vec::new(),
367            function,
368        }
369    }
370
371    /// Add a condition to the source pattern
372    pub fn with_condition(mut self, condition: String) -> Self {
373        self.source_conditions.push(condition);
374        self
375    }
376}
377
378// ============================================================================
379// Accumulate Function Registry
380// ============================================================================
381
382/// Registry of available accumulate functions
383pub struct AccumulateFunctionRegistry {
384    functions: HashMap<String, Box<dyn AccumulateFunction>>,
385}
386
387impl AccumulateFunctionRegistry {
388    /// Create a new registry with built-in functions
389    pub fn new() -> Self {
390        let mut registry = Self {
391            functions: HashMap::new(),
392        };
393
394        // Register built-in functions
395        registry.register(Box::new(SumFunction));
396        registry.register(Box::new(CountFunction));
397        registry.register(Box::new(AverageFunction));
398        registry.register(Box::new(MinFunction));
399        registry.register(Box::new(MaxFunction));
400
401        registry
402    }
403
404    /// Register a custom accumulate function
405    pub fn register(&mut self, function: Box<dyn AccumulateFunction>) {
406        self.functions.insert(function.name().to_string(), function);
407    }
408
409    /// Get a function by name
410    pub fn get(&self, name: &str) -> Option<Box<dyn AccumulateFunction>> {
411        self.functions.get(name).map(|f| f.clone_box())
412    }
413
414    /// Get all available function names
415    pub fn available_functions(&self) -> Vec<String> {
416        self.functions.keys().cloned().collect()
417    }
418}
419
420impl Default for AccumulateFunctionRegistry {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426// ============================================================================
427// Tests
428// ============================================================================
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_sum_function() {
436        let sum = SumFunction;
437        let mut state = sum.init();
438
439        state.accumulate(&FactValue::Integer(10));
440        state.accumulate(&FactValue::Integer(20));
441        state.accumulate(&FactValue::Float(15.5));
442
443        match state.get_result() {
444            FactValue::Float(f) => assert_eq!(f, 45.5),
445            _ => panic!("Expected Float"),
446        }
447    }
448
449    #[test]
450    fn test_count_function() {
451        let count = CountFunction;
452        let mut state = count.init();
453
454        state.accumulate(&FactValue::Integer(10));
455        state.accumulate(&FactValue::String("test".to_string()));
456        state.accumulate(&FactValue::Boolean(true));
457
458        match state.get_result() {
459            FactValue::Integer(i) => assert_eq!(i, 3),
460            _ => panic!("Expected Integer"),
461        }
462    }
463
464    #[test]
465    fn test_average_function() {
466        let avg = AverageFunction;
467        let mut state = avg.init();
468
469        state.accumulate(&FactValue::Integer(10));
470        state.accumulate(&FactValue::Integer(20));
471        state.accumulate(&FactValue::Integer(30));
472
473        match state.get_result() {
474            FactValue::Float(f) => assert_eq!(f, 20.0),
475            _ => panic!("Expected Float"),
476        }
477    }
478
479    #[test]
480    fn test_min_max_functions() {
481        let min = MinFunction;
482        let max = MaxFunction;
483
484        let mut min_state = min.init();
485        let mut max_state = max.init();
486
487        for value in &[FactValue::Integer(15), FactValue::Integer(5), FactValue::Integer(25)] {
488            min_state.accumulate(value);
489            max_state.accumulate(value);
490        }
491
492        match min_state.get_result() {
493            FactValue::Float(f) => assert_eq!(f, 5.0),
494            _ => panic!("Expected Float"),
495        }
496
497        match max_state.get_result() {
498            FactValue::Float(f) => assert_eq!(f, 25.0),
499            _ => panic!("Expected Float"),
500        }
501    }
502
503    #[test]
504    fn test_registry() {
505        let registry = AccumulateFunctionRegistry::new();
506
507        assert!(registry.get("sum").is_some());
508        assert!(registry.get("count").is_some());
509        assert!(registry.get("average").is_some());
510        assert!(registry.get("min").is_some());
511        assert!(registry.get("max").is_some());
512        assert!(registry.get("unknown").is_none());
513
514        let functions = registry.available_functions();
515        assert_eq!(functions.len(), 5);
516    }
517}