1use crate::data::datatable::{DataTable, DataValue};
2use crate::sql::where_ast::{ComparisonOp, WhereExpr, WhereValue};
3use anyhow::Result;
4
5pub struct WhereEvaluator<'a> {
7 table: &'a DataTable,
8 column_indices: Vec<usize>,
9}
10
11impl<'a> WhereEvaluator<'a> {
12 pub fn new(table: &'a DataTable) -> Self {
13 let column_indices = (0..table.column_count()).collect();
14 Self {
15 table,
16 column_indices,
17 }
18 }
19
20 pub fn evaluate(&self, expr: &WhereExpr, row_index: usize) -> Result<bool> {
22 match expr {
23 WhereExpr::And(left, right) => {
24 Ok(self.evaluate(left, row_index)? && self.evaluate(right, row_index)?)
25 }
26 WhereExpr::Or(left, right) => {
27 Ok(self.evaluate(left, row_index)? || self.evaluate(right, row_index)?)
28 }
29 WhereExpr::Not(inner) => Ok(!self.evaluate(inner, row_index)?),
30 WhereExpr::Equal(column, value) => {
31 self.evaluate_comparison(column, value, row_index, ComparisonOp::Equal)
32 }
33 WhereExpr::NotEqual(column, value) => {
34 self.evaluate_comparison(column, value, row_index, ComparisonOp::NotEqual)
35 }
36 WhereExpr::GreaterThan(column, value) => {
37 self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThan)
38 }
39 WhereExpr::GreaterThanOrEqual(column, value) => {
40 self.evaluate_comparison(column, value, row_index, ComparisonOp::GreaterThanOrEqual)
41 }
42 WhereExpr::LessThan(column, value) => {
43 self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThan)
44 }
45 WhereExpr::LessThanOrEqual(column, value) => {
46 self.evaluate_comparison(column, value, row_index, ComparisonOp::LessThanOrEqual)
47 }
48 WhereExpr::Between(column, lower, upper) => {
49 self.evaluate_between(column, lower, upper, row_index)
50 }
51 WhereExpr::In(column, values) => self.evaluate_in(column, values, row_index, false),
52 WhereExpr::NotIn(column, values) => {
53 Ok(!self.evaluate_in(column, values, row_index, false)?)
54 }
55 WhereExpr::InIgnoreCase(column, values) => {
56 self.evaluate_in(column, values, row_index, true)
57 }
58 WhereExpr::NotInIgnoreCase(column, values) => {
59 Ok(!self.evaluate_in(column, values, row_index, true)?)
60 }
61 WhereExpr::Like(column, pattern) => self.evaluate_like(column, pattern, row_index),
62 WhereExpr::IsNull(column) => self.evaluate_is_null(column, row_index, true),
63 WhereExpr::IsNotNull(column) => self.evaluate_is_null(column, row_index, false),
64 WhereExpr::Contains(column, substring) => self.evaluate_string_method(
65 column,
66 substring,
67 row_index,
68 StringMethod::Contains,
69 false,
70 ),
71 WhereExpr::StartsWith(column, prefix) => self.evaluate_string_method(
72 column,
73 prefix,
74 row_index,
75 StringMethod::StartsWith,
76 false,
77 ),
78 WhereExpr::EndsWith(column, suffix) => self.evaluate_string_method(
79 column,
80 suffix,
81 row_index,
82 StringMethod::EndsWith,
83 false,
84 ),
85 WhereExpr::ContainsIgnoreCase(column, substring) => self.evaluate_string_method(
86 column,
87 substring,
88 row_index,
89 StringMethod::Contains,
90 true,
91 ),
92 WhereExpr::StartsWithIgnoreCase(column, prefix) => self.evaluate_string_method(
93 column,
94 prefix,
95 row_index,
96 StringMethod::StartsWith,
97 true,
98 ),
99 WhereExpr::EndsWithIgnoreCase(column, suffix) => {
100 self.evaluate_string_method(column, suffix, row_index, StringMethod::EndsWith, true)
101 }
102 WhereExpr::ToLower(column, op, value) => {
103 self.evaluate_case_conversion(column, value, row_index, op, true)
104 }
105 WhereExpr::ToUpper(column, op, value) => {
106 self.evaluate_case_conversion(column, value, row_index, op, false)
107 }
108 WhereExpr::IsNullOrEmpty(column) => self.evaluate_is_null_or_empty(column, row_index),
109 WhereExpr::Length(column, op, length) => {
110 self.evaluate_length(column, *length, row_index, op)
111 }
112 }
113 }
114
115 fn get_column_index(&self, column: &str) -> Result<usize> {
116 let columns = self.table.column_names();
117 columns
118 .iter()
119 .position(|c| c.eq_ignore_ascii_case(column))
120 .ok_or_else(|| anyhow::anyhow!("Column '{}' not found", column))
121 }
122
123 fn get_cell_value(&self, column: &str, row_index: usize) -> Result<Option<DataValue>> {
124 let col_index = self.get_column_index(column)?;
125 Ok(self.table.get_value(row_index, col_index).cloned())
126 }
127
128 fn evaluate_comparison(
129 &self,
130 column: &str,
131 value: &WhereValue,
132 row_index: usize,
133 op: ComparisonOp,
134 ) -> Result<bool> {
135 let cell_value = self.get_cell_value(column, row_index)?;
136
137 match cell_value {
138 None | Some(DataValue::Null) => Ok(false),
139 Some(data_val) => {
140 let result = match (&data_val, value) {
141 (DataValue::Integer(a), WhereValue::Number(b)) => {
143 compare_numbers(*a as f64, *b, &op)
144 }
145 (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, &op),
146 (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, &op),
148 (DataValue::InternedString(a), WhereValue::String(b)) => {
149 compare_strings(a, b, &op)
150 }
151 (DataValue::String(a), WhereValue::Number(b)) => {
153 if let Ok(a_num) = a.parse::<f64>() {
154 compare_numbers(a_num, *b, &op)
155 } else {
156 false
157 }
158 }
159 (DataValue::InternedString(a), WhereValue::Number(b)) => {
160 if let Ok(a_num) = a.parse::<f64>() {
161 compare_numbers(a_num, *b, &op)
162 } else {
163 false
164 }
165 }
166 (DataValue::Integer(a), WhereValue::String(b)) => {
167 if let Ok(b_num) = b.parse::<f64>() {
168 compare_numbers(*a as f64, b_num, &op)
169 } else {
170 false
171 }
172 }
173 (DataValue::Float(a), WhereValue::String(b)) => {
174 if let Ok(b_num) = b.parse::<f64>() {
175 compare_numbers(*a, b_num, &op)
176 } else {
177 false
178 }
179 }
180 (DataValue::Boolean(a), WhereValue::String(b)) => {
182 let b_bool = b.eq_ignore_ascii_case("true");
183 compare_bools(*a, b_bool, &op)
184 }
185 (_, WhereValue::Null) => {
187 matches!(op, ComparisonOp::NotEqual)
188 }
189 _ => false,
190 };
191 Ok(result)
192 }
193 }
194 }
195
196 fn evaluate_between(
197 &self,
198 column: &str,
199 lower: &WhereValue,
200 upper: &WhereValue,
201 row_index: usize,
202 ) -> Result<bool> {
203 let cell_value = self.get_cell_value(column, row_index)?;
204
205 match cell_value {
206 None | Some(DataValue::Null) => Ok(false),
207 Some(data_val) => {
208 let ge_lower =
209 self.compare_value(&data_val, lower, &ComparisonOp::GreaterThanOrEqual);
210 let le_upper = self.compare_value(&data_val, upper, &ComparisonOp::LessThanOrEqual);
211 Ok(ge_lower && le_upper)
212 }
213 }
214 }
215
216 fn evaluate_in(
217 &self,
218 column: &str,
219 values: &[WhereValue],
220 row_index: usize,
221 ignore_case: bool,
222 ) -> Result<bool> {
223 let cell_value = self.get_cell_value(column, row_index)?;
224
225 match cell_value {
226 None | Some(DataValue::Null) => Ok(false),
227 Some(data_val) => {
228 for value in values {
229 if ignore_case {
230 if self.compare_ignore_case(&data_val, value) {
231 return Ok(true);
232 }
233 } else if self.compare_value(&data_val, value, &ComparisonOp::Equal) {
234 return Ok(true);
235 }
236 }
237 Ok(false)
238 }
239 }
240 }
241
242 fn evaluate_like(&self, column: &str, pattern: &str, row_index: usize) -> Result<bool> {
243 let cell_value = self.get_cell_value(column, row_index)?;
244
245 match cell_value {
246 Some(DataValue::String(s)) => {
247 let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
249
250 let regex = regex::RegexBuilder::new(&format!("^{}$", regex_pattern))
252 .case_insensitive(true)
253 .build()
254 .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
255
256 Ok(regex.is_match(&s))
257 }
258 Some(DataValue::InternedString(s)) => {
259 let regex_pattern = pattern.replace('%', ".*").replace('_', ".");
261
262 let regex = regex::RegexBuilder::new(&format!("^{}$", regex_pattern))
264 .case_insensitive(true)
265 .build()
266 .map_err(|e| anyhow::anyhow!("Invalid LIKE pattern: {}", e))?;
267
268 Ok(regex.is_match(&s))
269 }
270 _ => Ok(false),
271 }
272 }
273
274 fn evaluate_is_null(&self, column: &str, row_index: usize, expect_null: bool) -> Result<bool> {
275 let cell_value = self.get_cell_value(column, row_index)?;
276 let is_null = matches!(cell_value, None | Some(DataValue::Null));
277 Ok(is_null == expect_null)
278 }
279
280 fn evaluate_is_null_or_empty(&self, column: &str, row_index: usize) -> Result<bool> {
281 let cell_value = self.get_cell_value(column, row_index)?;
282 Ok(match cell_value {
283 None | Some(DataValue::Null) => true,
284 Some(DataValue::String(s)) => s.is_empty(),
285 Some(DataValue::InternedString(s)) => s.is_empty(),
286 _ => false,
287 })
288 }
289
290 fn evaluate_string_method(
291 &self,
292 column: &str,
293 pattern: &str,
294 row_index: usize,
295 method: StringMethod,
296 ignore_case: bool,
297 ) -> Result<bool> {
298 let cell_value = self.get_cell_value(column, row_index)?;
299
300 match cell_value {
301 Some(DataValue::String(s)) => {
302 let (s, pattern) = if ignore_case {
303 (s.to_lowercase(), pattern.to_lowercase())
304 } else {
305 (s, pattern.to_string())
306 };
307
308 Ok(match method {
309 StringMethod::Contains => s.contains(&pattern),
310 StringMethod::StartsWith => s.starts_with(&pattern),
311 StringMethod::EndsWith => s.ends_with(&pattern),
312 })
313 }
314 Some(DataValue::InternedString(s)) => {
315 let (s, pattern) = if ignore_case {
316 (s.to_lowercase(), pattern.to_lowercase())
317 } else {
318 (s.as_ref().clone(), pattern.to_string())
319 };
320
321 Ok(match method {
322 StringMethod::Contains => s.contains(&pattern),
323 StringMethod::StartsWith => s.starts_with(&pattern),
324 StringMethod::EndsWith => s.ends_with(&pattern),
325 })
326 }
327 _ => Ok(false),
328 }
329 }
330
331 fn evaluate_case_conversion(
332 &self,
333 column: &str,
334 value: &str,
335 row_index: usize,
336 op: &ComparisonOp,
337 to_lower: bool,
338 ) -> Result<bool> {
339 let cell_value = self.get_cell_value(column, row_index)?;
340
341 match cell_value {
342 Some(DataValue::String(s)) => {
343 let converted = if to_lower {
344 s.to_lowercase()
345 } else {
346 s.to_uppercase()
347 };
348 Ok(compare_strings(&converted, value, op))
349 }
350 Some(DataValue::InternedString(s)) => {
351 let converted = if to_lower {
352 s.to_lowercase()
353 } else {
354 s.to_uppercase()
355 };
356 Ok(compare_strings(&converted, value, op))
357 }
358 _ => Ok(false),
359 }
360 }
361
362 fn evaluate_length(
363 &self,
364 column: &str,
365 length: i64,
366 row_index: usize,
367 op: &ComparisonOp,
368 ) -> Result<bool> {
369 let cell_value = self.get_cell_value(column, row_index)?;
370
371 match cell_value {
372 Some(DataValue::String(s)) => {
373 let len = s.len() as i64;
374 Ok(compare_numbers(len as f64, length as f64, op))
375 }
376 Some(DataValue::InternedString(s)) => {
377 let len = s.len() as i64;
378 Ok(compare_numbers(len as f64, length as f64, op))
379 }
380 _ => Ok(false),
381 }
382 }
383
384 fn compare_value(
385 &self,
386 data_val: &DataValue,
387 where_val: &WhereValue,
388 op: &ComparisonOp,
389 ) -> bool {
390 match (data_val, where_val) {
391 (DataValue::Integer(a), WhereValue::Number(b)) => compare_numbers(*a as f64, *b, op),
392 (DataValue::Float(a), WhereValue::Number(b)) => compare_numbers(*a, *b, op),
393 (DataValue::String(a), WhereValue::String(b)) => compare_strings(a, b, op),
394 (DataValue::InternedString(a), WhereValue::String(b)) => compare_strings(a, b, op),
395 _ => false,
396 }
397 }
398
399 fn compare_ignore_case(&self, data_val: &DataValue, where_val: &WhereValue) -> bool {
400 match (data_val, where_val) {
401 (DataValue::String(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
402 (DataValue::InternedString(a), WhereValue::String(b)) => a.eq_ignore_ascii_case(b),
403 _ => self.compare_value(data_val, where_val, &ComparisonOp::Equal),
404 }
405 }
406}
407
408enum StringMethod {
409 Contains,
410 StartsWith,
411 EndsWith,
412}
413
414fn compare_numbers(a: f64, b: f64, op: &ComparisonOp) -> bool {
415 match op {
416 ComparisonOp::Equal => (a - b).abs() < f64::EPSILON,
417 ComparisonOp::NotEqual => (a - b).abs() >= f64::EPSILON,
418 ComparisonOp::GreaterThan => a > b,
419 ComparisonOp::GreaterThanOrEqual => a >= b,
420 ComparisonOp::LessThan => a < b,
421 ComparisonOp::LessThanOrEqual => a <= b,
422 }
423}
424
425fn compare_strings(a: &str, b: &str, op: &ComparisonOp) -> bool {
426 match op {
427 ComparisonOp::Equal => a == b,
428 ComparisonOp::NotEqual => a != b,
429 ComparisonOp::GreaterThan => a > b,
430 ComparisonOp::GreaterThanOrEqual => a >= b,
431 ComparisonOp::LessThan => a < b,
432 ComparisonOp::LessThanOrEqual => a <= b,
433 }
434}
435
436fn compare_bools(a: bool, b: bool, op: &ComparisonOp) -> bool {
437 match op {
438 ComparisonOp::Equal => a == b,
439 ComparisonOp::NotEqual => a != b,
440 _ => false,
441 }
442}