Skip to main content

zerodds_sql_filter/
evaluator.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2026 ZeroDDS Contributors
3
4//! Expression evaluator.
5//!
6//! Evaluation is strict — a type mismatch (e.g. String < Int) returns
7//! `EvalError::TypeMismatch`. The caller decides whether to treat that
8//! as `filter denies` or as `filter error`.
9
10use alloc::string::String;
11
12use crate::ast::{CmpOp, Expr, Operand, Value};
13
14/// Row abstraction: access to the fields of a sample.
15///
16/// Implementers map dotted paths (`a.b.c`) to the matching values. For
17/// simple flat structs a `HashMap<String, Value>` lookup is enough.
18pub trait RowAccess {
19    /// Returns the value of the field path if present. `None` = no field
20    /// with that name.
21    fn get(&self, path: &str) -> Option<Value>;
22}
23
24/// Error during evaluation.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum EvalError {
27    /// Field not found.
28    UnknownField(String),
29    /// Parameter index outside the passed slice.
30    MissingParam(u32),
31    /// Operator not compatible with the operand types.
32    TypeMismatch(String),
33}
34
35impl core::fmt::Display for EvalError {
36    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
37        match self {
38            Self::UnknownField(n) => write!(f, "unknown field: {n}"),
39            Self::MissingParam(i) => write!(f, "missing parameter %{i}"),
40            Self::TypeMismatch(m) => write!(f, "type mismatch: {m}"),
41        }
42    }
43}
44
45#[cfg(feature = "std")]
46impl std::error::Error for EvalError {}
47
48impl Expr {
49    /// Evaluates the expression against a row + parameter slice.
50    ///
51    /// # Errors
52    /// Siehe [`EvalError`].
53    pub fn evaluate<R: RowAccess>(&self, row: &R, params: &[Value]) -> Result<bool, EvalError> {
54        match self {
55            Self::And(a, b) => Ok(a.evaluate(row, params)? && b.evaluate(row, params)?),
56            Self::Or(a, b) => Ok(a.evaluate(row, params)? || b.evaluate(row, params)?),
57            Self::Not(inner) => Ok(!inner.evaluate(row, params)?),
58            Self::Cmp { lhs, op, rhs } => {
59                let l = resolve_operand(lhs, row, params)?;
60                let r = resolve_operand(rhs, row, params)?;
61                cmp(&l, *op, &r)
62            }
63            Self::Between {
64                field,
65                low,
66                high,
67                negated,
68            } => {
69                let f = resolve_operand(field, row, params)?;
70                let lo = resolve_operand(low, row, params)?;
71                let hi = resolve_operand(high, row, params)?;
72                let in_range = cmp(&f, CmpOp::Ge, &lo)? && cmp(&f, CmpOp::Le, &hi)?;
73                Ok(if *negated { !in_range } else { in_range })
74            }
75        }
76    }
77}
78
79fn resolve_operand<R: RowAccess>(
80    op: &Operand,
81    row: &R,
82    params: &[Value],
83) -> Result<Value, EvalError> {
84    match op {
85        Operand::Literal(v) => Ok(v.clone()),
86        Operand::Field(name) => row
87            .get(name)
88            .ok_or_else(|| EvalError::UnknownField(name.clone())),
89        Operand::Param(i) => params
90            .get(*i as usize)
91            .cloned()
92            .ok_or(EvalError::MissingParam(*i)),
93    }
94}
95
96fn cmp(lhs: &Value, op: CmpOp, rhs: &Value) -> Result<bool, EvalError> {
97    // Numeric promotion for int/float comparisons.
98    if let (Some(l), Some(r)) = (as_f64(lhs), as_f64(rhs)) {
99        return Ok(match op {
100            CmpOp::Eq => (l - r).abs() < f64::EPSILON,
101            CmpOp::Neq => (l - r).abs() >= f64::EPSILON,
102            CmpOp::Lt => l < r,
103            CmpOp::Le => l <= r,
104            CmpOp::Gt => l > r,
105            CmpOp::Ge => l >= r,
106            CmpOp::Like => {
107                return Err(EvalError::TypeMismatch("LIKE only for String".into()));
108            }
109        });
110    }
111
112    match (lhs, rhs, op) {
113        (Value::String(a), Value::String(b), CmpOp::Eq) => Ok(a == b),
114        (Value::String(a), Value::String(b), CmpOp::Neq) => Ok(a != b),
115        (Value::String(a), Value::String(b), CmpOp::Lt) => Ok(a < b),
116        (Value::String(a), Value::String(b), CmpOp::Le) => Ok(a <= b),
117        (Value::String(a), Value::String(b), CmpOp::Gt) => Ok(a > b),
118        (Value::String(a), Value::String(b), CmpOp::Ge) => Ok(a >= b),
119        (Value::String(a), Value::String(b), CmpOp::Like) => Ok(like_match(a, b)),
120        (Value::Bool(a), Value::Bool(b), CmpOp::Eq) => Ok(a == b),
121        (Value::Bool(a), Value::Bool(b), CmpOp::Neq) => Ok(a != b),
122        (a, b, op) => Err(EvalError::TypeMismatch(alloc::format!(
123            "{a:?} {op:?} {b:?}"
124        ))),
125    }
126}
127
128fn as_f64(v: &Value) -> Option<f64> {
129    match v {
130        #[allow(clippy::cast_precision_loss)]
131        Value::Int(n) => Some(*n as f64),
132        Value::Float(f) => Some(*f),
133        _ => None,
134    }
135}
136
137/// SQL-92 LIKE match with `%` (zero-or-more) and `_` (exactly one character).
138/// Backslash escape is not implemented — spec §B.2.1 does not require it;
139/// %/_ in data must be introduced by the caller via double-encoding.
140fn like_match(s: &str, pat: &str) -> bool {
141    // Klassisches DP: m[i][j] = s[..i] matcht pat[..j]
142    let s_chars: alloc::vec::Vec<char> = s.chars().collect();
143    let p_chars: alloc::vec::Vec<char> = pat.chars().collect();
144    let (m, n) = (s_chars.len(), p_chars.len());
145    let mut dp = alloc::vec![alloc::vec![false; n + 1]; m + 1];
146    dp[0][0] = true;
147    for j in 1..=n {
148        if p_chars[j - 1] == '%' {
149            dp[0][j] = dp[0][j - 1];
150        }
151    }
152    for i in 1..=m {
153        for j in 1..=n {
154            let pc = p_chars[j - 1];
155            dp[i][j] = if pc == '%' {
156                dp[i - 1][j] || dp[i][j - 1]
157            } else if pc == '_' || pc == s_chars[i - 1] {
158                dp[i - 1][j - 1]
159            } else {
160                false
161            };
162        }
163    }
164    dp[m][n]
165}
166
167#[cfg(test)]
168#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
169mod tests {
170    use super::*;
171    use crate::parser::parse;
172    use alloc::collections::BTreeMap;
173
174    struct MapRow(BTreeMap<String, Value>);
175    impl RowAccess for MapRow {
176        fn get(&self, path: &str) -> Option<Value> {
177            self.0.get(path).cloned()
178        }
179    }
180
181    fn row(pairs: &[(&str, Value)]) -> MapRow {
182        let mut m = BTreeMap::new();
183        for (k, v) in pairs {
184            m.insert((*k).into(), v.clone());
185        }
186        MapRow(m)
187    }
188
189    #[test]
190    fn evaluates_string_eq() {
191        let e = parse("color = 'RED'").unwrap();
192        let r = row(&[("color", Value::String("RED".into()))]);
193        assert_eq!(e.evaluate(&r, &[]), Ok(true));
194    }
195
196    #[test]
197    fn evaluates_int_compare() {
198        let e = parse("x > 10 AND x <= 100").unwrap();
199        let r = row(&[("x", Value::Int(42))]);
200        assert_eq!(e.evaluate(&r, &[]), Ok(true));
201    }
202
203    #[test]
204    fn evaluates_float_int_cross() {
205        // Int on one side, float on the other — promoted to f64.
206        let e = parse("x < 3.5").unwrap();
207        let r = row(&[("x", Value::Int(3))]);
208        assert_eq!(e.evaluate(&r, &[]), Ok(true));
209    }
210
211    #[test]
212    fn evaluates_boolean_not_or() {
213        let e = parse("NOT (x = 0 OR y = 0)").unwrap();
214        let r = row(&[("x", Value::Int(1)), ("y", Value::Int(2))]);
215        assert_eq!(e.evaluate(&r, &[]), Ok(true));
216    }
217
218    #[test]
219    fn evaluates_param() {
220        let e = parse("color = %0").unwrap();
221        let r = row(&[("color", Value::String("BLUE".into()))]);
222        assert_eq!(e.evaluate(&r, &[Value::String("BLUE".into())]), Ok(true),);
223    }
224
225    #[test]
226    fn missing_param_is_error() {
227        let e = parse("color = %0").unwrap();
228        let r = row(&[("color", Value::String("BLUE".into()))]);
229        assert_eq!(e.evaluate(&r, &[]), Err(EvalError::MissingParam(0)),);
230    }
231
232    #[test]
233    fn unknown_field_is_error() {
234        let e = parse("missing = 1").unwrap();
235        let r = row(&[("x", Value::Int(1))]);
236        assert!(matches!(
237            e.evaluate(&r, &[]),
238            Err(EvalError::UnknownField(_))
239        ));
240    }
241
242    #[test]
243    fn like_wildcards() {
244        let e = parse("name LIKE 'foo%'").unwrap();
245        let r_yes = row(&[("name", Value::String("foobar".into()))]);
246        let r_no = row(&[("name", Value::String("barfoo".into()))]);
247        assert_eq!(e.evaluate(&r_yes, &[]), Ok(true));
248        assert_eq!(e.evaluate(&r_no, &[]), Ok(false));
249
250        let single = parse("name LIKE 'a_c'").unwrap();
251        let r_yes = row(&[("name", Value::String("abc".into()))]);
252        let r_no = row(&[("name", Value::String("abbc".into()))]);
253        assert_eq!(single.evaluate(&r_yes, &[]), Ok(true));
254        assert_eq!(single.evaluate(&r_no, &[]), Ok(false));
255    }
256
257    #[test]
258    fn like_on_non_string_rejected() {
259        let e = parse("x LIKE 5").unwrap();
260        let r = row(&[("x", Value::Int(5))]);
261        assert!(matches!(
262            e.evaluate(&r, &[]),
263            Err(EvalError::TypeMismatch(_))
264        ));
265    }
266}