vibesql_executor/select/columnar/batch/
operations.rs

1//! Batch manipulation operations
2//!
3//! This module contains methods for accessing and manipulating
4//! `ColumnarBatch` and `ColumnArray` instances.
5
6#![allow(clippy::needless_range_loop)]
7
8use std::sync::Arc;
9
10use ahash::AHashSet;
11
12use crate::errors::ExecutorError;
13use vibesql_storage::Row;
14use vibesql_types::{DataType, Date, SqlValue, Time, Timestamp};
15
16use super::types::{ColumnArray, ColumnarBatch};
17
18impl ColumnarBatch {
19    /// Get the number of rows in this batch
20    pub fn row_count(&self) -> usize {
21        self.row_count
22    }
23
24    /// Get the number of columns in this batch
25    pub fn column_count(&self) -> usize {
26        self.columns.len()
27    }
28
29    /// Get a reference to a column array
30    pub fn column(&self, index: usize) -> Option<&ColumnArray> {
31        self.columns.get(index)
32    }
33
34    /// Get a mutable reference to a column array
35    pub fn column_mut(&mut self, index: usize) -> Option<&mut ColumnArray> {
36        self.columns.get_mut(index)
37    }
38
39    /// Add a column to the batch
40    pub fn add_column(&mut self, column: ColumnArray) -> Result<(), ExecutorError> {
41        // Verify column has correct length
42        let col_len = column.len();
43        if self.row_count > 0 && col_len != self.row_count {
44            return Err(ExecutorError::ColumnarLengthMismatch {
45                context: "add_column".to_string(),
46                expected: self.row_count,
47                actual: col_len,
48            });
49        }
50
51        if self.row_count == 0 {
52            self.row_count = col_len;
53        }
54
55        self.columns.push(column);
56        Ok(())
57    }
58
59    /// Set column names (for debugging)
60    pub fn set_column_names(&mut self, names: Vec<String>) {
61        self.column_names = Some(names);
62    }
63
64    /// Get column names
65    pub fn column_names(&self) -> Option<&[String]> {
66        self.column_names.as_deref()
67    }
68
69    /// Get column index by name
70    pub fn column_index_by_name(&self, name: &str) -> Option<usize> {
71        self.column_names.as_ref()?.iter().position(|n| n == name)
72    }
73
74    /// Get a value at a specific (row, column) position
75    pub fn get_value(&self, row_idx: usize, col_idx: usize) -> Result<SqlValue, ExecutorError> {
76        let column = self.column(col_idx).ok_or(ExecutorError::ColumnarColumnNotFound {
77            column_index: col_idx,
78            batch_columns: self.columns.len(),
79        })?;
80        column.get_value(row_idx)
81    }
82
83    /// Convert columnar batch back to row-oriented storage
84    pub fn to_rows(&self) -> Result<Vec<Row>, ExecutorError> {
85        let mut rows = Vec::with_capacity(self.row_count);
86
87        for row_idx in 0..self.row_count {
88            let mut values = Vec::with_capacity(self.columns.len());
89
90            for column in &self.columns {
91                let value = column.get_value(row_idx)?;
92                values.push(value);
93            }
94
95            rows.push(Row::new(values));
96        }
97
98        Ok(rows)
99    }
100
101    /// Deduplicate rows in the batch, returning a new batch with unique rows only
102    ///
103    /// Uses hash-based deduplication on all columns, preserving insertion order.
104    /// This implements DISTINCT semantics: NULL == NULL for uniqueness purposes.
105    ///
106    /// # Performance
107    ///
108    /// - O(n) time complexity where n is the number of rows
109    /// - Uses AHashSet for efficient duplicate detection
110    /// - Preserves the first occurrence of each unique row combination
111    ///
112    /// # Example
113    ///
114    /// ```text
115    /// // Original batch:
116    /// // [1, "A"], [2, "B"], [1, "A"], [3, "C"]
117    ///
118    /// // After deduplicate():
119    /// // [1, "A"], [2, "B"], [3, "C"]
120    /// ```
121    pub fn deduplicate(&self) -> Result<Self, ExecutorError> {
122        if self.row_count == 0 {
123            return Ok(self.clone());
124        }
125
126        // Track which row indices to keep (first occurrence of each unique row)
127        let mut seen: AHashSet<Vec<SqlValue>> = AHashSet::with_capacity(self.row_count);
128        let mut keep_indices: Vec<usize> = Vec::with_capacity(self.row_count);
129
130        for row_idx in 0..self.row_count {
131            // Extract all column values for this row as a key
132            let mut row_key = Vec::with_capacity(self.columns.len());
133            for col in &self.columns {
134                let value = col.get_value(row_idx)?;
135                row_key.push(value);
136            }
137
138            // If this row hasn't been seen, keep it
139            if seen.insert(row_key) {
140                keep_indices.push(row_idx);
141            }
142        }
143
144        // If no duplicates found, return self unchanged
145        if keep_indices.len() == self.row_count {
146            return Ok(self.clone());
147        }
148
149        log::debug!(
150            "Columnar deduplicate: {} rows -> {} unique rows",
151            self.row_count,
152            keep_indices.len()
153        );
154
155        // Build new batch with only the unique rows
156        self.select_rows(&keep_indices)
157    }
158
159    /// Select specific rows from the batch by index, returning a new batch
160    ///
161    /// # Arguments
162    ///
163    /// * `indices` - Row indices to select (must be valid for this batch)
164    ///
165    /// # Returns
166    ///
167    /// A new ColumnarBatch containing only the selected rows
168    pub fn select_rows(&self, indices: &[usize]) -> Result<Self, ExecutorError> {
169        if indices.is_empty() {
170            return Self::empty(self.columns.len());
171        }
172
173        let new_row_count = indices.len();
174        let mut new_columns = Vec::with_capacity(self.columns.len());
175
176        for column in &self.columns {
177            let new_column = column.select_rows(indices)?;
178            new_columns.push(new_column);
179        }
180
181        Ok(Self {
182            row_count: new_row_count,
183            columns: new_columns,
184            column_names: self.column_names.clone(),
185        })
186    }
187}
188
189impl ColumnArray {
190    /// Select specific rows from the column by index, returning a new column
191    fn select_rows(&self, indices: &[usize]) -> Result<Self, ExecutorError> {
192        match self {
193            Self::Int64(values, nulls) => {
194                let new_values: Vec<i64> = indices.iter().map(|&i| values[i]).collect();
195                let new_nulls = nulls
196                    .as_ref()
197                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
198                Ok(Self::Int64(Arc::new(new_values), new_nulls))
199            }
200            Self::Int32(values, nulls) => {
201                let new_values: Vec<i32> = indices.iter().map(|&i| values[i]).collect();
202                let new_nulls = nulls
203                    .as_ref()
204                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
205                Ok(Self::Int32(Arc::new(new_values), new_nulls))
206            }
207            Self::Float64(values, nulls) => {
208                let new_values: Vec<f64> = indices.iter().map(|&i| values[i]).collect();
209                let new_nulls = nulls
210                    .as_ref()
211                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
212                Ok(Self::Float64(Arc::new(new_values), new_nulls))
213            }
214            Self::Float32(values, nulls) => {
215                let new_values: Vec<f32> = indices.iter().map(|&i| values[i]).collect();
216                let new_nulls = nulls
217                    .as_ref()
218                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
219                Ok(Self::Float32(Arc::new(new_values), new_nulls))
220            }
221            Self::String(values, nulls) => {
222                let new_values: Vec<Arc<str>> = indices.iter().map(|&i| values[i].clone()).collect();
223                let new_nulls = nulls
224                    .as_ref()
225                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
226                Ok(Self::String(Arc::new(new_values), new_nulls))
227            }
228            Self::FixedString(values, nulls) => {
229                let new_values: Vec<Arc<str>> = indices.iter().map(|&i| values[i].clone()).collect();
230                let new_nulls = nulls
231                    .as_ref()
232                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
233                Ok(Self::FixedString(Arc::new(new_values), new_nulls))
234            }
235            Self::Date(values, nulls) => {
236                let new_values: Vec<i32> = indices.iter().map(|&i| values[i]).collect();
237                let new_nulls = nulls
238                    .as_ref()
239                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
240                Ok(Self::Date(Arc::new(new_values), new_nulls))
241            }
242            Self::Timestamp(values, nulls) => {
243                let new_values: Vec<i64> = indices.iter().map(|&i| values[i]).collect();
244                let new_nulls = nulls
245                    .as_ref()
246                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
247                Ok(Self::Timestamp(Arc::new(new_values), new_nulls))
248            }
249            Self::Boolean(values, nulls) => {
250                let new_values: Vec<u8> = indices.iter().map(|&i| values[i]).collect();
251                let new_nulls = nulls
252                    .as_ref()
253                    .map(|n| Arc::new(indices.iter().map(|&i| n[i]).collect::<Vec<_>>()));
254                Ok(Self::Boolean(Arc::new(new_values), new_nulls))
255            }
256            Self::Mixed(values) => {
257                let new_values: Vec<SqlValue> =
258                    indices.iter().map(|&i| values[i].clone()).collect();
259                Ok(Self::Mixed(Arc::new(new_values)))
260            }
261        }
262    }
263
264    /// Get the number of values in this column
265    pub fn len(&self) -> usize {
266        match self {
267            Self::Int64(v, _) => v.len(),
268            Self::Int32(v, _) => v.len(),
269            Self::Float64(v, _) => v.len(),
270            Self::Float32(v, _) => v.len(),
271            Self::String(v, _) => v.len(),
272            Self::FixedString(v, _) => v.len(),
273            Self::Date(v, _) => v.len(),
274            Self::Timestamp(v, _) => v.len(),
275            Self::Boolean(v, _) => v.len(),
276            Self::Mixed(v) => v.len(),
277        }
278    }
279
280    /// Check if column is empty
281    pub fn is_empty(&self) -> bool {
282        self.len() == 0
283    }
284
285    /// Get a value at the specified index as SqlValue
286    pub fn get_value(&self, index: usize) -> Result<SqlValue, ExecutorError> {
287        match self {
288            Self::Int64(values, nulls) => {
289                if let Some(null_mask) = nulls {
290                    if null_mask.get(index).copied().unwrap_or(false) {
291                        return Ok(SqlValue::Null);
292                    }
293                }
294                values
295                    .get(index)
296                    .map(|v| SqlValue::Integer(*v))
297                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
298            }
299
300            Self::Float64(values, nulls) => {
301                if let Some(null_mask) = nulls {
302                    if null_mask.get(index).copied().unwrap_or(false) {
303                        return Ok(SqlValue::Null);
304                    }
305                }
306                values
307                    .get(index)
308                    .map(|v| SqlValue::Double(*v))
309                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
310            }
311
312            Self::String(values, nulls) => {
313                if let Some(null_mask) = nulls {
314                    if null_mask.get(index).copied().unwrap_or(false) {
315                        return Ok(SqlValue::Null);
316                    }
317                }
318                values
319                    .get(index)
320                    .map(|v| SqlValue::Varchar(arcstr::ArcStr::from(v.as_ref())))
321                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
322            }
323
324            Self::Boolean(values, nulls) => {
325                if let Some(null_mask) = nulls {
326                    if null_mask.get(index).copied().unwrap_or(false) {
327                        return Ok(SqlValue::Null);
328                    }
329                }
330                values
331                    .get(index)
332                    .map(|v| SqlValue::Boolean(*v != 0))
333                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
334            }
335
336            Self::Mixed(values) => {
337                values.get(index).cloned().ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
338            }
339
340            Self::Int32(values, nulls) => {
341                if let Some(null_mask) = nulls {
342                    if null_mask.get(index).copied().unwrap_or(false) {
343                        return Ok(SqlValue::Null);
344                    }
345                }
346                values
347                    .get(index)
348                    .map(|v| SqlValue::Integer(*v as i64))
349                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
350            }
351
352            Self::Float32(values, nulls) => {
353                if let Some(null_mask) = nulls {
354                    if null_mask.get(index).copied().unwrap_or(false) {
355                        return Ok(SqlValue::Null);
356                    }
357                }
358                values
359                    .get(index)
360                    .map(|v| SqlValue::Real(*v))
361                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
362            }
363
364            Self::FixedString(values, nulls) => {
365                if let Some(null_mask) = nulls {
366                    if null_mask.get(index).copied().unwrap_or(false) {
367                        return Ok(SqlValue::Null);
368                    }
369                }
370                values
371                    .get(index)
372                    .map(|v| SqlValue::Character(arcstr::ArcStr::from(v.as_ref())))
373                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
374            }
375
376            Self::Date(values, nulls) => {
377                if let Some(null_mask) = nulls {
378                    if null_mask.get(index).copied().unwrap_or(false) {
379                        return Ok(SqlValue::Null);
380                    }
381                }
382                values
383                    .get(index)
384                    .map(|v| SqlValue::Date(days_since_epoch_to_date(*v)))
385                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
386            }
387
388            Self::Timestamp(values, nulls) => {
389                if let Some(null_mask) = nulls {
390                    if null_mask.get(index).copied().unwrap_or(false) {
391                        return Ok(SqlValue::Null);
392                    }
393                }
394                values
395                    .get(index)
396                    .map(|v| SqlValue::Timestamp(microseconds_to_timestamp(*v)))
397                    .ok_or(ExecutorError::ColumnIndexOutOfBounds { index })
398            }
399        }
400    }
401
402    /// Get the data type of this column
403    pub fn data_type(&self) -> DataType {
404        match self {
405            Self::Int64(_, _) => DataType::Integer,
406            Self::Int32(_, _) => DataType::Integer,
407            Self::Float64(_, _) => DataType::DoublePrecision,
408            Self::Float32(_, _) => DataType::Real,
409            Self::String(_, _) => DataType::Varchar { max_length: None },
410            Self::FixedString(_, _) => DataType::Character { length: 255 },
411            Self::Date(_, _) => DataType::Date,
412            Self::Timestamp(_, _) => DataType::Timestamp { with_timezone: false },
413            Self::Boolean(_, _) => DataType::Boolean,
414            Self::Mixed(_) => DataType::Varchar { max_length: None }, // fallback
415        }
416    }
417
418    /// Get raw i64 slice (for SIMD operations)
419    pub fn as_i64(&self) -> Option<(&[i64], Option<&[bool]>)> {
420        match self {
421            Self::Int64(values, nulls) => {
422                Some((values.as_slice(), nulls.as_ref().map(|n| n.as_slice())))
423            }
424            _ => None,
425        }
426    }
427
428    /// Get raw f64 slice (for SIMD operations)
429    pub fn as_f64(&self) -> Option<(&[f64], Option<&[bool]>)> {
430        match self {
431            Self::Float64(values, nulls) => {
432                Some((values.as_slice(), nulls.as_ref().map(|n| n.as_slice())))
433            }
434            _ => None,
435        }
436    }
437}
438
439/// Convert days since Unix epoch to Date
440fn days_since_epoch_to_date(days: i32) -> Date {
441    // Simplified conversion: start from 1970-01-01 and count forward
442    let mut year = 1970;
443    let mut remaining_days = days;
444
445    // Handle years
446    loop {
447        let year_days =
448            if year % 4 == 0 && (year % 100 != 0 || year % 400 == 0) { 366 } else { 365 };
449        if remaining_days < year_days {
450            break;
451        }
452        remaining_days -= year_days;
453        year += 1;
454    }
455
456    // Handle months
457    let is_leap = year % 4 == 0 && (year % 100 != 0 || year % 400 == 0);
458    let month_lengths = if is_leap {
459        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
460    } else {
461        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
462    };
463
464    let mut month = 1;
465    for &days_in_month in &month_lengths {
466        if remaining_days < days_in_month {
467            break;
468        }
469        remaining_days -= days_in_month;
470        month += 1;
471    }
472
473    let day = remaining_days + 1;
474
475    Date::new(year, month as u8, day as u8).unwrap_or_else(|_| Date::new(1970, 1, 1).unwrap())
476}
477
478/// Convert microseconds since Unix epoch to Timestamp
479fn microseconds_to_timestamp(micros: i64) -> Timestamp {
480    let days = (micros / 86_400_000_000) as i32;
481    let remaining_micros = micros % 86_400_000_000;
482
483    let date = days_since_epoch_to_date(days);
484
485    let hours = (remaining_micros / 3_600_000_000) as u8;
486    let remaining_micros = remaining_micros % 3_600_000_000;
487    let minutes = (remaining_micros / 60_000_000) as u8;
488    let remaining_micros = remaining_micros % 60_000_000;
489    let seconds = (remaining_micros / 1_000_000) as u8;
490    let nanoseconds = ((remaining_micros % 1_000_000) * 1_000) as u32;
491
492    let time = Time::new(hours, minutes, seconds, nanoseconds)
493        .unwrap_or_else(|_| Time::new(0, 0, 0, 0).unwrap());
494
495    Timestamp::new(date, time)
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_batch_to_rows_roundtrip() {
504        let original_rows = vec![
505            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(10.5)]),
506            Row::new(vec![SqlValue::Integer(2), SqlValue::Double(20.5)]),
507        ];
508
509        let batch = ColumnarBatch::from_rows(&original_rows).unwrap();
510        let converted_rows = batch.to_rows().unwrap();
511
512        assert_eq!(converted_rows.len(), original_rows.len());
513        for (original, converted) in original_rows.iter().zip(converted_rows.iter()) {
514            assert_eq!(original.len(), converted.len());
515            for i in 0..original.len() {
516                assert_eq!(original.get(i), converted.get(i));
517            }
518        }
519    }
520
521    #[test]
522    fn test_simd_column_access() {
523        let rows = vec![
524            Row::new(vec![SqlValue::Integer(1), SqlValue::Double(10.5)]),
525            Row::new(vec![SqlValue::Integer(2), SqlValue::Double(20.5)]),
526            Row::new(vec![SqlValue::Integer(3), SqlValue::Double(30.5)]),
527        ];
528
529        let batch = ColumnarBatch::from_rows(&rows).unwrap();
530
531        // Access i64 column for SIMD
532        let col0 = batch.column(0).unwrap();
533        if let Some((values, nulls)) = col0.as_i64() {
534            assert_eq!(values, &[1, 2, 3]);
535            assert!(nulls.is_none());
536        } else {
537            panic!("Expected i64 slice");
538        }
539
540        // Access f64 column for SIMD
541        let col1 = batch.column(1).unwrap();
542        if let Some((values, nulls)) = col1.as_f64() {
543            assert_eq!(values, &[10.5, 20.5, 30.5]);
544            assert!(nulls.is_none());
545        } else {
546            panic!("Expected f64 slice");
547        }
548    }
549
550    #[test]
551    fn test_deduplicate_with_duplicates() {
552        // Create batch with duplicate rows
553        let rows = vec![
554            Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("A"))]),
555            Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("B"))]),
556            Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("A"))]), // duplicate
557            Row::new(vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("C"))]),
558            Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("B"))]), // duplicate
559        ];
560
561        let batch = ColumnarBatch::from_rows(&rows).unwrap();
562        assert_eq!(batch.row_count(), 5);
563
564        let deduped = batch.deduplicate().unwrap();
565        assert_eq!(deduped.row_count(), 3);
566
567        // Verify the first occurrences are kept in order
568        let result_rows = deduped.to_rows().unwrap();
569        assert_eq!(result_rows[0].get(0), Some(&SqlValue::Integer(1)));
570        assert_eq!(result_rows[0].get(1), Some(&SqlValue::Varchar(arcstr::ArcStr::from("A"))));
571        assert_eq!(result_rows[1].get(0), Some(&SqlValue::Integer(2)));
572        assert_eq!(result_rows[1].get(1), Some(&SqlValue::Varchar(arcstr::ArcStr::from("B"))));
573        assert_eq!(result_rows[2].get(0), Some(&SqlValue::Integer(3)));
574        assert_eq!(result_rows[2].get(1), Some(&SqlValue::Varchar(arcstr::ArcStr::from("C"))));
575    }
576
577    #[test]
578    fn test_deduplicate_no_duplicates() {
579        // All unique rows
580        let rows = vec![
581            Row::new(vec![SqlValue::Integer(1)]),
582            Row::new(vec![SqlValue::Integer(2)]),
583            Row::new(vec![SqlValue::Integer(3)]),
584        ];
585
586        let batch = ColumnarBatch::from_rows(&rows).unwrap();
587        let deduped = batch.deduplicate().unwrap();
588
589        assert_eq!(deduped.row_count(), 3);
590    }
591
592    #[test]
593    fn test_deduplicate_with_nulls() {
594        // Test NULL handling: NULL == NULL for DISTINCT purposes
595        let rows = vec![
596            Row::new(vec![SqlValue::Null, SqlValue::Varchar(arcstr::ArcStr::from("A"))]),
597            Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("B"))]),
598            Row::new(vec![SqlValue::Null, SqlValue::Varchar(arcstr::ArcStr::from("A"))]), // duplicate
599        ];
600
601        let batch = ColumnarBatch::from_rows(&rows).unwrap();
602        let deduped = batch.deduplicate().unwrap();
603
604        assert_eq!(deduped.row_count(), 2);
605    }
606
607    #[test]
608    fn test_deduplicate_empty_batch() {
609        let batch = ColumnarBatch::new(2);
610        let deduped = batch.deduplicate().unwrap();
611        assert_eq!(deduped.row_count(), 0);
612    }
613
614    #[test]
615    fn test_select_rows() {
616        let rows = vec![
617            Row::new(vec![SqlValue::Integer(1), SqlValue::Varchar(arcstr::ArcStr::from("A"))]),
618            Row::new(vec![SqlValue::Integer(2), SqlValue::Varchar(arcstr::ArcStr::from("B"))]),
619            Row::new(vec![SqlValue::Integer(3), SqlValue::Varchar(arcstr::ArcStr::from("C"))]),
620            Row::new(vec![SqlValue::Integer(4), SqlValue::Varchar(arcstr::ArcStr::from("D"))]),
621        ];
622
623        let batch = ColumnarBatch::from_rows(&rows).unwrap();
624
625        // Select rows 0, 2 (indices)
626        let selected = batch.select_rows(&[0, 2]).unwrap();
627        assert_eq!(selected.row_count(), 2);
628
629        let result_rows = selected.to_rows().unwrap();
630        assert_eq!(result_rows[0].get(0), Some(&SqlValue::Integer(1)));
631        assert_eq!(result_rows[1].get(0), Some(&SqlValue::Integer(3)));
632    }
633}