Skip to main content

trueno_db/query/
executor.rs

1//! Query execution engine
2//!
3//! Executes parsed SQL queries against Arrow storage using GPU/SIMD backends.
4//!
5//! Toyota Way Principles:
6//! - Jidoka: Backend equivalence (GPU == SIMD == Scalar results)
7//! - Kaizen: Top-K optimization (O(N log K) vs O(N log N))
8//! - Genchi Genbutsu: Cost-based backend selection
9
10use super::{AggregateFunction, OrderDirection, QueryPlan};
11use crate::storage::StorageEngine;
12use crate::topk::{SortOrder, TopKSelection};
13use crate::{Backend, Error, Result};
14use arrow::array::{
15    Array, ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch,
16};
17use arrow::compute;
18use arrow::datatypes::{DataType, Field, Schema};
19use std::sync::Arc;
20
21/// Query executor for parsed SQL queries
22pub struct QueryExecutor {
23    #[allow(dead_code)]
24    backend: Backend,
25}
26
27impl Default for QueryExecutor {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl QueryExecutor {
34    /// Create a new query executor with cost-based backend selection
35    #[must_use]
36    pub const fn new() -> Self {
37        Self { backend: Backend::CostBased }
38    }
39
40    /// Create executor with forced backend
41    #[must_use]
42    pub const fn with_backend(backend: Backend) -> Self {
43        Self { backend }
44    }
45
46    /// Execute a query plan against storage
47    ///
48    /// # Arguments
49    /// * `plan` - Parsed query plan from `QueryEngine::parse()`
50    /// * `storage` - Storage engine containing the data
51    ///
52    /// # Returns
53    /// Result record batch with query results
54    ///
55    /// # Errors
56    /// Returns error if:
57    /// - Table not found in storage
58    /// - Column not found in schema
59    /// - Data type mismatch
60    /// - Backend execution failure
61    ///
62    /// # Example
63    /// ```rust,no_run
64    /// use trueno_db::query::{QueryEngine, QueryExecutor};
65    /// use trueno_db::storage::StorageEngine;
66    ///
67    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
68    /// let storage = StorageEngine::load_parquet("data/events.parquet")?;
69    /// let engine = QueryEngine::new();
70    /// let executor = QueryExecutor::new();
71    ///
72    /// let plan = engine.parse("SELECT category, SUM(value) FROM events GROUP BY category")?;
73    /// let result = executor.execute(&plan, &storage)?;
74    ///
75    /// println!("Results: {} rows", result.num_rows());
76    /// # Ok(())
77    /// # }
78    /// ```
79    pub fn execute(&self, plan: &QueryPlan, storage: &StorageEngine) -> Result<RecordBatch> {
80        // Get all batches from storage
81        let batches = storage.batches();
82        if batches.is_empty() {
83            return Err(Error::InvalidInput("No data in storage".to_string()));
84        }
85
86        // Combine batches (Phase 1: single table only)
87        let combined = Self::combine_batches(batches)?;
88
89        // Apply WHERE filter
90        let filtered = if let Some(ref filter_expr) = plan.filter {
91            Self::apply_filter(&combined, filter_expr)?
92        } else {
93            combined
94        };
95
96        // Execute aggregations if present
97        let result = if plan.aggregations.is_empty() {
98            // Project columns
99            Self::project_columns(&filtered, &plan.columns)?
100        } else {
101            Self::execute_aggregations(&filtered, plan)?
102        };
103
104        // Apply ORDER BY + LIMIT (Top-K optimization)
105        let result = if !plan.order_by.is_empty() {
106            Self::apply_order_by_limit(&result, plan)?
107        } else if let Some(limit) = plan.limit {
108            // LIMIT without ORDER BY: just slice
109            result.slice(0, limit.min(result.num_rows()))
110        } else {
111            result
112        };
113
114        Ok(result)
115    }
116
117    /// Combine multiple batches into single batch
118    fn combine_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
119        if batches.len() == 1 {
120            return Ok(batches[0].clone());
121        }
122
123        // Use Arrow concat
124        compute::concat_batches(&batches[0].schema(), batches)
125            .map_err(|e| Error::StorageError(format!("Failed to combine batches: {e}")))
126    }
127
128    /// Apply WHERE filter
129    fn apply_filter(batch: &RecordBatch, filter_expr: &str) -> Result<RecordBatch> {
130        // Phase 1: Simple predicates only (column > value, column < value, etc.)
131        // Parse filter expression: "column op value"
132        let parts: Vec<&str> = filter_expr.split_whitespace().collect();
133        if parts.len() < 3 {
134            return Err(Error::ParseError(format!("Invalid filter expression: {filter_expr}")));
135        }
136
137        let column_name = parts[0];
138        let op = parts[1];
139        let value_str = parts.get(2..).unwrap_or(&[]).join(" ");
140
141        // Find column index
142        let schema = batch.schema();
143        let column_index = schema
144            .fields()
145            .iter()
146            .position(|f| f.name() == column_name)
147            .ok_or_else(|| Error::InvalidInput(format!("Column not found: {column_name}")))?;
148
149        let column = batch.column(column_index);
150
151        // Build boolean mask based on data type
152        let mask = match column.data_type() {
153            DataType::Int32 => {
154                let array = column
155                    .as_any()
156                    .downcast_ref::<Int32Array>()
157                    .ok_or_else(|| Error::Other("Failed to downcast to Int32Array".to_string()))?;
158                let value: i32 = value_str
159                    .parse()
160                    .map_err(|_| Error::ParseError(format!("Invalid Int32 value: {value_str}")))?;
161                Self::build_comparison_mask_i32(array, op, value)?
162            }
163            DataType::Int64 => {
164                let array = column
165                    .as_any()
166                    .downcast_ref::<Int64Array>()
167                    .ok_or_else(|| Error::Other("Failed to downcast to Int64Array".to_string()))?;
168                let value: i64 = value_str
169                    .parse()
170                    .map_err(|_| Error::ParseError(format!("Invalid Int64 value: {value_str}")))?;
171                Self::build_comparison_mask_i64(array, op, value)?
172            }
173            DataType::Float32 => {
174                let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
175                    Error::Other("Failed to downcast to Float32Array".to_string())
176                })?;
177                let value: f32 = value_str.parse().map_err(|_| {
178                    Error::ParseError(format!("Invalid Float32 value: {value_str}"))
179                })?;
180                Self::build_comparison_mask_f32(array, op, value)?
181            }
182            DataType::Float64 => {
183                let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
184                    Error::Other("Failed to downcast to Float64Array".to_string())
185                })?;
186                let value: f64 = value_str.parse().map_err(|_| {
187                    Error::ParseError(format!("Invalid Float64 value: {value_str}"))
188                })?;
189                Self::build_comparison_mask_f64(array, op, value)?
190            }
191            dt => {
192                return Err(Error::InvalidInput(format!(
193                    "Filter not supported for data type: {dt:?}"
194                )))
195            }
196        };
197
198        // Apply filter using Arrow compute
199        compute::filter_record_batch(batch, &mask)
200            .map_err(|e| Error::StorageError(format!("Failed to apply filter: {e}")))
201    }
202
203    #[allow(clippy::unnecessary_wraps)]
204    fn build_comparison_mask_i32(
205        array: &Int32Array,
206        op: &str,
207        value: i32,
208    ) -> Result<arrow::array::BooleanArray> {
209        use arrow::array::BooleanArray;
210        let values: Vec<bool> = (0..array.len())
211            .map(|i| {
212                if array.is_null(i) {
213                    false
214                } else {
215                    let v = array.value(i);
216                    match op {
217                        ">" => v > value,
218                        ">=" => v >= value,
219                        "<" => v < value,
220                        "<=" => v <= value,
221                        "=" => v == value,
222                        "!=" | "<>" => v != value,
223                        _ => false,
224                    }
225                }
226            })
227            .collect();
228        Ok(BooleanArray::from(values))
229    }
230
231    #[allow(clippy::unnecessary_wraps)]
232    fn build_comparison_mask_i64(
233        array: &Int64Array,
234        op: &str,
235        value: i64,
236    ) -> Result<arrow::array::BooleanArray> {
237        use arrow::array::BooleanArray;
238        let values: Vec<bool> = (0..array.len())
239            .map(|i| {
240                if array.is_null(i) {
241                    false
242                } else {
243                    let v = array.value(i);
244                    match op {
245                        ">" => v > value,
246                        ">=" => v >= value,
247                        "<" => v < value,
248                        "<=" => v <= value,
249                        "=" => v == value,
250                        "!=" | "<>" => v != value,
251                        _ => false,
252                    }
253                }
254            })
255            .collect();
256        Ok(BooleanArray::from(values))
257    }
258
259    #[allow(clippy::unnecessary_wraps)]
260    fn build_comparison_mask_f32(
261        array: &Float32Array,
262        op: &str,
263        value: f32,
264    ) -> Result<arrow::array::BooleanArray> {
265        use arrow::array::BooleanArray;
266        let values: Vec<bool> = (0..array.len())
267            .map(|i| {
268                if array.is_null(i) {
269                    false
270                } else {
271                    let v = array.value(i);
272                    match op {
273                        ">" => v > value,
274                        ">=" => v >= value,
275                        "<" => v < value,
276                        "<=" => v <= value,
277                        "=" => (v - value).abs() < f32::EPSILON,
278                        "!=" | "<>" => (v - value).abs() >= f32::EPSILON,
279                        _ => false,
280                    }
281                }
282            })
283            .collect();
284        Ok(BooleanArray::from(values))
285    }
286
287    #[allow(clippy::unnecessary_wraps)]
288    fn build_comparison_mask_f64(
289        array: &Float64Array,
290        op: &str,
291        value: f64,
292    ) -> Result<arrow::array::BooleanArray> {
293        use arrow::array::BooleanArray;
294        let values: Vec<bool> = (0..array.len())
295            .map(|i| {
296                if array.is_null(i) {
297                    false
298                } else {
299                    let v = array.value(i);
300                    match op {
301                        ">" => v > value,
302                        ">=" => v >= value,
303                        "<" => v < value,
304                        "<=" => v <= value,
305                        "=" => (v - value).abs() < f64::EPSILON,
306                        "!=" | "<>" => (v - value).abs() >= f64::EPSILON,
307                        _ => false,
308                    }
309                }
310            })
311            .collect();
312        Ok(BooleanArray::from(values))
313    }
314
315    /// Project columns from batch
316    fn project_columns(batch: &RecordBatch, columns: &[String]) -> Result<RecordBatch> {
317        if columns.len() == 1 && columns[0] == "*" {
318            return Ok(batch.clone());
319        }
320
321        let schema = batch.schema();
322        let mut new_columns = Vec::new();
323        let mut new_fields = Vec::new();
324
325        for col_name in columns {
326            let index = schema
327                .fields()
328                .iter()
329                .position(|f| f.name() == col_name)
330                .ok_or_else(|| Error::InvalidInput(format!("Column not found: {col_name}")))?;
331
332            new_columns.push(batch.column(index).clone());
333            new_fields.push(schema.field(index).clone());
334        }
335
336        let new_schema = Arc::new(Schema::new(new_fields));
337        RecordBatch::try_new(new_schema, new_columns)
338            .map_err(|e| Error::StorageError(format!("Failed to project columns: {e}")))
339    }
340
341    /// Execute aggregations
342    fn execute_aggregations(batch: &RecordBatch, plan: &QueryPlan) -> Result<RecordBatch> {
343        // Phase 1: Simple aggregations without GROUP BY
344        if !plan.group_by.is_empty() {
345            return Err(Error::InvalidInput(
346                "GROUP BY aggregations not yet implemented in Phase 1".to_string(),
347            ));
348        }
349
350        let mut result_columns: Vec<ArrayRef> = Vec::new();
351        let mut result_fields: Vec<Field> = Vec::new();
352
353        for (agg_func, col_name, alias) in &plan.aggregations {
354            let result_name = alias.as_deref().unwrap_or(col_name);
355
356            // Find column
357            let schema = batch.schema();
358            let col_index = schema
359                .fields()
360                .iter()
361                .position(|f| f.name() == col_name || col_name == "*")
362                .ok_or_else(|| Error::InvalidInput(format!("Column not found: {col_name}")))?;
363
364            let column = batch.column(col_index);
365
366            // Execute aggregation
367            let (result_value, result_type) =
368                Self::execute_single_aggregation(*agg_func, column, batch.num_rows())?;
369
370            result_columns.push(result_value);
371            result_fields.push(Field::new(result_name, result_type, false));
372        }
373
374        let result_schema = Arc::new(Schema::new(result_fields));
375        RecordBatch::try_new(result_schema, result_columns)
376            .map_err(|e| Error::StorageError(format!("Failed to create result batch: {e}")))
377    }
378
379    /// Execute single aggregation function
380    fn execute_single_aggregation(
381        func: AggregateFunction,
382        column: &ArrayRef,
383        num_rows: usize,
384    ) -> Result<(ArrayRef, DataType)> {
385        match column.data_type() {
386            DataType::Int32 => {
387                let array = column
388                    .as_any()
389                    .downcast_ref::<Int32Array>()
390                    .ok_or_else(|| Error::Other("Failed to downcast to Int32Array".to_string()))?;
391                Self::aggregate_i32(func, array, num_rows)
392            }
393            DataType::Int64 => {
394                let array = column
395                    .as_any()
396                    .downcast_ref::<Int64Array>()
397                    .ok_or_else(|| Error::Other("Failed to downcast to Int64Array".to_string()))?;
398                Self::aggregate_i64(func, array, num_rows)
399            }
400            DataType::Float32 => {
401                let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
402                    Error::Other("Failed to downcast to Float32Array".to_string())
403                })?;
404                Self::aggregate_f32(func, array, num_rows)
405            }
406            DataType::Float64 => {
407                let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
408                    Error::Other("Failed to downcast to Float64Array".to_string())
409                })?;
410                Self::aggregate_f64(func, array, num_rows)
411            }
412            dt => {
413                Err(Error::InvalidInput(format!("Aggregation not supported for data type: {dt:?}")))
414            }
415        }
416    }
417
418    #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
419    fn aggregate_i32(
420        func: AggregateFunction,
421        array: &Int32Array,
422        num_rows: usize,
423    ) -> Result<(ArrayRef, DataType)> {
424        match func {
425            AggregateFunction::Sum => {
426                let sum: i64 = (0..array.len())
427                    .filter(|&i| !array.is_null(i))
428                    .map(|i| i64::from(array.value(i)))
429                    .sum();
430                Ok((Arc::new(Int64Array::from(vec![sum])), DataType::Int64))
431            }
432            AggregateFunction::Avg => {
433                let sum: f64 = (0..array.len())
434                    .filter(|&i| !array.is_null(i))
435                    .map(|i| f64::from(array.value(i)))
436                    .sum();
437                let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
438                let avg = if count > 0 { sum / count as f64 } else { 0.0 };
439                Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
440            }
441            AggregateFunction::Count => {
442                Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
443            }
444            AggregateFunction::Min => {
445                let min = (0..array.len())
446                    .filter(|&i| !array.is_null(i))
447                    .map(|i| array.value(i))
448                    .min()
449                    .unwrap_or(0);
450                Ok((Arc::new(Int32Array::from(vec![min])), DataType::Int32))
451            }
452            AggregateFunction::Max => {
453                let max = (0..array.len())
454                    .filter(|&i| !array.is_null(i))
455                    .map(|i| array.value(i))
456                    .max()
457                    .unwrap_or(0);
458                Ok((Arc::new(Int32Array::from(vec![max])), DataType::Int32))
459            }
460        }
461    }
462
463    #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
464    fn aggregate_i64(
465        func: AggregateFunction,
466        array: &Int64Array,
467        num_rows: usize,
468    ) -> Result<(ArrayRef, DataType)> {
469        match func {
470            AggregateFunction::Sum => {
471                let sum: i64 =
472                    (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
473                Ok((Arc::new(Int64Array::from(vec![sum])), DataType::Int64))
474            }
475            AggregateFunction::Avg => {
476                let sum: f64 = (0..array.len())
477                    .filter(|&i| !array.is_null(i))
478                    .map(|i| array.value(i) as f64)
479                    .sum();
480                let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
481                let avg = if count > 0 { sum / count as f64 } else { 0.0 };
482                Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
483            }
484            AggregateFunction::Count => {
485                Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
486            }
487            AggregateFunction::Min => {
488                let min = (0..array.len())
489                    .filter(|&i| !array.is_null(i))
490                    .map(|i| array.value(i))
491                    .min()
492                    .unwrap_or(0);
493                Ok((Arc::new(Int64Array::from(vec![min])), DataType::Int64))
494            }
495            AggregateFunction::Max => {
496                let max = (0..array.len())
497                    .filter(|&i| !array.is_null(i))
498                    .map(|i| array.value(i))
499                    .max()
500                    .unwrap_or(0);
501                Ok((Arc::new(Int64Array::from(vec![max])), DataType::Int64))
502            }
503        }
504    }
505
506    #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
507    fn aggregate_f32(
508        func: AggregateFunction,
509        array: &Float32Array,
510        num_rows: usize,
511    ) -> Result<(ArrayRef, DataType)> {
512        match func {
513            AggregateFunction::Sum => {
514                let sum: f32 =
515                    (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
516                Ok((Arc::new(Float32Array::from(vec![sum])), DataType::Float32))
517            }
518            AggregateFunction::Avg => {
519                let sum: f64 = (0..array.len())
520                    .filter(|&i| !array.is_null(i))
521                    .map(|i| f64::from(array.value(i)))
522                    .sum();
523                let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
524                let avg = if count > 0 { sum / count as f64 } else { 0.0 };
525                Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
526            }
527            AggregateFunction::Count => {
528                Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
529            }
530            AggregateFunction::Min => {
531                let min = (0..array.len())
532                    .filter(|&i| !array.is_null(i))
533                    .map(|i| array.value(i))
534                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
535                    .unwrap_or(0.0);
536                Ok((Arc::new(Float32Array::from(vec![min])), DataType::Float32))
537            }
538            AggregateFunction::Max => {
539                let max = (0..array.len())
540                    .filter(|&i| !array.is_null(i))
541                    .map(|i| array.value(i))
542                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
543                    .unwrap_or(0.0);
544                Ok((Arc::new(Float32Array::from(vec![max])), DataType::Float32))
545            }
546        }
547    }
548
549    #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
550    fn aggregate_f64(
551        func: AggregateFunction,
552        array: &Float64Array,
553        num_rows: usize,
554    ) -> Result<(ArrayRef, DataType)> {
555        match func {
556            AggregateFunction::Sum => {
557                let sum: f64 =
558                    (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
559                Ok((Arc::new(Float64Array::from(vec![sum])), DataType::Float64))
560            }
561            AggregateFunction::Avg => {
562                let sum: f64 =
563                    (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
564                let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
565                let avg = if count > 0 { sum / count as f64 } else { 0.0 };
566                Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
567            }
568            AggregateFunction::Count => {
569                Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
570            }
571            AggregateFunction::Min => {
572                let min = (0..array.len())
573                    .filter(|&i| !array.is_null(i))
574                    .map(|i| array.value(i))
575                    .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
576                    .unwrap_or(0.0);
577                Ok((Arc::new(Float64Array::from(vec![min])), DataType::Float64))
578            }
579            AggregateFunction::Max => {
580                let max = (0..array.len())
581                    .filter(|&i| !array.is_null(i))
582                    .map(|i| array.value(i))
583                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
584                    .unwrap_or(0.0);
585                Ok((Arc::new(Float64Array::from(vec![max])), DataType::Float64))
586            }
587        }
588    }
589
590    /// Apply ORDER BY + LIMIT using Top-K optimization
591    fn apply_order_by_limit(batch: &RecordBatch, plan: &QueryPlan) -> Result<RecordBatch> {
592        if plan.order_by.is_empty() {
593            return Ok(batch.clone());
594        }
595
596        // Phase 1: Single ORDER BY column only
597        let (col_name, direction) = &plan.order_by[0];
598
599        // Find column index
600        let schema = batch.schema();
601        let col_index = schema
602            .fields()
603            .iter()
604            .position(|f| f.name() == col_name)
605            .ok_or_else(|| Error::InvalidInput(format!("Column not found: {col_name}")))?;
606
607        // Convert OrderDirection to SortOrder
608        let sort_order = match direction {
609            OrderDirection::Asc => SortOrder::Ascending,
610            OrderDirection::Desc => SortOrder::Descending,
611        };
612
613        // Use Top-K if LIMIT is present, otherwise sort all
614        let k = plan.limit.unwrap_or_else(|| batch.num_rows());
615        batch.top_k(col_index, k, sort_order)
616    }
617}