rust_rule_engine/backward/
aggregation.rs

1//! Aggregation support for backward chaining queries
2//!
3//! Provides aggregate functions like COUNT, SUM, AVG, MIN, MAX
4//! for use in backward chaining queries.
5//!
6//! ## Example
7//!
8//! ```ignore
9//! use rust_rule_engine::backward::*;
10//!
11//! let mut engine = BackwardEngine::new(kb);
12//!
13//! // Count all employees
14//! let count = engine.query_aggregate(
15//!     "count(?x) WHERE employee(?x)",
16//!     &mut facts
17//! )?;
18//!
19//! // Sum of all salaries
20//! let total = engine.query_aggregate(
21//!     "sum(?salary) WHERE salary(?name, ?salary)",
22//!     &mut facts
23//! )?;
24//!
25//! // Average salary
26//! let avg = engine.query_aggregate(
27//!     "avg(?salary) WHERE salary(?name, ?salary) AND ?salary > 50000",
28//!     &mut facts
29//! )?;
30//! ```
31
32use crate::types::Value;
33use crate::errors::{Result, RuleEngineError};
34use super::search::Solution;
35use std::collections::HashMap;
36
37/// Aggregate function types
38#[derive(Debug, Clone, PartialEq)]
39pub enum AggregateFunction {
40    /// Count number of solutions
41    Count,
42
43    /// Sum of field values
44    Sum(String),
45
46    /// Average of field values
47    Avg(String),
48
49    /// Minimum field value
50    Min(String),
51
52    /// Maximum field value
53    Max(String),
54
55    /// First solution
56    First,
57
58    /// Last solution
59    Last,
60}
61
62impl AggregateFunction {
63    /// Get the field name for field-based aggregates
64    pub fn field_name(&self) -> Option<&str> {
65        match self {
66            AggregateFunction::Sum(f) |
67            AggregateFunction::Avg(f) |
68            AggregateFunction::Min(f) |
69            AggregateFunction::Max(f) => Some(f),
70            _ => None,
71        }
72    }
73}
74
75/// Parsed aggregate query
76#[derive(Debug, Clone)]
77pub struct AggregateQuery {
78    /// The aggregate function to apply
79    pub function: AggregateFunction,
80
81    /// The goal pattern to match
82    pub pattern: String,
83
84    /// Optional filter condition
85    pub filter: Option<String>,
86}
87
88impl AggregateQuery {
89    /// Create a new aggregate query
90    pub fn new(function: AggregateFunction, pattern: String) -> Self {
91        Self {
92            function,
93            pattern,
94            filter: None,
95        }
96    }
97
98    /// Add a filter condition
99    pub fn with_filter(mut self, filter: String) -> Self {
100        self.filter = Some(filter);
101        self
102    }
103}
104
105/// Parse an aggregate query string
106///
107/// Supported formats:
108/// - `count(?x) WHERE pattern`
109/// - `sum(?field) WHERE pattern`
110/// - `avg(?field) WHERE pattern AND ?field > 100`
111/// - `min(?field) WHERE pattern`
112/// - `max(?field) WHERE pattern`
113pub fn parse_aggregate_query(query: &str) -> Result<AggregateQuery> {
114    let query = query.trim();
115
116    // Split on WHERE keyword
117    let parts: Vec<&str> = query.splitn(2, " WHERE ").collect();
118    if parts.len() != 2 {
119        return Err(RuleEngineError::ParseError {
120            message: format!("Invalid aggregate query format. Expected: 'function(?var) WHERE pattern'. Got: '{}'", query),
121        });
122    }
123
124    let func_part = parts[0].trim();
125    let pattern_part = parts[1].trim();
126
127    // Parse function and variable
128    let (func_name, var_name) = parse_function_call(func_part)?;
129
130    // Create aggregate function
131    let function = match func_name.to_lowercase().as_str() {
132        "count" => AggregateFunction::Count,
133        "sum" => {
134            if var_name.is_empty() {
135                return Err(RuleEngineError::ParseError {
136                    message: "sum() requires a variable, e.g., sum(?amount)".to_string(),
137                });
138            }
139            AggregateFunction::Sum(var_name.to_string())
140        }
141        "avg" => {
142            if var_name.is_empty() {
143                return Err(RuleEngineError::ParseError {
144                    message: "avg() requires a variable, e.g., avg(?salary)".to_string(),
145                });
146            }
147            AggregateFunction::Avg(var_name.to_string())
148        }
149        "min" => {
150            if var_name.is_empty() {
151                return Err(RuleEngineError::ParseError {
152                    message: "min() requires a variable, e.g., min(?price)".to_string(),
153                });
154            }
155            AggregateFunction::Min(var_name.to_string())
156        }
157        "max" => {
158            if var_name.is_empty() {
159                return Err(RuleEngineError::ParseError {
160                    message: "max() requires a variable, e.g., max(?score)".to_string(),
161                });
162            }
163            AggregateFunction::Max(var_name.to_string())
164        }
165        "first" => AggregateFunction::First,
166        "last" => AggregateFunction::Last,
167        _ => {
168            return Err(RuleEngineError::ParseError {
169                message: format!("Unknown aggregate function: '{}'. Supported: count, sum, avg, min, max, first, last", func_name),
170            });
171        }
172    };
173
174    // Split pattern and filter (on AND)
175    let (pattern, filter) = if pattern_part.contains(" AND ") {
176        let parts: Vec<&str> = pattern_part.splitn(2, " AND ").collect();
177        (parts[0].trim().to_string(), Some(parts[1].trim().to_string()))
178    } else {
179        (pattern_part.to_string(), None)
180    };
181
182    Ok(AggregateQuery {
183        function,
184        pattern,
185        filter,
186    })
187}
188
189/// Parse a function call like "count(?x)" or "sum(?amount)"
190fn parse_function_call(s: &str) -> Result<(String, String)> {
191    let s = s.trim();
192
193    // Find opening parenthesis
194    let open_idx = s.find('(').ok_or_else(|| RuleEngineError::ParseError {
195        message: format!("Expected '(' in function call: '{}'", s),
196    })?;
197
198    // Find closing parenthesis
199    let close_idx = s.rfind(')').ok_or_else(|| RuleEngineError::ParseError {
200        message: format!("Expected ')' in function call: '{}'", s),
201    })?;
202
203    if close_idx <= open_idx {
204        return Err(RuleEngineError::ParseError {
205            message: format!("Invalid function call syntax: '{}'", s),
206        });
207    }
208
209    let func_name = s[..open_idx].trim().to_string();
210    let var_name = s[open_idx + 1..close_idx].trim().to_string();
211
212    // Remove leading ? from variable name if present
213    let var_name = if var_name.starts_with('?') {
214        var_name[1..].to_string()
215    } else {
216        var_name
217    };
218
219    Ok((func_name, var_name))
220}
221
222/// Apply aggregate function to solutions
223pub fn apply_aggregate(
224    function: &AggregateFunction,
225    solutions: &[Solution],
226) -> Result<Value> {
227    if solutions.is_empty() {
228        // Return appropriate zero value
229        return Ok(match function {
230            AggregateFunction::Count => Value::Integer(0),
231            AggregateFunction::Sum(_) => Value::Number(0.0),
232            AggregateFunction::Avg(_) => Value::Number(0.0),
233            AggregateFunction::Min(_) => Value::Null,
234            AggregateFunction::Max(_) => Value::Null,
235            AggregateFunction::First => Value::Null,
236            AggregateFunction::Last => Value::Null,
237        });
238    }
239
240    match function {
241        AggregateFunction::Count => {
242            Ok(Value::Integer(solutions.len() as i64))
243        }
244
245        AggregateFunction::Sum(field) => {
246            let sum: f64 = solutions.iter()
247                .filter_map(|s| s.bindings.get(field))
248                .filter_map(|v| value_to_float(v).ok())
249                .sum();
250            Ok(Value::Number(sum))
251        }
252
253        AggregateFunction::Avg(field) => {
254            let values: Vec<f64> = solutions.iter()
255                .filter_map(|s| s.bindings.get(field))
256                .filter_map(|v| value_to_float(v).ok())
257                .collect();
258
259            if values.is_empty() {
260                Ok(Value::Number(0.0))
261            } else {
262                let sum: f64 = values.iter().sum();
263                Ok(Value::Number(sum / values.len() as f64))
264            }
265        }
266
267        AggregateFunction::Min(field) => {
268            let min = solutions.iter()
269                .filter_map(|s| s.bindings.get(field))
270                .filter_map(|v| value_to_float(v).ok())
271                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273            Ok(min.map(Value::Number).unwrap_or(Value::Null))
274        }
275
276        AggregateFunction::Max(field) => {
277            let max = solutions.iter()
278                .filter_map(|s| s.bindings.get(field))
279                .filter_map(|v| value_to_float(v).ok())
280                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
281
282            Ok(max.map(Value::Number).unwrap_or(Value::Null))
283        }
284
285        AggregateFunction::First => {
286            Ok(solutions.first()
287                .and_then(|s| {
288                    // Return the first non-null binding
289                    s.bindings.values().next().cloned()
290                })
291                .unwrap_or(Value::Null))
292        }
293
294        AggregateFunction::Last => {
295            Ok(solutions.last()
296                .and_then(|s| {
297                    // Return the last non-null binding
298                    s.bindings.values().last().cloned()
299                })
300                .unwrap_or(Value::Null))
301        }
302    }
303}
304
305/// Convert a Value to f64 for numeric aggregations
306fn value_to_float(value: &Value) -> Result<f64> {
307    match value {
308        Value::Number(n) => Ok(*n),
309        Value::Integer(i) => Ok(*i as f64),
310        Value::String(s) => s.parse::<f64>().map_err(|_| {
311            RuleEngineError::EvaluationError {
312                message: format!("Cannot convert '{}' to number", s),
313            }
314        }),
315        _ => Err(RuleEngineError::EvaluationError {
316            message: format!("Cannot aggregate non-numeric value: {:?}", value),
317        }),
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_parse_count_query() {
327        let query = "count(?x) WHERE employee(?x)";
328        let result = parse_aggregate_query(query).unwrap();
329
330        assert_eq!(result.function, AggregateFunction::Count);
331        assert_eq!(result.pattern, "employee(?x)");
332        assert_eq!(result.filter, None);
333    }
334
335    #[test]
336    fn test_parse_sum_query() {
337        let query = "sum(?amount) WHERE purchase(?item, ?amount)";
338        let result = parse_aggregate_query(query).unwrap();
339
340        assert_eq!(result.function, AggregateFunction::Sum("amount".to_string()));
341        assert_eq!(result.pattern, "purchase(?item, ?amount)");
342    }
343
344    #[test]
345    fn test_parse_avg_with_filter() {
346        let query = "avg(?salary) WHERE salary(?name, ?salary) AND ?salary > 50000";
347        let result = parse_aggregate_query(query).unwrap();
348
349        assert_eq!(result.function, AggregateFunction::Avg("salary".to_string()));
350        assert_eq!(result.pattern, "salary(?name, ?salary)");
351        assert_eq!(result.filter, Some("?salary > 50000".to_string()));
352    }
353
354    #[test]
355    fn test_parse_min_query() {
356        let query = "min(?price) WHERE product(?name, ?price)";
357        let result = parse_aggregate_query(query).unwrap();
358
359        assert_eq!(result.function, AggregateFunction::Min("price".to_string()));
360    }
361
362    #[test]
363    fn test_parse_max_query() {
364        let query = "max(?score) WHERE student(?name, ?score)";
365        let result = parse_aggregate_query(query).unwrap();
366
367        assert_eq!(result.function, AggregateFunction::Max("score".to_string()));
368    }
369
370    #[test]
371    fn test_parse_invalid_query() {
372        let query = "count(?x)"; // Missing WHERE
373        let result = parse_aggregate_query(query);
374        assert!(result.is_err());
375    }
376
377    #[test]
378    fn test_parse_unknown_function() {
379        let query = "unknown(?x) WHERE test(?x)";
380        let result = parse_aggregate_query(query);
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_apply_count() {
386        let solutions = vec![
387            Solution { path: vec![], bindings: HashMap::new() },
388            Solution { path: vec![], bindings: HashMap::new() },
389            Solution { path: vec![], bindings: HashMap::new() },
390        ];
391
392        let result = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
393        assert_eq!(result, Value::Integer(3));
394    }
395
396    #[test]
397    fn test_apply_sum() {
398        let mut b1 = HashMap::new();
399        b1.insert("amount".to_string(), Value::Number(100.0));
400
401        let mut b2 = HashMap::new();
402        b2.insert("amount".to_string(), Value::Number(200.0));
403
404        let mut b3 = HashMap::new();
405        b3.insert("amount".to_string(), Value::Number(300.0));
406
407        let solutions = vec![
408            Solution { path: vec![], bindings: b1 },
409            Solution { path: vec![], bindings: b2 },
410            Solution { path: vec![], bindings: b3 },
411        ];
412
413        let result = apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
414        assert_eq!(result, Value::Number(600.0));
415    }
416
417    #[test]
418    fn test_apply_avg() {
419        let mut b1 = HashMap::new();
420        b1.insert("score".to_string(), Value::Integer(80));
421
422        let mut b2 = HashMap::new();
423        b2.insert("score".to_string(), Value::Integer(90));
424
425        let mut b3 = HashMap::new();
426        b3.insert("score".to_string(), Value::Integer(100));
427
428        let solutions = vec![
429            Solution { path: vec![], bindings: b1 },
430            Solution { path: vec![], bindings: b2 },
431            Solution { path: vec![], bindings: b3 },
432        ];
433
434        let result = apply_aggregate(&AggregateFunction::Avg("score".to_string()), &solutions).unwrap();
435        assert_eq!(result, Value::Number(90.0));
436    }
437
438    #[test]
439    fn test_apply_min() {
440        let mut b1 = HashMap::new();
441        b1.insert("price".to_string(), Value::Number(99.99));
442
443        let mut b2 = HashMap::new();
444        b2.insert("price".to_string(), Value::Number(49.99));
445
446        let mut b3 = HashMap::new();
447        b3.insert("price".to_string(), Value::Number(149.99));
448
449        let solutions = vec![
450            Solution { path: vec![], bindings: b1 },
451            Solution { path: vec![], bindings: b2 },
452            Solution { path: vec![], bindings: b3 },
453        ];
454
455        let result = apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
456        assert_eq!(result, Value::Number(49.99));
457    }
458
459    #[test]
460    fn test_apply_max() {
461        let mut b1 = HashMap::new();
462        b1.insert("price".to_string(), Value::Number(99.99));
463
464        let mut b2 = HashMap::new();
465        b2.insert("price".to_string(), Value::Number(49.99));
466
467        let mut b3 = HashMap::new();
468        b3.insert("price".to_string(), Value::Number(149.99));
469
470        let solutions = vec![
471            Solution { path: vec![], bindings: b1 },
472            Solution { path: vec![], bindings: b2 },
473            Solution { path: vec![], bindings: b3 },
474        ];
475
476        let result = apply_aggregate(&AggregateFunction::Max("price".to_string()), &solutions).unwrap();
477        assert_eq!(result, Value::Number(149.99));
478    }
479
480    #[test]
481    fn test_apply_empty_solutions() {
482        let solutions = vec![];
483
484        let count = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
485        assert_eq!(count, Value::Integer(0));
486
487        let sum = apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
488        assert_eq!(sum, Value::Number(0.0));
489
490        let min = apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
491        assert_eq!(min, Value::Null);
492    }
493}