rust_rule_engine/backward/
aggregation.rs

1//! Aggregation support for backward chaining queries
2//!
3//! Provides aggregate functions like COUNT, SUM, AVG, MIN, MAX
4//! for use in backward chaining queries.
5//!
6//! ## Example
7//!
8//! ```ignore
9//! use rust_rule_engine::backward::*;
10//!
11//! let mut engine = BackwardEngine::new(kb);
12//!
13//! // Count all employees
14//! let count = engine.query_aggregate(
15//!     "count(?x) WHERE employee(?x)",
16//!     &mut facts
17//! )?;
18//!
19//! // Sum of all salaries
20//! let total = engine.query_aggregate(
21//!     "sum(?salary) WHERE salary(?name, ?salary)",
22//!     &mut facts
23//! )?;
24//!
25//! // Average salary
26//! let avg = engine.query_aggregate(
27//!     "avg(?salary) WHERE salary(?name, ?salary) AND ?salary > 50000",
28//!     &mut facts
29//! )?;
30//! ```
31
32use super::search::Solution;
33use crate::errors::{Result, RuleEngineError};
34use crate::types::Value;
35
36/// Aggregate function types
37#[derive(Debug, Clone, PartialEq)]
38pub enum AggregateFunction {
39    /// Count number of solutions
40    Count,
41
42    /// Sum of field values
43    Sum(String),
44
45    /// Average of field values
46    Avg(String),
47
48    /// Minimum field value
49    Min(String),
50
51    /// Maximum field value
52    Max(String),
53
54    /// First solution
55    First,
56
57    /// Last solution
58    Last,
59}
60
61impl AggregateFunction {
62    /// Get the field name for field-based aggregates
63    pub fn field_name(&self) -> Option<&str> {
64        match self {
65            AggregateFunction::Sum(f)
66            | AggregateFunction::Avg(f)
67            | AggregateFunction::Min(f)
68            | AggregateFunction::Max(f) => Some(f),
69            _ => None,
70        }
71    }
72}
73
74/// Parsed aggregate query
75#[derive(Debug, Clone)]
76pub struct AggregateQuery {
77    /// The aggregate function to apply
78    pub function: AggregateFunction,
79
80    /// The goal pattern to match
81    pub pattern: String,
82
83    /// Optional filter condition
84    pub filter: Option<String>,
85}
86
87impl AggregateQuery {
88    /// Create a new aggregate query
89    pub fn new(function: AggregateFunction, pattern: String) -> Self {
90        Self {
91            function,
92            pattern,
93            filter: None,
94        }
95    }
96
97    /// Add a filter condition
98    pub fn with_filter(mut self, filter: String) -> Self {
99        self.filter = Some(filter);
100        self
101    }
102}
103
104/// Parse an aggregate query string
105///
106/// Supported formats:
107/// - `count(?x) WHERE pattern`
108/// - `sum(?field) WHERE pattern`
109/// - `avg(?field) WHERE pattern AND ?field > 100`
110/// - `min(?field) WHERE pattern`
111/// - `max(?field) WHERE pattern`
112pub fn parse_aggregate_query(query: &str) -> Result<AggregateQuery> {
113    let query = query.trim();
114
115    // Split on WHERE keyword
116    let parts: Vec<&str> = query.splitn(2, " WHERE ").collect();
117    if parts.len() != 2 {
118        return Err(RuleEngineError::ParseError {
119            message: format!("Invalid aggregate query format. Expected: 'function(?var) WHERE pattern'. Got: '{}'", query),
120        });
121    }
122
123    let func_part = parts[0].trim();
124    let pattern_part = parts[1].trim();
125
126    // Parse function and variable
127    let (func_name, var_name) = parse_function_call(func_part)?;
128
129    // Create aggregate function
130    let function = match func_name.to_lowercase().as_str() {
131        "count" => AggregateFunction::Count,
132        "sum" => {
133            if var_name.is_empty() {
134                return Err(RuleEngineError::ParseError {
135                    message: "sum() requires a variable, e.g., sum(?amount)".to_string(),
136                });
137            }
138            AggregateFunction::Sum(var_name.to_string())
139        }
140        "avg" => {
141            if var_name.is_empty() {
142                return Err(RuleEngineError::ParseError {
143                    message: "avg() requires a variable, e.g., avg(?salary)".to_string(),
144                });
145            }
146            AggregateFunction::Avg(var_name.to_string())
147        }
148        "min" => {
149            if var_name.is_empty() {
150                return Err(RuleEngineError::ParseError {
151                    message: "min() requires a variable, e.g., min(?price)".to_string(),
152                });
153            }
154            AggregateFunction::Min(var_name.to_string())
155        }
156        "max" => {
157            if var_name.is_empty() {
158                return Err(RuleEngineError::ParseError {
159                    message: "max() requires a variable, e.g., max(?score)".to_string(),
160                });
161            }
162            AggregateFunction::Max(var_name.to_string())
163        }
164        "first" => AggregateFunction::First,
165        "last" => AggregateFunction::Last,
166        _ => {
167            return Err(RuleEngineError::ParseError {
168                message: format!("Unknown aggregate function: '{}'. Supported: count, sum, avg, min, max, first, last", func_name),
169            });
170        }
171    };
172
173    // Split pattern and filter (on AND)
174    let (pattern, filter) = if pattern_part.contains(" AND ") {
175        let parts: Vec<&str> = pattern_part.splitn(2, " AND ").collect();
176        (
177            parts[0].trim().to_string(),
178            Some(parts[1].trim().to_string()),
179        )
180    } else {
181        (pattern_part.to_string(), None)
182    };
183
184    Ok(AggregateQuery {
185        function,
186        pattern,
187        filter,
188    })
189}
190
191/// Parse a function call like "count(?x)" or "sum(?amount)"
192fn parse_function_call(s: &str) -> Result<(String, String)> {
193    let s = s.trim();
194
195    // Find opening parenthesis
196    let open_idx = s.find('(').ok_or_else(|| RuleEngineError::ParseError {
197        message: format!("Expected '(' in function call: '{}'", s),
198    })?;
199
200    // Find closing parenthesis
201    let close_idx = s.rfind(')').ok_or_else(|| RuleEngineError::ParseError {
202        message: format!("Expected ')' in function call: '{}'", s),
203    })?;
204
205    if close_idx <= open_idx {
206        return Err(RuleEngineError::ParseError {
207            message: format!("Invalid function call syntax: '{}'", s),
208        });
209    }
210
211    let func_name = s[..open_idx].trim().to_string();
212    let var_name = s[open_idx + 1..close_idx].trim().to_string();
213
214    // Remove leading ? from variable name if present
215    let var_name = if let Some(stripped) = var_name.strip_prefix('?') {
216        stripped.to_string()
217    } else {
218        var_name
219    };
220
221    Ok((func_name, var_name))
222}
223
224/// Apply aggregate function to solutions
225pub fn apply_aggregate(function: &AggregateFunction, solutions: &[Solution]) -> Result<Value> {
226    if solutions.is_empty() {
227        // Return appropriate zero value
228        return Ok(match function {
229            AggregateFunction::Count => Value::Integer(0),
230            AggregateFunction::Sum(_) => Value::Number(0.0),
231            AggregateFunction::Avg(_) => Value::Number(0.0),
232            AggregateFunction::Min(_) => Value::Null,
233            AggregateFunction::Max(_) => Value::Null,
234            AggregateFunction::First => Value::Null,
235            AggregateFunction::Last => Value::Null,
236        });
237    }
238
239    match function {
240        AggregateFunction::Count => Ok(Value::Integer(solutions.len() as i64)),
241
242        AggregateFunction::Sum(field) => {
243            let sum: f64 = solutions
244                .iter()
245                .filter_map(|s| s.bindings.get(field))
246                .filter_map(|v| value_to_float(v).ok())
247                .sum();
248            Ok(Value::Number(sum))
249        }
250
251        AggregateFunction::Avg(field) => {
252            let values: Vec<f64> = solutions
253                .iter()
254                .filter_map(|s| s.bindings.get(field))
255                .filter_map(|v| value_to_float(v).ok())
256                .collect();
257
258            if values.is_empty() {
259                Ok(Value::Number(0.0))
260            } else {
261                let sum: f64 = values.iter().sum();
262                Ok(Value::Number(sum / values.len() as f64))
263            }
264        }
265
266        AggregateFunction::Min(field) => {
267            let min = solutions
268                .iter()
269                .filter_map(|s| s.bindings.get(field))
270                .filter_map(|v| value_to_float(v).ok())
271                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273            Ok(min.map(Value::Number).unwrap_or(Value::Null))
274        }
275
276        AggregateFunction::Max(field) => {
277            let max = solutions
278                .iter()
279                .filter_map(|s| s.bindings.get(field))
280                .filter_map(|v| value_to_float(v).ok())
281                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
282
283            Ok(max.map(Value::Number).unwrap_or(Value::Null))
284        }
285
286        AggregateFunction::First => {
287            Ok(solutions
288                .first()
289                .and_then(|s| {
290                    // Return the first non-null binding
291                    s.bindings.values().next().cloned()
292                })
293                .unwrap_or(Value::Null))
294        }
295
296        AggregateFunction::Last => {
297            Ok(solutions
298                .last()
299                .and_then(|s| {
300                    // Return the last non-null binding
301                    s.bindings.values().last().cloned()
302                })
303                .unwrap_or(Value::Null))
304        }
305    }
306}
307
308/// Convert a Value to f64 for numeric aggregations
309fn value_to_float(value: &Value) -> Result<f64> {
310    match value {
311        Value::Number(n) => Ok(*n),
312        Value::Integer(i) => Ok(*i as f64),
313        Value::String(s) => s
314            .parse::<f64>()
315            .map_err(|_| RuleEngineError::EvaluationError {
316                message: format!("Cannot convert '{}' to number", s),
317            }),
318        _ => Err(RuleEngineError::EvaluationError {
319            message: format!("Cannot aggregate non-numeric value: {:?}", value),
320        }),
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use std::collections::HashMap;
328
329    #[test]
330    fn test_parse_count_query() {
331        let query = "count(?x) WHERE employee(?x)";
332        let result = parse_aggregate_query(query).unwrap();
333
334        assert_eq!(result.function, AggregateFunction::Count);
335        assert_eq!(result.pattern, "employee(?x)");
336        assert_eq!(result.filter, None);
337    }
338
339    #[test]
340    fn test_parse_sum_query() {
341        let query = "sum(?amount) WHERE purchase(?item, ?amount)";
342        let result = parse_aggregate_query(query).unwrap();
343
344        assert_eq!(
345            result.function,
346            AggregateFunction::Sum("amount".to_string())
347        );
348        assert_eq!(result.pattern, "purchase(?item, ?amount)");
349    }
350
351    #[test]
352    fn test_parse_avg_with_filter() {
353        let query = "avg(?salary) WHERE salary(?name, ?salary) AND ?salary > 50000";
354        let result = parse_aggregate_query(query).unwrap();
355
356        assert_eq!(
357            result.function,
358            AggregateFunction::Avg("salary".to_string())
359        );
360        assert_eq!(result.pattern, "salary(?name, ?salary)");
361        assert_eq!(result.filter, Some("?salary > 50000".to_string()));
362    }
363
364    #[test]
365    fn test_parse_min_query() {
366        let query = "min(?price) WHERE product(?name, ?price)";
367        let result = parse_aggregate_query(query).unwrap();
368
369        assert_eq!(result.function, AggregateFunction::Min("price".to_string()));
370    }
371
372    #[test]
373    fn test_parse_max_query() {
374        let query = "max(?score) WHERE student(?name, ?score)";
375        let result = parse_aggregate_query(query).unwrap();
376
377        assert_eq!(result.function, AggregateFunction::Max("score".to_string()));
378    }
379
380    #[test]
381    fn test_parse_invalid_query() {
382        let query = "count(?x)"; // Missing WHERE
383        let result = parse_aggregate_query(query);
384        assert!(result.is_err());
385    }
386
387    #[test]
388    fn test_parse_unknown_function() {
389        let query = "unknown(?x) WHERE test(?x)";
390        let result = parse_aggregate_query(query);
391        assert!(result.is_err());
392    }
393
394    #[test]
395    fn test_apply_count() {
396        let solutions = vec![
397            Solution {
398                path: vec![],
399                bindings: HashMap::new(),
400            },
401            Solution {
402                path: vec![],
403                bindings: HashMap::new(),
404            },
405            Solution {
406                path: vec![],
407                bindings: HashMap::new(),
408            },
409        ];
410
411        let result = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
412        assert_eq!(result, Value::Integer(3));
413    }
414
415    #[test]
416    fn test_apply_sum() {
417        let mut b1 = HashMap::new();
418        b1.insert("amount".to_string(), Value::Number(100.0));
419
420        let mut b2 = HashMap::new();
421        b2.insert("amount".to_string(), Value::Number(200.0));
422
423        let mut b3 = HashMap::new();
424        b3.insert("amount".to_string(), Value::Number(300.0));
425
426        let solutions = vec![
427            Solution {
428                path: vec![],
429                bindings: b1,
430            },
431            Solution {
432                path: vec![],
433                bindings: b2,
434            },
435            Solution {
436                path: vec![],
437                bindings: b3,
438            },
439        ];
440
441        let result =
442            apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
443        assert_eq!(result, Value::Number(600.0));
444    }
445
446    #[test]
447    fn test_apply_avg() {
448        let mut b1 = HashMap::new();
449        b1.insert("score".to_string(), Value::Integer(80));
450
451        let mut b2 = HashMap::new();
452        b2.insert("score".to_string(), Value::Integer(90));
453
454        let mut b3 = HashMap::new();
455        b3.insert("score".to_string(), Value::Integer(100));
456
457        let solutions = vec![
458            Solution {
459                path: vec![],
460                bindings: b1,
461            },
462            Solution {
463                path: vec![],
464                bindings: b2,
465            },
466            Solution {
467                path: vec![],
468                bindings: b3,
469            },
470        ];
471
472        let result =
473            apply_aggregate(&AggregateFunction::Avg("score".to_string()), &solutions).unwrap();
474        assert_eq!(result, Value::Number(90.0));
475    }
476
477    #[test]
478    fn test_apply_min() {
479        let mut b1 = HashMap::new();
480        b1.insert("price".to_string(), Value::Number(99.99));
481
482        let mut b2 = HashMap::new();
483        b2.insert("price".to_string(), Value::Number(49.99));
484
485        let mut b3 = HashMap::new();
486        b3.insert("price".to_string(), Value::Number(149.99));
487
488        let solutions = vec![
489            Solution {
490                path: vec![],
491                bindings: b1,
492            },
493            Solution {
494                path: vec![],
495                bindings: b2,
496            },
497            Solution {
498                path: vec![],
499                bindings: b3,
500            },
501        ];
502
503        let result =
504            apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
505        assert_eq!(result, Value::Number(49.99));
506    }
507
508    #[test]
509    fn test_apply_max() {
510        let mut b1 = HashMap::new();
511        b1.insert("price".to_string(), Value::Number(99.99));
512
513        let mut b2 = HashMap::new();
514        b2.insert("price".to_string(), Value::Number(49.99));
515
516        let mut b3 = HashMap::new();
517        b3.insert("price".to_string(), Value::Number(149.99));
518
519        let solutions = vec![
520            Solution {
521                path: vec![],
522                bindings: b1,
523            },
524            Solution {
525                path: vec![],
526                bindings: b2,
527            },
528            Solution {
529                path: vec![],
530                bindings: b3,
531            },
532        ];
533
534        let result =
535            apply_aggregate(&AggregateFunction::Max("price".to_string()), &solutions).unwrap();
536        assert_eq!(result, Value::Number(149.99));
537    }
538
539    #[test]
540    fn test_apply_empty_solutions() {
541        let solutions = vec![];
542
543        let count = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
544        assert_eq!(count, Value::Integer(0));
545
546        let sum =
547            apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
548        assert_eq!(sum, Value::Number(0.0));
549
550        let min =
551            apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
552        assert_eq!(min, Value::Null);
553    }
554}