sql_cli/data/
computed_view.rs

1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::recursive_parser::SqlExpression;
3use std::sync::Arc;
4
5/// Represents a column in a computed view - either original or derived
6#[derive(Debug, Clone)]
7pub enum ViewColumn {
8    /// Direct reference to a column in the original table
9    Original {
10        source_index: usize, // Index in the original DataTable
11        name: String,        // May be aliased from original
12    },
13    /// Computed/derived column with cached results
14    Derived {
15        name: String,
16        expression: SqlExpression,
17        cached_values: Vec<DataValue>, // Pre-computed for all visible rows
18    },
19}
20
21/// A view over a DataTable that can contain both original and computed columns
22/// This is query-scoped - exists only for the duration of one query result
23#[derive(Debug, Clone)]
24pub struct ComputedDataView {
25    /// Reference to the original source table (never modified)
26    source_table: Arc<DataTable>,
27
28    /// Column definitions (mix of original and derived)
29    columns: Vec<ViewColumn>,
30
31    /// Which rows from the source table are visible (after WHERE clause filtering)
32    /// Indices refer to rows in source_table
33    visible_rows: Vec<usize>,
34}
35
36impl ComputedDataView {
37    /// Create a new computed view with specified columns and visible rows
38    pub fn new(
39        source_table: Arc<DataTable>,
40        columns: Vec<ViewColumn>,
41        visible_rows: Vec<usize>,
42    ) -> Self {
43        Self {
44            source_table,
45            columns,
46            visible_rows,
47        }
48    }
49
50    /// Get the number of visible rows
51    pub fn row_count(&self) -> usize {
52        self.visible_rows.len()
53    }
54
55    /// Get the number of columns (original + derived)
56    pub fn column_count(&self) -> usize {
57        self.columns.len()
58    }
59
60    /// Get column names
61    pub fn column_names(&self) -> Vec<String> {
62        self.columns
63            .iter()
64            .map(|col| match col {
65                ViewColumn::Original { name, .. } => name.clone(),
66                ViewColumn::Derived { name, .. } => name.clone(),
67            })
68            .collect()
69    }
70
71    /// Get a value at a specific row and column
72    pub fn get_value(&self, row_idx: usize, col_idx: usize) -> Option<DataValue> {
73        // Check bounds
74        if row_idx >= self.visible_rows.len() || col_idx >= self.columns.len() {
75            return None;
76        }
77
78        match &self.columns[col_idx] {
79            ViewColumn::Original { source_index, .. } => {
80                // Get the actual row index in the source table
81                let source_row_idx = self.visible_rows[row_idx];
82
83                // Get value from original table
84                self.source_table
85                    .get_row(source_row_idx)
86                    .and_then(|row| row.get(*source_index))
87                    .cloned()
88            }
89            ViewColumn::Derived { cached_values, .. } => {
90                // Return pre-computed value
91                cached_values.get(row_idx).cloned()
92            }
93        }
94    }
95
96    /// Get all values for a row
97    pub fn get_row_values(&self, row_idx: usize) -> Option<Vec<DataValue>> {
98        if row_idx >= self.visible_rows.len() {
99            return None;
100        }
101
102        let mut values = Vec::new();
103        for col_idx in 0..self.columns.len() {
104            values.push(self.get_value(row_idx, col_idx)?);
105        }
106        Some(values)
107    }
108
109    /// Get the underlying source table (for reference, not modification)
110    pub fn source_table(&self) -> &Arc<DataTable> {
111        &self.source_table
112    }
113
114    /// Get the visible row indices (useful for debugging)
115    pub fn visible_rows(&self) -> &[usize] {
116        &self.visible_rows
117    }
118
119    /// Check if a column is derived
120    pub fn is_derived_column(&self, col_idx: usize) -> bool {
121        matches!(self.columns.get(col_idx), Some(ViewColumn::Derived { .. }))
122    }
123
124    /// Create a simple view showing all columns from source (no computations)
125    pub fn from_source_all_columns(source: Arc<DataTable>) -> Self {
126        let columns: Vec<ViewColumn> = source
127            .column_names()
128            .into_iter()
129            .enumerate()
130            .map(|(idx, name)| ViewColumn::Original {
131                source_index: idx,
132                name,
133            })
134            .collect();
135
136        let visible_rows: Vec<usize> = (0..source.row_count()).collect();
137
138        Self::new(source, columns, visible_rows)
139    }
140
141    /// Create a view with filtered rows (WHERE clause applied)
142    pub fn with_filtered_rows(mut self, row_indices: Vec<usize>) -> Self {
143        self.visible_rows = row_indices;
144
145        // Update cached values for derived columns
146        for col in &mut self.columns {
147            if let ViewColumn::Derived { cached_values, .. } = col {
148                // Filter cached values to match new row set
149                let mut new_cache = Vec::new();
150                for &row_idx in &self.visible_rows {
151                    if row_idx < cached_values.len() {
152                        new_cache.push(cached_values[row_idx].clone());
153                    }
154                }
155                *cached_values = new_cache;
156            }
157        }
158
159        self
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::data::datatable::{DataColumn, DataRow};
167
168    fn create_test_table() -> Arc<DataTable> {
169        let mut table = DataTable::new("test");
170        table.add_column(DataColumn::new("a"));
171        table.add_column(DataColumn::new("b"));
172
173        table
174            .add_row(DataRow::new(vec![
175                DataValue::Integer(10),
176                DataValue::Float(2.5),
177            ]))
178            .unwrap();
179
180        table
181            .add_row(DataRow::new(vec![
182                DataValue::Integer(20),
183                DataValue::Float(3.5),
184            ]))
185            .unwrap();
186
187        Arc::new(table)
188    }
189
190    #[test]
191    fn test_original_columns_view() {
192        let table = create_test_table();
193        let view = ComputedDataView::from_source_all_columns(table);
194
195        assert_eq!(view.row_count(), 2);
196        assert_eq!(view.column_count(), 2);
197        assert_eq!(view.column_names(), vec!["a", "b"]);
198
199        // Check values
200        assert_eq!(view.get_value(0, 0), Some(DataValue::Integer(10)));
201        assert_eq!(view.get_value(0, 1), Some(DataValue::Float(2.5)));
202        assert_eq!(view.get_value(1, 0), Some(DataValue::Integer(20)));
203    }
204
205    #[test]
206    fn test_mixed_columns() {
207        let table = create_test_table();
208
209        // Create a view with original column 'a' and derived column 'doubled'
210        let columns = vec![
211            ViewColumn::Original {
212                source_index: 0,
213                name: "a".to_string(),
214            },
215            ViewColumn::Derived {
216                name: "doubled".to_string(),
217                expression: SqlExpression::Column("a".to_string()), // Placeholder
218                cached_values: vec![
219                    DataValue::Integer(20), // 10 * 2
220                    DataValue::Integer(40), // 20 * 2
221                ],
222            },
223        ];
224
225        let view = ComputedDataView::new(table, columns, vec![0, 1]);
226
227        assert_eq!(view.column_count(), 2);
228        assert_eq!(view.column_names(), vec!["a", "doubled"]);
229
230        // Original column
231        assert_eq!(view.get_value(0, 0), Some(DataValue::Integer(10)));
232
233        // Derived column
234        assert_eq!(view.get_value(0, 1), Some(DataValue::Integer(20)));
235        assert_eq!(view.get_value(1, 1), Some(DataValue::Integer(40)));
236    }
237
238    #[test]
239    fn test_filtered_rows() {
240        let table = create_test_table();
241        let view = ComputedDataView::from_source_all_columns(table).with_filtered_rows(vec![1]); // Only second row visible
242
243        assert_eq!(view.row_count(), 1);
244        assert_eq!(view.get_value(0, 0), Some(DataValue::Integer(20)));
245    }
246}