vibesql_executor/select/columnar/aggregate/
mod.rs

1//! Columnar aggregation - high-performance aggregate computation
2//!
3//! This module provides efficient aggregate operations on columnar data,
4//! organized into several sub-modules:
5//!
6//! - [`functions`] - Core aggregate function implementations (SUM, COUNT, AVG, MIN, MAX)
7//! - [`expression`] - Expression aggregates (e.g., SUM(a * b))
8//! - [`group_by`] - Hash-based GROUP BY aggregation
9//!
10//! The public API maintains backward compatibility while the implementation
11//! is split across focused modules for better maintainability.
12
13mod expression;
14mod functions;
15mod group_by;
16
17use crate::errors::ExecutorError;
18use crate::schema::CombinedSchema;
19use vibesql_ast::Expression;
20use vibesql_storage::Row;
21use vibesql_types::SqlValue;
22
23use super::batch::ColumnarBatch;
24use super::scan::ColumnarScan;
25
26// Re-export public types and functions to maintain API compatibility
27pub use expression::evaluate_expression_to_column;
28pub use expression::evaluate_expression_with_cached_column;
29pub use expression::extract_aggregates;
30pub use group_by::columnar_group_by;
31// SIMD-accelerated GROUP BY for ColumnarBatch - used in native columnar execution path
32pub use group_by::columnar_group_by_batch;
33
34/// Aggregate operation type
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum AggregateOp {
37    Sum,
38    Count,
39    Avg,
40    Min,
41    Max,
42}
43
44/// Source of data for an aggregate - either a simple column or an expression
45#[derive(Debug, Clone)]
46pub enum AggregateSource {
47    /// Simple column reference (fast path) - just a column index
48    Column(usize),
49    /// Complex expression that needs evaluation (e.g., a * b)
50    Expression(Expression),
51    /// COUNT(*) - count all rows (not tied to any specific column)
52    CountStar,
53}
54
55/// A complete aggregate specification
56#[derive(Debug, Clone)]
57pub struct AggregateSpec {
58    pub op: AggregateOp,
59    pub source: AggregateSource,
60}
61
62/// Compute an aggregate over a column with optional filtering
63///
64/// This is the core columnar aggregation function that processes
65/// columns directly without materializing Row objects.
66///
67/// Automatically detects when SIMD optimization is available (Int64/Float64 columns)
68/// and falls back to scalar implementation for other types.
69///
70/// # Arguments
71///
72/// * `scan` - Columnar scan over the data
73/// * `column_idx` - Index of the column to aggregate
74/// * `op` - Aggregate operation (SUM, COUNT, etc.)
75/// * `filter_bitmap` - Optional bitmap of which rows to include
76///
77/// # Returns
78///
79/// The aggregated SqlValue
80pub fn compute_columnar_aggregate(
81    scan: &ColumnarScan,
82    column_idx: usize,
83    op: AggregateOp,
84    filter_bitmap: Option<&[bool]>,
85) -> Result<SqlValue, ExecutorError> {
86    // Try SIMD path for numeric columns (5-10x speedup via LLVM auto-vectorization)
87    {
88        use super::simd_aggregate::{
89            can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64,
90        };
91
92        // Detect if column is SIMD-compatible
93        if let Some(is_integer) = can_use_simd_for_column(scan, column_idx) {
94            // Use SIMD implementation for Int64/Float64 columns
95            return if is_integer {
96                simd_aggregate_i64(scan, column_idx, op, filter_bitmap)
97            } else {
98                simd_aggregate_f64(scan, column_idx, op, filter_bitmap)
99            };
100        }
101        // Fall through to scalar path for non-SIMD types
102    }
103
104    // Scalar fallback path (always available, used for String, Date, etc.)
105    functions::compute_columnar_aggregate_impl(scan, column_idx, op, filter_bitmap)
106}
107
108/// Compute multiple aggregates in a single pass over the data
109///
110/// This is more efficient than computing each aggregate separately
111/// as it only scans the data once.
112pub fn compute_multiple_aggregates(
113    rows: &[Row],
114    aggregates: &[AggregateSpec],
115    filter_bitmap: Option<&[bool]>,
116    schema: Option<&CombinedSchema>,
117) -> Result<Vec<SqlValue>, ExecutorError> {
118    let scan = ColumnarScan::new(rows);
119    let mut results = Vec::with_capacity(aggregates.len());
120
121    for spec in aggregates {
122        let result = match &spec.source {
123            // Fast path: direct column aggregation
124            AggregateSource::Column(column_idx) => {
125                compute_columnar_aggregate(&scan, *column_idx, spec.op, filter_bitmap)?
126            }
127            // Expression path: evaluate expression for each row, then aggregate
128            AggregateSource::Expression(expr) => {
129                let schema = schema.ok_or_else(|| {
130                    ExecutorError::UnsupportedExpression(
131                        "Schema required for expression aggregates".to_string(),
132                    )
133                })?;
134                expression::compute_expression_aggregate(
135                    rows,
136                    expr,
137                    spec.op,
138                    filter_bitmap,
139                    schema,
140                )?
141            }
142            // COUNT(*) path: count all rows
143            AggregateSource::CountStar => functions::compute_count(&scan, filter_bitmap)?,
144        };
145        results.push(result);
146    }
147
148    Ok(results)
149}
150
151/// Compute aggregates directly from a ColumnarBatch (no row conversion)
152///
153/// This is a high-performance path that eliminates the overhead of converting
154/// ColumnarBatch back to rows before aggregation. It works directly on the
155/// typed column arrays for maximum efficiency.
156///
157/// # Arguments
158///
159/// * `batch` - The ColumnarBatch to aggregate over (typically filtered)
160/// * `aggregates` - List of aggregate specifications to compute
161/// * `schema` - Optional schema for expression aggregates
162///
163/// # Performance
164///
165/// This function provides 20-30% speedup over the row-based path by:
166/// - Avoiding `batch.to_rows()` conversion overhead (~10-15ms for large batches)
167/// - Working directly on typed arrays (Int64, Float64) with SIMD operations
168/// - Reducing memory allocations
169///
170/// # Returns
171///
172/// Vector of aggregate results in the same order as the input aggregates
173pub fn compute_aggregates_from_batch(
174    batch: &ColumnarBatch,
175    aggregates: &[AggregateSpec],
176    schema: Option<&CombinedSchema>,
177) -> Result<Vec<SqlValue>, ExecutorError> {
178    // Handle empty batch
179    if batch.row_count() == 0 {
180        return Ok(aggregates
181            .iter()
182            .map(|spec| match spec.op {
183                AggregateOp::Count => SqlValue::Integer(0),
184                _ => SqlValue::Null,
185            })
186            .collect());
187    }
188
189    let mut results = Vec::with_capacity(aggregates.len());
190
191    for spec in aggregates {
192        let result = match &spec.source {
193            // Fast path: direct batch aggregation (no row conversion)
194            AggregateSource::Column(column_idx) => {
195                functions::compute_batch_aggregate(batch, *column_idx, spec.op)?
196            }
197            // Expression path: SIMD-accelerated evaluation directly on batch columns
198            // This eliminates the ~10-15ms overhead of batch.to_rows() for large batches
199            AggregateSource::Expression(expr) => {
200                let schema = schema.ok_or_else(|| {
201                    ExecutorError::UnsupportedExpression(
202                        "Schema required for expression aggregates".to_string(),
203                    )
204                })?;
205                // Try batch-native expression evaluation first
206                // Falls back to row-based evaluation if column types aren't supported
207                match expression::compute_batch_expression_aggregate(batch, expr, spec.op, schema) {
208                    Ok(value) => value,
209                    Err(ExecutorError::UnsupportedExpression(_)) => {
210                        // Fall back to row-based for unsupported column types (Mixed, Date, etc.)
211                        let rows = batch.to_rows()?;
212                        expression::compute_expression_aggregate(
213                            &rows, expr, spec.op, None, schema,
214                        )?
215                    }
216                    Err(other) => return Err(other),
217                }
218            }
219            // COUNT(*) path: just count rows in batch
220            AggregateSource::CountStar => SqlValue::Integer(batch.row_count() as i64),
221        };
222        results.push(result);
223    }
224
225    Ok(results)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    fn make_test_rows() -> Vec<Row> {
233        vec![
234            Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
235            Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
236            Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
237        ]
238    }
239
240    #[test]
241    fn test_sum_aggregate() {
242        let rows = make_test_rows();
243        let scan = ColumnarScan::new(&rows);
244
245        let result = functions::compute_sum(&scan, 0, None).unwrap();
246        assert_eq!(result, SqlValue::Integer(60));
247
248        let result = functions::compute_sum(&scan, 1, None).unwrap();
249        assert!(matches!(result, SqlValue::Double(sum) if (sum - 7.5).abs() < 0.001));
250    }
251
252    #[test]
253    fn test_count_aggregate() {
254        let rows = make_test_rows();
255        let scan = ColumnarScan::new(&rows);
256
257        let result = functions::compute_count(&scan, None).unwrap();
258        assert_eq!(result, SqlValue::Integer(3));
259    }
260
261    #[test]
262    fn test_sum_with_filter() {
263        let rows = make_test_rows();
264        let scan = ColumnarScan::new(&rows);
265        let filter = vec![true, false, true]; // Include rows 0 and 2
266
267        let result = functions::compute_sum(&scan, 0, Some(&filter)).unwrap();
268        assert_eq!(result, SqlValue::Integer(40));
269    }
270
271    #[test]
272    fn test_multiple_aggregates() {
273        let rows = make_test_rows();
274        let aggregates = vec![
275            AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
276            AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
277        ];
278
279        let results = compute_multiple_aggregates(&rows, &aggregates, None, None).unwrap();
280        assert_eq!(results.len(), 2);
281        assert_eq!(results[0], SqlValue::Integer(60));
282        assert!(matches!(results[1], SqlValue::Double(avg) if (avg - 2.5).abs() < 0.001));
283    }
284
285    #[test]
286    fn test_extract_aggregates_simple() {
287        use crate::schema::CombinedSchema;
288        use vibesql_catalog::{ColumnSchema, TableSchema};
289        use vibesql_types::DataType;
290
291        // Create a simple schema with two columns
292        let schema = TableSchema::new(
293            "test".to_string(),
294            vec![
295                ColumnSchema::new("col1".to_string(), DataType::Integer, false),
296                ColumnSchema::new("col2".to_string(), DataType::DoublePrecision, false),
297            ],
298        );
299
300        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
301
302        // Test SUM(col1)
303        let exprs = vec![Expression::AggregateFunction {
304            name: "SUM".to_string(),
305            distinct: false,
306            args: vec![Expression::ColumnRef { table: None, column: "col1".to_string() }],
307        }];
308
309        let result = extract_aggregates(&exprs, &combined_schema);
310        assert!(result.is_some());
311        let aggregates = result.unwrap();
312        assert_eq!(aggregates.len(), 1);
313        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
314        assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
315
316        // Test COUNT(*)
317        let exprs = vec![Expression::AggregateFunction {
318            name: "COUNT".to_string(),
319            distinct: false,
320            args: vec![Expression::Wildcard],
321        }];
322
323        let result = extract_aggregates(&exprs, &combined_schema);
324        assert!(result.is_some());
325        let aggregates = result.unwrap();
326        assert_eq!(aggregates.len(), 1);
327        assert!(matches!(aggregates[0].op, AggregateOp::Count));
328        assert!(matches!(aggregates[0].source, AggregateSource::CountStar));
329
330        // Test multiple aggregates: SUM(col1), AVG(col2)
331        let exprs = vec![
332            Expression::AggregateFunction {
333                name: "SUM".to_string(),
334                distinct: false,
335                args: vec![Expression::ColumnRef { table: None, column: "col1".to_string() }],
336            },
337            Expression::AggregateFunction {
338                name: "AVG".to_string(),
339                distinct: false,
340                args: vec![Expression::ColumnRef { table: None, column: "col2".to_string() }],
341            },
342        ];
343
344        let result = extract_aggregates(&exprs, &combined_schema);
345        assert!(result.is_some());
346        let aggregates = result.unwrap();
347        assert_eq!(aggregates.len(), 2);
348        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
349        assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
350        assert!(matches!(aggregates[1].op, AggregateOp::Avg));
351        assert!(matches!(aggregates[1].source, AggregateSource::Column(1)));
352    }
353
354    #[test]
355    fn test_extract_aggregates_unsupported() {
356        use crate::schema::CombinedSchema;
357        use vibesql_catalog::{ColumnSchema, TableSchema};
358        use vibesql_types::DataType;
359
360        let schema = TableSchema::new(
361            "test".to_string(),
362            vec![ColumnSchema::new("col1".to_string(), DataType::Integer, false)],
363        );
364
365        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
366
367        // Test DISTINCT aggregate (should return None)
368        let exprs = vec![Expression::AggregateFunction {
369            name: "SUM".to_string(),
370            distinct: true,
371            args: vec![Expression::ColumnRef { table: None, column: "col1".to_string() }],
372        }];
373
374        let result = extract_aggregates(&exprs, &combined_schema);
375        assert!(result.is_none());
376
377        // Test non-aggregate expression (should return None)
378        let exprs = vec![Expression::ColumnRef { table: None, column: "col1".to_string() }];
379
380        let result = extract_aggregates(&exprs, &combined_schema);
381        assert!(result.is_none());
382
383        // Test subquery in aggregate (should return None - not supported)
384        let exprs = vec![Expression::AggregateFunction {
385            name: "SUM".to_string(),
386            distinct: false,
387            args: vec![Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
388                with_clause: None,
389                distinct: false,
390                select_list: vec![],
391                into_table: None,
392                into_variables: None,
393                from: None,
394                where_clause: None,
395                group_by: None,
396                having: None,
397                order_by: None,
398                limit: None,
399                offset: None,
400                set_operation: None,
401            }))],
402        }];
403
404        let result = extract_aggregates(&exprs, &combined_schema);
405        assert!(result.is_none());
406    }
407
408    #[test]
409    fn test_extract_aggregates_with_expression() {
410        use crate::schema::CombinedSchema;
411        use vibesql_catalog::{ColumnSchema, TableSchema};
412        use vibesql_types::DataType;
413
414        // Create a simple schema with two columns
415        let schema = TableSchema::new(
416            "test".to_string(),
417            vec![
418                ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
419                ColumnSchema::new("discount".to_string(), DataType::DoublePrecision, false),
420            ],
421        );
422
423        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
424
425        // Test SUM(price * discount) - simple binary operation
426        let exprs = vec![Expression::AggregateFunction {
427            name: "SUM".to_string(),
428            distinct: false,
429            args: vec![Expression::BinaryOp {
430                left: Box::new(Expression::ColumnRef { table: None, column: "price".to_string() }),
431                op: vibesql_ast::BinaryOperator::Multiply,
432                right: Box::new(Expression::ColumnRef {
433                    table: None,
434                    column: "discount".to_string(),
435                }),
436            }],
437        }];
438
439        let result = extract_aggregates(&exprs, &combined_schema);
440        assert!(result.is_some());
441        let aggregates = result.unwrap();
442        assert_eq!(aggregates.len(), 1);
443        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
444        assert!(matches!(aggregates[0].source, AggregateSource::Expression(_)));
445    }
446
447    // GROUP BY tests
448
449    #[test]
450    fn test_columnar_group_by_simple() {
451        // Test simple GROUP BY with one group column
452        // SELECT status, SUM(amount) FROM test GROUP BY status
453        let rows = vec![
454            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
455            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
456            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
457            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
458        ];
459
460        let group_cols = vec![0]; // Group by status
461        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM(amount)
462
463        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
464
465        // Should have 2 groups: A and B
466        assert_eq!(result.len(), 2);
467
468        // Sort results by group key for deterministic testing
469        let mut sorted = result;
470        sorted.sort_by(|a, b| {
471            let a_key = a.get(0).unwrap();
472            let b_key = b.get(0).unwrap();
473            a_key.partial_cmp(b_key).unwrap()
474        });
475
476        // Check group A: SUM = 250.0
477        assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
478        assert!(
479            matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
480        );
481
482        // Check group B: SUM = 250.0
483        assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
484        assert!(
485            matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
486        );
487    }
488
489    #[test]
490    fn test_columnar_group_by_multiple_group_keys() {
491        // Test GROUP BY with multiple columns
492        // SELECT status, category, COUNT(*) FROM test GROUP BY status, category
493        let rows = vec![
494            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
495            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(2)]),
496            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(1)]),
497            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Integer(1)]),
498            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Integer(2)]),
499        ];
500
501        let group_cols = vec![0, 1]; // Group by status, category
502        let agg_cols = vec![(0, AggregateOp::Count)]; // COUNT(*)
503
504        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
505
506        // Should have 4 groups: (A,1), (A,2), (B,1), (B,2)
507        assert_eq!(result.len(), 4);
508
509        // Verify we have correct counts
510        for row in &result {
511            let status = row.get(0).unwrap();
512            let category = row.get(1).unwrap();
513            let count = row.get(2).unwrap();
514
515            match (status, category) {
516                (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "A" => {
517                    assert_eq!(count, &SqlValue::Integer(2)); // Two rows with A,1
518                }
519                (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "A" => {
520                    assert_eq!(count, &SqlValue::Integer(1)); // One row with A,2
521                }
522                (SqlValue::Varchar(s), SqlValue::Integer(1)) if s == "B" => {
523                    assert_eq!(count, &SqlValue::Integer(1)); // One row with B,1
524                }
525                (SqlValue::Varchar(s), SqlValue::Integer(2)) if s == "B" => {
526                    assert_eq!(count, &SqlValue::Integer(1)); // One row with B,2
527                }
528                _ => panic!("Unexpected group key: {:?}, {:?}", status, category),
529            }
530        }
531    }
532
533    #[test]
534    fn test_columnar_group_by_multiple_aggregates() {
535        // Test GROUP BY with multiple aggregate functions
536        // SELECT category, SUM(price), AVG(quantity), COUNT(*) FROM test GROUP BY category
537        let rows = vec![
538            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(100.0), SqlValue::Integer(10)]),
539            Row::new(vec![SqlValue::Integer(2), SqlValue::Double(200.0), SqlValue::Integer(20)]),
540            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(150.0), SqlValue::Integer(15)]),
541        ];
542
543        let group_cols = vec![0]; // Group by category
544        let agg_cols = vec![
545            (1, AggregateOp::Sum),   // SUM(price)
546            (2, AggregateOp::Avg),   // AVG(quantity)
547            (0, AggregateOp::Count), // COUNT(*)
548        ];
549
550        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
551
552        // Should have 2 groups
553        assert_eq!(result.len(), 2);
554
555        // Sort by category for deterministic testing
556        let mut sorted = result;
557        sorted.sort_by(|a, b| {
558            let a_key = a.get(0).unwrap();
559            let b_key = b.get(0).unwrap();
560            a_key.partial_cmp(b_key).unwrap()
561        });
562
563        // Group 1: SUM=250, AVG=12.5, COUNT=2
564        assert_eq!(sorted[0].get(0), Some(&SqlValue::Integer(1)));
565        assert!(
566            matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
567        );
568        assert!(
569            matches!(sorted[0].get(2), Some(&SqlValue::Double(avg)) if (avg - 12.5).abs() < 0.001)
570        );
571        assert_eq!(sorted[0].get(3), Some(&SqlValue::Integer(2)));
572
573        // Group 2: SUM=200, AVG=20.0, COUNT=1
574        assert_eq!(sorted[1].get(0), Some(&SqlValue::Integer(2)));
575        assert!(
576            matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
577        );
578        assert!(
579            matches!(sorted[1].get(2), Some(&SqlValue::Double(avg)) if (avg - 20.0).abs() < 0.001)
580        );
581        assert_eq!(sorted[1].get(3), Some(&SqlValue::Integer(1)));
582    }
583
584    #[test]
585    fn test_columnar_group_by_with_filter() {
586        // Test GROUP BY with pre-filtering
587        // SELECT status, SUM(amount) FROM test WHERE amount > 100 GROUP BY status
588        let rows = vec![
589            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
590            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(200.0)]),
591            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
592            Row::new(vec![SqlValue::Varchar("B".to_string()), SqlValue::Double(50.0)]),
593        ];
594
595        // Filter: amount > 100 (rows 1 and 2)
596        let filter = vec![false, true, true, false];
597
598        let group_cols = vec![0]; // Group by status
599        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM(amount)
600
601        let result = columnar_group_by(&rows, &group_cols, &agg_cols, Some(&filter)).unwrap();
602
603        // Should have 2 groups (only rows passing filter)
604        assert_eq!(result.len(), 2);
605
606        // Sort results by group key
607        let mut sorted = result;
608        sorted.sort_by(|a, b| {
609            let a_key = a.get(0).unwrap();
610            let b_key = b.get(0).unwrap();
611            a_key.partial_cmp(b_key).unwrap()
612        });
613
614        // Check group A: only row 2 (150.0) passes filter
615        assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar("A".to_string())));
616        assert!(
617            matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 150.0).abs() < 0.001)
618        );
619
620        // Check group B: only row 1 (200.0) passes filter
621        assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar("B".to_string())));
622        assert!(
623            matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
624        );
625    }
626
627    #[test]
628    fn test_columnar_group_by_empty_input() {
629        let rows: Vec<Row> = vec![];
630        let group_cols = vec![0];
631        let agg_cols = vec![(1, AggregateOp::Sum)];
632
633        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
634
635        // Should return empty result for empty input
636        assert_eq!(result.len(), 0);
637    }
638
639    #[test]
640    fn test_columnar_group_by_null_in_group_key() {
641        // Test that NULL values in group keys are handled correctly
642        let rows = vec![
643            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(100.0)]),
644            Row::new(vec![SqlValue::Null, SqlValue::Double(200.0)]),
645            Row::new(vec![SqlValue::Varchar("A".to_string()), SqlValue::Double(150.0)]),
646            Row::new(vec![SqlValue::Null, SqlValue::Double(50.0)]),
647        ];
648
649        let group_cols = vec![0]; // Group by first column
650        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM
651
652        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
653
654        // Should have 2 groups: "A" and NULL
655        assert_eq!(result.len(), 2);
656
657        // Find the groups
658        let a_group =
659            result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Varchar(s)) if s == "A"));
660        let null_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Null)));
661
662        assert!(a_group.is_some());
663        assert!(null_group.is_some());
664
665        // Check "A" group: 100 + 150 = 250
666        assert!(
667            matches!(a_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
668        );
669
670        // Check NULL group: 200 + 50 = 250
671        assert!(
672            matches!(null_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
673        );
674    }
675}