rust_rule_engine/streaming/
aggregator.rs

1//! Stream Aggregation Functions
2//!
3//! Provides various aggregation operations for streaming data analysis.
4
5use crate::streaming::event::StreamEvent;
6use crate::streaming::window::TimeWindow;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Type of aggregation to perform
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12pub enum AggregationType {
13    /// Count number of events
14    Count,
15    /// Sum numeric values
16    Sum { field: String },
17    /// Calculate average
18    Average { field: String },
19    /// Find minimum value
20    Min { field: String },
21    /// Find maximum value
22    Max { field: String },
23    /// Count distinct values
24    CountDistinct { field: String },
25    /// Calculate standard deviation
26    StdDev { field: String },
27    /// Calculate percentile
28    Percentile { field: String, percentile: f64 },
29    /// First event in window
30    First,
31    /// Last event in window
32    Last,
33    /// Count by category
34    CountBy { field: String },
35}
36
37/// Result of an aggregation operation
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum AggregationResult {
40    /// Numeric result
41    Number(f64),
42    /// String result
43    Text(String),
44    /// Boolean result
45    Boolean(bool),
46    /// Map of category counts
47    CountMap(HashMap<String, usize>),
48    /// No result (empty data)
49    None,
50}
51
52impl AggregationResult {
53    /// Convert to numeric value if possible
54    pub fn as_number(&self) -> Option<f64> {
55        match self {
56            AggregationResult::Number(n) => Some(*n),
57            _ => None,
58        }
59    }
60
61    /// Convert to string if possible
62    pub fn as_string(&self) -> Option<&str> {
63        match self {
64            AggregationResult::Text(s) => Some(s),
65            _ => None,
66        }
67    }
68
69    /// Convert to boolean if possible
70    pub fn as_boolean(&self) -> Option<bool> {
71        match self {
72            AggregationResult::Boolean(b) => Some(*b),
73            _ => None,
74        }
75    }
76}
77
78/// Aggregator for performing calculations on event streams
79#[derive(Debug)]
80#[allow(dead_code)]
81pub struct Aggregator {
82    /// Type of aggregation
83    aggregation_type: AggregationType,
84    /// Field to aggregate on (if applicable)
85    _field: Option<String>,
86}
87
88impl Aggregator {
89    /// Create a new aggregator
90    pub fn new(aggregation_type: AggregationType) -> Self {
91        let _field = match &aggregation_type {
92            AggregationType::Sum { field }
93            | AggregationType::Average { field }
94            | AggregationType::Min { field }
95            | AggregationType::Max { field }
96            | AggregationType::CountDistinct { field }
97            | AggregationType::StdDev { field }
98            | AggregationType::Percentile { field, .. }
99            | AggregationType::CountBy { field } => Some(field.clone()),
100            _ => None,
101        };
102
103        Self {
104            aggregation_type,
105            _field,
106        }
107    }
108
109    /// Perform aggregation on a time window
110    pub fn aggregate(&self, window: &TimeWindow) -> AggregationResult {
111        let events = window.events();
112
113        match &self.aggregation_type {
114            AggregationType::Count => AggregationResult::Number(events.len() as f64),
115
116            AggregationType::Sum { field } => {
117                let sum = window.sum(field);
118                AggregationResult::Number(sum)
119            }
120
121            AggregationType::Average { field } => match window.average(field) {
122                Some(avg) => AggregationResult::Number(avg),
123                None => AggregationResult::None,
124            },
125
126            AggregationType::Min { field } => match window.min(field) {
127                Some(min) => AggregationResult::Number(min),
128                None => AggregationResult::None,
129            },
130
131            AggregationType::Max { field } => match window.max(field) {
132                Some(max) => AggregationResult::Number(max),
133                None => AggregationResult::None,
134            },
135
136            AggregationType::CountDistinct { field } => {
137                let distinct_count = self.count_distinct_values(events, field);
138                AggregationResult::Number(distinct_count as f64)
139            }
140
141            AggregationType::StdDev { field } => {
142                let std_dev = self.calculate_std_dev(events, field);
143                match std_dev {
144                    Some(sd) => AggregationResult::Number(sd),
145                    None => AggregationResult::None,
146                }
147            }
148
149            AggregationType::Percentile { field, percentile } => {
150                let value = self.calculate_percentile(events, field, *percentile);
151                match value {
152                    Some(v) => AggregationResult::Number(v),
153                    None => AggregationResult::None,
154                }
155            }
156
157            AggregationType::First => match events.front() {
158                Some(event) => AggregationResult::Text(event.id.clone()),
159                None => AggregationResult::None,
160            },
161
162            AggregationType::Last => match events.back() {
163                Some(event) => AggregationResult::Text(event.id.clone()),
164                None => AggregationResult::None,
165            },
166
167            AggregationType::CountBy { field } => {
168                let counts = self.count_by_field(events, field);
169                AggregationResult::CountMap(counts)
170            }
171        }
172    }
173
174    /// Perform aggregation on a slice of events
175    pub fn aggregate_events(&self, events: &[StreamEvent]) -> AggregationResult {
176        match &self.aggregation_type {
177            AggregationType::Count => AggregationResult::Number(events.len() as f64),
178
179            AggregationType::Sum { field } => {
180                let sum: f64 = events.iter().filter_map(|e| e.get_numeric(field)).sum();
181                AggregationResult::Number(sum)
182            }
183
184            AggregationType::Average { field } => {
185                let values: Vec<f64> = events.iter().filter_map(|e| e.get_numeric(field)).collect();
186
187                if values.is_empty() {
188                    AggregationResult::None
189                } else {
190                    let avg = values.iter().sum::<f64>() / values.len() as f64;
191                    AggregationResult::Number(avg)
192                }
193            }
194
195            _ => {
196                // For other types, create a temporary window
197                // This is less efficient but provides compatibility
198                AggregationResult::None
199            }
200        }
201    }
202
203    /// Count distinct values in a field
204    fn count_distinct_values(
205        &self,
206        events: &std::collections::VecDeque<StreamEvent>,
207        field: &str,
208    ) -> usize {
209        let mut seen = std::collections::HashSet::new();
210
211        for event in events {
212            if let Some(value) = event.data.get(field) {
213                seen.insert(format!("{:?}", value));
214            }
215        }
216
217        seen.len()
218    }
219
220    /// Calculate standard deviation
221    fn calculate_std_dev(
222        &self,
223        events: &std::collections::VecDeque<StreamEvent>,
224        field: &str,
225    ) -> Option<f64> {
226        let values: Vec<f64> = events.iter().filter_map(|e| e.get_numeric(field)).collect();
227
228        if values.len() < 2 {
229            return None;
230        }
231
232        let mean = values.iter().sum::<f64>() / values.len() as f64;
233        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
234
235        Some(variance.sqrt())
236    }
237
238    /// Calculate percentile
239    fn calculate_percentile(
240        &self,
241        events: &std::collections::VecDeque<StreamEvent>,
242        field: &str,
243        percentile: f64,
244    ) -> Option<f64> {
245        let mut values: Vec<f64> = events.iter().filter_map(|e| e.get_numeric(field)).collect();
246
247        if values.is_empty() {
248            return None;
249        }
250
251        values.sort_by(|a, b| a.partial_cmp(b).unwrap());
252
253        let index = (percentile / 100.0 * (values.len() - 1) as f64).round() as usize;
254        values.get(index).copied()
255    }
256
257    /// Count occurrences by field value
258    fn count_by_field(
259        &self,
260        events: &std::collections::VecDeque<StreamEvent>,
261        field: &str,
262    ) -> HashMap<String, usize> {
263        let mut counts = HashMap::new();
264
265        for event in events {
266            if let Some(value) = event.data.get(field) {
267                let key = match value {
268                    crate::types::Value::String(s) => s.clone(),
269                    crate::types::Value::Number(n) => n.to_string(),
270                    crate::types::Value::Integer(i) => i.to_string(),
271                    crate::types::Value::Boolean(b) => b.to_string(),
272                    _ => format!("{:?}", value),
273                };
274
275                *counts.entry(key).or_insert(0) += 1;
276            }
277        }
278
279        counts
280    }
281}
282
283/// Stream analytics helper for complex aggregations
284#[derive(Debug)]
285pub struct StreamAnalytics {
286    /// Cache of recent calculations
287    cache: HashMap<String, (u64, AggregationResult)>,
288    /// Cache TTL in milliseconds
289    cache_ttl: u64,
290}
291
292impl StreamAnalytics {
293    /// Create new stream analytics instance
294    pub fn new(cache_ttl_ms: u64) -> Self {
295        Self {
296            cache: HashMap::new(),
297            cache_ttl: cache_ttl_ms,
298        }
299    }
300
301    /// Perform cached aggregation
302    pub fn aggregate_cached(
303        &mut self,
304        key: &str,
305        window: &TimeWindow,
306        aggregator: &Aggregator,
307        current_time: u64,
308    ) -> AggregationResult {
309        // Check cache
310        if let Some((timestamp, result)) = self.cache.get(key) {
311            if current_time - timestamp < self.cache_ttl {
312                return result.clone();
313            }
314        }
315
316        // Calculate new result
317        let result = aggregator.aggregate(window);
318        self.cache
319            .insert(key.to_string(), (current_time, result.clone()));
320
321        // Clean old cache entries
322        self.cache
323            .retain(|_, (timestamp, _)| current_time - *timestamp < self.cache_ttl);
324
325        result
326    }
327
328    /// Calculate moving average over multiple windows
329    pub fn moving_average(
330        &self,
331        windows: &[TimeWindow],
332        field: &str,
333        window_count: usize,
334    ) -> Option<f64> {
335        if windows.is_empty() {
336            return None;
337        }
338
339        let recent_windows = if windows.len() > window_count {
340            &windows[windows.len() - window_count..]
341        } else {
342            windows
343        };
344
345        let total_sum: f64 = recent_windows.iter().map(|w| w.sum(field)).sum();
346
347        let total_count: usize = recent_windows.iter().map(|w| w.count()).sum();
348
349        if total_count == 0 {
350            None
351        } else {
352            Some(total_sum / total_count as f64)
353        }
354    }
355
356    /// Detect anomalies using z-score
357    pub fn detect_anomalies(
358        &self,
359        windows: &[TimeWindow],
360        field: &str,
361        threshold: f64,
362    ) -> Vec<String> {
363        if windows.len() < 3 {
364            return Vec::new();
365        }
366
367        // Calculate baseline statistics from historical windows
368        let historical_windows = &windows[..windows.len() - 1];
369        let values: Vec<f64> = historical_windows
370            .iter()
371            .flat_map(|w| w.events().iter().filter_map(|e| e.get_numeric(field)))
372            .collect();
373
374        if values.len() < 10 {
375            return Vec::new();
376        }
377
378        let mean = values.iter().sum::<f64>() / values.len() as f64;
379        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
380        let std_dev = variance.sqrt();
381
382        // Check current window for anomalies
383        let current_window = windows.last().unwrap();
384        let mut anomalies = Vec::new();
385
386        for event in current_window.events() {
387            if let Some(value) = event.get_numeric(field) {
388                let z_score = (value - mean) / std_dev;
389                if z_score.abs() > threshold {
390                    anomalies.push(event.id.clone());
391                }
392            }
393        }
394
395        anomalies
396    }
397
398    /// Calculate trend direction
399    pub fn calculate_trend(&self, windows: &[TimeWindow], field: &str) -> TrendDirection {
400        if windows.len() < 2 {
401            return TrendDirection::Stable;
402        }
403
404        let averages: Vec<f64> = windows.iter().filter_map(|w| w.average(field)).collect();
405
406        if averages.len() < 2 {
407            return TrendDirection::Stable;
408        }
409
410        let first_half = &averages[..averages.len() / 2];
411        let second_half = &averages[averages.len() / 2..];
412
413        let first_avg = first_half.iter().sum::<f64>() / first_half.len() as f64;
414        let second_avg = second_half.iter().sum::<f64>() / second_half.len() as f64;
415
416        let change_percent = ((second_avg - first_avg) / first_avg) * 100.0;
417
418        if change_percent > 5.0 {
419            TrendDirection::Increasing
420        } else if change_percent < -5.0 {
421            TrendDirection::Decreasing
422        } else {
423            TrendDirection::Stable
424        }
425    }
426}
427
428/// Direction of trend analysis
429#[derive(Debug, Clone, PartialEq)]
430pub enum TrendDirection {
431    /// Values are increasing
432    Increasing,
433    /// Values are decreasing
434    Decreasing,
435    /// Values are stable
436    Stable,
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use crate::streaming::event::StreamEvent;
443    use crate::types::Value;
444    use std::collections::HashMap;
445
446    #[test]
447    fn test_count_aggregation() {
448        let aggregator = Aggregator::new(AggregationType::Count);
449        let events = create_test_events(5);
450
451        let result = aggregator.aggregate_events(&events);
452        assert_eq!(result.as_number(), Some(5.0));
453    }
454
455    #[test]
456    fn test_sum_aggregation() {
457        let aggregator = Aggregator::new(AggregationType::Sum {
458            field: "value".to_string(),
459        });
460        let events = create_test_events(5);
461
462        let result = aggregator.aggregate_events(&events);
463        assert_eq!(result.as_number(), Some(10.0)); // 0+1+2+3+4
464    }
465
466    #[test]
467    fn test_average_aggregation() {
468        let aggregator = Aggregator::new(AggregationType::Average {
469            field: "value".to_string(),
470        });
471        let events = create_test_events(5);
472
473        let result = aggregator.aggregate_events(&events);
474        assert_eq!(result.as_number(), Some(2.0));
475    }
476
477    fn create_test_events(count: usize) -> Vec<StreamEvent> {
478        (0..count)
479            .map(|i| {
480                let mut data = HashMap::new();
481                data.insert("value".to_string(), Value::Number(i as f64));
482                StreamEvent::new("TestEvent", data, "test")
483            })
484            .collect()
485    }
486}