vibesql_executor/select/columnar/aggregate/
mod.rs

1//! Columnar aggregation - high-performance aggregate computation
2//!
3//! This module provides efficient aggregate operations on columnar data,
4//! organized into several sub-modules:
5//!
6//! - [`functions`] - Core aggregate function implementations (SUM, COUNT, AVG, MIN, MAX)
7//! - [`expression`] - Expression aggregates (e.g., SUM(a * b))
8//! - [`group_by`] - Hash-based GROUP BY aggregation
9//!
10//! The public API maintains backward compatibility while the implementation
11//! is split across focused modules for better maintainability.
12
13mod expression;
14mod functions;
15mod group_by;
16
17// Re-export public types and functions to maintain API compatibility
18pub use expression::{
19    evaluate_expression_to_column, evaluate_expression_with_cached_column, extract_aggregates,
20};
21pub use group_by::columnar_group_by;
22// SIMD-accelerated GROUP BY for ColumnarBatch - used in native columnar execution path
23pub use group_by::columnar_group_by_batch;
24use vibesql_ast::Expression;
25use vibesql_storage::Row;
26use vibesql_types::SqlValue;
27
28use super::{batch::ColumnarBatch, scan::ColumnarScan};
29use crate::{errors::ExecutorError, schema::CombinedSchema};
30
31/// Aggregate operation type
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum AggregateOp {
34    Sum,
35    Count,
36    Avg,
37    Min,
38    Max,
39}
40
41/// Source of data for an aggregate - either a simple column or an expression
42#[derive(Debug, Clone, PartialEq)]
43pub enum AggregateSource {
44    /// Simple column reference (fast path) - just a column index
45    Column(usize),
46    /// Complex expression that needs evaluation (e.g., a * b)
47    Expression(Expression),
48    /// COUNT(*) - count all rows (not tied to any specific column)
49    CountStar,
50}
51
52/// A complete aggregate specification
53#[derive(Debug, Clone)]
54pub struct AggregateSpec {
55    pub op: AggregateOp,
56    pub source: AggregateSource,
57}
58
59/// Compute an aggregate over a column with optional filtering
60///
61/// This is the core columnar aggregation function that processes
62/// columns directly without materializing Row objects.
63///
64/// Automatically detects when SIMD optimization is available (Int64/Float64 columns)
65/// and falls back to scalar implementation for other types.
66///
67/// # Arguments
68///
69/// * `scan` - Columnar scan over the data
70/// * `column_idx` - Index of the column to aggregate
71/// * `op` - Aggregate operation (SUM, COUNT, etc.)
72/// * `filter_bitmap` - Optional bitmap of which rows to include
73///
74/// # Returns
75///
76/// The aggregated SqlValue
77pub fn compute_columnar_aggregate(
78    scan: &ColumnarScan,
79    column_idx: usize,
80    op: AggregateOp,
81    filter_bitmap: Option<&[bool]>,
82) -> Result<SqlValue, ExecutorError> {
83    // Try SIMD path for numeric columns (5-10x speedup via LLVM auto-vectorization)
84    {
85        use super::simd_aggregate::{
86            can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64,
87        };
88
89        // Detect if column is SIMD-compatible
90        if let Some(is_integer) = can_use_simd_for_column(scan, column_idx) {
91            // Use SIMD implementation for Int64/Float64 columns
92            return if is_integer {
93                simd_aggregate_i64(scan, column_idx, op, filter_bitmap)
94            } else {
95                simd_aggregate_f64(scan, column_idx, op, filter_bitmap)
96            };
97        }
98        // Fall through to scalar path for non-SIMD types
99    }
100
101    // Scalar fallback path (always available, used for String, Date, etc.)
102    functions::compute_columnar_aggregate_impl(scan, column_idx, op, filter_bitmap)
103}
104
105/// Compute multiple aggregates in a single pass over the data
106///
107/// This is more efficient than computing each aggregate separately
108/// as it only scans the data once.
109pub fn compute_multiple_aggregates(
110    rows: &[Row],
111    aggregates: &[AggregateSpec],
112    filter_bitmap: Option<&[bool]>,
113    schema: Option<&CombinedSchema>,
114) -> Result<Vec<SqlValue>, ExecutorError> {
115    let scan = ColumnarScan::new(rows);
116    let mut results = Vec::with_capacity(aggregates.len());
117
118    for spec in aggregates {
119        let result = match &spec.source {
120            // Fast path: direct column aggregation
121            AggregateSource::Column(column_idx) => {
122                compute_columnar_aggregate(&scan, *column_idx, spec.op, filter_bitmap)?
123            }
124            // Expression path: evaluate expression for each row, then aggregate
125            AggregateSource::Expression(expr) => {
126                let schema = schema.ok_or_else(|| {
127                    ExecutorError::UnsupportedExpression(
128                        "Schema required for expression aggregates".to_string(),
129                    )
130                })?;
131                expression::compute_expression_aggregate(
132                    rows,
133                    expr,
134                    spec.op,
135                    filter_bitmap,
136                    schema,
137                )?
138            }
139            // COUNT(*) path: count all rows
140            AggregateSource::CountStar => functions::compute_count(&scan, filter_bitmap)?,
141        };
142        results.push(result);
143    }
144
145    Ok(results)
146}
147
148/// Compute aggregates directly from a ColumnarBatch (no row conversion)
149///
150/// This is a high-performance path that eliminates the overhead of converting
151/// ColumnarBatch back to rows before aggregation. It works directly on the
152/// typed column arrays for maximum efficiency.
153///
154/// # Arguments
155///
156/// * `batch` - The ColumnarBatch to aggregate over (typically filtered)
157/// * `aggregates` - List of aggregate specifications to compute
158/// * `schema` - Optional schema for expression aggregates
159///
160/// # Performance
161///
162/// This function provides 20-30% speedup over the row-based path by:
163/// - Avoiding `batch.to_rows()` conversion overhead (~10-15ms for large batches)
164/// - Working directly on typed arrays (Int64, Float64) with SIMD operations
165/// - Reducing memory allocations
166///
167/// # Returns
168///
169/// Vector of aggregate results in the same order as the input aggregates
170pub fn compute_aggregates_from_batch(
171    batch: &ColumnarBatch,
172    aggregates: &[AggregateSpec],
173    schema: Option<&CombinedSchema>,
174) -> Result<Vec<SqlValue>, ExecutorError> {
175    // Handle empty batch
176    if batch.row_count() == 0 {
177        return Ok(aggregates
178            .iter()
179            .map(|spec| match spec.op {
180                AggregateOp::Count => SqlValue::Integer(0),
181                _ => SqlValue::Null,
182            })
183            .collect());
184    }
185
186    let mut results = Vec::with_capacity(aggregates.len());
187
188    for spec in aggregates {
189        let result = match &spec.source {
190            // Fast path: direct batch aggregation (no row conversion)
191            AggregateSource::Column(column_idx) => {
192                functions::compute_batch_aggregate(batch, *column_idx, spec.op)?
193            }
194            // Expression path: SIMD-accelerated evaluation directly on batch columns
195            // This eliminates the ~10-15ms overhead of batch.to_rows() for large batches
196            AggregateSource::Expression(expr) => {
197                let schema = schema.ok_or_else(|| {
198                    ExecutorError::UnsupportedExpression(
199                        "Schema required for expression aggregates".to_string(),
200                    )
201                })?;
202                // Try batch-native expression evaluation first
203                // Falls back to row-based evaluation if column types aren't supported
204                match expression::compute_batch_expression_aggregate(batch, expr, spec.op, schema) {
205                    Ok(value) => value,
206                    Err(ExecutorError::UnsupportedExpression(_)) => {
207                        // Fall back to row-based for unsupported column types (Mixed, Date, etc.)
208                        let rows = batch.to_rows()?;
209                        expression::compute_expression_aggregate(
210                            &rows, expr, spec.op, None, schema,
211                        )?
212                    }
213                    Err(other) => return Err(other),
214                }
215            }
216            // COUNT(*) path: just count rows in batch
217            AggregateSource::CountStar => SqlValue::Integer(batch.row_count() as i64),
218        };
219        results.push(result);
220    }
221
222    Ok(results)
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    fn make_test_rows() -> Vec<Row> {
230        vec![
231            Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
232            Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
233            Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
234        ]
235    }
236
237    #[test]
238    fn test_sum_aggregate() {
239        let rows = make_test_rows();
240        let scan = ColumnarScan::new(&rows);
241
242        let result = functions::compute_sum(&scan, 0, None).unwrap();
243        assert_eq!(result, SqlValue::Integer(60));
244
245        let result = functions::compute_sum(&scan, 1, None).unwrap();
246        assert!(matches!(result, SqlValue::Double(sum) if (sum - 7.5).abs() < 0.001));
247    }
248
249    #[test]
250    fn test_count_aggregate() {
251        let rows = make_test_rows();
252        let scan = ColumnarScan::new(&rows);
253
254        let result = functions::compute_count(&scan, None).unwrap();
255        assert_eq!(result, SqlValue::Integer(3));
256    }
257
258    #[test]
259    fn test_sum_with_filter() {
260        let rows = make_test_rows();
261        let scan = ColumnarScan::new(&rows);
262        let filter = vec![true, false, true]; // Include rows 0 and 2
263
264        let result = functions::compute_sum(&scan, 0, Some(&filter)).unwrap();
265        assert_eq!(result, SqlValue::Integer(40));
266    }
267
268    #[test]
269    fn test_multiple_aggregates() {
270        let rows = make_test_rows();
271        let aggregates = vec![
272            AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
273            AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
274        ];
275
276        let results = compute_multiple_aggregates(&rows, &aggregates, None, None).unwrap();
277        assert_eq!(results.len(), 2);
278        assert_eq!(results[0], SqlValue::Integer(60));
279        assert!(matches!(results[1], SqlValue::Double(avg) if (avg - 2.5).abs() < 0.001));
280    }
281
282    #[test]
283    fn test_extract_aggregates_simple() {
284        use vibesql_catalog::{ColumnSchema, TableSchema};
285        use vibesql_types::DataType;
286
287        use crate::schema::CombinedSchema;
288
289        // Create a simple schema with two columns
290        let schema = TableSchema::new(
291            "test".to_string(),
292            vec![
293                ColumnSchema::new("col1".to_string(), DataType::Integer, false),
294                ColumnSchema::new("col2".to_string(), DataType::DoublePrecision, false),
295            ],
296        );
297
298        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
299
300        // Test SUM(col1)
301        let exprs = vec![Expression::AggregateFunction {
302            name: vibesql_ast::FunctionIdentifier::new("SUM"),
303            distinct: false,
304            args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col1", false))],
305            order_by: None,
306            filter: None,
307        }];
308
309        let result = extract_aggregates(&exprs, &combined_schema);
310        assert!(result.is_some());
311        let aggregates = result.unwrap();
312        assert_eq!(aggregates.len(), 1);
313        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
314        assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
315
316        // Test COUNT(*)
317        let exprs = vec![Expression::AggregateFunction {
318            name: vibesql_ast::FunctionIdentifier::new("COUNT"),
319            distinct: false,
320            args: vec![Expression::Wildcard],
321            order_by: None,
322            filter: None,
323        }];
324
325        let result = extract_aggregates(&exprs, &combined_schema);
326        assert!(result.is_some());
327        let aggregates = result.unwrap();
328        assert_eq!(aggregates.len(), 1);
329        assert!(matches!(aggregates[0].op, AggregateOp::Count));
330        assert!(matches!(aggregates[0].source, AggregateSource::CountStar));
331
332        // Test multiple aggregates: SUM(col1), AVG(col2)
333        let exprs = vec![
334            Expression::AggregateFunction {
335                name: vibesql_ast::FunctionIdentifier::new("SUM"),
336                distinct: false,
337                args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
338                    "col1", false,
339                ))],
340                order_by: None,
341                filter: None,
342            },
343            Expression::AggregateFunction {
344                name: vibesql_ast::FunctionIdentifier::new("AVG"),
345                distinct: false,
346                args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
347                    "col2", false,
348                ))],
349                order_by: None,
350                filter: None,
351            },
352        ];
353
354        let result = extract_aggregates(&exprs, &combined_schema);
355        assert!(result.is_some());
356        let aggregates = result.unwrap();
357        assert_eq!(aggregates.len(), 2);
358        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
359        assert!(matches!(aggregates[0].source, AggregateSource::Column(0)));
360        assert!(matches!(aggregates[1].op, AggregateOp::Avg));
361        assert!(matches!(aggregates[1].source, AggregateSource::Column(1)));
362    }
363
364    #[test]
365    fn test_extract_aggregates_unsupported() {
366        use vibesql_catalog::{ColumnSchema, TableSchema};
367        use vibesql_types::DataType;
368
369        use crate::schema::CombinedSchema;
370
371        let schema = TableSchema::new(
372            "test".to_string(),
373            vec![ColumnSchema::new("col1".to_string(), DataType::Integer, false)],
374        );
375
376        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
377
378        // Test DISTINCT aggregate (should return None)
379        let exprs = vec![Expression::AggregateFunction {
380            name: vibesql_ast::FunctionIdentifier::new("SUM"),
381            distinct: true,
382            args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col1", false))],
383            order_by: None,
384            filter: None,
385        }];
386
387        let result = extract_aggregates(&exprs, &combined_schema);
388        assert!(result.is_none());
389
390        // Test non-aggregate expression (should return None)
391        let exprs =
392            vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple("col1", false))];
393
394        let result = extract_aggregates(&exprs, &combined_schema);
395        assert!(result.is_none());
396
397        // Test subquery in aggregate (should return None - not supported)
398        let exprs = vec![Expression::AggregateFunction {
399            name: vibesql_ast::FunctionIdentifier::new("SUM"),
400            distinct: false,
401            args: vec![Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
402                with_clause: None,
403                distinct: false,
404                select_list: vec![],
405                into_table: None,
406                into_variables: None,
407                from: None,
408                where_clause: None,
409                group_by: None,
410                having: None,
411                order_by: None,
412                limit: None,
413                offset: None,
414                set_operation: None,
415                values: None,
416            }))],
417            order_by: None,
418            filter: None,
419        }];
420
421        let result = extract_aggregates(&exprs, &combined_schema);
422        assert!(result.is_none());
423    }
424
425    #[test]
426    fn test_extract_aggregates_with_expression() {
427        use vibesql_catalog::{ColumnSchema, TableSchema};
428        use vibesql_types::DataType;
429
430        use crate::schema::CombinedSchema;
431
432        // Create a simple schema with two columns
433        let schema = TableSchema::new(
434            "test".to_string(),
435            vec![
436                ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
437                ColumnSchema::new("discount".to_string(), DataType::DoublePrecision, false),
438            ],
439        );
440
441        let combined_schema = CombinedSchema::from_table("test".to_string(), schema);
442
443        // Test SUM(price * discount) - simple binary operation
444        let exprs = vec![Expression::AggregateFunction {
445            name: vibesql_ast::FunctionIdentifier::new("SUM"),
446            distinct: false,
447            args: vec![Expression::BinaryOp {
448                left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
449                    "price", false,
450                ))),
451                op: vibesql_ast::BinaryOperator::Multiply,
452                right: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
453                    "discount", false,
454                ))),
455            }],
456            order_by: None,
457            filter: None,
458        }];
459
460        let result = extract_aggregates(&exprs, &combined_schema);
461        assert!(result.is_some());
462        let aggregates = result.unwrap();
463        assert_eq!(aggregates.len(), 1);
464        assert!(matches!(aggregates[0].op, AggregateOp::Sum));
465        assert!(matches!(aggregates[0].source, AggregateSource::Expression(_)));
466    }
467
468    // GROUP BY tests
469
470    #[test]
471    fn test_columnar_group_by_simple() {
472        // Test simple GROUP BY with one group column
473        // SELECT status, SUM(amount) FROM test GROUP BY status
474        let rows = vec![
475            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(100.0)]),
476            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(200.0)]),
477            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(150.0)]),
478            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(50.0)]),
479        ];
480
481        let group_cols = vec![0]; // Group by status
482        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM(amount)
483
484        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
485
486        // Should have 2 groups: A and B
487        assert_eq!(result.len(), 2);
488
489        // Sort results by group key for deterministic testing
490        let mut sorted = result;
491        sorted.sort_by(|a, b| {
492            let a_key = a.get(0).unwrap();
493            let b_key = b.get(0).unwrap();
494            a_key.partial_cmp(b_key).unwrap()
495        });
496
497        // Check group A: SUM = 250.0
498        assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("A"))));
499        assert!(
500            matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
501        );
502
503        // Check group B: SUM = 250.0
504        assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("B"))));
505        assert!(
506            matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
507        );
508    }
509
510    #[test]
511    fn test_columnar_group_by_multiple_group_keys() {
512        // Test GROUP BY with multiple columns
513        // SELECT status, category, COUNT(*) FROM test GROUP BY status, category
514        let rows = vec![
515            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Integer(1)]),
516            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Integer(2)]),
517            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Integer(1)]),
518            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Integer(1)]),
519            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Integer(2)]),
520        ];
521
522        let group_cols = vec![0, 1]; // Group by status, category
523        let agg_cols = vec![(0, AggregateOp::Count)]; // COUNT(*)
524
525        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
526
527        // Should have 4 groups: (A,1), (A,2), (B,1), (B,2)
528        assert_eq!(result.len(), 4);
529
530        // Verify we have correct counts
531        for row in &result {
532            let status = row.get(0).unwrap();
533            let category = row.get(1).unwrap();
534            let count = row.get(2).unwrap();
535
536            match (status, category) {
537                (SqlValue::Varchar(s), SqlValue::Integer(1)) if s.as_str() == "A" => {
538                    assert_eq!(count, &SqlValue::Integer(2)); // Two rows with A,1
539                }
540                (SqlValue::Varchar(s), SqlValue::Integer(2)) if s.as_str() == "A" => {
541                    assert_eq!(count, &SqlValue::Integer(1)); // One row with A,2
542                }
543                (SqlValue::Varchar(s), SqlValue::Integer(1)) if s.as_str() == "B" => {
544                    assert_eq!(count, &SqlValue::Integer(1)); // One row with B,1
545                }
546                (SqlValue::Varchar(s), SqlValue::Integer(2)) if s.as_str() == "B" => {
547                    assert_eq!(count, &SqlValue::Integer(1)); // One row with B,2
548                }
549                _ => panic!("Unexpected group key: {:?}, {:?}", status, category),
550            }
551        }
552    }
553
554    #[test]
555    fn test_columnar_group_by_multiple_aggregates() {
556        // Test GROUP BY with multiple aggregate functions
557        // SELECT category, SUM(price), AVG(quantity), COUNT(*) FROM test GROUP BY category
558        let rows = vec![
559            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(100.0), SqlValue::Integer(10)]),
560            Row::new(vec![SqlValue::Integer(2), SqlValue::Double(200.0), SqlValue::Integer(20)]),
561            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(150.0), SqlValue::Integer(15)]),
562        ];
563
564        let group_cols = vec![0]; // Group by category
565        let agg_cols = vec![
566            (1, AggregateOp::Sum),   // SUM(price)
567            (2, AggregateOp::Avg),   // AVG(quantity)
568            (0, AggregateOp::Count), // COUNT(*)
569        ];
570
571        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
572
573        // Should have 2 groups
574        assert_eq!(result.len(), 2);
575
576        // Sort by category for deterministic testing
577        let mut sorted = result;
578        sorted.sort_by(|a, b| {
579            let a_key = a.get(0).unwrap();
580            let b_key = b.get(0).unwrap();
581            a_key.partial_cmp(b_key).unwrap()
582        });
583
584        // Group 1: SUM=250, AVG=12.5, COUNT=2
585        assert_eq!(sorted[0].get(0), Some(&SqlValue::Integer(1)));
586        assert!(
587            matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
588        );
589        assert!(
590            matches!(sorted[0].get(2), Some(&SqlValue::Double(avg)) if (avg - 12.5).abs() < 0.001)
591        );
592        assert_eq!(sorted[0].get(3), Some(&SqlValue::Integer(2)));
593
594        // Group 2: SUM=200, AVG=20.0, COUNT=1
595        assert_eq!(sorted[1].get(0), Some(&SqlValue::Integer(2)));
596        assert!(
597            matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
598        );
599        assert!(
600            matches!(sorted[1].get(2), Some(&SqlValue::Double(avg)) if (avg - 20.0).abs() < 0.001)
601        );
602        assert_eq!(sorted[1].get(3), Some(&SqlValue::Integer(1)));
603    }
604
605    #[test]
606    fn test_columnar_group_by_with_filter() {
607        // Test GROUP BY with pre-filtering
608        // SELECT status, SUM(amount) FROM test WHERE amount > 100 GROUP BY status
609        let rows = vec![
610            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(100.0)]),
611            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(200.0)]),
612            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(150.0)]),
613            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("B")), SqlValue::Double(50.0)]),
614        ];
615
616        // Filter: amount > 100 (rows 1 and 2)
617        let filter = vec![false, true, true, false];
618
619        let group_cols = vec![0]; // Group by status
620        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM(amount)
621
622        let result = columnar_group_by(&rows, &group_cols, &agg_cols, Some(&filter)).unwrap();
623
624        // Should have 2 groups (only rows passing filter)
625        assert_eq!(result.len(), 2);
626
627        // Sort results by group key
628        let mut sorted = result;
629        sorted.sort_by(|a, b| {
630            let a_key = a.get(0).unwrap();
631            let b_key = b.get(0).unwrap();
632            a_key.partial_cmp(b_key).unwrap()
633        });
634
635        // Check group A: only row 2 (150.0) passes filter
636        assert_eq!(sorted[0].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("A"))));
637        assert!(
638            matches!(sorted[0].get(1), Some(&SqlValue::Double(sum)) if (sum - 150.0).abs() < 0.001)
639        );
640
641        // Check group B: only row 1 (200.0) passes filter
642        assert_eq!(sorted[1].get(0), Some(&SqlValue::Varchar(arcstr::ArcStr::from("B"))));
643        assert!(
644            matches!(sorted[1].get(1), Some(&SqlValue::Double(sum)) if (sum - 200.0).abs() < 0.001)
645        );
646    }
647
648    #[test]
649    fn test_columnar_group_by_empty_input() {
650        let rows: Vec<Row> = vec![];
651        let group_cols = vec![0];
652        let agg_cols = vec![(1, AggregateOp::Sum)];
653
654        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
655
656        // Should return empty result for empty input
657        assert_eq!(result.len(), 0);
658    }
659
660    #[test]
661    fn test_columnar_group_by_null_in_group_key() {
662        // Test that NULL values in group keys are handled correctly
663        let rows = vec![
664            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(100.0)]),
665            Row::new(vec![SqlValue::Null, SqlValue::Double(200.0)]),
666            Row::new(vec![SqlValue::Varchar(arcstr::ArcStr::from("A")), SqlValue::Double(150.0)]),
667            Row::new(vec![SqlValue::Null, SqlValue::Double(50.0)]),
668        ];
669
670        let group_cols = vec![0]; // Group by first column
671        let agg_cols = vec![(1, AggregateOp::Sum)]; // SUM
672
673        let result = columnar_group_by(&rows, &group_cols, &agg_cols, None).unwrap();
674
675        // Should have 2 groups: "A" and NULL
676        assert_eq!(result.len(), 2);
677
678        // Find the groups
679        let a_group = result
680            .iter()
681            .find(|r| matches!(r.get(0), Some(SqlValue::Varchar(s)) if s.as_str() == "A"));
682        let null_group = result.iter().find(|r| matches!(r.get(0), Some(SqlValue::Null)));
683
684        assert!(a_group.is_some());
685        assert!(null_group.is_some());
686
687        // Check "A" group: 100 + 150 = 250
688        assert!(
689            matches!(a_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
690        );
691
692        // Check NULL group: 200 + 50 = 250
693        assert!(
694            matches!(null_group.unwrap().get(1), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
695        );
696    }
697}