vibesql_executor/select/columnar/
mod.rs

1//! Columnar execution for high-performance aggregation queries
2//!
3//! This module implements column-oriented query execution that avoids
4//! materializing full Row objects during table scans, providing 8-10x speedup
5//! for aggregation-heavy workloads.
6//!
7//! ## Architecture
8//!
9//! Instead of:
10//! ```text
11//! TableScan → Row{Vec<SqlValue>} → Filter(Row) → Aggregate(Row) → Vec<Row>
12//! ```
13//!
14//! We use:
15//! ```text
16//! TableScan → ColumnRefs → Filter(native types) → Aggregate → Row
17//! ```
18//!
19//! ## Benefits
20//!
21//! - **Zero-copy**: Work with `&SqlValue` references instead of cloning
22//! - **Cache-friendly**: Access contiguous column data instead of scattered row data
23//! - **Type-specialized**: Skip SqlValue enum matching overhead for filters/aggregates
24//! - **Minimal allocations**: Only allocate result rows, not intermediate data
25//!
26//! ## Usage
27//!
28//! This path is automatically selected for simple aggregate queries that:
29//! - Have a single table scan (no JOINs)
30//! - Use simple WHERE predicates
31//! - Compute aggregates (SUM, COUNT, AVG, MIN, MAX)
32//! - Don't use window functions or complex subqueries
33
34mod aggregate;
35pub mod batch;
36mod executor;
37pub mod filter;
38mod scan;
39mod string_ops;
40
41// Auto-vectorized SIMD operations - replaces the `wide` crate dependency
42// See simd_ops.rs for documentation on why these patterns are structured this way
43pub mod simd_ops;
44
45mod simd_aggregate;
46pub mod simd_filter;
47mod simd_join;
48
49pub use aggregate::{
50    columnar_group_by, columnar_group_by_batch, compute_aggregates_from_batch,
51    compute_multiple_aggregates, evaluate_expression_to_column,
52    evaluate_expression_with_cached_column, extract_aggregates, AggregateOp, AggregateSource,
53    AggregateSpec,
54};
55pub use batch::{ColumnArray, ColumnarBatch};
56pub use executor::execute_columnar_batch;
57pub use filter::{
58    apply_columnar_filter, apply_columnar_filter_simd_streaming, create_filter_bitmap,
59    create_filter_bitmap_tree, evaluate_predicate_tree, extract_column_predicates,
60    extract_predicate_tree, ColumnPredicate, PredicateTree,
61};
62use log;
63pub use scan::ColumnarScan;
64pub use simd_aggregate::{can_use_simd_for_column, simd_aggregate_f64, simd_aggregate_i64};
65pub use simd_filter::{
66    simd_create_filter_mask, simd_create_filter_mask_auto, simd_create_filter_mask_packed,
67    simd_filter_batch, simd_filter_to_indices,
68};
69#[cfg(feature = "parallel")]
70pub use simd_filter::{simd_create_filter_mask_parallel, simd_filter_batch_parallel};
71pub use simd_join::columnar_hash_join_inner;
72pub use simd_ops::PackedMask;
73use vibesql_storage::Row;
74use vibesql_types::SqlValue;
75
76use crate::{errors::ExecutorError, schema::CombinedSchema};
77
78/// Execute a columnar aggregate query with filtering
79///
80/// This is a simplified entry point for columnar execution that demonstrates
81/// the full pipeline: scan → filter → aggregate.
82///
83/// # Arguments
84///
85/// * `rows` - Input rows to process
86/// * `predicates` - Column predicates for filtering (optional)
87/// * `aggregates` - List of (column_index, aggregate_op) pairs to compute
88///
89/// # Returns
90///
91/// A single Row containing the computed aggregate values
92///
93/// # Example
94///
95/// ```text
96/// // Compute SUM(col0), AVG(col1) WHERE col2 < 100
97/// let predicates = vec![
98///     ColumnPredicate::LessThan {
99///         column_idx: 2,
100///         value: SqlValue::Integer(100),
101///     },
102/// ];
103/// let aggregates = vec![
104///     (0, AggregateOp::Sum),
105///     (1, AggregateOp::Avg),
106/// ];
107///
108/// let result = execute_columnar_aggregate(&rows, &predicates, &aggregates)?;
109/// ```
110///
111/// Note: This function provides SIMD-accelerated filtering and aggregation through
112/// LLVM auto-vectorization of batch-native operations.
113pub fn execute_columnar_aggregate(
114    rows: &[Row],
115    predicates: &[ColumnPredicate],
116    aggregates: &[aggregate::AggregateSpec],
117    schema: Option<&CombinedSchema>,
118) -> Result<Vec<Row>, ExecutorError> {
119    // Early return for empty input
120    // SQL standard: COUNT returns 0 for empty input, other aggregates return NULL
121    if rows.is_empty() {
122        let values: Vec<SqlValue> = aggregates
123            .iter()
124            .map(|spec| match spec.op {
125                aggregate::AggregateOp::Count => SqlValue::Integer(0),
126                _ => SqlValue::Null,
127            })
128            .collect();
129        return Ok(vec![Row::new(values)]);
130    }
131
132    // Phase 1: Convert to columnar batch for SIMD acceleration
133    #[cfg(feature = "profile-q6")]
134    let batch_start = std::time::Instant::now();
135
136    let batch = ColumnarBatch::from_rows(rows)?;
137
138    #[cfg(feature = "profile-q6")]
139    {
140        let batch_time = batch_start.elapsed();
141        eprintln!("[PROFILE-Q6]   Phase 1 - Convert to batch: {:?}", batch_time);
142    }
143
144    // Phase 2: Apply SIMD-accelerated filtering
145    #[cfg(feature = "profile-q6")]
146    let filter_start = std::time::Instant::now();
147
148    let filtered_batch =
149        if predicates.is_empty() { batch.clone() } else { simd_filter_batch(&batch, predicates)? };
150
151    #[cfg(feature = "profile-q6")]
152    {
153        let filter_time = filter_start.elapsed();
154        eprintln!(
155            "[PROFILE-Q6]   Phase 2 - SIMD filter: {:?} ({}/{} rows passed)",
156            filter_time,
157            filtered_batch.row_count(),
158            rows.len()
159        );
160    }
161
162    // Phase 3: Compute aggregates directly on batch (no row conversion!)
163    #[cfg(feature = "profile-q6")]
164    let agg_start = std::time::Instant::now();
165
166    // Use batch-native aggregation to avoid to_rows() conversion overhead
167    let results = compute_aggregates_from_batch(&filtered_batch, aggregates, schema)?;
168
169    #[cfg(feature = "profile-q6")]
170    {
171        let agg_time = agg_start.elapsed();
172        eprintln!(
173            "[PROFILE-Q6]   Phase 3 - Batch-native aggregate: {:?} ({} aggregates)",
174            agg_time,
175            aggregates.len()
176        );
177    }
178
179    // Return as single row
180    Ok(vec![Row::new(results)])
181}
182
183/// Fast single-pass aggregate on rows - avoids batch conversion overhead
184///
185/// This function performs filtering and aggregation in a single pass over the input rows,
186/// without converting to columnar format. It's 3-5x faster than `execute_columnar_aggregate`
187/// for queries that come from row-based storage.
188///
189/// # Use Cases
190///
191/// Best suited for:
192/// - Simple aggregate queries without GROUP BY
193/// - When data arrives as Vec<Row> (not native columnar)
194/// - TPC-H style queries: SUM(price * discount) WHERE ...
195///
196/// # Arguments
197///
198/// * `rows` - Input rows to process
199/// * `predicates` - Column predicates for filtering
200/// * `aggregates` - Aggregate specifications
201///
202/// # Returns
203///
204/// A single Row containing the computed aggregate values
205pub fn fast_aggregate_on_rows(
206    rows: &[Row],
207    predicates: &[ColumnPredicate],
208    aggregates: &[aggregate::AggregateSpec],
209) -> Result<Vec<Row>, ExecutorError> {
210    use aggregate::{AggregateOp, AggregateSource};
211
212    // Early return for empty input
213    if rows.is_empty() {
214        let values: Vec<SqlValue> = aggregates
215            .iter()
216            .map(|spec| match spec.op {
217                AggregateOp::Count => SqlValue::Integer(0),
218                _ => SqlValue::Null,
219            })
220            .collect();
221        return Ok(vec![Row::new(values)]);
222    }
223
224    // Initialize accumulators for each aggregate
225    struct Accumulator {
226        sum_f64: f64,
227        sum_i64: i64,
228        count: i64,
229        min_f64: Option<f64>,
230        max_f64: Option<f64>,
231        min_i64: Option<i64>,
232        max_i64: Option<i64>,
233        is_integer: bool,
234    }
235
236    let mut accumulators: Vec<Accumulator> = aggregates
237        .iter()
238        .map(|_| Accumulator {
239            sum_f64: 0.0,
240            sum_i64: 0,
241            count: 0,
242            min_f64: None,
243            max_f64: None,
244            min_i64: None,
245            max_i64: None,
246            is_integer: true,
247        })
248        .collect();
249
250    // Single pass: filter and accumulate
251    for row in rows {
252        // Check all predicates
253        let passes_filter = predicates.iter().all(|pred| evaluate_predicate(row, pred));
254
255        if !passes_filter {
256            continue;
257        }
258
259        // Accumulate values for each aggregate
260        for (i, spec) in aggregates.iter().enumerate() {
261            let acc = &mut accumulators[i];
262
263            match &spec.source {
264                AggregateSource::CountStar => {
265                    acc.count += 1;
266                }
267                AggregateSource::Column(col_idx) => {
268                    if let Some(value) = row.get(*col_idx) {
269                        if !matches!(value, SqlValue::Null) {
270                            acc.count += 1;
271                            match value {
272                                SqlValue::Integer(v) => {
273                                    acc.sum_i64 += v;
274                                    acc.sum_f64 += *v as f64;
275                                    acc.min_i64 = Some(acc.min_i64.map_or(*v, |m| m.min(*v)));
276                                    acc.max_i64 = Some(acc.max_i64.map_or(*v, |m| m.max(*v)));
277                                    acc.min_f64 =
278                                        Some(acc.min_f64.map_or(*v as f64, |m| m.min(*v as f64)));
279                                    acc.max_f64 =
280                                        Some(acc.max_f64.map_or(*v as f64, |m| m.max(*v as f64)));
281                                }
282                                SqlValue::Double(v) => {
283                                    acc.is_integer = false;
284                                    acc.sum_f64 += v;
285                                    acc.min_f64 = Some(acc.min_f64.map_or(*v, |m| m.min(*v)));
286                                    acc.max_f64 = Some(acc.max_f64.map_or(*v, |m| m.max(*v)));
287                                }
288                                SqlValue::Float(v) => {
289                                    acc.is_integer = false;
290                                    acc.sum_f64 += *v as f64;
291                                    acc.min_f64 =
292                                        Some(acc.min_f64.map_or(*v as f64, |m| m.min(*v as f64)));
293                                    acc.max_f64 =
294                                        Some(acc.max_f64.map_or(*v as f64, |m| m.max(*v as f64)));
295                                }
296                                SqlValue::Bigint(v) => {
297                                    acc.sum_i64 += v;
298                                    acc.sum_f64 += *v as f64;
299                                    acc.min_i64 = Some(acc.min_i64.map_or(*v, |m| m.min(*v)));
300                                    acc.max_i64 = Some(acc.max_i64.map_or(*v, |m| m.max(*v)));
301                                    acc.min_f64 =
302                                        Some(acc.min_f64.map_or(*v as f64, |m| m.min(*v as f64)));
303                                    acc.max_f64 =
304                                        Some(acc.max_f64.map_or(*v as f64, |m| m.max(*v as f64)));
305                                }
306                                SqlValue::Numeric(v) => {
307                                    acc.is_integer = false;
308                                    acc.sum_f64 += v;
309                                    acc.min_f64 = Some(acc.min_f64.map_or(*v, |m| m.min(*v)));
310                                    acc.max_f64 = Some(acc.max_f64.map_or(*v, |m| m.max(*v)));
311                                }
312                                _ => {}
313                            }
314                        }
315                    }
316                }
317                AggregateSource::Expression(expr) => {
318                    // For expression aggregates like SUM(a * b), evaluate the expression
319                    // This is a simplified evaluator for common binary operations
320                    if let Some(value) = eval_simple_expression(row, expr) {
321                        acc.count += 1;
322                        acc.is_integer = false;
323                        acc.sum_f64 += value;
324                        acc.min_f64 = Some(acc.min_f64.map_or(value, |m| m.min(value)));
325                        acc.max_f64 = Some(acc.max_f64.map_or(value, |m| m.max(value)));
326                    }
327                }
328            }
329        }
330    }
331
332    // Build result row from accumulators
333    let values: Vec<SqlValue> = aggregates
334        .iter()
335        .zip(accumulators.iter())
336        .map(|(spec, acc)| match spec.op {
337            AggregateOp::Count => SqlValue::Integer(acc.count),
338            AggregateOp::Sum => {
339                // SQLite's SUM() preserves integer type for integer inputs
340                if acc.count == 0 {
341                    SqlValue::Null
342                } else if acc.is_integer {
343                    SqlValue::Integer(acc.sum_i64)
344                } else {
345                    SqlValue::Double(acc.sum_f64)
346                }
347            }
348            AggregateOp::Avg => {
349                if acc.count == 0 {
350                    SqlValue::Null
351                } else {
352                    SqlValue::Double(acc.sum_f64 / acc.count as f64)
353                }
354            }
355            AggregateOp::Min => {
356                if acc.is_integer {
357                    acc.min_i64.map(SqlValue::Integer).unwrap_or(SqlValue::Null)
358                } else {
359                    acc.min_f64.map(SqlValue::Double).unwrap_or(SqlValue::Null)
360                }
361            }
362            AggregateOp::Max => {
363                if acc.is_integer {
364                    acc.max_i64.map(SqlValue::Integer).unwrap_or(SqlValue::Null)
365                } else {
366                    acc.max_f64.map(SqlValue::Double).unwrap_or(SqlValue::Null)
367                }
368            }
369        })
370        .collect();
371
372    Ok(vec![Row::new(values)])
373}
374
375/// Evaluate a column predicate against a row
376fn evaluate_predicate(row: &Row, predicate: &ColumnPredicate) -> bool {
377    match predicate {
378        ColumnPredicate::LessThan { column_idx, value } => row
379            .get(*column_idx)
380            .map(|v| compare_values(v, value) == std::cmp::Ordering::Less)
381            .unwrap_or(false),
382        ColumnPredicate::LessThanOrEqual { column_idx, value } => row
383            .get(*column_idx)
384            .map(|v| compare_values(v, value) != std::cmp::Ordering::Greater)
385            .unwrap_or(false),
386        ColumnPredicate::GreaterThan { column_idx, value } => row
387            .get(*column_idx)
388            .map(|v| compare_values(v, value) == std::cmp::Ordering::Greater)
389            .unwrap_or(false),
390        ColumnPredicate::GreaterThanOrEqual { column_idx, value } => row
391            .get(*column_idx)
392            .map(|v| compare_values(v, value) != std::cmp::Ordering::Less)
393            .unwrap_or(false),
394        ColumnPredicate::Equal { column_idx, value } => row
395            .get(*column_idx)
396            .map(|v| compare_values(v, value) == std::cmp::Ordering::Equal)
397            .unwrap_or(false),
398        ColumnPredicate::NotEqual { column_idx, value } => row
399            .get(*column_idx)
400            .map(|v| compare_values(v, value) != std::cmp::Ordering::Equal)
401            .unwrap_or(false),
402        ColumnPredicate::Between { column_idx, low, high } => row
403            .get(*column_idx)
404            .map(|v| {
405                compare_values(v, low) != std::cmp::Ordering::Less
406                    && compare_values(v, high) != std::cmp::Ordering::Greater
407            })
408            .unwrap_or(false),
409        ColumnPredicate::Like { column_idx, pattern, negated } => {
410            // Simple LIKE pattern matching for row-based fast path
411            let matches = row
412                .get(*column_idx)
413                .map(|v| {
414                    if let SqlValue::Varchar(s) = v {
415                        // Convert SQL LIKE pattern to simple check
416                        // This is a simplified version - full LIKE support is in simd_filter
417                        let pattern_str = pattern.as_str();
418                        if let Some(inner) =
419                            pattern_str.strip_prefix('%').and_then(|s| s.strip_suffix('%'))
420                        {
421                            s.contains(inner)
422                        } else if let Some(suffix) = pattern_str.strip_prefix('%') {
423                            s.ends_with(suffix)
424                        } else if let Some(prefix) = pattern_str.strip_suffix('%') {
425                            s.starts_with(prefix)
426                        } else {
427                            &**s == pattern_str
428                        }
429                    } else {
430                        false
431                    }
432                })
433                .unwrap_or(false);
434            if *negated {
435                !matches
436            } else {
437                matches
438            }
439        }
440        ColumnPredicate::InList { column_idx, values, negated, use_strict_type_ordering } => {
441            // Check if column value matches any value in the list
442            let matches = row
443                .get(*column_idx)
444                .map(|v| {
445                    values.iter().any(|list_val| {
446                        if *use_strict_type_ordering {
447                            // Use strict type ordering - no coercion
448                            strict_type_equal(v, list_val)
449                        } else {
450                            compare_values(v, list_val) == std::cmp::Ordering::Equal
451                        }
452                    })
453                })
454                .unwrap_or(false);
455            if *negated {
456                !matches
457            } else {
458                matches
459            }
460        }
461        ColumnPredicate::ColumnCompare { left_column_idx, op, right_column_idx } => {
462            // Column-to-column comparison
463            let left_val = row.get(*left_column_idx);
464            let right_val = row.get(*right_column_idx);
465            match (left_val, right_val) {
466                (Some(l), Some(r)) => {
467                    use std::cmp::Ordering;
468                    let cmp = compare_values(l, r);
469                    match op {
470                        filter::CompareOp::LessThan => cmp == Ordering::Less,
471                        filter::CompareOp::GreaterThan => cmp == Ordering::Greater,
472                        filter::CompareOp::LessThanOrEqual => cmp != Ordering::Greater,
473                        filter::CompareOp::GreaterThanOrEqual => cmp != Ordering::Less,
474                        filter::CompareOp::Equal => cmp == Ordering::Equal,
475                        filter::CompareOp::NotEqual => cmp != Ordering::Equal,
476                    }
477                }
478                _ => false, // NULL comparison returns false
479            }
480        }
481    }
482}
483
484/// Compare two SqlValues
485fn compare_values(a: &SqlValue, b: &SqlValue) -> std::cmp::Ordering {
486    use std::cmp::Ordering;
487
488    match (a, b) {
489        (SqlValue::Integer(a), SqlValue::Integer(b)) => a.cmp(b),
490        (SqlValue::Bigint(a), SqlValue::Bigint(b)) => a.cmp(b),
491        (SqlValue::Double(a), SqlValue::Double(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
492        (SqlValue::Float(a), SqlValue::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
493        // Cross-type comparisons
494        (SqlValue::Integer(a), SqlValue::Double(b)) => {
495            (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal)
496        }
497        (SqlValue::Double(a), SqlValue::Integer(b)) => {
498            a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal)
499        }
500        (SqlValue::Integer(a), SqlValue::Bigint(b)) => (*a).cmp(b),
501        (SqlValue::Bigint(a), SqlValue::Integer(b)) => a.cmp(&{ *b }),
502        // String comparisons
503        (SqlValue::Varchar(a), SqlValue::Varchar(b)) => a.cmp(b),
504        // Date comparisons
505        (SqlValue::Date(a), SqlValue::Date(b)) => {
506            // Compare year, month, day
507            match a.year.cmp(&b.year) {
508                Ordering::Equal => match a.month.cmp(&b.month) {
509                    Ordering::Equal => a.day.cmp(&b.day),
510                    other => other,
511                },
512                other => other,
513            }
514        }
515        // NULL handling
516        (SqlValue::Null, _) | (_, SqlValue::Null) => Ordering::Equal, /* NULL comparisons are */
517        // undefined
518        _ => Ordering::Equal, // Incompatible types
519    }
520}
521
522/// Strict equality comparison without string-to-number coercion
523///
524/// Returns true if both values can be considered equal WITHOUT coercing
525/// string values to numbers. This is used for columns with NONE or INTEGER
526/// affinity in IN expressions, where string values should NOT be coerced.
527///
528/// Numeric types (Integer, Float, Real, Double, etc.) can still be compared
529/// with each other since they're all in the same storage class.
530fn strict_type_equal(a: &SqlValue, b: &SqlValue) -> bool {
531    // NULL never equals anything
532    if matches!(a, SqlValue::Null) || matches!(b, SqlValue::Null) {
533        return false;
534    }
535
536    // Helper to check if a value is a string type
537    let is_string = |v: &SqlValue| matches!(v, SqlValue::Varchar(_) | SqlValue::Character(_));
538
539    // Helper to check if a value is a numeric type
540    let is_numeric = |v: &SqlValue| {
541        matches!(
542            v,
543            SqlValue::Integer(_)
544                | SqlValue::Bigint(_)
545                | SqlValue::Smallint(_)
546                | SqlValue::Float(_)
547                | SqlValue::Real(_)
548                | SqlValue::Double(_)
549                | SqlValue::Numeric(_)
550        )
551    };
552
553    // Different storage classes (string vs numeric) - NOT equal without coercion
554    if (is_string(a) && is_numeric(b)) || (is_numeric(a) && is_string(b)) {
555        return false;
556    }
557
558    // Same storage class - use normal comparison
559    match (a, b) {
560        // Same integer types
561        (SqlValue::Integer(x), SqlValue::Integer(y)) => x == y,
562        (SqlValue::Bigint(x), SqlValue::Bigint(y)) => x == y,
563        (SqlValue::Smallint(x), SqlValue::Smallint(y)) => x == y,
564
565        // Same string types (VARCHAR and CHAR are compatible)
566        (SqlValue::Varchar(x), SqlValue::Varchar(y))
567        | (SqlValue::Character(x), SqlValue::Character(y)) => x == y,
568        (SqlValue::Varchar(x), SqlValue::Character(y))
569        | (SqlValue::Character(x), SqlValue::Varchar(y)) => x == y,
570
571        // Same float types
572        (SqlValue::Float(x), SqlValue::Float(y)) => (x - y).abs() < f32::EPSILON,
573        (SqlValue::Double(x), SqlValue::Double(y)) => (x - y).abs() < f64::EPSILON,
574        (SqlValue::Real(x), SqlValue::Real(y)) => (x - y).abs() < f64::EPSILON,
575
576        // Cross-type numeric comparisons (all numeric types are comparable)
577        (SqlValue::Integer(x), SqlValue::Bigint(y))
578        | (SqlValue::Bigint(y), SqlValue::Integer(x)) => *x as i64 == *y,
579
580        // Float types with integer
581        (SqlValue::Float(x), SqlValue::Integer(y))
582        | (SqlValue::Integer(y), SqlValue::Float(x)) => (*x as f64 - *y as f64).abs() < f64::EPSILON,
583        (SqlValue::Double(x), SqlValue::Integer(y))
584        | (SqlValue::Integer(y), SqlValue::Double(x)) => (*x - *y as f64).abs() < f64::EPSILON,
585        (SqlValue::Real(x), SqlValue::Integer(y))
586        | (SqlValue::Integer(y), SqlValue::Real(x)) => (*x - *y as f64).abs() < f64::EPSILON,
587
588        // Mixed float types
589        (SqlValue::Float(x), SqlValue::Double(y))
590        | (SqlValue::Double(y), SqlValue::Float(x)) => (*x as f64 - *y).abs() < f64::EPSILON,
591        (SqlValue::Float(x), SqlValue::Real(y))
592        | (SqlValue::Real(y), SqlValue::Float(x)) => (*x as f64 - *y).abs() < f64::EPSILON,
593        (SqlValue::Double(x), SqlValue::Real(y))
594        | (SqlValue::Real(y), SqlValue::Double(x)) => (*x - *y).abs() < f64::EPSILON,
595
596        // Numeric type (f64) with integer types
597        (SqlValue::Numeric(x), SqlValue::Integer(y))
598        | (SqlValue::Integer(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
599        (SqlValue::Numeric(x), SqlValue::Bigint(y))
600        | (SqlValue::Bigint(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
601        (SqlValue::Numeric(x), SqlValue::Smallint(y))
602        | (SqlValue::Smallint(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
603
604        // Numeric with float types
605        (SqlValue::Numeric(x), SqlValue::Float(y))
606        | (SqlValue::Float(y), SqlValue::Numeric(x)) => (*x - *y as f64).abs() < f64::EPSILON,
607        (SqlValue::Numeric(x), SqlValue::Double(y))
608        | (SqlValue::Double(y), SqlValue::Numeric(x)) => (*x - *y).abs() < f64::EPSILON,
609        (SqlValue::Numeric(x), SqlValue::Real(y))
610        | (SqlValue::Real(y), SqlValue::Numeric(x)) => (*x - *y).abs() < f64::EPSILON,
611        (SqlValue::Numeric(x), SqlValue::Numeric(y)) => (*x - *y).abs() < f64::EPSILON,
612
613        // Different storage classes not handled above - NOT equal
614        _ => false,
615    }
616}
617
618/// Evaluate a simple expression (for expression aggregates)
619#[allow(clippy::only_used_in_recursion)]
620fn eval_simple_expression(row: &Row, expr: &vibesql_ast::Expression) -> Option<f64> {
621    use vibesql_ast::{BinaryOperator, Expression};
622
623    match expr {
624        Expression::BinaryOp { left, op, right } => {
625            let left_val = eval_simple_expression(row, left)?;
626            let right_val = eval_simple_expression(row, right)?;
627            match op {
628                BinaryOperator::Multiply => Some(left_val * right_val),
629                BinaryOperator::Divide => Some(left_val / right_val),
630                BinaryOperator::Plus => Some(left_val + right_val),
631                BinaryOperator::Minus => Some(left_val - right_val),
632                _ => None,
633            }
634        }
635        Expression::ColumnRef(col_id) => {
636            // Cannot resolve column names without a schema - return None to skip
637            // the fast path and fall back to the columnar execution path which
638            // properly handles expression aggregates with schema resolution.
639            log::debug!(
640                "fast_aggregate_on_rows: ColumnRef '{}' requires schema resolution, skipping fast path",
641                col_id.column_canonical()
642            );
643            None
644        }
645        Expression::Literal(val) => match val {
646            SqlValue::Integer(v) => Some(*v as f64),
647            SqlValue::Double(v) => Some(*v),
648            SqlValue::Float(v) => Some(*v as f64),
649            SqlValue::Bigint(v) => Some(*v as f64),
650            SqlValue::Numeric(v) => Some(*v),
651            _ => None,
652        },
653        _ => None,
654    }
655}
656
657/// Execute a query using columnar processing (AST-based interface)
658///
659/// This is the entry point for columnar execution that accepts AST expressions
660/// and converts them to the columnar execution pipeline.
661///
662/// # Arguments
663///
664/// * `rows` - The rows to process
665/// * `filter` - Optional WHERE clause expression
666/// * `aggregates` - SELECT list aggregate expressions
667/// * `schema` - Schema for resolving column names to indices
668///
669/// # Returns
670///
671/// Some(Result) if the query can be optimized using columnar execution,
672/// None if the expressions are too complex for columnar optimization.
673///
674/// Note: This function uses LLVM auto-vectorization for vectorized execution.
675pub fn execute_columnar(
676    rows: &[Row],
677    filter: Option<&vibesql_ast::Expression>,
678    aggregates: &[vibesql_ast::Expression],
679    schema: &CombinedSchema,
680) -> Option<Result<Vec<Row>, ExecutorError>> {
681    log::debug!("  Executing columnar query with {} rows", rows.len());
682
683    // Extract column predicates from filter expression
684    let predicates = if let Some(filter_expr) = filter {
685        match extract_column_predicates(filter_expr, schema) {
686            Some(preds) => {
687                log::debug!("    ✓ Extracted {} column predicates for SIMD filtering", preds.len());
688                preds
689            }
690            None => {
691                log::debug!("    ✗ Filter too complex for columnar optimization");
692                return None; // Too complex for columnar optimization
693            }
694        }
695    } else {
696        log::debug!("    No filter predicates");
697        vec![] // No filter
698    };
699
700    // Extract aggregates from SELECT list
701    let agg_specs = match extract_aggregates(aggregates, schema) {
702        Some(specs) => {
703            log::debug!("    ✓ Extracted {} aggregate operations", specs.len());
704            for (i, spec) in specs.iter().enumerate() {
705                log::debug!("      Aggregate {}: {:?}", i + 1, spec.op);
706            }
707            specs
708        }
709        None => {
710            log::debug!("    ✗ Aggregates too complex for columnar optimization");
711            return None; // Too complex for columnar optimization
712        }
713    };
714
715    // Call the simplified interface, passing schema if any aggregates use expressions
716    let needs_schema = agg_specs
717        .iter()
718        .any(|spec| matches!(spec.source, aggregate::AggregateSource::Expression(_)));
719    let schema_ref = if needs_schema { Some(schema) } else { None };
720
721    log::debug!("    Executing SIMD-accelerated columnar aggregation");
722    Some(execute_columnar_aggregate(rows, &predicates, &agg_specs, schema_ref))
723}
724
725#[cfg(test)]
726mod tests {
727    use vibesql_types::Date;
728
729    use super::*;
730
731    /// Test the full columnar pipeline: filter + aggregation
732    #[test]
733    fn test_columnar_pipeline_filtered_sum() {
734        // Create test data: TPC-H Q6 style query
735        // SELECT SUM(l_extendedprice * l_discount)
736        // WHERE l_shipdate >= '1994-01-01'
737        //   AND l_shipdate < '1995-01-01'
738        //   AND l_discount BETWEEN 0.05 AND 0.07
739        //   AND l_quantity < 24
740
741        let rows = vec![
742            Row::new(vec![
743                SqlValue::Integer(10),   // quantity
744                SqlValue::Double(100.0), // extendedprice
745                SqlValue::Double(0.06),  // discount
746                SqlValue::Date(Date::new(1994, 6, 1).unwrap()),
747            ]),
748            Row::new(vec![
749                SqlValue::Integer(25), // quantity (filtered out: > 24)
750                SqlValue::Double(200.0),
751                SqlValue::Double(0.06),
752                SqlValue::Date(Date::new(1994, 7, 1).unwrap()),
753            ]),
754            Row::new(vec![
755                SqlValue::Integer(15), // quantity
756                SqlValue::Double(150.0),
757                SqlValue::Double(0.05), // discount
758                SqlValue::Date(Date::new(1994, 8, 1).unwrap()),
759            ]),
760            Row::new(vec![
761                SqlValue::Integer(20), // quantity
762                SqlValue::Double(180.0),
763                SqlValue::Double(0.08), // discount (filtered out: > 0.07)
764                SqlValue::Date(Date::new(1994, 9, 1).unwrap()),
765            ]),
766        ];
767
768        // Predicates: quantity < 24 AND discount BETWEEN 0.05 AND 0.07
769        let predicates = vec![
770            ColumnPredicate::LessThan { column_idx: 0, value: SqlValue::Integer(24) },
771            ColumnPredicate::Between {
772                column_idx: 2,
773                low: SqlValue::Double(0.05),
774                high: SqlValue::Double(0.07),
775            },
776        ];
777
778        // Aggregates: SUM(extendedprice), COUNT(*)
779        let aggregates = vec![
780            AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(1) }, /* SUM(extendedprice) */
781            AggregateSpec { op: AggregateOp::Count, source: AggregateSource::Column(0) }, /* COUNT(*) */
782        ];
783
784        let result = execute_columnar_aggregate(&rows, &predicates, &aggregates, None).unwrap();
785
786        assert_eq!(result.len(), 1);
787        let result_row = &result[0];
788
789        // Only rows 0 and 2 pass the filter (quantity < 24 AND discount in range)
790        // SUM(extendedprice) = 100.0 + 150.0 = 250.0
791        assert!(
792            matches!(result_row.get(0), Some(&SqlValue::Double(sum)) if (sum - 250.0).abs() < 0.001)
793        );
794        // COUNT(*) = 2
795        assert_eq!(result_row.get(1), Some(&SqlValue::Integer(2)));
796    }
797
798    /// Test columnar execution with no filtering
799    #[test]
800    fn test_columnar_pipeline_no_filter() {
801        let rows = vec![
802            Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
803            Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
804            Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
805        ];
806
807        let predicates = vec![];
808        let aggregates = vec![
809            AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
810            AggregateSpec { op: AggregateOp::Avg, source: AggregateSource::Column(1) },
811            AggregateSpec { op: AggregateOp::Max, source: AggregateSource::Column(0) },
812        ];
813
814        let result = execute_columnar_aggregate(&rows, &predicates, &aggregates, None).unwrap();
815
816        assert_eq!(result.len(), 1);
817        let result_row = &result[0];
818
819        // SUM(col0) = 60 (SQLite's SUM() returns REAL)
820        assert_eq!(result_row.get(0), Some(&SqlValue::Integer(60)));
821        // AVG(col1) = 2.5
822        assert!(
823            matches!(result_row.get(1), Some(&SqlValue::Double(avg)) if (avg - 2.5).abs() < 0.001)
824        );
825        // MAX(col0) = 30
826        assert_eq!(result_row.get(2), Some(&SqlValue::Integer(30)));
827    }
828
829    /// Test columnar execution with empty result set
830    #[test]
831    fn test_columnar_pipeline_empty_result() {
832        let rows =
833            vec![Row::new(vec![SqlValue::Integer(100)]), Row::new(vec![SqlValue::Integer(200)])];
834
835        // Filter that matches nothing
836        let predicates =
837            vec![ColumnPredicate::LessThan { column_idx: 0, value: SqlValue::Integer(50) }];
838
839        let aggregates = vec![
840            AggregateSpec { op: AggregateOp::Sum, source: AggregateSource::Column(0) },
841            AggregateSpec { op: AggregateOp::Count, source: AggregateSource::Column(0) },
842        ];
843
844        let result = execute_columnar_aggregate(&rows, &predicates, &aggregates, None).unwrap();
845
846        assert_eq!(result.len(), 1);
847        let result_row = &result[0];
848
849        // SUM of empty set = NULL
850        assert_eq!(result_row.get(0), Some(&SqlValue::Null));
851        // COUNT of empty set = 0
852        assert_eq!(result_row.get(1), Some(&SqlValue::Integer(0)));
853    }
854
855    // AST Integration Tests
856
857    use vibesql_ast::{BinaryOperator, Expression};
858    use vibesql_catalog::{ColumnSchema, TableSchema};
859    use vibesql_types::DataType;
860
861    use crate::schema::CombinedSchema;
862
863    fn make_test_schema() -> CombinedSchema {
864        let schema = TableSchema::new(
865            "test".to_string(),
866            vec![
867                ColumnSchema::new("quantity".to_string(), DataType::Integer, false),
868                ColumnSchema::new("price".to_string(), DataType::DoublePrecision, false),
869            ],
870        );
871        CombinedSchema::from_table("test".to_string(), schema)
872    }
873
874    fn make_test_rows_for_ast() -> Vec<Row> {
875        vec![
876            Row::new(vec![SqlValue::Integer(10), SqlValue::Double(1.5)]),
877            Row::new(vec![SqlValue::Integer(20), SqlValue::Double(2.5)]),
878            Row::new(vec![SqlValue::Integer(30), SqlValue::Double(3.5)]),
879            Row::new(vec![SqlValue::Integer(40), SqlValue::Double(4.5)]),
880        ]
881    }
882
883    #[test]
884    fn test_execute_columnar_simple_aggregate() {
885        let rows = make_test_rows_for_ast();
886        let schema = make_test_schema();
887
888        // SELECT SUM(price) FROM test
889        let aggregates = vec![Expression::AggregateFunction {
890            name: vibesql_ast::FunctionIdentifier::new("SUM"),
891            distinct: false,
892            args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
893                "price", false,
894            ))],
895            order_by: None,
896            filter: None,
897        }];
898
899        let result = execute_columnar(&rows, None, &aggregates, &schema);
900        assert!(result.is_some());
901
902        let rows_result = result.unwrap();
903        assert!(rows_result.is_ok());
904
905        let result_rows = rows_result.unwrap();
906        assert_eq!(result_rows.len(), 1);
907        assert_eq!(result_rows[0].len(), 1);
908
909        // Sum should be 1.5 + 2.5 + 3.5 + 4.5 = 12.0
910        if let Some(SqlValue::Double(sum)) = result_rows[0].get(0) {
911            assert!((sum - 12.0).abs() < 0.001);
912        } else {
913            panic!("Expected Numeric value for SUM");
914        }
915    }
916
917    #[test]
918    fn test_execute_columnar_with_filter() {
919        let rows = make_test_rows_for_ast();
920        let schema = make_test_schema();
921
922        // SELECT SUM(price) FROM test WHERE quantity < 25
923        let filter = Expression::BinaryOp {
924            left: Box::new(Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
925                "quantity", false,
926            ))),
927            op: BinaryOperator::LessThan,
928            right: Box::new(Expression::Literal(SqlValue::Integer(25))),
929        };
930
931        let aggregates = vec![Expression::AggregateFunction {
932            name: vibesql_ast::FunctionIdentifier::new("SUM"),
933            distinct: false,
934            args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
935                "price", false,
936            ))],
937            order_by: None,
938            filter: None,
939        }];
940
941        let result = execute_columnar(&rows, Some(&filter), &aggregates, &schema);
942        assert!(result.is_some());
943
944        let rows_result = result.unwrap();
945        assert!(rows_result.is_ok());
946
947        let result_rows = rows_result.unwrap();
948        assert_eq!(result_rows.len(), 1);
949        assert_eq!(result_rows[0].len(), 1);
950
951        // Sum of rows where quantity < 25: 1.5 + 2.5 = 4.0
952        if let Some(SqlValue::Double(sum)) = result_rows[0].get(0) {
953            assert!((sum - 4.0).abs() < 0.001);
954        } else {
955            panic!("Expected Numeric value for SUM");
956        }
957    }
958
959    #[test]
960    fn test_execute_columnar_multiple_aggregates() {
961        let rows = make_test_rows_for_ast();
962        let schema = make_test_schema();
963
964        // SELECT SUM(price), COUNT(*), AVG(quantity) FROM test
965        let aggregates = vec![
966            Expression::AggregateFunction {
967                name: vibesql_ast::FunctionIdentifier::new("SUM"),
968                distinct: false,
969                args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
970                    "price", false,
971                ))],
972                order_by: None,
973                filter: None,
974            },
975            Expression::AggregateFunction {
976                name: vibesql_ast::FunctionIdentifier::new("COUNT"),
977                distinct: false,
978                args: vec![Expression::Wildcard],
979                order_by: None,
980                filter: None,
981            },
982            Expression::AggregateFunction {
983                name: vibesql_ast::FunctionIdentifier::new("AVG"),
984                distinct: false,
985                args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
986                    "quantity", false,
987                ))],
988                order_by: None,
989                filter: None,
990            },
991        ];
992
993        let result = execute_columnar(&rows, None, &aggregates, &schema);
994        assert!(result.is_some());
995
996        let rows_result = result.unwrap();
997        assert!(rows_result.is_ok());
998
999        let result_rows = rows_result.unwrap();
1000        assert_eq!(result_rows.len(), 1);
1001        assert_eq!(result_rows[0].len(), 3);
1002
1003        // Check SUM(price) = 12.0
1004        if let Some(SqlValue::Double(sum)) = result_rows[0].get(0) {
1005            assert!((sum - 12.0).abs() < 0.001);
1006        } else {
1007            panic!("Expected Numeric value for SUM");
1008        }
1009
1010        // Check COUNT(*) = 4
1011        assert_eq!(result_rows[0].get(1), Some(&SqlValue::Integer(4)));
1012
1013        // Check AVG(quantity) = (10+20+30+40)/4 = 25.0
1014        if let Some(SqlValue::Double(avg)) = result_rows[0].get(2) {
1015            assert!((avg - 25.0).abs() < 0.001);
1016        } else {
1017            panic!("Expected Double value for AVG");
1018        }
1019    }
1020
1021    #[test]
1022    fn test_execute_columnar_unsupported_distinct() {
1023        let rows = make_test_rows_for_ast();
1024        let schema = make_test_schema();
1025
1026        // SELECT SUM(DISTINCT price) FROM test - should return None
1027        let aggregates = vec![Expression::AggregateFunction {
1028            name: vibesql_ast::FunctionIdentifier::new("SUM"),
1029            distinct: true,
1030            args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
1031                "price", false,
1032            ))],
1033            order_by: None,
1034            filter: None,
1035        }];
1036
1037        let result = execute_columnar(&rows, None, &aggregates, &schema);
1038        assert!(result.is_none());
1039    }
1040
1041    #[test]
1042    fn test_execute_columnar_unsupported_complex_filter() {
1043        let rows = make_test_rows_for_ast();
1044        let schema = make_test_schema();
1045
1046        // SELECT SUM(price) FROM test WHERE quantity IN (SELECT ...) - unsupported
1047        let filter = Expression::ScalarSubquery(Box::new(vibesql_ast::SelectStmt {
1048            with_clause: None,
1049            distinct: false,
1050            select_list: vec![],
1051            into_table: None,
1052            into_variables: None,
1053            from: None,
1054            where_clause: None,
1055            group_by: None,
1056            having: None,
1057            order_by: None,
1058            limit: None,
1059            offset: None,
1060            set_operation: None,
1061            values: None,
1062        }));
1063
1064        let aggregates = vec![Expression::AggregateFunction {
1065            name: vibesql_ast::FunctionIdentifier::new("SUM"),
1066            distinct: false,
1067            args: vec![Expression::ColumnRef(vibesql_ast::ColumnIdentifier::simple(
1068                "price", false,
1069            ))],
1070            order_by: None,
1071            filter: None,
1072        }];
1073
1074        let result = execute_columnar(&rows, Some(&filter), &aggregates, &schema);
1075        assert!(result.is_none());
1076    }
1077}