sql_cli/data/
where_evaluator.rs

1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::where_ast::{ComparisonOp, WhereExpr, WhereValue};
3use anyhow::Result;
4
5/// Evaluates WHERE clause expressions against DataTable rows
6pub struct WhereEvaluator<'a> {
7    table: &'a DataTable,
8    column_indices: Vec<usize>,
9}
10
11impl<'a> WhereEvaluator<'a> {
12    pub fn new(table: &'a DataTable) -> Self {
13        let column_indices = (0..table.column_count()).collect();
14        Self {
15            table,
16            column_indices,
17        }
18    }
19
20    /// Evaluate a WHERE expression for a specific row
21    pub fn evaluate(&self, expr: &WhereExpr, row_index: usize) -> Result<bool> {
22        match expr {
23            WhereExpr::And(left, right) => {
24                Ok(self.evaluate(left, row_index)? && self.evaluate(right, row_index)?)
25            }
26            WhereExpr::Or(left, right) => {
27                Ok(self.evaluate(left, row_index)? || self.evaluate(right, row_index)?)
28            }
29            WhereExpr::Not(inner) => Ok(!self.evaluate(inner, row_index)?),
30            WhereExpr::Equal(column, value) => {
31                self.evaluate_comparison(column, value, row_index, ComparisonOp::Equal)
32            }
33            WhereExpr::NotEqual(column, value) => {
34                self.evaluate_comparison(column, value, row_index, ComparisonOp::NotEqual)
35            }
36            WhereExpr::GreaterThan(column, value) => {
37                self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThan)
38            }
39            WhereExpr::GreaterThanOrEqual(column, value) => {
40                self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThanOrEqual)
41            }
42            WhereExpr::LessThan(column, value) => {
43                self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThan)
44            }
45            WhereExpr::LessThanOrEqual(column, value) => {
46                self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThanOrEqual)
47            }
48            WhereExpr::Between(column, lower, upper) => {
49                self.evaluate_between(column, lower, upper, row_index)
50            }
51            WhereExpr::In(column, values) => self.evaluate_in(column, values, row_index, false),
52            WhereExpr::NotIn(column, values) => {
53                Ok(!self.evaluate_in(column, values, row_index, false)?)
54            }
55            WhereExpr::InIgnoreCase(column, values) => {
56                self.evaluate_in(column, values, row_index, true)
57            }
58            WhereExpr::NotInIgnoreCase(column, values) => {
59                Ok(!self.evaluate_in(column, values, row_index, true)?)
60            }
61            WhereExpr::Like(column, pattern) => self.evaluate_like(column, pattern, row_index),
62            WhereExpr::IsNull(column) => self.evaluate_is_null(column, row_index, true),
63            WhereExpr::IsNotNull(column) => self.evaluate_is_null(column, row_index, false),
64            WhereExpr::Contains(column, substring) => self.evaluate_string_method(
65                column,
66                substring,
67                row_index,
68                StringMethod::Contains,
69                false,
70            ),
71            WhereExpr::StartsWith(column, prefix) => self.evaluate_string_method(
72                column,
73                prefix,
74                row_index,
75                StringMethod::StartsWith,
76                false,
77            ),
78            WhereExpr::EndsWith(column, suffix) => self.evaluate_string_method(
79                column,
80                suffix,
81                row_index,
82                StringMethod::EndsWith,
83                false,
84            ),
85            WhereExpr::ContainsIgnoreCase(column, substring) => self.evaluate_string_method(
86                column,
87                substring,
88                row_index,
89                StringMethod::Contains,
90                true,
91            ),
92            WhereExpr::StartsWithIgnoreCase(column, prefix) => self.evaluate_string_method(
93                column,
94                prefix,
95                row_index,
96                StringMethod::StartsWith,
97                true,
98            ),
99            WhereExpr::EndsWithIgnoreCase(column, suffix) => {
100                self.evaluate_string_method(column, suffix, row_index, StringMethod::EndsWith, true)
101            }
102            WhereExpr::ToLower(column, op, value) => {
103                self.evaluate_case_conversion(column, value, row_index, op, true)
104            }
105            WhereExpr::ToUpper(column, op, value) => {
106                self.evaluate_case_conversion(column, value, row_index, op, false)
107            }
108            WhereExpr::IsNullOrEmpty(column) => self.evaluate_is_null_or_empty(column, row_index),
109            WhereExpr::Length(column, op, length) => {
110                self.evaluate_length(column, *length, row_index, op)
111            }
112        }
113    }
114
115    fn get_column_index(&self, column: &str) -> Result<usize> {
116        let columns = self.table.column_names();
117        columns
118            .iter()
119            .position(|c| c.eq_ignore_ascii_case(column))
120            .ok_or_else(|| anyhow::anyhow!("Column '{}' not found", column))
121    }
122
123    fn get_cell_value(&self, column: &str, row_index: usize) -> Result<Option<DataValue>> {
124        let col_index = self.get_column_index(column)?;
125        Ok(self.table.get_value(row_index, col_index).cloned())
126    }
127
128    fn evaluate_comparison(
129        &self,
130        column: &str,
131        value: &WhereValue,
132        row_index: usize,
133        op: ComparisonOp,
134    ) -> Result<bool> {
135        let cell_value = self.get_cell_value(column, row_index)?;
136
137        match cell_value {
138            None | Some(DataValue::Null) => Ok(false),
139            Some(data_val) => {
140                let result = match (&data_val, value) {
141                    // Number comparisons
142                    (DataValue::Integer(a), WhereValue::Number(b)) => {
143                        compare_numbers(*a as f64, *b, &op)
144                    }
145                    (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, &op),
146                    // String comparisons
147                    (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, &op),
148                    (DataValue::InternedString(a), WhereValue::String(b)) => {
149                        compare_strings(a, b, &op)
150                    }
151                    // String to number coercion
152                    (DataValue::String(a), WhereValue::Number(b)) => {
153                        if let Ok(a_num) = a.parse::<f64>() {
154                            compare_numbers(a_num, *b, &op)
155                        } else {
156                            false
157                        }
158                    }
159                    (DataValue::InternedString(a), WhereValue::Number(b)) => {
160                        if let Ok(a_num) = a.parse::<f64>() {
161                            compare_numbers(a_num, *b, &op)
162                        } else {
163                            false
164                        }
165                    }
166                    (DataValue::Integer(a), WhereValue::String(b)) => {
167                        if let Ok(b_num) = b.parse::<f64>() {
168                            compare_numbers(*a as f64, b_num, &op)
169                        } else {
170                            false
171                        }
172                    }
173                    (DataValue::Float(a), WhereValue::String(b)) => {
174                        if let Ok(b_num) = b.parse::<f64>() {
175                            compare_numbers(*a, b_num, &op)
176                        } else {
177                            false
178                        }
179                    }
180                    // Boolean comparisons
181                    (DataValue::Boolean(a), WhereValue::String(b)) => {
182                        let b_bool = b.eq_ignore_ascii_case("true");
183                        compare_bools(*a, b_bool, &op)
184                    }
185                    // Null comparisons
186                    (_, WhereValue::Null) => {
187                        matches!(op, ComparisonOp::NotEqual)
188                    }
189                    _ => false,
190                };
191                Ok(result)
192            }
193        }
194    }
195
196    fn evaluate_between(
197        &self,
198        column: &str,
199        lower: &WhereValue,
200        upper: &WhereValue,
201        row_index: usize,
202    ) -> Result<bool> {
203        let cell_value = self.get_cell_value(column, row_index)?;
204
205        match cell_value {
206            None | Some(DataValue::Null) => Ok(false),
207            Some(data_val) => {
208                let ge_lower =
209                    self.compare_value(&data_val, lower, &ComparisonOp::GreaterThanOrEqual);
210                let le_upper = self.compare_value(&data_val, upper, &ComparisonOp::LessThanOrEqual);
211                Ok(ge_lower && le_upper)
212            }
213        }
214    }
215
216    fn evaluate_in(
217        &self,
218        column: &str,
219        values: &[WhereValue],
220        row_index: usize,
221        ignore_case: bool,
222    ) -> Result<bool> {
223        let cell_value = self.get_cell_value(column, row_index)?;
224
225        match cell_value {
226            None | Some(DataValue::Null) => Ok(false),
227            Some(data_val) => {
228                for value in values {
229                    if ignore_case {
230                        if self.compare_ignore_case(&data_val, value) {
231                            return Ok(true);
232                        }
233                    } else if self.compare_value(&data_val, value, &ComparisonOp::Equal) {
234                        return Ok(true);
235                    }
236                }
237                Ok(false)
238            }
239        }
240    }
241
242    fn evaluate_like(&self, column: &str, pattern: &str, row_index: usize) -> Result<bool> {
243        let cell_value = self.get_cell_value(column, row_index)?;
244
245        match cell_value {
246            Some(DataValue::String(s)) => {
247                // Convert SQL LIKE pattern to regex
248                let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
249
250                // Use case-insensitive matching
251                let regex = regex::RegexBuilder::new(&format!("^{}$", regex_pattern))
252                    .case_insensitive(true)
253                    .build()
254                    .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
255
256                Ok(regex.is_match(&s))
257            }
258            Some(DataValue::InternedString(s)) => {
259                // Convert SQL LIKE pattern to regex
260                let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
261
262                // Use case-insensitive matching
263                let regex = regex::RegexBuilder::new(&format!("^{}$", regex_pattern))
264                    .case_insensitive(true)
265                    .build()
266                    .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
267
268                Ok(regex.is_match(&s))
269            }
270            _ => Ok(false),
271        }
272    }
273
274    fn evaluate_is_null(&self, column: &str, row_index: usize, expect_null: bool) -> Result<bool> {
275        let cell_value = self.get_cell_value(column, row_index)?;
276        let is_null = matches!(cell_value, None | Some(DataValue::Null));
277        Ok(is_null == expect_null)
278    }
279
280    fn evaluate_is_null_or_empty(&self, column: &str, row_index: usize) -> Result<bool> {
281        let cell_value = self.get_cell_value(column, row_index)?;
282        Ok(match cell_value {
283            None | Some(DataValue::Null) => true,
284            Some(DataValue::String(s)) => s.is_empty(),
285            Some(DataValue::InternedString(s)) => s.is_empty(),
286            _ => false,
287        })
288    }
289
290    fn evaluate_string_method(
291        &self,
292        column: &str,
293        pattern: &str,
294        row_index: usize,
295        method: StringMethod,
296        ignore_case: bool,
297    ) -> Result<bool> {
298        let cell_value = self.get_cell_value(column, row_index)?;
299
300        match cell_value {
301            Some(DataValue::String(s)) => {
302                let (s, pattern) = if ignore_case {
303                    (s.to_lowercase(), pattern.to_lowercase())
304                } else {
305                    (s, pattern.to_string())
306                };
307
308                Ok(match method {
309                    StringMethod::Contains => s.contains(&pattern),
310                    StringMethod::StartsWith => s.starts_with(&pattern),
311                    StringMethod::EndsWith => s.ends_with(&pattern),
312                })
313            }
314            Some(DataValue::InternedString(s)) => {
315                let (s, pattern) = if ignore_case {
316                    (s.to_lowercase(), pattern.to_lowercase())
317                } else {
318                    (s.as_ref().clone(), pattern.to_string())
319                };
320
321                Ok(match method {
322                    StringMethod::Contains => s.contains(&pattern),
323                    StringMethod::StartsWith => s.starts_with(&pattern),
324                    StringMethod::EndsWith => s.ends_with(&pattern),
325                })
326            }
327            _ => Ok(false),
328        }
329    }
330
331    fn evaluate_case_conversion(
332        &self,
333        column: &str,
334        value: &str,
335        row_index: usize,
336        op: &ComparisonOp,
337        to_lower: bool,
338    ) -> Result<bool> {
339        let cell_value = self.get_cell_value(column, row_index)?;
340
341        match cell_value {
342            Some(DataValue::String(s)) => {
343                let converted = if to_lower {
344                    s.to_lowercase()
345                } else {
346                    s.to_uppercase()
347                };
348                Ok(compare_strings(&converted, value, op))
349            }
350            Some(DataValue::InternedString(s)) => {
351                let converted = if to_lower {
352                    s.to_lowercase()
353                } else {
354                    s.to_uppercase()
355                };
356                Ok(compare_strings(&converted, value, op))
357            }
358            _ => Ok(false),
359        }
360    }
361
362    fn evaluate_length(
363        &self,
364        column: &str,
365        length: i64,
366        row_index: usize,
367        op: &ComparisonOp,
368    ) -> Result<bool> {
369        let cell_value = self.get_cell_value(column, row_index)?;
370
371        match cell_value {
372            Some(DataValue::String(s)) => {
373                let len = s.len() as i64;
374                Ok(compare_numbers(len as f64, length as f64, op))
375            }
376            Some(DataValue::InternedString(s)) => {
377                let len = s.len() as i64;
378                Ok(compare_numbers(len as f64, length as f64, op))
379            }
380            _ => Ok(false),
381        }
382    }
383
384    fn compare_value(
385        &self,
386        data_val: &DataValue,
387        where_val: &WhereValue,
388        op: &ComparisonOp,
389    ) -> bool {
390        match (data_val, where_val) {
391            (DataValue::Integer(a), WhereValue::Number(b)) => compare_numbers(*a as f64, *b, op),
392            (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, op),
393            (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, op),
394            (DataValue::InternedString(a), WhereValue::String(b)) => compare_strings(a, b, op),
395            _ => false,
396        }
397    }
398
399    fn compare_ignore_case(&self, data_val: &DataValue, where_val: &WhereValue) -> bool {
400        match (data_val, where_val) {
401            (DataValue::String(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
402            (DataValue::InternedString(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
403            _ => self.compare_value(data_val, where_val, &ComparisonOp::Equal),
404        }
405    }
406}
407
408enum StringMethod {
409    Contains,
410    StartsWith,
411    EndsWith,
412}
413
414fn compare_numbers(a: f64, b: f64, op: &ComparisonOp) -> bool {
415    match op {
416        ComparisonOp::Equal => (a - b).abs() < f64::EPSILON,
417        ComparisonOp::NotEqual => (a - b).abs() >= f64::EPSILON,
418        ComparisonOp::GreaterThan => a > b,
419        ComparisonOp::GreaterThanOrEqual => a >= b,
420        ComparisonOp::LessThan => a < b,
421        ComparisonOp::LessThanOrEqual => a <= b,
422    }
423}
424
425fn compare_strings(a: &str, b: &str, op: &ComparisonOp) -> bool {
426    match op {
427        ComparisonOp::Equal => a == b,
428        ComparisonOp::NotEqual => a != b,
429        ComparisonOp::GreaterThan => a > b,
430        ComparisonOp::GreaterThanOrEqual => a >= b,
431        ComparisonOp::LessThan => a < b,
432        ComparisonOp::LessThanOrEqual => a <= b,
433    }
434}
435
436fn compare_bools(a: bool, b: bool, op: &ComparisonOp) -> bool {
437    match op {
438        ComparisonOp::Equal => a == b,
439        ComparisonOp::NotEqual => a != b,
440        _ => false,
441    }
442}