1use super::ast::*;
46use super::bridge::ExecutionResult;
47use super::error::SqlResult;
48use rayon::prelude::*;
49use sochdb_core::SochValue;
50use std::collections::{HashMap, HashSet};
51
52const PARALLEL_THRESHOLD: usize = 100_000;
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum AggFn {
62 Count,
63 Sum,
64 Avg,
65 Min,
66 Max,
67 Median,
68 Stddev,
69}
70
71impl AggFn {
72 pub fn from_name(name: &str) -> Option<Self> {
74 match name.to_ascii_uppercase().as_str() {
75 "COUNT" => Some(Self::Count),
76 "SUM" => Some(Self::Sum),
77 "AVG" | "MEAN" => Some(Self::Avg),
78 "MIN" => Some(Self::Min),
79 "MAX" => Some(Self::Max),
80 "MEDIAN" => Some(Self::Median),
81 "STDDEV" | "STDDEV_SAMP" | "STDEV" | "SD" => Some(Self::Stddev),
82 _ => None,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89struct AggSpec {
90 key: String,
93 func: AggFn,
94 arg: Option<Expr>,
96 distinct: bool,
97}
98
99pub fn is_aggregate_query(select: &SelectStmt) -> bool {
101 if !select.group_by.is_empty() {
102 return true;
103 }
104 select
105 .columns
106 .iter()
107 .any(|item| matches!(item, SelectItem::Expr { expr, .. } if contains_aggregate(expr)))
108}
109
110fn contains_aggregate(expr: &Expr) -> bool {
112 match expr {
113 Expr::Function(f) => {
114 AggFn::from_name(f.name.name()).is_some() || f.args.iter().any(contains_aggregate)
115 }
116 Expr::BinaryOp { left, right, .. } => contains_aggregate(left) || contains_aggregate(right),
117 Expr::UnaryOp { expr, .. } => contains_aggregate(expr),
118 Expr::IsNull { expr, .. } => contains_aggregate(expr),
119 Expr::Case {
120 operand,
121 conditions,
122 else_result,
123 } => {
124 operand.as_deref().map(contains_aggregate).unwrap_or(false)
125 || conditions
126 .iter()
127 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
128 || else_result
129 .as_deref()
130 .map(contains_aggregate)
131 .unwrap_or(false)
132 }
133 _ => false,
134 }
135}
136
137fn collect_agg_specs(select: &SelectStmt) -> Vec<AggSpec> {
139 let mut specs: Vec<AggSpec> = Vec::new();
140 let mut seen: HashSet<String> = HashSet::new();
141
142 let walk = |expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>| {
143 collect_from_expr(expr, specs, seen);
144 };
145
146 for item in &select.columns {
147 if let SelectItem::Expr { expr, .. } = item {
148 walk(expr, &mut specs, &mut seen);
149 }
150 }
151 if let Some(h) = &select.having {
152 walk(h, &mut specs, &mut seen);
153 }
154 for ob in &select.order_by {
155 walk(&ob.expr, &mut specs, &mut seen);
156 }
157 specs
158}
159
160fn collect_from_expr(expr: &Expr, specs: &mut Vec<AggSpec>, seen: &mut HashSet<String>) {
161 match expr {
162 Expr::Function(f) => {
163 if let Some(func) = AggFn::from_name(f.name.name()) {
164 let arg = f.args.first().cloned();
165 let is_star = matches!(arg.as_ref(), Some(Expr::Column(c)) if c.column == "*");
166 let arg = if is_star { None } else { arg };
167 let key = render_agg_key(func, arg.as_ref(), f.distinct);
168 if seen.insert(key.clone()) {
169 specs.push(AggSpec {
170 key,
171 func,
172 arg,
173 distinct: f.distinct,
174 });
175 }
176 } else {
177 for a in &f.args {
178 collect_from_expr(a, specs, seen);
179 }
180 }
181 }
182 Expr::BinaryOp { left, right, .. } => {
183 collect_from_expr(left, specs, seen);
184 collect_from_expr(right, specs, seen);
185 }
186 Expr::UnaryOp { expr, .. } => collect_from_expr(expr, specs, seen),
187 Expr::IsNull { expr, .. } => collect_from_expr(expr, specs, seen),
188 Expr::Case {
189 operand,
190 conditions,
191 else_result,
192 } => {
193 if let Some(op) = operand {
194 collect_from_expr(op, specs, seen);
195 }
196 for (w, t) in conditions {
197 collect_from_expr(w, specs, seen);
198 collect_from_expr(t, specs, seen);
199 }
200 if let Some(e) = else_result {
201 collect_from_expr(e, specs, seen);
202 }
203 }
204 _ => {}
205 }
206}
207
208fn render_agg_key(func: AggFn, arg: Option<&Expr>, distinct: bool) -> String {
211 let fname = match func {
212 AggFn::Count => "count",
213 AggFn::Sum => "sum",
214 AggFn::Avg => "avg",
215 AggFn::Min => "min",
216 AggFn::Max => "max",
217 AggFn::Median => "median",
218 AggFn::Stddev => "stddev",
219 };
220 let arg_s = match arg {
221 None => "*".to_string(),
222 Some(e) => render_expr_name(e),
223 };
224 if distinct {
225 format!("{}(distinct {})", fname, arg_s)
226 } else {
227 format!("{}({})", fname, arg_s)
228 }
229}
230
231pub fn render_expr_name(expr: &Expr) -> String {
234 match expr {
235 Expr::Column(c) => {
236 if let Some(t) = &c.table {
237 format!("{}.{}", t, c.column)
238 } else {
239 c.column.clone()
240 }
241 }
242 Expr::Literal(Literal::Integer(n)) => n.to_string(),
243 Expr::Literal(Literal::Float(f)) => f.to_string(),
244 Expr::Literal(Literal::String(s)) => format!("'{}'", s),
245 Expr::Literal(Literal::Boolean(b)) => b.to_string(),
246 Expr::Literal(Literal::Null) => "null".to_string(),
247 Expr::Function(f) => {
248 if let Some(func) = AggFn::from_name(f.name.name()) {
249 let arg = f.args.first();
250 let is_star = matches!(arg, Some(Expr::Column(c)) if c.column == "*");
251 render_agg_key(func, if is_star { None } else { arg }, f.distinct)
252 } else {
253 let args: Vec<String> = f.args.iter().map(render_expr_name).collect();
254 format!("{}({})", f.name.name().to_lowercase(), args.join(", "))
255 }
256 }
257 Expr::BinaryOp { left, op, right } => format!(
258 "{} {} {}",
259 render_expr_name(left),
260 binary_op_symbol(op),
261 render_expr_name(right)
262 ),
263 Expr::UnaryOp { op, expr } => match op {
264 UnaryOperator::Minus => format!("-{}", render_expr_name(expr)),
265 UnaryOperator::Plus => render_expr_name(expr),
266 UnaryOperator::Not => format!("not {}", render_expr_name(expr)),
267 UnaryOperator::BitNot => format!("~{}", render_expr_name(expr)),
268 },
269 _ => "expr".to_string(),
270 }
271}
272
273fn binary_op_symbol(op: &BinaryOperator) -> &'static str {
274 match op {
275 BinaryOperator::Plus => "+",
276 BinaryOperator::Minus => "-",
277 BinaryOperator::Multiply => "*",
278 BinaryOperator::Divide => "/",
279 BinaryOperator::Modulo => "%",
280 BinaryOperator::Eq => "=",
281 BinaryOperator::Ne => "<>",
282 BinaryOperator::Lt => "<",
283 BinaryOperator::Le => "<=",
284 BinaryOperator::Gt => ">",
285 BinaryOperator::Ge => ">=",
286 BinaryOperator::And => "and",
287 BinaryOperator::Or => "or",
288 _ => "?",
289 }
290}
291
292fn eval_scalar(expr: &Expr, row: &HashMap<String, SochValue>, params: &[SochValue]) -> SochValue {
302 match expr {
303 Expr::Column(c) => {
304 if let Some(t) = &c.table {
305 let qualified = format!("{}.{}", t, c.column);
306 if let Some(v) = row.get(&qualified) {
307 return v.clone();
308 }
309 }
310 row.get(&c.column).cloned().unwrap_or(SochValue::Null)
311 }
312 Expr::Literal(lit) => literal_to_value(lit),
313 Expr::Placeholder(idx) => params
314 .get((*idx as usize).saturating_sub(1))
315 .cloned()
316 .unwrap_or(SochValue::Null),
317 Expr::Function(f) => {
318 let key = render_expr_name(&Expr::Function(f.clone()));
321 row.get(&key).cloned().unwrap_or(SochValue::Null)
322 }
323 Expr::BinaryOp { left, op, right } => {
324 let l = eval_scalar(left, row, params);
325 let r = eval_scalar(right, row, params);
326 eval_binary(&l, op, &r)
327 }
328 Expr::UnaryOp { op, expr } => {
329 let v = eval_scalar(expr, row, params);
330 match op {
331 UnaryOperator::Minus => match v {
332 SochValue::Int(i) => SochValue::Int(-i),
333 SochValue::Float(f) => SochValue::Float(-f),
334 _ => SochValue::Null,
335 },
336 UnaryOperator::Plus => v,
337 UnaryOperator::Not => match v {
338 SochValue::Bool(b) => SochValue::Bool(!b),
339 _ => SochValue::Null,
340 },
341 UnaryOperator::BitNot => match v {
342 SochValue::Int(i) => SochValue::Int(!i),
343 _ => SochValue::Null,
344 },
345 }
346 }
347 Expr::IsNull { expr, negated } => {
348 let v = eval_scalar(expr, row, params);
349 let is_null = v.is_null();
350 SochValue::Bool(if *negated { !is_null } else { is_null })
351 }
352 _ => SochValue::Null,
353 }
354}
355
356fn literal_to_value(lit: &Literal) -> SochValue {
357 match lit {
358 Literal::Integer(i) => SochValue::Int(*i),
359 Literal::Float(f) => SochValue::Float(*f),
360 Literal::String(s) => SochValue::Text(s.clone()),
361 Literal::Boolean(b) => SochValue::Bool(*b),
362 Literal::Null => SochValue::Null,
363 _ => SochValue::Null,
364 }
365}
366
367fn numeric(v: &SochValue) -> Option<f64> {
368 match v {
369 SochValue::Int(i) => Some(*i as f64),
370 SochValue::UInt(u) => Some(*u as f64),
371 SochValue::Float(f) => Some(*f),
372 SochValue::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
373 _ => None,
374 }
375}
376
377fn eval_binary(l: &SochValue, op: &BinaryOperator, r: &SochValue) -> SochValue {
378 use BinaryOperator::*;
379 match op {
380 Plus | Minus | Multiply | Divide | Modulo => {
381 if let (SochValue::Int(a), SochValue::Int(b)) = (l, r) {
383 return match op {
384 Plus => SochValue::Int(a.wrapping_add(*b)),
385 Minus => SochValue::Int(a.wrapping_sub(*b)),
386 Multiply => SochValue::Int(a.wrapping_mul(*b)),
387 Divide => {
388 if *b == 0 {
389 SochValue::Null
390 } else {
391 SochValue::Float(*a as f64 / *b as f64)
392 }
393 }
394 Modulo => {
395 if *b == 0 {
396 SochValue::Null
397 } else {
398 SochValue::Int(a % b)
399 }
400 }
401 _ => unreachable!(),
402 };
403 }
404 let (a, b) = match (numeric(l), numeric(r)) {
405 (Some(a), Some(b)) => (a, b),
406 _ => return SochValue::Null,
407 };
408 match op {
409 Plus => SochValue::Float(a + b),
410 Minus => SochValue::Float(a - b),
411 Multiply => SochValue::Float(a * b),
412 Divide => {
413 if b == 0.0 {
414 SochValue::Null
415 } else {
416 SochValue::Float(a / b)
417 }
418 }
419 Modulo => {
420 if b == 0.0 {
421 SochValue::Null
422 } else {
423 SochValue::Float(a % b)
424 }
425 }
426 _ => unreachable!(),
427 }
428 }
429 Eq | Ne | Lt | Le | Gt | Ge => {
430 if l.is_null() || r.is_null() {
431 return SochValue::Null;
432 }
433 let ord = compare_values(l, r);
434 let b = match op {
435 Eq => ord == std::cmp::Ordering::Equal,
436 Ne => ord != std::cmp::Ordering::Equal,
437 Lt => ord == std::cmp::Ordering::Less,
438 Le => ord != std::cmp::Ordering::Greater,
439 Gt => ord == std::cmp::Ordering::Greater,
440 Ge => ord != std::cmp::Ordering::Less,
441 _ => unreachable!(),
442 };
443 SochValue::Bool(b)
444 }
445 And => match (as_bool(l), as_bool(r)) {
446 (Some(a), Some(b)) => SochValue::Bool(a && b),
447 _ => SochValue::Null,
448 },
449 Or => match (as_bool(l), as_bool(r)) {
450 (Some(a), Some(b)) => SochValue::Bool(a || b),
451 _ => SochValue::Null,
452 },
453 _ => SochValue::Null,
454 }
455}
456
457fn as_bool(v: &SochValue) -> Option<bool> {
458 match v {
459 SochValue::Bool(b) => Some(*b),
460 SochValue::Int(i) => Some(*i != 0),
461 SochValue::Null => None,
462 _ => None,
463 }
464}
465
466pub fn compare_values(a: &SochValue, b: &SochValue) -> std::cmp::Ordering {
468 use std::cmp::Ordering;
469 match (numeric(a), numeric(b)) {
470 (Some(x), Some(y)) => return x.partial_cmp(&y).unwrap_or(Ordering::Equal),
471 _ => {}
472 }
473 match (a, b) {
474 (SochValue::Text(x), SochValue::Text(y)) => x.cmp(y),
475 (SochValue::Null, SochValue::Null) => Ordering::Equal,
476 (SochValue::Null, _) => Ordering::Less,
477 (_, SochValue::Null) => Ordering::Greater,
478 _ => Ordering::Equal,
479 }
480}
481
482fn key_repr(v: &SochValue) -> String {
485 match v {
486 SochValue::Null => "\u{0}N".to_string(),
487 SochValue::Int(i) => format!("i{}", i),
488 SochValue::UInt(u) => format!("i{}", u),
489 SochValue::Float(f) => {
490 if f.fract() == 0.0 && f.abs() < 9.0e15 {
491 format!("i{}", *f as i64)
492 } else {
493 format!("f{}", f)
494 }
495 }
496 SochValue::Text(s) => format!("s{}", s),
497 SochValue::Bool(b) => format!("b{}", b),
498 other => format!("{:?}", other),
499 }
500}
501
502#[derive(Debug)]
507enum Acc {
508 CountStar(u64),
509 Count(u64),
510 CountDistinct(HashSet<String>),
511 Sum {
513 int: i64,
514 float: f64,
515 saw_float: bool,
516 saw_any: bool,
517 overflowed: bool,
518 },
519 Avg {
520 sum: f64,
521 n: u64,
522 },
523 Min(Option<SochValue>),
524 Max(Option<SochValue>),
525 Median(Vec<f64>),
526 Stddev {
528 n: u64,
529 mean: f64,
530 m2: f64,
531 },
532}
533
534impl Acc {
535 fn new(spec: &AggSpec) -> Self {
536 match (spec.func, spec.arg.is_some(), spec.distinct) {
537 (AggFn::Count, false, _) => Acc::CountStar(0),
538 (AggFn::Count, true, true) => Acc::CountDistinct(HashSet::new()),
539 (AggFn::Count, true, false) => Acc::Count(0),
540 (AggFn::Sum, _, _) => Acc::Sum {
541 int: 0,
542 float: 0.0,
543 saw_float: false,
544 saw_any: false,
545 overflowed: false,
546 },
547 (AggFn::Avg, _, _) => Acc::Avg { sum: 0.0, n: 0 },
548 (AggFn::Min, _, _) => Acc::Min(None),
549 (AggFn::Max, _, _) => Acc::Max(None),
550 (AggFn::Median, _, _) => Acc::Median(Vec::new()),
551 (AggFn::Stddev, _, _) => Acc::Stddev {
552 n: 0,
553 mean: 0.0,
554 m2: 0.0,
555 },
556 }
557 }
558
559 fn update(&mut self, val: Option<&SochValue>) {
561 match self {
562 Acc::CountStar(n) => *n += 1,
563 Acc::Count(n) => {
564 if let Some(v) = val {
565 if !v.is_null() {
566 *n += 1;
567 }
568 }
569 }
570 Acc::CountDistinct(set) => {
571 if let Some(v) = val {
572 if !v.is_null() {
573 set.insert(key_repr(v));
574 }
575 }
576 }
577 Acc::Sum {
578 int,
579 float,
580 saw_float,
581 saw_any,
582 overflowed,
583 } => {
584 let Some(v) = val else { return };
585 match v {
586 SochValue::Int(i) => {
587 *saw_any = true;
588 match int.checked_add(*i) {
589 Some(s) => *int = s,
590 None => *overflowed = true,
591 }
592 *float += *i as f64;
593 }
594 SochValue::UInt(u) => {
595 *saw_any = true;
596 match int.checked_add(*u as i64) {
597 Some(s) => *int = s,
598 None => *overflowed = true,
599 }
600 *float += *u as f64;
601 }
602 SochValue::Float(f) => {
603 *saw_any = true;
604 *saw_float = true;
605 *float += *f;
606 }
607 _ => {}
608 }
609 }
610 Acc::Avg { sum, n } => {
611 if let Some(x) = val.and_then(numeric) {
612 *sum += x;
613 *n += 1;
614 }
615 }
616 Acc::Min(cur) => {
617 let Some(v) = val else { return };
618 if v.is_null() {
619 return;
620 }
621 match cur {
622 None => *cur = Some(v.clone()),
623 Some(c) => {
624 if compare_values(v, c) == std::cmp::Ordering::Less {
625 *cur = Some(v.clone());
626 }
627 }
628 }
629 }
630 Acc::Max(cur) => {
631 let Some(v) = val else { return };
632 if v.is_null() {
633 return;
634 }
635 match cur {
636 None => *cur = Some(v.clone()),
637 Some(c) => {
638 if compare_values(v, c) == std::cmp::Ordering::Greater {
639 *cur = Some(v.clone());
640 }
641 }
642 }
643 }
644 Acc::Median(vals) => {
645 if let Some(x) = val.and_then(numeric) {
646 vals.push(x);
647 }
648 }
649 Acc::Stddev { n, mean, m2 } => {
650 if let Some(x) = val.and_then(numeric) {
651 *n += 1;
652 let delta = x - *mean;
653 *mean += delta / *n as f64;
654 let delta2 = x - *mean;
655 *m2 += delta * delta2;
656 }
657 }
658 }
659 }
660
661 fn merge(&mut self, other: Acc) {
664 match (self, other) {
665 (Acc::CountStar(a), Acc::CountStar(b)) => *a += b,
666 (Acc::Count(a), Acc::Count(b)) => *a += b,
667 (Acc::CountDistinct(a), Acc::CountDistinct(b)) => a.extend(b),
668 (
669 Acc::Sum {
670 int,
671 float,
672 saw_float,
673 saw_any,
674 overflowed,
675 },
676 Acc::Sum {
677 int: i2,
678 float: f2,
679 saw_float: sf2,
680 saw_any: sa2,
681 overflowed: of2,
682 },
683 ) => {
684 match int.checked_add(i2) {
685 Some(s) => *int = s,
686 None => *overflowed = true,
687 }
688 *float += f2;
689 *saw_float |= sf2;
690 *saw_any |= sa2;
691 *overflowed |= of2;
692 }
693 (Acc::Avg { sum, n }, Acc::Avg { sum: s2, n: n2 }) => {
694 *sum += s2;
695 *n += n2;
696 }
697 (Acc::Min(a), Acc::Min(Some(b))) => match a {
698 None => *a = Some(b),
699 Some(cur) => {
700 if compare_values(&b, cur) == std::cmp::Ordering::Less {
701 *a = Some(b);
702 }
703 }
704 },
705 (Acc::Max(a), Acc::Max(Some(b))) => match a {
706 None => *a = Some(b),
707 Some(cur) => {
708 if compare_values(&b, cur) == std::cmp::Ordering::Greater {
709 *a = Some(b);
710 }
711 }
712 },
713 (Acc::Min(_), Acc::Min(None)) | (Acc::Max(_), Acc::Max(None)) => {}
714 (Acc::Median(a), Acc::Median(b)) => a.extend(b),
715 (
716 Acc::Stddev { n, mean, m2 },
717 Acc::Stddev {
718 n: nb,
719 mean: mb,
720 m2: m2b,
721 },
722 ) => {
723 if nb > 0 {
725 if *n == 0 {
726 *n = nb;
727 *mean = mb;
728 *m2 = m2b;
729 } else {
730 let na = *n as f64;
731 let nbf = nb as f64;
732 let delta = mb - *mean;
733 let total = na + nbf;
734 *mean += delta * nbf / total;
735 *m2 += m2b + delta * delta * na * nbf / total;
736 *n += nb;
737 }
738 }
739 }
740 _ => unreachable!("mismatched accumulator merge"),
741 }
742 }
743
744 fn finalize(self) -> SochValue {
745 match self {
746 Acc::CountStar(n) | Acc::Count(n) => SochValue::Int(n as i64),
747 Acc::CountDistinct(set) => SochValue::Int(set.len() as i64),
748 Acc::Sum {
749 int,
750 float,
751 saw_float,
752 saw_any,
753 overflowed,
754 } => {
755 if !saw_any {
756 SochValue::Null
757 } else if saw_float || overflowed {
758 SochValue::Float(float)
759 } else {
760 SochValue::Int(int)
761 }
762 }
763 Acc::Avg { sum, n } => {
764 if n == 0 {
765 SochValue::Null
766 } else {
767 SochValue::Float(sum / n as f64)
768 }
769 }
770 Acc::Min(v) | Acc::Max(v) => v.unwrap_or(SochValue::Null),
771 Acc::Median(mut vals) => {
772 if vals.is_empty() {
773 return SochValue::Null;
774 }
775 let mid = vals.len() / 2;
776 if vals.len() % 2 == 1 {
777 let (_, m, _) = vals.select_nth_unstable_by(mid, |a, b| {
778 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
779 });
780 SochValue::Float(*m)
781 } else {
782 let (lo, hi_first, _) = vals.select_nth_unstable_by(mid, |a, b| {
784 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
785 });
786 let lo_max = lo.iter().copied().fold(f64::NEG_INFINITY, f64::max);
787 SochValue::Float((lo_max + *hi_first) / 2.0)
788 }
789 }
790 Acc::Stddev { n, m2, .. } => {
791 if n < 2 {
792 SochValue::Null
793 } else {
794 SochValue::Float((m2 / (n - 1) as f64).sqrt())
795 }
796 }
797 }
798 }
799}
800
801struct GroupState {
806 key_values: Vec<SochValue>,
807 first_row: HashMap<String, SochValue>,
808 accs: Vec<Acc>,
809}
810
811#[derive(Debug, Clone, PartialEq, Eq, Hash)]
818enum KeyAtom<'a> {
819 Null,
820 Int(i64),
821 FBits(u64),
823 Str(&'a str),
824 Bool(bool),
825}
826
827impl<'a> KeyAtom<'a> {
828 fn from_value(v: &'a SochValue) -> Self {
829 match v {
830 SochValue::Null => KeyAtom::Null,
831 SochValue::Int(i) => KeyAtom::Int(*i),
832 SochValue::UInt(u) => KeyAtom::Int(*u as i64),
833 SochValue::Float(f) => {
834 if f.fract() == 0.0 && f.abs() < 9.0e15 {
835 KeyAtom::Int(*f as i64)
836 } else if f.is_nan() {
837 KeyAtom::FBits(f64::NAN.to_bits())
838 } else {
839 KeyAtom::FBits(f.to_bits())
840 }
841 }
842 SochValue::Text(s) => KeyAtom::Str(s.as_str()),
843 SochValue::Bool(b) => KeyAtom::Bool(*b),
844 _ => KeyAtom::Null,
845 }
846 }
847}
848
849#[derive(Debug, Clone, PartialEq, Eq, Hash)]
850enum GroupKey<'a> {
851 Empty,
852 One(KeyAtom<'a>),
853 Many(Vec<KeyAtom<'a>>),
854}
855
856static NULL_VALUE: SochValue = SochValue::Null;
857
858#[inline]
860fn col_get<'r>(row: &'r HashMap<String, SochValue>, col: &PlainCol) -> &'r SochValue {
861 if let Some(q) = &col.qualified {
862 if let Some(v) = row.get(q) {
863 return v;
864 }
865 }
866 row.get(&col.name).unwrap_or(&NULL_VALUE)
867}
868
869struct PlainCol {
871 name: String,
872 qualified: Option<String>,
873}
874
875fn as_plain_col(expr: &Expr) -> Option<PlainCol> {
876 match expr {
877 Expr::Column(c) => Some(PlainCol {
878 name: c.column.clone(),
879 qualified: c.table.as_ref().map(|t| format!("{}.{}", t, c.column)),
880 }),
881 _ => None,
882 }
883}
884
885fn make_group_key<'r>(
887 row: &'r HashMap<String, SochValue>,
888 group_cols: &[PlainCol],
889) -> GroupKey<'r> {
890 match group_cols.len() {
891 0 => GroupKey::Empty,
892 1 => GroupKey::One(KeyAtom::from_value(col_get(row, &group_cols[0]))),
893 _ => GroupKey::Many(
894 group_cols
895 .iter()
896 .map(|c| KeyAtom::from_value(col_get(row, c)))
897 .collect(),
898 ),
899 }
900}
901
902fn accumulate_fast<'a>(
907 select: &SelectStmt,
908 specs: &[AggSpec],
909 rows: &'a [HashMap<String, SochValue>],
910) -> Option<Vec<GroupState>> {
911 let group_cols: Vec<PlainCol> = select
913 .group_by
914 .iter()
915 .map(as_plain_col)
916 .collect::<Option<Vec<_>>>()?;
917 let arg_cols: Vec<Option<PlainCol>> = specs
919 .iter()
920 .map(|s| match &s.arg {
921 None => Some(None),
922 Some(e) => as_plain_col(e).map(Some),
923 })
924 .collect::<Option<Vec<_>>>()?;
925
926 let accumulate_chunk =
927 |chunk: &'a [HashMap<String, SochValue>]| -> Vec<(GroupKey<'a>, GroupState)> {
928 let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
929 let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
930 for row in chunk {
931 let key = make_group_key(row, &group_cols);
932 let idx = match index.get(&key) {
933 Some(&i) => i,
934 None => {
935 let state = GroupState {
936 key_values: group_cols
937 .iter()
938 .map(|c| col_get(row, c).clone())
939 .collect(),
940 first_row: row.clone(),
941 accs: specs.iter().map(Acc::new).collect(),
942 };
943 order.push((key.clone(), state));
944 index.insert(key, order.len() - 1);
945 order.len() - 1
946 }
947 };
948 let accs = &mut order[idx].1.accs;
949 for (acc, arg) in accs.iter_mut().zip(arg_cols.iter()) {
950 match arg {
951 None => acc.update(None),
952 Some(col) => acc.update(Some(col_get(row, col))),
953 }
954 }
955 }
956 order
957 };
958
959 let merged: Vec<(GroupKey<'a>, GroupState)> = if rows.len() >= PARALLEL_THRESHOLD {
960 let n_threads = rayon::current_num_threads().max(1);
961 let chunk_size = (rows.len() / (n_threads * 4)).max(16_384);
962 let partials: Vec<Vec<(GroupKey<'a>, GroupState)>> =
963 rows.par_chunks(chunk_size).map(accumulate_chunk).collect();
964 let mut order: Vec<(GroupKey<'a>, GroupState)> = Vec::new();
966 let mut index: HashMap<GroupKey<'a>, usize> = HashMap::new();
967 for partial in partials {
968 for (key, state) in partial {
969 match index.get(&key) {
970 Some(&i) => {
971 let dst = &mut order[i].1;
972 for (a, b) in dst.accs.iter_mut().zip(state.accs.into_iter()) {
973 a.merge(b);
974 }
975 }
976 None => {
977 order.push((key.clone(), state));
978 index.insert(key, order.len() - 1);
979 }
980 }
981 }
982 }
983 order
984 } else {
985 accumulate_chunk(rows)
986 };
987
988 Some(merged.into_iter().map(|(_, s)| s).collect())
989}
990
991pub fn execute_aggregate(
996 select: &SelectStmt,
997 rows: &[HashMap<String, SochValue>],
998 params: &[SochValue],
999 limit: Option<usize>,
1000 offset: Option<usize>,
1001) -> SqlResult<ExecutionResult> {
1002 let specs = collect_agg_specs(select);
1003 let grouped = !select.group_by.is_empty();
1004
1005 let mut order: Vec<GroupState> = match accumulate_fast(select, &specs, rows) {
1010 Some(states) => states,
1011 None => {
1012 let mut order: Vec<GroupState> = Vec::new();
1013 let mut index: HashMap<Vec<String>, usize> = HashMap::new();
1014
1015 for row in rows {
1016 let key_values: Vec<SochValue> = select
1017 .group_by
1018 .iter()
1019 .map(|e| eval_scalar(e, row, params))
1020 .collect();
1021 let hash_key: Vec<String> = key_values.iter().map(key_repr).collect();
1022
1023 let idx = match index.get(&hash_key) {
1024 Some(&i) => i,
1025 None => {
1026 let state = GroupState {
1027 key_values,
1028 first_row: row.clone(),
1029 accs: specs.iter().map(Acc::new).collect(),
1030 };
1031 order.push(state);
1032 index.insert(hash_key, order.len() - 1);
1033 order.len() - 1
1034 }
1035 };
1036
1037 let state = &mut order[idx];
1038 for (acc, spec) in state.accs.iter_mut().zip(specs.iter()) {
1039 match &spec.arg {
1040 None => acc.update(None),
1041 Some(arg) => {
1042 let v = eval_scalar(arg, row, params);
1043 acc.update(Some(&v));
1044 }
1045 }
1046 }
1047 }
1048 order
1049 }
1050 };
1051
1052 if !grouped && order.is_empty() {
1054 order.push(GroupState {
1055 key_values: Vec::new(),
1056 first_row: HashMap::new(),
1057 accs: specs.iter().map(Acc::new).collect(),
1058 });
1059 }
1060
1061 let group_names: Vec<String> = select.group_by.iter().map(render_expr_name).collect();
1063
1064 let mut out_rows: Vec<HashMap<String, SochValue>> = Vec::with_capacity(order.len());
1065 for state in order {
1066 let mut row = state.first_row;
1069 for (name, val) in group_names.iter().zip(state.key_values.into_iter()) {
1070 row.insert(name.clone(), val);
1071 }
1072 for (spec, acc) in specs.iter().zip(state.accs.into_iter()) {
1073 row.insert(spec.key.clone(), acc.finalize());
1074 }
1075 out_rows.push(row);
1076 }
1077
1078 if let Some(having) = &select.having {
1080 out_rows.retain(|row| matches!(eval_scalar(having, row, params), SochValue::Bool(true)));
1081 }
1082
1083 if !select.order_by.is_empty() {
1085 let alias_map: Vec<(String, Expr)> = select
1087 .columns
1088 .iter()
1089 .filter_map(|item| match item {
1090 SelectItem::Expr {
1091 expr,
1092 alias: Some(a),
1093 } => Some((a.clone(), expr.clone())),
1094 _ => None,
1095 })
1096 .collect();
1097 for row in &mut out_rows {
1098 for (alias, expr) in &alias_map {
1099 if !row.contains_key(alias) {
1100 let v = eval_scalar(expr, row, params);
1101 row.insert(alias.clone(), v);
1102 }
1103 }
1104 }
1105 out_rows.sort_by(|a, b| {
1106 for item in &select.order_by {
1107 let va = eval_scalar(&item.expr, a, params);
1108 let vb = eval_scalar(&item.expr, b, params);
1109 let mut cmp = compare_values(&va, &vb);
1110 if !item.asc {
1111 cmp = cmp.reverse();
1112 }
1113 if cmp != std::cmp::Ordering::Equal {
1114 return cmp;
1115 }
1116 }
1117 std::cmp::Ordering::Equal
1118 });
1119 }
1120
1121 if let Some(off) = offset {
1123 if off > 0 {
1124 out_rows.drain(..off.min(out_rows.len()));
1125 }
1126 }
1127 if let Some(lim) = limit {
1128 out_rows.truncate(lim);
1129 }
1130
1131 let mut columns: Vec<String> = Vec::new();
1133 let mut projections: Vec<(String, Expr)> = Vec::new();
1134 for item in &select.columns {
1135 match item {
1136 SelectItem::Wildcard | SelectItem::QualifiedWildcard(_) => {
1137 for name in &group_names {
1139 columns.push(name.clone());
1140 projections.push((name.clone(), Expr::Column(ColumnRef::new(name.clone()))));
1141 }
1142 for spec in &specs {
1143 columns.push(spec.key.clone());
1144 projections.push((
1145 spec.key.clone(),
1146 Expr::Column(ColumnRef::new(spec.key.clone())),
1147 ));
1148 }
1149 }
1150 SelectItem::Expr { expr, alias } => {
1151 let name = alias.clone().unwrap_or_else(|| render_expr_name(expr));
1152 columns.push(name.clone());
1153 projections.push((name, expr.clone()));
1154 }
1155 }
1156 }
1157
1158 let projected: Vec<HashMap<String, SochValue>> = out_rows
1159 .into_iter()
1160 .map(|row| {
1161 let mut out = HashMap::with_capacity(projections.len());
1162 for (name, expr) in &projections {
1163 let v = eval_scalar(expr, &row, params);
1164 out.insert(name.clone(), v);
1165 }
1166 out
1167 })
1168 .collect();
1169
1170 Ok(ExecutionResult::Rows {
1171 columns,
1172 rows: projected,
1173 })
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use super::super::bridge::{SqlBridge, SqlConnection};
1179 use super::*;
1180
1181 fn fcall(name: &str, arg: &str) -> Expr {
1182 Expr::Function(FunctionCall {
1183 name: ObjectName::new(name),
1184 args: vec![Expr::Column(ColumnRef::new(arg))],
1185 distinct: false,
1186 filter: None,
1187 over: None,
1188 })
1189 }
1190
1191 #[test]
1192 fn agg_fn_recognition() {
1193 assert_eq!(AggFn::from_name("median"), Some(AggFn::Median));
1194 assert_eq!(AggFn::from_name("STDDEV"), Some(AggFn::Stddev));
1195 assert_eq!(AggFn::from_name("stddev_samp"), Some(AggFn::Stddev));
1196 assert_eq!(AggFn::from_name("upper"), None);
1197 }
1198
1199 #[test]
1200 fn canonical_keys() {
1201 assert_eq!(render_expr_name(&fcall("SUM", "v1")), "sum(v1)");
1202 assert_eq!(render_expr_name(&fcall("Median", "v3")), "median(v3)");
1203 }
1204
1205 struct DataConn {
1211 tables: HashMap<String, Vec<HashMap<String, SochValue>>>,
1212 }
1213
1214 impl DataConn {
1215 fn new() -> Self {
1216 Self {
1217 tables: HashMap::new(),
1218 }
1219 }
1220
1221 fn with_table(mut self, name: &str, cols: &[&str], rows: Vec<Vec<SochValue>>) -> Self {
1222 let rows = rows
1223 .into_iter()
1224 .map(|vals| {
1225 cols.iter()
1226 .map(|c| c.to_string())
1227 .zip(vals.into_iter())
1228 .collect::<HashMap<_, _>>()
1229 })
1230 .collect();
1231 self.tables.insert(name.to_string(), rows);
1232 self
1233 }
1234 }
1235
1236 impl SqlConnection for DataConn {
1237 fn select(
1238 &self,
1239 table: &str,
1240 _: &[String],
1241 _where_clause: Option<&Expr>,
1242 _: &[OrderByItem],
1243 _: Option<usize>,
1244 _: Option<usize>,
1245 _: &[SochValue],
1246 ) -> SqlResult<ExecutionResult> {
1247 let rows = self.tables.get(table).cloned().unwrap_or_default();
1249 Ok(ExecutionResult::Rows {
1250 columns: vec![],
1251 rows,
1252 })
1253 }
1254 fn insert(
1255 &mut self,
1256 _: &str,
1257 _: Option<&[String]>,
1258 _: &[Vec<Expr>],
1259 _: Option<&OnConflict>,
1260 _: &[SochValue],
1261 ) -> SqlResult<ExecutionResult> {
1262 Ok(ExecutionResult::RowsAffected(0))
1263 }
1264 fn update(
1265 &mut self,
1266 _: &str,
1267 _: &[Assignment],
1268 _: Option<&Expr>,
1269 _: &[SochValue],
1270 ) -> SqlResult<ExecutionResult> {
1271 Ok(ExecutionResult::RowsAffected(0))
1272 }
1273 fn delete(
1274 &mut self,
1275 _: &str,
1276 _: Option<&Expr>,
1277 _: &[SochValue],
1278 ) -> SqlResult<ExecutionResult> {
1279 Ok(ExecutionResult::RowsAffected(0))
1280 }
1281 fn create_table(&mut self, _: &CreateTableStmt) -> SqlResult<ExecutionResult> {
1282 Ok(ExecutionResult::Ok)
1283 }
1284 fn drop_table(&mut self, _: &DropTableStmt) -> SqlResult<ExecutionResult> {
1285 Ok(ExecutionResult::Ok)
1286 }
1287 fn create_index(&mut self, _: &CreateIndexStmt) -> SqlResult<ExecutionResult> {
1288 Ok(ExecutionResult::Ok)
1289 }
1290 fn drop_index(&mut self, _: &DropIndexStmt) -> SqlResult<ExecutionResult> {
1291 Ok(ExecutionResult::Ok)
1292 }
1293 fn alter_table(&mut self, _: &AlterTableStmt) -> SqlResult<ExecutionResult> {
1294 Ok(ExecutionResult::Ok)
1295 }
1296 fn begin(&mut self, _: &BeginStmt) -> SqlResult<ExecutionResult> {
1297 Ok(ExecutionResult::TransactionOk)
1298 }
1299 fn commit(&mut self) -> SqlResult<ExecutionResult> {
1300 Ok(ExecutionResult::TransactionOk)
1301 }
1302 fn rollback(&mut self, _: Option<&str>) -> SqlResult<ExecutionResult> {
1303 Ok(ExecutionResult::TransactionOk)
1304 }
1305 fn table_exists(&self, t: &str) -> SqlResult<bool> {
1306 Ok(self.tables.contains_key(t))
1307 }
1308 fn index_exists(&self, _: &str) -> SqlResult<bool> {
1309 Ok(false)
1310 }
1311 fn scan_all(
1312 &self,
1313 table: &str,
1314 _: &[String],
1315 ) -> SqlResult<Vec<HashMap<String, SochValue>>> {
1316 Ok(self.tables.get(table).cloned().unwrap_or_default())
1317 }
1318 fn eval_join_predicate(
1319 &self,
1320 expr: &Expr,
1321 row: &HashMap<String, SochValue>,
1322 params: &[SochValue],
1323 ) -> Option<bool> {
1324 match eval_scalar(expr, row, params) {
1325 SochValue::Bool(b) => Some(b),
1326 SochValue::Null => Some(false),
1327 _ => None,
1328 }
1329 }
1330 }
1331
1332 fn i(v: i64) -> SochValue {
1333 SochValue::Int(v)
1334 }
1335 fn f(v: f64) -> SochValue {
1336 SochValue::Float(v)
1337 }
1338 fn t(v: &str) -> SochValue {
1339 SochValue::Text(v.to_string())
1340 }
1341
1342 fn bench_bridge() -> SqlBridge<DataConn> {
1344 let conn = DataConn::new().with_table(
1345 "x",
1346 &["id1", "id3", "v1", "v2", "v3"],
1347 vec![
1348 vec![t("id001"), t("id0000001"), i(1), i(10), f(1.0)],
1349 vec![t("id001"), t("id0000002"), i(2), i(20), f(2.0)],
1350 vec![t("id002"), t("id0000001"), i(3), i(30), f(3.0)],
1351 vec![t("id002"), t("id0000002"), i(4), i(40), f(4.0)],
1352 ],
1353 );
1354 SqlBridge::new(conn)
1355 }
1356
1357 fn rows_of(result: ExecutionResult) -> Vec<HashMap<String, SochValue>> {
1358 match result {
1359 ExecutionResult::Rows { rows, .. } => rows,
1360 other => panic!("expected rows, got {:?}", other),
1361 }
1362 }
1363
1364 fn get<'a>(row: &'a HashMap<String, SochValue>, k: &str) -> &'a SochValue {
1365 row.get(k)
1366 .unwrap_or_else(|| panic!("column '{}' missing from {:?}", k, row))
1367 }
1368
1369 #[test]
1370 fn groupby_sum_q1_shape() {
1371 let mut b = bench_bridge();
1373 let rows = rows_of(
1374 b.execute("SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1 ORDER BY id1")
1375 .unwrap(),
1376 );
1377 assert_eq!(rows.len(), 2);
1378 assert_eq!(get(&rows[0], "id1"), &t("id001"));
1379 assert_eq!(get(&rows[0], "v1"), &i(3));
1380 assert_eq!(get(&rows[1], "id1"), &t("id002"));
1381 assert_eq!(get(&rows[1], "v1"), &i(7));
1382 }
1383
1384 #[test]
1385 fn groupby_multi_key_mean() {
1386 let mut b = bench_bridge();
1388 let rows = rows_of(
1389 b.execute("SELECT id1, id3, avg(v1) AS m FROM x GROUP BY id1, id3 ORDER BY id1, id3")
1390 .unwrap(),
1391 );
1392 assert_eq!(rows.len(), 4);
1393 assert_eq!(get(&rows[0], "m"), &f(1.0));
1394 assert_eq!(get(&rows[3], "m"), &f(4.0));
1395 }
1396
1397 #[test]
1398 fn median_and_stddev() {
1399 let mut b = bench_bridge();
1402 let rows = rows_of(
1403 b.execute("SELECT median(v3) AS med, stddev(v3) AS sd FROM x")
1404 .unwrap(),
1405 );
1406 assert_eq!(rows.len(), 1);
1407 assert_eq!(get(&rows[0], "med"), &f(2.5));
1408 match get(&rows[0], "sd") {
1409 SochValue::Float(sd) => {
1410 assert!((sd - (5.0f64 / 3.0).sqrt()).abs() < 1e-12, "sd={}", sd)
1411 }
1412 other => panic!("expected float sd, got {:?}", other),
1413 }
1414 }
1415
1416 #[test]
1417 fn median_odd_count() {
1418 let conn =
1419 DataConn::new().with_table("t", &["v"], vec![vec![f(5.0)], vec![f(1.0)], vec![f(3.0)]]);
1420 let mut b = SqlBridge::new(conn);
1421 let rows = rows_of(b.execute("SELECT median(v) AS m FROM t").unwrap());
1422 assert_eq!(get(&rows[0], "m"), &f(3.0));
1423 }
1424
1425 #[test]
1426 fn range_expression_q9_shape() {
1427 let mut b = bench_bridge();
1429 let rows = rows_of(
1430 b.execute(
1431 "SELECT id3, max(v1) - min(v2) AS range_v1_v2 FROM x GROUP BY id3 ORDER BY id3",
1432 )
1433 .unwrap(),
1434 );
1435 assert_eq!(rows.len(), 2);
1436 assert_eq!(get(&rows[0], "range_v1_v2"), &i(-7));
1438 assert_eq!(get(&rows[1], "range_v1_v2"), &i(-16));
1440 }
1441
1442 #[test]
1443 fn count_star_vs_count_col_with_nulls() {
1444 let conn = DataConn::new().with_table(
1445 "t",
1446 &["g", "v"],
1447 vec![
1448 vec![t("a"), i(1)],
1449 vec![t("a"), SochValue::Null],
1450 vec![t("b"), i(2)],
1451 ],
1452 );
1453 let mut b = SqlBridge::new(conn);
1454 let rows = rows_of(
1455 b.execute("SELECT g, count(*) AS n, count(v) AS nv FROM t GROUP BY g ORDER BY g")
1456 .unwrap(),
1457 );
1458 assert_eq!(rows.len(), 2);
1459 assert_eq!(get(&rows[0], "n"), &i(2));
1460 assert_eq!(get(&rows[0], "nv"), &i(1));
1461 assert_eq!(get(&rows[1], "n"), &i(1));
1462 assert_eq!(get(&rows[1], "nv"), &i(1));
1463 }
1464
1465 #[test]
1466 fn count_distinct() {
1467 let mut b = bench_bridge();
1469 let rows = rows_of(
1470 b.execute("SELECT id3, count(DISTINCT id1) AS u FROM x GROUP BY id3 ORDER BY id3")
1471 .unwrap(),
1472 );
1473 assert_eq!(rows.len(), 2);
1474 assert_eq!(get(&rows[0], "u"), &i(2));
1475 assert_eq!(get(&rows[1], "u"), &i(2));
1476 }
1477
1478 #[test]
1479 fn having_filters_groups() {
1480 let mut b = bench_bridge();
1481 let rows = rows_of(
1482 b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 HAVING sum(v1) > 5")
1483 .unwrap(),
1484 );
1485 assert_eq!(rows.len(), 1);
1486 assert_eq!(get(&rows[0], "id1"), &t("id002"));
1487 assert_eq!(get(&rows[0], "s"), &i(7));
1488 }
1489
1490 #[test]
1491 fn order_by_aggregate_desc_with_limit() {
1492 let mut b = bench_bridge();
1493 let rows = rows_of(
1494 b.execute("SELECT id1, sum(v1) AS s FROM x GROUP BY id1 ORDER BY s DESC LIMIT 1")
1495 .unwrap(),
1496 );
1497 assert_eq!(rows.len(), 1);
1498 assert_eq!(get(&rows[0], "id1"), &t("id002"));
1499 }
1500
1501 #[test]
1502 fn ungrouped_aggregate_over_empty_table() {
1503 let conn = DataConn::new().with_table("e", &["v"], vec![]);
1504 let mut b = SqlBridge::new(conn);
1505 let rows = rows_of(
1506 b.execute("SELECT count(*) AS n, sum(v) AS s FROM e")
1507 .unwrap(),
1508 );
1509 assert_eq!(rows.len(), 1, "ungrouped agg over empty input = one row");
1510 assert_eq!(get(&rows[0], "n"), &i(0));
1511 assert_eq!(get(&rows[0], "s"), &SochValue::Null);
1512 }
1513
1514 #[test]
1515 fn grouped_aggregate_over_empty_table_yields_no_rows() {
1516 let conn = DataConn::new().with_table("e", &["g", "v"], vec![]);
1517 let mut b = SqlBridge::new(conn);
1518 let rows = rows_of(
1519 b.execute("SELECT g, sum(v) AS s FROM e GROUP BY g")
1520 .unwrap(),
1521 );
1522 assert!(rows.is_empty());
1523 }
1524
1525 #[test]
1526 fn sum_overflow_promotes_to_float() {
1527 let conn =
1528 DataConn::new().with_table("t", &["v"], vec![vec![i(i64::MAX)], vec![i(i64::MAX)]]);
1529 let mut b = SqlBridge::new(conn);
1530 let rows = rows_of(b.execute("SELECT sum(v) AS s FROM t").unwrap());
1531 match get(&rows[0], "s") {
1532 SochValue::Float(v) => assert!(*v > 1.8e19),
1533 other => panic!("expected float after overflow, got {:?}", other),
1534 }
1535 }
1536
1537 #[test]
1538 fn aggregate_after_join() {
1539 let conn = DataConn::new()
1541 .with_table(
1542 "a",
1543 &["id", "v"],
1544 vec![
1545 vec![t("k1"), i(1)],
1546 vec![t("k1"), i(2)],
1547 vec![t("k2"), i(3)],
1548 ],
1549 )
1550 .with_table(
1551 "b",
1552 &["id", "w"],
1553 vec![vec![t("k1"), i(10)], vec![t("k2"), i(20)]],
1554 );
1555 let mut br = SqlBridge::new(conn);
1556 let rows = rows_of(
1557 br.execute(
1558 "SELECT a.id, sum(a.v) AS sv, sum(b.w) AS sw \
1559 FROM a JOIN b ON a.id = b.id GROUP BY a.id ORDER BY a.id",
1560 )
1561 .unwrap(),
1562 );
1563 assert_eq!(rows.len(), 2);
1564 assert_eq!(get(&rows[0], "sv"), &i(3));
1565 assert_eq!(get(&rows[0], "sw"), &i(20)); assert_eq!(get(&rows[1], "sv"), &i(3));
1567 assert_eq!(get(&rows[1], "sw"), &i(20));
1568 }
1569
1570 #[test]
1571 fn lowercase_function_names_parse() {
1572 let mut b = bench_bridge();
1574 assert!(b.execute("SELECT id1, sum(v1) FROM x GROUP BY id1").is_ok());
1575 assert!(b.execute("SELECT median(v3) FROM x").is_ok());
1576 assert!(b.execute("SELECT stddev(v3) FROM x").is_ok());
1577 }
1578
1579 #[test]
1580 fn parallel_path_matches_reference_computation() {
1581 let n: usize = 150_000;
1585 let groups = 7usize;
1586 let mut data: Vec<Vec<SochValue>> = Vec::with_capacity(n);
1587 for idx in 0..n {
1588 data.push(vec![
1589 t(&format!("g{}", idx % groups)),
1590 f((idx * 31 % 1000) as f64 / 4.0),
1591 ]);
1592 }
1593 let mut per_group: Vec<Vec<f64>> = vec![Vec::new(); groups];
1595 for idx in 0..n {
1596 per_group[idx % groups].push((idx * 31 % 1000) as f64 / 4.0);
1597 }
1598
1599 let conn = DataConn::new().with_table("big", &["g", "v"], data);
1600 let mut b = SqlBridge::new(conn);
1601 let rows = rows_of(
1602 b.execute(
1603 "SELECT g, count(*) AS n, sum(v) AS s, avg(v) AS m, \
1604 median(v) AS med, stddev(v) AS sd FROM big GROUP BY g ORDER BY g",
1605 )
1606 .unwrap(),
1607 );
1608 assert_eq!(rows.len(), groups);
1609
1610 for (gi, row) in rows.iter().enumerate() {
1611 let vals = &per_group[gi];
1612 let cnt = vals.len() as f64;
1613 let sum: f64 = vals.iter().sum();
1614 let mean = sum / cnt;
1615 let var = vals.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / (cnt - 1.0);
1616 let mut sorted = vals.clone();
1617 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1618 let med = if sorted.len() % 2 == 1 {
1619 sorted[sorted.len() / 2]
1620 } else {
1621 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
1622 };
1623
1624 assert_eq!(get(row, "g"), &t(&format!("g{}", gi)));
1625 assert_eq!(get(row, "n"), &i(vals.len() as i64));
1626 match get(row, "s") {
1627 SochValue::Float(v) => assert!((v - sum).abs() < 1e-6, "sum"),
1628 other => panic!("sum type {:?}", other),
1629 }
1630 match get(row, "m") {
1631 SochValue::Float(v) => assert!((v - mean).abs() < 1e-9, "mean"),
1632 other => panic!("mean type {:?}", other),
1633 }
1634 match get(row, "med") {
1635 SochValue::Float(v) => assert!((v - med).abs() < 1e-9, "median"),
1636 other => panic!("median type {:?}", other),
1637 }
1638 match get(row, "sd") {
1639 SochValue::Float(v) => {
1640 assert!((v - var.sqrt()).abs() < 1e-9, "sd {} vs {}", v, var.sqrt())
1641 }
1642 other => panic!("sd type {:?}", other),
1643 }
1644 }
1645 }
1646
1647 #[test]
1648 fn unaliased_aggregate_column_name_is_canonical() {
1649 let mut b = bench_bridge();
1650 let result = b
1651 .execute("SELECT id1, sum(v1) FROM x GROUP BY id1")
1652 .unwrap();
1653 let cols = result.columns().unwrap().clone();
1654 assert!(cols.contains(&"sum(v1)".to_string()), "cols={:?}", cols);
1655 }
1656}