1use super::search::Solution;
33use crate::errors::{Result, RuleEngineError};
34use crate::types::Value;
35
36#[derive(Debug, Clone, PartialEq)]
38pub enum AggregateFunction {
39 Count,
41
42 Sum(String),
44
45 Avg(String),
47
48 Min(String),
50
51 Max(String),
53
54 First,
56
57 Last,
59}
60
61impl AggregateFunction {
62 pub fn field_name(&self) -> Option<&str> {
64 match self {
65 AggregateFunction::Sum(f)
66 | AggregateFunction::Avg(f)
67 | AggregateFunction::Min(f)
68 | AggregateFunction::Max(f) => Some(f),
69 _ => None,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct AggregateQuery {
77 pub function: AggregateFunction,
79
80 pub pattern: String,
82
83 pub filter: Option<String>,
85}
86
87impl AggregateQuery {
88 pub fn new(function: AggregateFunction, pattern: String) -> Self {
90 Self {
91 function,
92 pattern,
93 filter: None,
94 }
95 }
96
97 pub fn with_filter(mut self, filter: String) -> Self {
99 self.filter = Some(filter);
100 self
101 }
102}
103
104pub fn parse_aggregate_query(query: &str) -> Result<AggregateQuery> {
113 let query = query.trim();
114
115 let parts: Vec<&str> = query.splitn(2, " WHERE ").collect();
117 if parts.len() != 2 {
118 return Err(RuleEngineError::ParseError {
119 message: format!("Invalid aggregate query format. Expected: 'function(?var) WHERE pattern'. Got: '{}'", query),
120 });
121 }
122
123 let func_part = parts[0].trim();
124 let pattern_part = parts[1].trim();
125
126 let (func_name, var_name) = parse_function_call(func_part)?;
128
129 let function = match func_name.to_lowercase().as_str() {
131 "count" => AggregateFunction::Count,
132 "sum" => {
133 if var_name.is_empty() {
134 return Err(RuleEngineError::ParseError {
135 message: "sum() requires a variable, e.g., sum(?amount)".to_string(),
136 });
137 }
138 AggregateFunction::Sum(var_name.to_string())
139 }
140 "avg" => {
141 if var_name.is_empty() {
142 return Err(RuleEngineError::ParseError {
143 message: "avg() requires a variable, e.g., avg(?salary)".to_string(),
144 });
145 }
146 AggregateFunction::Avg(var_name.to_string())
147 }
148 "min" => {
149 if var_name.is_empty() {
150 return Err(RuleEngineError::ParseError {
151 message: "min() requires a variable, e.g., min(?price)".to_string(),
152 });
153 }
154 AggregateFunction::Min(var_name.to_string())
155 }
156 "max" => {
157 if var_name.is_empty() {
158 return Err(RuleEngineError::ParseError {
159 message: "max() requires a variable, e.g., max(?score)".to_string(),
160 });
161 }
162 AggregateFunction::Max(var_name.to_string())
163 }
164 "first" => AggregateFunction::First,
165 "last" => AggregateFunction::Last,
166 _ => {
167 return Err(RuleEngineError::ParseError {
168 message: format!("Unknown aggregate function: '{}'. Supported: count, sum, avg, min, max, first, last", func_name),
169 });
170 }
171 };
172
173 let (pattern, filter) = if pattern_part.contains(" AND ") {
175 let parts: Vec<&str> = pattern_part.splitn(2, " AND ").collect();
176 (
177 parts[0].trim().to_string(),
178 Some(parts[1].trim().to_string()),
179 )
180 } else {
181 (pattern_part.to_string(), None)
182 };
183
184 Ok(AggregateQuery {
185 function,
186 pattern,
187 filter,
188 })
189}
190
191fn parse_function_call(s: &str) -> Result<(String, String)> {
193 let s = s.trim();
194
195 let open_idx = s.find('(').ok_or_else(|| RuleEngineError::ParseError {
197 message: format!("Expected '(' in function call: '{}'", s),
198 })?;
199
200 let close_idx = s.rfind(')').ok_or_else(|| RuleEngineError::ParseError {
202 message: format!("Expected ')' in function call: '{}'", s),
203 })?;
204
205 if close_idx <= open_idx {
206 return Err(RuleEngineError::ParseError {
207 message: format!("Invalid function call syntax: '{}'", s),
208 });
209 }
210
211 let func_name = s[..open_idx].trim().to_string();
212 let var_name = s[open_idx + 1..close_idx].trim().to_string();
213
214 let var_name = if let Some(stripped) = var_name.strip_prefix('?') {
216 stripped.to_string()
217 } else {
218 var_name
219 };
220
221 Ok((func_name, var_name))
222}
223
224pub fn apply_aggregate(function: &AggregateFunction, solutions: &[Solution]) -> Result<Value> {
226 if solutions.is_empty() {
227 return Ok(match function {
229 AggregateFunction::Count => Value::Integer(0),
230 AggregateFunction::Sum(_) => Value::Number(0.0),
231 AggregateFunction::Avg(_) => Value::Number(0.0),
232 AggregateFunction::Min(_) => Value::Null,
233 AggregateFunction::Max(_) => Value::Null,
234 AggregateFunction::First => Value::Null,
235 AggregateFunction::Last => Value::Null,
236 });
237 }
238
239 match function {
240 AggregateFunction::Count => Ok(Value::Integer(solutions.len() as i64)),
241
242 AggregateFunction::Sum(field) => {
243 let sum: f64 = solutions
244 .iter()
245 .filter_map(|s| s.bindings.get(field))
246 .filter_map(|v| value_to_float(v).ok())
247 .sum();
248 Ok(Value::Number(sum))
249 }
250
251 AggregateFunction::Avg(field) => {
252 let values: Vec<f64> = solutions
253 .iter()
254 .filter_map(|s| s.bindings.get(field))
255 .filter_map(|v| value_to_float(v).ok())
256 .collect();
257
258 if values.is_empty() {
259 Ok(Value::Number(0.0))
260 } else {
261 let sum: f64 = values.iter().sum();
262 Ok(Value::Number(sum / values.len() as f64))
263 }
264 }
265
266 AggregateFunction::Min(field) => {
267 let min = solutions
268 .iter()
269 .filter_map(|s| s.bindings.get(field))
270 .filter_map(|v| value_to_float(v).ok())
271 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
272
273 Ok(min.map(Value::Number).unwrap_or(Value::Null))
274 }
275
276 AggregateFunction::Max(field) => {
277 let max = solutions
278 .iter()
279 .filter_map(|s| s.bindings.get(field))
280 .filter_map(|v| value_to_float(v).ok())
281 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
282
283 Ok(max.map(Value::Number).unwrap_or(Value::Null))
284 }
285
286 AggregateFunction::First => {
287 Ok(solutions
288 .first()
289 .and_then(|s| {
290 s.bindings.values().next().cloned()
292 })
293 .unwrap_or(Value::Null))
294 }
295
296 AggregateFunction::Last => {
297 Ok(solutions
298 .last()
299 .and_then(|s| {
300 s.bindings.values().last().cloned()
302 })
303 .unwrap_or(Value::Null))
304 }
305 }
306}
307
308fn value_to_float(value: &Value) -> Result<f64> {
310 match value {
311 Value::Number(n) => Ok(*n),
312 Value::Integer(i) => Ok(*i as f64),
313 Value::String(s) => s
314 .parse::<f64>()
315 .map_err(|_| RuleEngineError::EvaluationError {
316 message: format!("Cannot convert '{}' to number", s),
317 }),
318 _ => Err(RuleEngineError::EvaluationError {
319 message: format!("Cannot aggregate non-numeric value: {:?}", value),
320 }),
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use std::collections::HashMap;
328
329 #[test]
330 fn test_parse_count_query() {
331 let query = "count(?x) WHERE employee(?x)";
332 let result = parse_aggregate_query(query).unwrap();
333
334 assert_eq!(result.function, AggregateFunction::Count);
335 assert_eq!(result.pattern, "employee(?x)");
336 assert_eq!(result.filter, None);
337 }
338
339 #[test]
340 fn test_parse_sum_query() {
341 let query = "sum(?amount) WHERE purchase(?item, ?amount)";
342 let result = parse_aggregate_query(query).unwrap();
343
344 assert_eq!(
345 result.function,
346 AggregateFunction::Sum("amount".to_string())
347 );
348 assert_eq!(result.pattern, "purchase(?item, ?amount)");
349 }
350
351 #[test]
352 fn test_parse_avg_with_filter() {
353 let query = "avg(?salary) WHERE salary(?name, ?salary) AND ?salary > 50000";
354 let result = parse_aggregate_query(query).unwrap();
355
356 assert_eq!(
357 result.function,
358 AggregateFunction::Avg("salary".to_string())
359 );
360 assert_eq!(result.pattern, "salary(?name, ?salary)");
361 assert_eq!(result.filter, Some("?salary > 50000".to_string()));
362 }
363
364 #[test]
365 fn test_parse_min_query() {
366 let query = "min(?price) WHERE product(?name, ?price)";
367 let result = parse_aggregate_query(query).unwrap();
368
369 assert_eq!(result.function, AggregateFunction::Min("price".to_string()));
370 }
371
372 #[test]
373 fn test_parse_max_query() {
374 let query = "max(?score) WHERE student(?name, ?score)";
375 let result = parse_aggregate_query(query).unwrap();
376
377 assert_eq!(result.function, AggregateFunction::Max("score".to_string()));
378 }
379
380 #[test]
381 fn test_parse_invalid_query() {
382 let query = "count(?x)"; let result = parse_aggregate_query(query);
384 assert!(result.is_err());
385 }
386
387 #[test]
388 fn test_parse_unknown_function() {
389 let query = "unknown(?x) WHERE test(?x)";
390 let result = parse_aggregate_query(query);
391 assert!(result.is_err());
392 }
393
394 #[test]
395 fn test_apply_count() {
396 let solutions = vec![
397 Solution {
398 path: vec![],
399 bindings: HashMap::new(),
400 },
401 Solution {
402 path: vec![],
403 bindings: HashMap::new(),
404 },
405 Solution {
406 path: vec![],
407 bindings: HashMap::new(),
408 },
409 ];
410
411 let result = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
412 assert_eq!(result, Value::Integer(3));
413 }
414
415 #[test]
416 fn test_apply_sum() {
417 let mut b1 = HashMap::new();
418 b1.insert("amount".to_string(), Value::Number(100.0));
419
420 let mut b2 = HashMap::new();
421 b2.insert("amount".to_string(), Value::Number(200.0));
422
423 let mut b3 = HashMap::new();
424 b3.insert("amount".to_string(), Value::Number(300.0));
425
426 let solutions = vec![
427 Solution {
428 path: vec![],
429 bindings: b1,
430 },
431 Solution {
432 path: vec![],
433 bindings: b2,
434 },
435 Solution {
436 path: vec![],
437 bindings: b3,
438 },
439 ];
440
441 let result =
442 apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
443 assert_eq!(result, Value::Number(600.0));
444 }
445
446 #[test]
447 fn test_apply_avg() {
448 let mut b1 = HashMap::new();
449 b1.insert("score".to_string(), Value::Integer(80));
450
451 let mut b2 = HashMap::new();
452 b2.insert("score".to_string(), Value::Integer(90));
453
454 let mut b3 = HashMap::new();
455 b3.insert("score".to_string(), Value::Integer(100));
456
457 let solutions = vec![
458 Solution {
459 path: vec![],
460 bindings: b1,
461 },
462 Solution {
463 path: vec![],
464 bindings: b2,
465 },
466 Solution {
467 path: vec![],
468 bindings: b3,
469 },
470 ];
471
472 let result =
473 apply_aggregate(&AggregateFunction::Avg("score".to_string()), &solutions).unwrap();
474 assert_eq!(result, Value::Number(90.0));
475 }
476
477 #[test]
478 fn test_apply_min() {
479 let mut b1 = HashMap::new();
480 b1.insert("price".to_string(), Value::Number(99.99));
481
482 let mut b2 = HashMap::new();
483 b2.insert("price".to_string(), Value::Number(49.99));
484
485 let mut b3 = HashMap::new();
486 b3.insert("price".to_string(), Value::Number(149.99));
487
488 let solutions = vec![
489 Solution {
490 path: vec![],
491 bindings: b1,
492 },
493 Solution {
494 path: vec![],
495 bindings: b2,
496 },
497 Solution {
498 path: vec![],
499 bindings: b3,
500 },
501 ];
502
503 let result =
504 apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
505 assert_eq!(result, Value::Number(49.99));
506 }
507
508 #[test]
509 fn test_apply_max() {
510 let mut b1 = HashMap::new();
511 b1.insert("price".to_string(), Value::Number(99.99));
512
513 let mut b2 = HashMap::new();
514 b2.insert("price".to_string(), Value::Number(49.99));
515
516 let mut b3 = HashMap::new();
517 b3.insert("price".to_string(), Value::Number(149.99));
518
519 let solutions = vec![
520 Solution {
521 path: vec![],
522 bindings: b1,
523 },
524 Solution {
525 path: vec![],
526 bindings: b2,
527 },
528 Solution {
529 path: vec![],
530 bindings: b3,
531 },
532 ];
533
534 let result =
535 apply_aggregate(&AggregateFunction::Max("price".to_string()), &solutions).unwrap();
536 assert_eq!(result, Value::Number(149.99));
537 }
538
539 #[test]
540 fn test_apply_empty_solutions() {
541 let solutions = vec![];
542
543 let count = apply_aggregate(&AggregateFunction::Count, &solutions).unwrap();
544 assert_eq!(count, Value::Integer(0));
545
546 let sum =
547 apply_aggregate(&AggregateFunction::Sum("amount".to_string()), &solutions).unwrap();
548 assert_eq!(sum, Value::Number(0.0));
549
550 let min =
551 apply_aggregate(&AggregateFunction::Min("price".to_string()), &solutions).unwrap();
552 assert_eq!(min, Value::Null);
553 }
554}