Skip to main content

virtual_frame/
expr.rs

1//! Expression system — predicates for filter, computed columns for mutate.
2//!
3//! DExpr is a simple expression tree: column references, literals, and
4//! binary operations. It evaluates against a DataFrame row-by-row or
5//! column-by-column (the columnar fast path).
6
7use crate::column::Column;
8use crate::dataframe::DataFrame;
9
10/// Data expression — used in filter predicates and mutate computations.
11#[derive(Debug, Clone)]
12pub enum DExpr {
13    /// Column reference: col("name")
14    Col(String),
15    /// Literal integer
16    LitInt(i64),
17    /// Literal float
18    LitFloat(f64),
19    /// Literal bool
20    LitBool(bool),
21    /// Literal string
22    LitStr(String),
23    /// Binary operation
24    BinOp {
25        op: BinOp,
26        left: Box<DExpr>,
27        right: Box<DExpr>,
28    },
29    /// Unary NOT
30    Not(Box<DExpr>),
31    /// AND of two boolean expressions
32    And(Box<DExpr>, Box<DExpr>),
33    /// OR of two boolean expressions
34    Or(Box<DExpr>, Box<DExpr>),
35}
36
37/// Binary operators supported in expressions.
38#[derive(Debug, Clone, Copy)]
39pub enum BinOp {
40    Add,
41    Sub,
42    Mul,
43    Div,
44    Eq,
45    Ne,
46    Lt,
47    Le,
48    Gt,
49    Ge,
50}
51
52/// Result of evaluating a DExpr at a single row.
53#[derive(Debug, Clone)]
54pub enum ExprValue {
55    Int(i64),
56    Float(f64),
57    Bool(bool),
58    Str(String),
59}
60
61impl ExprValue {
62    pub fn type_name(&self) -> &'static str {
63        match self {
64            ExprValue::Int(_) => "Int",
65            ExprValue::Float(_) => "Float",
66            ExprValue::Bool(_) => "Bool",
67            ExprValue::Str(_) => "Str",
68        }
69    }
70
71    pub fn as_f64(&self) -> Option<f64> {
72        match self {
73            ExprValue::Int(v) => Some(*v as f64),
74            ExprValue::Float(v) => Some(*v),
75            ExprValue::Bool(v) => Some(if *v { 1.0 } else { 0.0 }),
76            ExprValue::Str(_) => None,
77        }
78    }
79
80    pub fn as_bool(&self) -> Option<bool> {
81        match self {
82            ExprValue::Bool(v) => Some(*v),
83            _ => None,
84        }
85    }
86}
87
88/// Evaluate a DExpr at a single row of a DataFrame.
89pub fn eval_expr_row(df: &DataFrame, expr: &DExpr, row: usize) -> Result<ExprValue, String> {
90    match expr {
91        DExpr::Col(name) => {
92            let col = df
93                .get_column(name)
94                .ok_or_else(|| format!("column `{}` not found", name))?;
95            Ok(match col {
96                Column::Int(v) => ExprValue::Int(v[row]),
97                Column::Float(v) => ExprValue::Float(v[row]),
98                Column::Str(v) => ExprValue::Str(v[row].clone()),
99                Column::Bool(v) => ExprValue::Bool(v[row]),
100            })
101        }
102        DExpr::LitInt(v) => Ok(ExprValue::Int(*v)),
103        DExpr::LitFloat(v) => Ok(ExprValue::Float(*v)),
104        DExpr::LitBool(v) => Ok(ExprValue::Bool(*v)),
105        DExpr::LitStr(v) => Ok(ExprValue::Str(v.clone())),
106        DExpr::BinOp { op, left, right } => {
107            let lv = eval_expr_row(df, left, row)?;
108            let rv = eval_expr_row(df, right, row)?;
109            eval_binop(*op, &lv, &rv)
110        }
111        DExpr::Not(inner) => {
112            let v = eval_expr_row(df, inner, row)?;
113            match v {
114                ExprValue::Bool(b) => Ok(ExprValue::Bool(!b)),
115                _ => Err(format!("NOT requires Bool, got {}", v.type_name())),
116            }
117        }
118        DExpr::And(a, b) => {
119            let av = eval_expr_row(df, a, row)?;
120            let bv = eval_expr_row(df, b, row)?;
121            match (av, bv) {
122                (ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x && y)),
123                _ => Err("AND requires two Bool operands".into()),
124            }
125        }
126        DExpr::Or(a, b) => {
127            let av = eval_expr_row(df, a, row)?;
128            let bv = eval_expr_row(df, b, row)?;
129            match (av, bv) {
130                (ExprValue::Bool(x), ExprValue::Bool(y)) => Ok(ExprValue::Bool(x || y)),
131                _ => Err("OR requires two Bool operands".into()),
132            }
133        }
134    }
135}
136
137fn eval_binop(op: BinOp, lv: &ExprValue, rv: &ExprValue) -> Result<ExprValue, String> {
138    match op {
139        // Comparison operators
140        BinOp::Eq => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Equal))),
141        BinOp::Ne => Ok(ExprValue::Bool(cmp_values(lv, rv) != Some(std::cmp::Ordering::Equal))),
142        BinOp::Lt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Less))),
143        BinOp::Le => Ok(ExprValue::Bool(matches!(
144            cmp_values(lv, rv),
145            Some(std::cmp::Ordering::Less) | Some(std::cmp::Ordering::Equal)
146        ))),
147        BinOp::Gt => Ok(ExprValue::Bool(cmp_values(lv, rv) == Some(std::cmp::Ordering::Greater))),
148        BinOp::Ge => Ok(ExprValue::Bool(matches!(
149            cmp_values(lv, rv),
150            Some(std::cmp::Ordering::Greater) | Some(std::cmp::Ordering::Equal)
151        ))),
152        // Arithmetic operators
153        BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div => {
154            let l = lv.as_f64().ok_or_else(|| {
155                format!("arithmetic requires numeric types, got {}", lv.type_name())
156            })?;
157            let r = rv.as_f64().ok_or_else(|| {
158                format!("arithmetic requires numeric types, got {}", rv.type_name())
159            })?;
160            let result = match op {
161                BinOp::Add => l + r,
162                BinOp::Sub => l - r,
163                BinOp::Mul => l * r,
164                BinOp::Div => l / r,
165                _ => unreachable!(),
166            };
167            Ok(ExprValue::Float(result))
168        }
169    }
170}
171
172fn cmp_values(a: &ExprValue, b: &ExprValue) -> Option<std::cmp::Ordering> {
173    match (a, b) {
174        (ExprValue::Int(x), ExprValue::Int(y)) => Some(x.cmp(y)),
175        (ExprValue::Float(x), ExprValue::Float(y)) => x.partial_cmp(y),
176        (ExprValue::Int(x), ExprValue::Float(y)) => (*x as f64).partial_cmp(y),
177        (ExprValue::Float(x), ExprValue::Int(y)) => x.partial_cmp(&(*y as f64)),
178        (ExprValue::Str(x), ExprValue::Str(y)) => Some(x.cmp(y)),
179        (ExprValue::Bool(x), ExprValue::Bool(y)) => Some(x.cmp(y)),
180        (ExprValue::Str(x), ExprValue::Int(y)) => Some(x.cmp(&y.to_string())),
181        (ExprValue::Int(x), ExprValue::Str(y)) => Some(x.to_string().cmp(y)),
182        _ => None,
183    }
184}
185
186/// Try to evaluate a predicate in columnar mode (fast path).
187///
188/// For simple predicates like `col("x") > 5`, this scans the column
189/// directly instead of evaluating row-by-row. Returns None if the
190/// expression is too complex for the columnar fast path.
191pub fn try_eval_predicate_columnar(
192    df: &DataFrame,
193    expr: &DExpr,
194    current_mask: &crate::bitmask::BitMask,
195) -> Option<crate::bitmask::BitMask> {
196    match expr {
197        DExpr::BinOp { op, left, right } => {
198            // Only handle Col op Literal patterns
199            let (col_name, lit, flip) = match (left.as_ref(), right.as_ref()) {
200                (DExpr::Col(name), lit) if is_literal(lit) => (name.as_str(), lit, false),
201                (lit, DExpr::Col(name)) if is_literal(lit) => (name.as_str(), lit, true),
202                _ => return None,
203            };
204
205            let col = df.get_column(col_name)?;
206            let nrows = df.nrows();
207            let mut new_words = current_mask.words.clone();
208
209            match (col, lit) {
210                (Column::Int(data), DExpr::LitInt(val)) => {
211                    for row in current_mask.iter_set() {
212                        let (l, r) = if flip {
213                            (*val, data[row])
214                        } else {
215                            (data[row], *val)
216                        };
217                        if !cmp_i64(*op, l, r) {
218                            new_words[row / 64] &= !(1u64 << (row % 64));
219                        }
220                    }
221                }
222                (Column::Float(data), DExpr::LitFloat(val)) => {
223                    for row in current_mask.iter_set() {
224                        let (l, r) = if flip {
225                            (*val, data[row])
226                        } else {
227                            (data[row], *val)
228                        };
229                        if !cmp_f64(*op, l, r) {
230                            new_words[row / 64] &= !(1u64 << (row % 64));
231                        }
232                    }
233                }
234                (Column::Int(data), DExpr::LitFloat(val)) => {
235                    for row in current_mask.iter_set() {
236                        let (l, r) = if flip {
237                            (*val, data[row] as f64)
238                        } else {
239                            (data[row] as f64, *val)
240                        };
241                        if !cmp_f64(*op, l, r) {
242                            new_words[row / 64] &= !(1u64 << (row % 64));
243                        }
244                    }
245                }
246                (Column::Str(data), DExpr::LitStr(val)) => {
247                    for row in current_mask.iter_set() {
248                        let pass = if flip {
249                            cmp_str(*op, val, &data[row])
250                        } else {
251                            cmp_str(*op, &data[row], val)
252                        };
253                        if !pass {
254                            new_words[row / 64] &= !(1u64 << (row % 64));
255                        }
256                    }
257                }
258                _ => return None,
259            }
260
261            Some(crate::bitmask::BitMask {
262                words: new_words,
263                nrows,
264            })
265        }
266        _ => None,
267    }
268}
269
270fn is_literal(expr: &DExpr) -> bool {
271    matches!(
272        expr,
273        DExpr::LitInt(_) | DExpr::LitFloat(_) | DExpr::LitBool(_) | DExpr::LitStr(_)
274    )
275}
276
277#[inline]
278fn cmp_i64(op: BinOp, l: i64, r: i64) -> bool {
279    match op {
280        BinOp::Eq => l == r,
281        BinOp::Ne => l != r,
282        BinOp::Lt => l < r,
283        BinOp::Le => l <= r,
284        BinOp::Gt => l > r,
285        BinOp::Ge => l >= r,
286        _ => false,
287    }
288}
289
290#[inline]
291fn cmp_f64(op: BinOp, l: f64, r: f64) -> bool {
292    match op {
293        BinOp::Eq => l == r,
294        BinOp::Ne => l != r,
295        BinOp::Lt => l < r,
296        BinOp::Le => l <= r,
297        BinOp::Gt => l > r,
298        BinOp::Ge => l >= r,
299        _ => false,
300    }
301}
302
303#[inline]
304fn cmp_str(op: BinOp, l: &str, r: &str) -> bool {
305    match op {
306        BinOp::Eq => l == r,
307        BinOp::Ne => l != r,
308        BinOp::Lt => l < r,
309        BinOp::Le => l <= r,
310        BinOp::Gt => l > r,
311        BinOp::Ge => l >= r,
312        _ => false,
313    }
314}
315
316// ── Builder helpers (for Python API ergonomics) ──────────────────────────
317
318/// Create a column reference expression.
319pub fn col(name: &str) -> DExpr {
320    DExpr::Col(name.to_string())
321}
322
323/// Create a binary operation expression.
324pub fn binop(op: BinOp, left: DExpr, right: DExpr) -> DExpr {
325    DExpr::BinOp {
326        op,
327        left: Box::new(left),
328        right: Box::new(right),
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_eval_comparison() {
338        let df = DataFrame::from_columns(vec![
339            ("x".into(), Column::Int(vec![10, 20, 30])),
340        ])
341        .unwrap();
342        let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(15));
343        let r0 = eval_expr_row(&df, &expr, 0).unwrap();
344        let r1 = eval_expr_row(&df, &expr, 1).unwrap();
345        assert_eq!(r0.as_bool(), Some(false)); // 10 > 15 = false
346        assert_eq!(r1.as_bool(), Some(true)); // 20 > 15 = true
347    }
348
349    #[test]
350    fn test_columnar_fast_path() {
351        let df = DataFrame::from_columns(vec![
352            ("x".into(), Column::Int(vec![1, 2, 3, 4, 5])),
353        ])
354        .unwrap();
355        let mask = crate::bitmask::BitMask::all_true(5);
356        let expr = binop(BinOp::Gt, col("x"), DExpr::LitInt(3));
357        let result = try_eval_predicate_columnar(&df, &expr, &mask).unwrap();
358        let indices: Vec<usize> = result.iter_set().collect();
359        assert_eq!(indices, vec![3, 4]); // x=4 and x=5
360    }
361}