sql_cli/data/
where_evaluator.rs

1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::where_ast::{ComparisonOp, WhereExpr, WhereValue};
3use anyhow::Result;
4
5/// Evaluates WHERE clause expressions against `DataTable` rows
6pub struct WhereEvaluator<'a> {
7    table: &'a DataTable,
8    column_indices: Vec<usize>,
9}
10
11impl<'a> WhereEvaluator<'a> {
12    #[must_use]
13    pub fn new(table: &'a DataTable) -> Self {
14        let column_indices = (0..table.column_count()).collect();
15        Self {
16            table,
17            column_indices,
18        }
19    }
20
21    /// Evaluate a WHERE expression for a specific row
22    pub fn evaluate(&self, expr: &WhereExpr, row_index: usize) -> Result<bool> {
23        match expr {
24            WhereExpr::And(left, right) => {
25                Ok(self.evaluate(left, row_index)? && self.evaluate(right, row_index)?)
26            }
27            WhereExpr::Or(left, right) => {
28                Ok(self.evaluate(left, row_index)? || self.evaluate(right, row_index)?)
29            }
30            WhereExpr::Not(inner) => Ok(!self.evaluate(inner, row_index)?),
31            WhereExpr::Equal(column, value) => {
32                self.evaluate_comparison(column, value, row_index, ComparisonOp::Equal)
33            }
34            WhereExpr::NotEqual(column, value) => {
35                self.evaluate_comparison(column, value, row_index, ComparisonOp::NotEqual)
36            }
37            WhereExpr::GreaterThan(column, value) => {
38                self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThan)
39            }
40            WhereExpr::GreaterThanOrEqual(column, value) => {
41                self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThanOrEqual)
42            }
43            WhereExpr::LessThan(column, value) => {
44                self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThan)
45            }
46            WhereExpr::LessThanOrEqual(column, value) => {
47                self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThanOrEqual)
48            }
49            WhereExpr::Between(column, lower, upper) => {
50                self.evaluate_between(column, lower, upper, row_index)
51            }
52            WhereExpr::In(column, values) => self.evaluate_in(column, values, row_index, false),
53            WhereExpr::NotIn(column, values) => {
54                Ok(!self.evaluate_in(column, values, row_index, false)?)
55            }
56            WhereExpr::InIgnoreCase(column, values) => {
57                self.evaluate_in(column, values, row_index, true)
58            }
59            WhereExpr::NotInIgnoreCase(column, values) => {
60                Ok(!self.evaluate_in(column, values, row_index, true)?)
61            }
62            WhereExpr::Like(column, pattern) => self.evaluate_like(column, pattern, row_index),
63            WhereExpr::IsNull(column) => self.evaluate_is_null(column, row_index, true),
64            WhereExpr::IsNotNull(column) => self.evaluate_is_null(column, row_index, false),
65            WhereExpr::Contains(column, substring) => self.evaluate_string_method(
66                column,
67                substring,
68                row_index,
69                StringMethod::Contains,
70                false,
71            ),
72            WhereExpr::StartsWith(column, prefix) => self.evaluate_string_method(
73                column,
74                prefix,
75                row_index,
76                StringMethod::StartsWith,
77                false,
78            ),
79            WhereExpr::EndsWith(column, suffix) => self.evaluate_string_method(
80                column,
81                suffix,
82                row_index,
83                StringMethod::EndsWith,
84                false,
85            ),
86            WhereExpr::ContainsIgnoreCase(column, substring) => self.evaluate_string_method(
87                column,
88                substring,
89                row_index,
90                StringMethod::Contains,
91                true,
92            ),
93            WhereExpr::StartsWithIgnoreCase(column, prefix) => self.evaluate_string_method(
94                column,
95                prefix,
96                row_index,
97                StringMethod::StartsWith,
98                true,
99            ),
100            WhereExpr::EndsWithIgnoreCase(column, suffix) => {
101                self.evaluate_string_method(column, suffix, row_index, StringMethod::EndsWith, true)
102            }
103            WhereExpr::ToLower(column, op, value) => {
104                self.evaluate_case_conversion(column, value, row_index, op, true)
105            }
106            WhereExpr::ToUpper(column, op, value) => {
107                self.evaluate_case_conversion(column, value, row_index, op, false)
108            }
109            WhereExpr::IsNullOrEmpty(column) => self.evaluate_is_null_or_empty(column, row_index),
110            WhereExpr::Length(column, op, length) => {
111                self.evaluate_length(column, *length, row_index, op)
112            }
113        }
114    }
115
116    fn get_column_index(&self, column: &str) -> Result<usize> {
117        let columns = self.table.column_names();
118        columns
119            .iter()
120            .position(|c| c.eq_ignore_ascii_case(column))
121            .ok_or_else(|| anyhow::anyhow!("Column '{}' not found", column))
122    }
123
124    fn get_cell_value(&self, column: &str, row_index: usize) -> Result<Option<DataValue>> {
125        let col_index = self.get_column_index(column)?;
126        Ok(self.table.get_value(row_index, col_index).cloned())
127    }
128
129    fn evaluate_comparison(
130        &self,
131        column: &str,
132        value: &WhereValue,
133        row_index: usize,
134        op: ComparisonOp,
135    ) -> Result<bool> {
136        let cell_value = self.get_cell_value(column, row_index)?;
137
138        match cell_value {
139            None | Some(DataValue::Null) => Ok(false),
140            Some(data_val) => {
141                let result = match (&data_val, value) {
142                    // Number comparisons
143                    (DataValue::Integer(a), WhereValue::Number(b)) => {
144                        compare_numbers(*a as f64, *b, &op)
145                    }
146                    (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, &op),
147                    // String comparisons
148                    (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, &op),
149                    (DataValue::InternedString(a), WhereValue::String(b)) => {
150                        compare_strings(a, b, &op)
151                    }
152                    // String to number coercion
153                    (DataValue::String(a), WhereValue::Number(b)) => {
154                        if let Ok(a_num) = a.parse::<f64>() {
155                            compare_numbers(a_num, *b, &op)
156                        } else {
157                            false
158                        }
159                    }
160                    (DataValue::InternedString(a), WhereValue::Number(b)) => {
161                        if let Ok(a_num) = a.parse::<f64>() {
162                            compare_numbers(a_num, *b, &op)
163                        } else {
164                            false
165                        }
166                    }
167                    (DataValue::Integer(a), WhereValue::String(b)) => {
168                        if let Ok(b_num) = b.parse::<f64>() {
169                            compare_numbers(*a as f64, b_num, &op)
170                        } else {
171                            false
172                        }
173                    }
174                    (DataValue::Float(a), WhereValue::String(b)) => {
175                        if let Ok(b_num) = b.parse::<f64>() {
176                            compare_numbers(*a, b_num, &op)
177                        } else {
178                            false
179                        }
180                    }
181                    // Boolean comparisons
182                    (DataValue::Boolean(a), WhereValue::String(b)) => {
183                        let b_bool = b.eq_ignore_ascii_case("true");
184                        compare_bools(*a, b_bool, &op)
185                    }
186                    // Null comparisons
187                    (_, WhereValue::Null) => {
188                        matches!(op, ComparisonOp::NotEqual)
189                    }
190                    _ => false,
191                };
192                Ok(result)
193            }
194        }
195    }
196
197    fn evaluate_between(
198        &self,
199        column: &str,
200        lower: &WhereValue,
201        upper: &WhereValue,
202        row_index: usize,
203    ) -> Result<bool> {
204        let cell_value = self.get_cell_value(column, row_index)?;
205
206        match cell_value {
207            None | Some(DataValue::Null) => Ok(false),
208            Some(data_val) => {
209                let ge_lower =
210                    self.compare_value(&data_val, lower, &ComparisonOp::GreaterThanOrEqual);
211                let le_upper = self.compare_value(&data_val, upper, &ComparisonOp::LessThanOrEqual);
212                Ok(ge_lower && le_upper)
213            }
214        }
215    }
216
217    fn evaluate_in(
218        &self,
219        column: &str,
220        values: &[WhereValue],
221        row_index: usize,
222        ignore_case: bool,
223    ) -> Result<bool> {
224        let cell_value = self.get_cell_value(column, row_index)?;
225
226        match cell_value {
227            None | Some(DataValue::Null) => Ok(false),
228            Some(data_val) => {
229                for value in values {
230                    if ignore_case {
231                        if self.compare_ignore_case(&data_val, value) {
232                            return Ok(true);
233                        }
234                    } else if self.compare_value(&data_val, value, &ComparisonOp::Equal) {
235                        return Ok(true);
236                    }
237                }
238                Ok(false)
239            }
240        }
241    }
242
243    fn evaluate_like(&self, column: &str, pattern: &str, row_index: usize) -> Result<bool> {
244        let cell_value = self.get_cell_value(column, row_index)?;
245
246        match cell_value {
247            Some(DataValue::String(s)) => {
248                // Convert SQL LIKE pattern to regex
249                let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
250
251                // Use case-insensitive matching
252                let regex = regex::RegexBuilder::new(&format!("^{regex_pattern}$"))
253                    .case_insensitive(true)
254                    .build()
255                    .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
256
257                Ok(regex.is_match(&s))
258            }
259            Some(DataValue::InternedString(s)) => {
260                // Convert SQL LIKE pattern to regex
261                let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
262
263                // Use case-insensitive matching
264                let regex = regex::RegexBuilder::new(&format!("^{regex_pattern}$"))
265                    .case_insensitive(true)
266                    .build()
267                    .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
268
269                Ok(regex.is_match(&s))
270            }
271            _ => Ok(false),
272        }
273    }
274
275    fn evaluate_is_null(&self, column: &str, row_index: usize, expect_null: bool) -> Result<bool> {
276        let cell_value = self.get_cell_value(column, row_index)?;
277        let is_null = matches!(cell_value, None | Some(DataValue::Null));
278        Ok(is_null == expect_null)
279    }
280
281    fn evaluate_is_null_or_empty(&self, column: &str, row_index: usize) -> Result<bool> {
282        let cell_value = self.get_cell_value(column, row_index)?;
283        Ok(match cell_value {
284            None | Some(DataValue::Null) => true,
285            Some(DataValue::String(s)) => s.is_empty(),
286            Some(DataValue::InternedString(s)) => s.is_empty(),
287            _ => false,
288        })
289    }
290
291    fn evaluate_string_method(
292        &self,
293        column: &str,
294        pattern: &str,
295        row_index: usize,
296        method: StringMethod,
297        ignore_case: bool,
298    ) -> Result<bool> {
299        let cell_value = self.get_cell_value(column, row_index)?;
300
301        match cell_value {
302            Some(DataValue::String(s)) => {
303                let (s, pattern) = if ignore_case {
304                    (s.to_lowercase(), pattern.to_lowercase())
305                } else {
306                    (s, pattern.to_string())
307                };
308
309                Ok(match method {
310                    StringMethod::Contains => s.contains(&pattern),
311                    StringMethod::StartsWith => s.starts_with(&pattern),
312                    StringMethod::EndsWith => s.ends_with(&pattern),
313                })
314            }
315            Some(DataValue::InternedString(s)) => {
316                let (s, pattern) = if ignore_case {
317                    (s.to_lowercase(), pattern.to_lowercase())
318                } else {
319                    (s.as_ref().clone(), pattern.to_string())
320                };
321
322                Ok(match method {
323                    StringMethod::Contains => s.contains(&pattern),
324                    StringMethod::StartsWith => s.starts_with(&pattern),
325                    StringMethod::EndsWith => s.ends_with(&pattern),
326                })
327            }
328            _ => Ok(false),
329        }
330    }
331
332    fn evaluate_case_conversion(
333        &self,
334        column: &str,
335        value: &str,
336        row_index: usize,
337        op: &ComparisonOp,
338        to_lower: bool,
339    ) -> Result<bool> {
340        let cell_value = self.get_cell_value(column, row_index)?;
341
342        match cell_value {
343            Some(DataValue::String(s)) => {
344                let converted = if to_lower {
345                    s.to_lowercase()
346                } else {
347                    s.to_uppercase()
348                };
349                Ok(compare_strings(&converted, value, op))
350            }
351            Some(DataValue::InternedString(s)) => {
352                let converted = if to_lower {
353                    s.to_lowercase()
354                } else {
355                    s.to_uppercase()
356                };
357                Ok(compare_strings(&converted, value, op))
358            }
359            _ => Ok(false),
360        }
361    }
362
363    fn evaluate_length(
364        &self,
365        column: &str,
366        length: i64,
367        row_index: usize,
368        op: &ComparisonOp,
369    ) -> Result<bool> {
370        let cell_value = self.get_cell_value(column, row_index)?;
371
372        match cell_value {
373            Some(DataValue::String(s)) => {
374                let len = s.len() as i64;
375                Ok(compare_numbers(len as f64, length as f64, op))
376            }
377            Some(DataValue::InternedString(s)) => {
378                let len = s.len() as i64;
379                Ok(compare_numbers(len as f64, length as f64, op))
380            }
381            _ => Ok(false),
382        }
383    }
384
385    fn compare_value(
386        &self,
387        data_val: &DataValue,
388        where_val: &WhereValue,
389        op: &ComparisonOp,
390    ) -> bool {
391        match (data_val, where_val) {
392            (DataValue::Integer(a), WhereValue::Number(b)) => compare_numbers(*a as f64, *b, op),
393            (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, op),
394            (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, op),
395            (DataValue::InternedString(a), WhereValue::String(b)) => compare_strings(a, b, op),
396            _ => false,
397        }
398    }
399
400    fn compare_ignore_case(&self, data_val: &DataValue, where_val: &WhereValue) -> bool {
401        match (data_val, where_val) {
402            (DataValue::String(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
403            (DataValue::InternedString(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
404            _ => self.compare_value(data_val, where_val, &ComparisonOp::Equal),
405        }
406    }
407}
408
409enum StringMethod {
410    Contains,
411    StartsWith,
412    EndsWith,
413}
414
415fn compare_numbers(a: f64, b: f64, op: &ComparisonOp) -> bool {
416    match op {
417        ComparisonOp::Equal => (a - b).abs() < f64::EPSILON,
418        ComparisonOp::NotEqual => (a - b).abs() >= f64::EPSILON,
419        ComparisonOp::GreaterThan => a > b,
420        ComparisonOp::GreaterThanOrEqual => a >= b,
421        ComparisonOp::LessThan => a < b,
422        ComparisonOp::LessThanOrEqual => a <= b,
423    }
424}
425
426fn compare_strings(a: &str, b: &str, op: &ComparisonOp) -> bool {
427    match op {
428        ComparisonOp::Equal => a == b,
429        ComparisonOp::NotEqual => a != b,
430        ComparisonOp::GreaterThan => a > b,
431        ComparisonOp::GreaterThanOrEqual => a >= b,
432        ComparisonOp::LessThan => a < b,
433        ComparisonOp::LessThanOrEqual => a <= b,
434    }
435}
436
437fn compare_bools(a: bool, b: bool, op: &ComparisonOp) -> bool {
438    match op {
439        ComparisonOp::Equal => a == b,
440        ComparisonOp::NotEqual => a != b,
441        _ => false,
442    }
443}