rust_rule_engine/backward/
aggregation.rs1use crate::types::Value;
33use crate::errors::{Result, RuleEngineError};
34use super::search::Solution;
35use std::collections::HashMap;
36
37#[derive(Debug, Clone, PartialEq)]
39pub enum AggregateFunction {
40 Count,
42
43 Sum(String),
45
46 Avg(String),
48
49 Min(String),
51
52 Max(String),
54
55 First,
57
58 Last,
60}
61
62impl AggregateFunction {
63 pub fn field_name(&self) -> Option<&str> {
65 match self {
66 AggregateFunction::Sum(f) |
67 AggregateFunction::Avg(f) |
68 AggregateFunction::Min(f) |
69 AggregateFunction::Max(f) => Some(f),
70 _ => None,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct AggregateQuery {
78 pub function: AggregateFunction,
80
81 pub pattern: String,
83
84 pub filter: Option<String>,
86}
87
88impl AggregateQuery {
89 pub fn new(function: AggregateFunction, pattern: String) -> Self {
91 Self {
92 function,
93 pattern,
94 filter: None,
95 }
96 }
97
98 pub fn with_filter(mut self, filter: String) -> Self {
100 self.filter = Some(filter);
101 self
102 }
103}
104
105pub fn parse_aggregate_query(query: &str) -> Result<AggregateQuery> {
114 let query = query.trim();
115
116 let parts: Vec<&str> = query.splitn(2, " WHERE ").collect();
118 if parts.len() != 2 {
119 return Err(RuleEngineError::ParseError {
120 message: format!("Invalid aggregate query format. Expected: 'function(?var) WHERE pattern'. Got: '{}'", query),
121 });
122 }
123
124 let func_part = parts[0].trim();
125 let pattern_part = parts[1].trim();
126
127 let (func_name, var_name) = parse_function_call(func_part)?;
129
130 let function = match func_name.to_lowercase().as_str() {
132 "count" => AggregateFunction::Count,
133 "sum" => {
134 if var_name.is_empty() {
135 return Err(RuleEngineError::ParseError {
136 message: "sum() requires a variable, e.g., sum(?amount)".to_string(),
137 });
138 }
139 AggregateFunction::Sum(var_name.to_string())
140 }
141 "avg" => {
142 if var_name.is_empty() {
143 return Err(RuleEngineError::ParseError {
144 message: "avg() requires a variable, e.g., avg(?salary)".to_string(),
145 });
146 }
147 AggregateFunction::Avg(var_name.to_string())
148 }
149 "min" => {
150 if var_name.is_empty() {
151 return Err(RuleEngineError::ParseError {
152 message: "min() requires a variable, e.g., min(?price)".to_string(),
153 });
154 }
155 AggregateFunction::Min(var_name.to_string())
156 }
157 "max" => {
158 if var_name.is_empty() {
159 return Err(RuleEngineError::ParseError {
160 message: "max() requires a variable, e.g., max(?score)".to_string(),
161 });
162 }
163 AggregateFunction::Max(var_name.to_string())
164 }
165 "first" => AggregateFunction::First,
166 "last" => AggregateFunction::Last,
167 _ => {
168 return Err(RuleEngineError::ParseError {
169 message: format!("Unknown aggregate function: '{}'. Supported: count, sum, avg, min, max, first, last", func_name),
170 });
171 }
172 };
173
174 let (pattern, filter) = if pattern_part.contains(" AND ") {
176 let parts: Vec<&str> = pattern_part.splitn(2, " AND ").collect();
177 (parts[0].trim().to_string(), Some(parts[1].trim().to_string()))
178 } else {
179 (pattern_part.to_string(), None)
180 };
181
182 Ok(AggregateQuery {
183 function,
184 pattern,
185 filter,
186 })
187}
188
189fn parse_function_call(s: &str) -> Result<(String, String)> {
191 let s = s.trim();
192
193 let open_idx = s.find('(').ok_or_else(|| RuleEngineError::ParseError {
195 message: format!("Expected '(' in function call: '{}'", s),
196 })?;
197
198 let close_idx = s.rfind(')').ok_or_else(|| RuleEngineError::ParseError {
200 message: format!("Expected ')' in function call: '{}'", s),
201 })?;
202
203 if close_idx <= open_idx {
204 return Err(RuleEngineError::ParseError {
205 message: format!("Invalid function call syntax: '{}'", s),
206 });
207 }
208
209 let func_name = s[..open_idx].trim().to_string();
210 let var_name = s[open_idx + 1..close_idx].trim().to_string();
211
212 let var_name = if var_name.starts_with('?') {
214 var_name[1..].to_string()
215 } else {
216 var_name
217 };
218
219 Ok((func_name, var_name))
220}
221
222pub fn apply_aggregate(
224 function: &AggregateFunction,
225 solutions: &[Solution],
226) -> Result<Value> {
227 if solutions.is_empty() {
228 return Ok(match function {
230 AggregateFunction::Count => Value::Integer(0),
231 AggregateFunction::Sum(_) => Value::Number(0.0),
232 AggregateFunction::Avg(_) => Value::Number(0.0),
233 AggregateFunction::Min(_) => Value::Null,
234 AggregateFunction::Max(_) => Value::Null,
235 AggregateFunction::First => Value::Null,
236 AggregateFunction::Last => Value::Null,
237 });
238 }
239
240 match function {
241 AggregateFunction::Count => {
242 Ok(Value::Integer(solutions.len() as i64))
243 }
244
245 AggregateFunction::Sum(field) => {
246 let sum: f64 = solutions.iter()
247 .filter_map(|s| s.bindings.get(field))
248 .filter_map(|v| value_to_float(v).ok())
249 .sum();
250 Ok(Value::Number(sum))
251 }
252
253 AggregateFunction::Avg(field) => {
254 let values: Vec<f64> = solutions.iter()
255 .filter_map(|s| s.bindings.get(field))
256 .filter_map(|v| value_to_float(v).ok())
257 .collect();
258
259 if values.is_empty() {
260 Ok(Value::Number(0.0))
261 } else {
262 let sum: f64 = values.iter().sum();
263 Ok(Value::Number(sum / values.len() as f64))
264 }
265 }
266
267 AggregateFunction::Min(field) => {
268 let min = solutions.iter()
269 .filter_map(|s| s.bindings.get(field))
270 .filter_map(|v| value_to_float(v).ok())
271 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273 Ok(min.map(Value::Number).unwrap_or(Value::Null))
274 }
275
276 AggregateFunction::Max(field) => {
277 let max = solutions.iter()
278 .filter_map(|s| s.bindings.get(field))
279 .filter_map(|v| value_to_float(v).ok())
280 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
281
282 Ok(max.map(Value::Number).unwrap_or(Value::Null))
283 }
284
285 AggregateFunction::First => {
286 Ok(solutions.first()
287 .and_then(|s| {
288 s.bindings.values().next().cloned()
290 })
291 .unwrap_or(Value::Null))
292 }
293
294 AggregateFunction::Last => {
295 Ok(solutions.last()
296 .and_then(|s| {
297 s.bindings.values().last().cloned()
299 })
300 .unwrap_or(Value::Null))
301 }
302 }
303}
304
305fn value_to_float(value: &Value) -> Result<f64> {
307 match value {
308 Value::Number(n) => Ok(*n),
309 Value::Integer(i) => Ok(*i as f64),
310 Value::String(s) => s.parse::<f64>().map_err(|_| {
311 RuleEngineError::EvaluationError {
312 message: format!("Cannot convert '{}' to number", s),
313 }
314 }),
315 _ => Err(RuleEngineError::EvaluationError {
316 message: format!("Cannot aggregate non-numeric value: {:?}", value),
317 }),
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_parse_count_query() {
327 let query = "count(?x) WHERE employee(?x)";
328 let result = parse_aggregate_query(query).unwrap();
329
330 assert_eq!(result.function, AggregateFunction::Count);
331 assert_eq!(result.pattern, "employee(?x)");
332 assert_eq!(result.filter, None);
333 }
334
335 #[test]
336 fn test_parse_sum_query() {
337 let query = "sum(?amount) WHERE purchase(?item, ?amount)";
338 let result = parse_aggregate_query(query).unwrap();
339
340 assert_eq!(result.function, AggregateFunction::Sum("amount".to_string()));
341 assert_eq!(result.pattern, "purchase(?item, ?amount)");
342 }
343
344 #[test]
345 fn test_parse_avg_with_filter() {
346 let query = "avg(?salary) WHERE salary(?name, ?salary) AND ?salary > 50000";
347 let result = parse_aggregate_query(query).unwrap();
348
349 assert_eq!(result.function, AggregateFunction::Avg("salary".to_string()));
350 assert_eq!(result.pattern, "salary(?name, ?salary)");
351 assert_eq!(result.filter, Some("?salary > 50000".to_string()));
352 }
353
354 #[test]
355 fn test_parse_min_query() {
356 let query = "min(?price) WHERE product(?name, ?price)";
357 let result = parse_aggregate_query(query).unwrap();
358
359 assert_eq!(result.function, AggregateFunction::Min("price".to_string()));
360 }
361
362 #[test]
363 fn test_parse_max_query() {
364 let query = "max(?score) WHERE student(?name, ?score)";
365 let result = parse_aggregate_query(query).unwrap();
366
367 assert_eq!(result.function, AggregateFunction::Max("score".to_string()));
368 }
369
370 #[test]
371 fn test_parse_invalid_query() {
372 let query = "count(?x)"; let result = parse_aggregate_query(query);
374 assert!(result.is_err());
375 }
376
377 #[test]
378 fn test_parse_unknown_function() {
379 let query = "unknown(?x) WHERE test(?x)";
380 let result = parse_aggregate_query(query);
381 assert!(result.is_err());
382 }
383
384 #[test]
385 fn test_apply_count() {
386 let solutions = vec![
387 Solution { path: vec![], bindings: HashMap::new() },
388 Solution { path: vec![], bindings: HashMap::new() },
389 Solution { path: vec![], bindings: HashMap::new() },
390 ];
391
392 let result = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
393 assert_eq!(result, Value::Integer(3));
394 }
395
396 #[test]
397 fn test_apply_sum() {
398 let mut b1 = HashMap::new();
399 b1.insert("amount".to_string(), Value::Number(100.0));
400
401 let mut b2 = HashMap::new();
402 b2.insert("amount".to_string(), Value::Number(200.0));
403
404 let mut b3 = HashMap::new();
405 b3.insert("amount".to_string(), Value::Number(300.0));
406
407 let solutions = vec![
408 Solution { path: vec![], bindings: b1 },
409 Solution { path: vec![], bindings: b2 },
410 Solution { path: vec![], bindings: b3 },
411 ];
412
413 let result = apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
414 assert_eq!(result, Value::Number(600.0));
415 }
416
417 #[test]
418 fn test_apply_avg() {
419 let mut b1 = HashMap::new();
420 b1.insert("score".to_string(), Value::Integer(80));
421
422 let mut b2 = HashMap::new();
423 b2.insert("score".to_string(), Value::Integer(90));
424
425 let mut b3 = HashMap::new();
426 b3.insert("score".to_string(), Value::Integer(100));
427
428 let solutions = vec![
429 Solution { path: vec![], bindings: b1 },
430 Solution { path: vec![], bindings: b2 },
431 Solution { path: vec![], bindings: b3 },
432 ];
433
434 let result = apply_aggregate(&AggregateFunction::Avg("score".to_string()), &solutions).unwrap();
435 assert_eq!(result, Value::Number(90.0));
436 }
437
438 #[test]
439 fn test_apply_min() {
440 let mut b1 = HashMap::new();
441 b1.insert("price".to_string(), Value::Number(99.99));
442
443 let mut b2 = HashMap::new();
444 b2.insert("price".to_string(), Value::Number(49.99));
445
446 let mut b3 = HashMap::new();
447 b3.insert("price".to_string(), Value::Number(149.99));
448
449 let solutions = vec![
450 Solution { path: vec![], bindings: b1 },
451 Solution { path: vec![], bindings: b2 },
452 Solution { path: vec![], bindings: b3 },
453 ];
454
455 let result = apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
456 assert_eq!(result, Value::Number(49.99));
457 }
458
459 #[test]
460 fn test_apply_max() {
461 let mut b1 = HashMap::new();
462 b1.insert("price".to_string(), Value::Number(99.99));
463
464 let mut b2 = HashMap::new();
465 b2.insert("price".to_string(), Value::Number(49.99));
466
467 let mut b3 = HashMap::new();
468 b3.insert("price".to_string(), Value::Number(149.99));
469
470 let solutions = vec![
471 Solution { path: vec![], bindings: b1 },
472 Solution { path: vec![], bindings: b2 },
473 Solution { path: vec![], bindings: b3 },
474 ];
475
476 let result = apply_aggregate(&AggregateFunction::Max("price".to_string()), &solutions).unwrap();
477 assert_eq!(result, Value::Number(149.99));
478 }
479
480 #[test]
481 fn test_apply_empty_solutions() {
482 let solutions = vec![];
483
484 let count = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
485 assert_eq!(count, Value::Integer(0));
486
487 let sum = apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
488 assert_eq!(sum, Value::Number(0.0));
489
490 let min = apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
491 assert_eq!(min, Value::Null);
492 }
493}