Skip to main content

vortex_datafusion/convert/
exprs.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_schema::DataType;
7use arrow_schema::Schema;
8use datafusion_common::Result as DFResult;
9use datafusion_common::exec_datafusion_err;
10use datafusion_common::tree_node::TreeNode;
11use datafusion_common::tree_node::TreeNodeRecursion;
12use datafusion_expr::Operator as DFOperator;
13use datafusion_functions::core::getfield::GetFieldFunc;
14use datafusion_physical_expr::PhysicalExpr;
15use datafusion_physical_expr::ScalarFunctionExpr;
16use datafusion_physical_expr::projection::ProjectionExpr;
17use datafusion_physical_expr::projection::ProjectionExprs;
18use datafusion_physical_expr::utils::collect_columns;
19use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
20use datafusion_physical_plan::expressions as df_expr;
21use itertools::Itertools;
22use vortex::dtype::DType;
23use vortex::dtype::Nullability;
24use vortex::dtype::arrow::FromArrowType;
25use vortex::expr::Expression;
26use vortex::expr::and_collect;
27use vortex::expr::cast;
28use vortex::expr::get_item;
29use vortex::expr::is_not_null;
30use vortex::expr::is_null;
31use vortex::expr::list_contains;
32use vortex::expr::lit;
33use vortex::expr::nested_case_when;
34use vortex::expr::not;
35use vortex::expr::pack;
36use vortex::expr::root;
37use vortex::scalar::Scalar;
38use vortex::scalar_fn::ScalarFnVTableExt;
39use vortex::scalar_fn::fns::binary::Binary;
40use vortex::scalar_fn::fns::like::Like;
41use vortex::scalar_fn::fns::like::LikeOptions;
42use vortex::scalar_fn::fns::operators::Operator;
43
44use crate::convert::FromDataFusion;
45
46/// Result of splitting a projection into Vortex expressions and leftover DataFusion projections.
47pub struct ProcessedProjection {
48    pub scan_projection: Expression,
49    pub leftover_projection: ProjectionExprs,
50}
51
52/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
53pub(crate) fn make_vortex_predicate(
54    expr_convertor: &dyn ExpressionConvertor,
55    predicate: &[Arc<dyn PhysicalExpr>],
56) -> DFResult<Option<Expression>> {
57    let exprs = predicate
58        .iter()
59        .map(|e| expr_convertor.convert(e.as_ref()))
60        .collect::<DFResult<Vec<_>>>()?;
61
62    Ok(and_collect(exprs))
63}
64
65/// Trait for converting DataFusion expressions to Vortex ones.
66pub trait ExpressionConvertor: Send + Sync {
67    /// Can an expression be pushed down given a specific schema
68    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;
69
70    /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
71    fn convert(&self, expr: &dyn PhysicalExpr) -> DFResult<Expression>;
72
73    /// Split a projection into Vortex expressions that can be pushed down and leftover
74    /// DataFusion projections that need to be evaluated after the scan.
75    fn split_projection(
76        &self,
77        source_projection: ProjectionExprs,
78        input_schema: &Schema,
79        output_schema: &Schema,
80    ) -> DFResult<ProcessedProjection>;
81
82    /// Create a projection that reads only the required columns without pushing down
83    /// any expressions. All projection logic is applied after the scan.
84    fn no_pushdown_projection(
85        &self,
86        source_projection: ProjectionExprs,
87        input_schema: &Schema,
88    ) -> DFResult<ProcessedProjection> {
89        // Get all unique column indices referenced by the projection
90        let column_indices = source_projection.column_indices();
91
92        // Create scan projection that reads the required columns
93        let scan_columns: Vec<(String, Expression)> = column_indices
94            .into_iter()
95            .map(|idx| {
96                let field = input_schema.field(idx);
97                let name = field.name().clone();
98                (name.clone(), get_item(name, root()))
99            })
100            .collect();
101
102        Ok(ProcessedProjection {
103            scan_projection: pack(scan_columns, Nullability::NonNullable),
104            leftover_projection: source_projection,
105        })
106    }
107}
108
109/// The default [`ExpressionConvertor`].
110#[derive(Default)]
111pub struct DefaultExpressionConvertor {}
112
113impl DefaultExpressionConvertor {
114    /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
115    fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
116        if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
117        {
118            // DataFusion's GetFieldFunc flattens nested field access into a single call
119            // with multiple field name arguments. For example, `outer.inner.leaf` becomes
120            // get_field(Column("outer"), "inner", "leaf"). We build a chain of get_item
121            // calls for each field name in the path.
122            let (source_expr, field_names) = get_field_fn
123                .args()
124                .split_first()
125                .ok_or_else(|| exec_datafusion_err!("get_field missing source expression"))?;
126
127            let mut result = self.convert(source_expr.as_ref())?;
128            for expr in field_names {
129                let field_name = expr
130                    .as_any()
131                    .downcast_ref::<df_expr::Literal>()
132                    .ok_or_else(|| exec_datafusion_err!("get_field field name must be a literal"))?
133                    .value()
134                    .try_as_str()
135                    .flatten()
136                    .ok_or_else(|| {
137                        exec_datafusion_err!("get_field field name must be a UTF-8 string")
138                    })?;
139                result = get_item(field_name.to_string(), result);
140            }
141            return Ok(result);
142        }
143
144        Err(exec_datafusion_err!(
145            "Unsupported ScalarFunctionExpr: {}",
146            scalar_fn.name()
147        ))
148    }
149
150    /// Attempts to convert a DataFusion CaseExpr to a Vortex expression.
151    fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult<Expression> {
152        // DataFusion CaseExpr has:
153        // - expr(): Optional base expression (for "CASE expr WHEN ..." form)
154        // - when_then_expr(): Vec of (when, then) pairs
155        // - else_expr(): Optional else expression
156
157        // We don't support the "CASE expr WHEN value1 THEN result1" form yet
158        if case_expr.expr().is_some() {
159            return Err(exec_datafusion_err!(
160                "CASE expr WHEN form is not yet supported, only searched CASE is supported"
161            ));
162        }
163
164        let when_then_pairs = case_expr.when_then_expr();
165        if when_then_pairs.is_empty() {
166            return Err(exec_datafusion_err!(
167                "CASE expression must have at least one WHEN clause"
168            ));
169        }
170
171        // Convert all when/then pairs to (condition, value) tuples
172        let mut pairs = Vec::with_capacity(when_then_pairs.len());
173        for (when_expr, then_expr) in when_then_pairs {
174            let condition = self.convert(when_expr.as_ref())?;
175            let value = self.convert(then_expr.as_ref())?;
176            pairs.push((condition, value));
177        }
178
179        // Convert optional else expression
180        let else_value = case_expr
181            .else_expr()
182            .map(|e| self.convert(e.as_ref()))
183            .transpose()?;
184
185        // Build a single n-ary CASE WHEN expression from DataFusion WHEN/THEN pairs
186        Ok(nested_case_when(pairs, else_value))
187    }
188}
189
190impl ExpressionConvertor for DefaultExpressionConvertor {
191    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
192        can_be_pushed_down_impl(expr, schema)
193    }
194
195    fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
196        // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
197        //  for that node, up to any `and` or `or` node.
198        if let Some(binary_expr) = df.as_any().downcast_ref::<df_expr::BinaryExpr>() {
199            let left = self.convert(binary_expr.left().as_ref())?;
200            let right = self.convert(binary_expr.right().as_ref())?;
201            let operator = try_operator_from_df(binary_expr.op())?;
202
203            return Ok(Binary.new_expr(operator, [left, right]));
204        }
205
206        if let Some(col_expr) = df.as_any().downcast_ref::<df_expr::Column>() {
207            return Ok(get_item(col_expr.name().to_owned(), root()));
208        }
209
210        if let Some(like) = df.as_any().downcast_ref::<df_expr::LikeExpr>() {
211            let child = self.convert(like.expr().as_ref())?;
212            let pattern = self.convert(like.pattern().as_ref())?;
213            return Ok(Like.new_expr(
214                LikeOptions {
215                    negated: like.negated(),
216                    case_insensitive: like.case_insensitive(),
217                },
218                [child, pattern],
219            ));
220        }
221
222        if let Some(literal) = df.as_any().downcast_ref::<df_expr::Literal>() {
223            let value = Scalar::from_df(literal.value());
224            return Ok(lit(value));
225        }
226
227        if let Some(cast_expr) = df.as_any().downcast_ref::<df_expr::CastExpr>() {
228            let cast_dtype = DType::from_arrow((cast_expr.cast_type(), Nullability::Nullable));
229            let child = self.convert(cast_expr.expr().as_ref())?;
230            return Ok(cast(child, cast_dtype));
231        }
232
233        if let Some(cast_col_expr) = df.as_any().downcast_ref::<df_expr::CastColumnExpr>() {
234            let target = cast_col_expr.target_field();
235
236            let target_dtype = DType::from_arrow((target.data_type(), target.is_nullable().into()));
237            let child = self.convert(cast_col_expr.expr().as_ref())?;
238            return Ok(cast(child, target_dtype));
239        }
240
241        if let Some(is_null_expr) = df.as_any().downcast_ref::<df_expr::IsNullExpr>() {
242            let arg = self.convert(is_null_expr.arg().as_ref())?;
243            return Ok(is_null(arg));
244        }
245
246        if let Some(is_not_null_expr) = df.as_any().downcast_ref::<df_expr::IsNotNullExpr>() {
247            let arg = self.convert(is_not_null_expr.arg().as_ref())?;
248            return Ok(is_not_null(arg));
249        }
250
251        if let Some(in_list) = df.as_any().downcast_ref::<df_expr::InListExpr>() {
252            let value = self.convert(in_list.expr().as_ref())?;
253            let list_elements: Vec<_> = in_list
254                .list()
255                .iter()
256                .map(|e| {
257                    if let Some(lit) = e.as_any().downcast_ref::<df_expr::Literal>() {
258                        Ok(Scalar::from_df(lit.value()))
259                    } else {
260                        Err(exec_datafusion_err!("Failed to cast sub-expression"))
261                    }
262                })
263                .try_collect()?;
264
265            let list = Scalar::list(
266                list_elements[0].dtype().clone(),
267                list_elements,
268                Nullability::Nullable,
269            );
270            let expr = list_contains(lit(list), value);
271
272            return Ok(if in_list.negated() { not(expr) } else { expr });
273        }
274
275        if let Some(scalar_fn) = df.as_any().downcast_ref::<ScalarFunctionExpr>() {
276            return self.try_convert_scalar_function(scalar_fn);
277        }
278
279        if let Some(case_expr) = df.as_any().downcast_ref::<df_expr::CaseExpr>() {
280            return self.try_convert_case_expr(case_expr);
281        }
282
283        Err(exec_datafusion_err!(
284            "Couldn't convert DataFusion physical {df} expression to a vortex expression"
285        ))
286    }
287
288    fn split_projection(
289        &self,
290        source_projection: ProjectionExprs,
291        input_schema: &Schema,
292        output_schema: &Schema,
293    ) -> DFResult<ProcessedProjection> {
294        let mut scan_projection = vec![];
295        let mut leftover_projection: Vec<ProjectionExpr> = vec![];
296
297        for projection_expr in source_projection.iter() {
298            let r = projection_expr.expr.apply(|node| {
299                // We only pull column children of scalar functions that we can't push into the scan.
300                if let Some(scalar_fn_expr) = node.as_any().downcast_ref::<ScalarFunctionExpr>()
301                    && !can_scalar_fn_be_pushed_down(scalar_fn_expr)
302                {
303                    scan_projection.extend(
304                        collect_columns(node)
305                            .into_iter()
306                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
307                    );
308
309                    leftover_projection.push(projection_expr.clone());
310                    return Ok(TreeNodeRecursion::Stop);
311                }
312
313                // DataFusion assumes different decimal types can be coerced.
314                // Vortex expects a perfect match so we don't push it down.
315                if let Some(binary_expr) = node.as_any().downcast_ref::<df_expr::BinaryExpr>()
316                    && binary_expr.op().is_numerical_operators()
317                    && (is_decimal(&binary_expr.left().data_type(input_schema)?)
318                        && is_decimal(&binary_expr.right().data_type(input_schema)?))
319                {
320                    scan_projection.extend(
321                        collect_columns(node)
322                            .into_iter()
323                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
324                    );
325
326                    leftover_projection.push(projection_expr.clone());
327                    return Ok(TreeNodeRecursion::Stop);
328                }
329
330                Ok(TreeNodeRecursion::Continue)
331            })?;
332
333            // if we didn't stop early
334            if matches!(r, TreeNodeRecursion::Continue) {
335                scan_projection.push((
336                    projection_expr.alias.clone(),
337                    self.convert(projection_expr.expr.as_ref())?,
338                ));
339                leftover_projection.push(ProjectionExpr {
340                    expr: Arc::new(df_expr::Column::new_with_schema(
341                        projection_expr.alias.as_str(),
342                        output_schema,
343                    )?),
344                    alias: projection_expr.alias.clone(),
345                });
346            }
347        }
348
349        Ok(ProcessedProjection {
350            scan_projection: pack(scan_projection, Nullability::NonNullable),
351            leftover_projection: leftover_projection.into(),
352        })
353    }
354}
355
356fn try_operator_from_df(value: &DFOperator) -> DFResult<Operator> {
357    match value {
358        DFOperator::Eq => Ok(Operator::Eq),
359        DFOperator::NotEq => Ok(Operator::NotEq),
360        DFOperator::Lt => Ok(Operator::Lt),
361        DFOperator::LtEq => Ok(Operator::Lte),
362        DFOperator::Gt => Ok(Operator::Gt),
363        DFOperator::GtEq => Ok(Operator::Gte),
364        DFOperator::And => Ok(Operator::And),
365        DFOperator::Or => Ok(Operator::Or),
366        DFOperator::Plus => Ok(Operator::Add),
367        DFOperator::Minus => Ok(Operator::Sub),
368        DFOperator::Multiply => Ok(Operator::Mul),
369        DFOperator::Divide => Ok(Operator::Div),
370        DFOperator::IsDistinctFrom
371        | DFOperator::IsNotDistinctFrom
372        | DFOperator::RegexMatch
373        | DFOperator::RegexIMatch
374        | DFOperator::RegexNotMatch
375        | DFOperator::RegexNotIMatch
376        | DFOperator::LikeMatch
377        | DFOperator::ILikeMatch
378        | DFOperator::NotLikeMatch
379        | DFOperator::NotILikeMatch
380        | DFOperator::BitwiseAnd
381        | DFOperator::BitwiseOr
382        | DFOperator::BitwiseXor
383        | DFOperator::BitwiseShiftRight
384        | DFOperator::BitwiseShiftLeft
385        | DFOperator::StringConcat
386        | DFOperator::AtArrow
387        | DFOperator::ArrowAt
388        | DFOperator::Modulo
389        | DFOperator::Arrow
390        | DFOperator::LongArrow
391        | DFOperator::HashArrow
392        | DFOperator::HashLongArrow
393        | DFOperator::AtAt
394        | DFOperator::IntegerDivide
395        | DFOperator::HashMinus
396        | DFOperator::AtQuestion
397        | DFOperator::Question
398        | DFOperator::QuestionAnd
399        | DFOperator::QuestionPipe
400        | DFOperator::Colon => {
401            tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
402            Err(exec_datafusion_err!(
403                "Unsupported datafusion operator {value}"
404            ))
405        }
406    }
407}
408
409fn can_be_pushed_down_impl(df_expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
410    // We currently do not support pushdown of dynamic expressions in DF.
411    // See issue: https://github.com/vortex-data/vortex/issues/4034
412    if is_dynamic_physical_expr(df_expr) {
413        return false;
414    }
415
416    let expr = df_expr.as_any();
417    if let Some(binary) = expr.downcast_ref::<df_expr::BinaryExpr>() {
418        can_binary_be_pushed_down(binary, schema)
419    } else if let Some(col) = expr.downcast_ref::<df_expr::Column>() {
420        schema
421            .field_with_name(col.name())
422            .ok()
423            .is_some_and(|field| supported_data_types(field.data_type()))
424    } else if let Some(like) = expr.downcast_ref::<df_expr::LikeExpr>() {
425        can_be_pushed_down_impl(like.expr(), schema)
426            && can_be_pushed_down_impl(like.pattern(), schema)
427    } else if let Some(lit) = expr.downcast_ref::<df_expr::Literal>() {
428        supported_data_types(&lit.value().data_type())
429    } else if let Some(cast_expr) = expr.downcast_ref::<df_expr::CastExpr>() {
430        // CastExpr child must be an expression type that convert() can handle
431        is_convertible_expr(cast_expr.expr())
432    } else if let Some(cast_col_expr) = expr.downcast_ref::<df_expr::CastColumnExpr>() {
433        // CastColumnExpr child must be an expression type that convert() can handle
434        is_convertible_expr(cast_col_expr.expr())
435    } else if let Some(is_null) = expr.downcast_ref::<df_expr::IsNullExpr>() {
436        can_be_pushed_down_impl(is_null.arg(), schema)
437    } else if let Some(is_not_null) = expr.downcast_ref::<df_expr::IsNotNullExpr>() {
438        can_be_pushed_down_impl(is_not_null.arg(), schema)
439    } else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
440        can_be_pushed_down_impl(in_list.expr(), schema)
441            && in_list
442                .list()
443                .iter()
444                .all(|e| can_be_pushed_down_impl(e, schema))
445    } else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
446        can_scalar_fn_be_pushed_down(scalar_fn)
447    } else if let Some(case_expr) = expr.downcast_ref::<df_expr::CaseExpr>() {
448        can_case_be_pushed_down(case_expr, schema)
449    } else {
450        tracing::debug!(%df_expr, "DataFusion expression can't be pushed down");
451        false
452    }
453}
454
455/// Checks if an expression type is one that convert() can handle.
456/// This is less restrictive than can_be_pushed_down since it only checks
457/// expression types, not data type support.
458fn is_convertible_expr(df_expr: &Arc<dyn PhysicalExpr>) -> bool {
459    let expr = df_expr.as_any();
460
461    // Expression types that convert() handles
462    expr.downcast_ref::<df_expr::BinaryExpr>().is_some()
463        || expr.downcast_ref::<df_expr::Column>().is_some()
464        || expr.downcast_ref::<df_expr::LikeExpr>().is_some()
465        || expr.downcast_ref::<df_expr::Literal>().is_some()
466        || expr
467            .downcast_ref::<df_expr::CastExpr>()
468            .is_some_and(|e| is_convertible_expr(e.expr()))
469        || expr
470            .downcast_ref::<df_expr::CastColumnExpr>()
471            .is_some_and(|e| is_convertible_expr(e.expr()))
472        || expr.downcast_ref::<df_expr::IsNullExpr>().is_some()
473        || expr.downcast_ref::<df_expr::IsNotNullExpr>().is_some()
474        || expr.downcast_ref::<df_expr::InListExpr>().is_some()
475        || expr
476            .downcast_ref::<ScalarFunctionExpr>()
477            .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some())
478}
479
480fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
481    let is_op_supported = try_operator_from_df(binary.op()).is_ok();
482    is_op_supported
483        && can_be_pushed_down_impl(binary.left(), schema)
484        && can_be_pushed_down_impl(binary.right(), schema)
485}
486
487fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool {
488    // We only support the "searched CASE" form (CASE WHEN cond THEN result ...)
489    // not the "simple CASE" form (CASE expr WHEN value THEN result ...)
490    if case_expr.expr().is_some() {
491        return false;
492    }
493
494    // Check all when/then pairs
495    for (when_expr, then_expr) in case_expr.when_then_expr() {
496        if !can_be_pushed_down_impl(when_expr, schema)
497            || !can_be_pushed_down_impl(then_expr, schema)
498        {
499            return false;
500        }
501    }
502
503    // Check the optional else clause
504    if let Some(else_expr) = case_expr.else_expr()
505        && !can_be_pushed_down_impl(else_expr, schema)
506    {
507        return false;
508    }
509
510    true
511}
512
513fn supported_data_types(dt: &DataType) -> bool {
514    use DataType::*;
515
516    // For dictionary types, check if the value type is supported.
517    if let Dictionary(_, value_type) = dt {
518        return supported_data_types(value_type.as_ref());
519    }
520
521    let is_supported = dt.is_null()
522        || dt.is_numeric()
523        || matches!(
524            dt,
525            Boolean
526                | Utf8
527                | LargeUtf8
528                | Utf8View
529                | Binary
530                | LargeBinary
531                | BinaryView
532                | Date32
533                | Date64
534                | Timestamp(_, _)
535                | Time32(_)
536                | Time64(_)
537        );
538
539    if !is_supported {
540        tracing::debug!("DataFusion data type {dt:?} is not supported");
541    }
542
543    is_supported
544}
545
546/// Checks if a scalar function can be pushed down.
547/// Currently only GetFieldFunc is supported.
548fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
549    ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
550}
551
552// TODO(adam): Replace with `DataType::is_decimal` once its released.
553fn is_decimal(dt: &DataType) -> bool {
554    matches!(
555        dt,
556        DataType::Decimal32(_, _)
557            | DataType::Decimal64(_, _)
558            | DataType::Decimal128(_, _)
559            | DataType::Decimal256(_, _)
560    )
561}
562
563#[cfg(test)]
564mod tests {
565    use std::sync::Arc;
566
567    use arrow_schema::DataType;
568    use arrow_schema::Field;
569    use arrow_schema::Schema;
570    use arrow_schema::TimeUnit as ArrowTimeUnit;
571    use datafusion_common::ScalarValue;
572    use datafusion_expr::Operator as DFOperator;
573    use datafusion_physical_expr::PhysicalExpr;
574    use datafusion_physical_plan::expressions as df_expr;
575    use insta::assert_snapshot;
576    use rstest::rstest;
577
578    use super::*;
579    use crate::common_tests::TestSessionContext;
580
581    #[rstest::fixture]
582    fn test_schema() -> Schema {
583        Schema::new(vec![
584            Field::new("id", DataType::Int32, false),
585            Field::new("name", DataType::Utf8, true),
586            Field::new("score", DataType::Float64, true),
587            Field::new("active", DataType::Boolean, false),
588            Field::new(
589                "created_at",
590                DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
591                true,
592            ),
593            Field::new(
594                "unsupported_list",
595                DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
596                true,
597            ),
598        ])
599    }
600
601    #[test]
602    fn test_make_vortex_predicate_empty() {
603        let expr_convertor = DefaultExpressionConvertor::default();
604        let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
605        assert!(result.is_none());
606    }
607
608    #[test]
609    fn test_make_vortex_predicate_single() {
610        let expr_convertor = DefaultExpressionConvertor::default();
611        let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
612        let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
613        assert!(result.is_some());
614    }
615
616    #[test]
617    fn test_make_vortex_predicate_multiple() {
618        let expr_convertor = DefaultExpressionConvertor::default();
619        let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
620        let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
621        let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
622        assert!(result.is_some());
623        // Result should be an AND expression combining the two columns
624    }
625
626    #[rstest]
627    #[case::eq(DFOperator::Eq, Operator::Eq)]
628    #[case::not_eq(DFOperator::NotEq, Operator::NotEq)]
629    #[case::lt(DFOperator::Lt, Operator::Lt)]
630    #[case::lte(DFOperator::LtEq, Operator::Lte)]
631    #[case::gt(DFOperator::Gt, Operator::Gt)]
632    #[case::gte(DFOperator::GtEq, Operator::Gte)]
633    #[case::and(DFOperator::And, Operator::And)]
634    #[case::or(DFOperator::Or, Operator::Or)]
635    #[case::plus(DFOperator::Plus, Operator::Add)]
636    #[case::plus(DFOperator::Minus, Operator::Sub)]
637    #[case::plus(DFOperator::Multiply, Operator::Mul)]
638    #[case::plus(DFOperator::Divide, Operator::Div)]
639    fn test_operator_conversion_supported(
640        #[case] df_op: DFOperator,
641        #[case] expected_vortex_op: Operator,
642    ) {
643        let result = try_operator_from_df(&df_op).unwrap();
644        assert_eq!(result, expected_vortex_op);
645    }
646
647    #[rstest]
648    #[case::modulo(DFOperator::Modulo)]
649    #[case::bitwise_and(DFOperator::BitwiseAnd)]
650    #[case::regex_match(DFOperator::RegexMatch)]
651    #[case::like_match(DFOperator::LikeMatch)]
652    fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
653        let result = try_operator_from_df(&df_op);
654        assert!(result.is_err());
655        assert!(
656            result
657                .unwrap_err()
658                .to_string()
659                .contains("Unsupported datafusion operator")
660        );
661    }
662
663    #[test]
664    fn test_expr_from_df_column() {
665        let col_expr = df_expr::Column::new("test_column", 0);
666        let result = DefaultExpressionConvertor::default()
667            .convert(&col_expr)
668            .unwrap();
669
670        assert_snapshot!(result.display_tree().to_string(), @r"
671        vortex.get_item(test_column)
672        └── input: vortex.root()
673        ");
674    }
675
676    #[test]
677    fn test_expr_from_df_literal() {
678        let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
679        let result = DefaultExpressionConvertor::default()
680            .convert(&literal_expr)
681            .unwrap();
682
683        assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
684    }
685
686    #[test]
687    fn test_expr_from_df_binary() {
688        let left = Arc::new(df_expr::Column::new("left", 0)) as Arc<dyn PhysicalExpr>;
689        let right =
690            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
691        let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);
692
693        let result = DefaultExpressionConvertor::default()
694            .convert(&binary_expr)
695            .unwrap();
696
697        assert_snapshot!(result.display_tree().to_string(), @r"
698        vortex.binary(=)
699        ├── lhs: vortex.get_item(left)
700        │   └── input: vortex.root()
701        └── rhs: vortex.literal(42i32)
702        ");
703    }
704
705    #[rstest]
706    #[case::like_normal(false, false)]
707    #[case::like_negated(true, false)]
708    #[case::like_case_insensitive(false, true)]
709    #[case::like_negated_case_insensitive(true, true)]
710    fn test_expr_from_df_like(#[case] negated: bool, #[case] case_insensitive: bool) {
711        let expr = Arc::new(df_expr::Column::new("text_col", 0)) as Arc<dyn PhysicalExpr>;
712        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
713            "test%".to_string(),
714        )))) as Arc<dyn PhysicalExpr>;
715        let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);
716
717        let result = DefaultExpressionConvertor::default()
718            .convert(&like_expr)
719            .unwrap();
720        let like_opts = result.as_::<Like>();
721        assert_eq!(
722            like_opts,
723            &LikeOptions {
724                negated,
725                case_insensitive
726            }
727        );
728    }
729
730    #[rstest]
731    // Supported types
732    #[case::null(DataType::Null, true)]
733    #[case::boolean(DataType::Boolean, true)]
734    #[case::int8(DataType::Int8, true)]
735    #[case::int16(DataType::Int16, true)]
736    #[case::int32(DataType::Int32, true)]
737    #[case::int64(DataType::Int64, true)]
738    #[case::uint8(DataType::UInt8, true)]
739    #[case::uint16(DataType::UInt16, true)]
740    #[case::uint32(DataType::UInt32, true)]
741    #[case::uint64(DataType::UInt64, true)]
742    #[case::float32(DataType::Float32, true)]
743    #[case::float64(DataType::Float64, true)]
744    #[case::utf8(DataType::Utf8, true)]
745    #[case::utf8_view(DataType::Utf8View, true)]
746    #[case::binary(DataType::Binary, true)]
747    #[case::binary_view(DataType::BinaryView, true)]
748    #[case::date32(DataType::Date32, true)]
749    #[case::date64(DataType::Date64, true)]
750    #[case::timestamp_ms(DataType::Timestamp(ArrowTimeUnit::Millisecond, None), true)]
751    #[case::timestamp_us(
752        DataType::Timestamp(ArrowTimeUnit::Microsecond, Some(Arc::from("UTC"))),
753        true
754    )]
755    #[case::time32_s(DataType::Time32(ArrowTimeUnit::Second), true)]
756    #[case::time64_ns(DataType::Time64(ArrowTimeUnit::Nanosecond), true)]
757    // Unsupported types
758    #[case::list(
759        DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
760        false
761    )]
762    #[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()
763    ), false)]
764    // Dictionary types - should be supported if value type is supported
765    #[case::dict_utf8(
766        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
767        true
768    )]
769    #[case::dict_int32(
770        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)),
771        true
772    )]
773    #[case::dict_unsupported(
774        DataType::Dictionary(
775            Box::new(DataType::UInt32),
776            Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true))))
777        ),
778        false
779    )]
780    fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) {
781        assert_eq!(supported_data_types(&data_type), expected);
782    }
783
784    #[rstest]
785    fn test_can_be_pushed_down_column_supported(test_schema: Schema) {
786        let col_expr = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
787
788        assert!(can_be_pushed_down_impl(&col_expr, &test_schema));
789    }
790
791    #[rstest]
792    fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
793        let col_expr =
794            Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
795
796        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
797    }
798
799    #[rstest]
800    fn test_can_be_pushed_down_column_not_found(test_schema: Schema) {
801        let col_expr = Arc::new(df_expr::Column::new("nonexistent", 99)) as Arc<dyn PhysicalExpr>;
802
803        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
804    }
805
806    #[rstest]
807    fn test_can_be_pushed_down_literal_supported(test_schema: Schema) {
808        let lit_expr =
809            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
810
811        assert!(can_be_pushed_down_impl(&lit_expr, &test_schema));
812    }
813
814    #[rstest]
815    fn test_can_be_pushed_down_literal_unsupported(test_schema: Schema) {
816        // Use a simpler unsupported type - Duration is not supported
817        let unsupported_literal = ScalarValue::DurationSecond(Some(42));
818        let lit_expr =
819            Arc::new(df_expr::Literal::new(unsupported_literal)) as Arc<dyn PhysicalExpr>;
820
821        assert!(!can_be_pushed_down_impl(&lit_expr, &test_schema));
822    }
823
824    #[rstest]
825    fn test_can_be_pushed_down_binary_supported(test_schema: Schema) {
826        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
827        let right =
828            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
829        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
830            as Arc<dyn PhysicalExpr>;
831
832        assert!(can_be_pushed_down_impl(&binary_expr, &test_schema));
833    }
834
835    #[rstest]
836    fn test_can_be_pushed_down_binary_unsupported_operator(test_schema: Schema) {
837        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
838        let right =
839            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
840        let binary_expr = Arc::new(df_expr::BinaryExpr::new(
841            left,
842            DFOperator::AtQuestion,
843            right,
844        )) as Arc<dyn PhysicalExpr>;
845
846        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
847    }
848
849    #[rstest]
850    fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
851        let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
852        let right =
853            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
854        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
855            as Arc<dyn PhysicalExpr>;
856
857        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
858    }
859
860    #[rstest]
861    fn test_can_be_pushed_down_like_supported(test_schema: Schema) {
862        let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
863        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
864            "test%".to_string(),
865        )))) as Arc<dyn PhysicalExpr>;
866        let like_expr =
867            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
868
869        assert!(can_be_pushed_down_impl(&like_expr, &test_schema));
870    }
871
872    #[rstest]
873    fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
874        let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
875        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
876            "test%".to_string(),
877        )))) as Arc<dyn PhysicalExpr>;
878        let like_expr =
879            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
880
881        assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
882    }
883
884    // https://github.com/vortex-data/vortex/issues/6211
885    #[tokio::test]
886    async fn test_cast_int_to_string() -> anyhow::Result<()> {
887        let ctx = TestSessionContext::default();
888
889        ctx.session
890            .sql(r#"copy (select 1 as id) to 'example.vortex'"#)
891            .await?
892            .show()
893            .await?;
894
895        ctx.session
896            .sql(r#"select cast(id as string) as sid from 'example.vortex' where id > 0"#)
897            .await?
898            .show()
899            .await?;
900
901        ctx.session
902            .sql(r#"select id from 'example.vortex' where cast (id as string) == '1'"#)
903            .await?
904            .show()
905            .await?;
906
907        // This fails as it pushes string cast to the scan
908        ctx.session
909            .sql(r#"select cast(id as string) from 'example.vortex'"#)
910            .await?
911            .collect()
912            .await?;
913
914        Ok(())
915    }
916
917    /// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion
918    /// matches the result of applying the converted Vortex expression.
919    #[test]
920    fn test_case_when_datafusion_vortex_equivalence() {
921        use datafusion::arrow::array::Int32Array;
922        use datafusion::arrow::array::RecordBatch;
923        use datafusion_physical_expr::expressions::CaseExpr;
924        use vortex::VortexSessionDefault;
925        use vortex::array::ArrayRef;
926        use vortex::array::Canonical;
927        use vortex::array::VortexSessionExecute as _;
928        use vortex::array::arrow::FromArrowArray;
929        use vortex::session::VortexSession;
930
931        // Create test data
932        let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20]));
933        let schema = Arc::new(Schema::new(vec![Field::new(
934            "value",
935            DataType::Int32,
936            false,
937        )]));
938        let batch = RecordBatch::try_new(schema, vec![values]).unwrap();
939
940        // Build a DataFusion CASE expression:
941        // CASE WHEN value > 10 THEN 100 WHEN value > 5 THEN 50 ELSE 0 END
942        let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc<dyn PhysicalExpr>;
943        let lit_10 =
944            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
945        let lit_5 =
946            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
947        let lit_100 =
948            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc<dyn PhysicalExpr>;
949        let lit_50 =
950            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc<dyn PhysicalExpr>;
951        let lit_0 =
952            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
953
954        // WHEN value > 10 THEN 100
955        let when1 = Arc::new(df_expr::BinaryExpr::new(
956            Arc::clone(&col_value),
957            DFOperator::Gt,
958            lit_10,
959        )) as Arc<dyn PhysicalExpr>;
960        // WHEN value > 5 THEN 50
961        let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5))
962            as Arc<dyn PhysicalExpr>;
963
964        let case_expr =
965            CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap();
966
967        // Apply DataFusion expression
968        let df_result = case_expr.evaluate(&batch).unwrap();
969        let df_array = df_result.into_array(batch.num_rows()).unwrap();
970
971        // Convert to Vortex expression
972        let expr_convertor = DefaultExpressionConvertor::default();
973        let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap();
974
975        // Convert batch to Vortex array
976        let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap();
977
978        // Apply Vortex expression
979        let session = VortexSession::default();
980        let mut ctx = session.create_execution_ctx();
981        let vortex_result = vortex_array
982            .apply(&vortex_expr)
983            .unwrap()
984            .execute::<Canonical>(&mut ctx)
985            .unwrap();
986
987        // Convert back to Arrow for comparison
988        let vortex_as_arrow = vortex_result.into_primitive().as_slice::<i32>().to_vec();
989
990        // Convert DataFusion result to Vec for comparison
991        let df_as_arrow: Vec<i32> = df_array
992            .as_any()
993            .downcast_ref::<Int32Array>()
994            .unwrap()
995            .values()
996            .to_vec();
997
998        // Compare results
999        // Expected: [0, 0, 50, 100, 100] for values [1, 5, 10, 15, 20]
1000        // value=1: not > 10, not > 5 -> ELSE 0
1001        // value=5: not > 10, not > 5 -> ELSE 0
1002        // value=10: not > 10, > 5 -> 50
1003        // value=15: > 10 -> 100
1004        // value=20: > 10 -> 100
1005        assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]);
1006        assert_eq!(vortex_as_arrow, df_as_arrow);
1007    }
1008}