1use super::{AggregateFunction, OrderDirection, QueryPlan};
11use crate::storage::StorageEngine;
12use crate::topk::{SortOrder, TopKSelection};
13use crate::{Backend, Error, Result};
14use arrow::array::{
15 Array, ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch,
16};
17use arrow::compute;
18use arrow::datatypes::{DataType, Field, Schema};
19use std::sync::Arc;
20
21pub struct QueryExecutor {
23 #[allow(dead_code)]
24 backend: Backend,
25}
26
27impl Default for QueryExecutor {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl QueryExecutor {
34 #[must_use]
36 pub const fn new() -> Self {
37 Self { backend: Backend::CostBased }
38 }
39
40 #[must_use]
42 pub const fn with_backend(backend: Backend) -> Self {
43 Self { backend }
44 }
45
46 pub fn execute(&self, plan: &QueryPlan, storage: &StorageEngine) -> Result<RecordBatch> {
80 let batches = storage.batches();
82 if batches.is_empty() {
83 return Err(Error::InvalidInput("No data in storage".to_string()));
84 }
85
86 let combined = Self::combine_batches(batches)?;
88
89 let filtered = if let Some(ref filter_expr) = plan.filter {
91 Self::apply_filter(&combined, filter_expr)?
92 } else {
93 combined
94 };
95
96 let result = if plan.aggregations.is_empty() {
98 Self::project_columns(&filtered, &plan.columns)?
100 } else {
101 Self::execute_aggregations(&filtered, plan)?
102 };
103
104 let result = if !plan.order_by.is_empty() {
106 Self::apply_order_by_limit(&result, plan)?
107 } else if let Some(limit) = plan.limit {
108 result.slice(0, limit.min(result.num_rows()))
110 } else {
111 result
112 };
113
114 Ok(result)
115 }
116
117 fn combine_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
119 if batches.len() == 1 {
120 return Ok(batches[0].clone());
121 }
122
123 compute::concat_batches(&batches[0].schema(), batches)
125 .map_err(|e| Error::StorageError(format!("Failed to combine batches: {e}")))
126 }
127
128 fn apply_filter(batch: &RecordBatch, filter_expr: &str) -> Result<RecordBatch> {
130 let parts: Vec<&str> = filter_expr.split_whitespace().collect();
133 if parts.len() < 3 {
134 return Err(Error::ParseError(format!("Invalid filter expression: {filter_expr}")));
135 }
136
137 let column_name = parts[0];
138 let op = parts[1];
139 let value_str = parts.get(2..).unwrap_or(&[]).join(" ");
140
141 let schema = batch.schema();
143 let column_index = schema
144 .fields()
145 .iter()
146 .position(|f| f.name() == column_name)
147 .ok_or_else(|| Error::InvalidInput(format!("Column not found: {column_name}")))?;
148
149 let column = batch.column(column_index);
150
151 let mask = match column.data_type() {
153 DataType::Int32 => {
154 let array = column
155 .as_any()
156 .downcast_ref::<Int32Array>()
157 .ok_or_else(|| Error::Other("Failed to downcast to Int32Array".to_string()))?;
158 let value: i32 = value_str
159 .parse()
160 .map_err(|_| Error::ParseError(format!("Invalid Int32 value: {value_str}")))?;
161 Self::build_comparison_mask_i32(array, op, value)?
162 }
163 DataType::Int64 => {
164 let array = column
165 .as_any()
166 .downcast_ref::<Int64Array>()
167 .ok_or_else(|| Error::Other("Failed to downcast to Int64Array".to_string()))?;
168 let value: i64 = value_str
169 .parse()
170 .map_err(|_| Error::ParseError(format!("Invalid Int64 value: {value_str}")))?;
171 Self::build_comparison_mask_i64(array, op, value)?
172 }
173 DataType::Float32 => {
174 let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
175 Error::Other("Failed to downcast to Float32Array".to_string())
176 })?;
177 let value: f32 = value_str.parse().map_err(|_| {
178 Error::ParseError(format!("Invalid Float32 value: {value_str}"))
179 })?;
180 Self::build_comparison_mask_f32(array, op, value)?
181 }
182 DataType::Float64 => {
183 let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
184 Error::Other("Failed to downcast to Float64Array".to_string())
185 })?;
186 let value: f64 = value_str.parse().map_err(|_| {
187 Error::ParseError(format!("Invalid Float64 value: {value_str}"))
188 })?;
189 Self::build_comparison_mask_f64(array, op, value)?
190 }
191 dt => {
192 return Err(Error::InvalidInput(format!(
193 "Filter not supported for data type: {dt:?}"
194 )))
195 }
196 };
197
198 compute::filter_record_batch(batch, &mask)
200 .map_err(|e| Error::StorageError(format!("Failed to apply filter: {e}")))
201 }
202
203 #[allow(clippy::unnecessary_wraps)]
204 fn build_comparison_mask_i32(
205 array: &Int32Array,
206 op: &str,
207 value: i32,
208 ) -> Result<arrow::array::BooleanArray> {
209 use arrow::array::BooleanArray;
210 let values: Vec<bool> = (0..array.len())
211 .map(|i| {
212 if array.is_null(i) {
213 false
214 } else {
215 let v = array.value(i);
216 match op {
217 ">" => v > value,
218 ">=" => v >= value,
219 "<" => v < value,
220 "<=" => v <= value,
221 "=" => v == value,
222 "!=" | "<>" => v != value,
223 _ => false,
224 }
225 }
226 })
227 .collect();
228 Ok(BooleanArray::from(values))
229 }
230
231 #[allow(clippy::unnecessary_wraps)]
232 fn build_comparison_mask_i64(
233 array: &Int64Array,
234 op: &str,
235 value: i64,
236 ) -> Result<arrow::array::BooleanArray> {
237 use arrow::array::BooleanArray;
238 let values: Vec<bool> = (0..array.len())
239 .map(|i| {
240 if array.is_null(i) {
241 false
242 } else {
243 let v = array.value(i);
244 match op {
245 ">" => v > value,
246 ">=" => v >= value,
247 "<" => v < value,
248 "<=" => v <= value,
249 "=" => v == value,
250 "!=" | "<>" => v != value,
251 _ => false,
252 }
253 }
254 })
255 .collect();
256 Ok(BooleanArray::from(values))
257 }
258
259 #[allow(clippy::unnecessary_wraps)]
260 fn build_comparison_mask_f32(
261 array: &Float32Array,
262 op: &str,
263 value: f32,
264 ) -> Result<arrow::array::BooleanArray> {
265 use arrow::array::BooleanArray;
266 let values: Vec<bool> = (0..array.len())
267 .map(|i| {
268 if array.is_null(i) {
269 false
270 } else {
271 let v = array.value(i);
272 match op {
273 ">" => v > value,
274 ">=" => v >= value,
275 "<" => v < value,
276 "<=" => v <= value,
277 "=" => (v - value).abs() < f32::EPSILON,
278 "!=" | "<>" => (v - value).abs() >= f32::EPSILON,
279 _ => false,
280 }
281 }
282 })
283 .collect();
284 Ok(BooleanArray::from(values))
285 }
286
287 #[allow(clippy::unnecessary_wraps)]
288 fn build_comparison_mask_f64(
289 array: &Float64Array,
290 op: &str,
291 value: f64,
292 ) -> Result<arrow::array::BooleanArray> {
293 use arrow::array::BooleanArray;
294 let values: Vec<bool> = (0..array.len())
295 .map(|i| {
296 if array.is_null(i) {
297 false
298 } else {
299 let v = array.value(i);
300 match op {
301 ">" => v > value,
302 ">=" => v >= value,
303 "<" => v < value,
304 "<=" => v <= value,
305 "=" => (v - value).abs() < f64::EPSILON,
306 "!=" | "<>" => (v - value).abs() >= f64::EPSILON,
307 _ => false,
308 }
309 }
310 })
311 .collect();
312 Ok(BooleanArray::from(values))
313 }
314
315 fn project_columns(batch: &RecordBatch, columns: &[String]) -> Result<RecordBatch> {
317 if columns.len() == 1 && columns[0] == "*" {
318 return Ok(batch.clone());
319 }
320
321 let schema = batch.schema();
322 let mut new_columns = Vec::new();
323 let mut new_fields = Vec::new();
324
325 for col_name in columns {
326 let index = schema
327 .fields()
328 .iter()
329 .position(|f| f.name() == col_name)
330 .ok_or_else(|| Error::InvalidInput(format!("Column not found: {col_name}")))?;
331
332 new_columns.push(batch.column(index).clone());
333 new_fields.push(schema.field(index).clone());
334 }
335
336 let new_schema = Arc::new(Schema::new(new_fields));
337 RecordBatch::try_new(new_schema, new_columns)
338 .map_err(|e| Error::StorageError(format!("Failed to project columns: {e}")))
339 }
340
341 fn execute_aggregations(batch: &RecordBatch, plan: &QueryPlan) -> Result<RecordBatch> {
343 if !plan.group_by.is_empty() {
345 return Err(Error::InvalidInput(
346 "GROUP BY aggregations not yet implemented in Phase 1".to_string(),
347 ));
348 }
349
350 let mut result_columns: Vec<ArrayRef> = Vec::new();
351 let mut result_fields: Vec<Field> = Vec::new();
352
353 for (agg_func, col_name, alias) in &plan.aggregations {
354 let result_name = alias.as_deref().unwrap_or(col_name);
355
356 let schema = batch.schema();
358 let col_index = schema
359 .fields()
360 .iter()
361 .position(|f| f.name() == col_name || col_name == "*")
362 .ok_or_else(|| Error::InvalidInput(format!("Column not found: {col_name}")))?;
363
364 let column = batch.column(col_index);
365
366 let (result_value, result_type) =
368 Self::execute_single_aggregation(*agg_func, column, batch.num_rows())?;
369
370 result_columns.push(result_value);
371 result_fields.push(Field::new(result_name, result_type, false));
372 }
373
374 let result_schema = Arc::new(Schema::new(result_fields));
375 RecordBatch::try_new(result_schema, result_columns)
376 .map_err(|e| Error::StorageError(format!("Failed to create result batch: {e}")))
377 }
378
379 fn execute_single_aggregation(
381 func: AggregateFunction,
382 column: &ArrayRef,
383 num_rows: usize,
384 ) -> Result<(ArrayRef, DataType)> {
385 match column.data_type() {
386 DataType::Int32 => {
387 let array = column
388 .as_any()
389 .downcast_ref::<Int32Array>()
390 .ok_or_else(|| Error::Other("Failed to downcast to Int32Array".to_string()))?;
391 Self::aggregate_i32(func, array, num_rows)
392 }
393 DataType::Int64 => {
394 let array = column
395 .as_any()
396 .downcast_ref::<Int64Array>()
397 .ok_or_else(|| Error::Other("Failed to downcast to Int64Array".to_string()))?;
398 Self::aggregate_i64(func, array, num_rows)
399 }
400 DataType::Float32 => {
401 let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
402 Error::Other("Failed to downcast to Float32Array".to_string())
403 })?;
404 Self::aggregate_f32(func, array, num_rows)
405 }
406 DataType::Float64 => {
407 let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
408 Error::Other("Failed to downcast to Float64Array".to_string())
409 })?;
410 Self::aggregate_f64(func, array, num_rows)
411 }
412 dt => {
413 Err(Error::InvalidInput(format!("Aggregation not supported for data type: {dt:?}")))
414 }
415 }
416 }
417
418 #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
419 fn aggregate_i32(
420 func: AggregateFunction,
421 array: &Int32Array,
422 num_rows: usize,
423 ) -> Result<(ArrayRef, DataType)> {
424 match func {
425 AggregateFunction::Sum => {
426 let sum: i64 = (0..array.len())
427 .filter(|&i| !array.is_null(i))
428 .map(|i| i64::from(array.value(i)))
429 .sum();
430 Ok((Arc::new(Int64Array::from(vec![sum])), DataType::Int64))
431 }
432 AggregateFunction::Avg => {
433 let sum: f64 = (0..array.len())
434 .filter(|&i| !array.is_null(i))
435 .map(|i| f64::from(array.value(i)))
436 .sum();
437 let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
438 let avg = if count > 0 { sum / count as f64 } else { 0.0 };
439 Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
440 }
441 AggregateFunction::Count => {
442 Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
443 }
444 AggregateFunction::Min => {
445 let min = (0..array.len())
446 .filter(|&i| !array.is_null(i))
447 .map(|i| array.value(i))
448 .min()
449 .unwrap_or(0);
450 Ok((Arc::new(Int32Array::from(vec![min])), DataType::Int32))
451 }
452 AggregateFunction::Max => {
453 let max = (0..array.len())
454 .filter(|&i| !array.is_null(i))
455 .map(|i| array.value(i))
456 .max()
457 .unwrap_or(0);
458 Ok((Arc::new(Int32Array::from(vec![max])), DataType::Int32))
459 }
460 }
461 }
462
463 #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
464 fn aggregate_i64(
465 func: AggregateFunction,
466 array: &Int64Array,
467 num_rows: usize,
468 ) -> Result<(ArrayRef, DataType)> {
469 match func {
470 AggregateFunction::Sum => {
471 let sum: i64 =
472 (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
473 Ok((Arc::new(Int64Array::from(vec![sum])), DataType::Int64))
474 }
475 AggregateFunction::Avg => {
476 let sum: f64 = (0..array.len())
477 .filter(|&i| !array.is_null(i))
478 .map(|i| array.value(i) as f64)
479 .sum();
480 let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
481 let avg = if count > 0 { sum / count as f64 } else { 0.0 };
482 Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
483 }
484 AggregateFunction::Count => {
485 Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
486 }
487 AggregateFunction::Min => {
488 let min = (0..array.len())
489 .filter(|&i| !array.is_null(i))
490 .map(|i| array.value(i))
491 .min()
492 .unwrap_or(0);
493 Ok((Arc::new(Int64Array::from(vec![min])), DataType::Int64))
494 }
495 AggregateFunction::Max => {
496 let max = (0..array.len())
497 .filter(|&i| !array.is_null(i))
498 .map(|i| array.value(i))
499 .max()
500 .unwrap_or(0);
501 Ok((Arc::new(Int64Array::from(vec![max])), DataType::Int64))
502 }
503 }
504 }
505
506 #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
507 fn aggregate_f32(
508 func: AggregateFunction,
509 array: &Float32Array,
510 num_rows: usize,
511 ) -> Result<(ArrayRef, DataType)> {
512 match func {
513 AggregateFunction::Sum => {
514 let sum: f32 =
515 (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
516 Ok((Arc::new(Float32Array::from(vec![sum])), DataType::Float32))
517 }
518 AggregateFunction::Avg => {
519 let sum: f64 = (0..array.len())
520 .filter(|&i| !array.is_null(i))
521 .map(|i| f64::from(array.value(i)))
522 .sum();
523 let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
524 let avg = if count > 0 { sum / count as f64 } else { 0.0 };
525 Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
526 }
527 AggregateFunction::Count => {
528 Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
529 }
530 AggregateFunction::Min => {
531 let min = (0..array.len())
532 .filter(|&i| !array.is_null(i))
533 .map(|i| array.value(i))
534 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
535 .unwrap_or(0.0);
536 Ok((Arc::new(Float32Array::from(vec![min])), DataType::Float32))
537 }
538 AggregateFunction::Max => {
539 let max = (0..array.len())
540 .filter(|&i| !array.is_null(i))
541 .map(|i| array.value(i))
542 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
543 .unwrap_or(0.0);
544 Ok((Arc::new(Float32Array::from(vec![max])), DataType::Float32))
545 }
546 }
547 }
548
549 #[allow(clippy::cast_precision_loss, clippy::cast_possible_wrap, clippy::unnecessary_wraps)]
550 fn aggregate_f64(
551 func: AggregateFunction,
552 array: &Float64Array,
553 num_rows: usize,
554 ) -> Result<(ArrayRef, DataType)> {
555 match func {
556 AggregateFunction::Sum => {
557 let sum: f64 =
558 (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
559 Ok((Arc::new(Float64Array::from(vec![sum])), DataType::Float64))
560 }
561 AggregateFunction::Avg => {
562 let sum: f64 =
563 (0..array.len()).filter(|&i| !array.is_null(i)).map(|i| array.value(i)).sum();
564 let count = (0..array.len()).filter(|&i| !array.is_null(i)).count();
565 let avg = if count > 0 { sum / count as f64 } else { 0.0 };
566 Ok((Arc::new(Float64Array::from(vec![avg])), DataType::Float64))
567 }
568 AggregateFunction::Count => {
569 Ok((Arc::new(Int64Array::from(vec![num_rows as i64])), DataType::Int64))
570 }
571 AggregateFunction::Min => {
572 let min = (0..array.len())
573 .filter(|&i| !array.is_null(i))
574 .map(|i| array.value(i))
575 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
576 .unwrap_or(0.0);
577 Ok((Arc::new(Float64Array::from(vec![min])), DataType::Float64))
578 }
579 AggregateFunction::Max => {
580 let max = (0..array.len())
581 .filter(|&i| !array.is_null(i))
582 .map(|i| array.value(i))
583 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
584 .unwrap_or(0.0);
585 Ok((Arc::new(Float64Array::from(vec![max])), DataType::Float64))
586 }
587 }
588 }
589
590 fn apply_order_by_limit(batch: &RecordBatch, plan: &QueryPlan) -> Result<RecordBatch> {
592 if plan.order_by.is_empty() {
593 return Ok(batch.clone());
594 }
595
596 let (col_name, direction) = &plan.order_by[0];
598
599 let schema = batch.schema();
601 let col_index = schema
602 .fields()
603 .iter()
604 .position(|f| f.name() == col_name)
605 .ok_or_else(|| Error::InvalidInput(format!("Column not found: {col_name}")))?;
606
607 let sort_order = match direction {
609 OrderDirection::Asc => SortOrder::Ascending,
610 OrderDirection::Desc => SortOrder::Descending,
611 };
612
613 let k = plan.limit.unwrap_or_else(|| batch.num_rows());
615 batch.top_k(col_index, k, sort_order)
616 }
617}