1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::where_ast::{ComparisonOp, WhereExpr, WhereValue};
3use anyhow::Result;
4
5pub struct WhereEvaluator<'a> {
7 table: &'a DataTable,
8 column_indices: Vec<usize>,
9}
10
11impl<'a> WhereEvaluator<'a> {
12 #[must_use]
13 pub fn new(table: &'a DataTable) -> Self {
14 let column_indices = (0..table.column_count()).collect();
15 Self {
16 table,
17 column_indices,
18 }
19 }
20
21 pub fn evaluate(&self, expr: &WhereExpr, row_index: usize) -> Result<bool> {
23 match expr {
24 WhereExpr::And(left, right) => {
25 Ok(self.evaluate(left, row_index)? && self.evaluate(right, row_index)?)
26 }
27 WhereExpr::Or(left, right) => {
28 Ok(self.evaluate(left, row_index)? || self.evaluate(right, row_index)?)
29 }
30 WhereExpr::Not(inner) => Ok(!self.evaluate(inner, row_index)?),
31 WhereExpr::Equal(column, value) => {
32 self.evaluate_comparison(column, value, row_index, ComparisonOp::Equal)
33 }
34 WhereExpr::NotEqual(column, value) => {
35 self.evaluate_comparison(column, value, row_index, ComparisonOp::NotEqual)
36 }
37 WhereExpr::GreaterThan(column, value) => {
38 self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThan)
39 }
40 WhereExpr::GreaterThanOrEqual(column, value) => {
41 self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThanOrEqual)
42 }
43 WhereExpr::LessThan(column, value) => {
44 self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThan)
45 }
46 WhereExpr::LessThanOrEqual(column, value) => {
47 self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThanOrEqual)
48 }
49 WhereExpr::Between(column, lower, upper) => {
50 self.evaluate_between(column, lower, upper, row_index)
51 }
52 WhereExpr::In(column, values) => self.evaluate_in(column, values, row_index, false),
53 WhereExpr::NotIn(column, values) => {
54 Ok(!self.evaluate_in(column, values, row_index, false)?)
55 }
56 WhereExpr::InIgnoreCase(column, values) => {
57 self.evaluate_in(column, values, row_index, true)
58 }
59 WhereExpr::NotInIgnoreCase(column, values) => {
60 Ok(!self.evaluate_in(column, values, row_index, true)?)
61 }
62 WhereExpr::Like(column, pattern) => self.evaluate_like(column, pattern, row_index),
63 WhereExpr::IsNull(column) => self.evaluate_is_null(column, row_index, true),
64 WhereExpr::IsNotNull(column) => self.evaluate_is_null(column, row_index, false),
65 WhereExpr::Contains(column, substring) => self.evaluate_string_method(
66 column,
67 substring,
68 row_index,
69 StringMethod::Contains,
70 false,
71 ),
72 WhereExpr::StartsWith(column, prefix) => self.evaluate_string_method(
73 column,
74 prefix,
75 row_index,
76 StringMethod::StartsWith,
77 false,
78 ),
79 WhereExpr::EndsWith(column, suffix) => self.evaluate_string_method(
80 column,
81 suffix,
82 row_index,
83 StringMethod::EndsWith,
84 false,
85 ),
86 WhereExpr::ContainsIgnoreCase(column, substring) => self.evaluate_string_method(
87 column,
88 substring,
89 row_index,
90 StringMethod::Contains,
91 true,
92 ),
93 WhereExpr::StartsWithIgnoreCase(column, prefix) => self.evaluate_string_method(
94 column,
95 prefix,
96 row_index,
97 StringMethod::StartsWith,
98 true,
99 ),
100 WhereExpr::EndsWithIgnoreCase(column, suffix) => {
101 self.evaluate_string_method(column, suffix, row_index, StringMethod::EndsWith, true)
102 }
103 WhereExpr::ToLower(column, op, value) => {
104 self.evaluate_case_conversion(column, value, row_index, op, true)
105 }
106 WhereExpr::ToUpper(column, op, value) => {
107 self.evaluate_case_conversion(column, value, row_index, op, false)
108 }
109 WhereExpr::IsNullOrEmpty(column) => self.evaluate_is_null_or_empty(column, row_index),
110 WhereExpr::Length(column, op, length) => {
111 self.evaluate_length(column, *length, row_index, op)
112 }
113 }
114 }
115
116 fn get_column_index(&self, column: &str) -> Result<usize> {
117 let columns = self.table.column_names();
118 columns
119 .iter()
120 .position(|c| c.eq_ignore_ascii_case(column))
121 .ok_or_else(|| anyhow::anyhow!("Column '{}' not found", column))
122 }
123
124 fn get_cell_value(&self, column: &str, row_index: usize) -> Result<Option<DataValue>> {
125 let col_index = self.get_column_index(column)?;
126 Ok(self.table.get_value(row_index, col_index).cloned())
127 }
128
129 fn evaluate_comparison(
130 &self,
131 column: &str,
132 value: &WhereValue,
133 row_index: usize,
134 op: ComparisonOp,
135 ) -> Result<bool> {
136 let cell_value = self.get_cell_value(column, row_index)?;
137
138 match cell_value {
139 None | Some(DataValue::Null) => Ok(false),
140 Some(data_val) => {
141 let result = match (&data_val, value) {
142 (DataValue::Integer(a), WhereValue::Number(b)) => {
144 compare_numbers(*a as f64, *b, &op)
145 }
146 (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, &op),
147 (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, &op),
149 (DataValue::InternedString(a), WhereValue::String(b)) => {
150 compare_strings(a, b, &op)
151 }
152 (DataValue::String(a), WhereValue::Number(b)) => {
154 if let Ok(a_num) = a.parse::<f64>() {
155 compare_numbers(a_num, *b, &op)
156 } else {
157 false
158 }
159 }
160 (DataValue::InternedString(a), WhereValue::Number(b)) => {
161 if let Ok(a_num) = a.parse::<f64>() {
162 compare_numbers(a_num, *b, &op)
163 } else {
164 false
165 }
166 }
167 (DataValue::Integer(a), WhereValue::String(b)) => {
168 if let Ok(b_num) = b.parse::<f64>() {
169 compare_numbers(*a as f64, b_num, &op)
170 } else {
171 false
172 }
173 }
174 (DataValue::Float(a), WhereValue::String(b)) => {
175 if let Ok(b_num) = b.parse::<f64>() {
176 compare_numbers(*a, b_num, &op)
177 } else {
178 false
179 }
180 }
181 (DataValue::Boolean(a), WhereValue::String(b)) => {
183 let b_bool = b.eq_ignore_ascii_case("true");
184 compare_bools(*a, b_bool, &op)
185 }
186 (_, WhereValue::Null) => {
188 matches!(op, ComparisonOp::NotEqual)
189 }
190 _ => false,
191 };
192 Ok(result)
193 }
194 }
195 }
196
197 fn evaluate_between(
198 &self,
199 column: &str,
200 lower: &WhereValue,
201 upper: &WhereValue,
202 row_index: usize,
203 ) -> Result<bool> {
204 let cell_value = self.get_cell_value(column, row_index)?;
205
206 match cell_value {
207 None | Some(DataValue::Null) => Ok(false),
208 Some(data_val) => {
209 let ge_lower =
210 self.compare_value(&data_val, lower, &ComparisonOp::GreaterThanOrEqual);
211 let le_upper = self.compare_value(&data_val, upper, &ComparisonOp::LessThanOrEqual);
212 Ok(ge_lower && le_upper)
213 }
214 }
215 }
216
217 fn evaluate_in(
218 &self,
219 column: &str,
220 values: &[WhereValue],
221 row_index: usize,
222 ignore_case: bool,
223 ) -> Result<bool> {
224 let cell_value = self.get_cell_value(column, row_index)?;
225
226 match cell_value {
227 None | Some(DataValue::Null) => Ok(false),
228 Some(data_val) => {
229 for value in values {
230 if ignore_case {
231 if self.compare_ignore_case(&data_val, value) {
232 return Ok(true);
233 }
234 } else if self.compare_value(&data_val, value, &ComparisonOp::Equal) {
235 return Ok(true);
236 }
237 }
238 Ok(false)
239 }
240 }
241 }
242
243 fn evaluate_like(&self, column: &str, pattern: &str, row_index: usize) -> Result<bool> {
244 let cell_value = self.get_cell_value(column, row_index)?;
245
246 match cell_value {
247 Some(DataValue::String(s)) => {
248 let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
250
251 let regex = regex::RegexBuilder::new(&format!("^{regex_pattern}$"))
253 .case_insensitive(true)
254 .build()
255 .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
256
257 Ok(regex.is_match(&s))
258 }
259 Some(DataValue::InternedString(s)) => {
260 let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
262
263 let regex = regex::RegexBuilder::new(&format!("^{regex_pattern}$"))
265 .case_insensitive(true)
266 .build()
267 .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
268
269 Ok(regex.is_match(&s))
270 }
271 _ => Ok(false),
272 }
273 }
274
275 fn evaluate_is_null(&self, column: &str, row_index: usize, expect_null: bool) -> Result<bool> {
276 let cell_value = self.get_cell_value(column, row_index)?;
277 let is_null = matches!(cell_value, None | Some(DataValue::Null));
278 Ok(is_null == expect_null)
279 }
280
281 fn evaluate_is_null_or_empty(&self, column: &str, row_index: usize) -> Result<bool> {
282 let cell_value = self.get_cell_value(column, row_index)?;
283 Ok(match cell_value {
284 None | Some(DataValue::Null) => true,
285 Some(DataValue::String(s)) => s.is_empty(),
286 Some(DataValue::InternedString(s)) => s.is_empty(),
287 _ => false,
288 })
289 }
290
291 fn evaluate_string_method(
292 &self,
293 column: &str,
294 pattern: &str,
295 row_index: usize,
296 method: StringMethod,
297 ignore_case: bool,
298 ) -> Result<bool> {
299 let cell_value = self.get_cell_value(column, row_index)?;
300
301 match cell_value {
302 Some(DataValue::String(s)) => {
303 let (s, pattern) = if ignore_case {
304 (s.to_lowercase(), pattern.to_lowercase())
305 } else {
306 (s, pattern.to_string())
307 };
308
309 Ok(match method {
310 StringMethod::Contains => s.contains(&pattern),
311 StringMethod::StartsWith => s.starts_with(&pattern),
312 StringMethod::EndsWith => s.ends_with(&pattern),
313 })
314 }
315 Some(DataValue::InternedString(s)) => {
316 let (s, pattern) = if ignore_case {
317 (s.to_lowercase(), pattern.to_lowercase())
318 } else {
319 (s.as_ref().clone(), pattern.to_string())
320 };
321
322 Ok(match method {
323 StringMethod::Contains => s.contains(&pattern),
324 StringMethod::StartsWith => s.starts_with(&pattern),
325 StringMethod::EndsWith => s.ends_with(&pattern),
326 })
327 }
328 _ => Ok(false),
329 }
330 }
331
332 fn evaluate_case_conversion(
333 &self,
334 column: &str,
335 value: &str,
336 row_index: usize,
337 op: &ComparisonOp,
338 to_lower: bool,
339 ) -> Result<bool> {
340 let cell_value = self.get_cell_value(column, row_index)?;
341
342 match cell_value {
343 Some(DataValue::String(s)) => {
344 let converted = if to_lower {
345 s.to_lowercase()
346 } else {
347 s.to_uppercase()
348 };
349 Ok(compare_strings(&converted, value, op))
350 }
351 Some(DataValue::InternedString(s)) => {
352 let converted = if to_lower {
353 s.to_lowercase()
354 } else {
355 s.to_uppercase()
356 };
357 Ok(compare_strings(&converted, value, op))
358 }
359 _ => Ok(false),
360 }
361 }
362
363 fn evaluate_length(
364 &self,
365 column: &str,
366 length: i64,
367 row_index: usize,
368 op: &ComparisonOp,
369 ) -> Result<bool> {
370 let cell_value = self.get_cell_value(column, row_index)?;
371
372 match cell_value {
373 Some(DataValue::String(s)) => {
374 let len = s.len() as i64;
375 Ok(compare_numbers(len as f64, length as f64, op))
376 }
377 Some(DataValue::InternedString(s)) => {
378 let len = s.len() as i64;
379 Ok(compare_numbers(len as f64, length as f64, op))
380 }
381 _ => Ok(false),
382 }
383 }
384
385 fn compare_value(
386 &self,
387 data_val: &DataValue,
388 where_val: &WhereValue,
389 op: &ComparisonOp,
390 ) -> bool {
391 match (data_val, where_val) {
392 (DataValue::Integer(a), WhereValue::Number(b)) => compare_numbers(*a as f64, *b, op),
393 (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, op),
394 (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, op),
395 (DataValue::InternedString(a), WhereValue::String(b)) => compare_strings(a, b, op),
396 _ => false,
397 }
398 }
399
400 fn compare_ignore_case(&self, data_val: &DataValue, where_val: &WhereValue) -> bool {
401 match (data_val, where_val) {
402 (DataValue::String(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
403 (DataValue::InternedString(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
404 _ => self.compare_value(data_val, where_val, &ComparisonOp::Equal),
405 }
406 }
407}
408
409enum StringMethod {
410 Contains,
411 StartsWith,
412 EndsWith,
413}
414
415fn compare_numbers(a: f64, b: f64, op: &ComparisonOp) -> bool {
416 match op {
417 ComparisonOp::Equal => (a - b).abs() < f64::EPSILON,
418 ComparisonOp::NotEqual => (a - b).abs() >= f64::EPSILON,
419 ComparisonOp::GreaterThan => a > b,
420 ComparisonOp::GreaterThanOrEqual => a >= b,
421 ComparisonOp::LessThan => a < b,
422 ComparisonOp::LessThanOrEqual => a <= b,
423 }
424}
425
426fn compare_strings(a: &str, b: &str, op: &ComparisonOp) -> bool {
427 match op {
428 ComparisonOp::Equal => a == b,
429 ComparisonOp::NotEqual => a != b,
430 ComparisonOp::GreaterThan => a > b,
431 ComparisonOp::GreaterThanOrEqual => a >= b,
432 ComparisonOp::LessThan => a < b,
433 ComparisonOp::LessThanOrEqual => a <= b,
434 }
435}
436
437fn compare_bools(a: bool, b: bool, op: &ComparisonOp) -> bool {
438 match op {
439 ComparisonOp::Equal => a == b,
440 ComparisonOp::NotEqual => a != b,
441 _ => false,
442 }
443}