Skip to main content

vortex_datafusion/convert/
exprs.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_schema::DataType;
7use arrow_schema::Schema;
8use datafusion_common::Result as DFResult;
9use datafusion_common::exec_datafusion_err;
10use datafusion_common::tree_node::TreeNode;
11use datafusion_common::tree_node::TreeNodeRecursion;
12use datafusion_expr::Operator as DFOperator;
13use datafusion_functions::core::getfield::GetFieldFunc;
14use datafusion_physical_expr::PhysicalExpr;
15use datafusion_physical_expr::ScalarFunctionExpr;
16use datafusion_physical_expr::projection::ProjectionExpr;
17use datafusion_physical_expr::projection::ProjectionExprs;
18use datafusion_physical_expr::utils::collect_columns;
19use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
20use datafusion_physical_plan::expressions as df_expr;
21use itertools::Itertools;
22use vortex::dtype::DType;
23use vortex::dtype::Nullability;
24use vortex::dtype::arrow::FromArrowType;
25use vortex::expr::Expression;
26use vortex::expr::and_collect;
27use vortex::expr::cast;
28use vortex::expr::get_item;
29use vortex::expr::is_not_null;
30use vortex::expr::is_null;
31use vortex::expr::list_contains;
32use vortex::expr::lit;
33use vortex::expr::nested_case_when;
34use vortex::expr::not;
35use vortex::expr::pack;
36use vortex::expr::root;
37use vortex::scalar::Scalar;
38use vortex::scalar_fn::ScalarFnVTableExt;
39use vortex::scalar_fn::fns::binary::Binary;
40use vortex::scalar_fn::fns::like::Like;
41use vortex::scalar_fn::fns::like::LikeOptions;
42use vortex::scalar_fn::fns::operators::Operator;
43
44use crate::convert::FromDataFusion;
45
46/// Result of splitting a projection into Vortex expressions and leftover DataFusion projections.
47pub struct ProcessedProjection {
48    pub scan_projection: Expression,
49    pub leftover_projection: ProjectionExprs,
50}
51
52/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
53pub(crate) fn make_vortex_predicate(
54    expr_convertor: &dyn ExpressionConvertor,
55    predicate: &[Arc<dyn PhysicalExpr>],
56) -> DFResult<Option<Expression>> {
57    let exprs = predicate
58        .iter()
59        .map(|e| expr_convertor.convert(e.as_ref()))
60        .collect::<DFResult<Vec<_>>>()?;
61
62    Ok(and_collect(exprs))
63}
64
65/// Trait for converting DataFusion expressions to Vortex ones.
66pub trait ExpressionConvertor: Send + Sync {
67    /// Can an expression be pushed down given a specific schema
68    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;
69
70    /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
71    fn convert(&self, expr: &dyn PhysicalExpr) -> DFResult<Expression>;
72
73    /// Split a projection into Vortex expressions that can be pushed down and leftover
74    /// DataFusion projections that need to be evaluated after the scan.
75    fn split_projection(
76        &self,
77        source_projection: ProjectionExprs,
78        input_schema: &Schema,
79        output_schema: &Schema,
80    ) -> DFResult<ProcessedProjection>;
81
82    /// Create a projection that reads only the required columns without pushing down
83    /// any expressions. All projection logic is applied after the scan.
84    fn no_pushdown_projection(
85        &self,
86        source_projection: ProjectionExprs,
87        input_schema: &Schema,
88    ) -> DFResult<ProcessedProjection> {
89        // Get all unique column indices referenced by the projection
90        let column_indices = source_projection.column_indices();
91
92        // Create scan projection that reads the required columns
93        let scan_columns: Vec<(String, Expression)> = column_indices
94            .into_iter()
95            .map(|idx| {
96                let field = input_schema.field(idx);
97                let name = field.name().clone();
98                (name.clone(), get_item(name, root()))
99            })
100            .collect();
101
102        Ok(ProcessedProjection {
103            scan_projection: pack(scan_columns, Nullability::NonNullable),
104            leftover_projection: source_projection,
105        })
106    }
107}
108
109/// The default [`ExpressionConvertor`].
110#[derive(Default)]
111pub struct DefaultExpressionConvertor {}
112
113impl DefaultExpressionConvertor {
114    /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
115    fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
116        if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
117        {
118            // DataFusion's GetFieldFunc flattens nested field access into a single call
119            // with multiple field name arguments. For example, `outer.inner.leaf` becomes
120            // get_field(Column("outer"), "inner", "leaf"). We build a chain of get_item
121            // calls for each field name in the path.
122            let (source_expr, field_names) = get_field_fn
123                .args()
124                .split_first()
125                .ok_or_else(|| exec_datafusion_err!("get_field missing source expression"))?;
126
127            let mut result = self.convert(source_expr.as_ref())?;
128            for expr in field_names {
129                let field_name = expr
130                    .downcast_ref::<df_expr::Literal>()
131                    .ok_or_else(|| exec_datafusion_err!("get_field field name must be a literal"))?
132                    .value()
133                    .try_as_str()
134                    .flatten()
135                    .ok_or_else(|| {
136                        exec_datafusion_err!("get_field field name must be a UTF-8 string")
137                    })?;
138                result = get_item(field_name.to_string(), result);
139            }
140            return Ok(result);
141        }
142
143        Err(exec_datafusion_err!(
144            "Unsupported ScalarFunctionExpr: {}",
145            scalar_fn.name()
146        ))
147    }
148
149    /// Attempts to convert a DataFusion CaseExpr to a Vortex expression.
150    fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult<Expression> {
151        // DataFusion CaseExpr has:
152        // - expr(): Optional base expression (for "CASE expr WHEN ..." form)
153        // - when_then_expr(): Vec of (when, then) pairs
154        // - else_expr(): Optional else expression
155
156        // We don't support the "CASE expr WHEN value1 THEN result1" form yet
157        if case_expr.expr().is_some() {
158            return Err(exec_datafusion_err!(
159                "CASE expr WHEN form is not yet supported, only searched CASE is supported"
160            ));
161        }
162
163        let when_then_pairs = case_expr.when_then_expr();
164        if when_then_pairs.is_empty() {
165            return Err(exec_datafusion_err!(
166                "CASE expression must have at least one WHEN clause"
167            ));
168        }
169
170        // Convert all when/then pairs to (condition, value) tuples
171        let mut pairs = Vec::with_capacity(when_then_pairs.len());
172        for (when_expr, then_expr) in when_then_pairs {
173            let condition = self.convert(when_expr.as_ref())?;
174            let value = self.convert(then_expr.as_ref())?;
175            pairs.push((condition, value));
176        }
177
178        // Convert optional else expression
179        let else_value = case_expr
180            .else_expr()
181            .map(|e| self.convert(e.as_ref()))
182            .transpose()?;
183
184        // Build a single n-ary CASE WHEN expression from DataFusion WHEN/THEN pairs
185        Ok(nested_case_when(pairs, else_value))
186    }
187}
188
189impl ExpressionConvertor for DefaultExpressionConvertor {
190    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
191        can_be_pushed_down_impl(expr, schema)
192    }
193
194    fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
195        // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
196        //  for that node, up to any `and` or `or` node.
197        if let Some(binary_expr) = df.downcast_ref::<df_expr::BinaryExpr>() {
198            let left = self.convert(binary_expr.left().as_ref())?;
199            let right = self.convert(binary_expr.right().as_ref())?;
200            let operator = try_operator_from_df(binary_expr.op())?;
201
202            return Ok(Binary.new_expr(operator, [left, right]));
203        }
204
205        if let Some(col_expr) = df.downcast_ref::<df_expr::Column>() {
206            return Ok(get_item(col_expr.name().to_owned(), root()));
207        }
208
209        if let Some(like) = df.downcast_ref::<df_expr::LikeExpr>() {
210            let child = self.convert(like.expr().as_ref())?;
211            let pattern = self.convert(like.pattern().as_ref())?;
212            return Ok(Like.new_expr(
213                LikeOptions {
214                    negated: like.negated(),
215                    case_insensitive: like.case_insensitive(),
216                },
217                [child, pattern],
218            ));
219        }
220
221        if let Some(literal) = df.downcast_ref::<df_expr::Literal>() {
222            let value = Scalar::from_df(literal.value());
223            return Ok(lit(value));
224        }
225
226        if let Some(cast_expr) = df.downcast_ref::<df_expr::CastExpr>() {
227            let cast_dtype = DType::from_arrow(cast_expr.target_field().as_ref());
228            let child = self.convert(cast_expr.expr().as_ref())?;
229            return Ok(cast(child, cast_dtype));
230        }
231
232        if let Some(is_null_expr) = df.downcast_ref::<df_expr::IsNullExpr>() {
233            let arg = self.convert(is_null_expr.arg().as_ref())?;
234            return Ok(is_null(arg));
235        }
236
237        if let Some(is_not_null_expr) = df.downcast_ref::<df_expr::IsNotNullExpr>() {
238            let arg = self.convert(is_not_null_expr.arg().as_ref())?;
239            return Ok(is_not_null(arg));
240        }
241
242        if let Some(in_list) = df.downcast_ref::<df_expr::InListExpr>() {
243            let value = self.convert(in_list.expr().as_ref())?;
244            let list_elements: Vec<_> = in_list
245                .list()
246                .iter()
247                .map(|e| {
248                    if let Some(lit) = e.downcast_ref::<df_expr::Literal>() {
249                        Ok(Scalar::from_df(lit.value()))
250                    } else {
251                        Err(exec_datafusion_err!("Failed to cast sub-expression"))
252                    }
253                })
254                .try_collect()?;
255
256            let list = Scalar::list(
257                list_elements[0].dtype().clone(),
258                list_elements,
259                Nullability::Nullable,
260            );
261            let expr = list_contains(lit(list), value);
262
263            return Ok(if in_list.negated() { not(expr) } else { expr });
264        }
265
266        if let Some(scalar_fn) = df.downcast_ref::<ScalarFunctionExpr>() {
267            return self.try_convert_scalar_function(scalar_fn);
268        }
269
270        if let Some(case_expr) = df.downcast_ref::<df_expr::CaseExpr>() {
271            return self.try_convert_case_expr(case_expr);
272        }
273
274        Err(exec_datafusion_err!(
275            "Couldn't convert DataFusion physical {df} expression to a vortex expression"
276        ))
277    }
278
279    fn split_projection(
280        &self,
281        source_projection: ProjectionExprs,
282        input_schema: &Schema,
283        output_schema: &Schema,
284    ) -> DFResult<ProcessedProjection> {
285        let mut scan_projection = vec![];
286        let mut leftover_projection: Vec<ProjectionExpr> = vec![];
287
288        for projection_expr in source_projection.iter() {
289            let r = projection_expr.expr.apply(|node| {
290                // We only pull column children of scalar functions that we can't push into the scan.
291                if let Some(scalar_fn_expr) = node.downcast_ref::<ScalarFunctionExpr>()
292                    && !can_scalar_fn_be_pushed_down(scalar_fn_expr)
293                {
294                    scan_projection.extend(
295                        collect_columns(node)
296                            .into_iter()
297                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
298                    );
299
300                    leftover_projection.push(projection_expr.clone());
301                    return Ok(TreeNodeRecursion::Stop);
302                }
303
304                // DataFusion assumes different decimal types can be coerced.
305                // Vortex expects a perfect match so we don't push it down.
306                if let Some(binary_expr) = node.downcast_ref::<df_expr::BinaryExpr>()
307                    && binary_expr.op().is_numerical_operators()
308                    && (is_decimal(&binary_expr.left().data_type(input_schema)?)
309                        && is_decimal(&binary_expr.right().data_type(input_schema)?))
310                {
311                    scan_projection.extend(
312                        collect_columns(node)
313                            .into_iter()
314                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
315                    );
316
317                    leftover_projection.push(projection_expr.clone());
318                    return Ok(TreeNodeRecursion::Stop);
319                }
320
321                Ok(TreeNodeRecursion::Continue)
322            })?;
323
324            // if we didn't stop early
325            if matches!(r, TreeNodeRecursion::Continue) {
326                scan_projection.push((
327                    projection_expr.alias.clone(),
328                    self.convert(projection_expr.expr.as_ref())?,
329                ));
330                leftover_projection.push(ProjectionExpr {
331                    expr: Arc::new(df_expr::Column::new_with_schema(
332                        projection_expr.alias.as_str(),
333                        output_schema,
334                    )?),
335                    alias: projection_expr.alias.clone(),
336                });
337            }
338        }
339
340        Ok(ProcessedProjection {
341            scan_projection: pack(scan_projection, Nullability::NonNullable),
342            leftover_projection: leftover_projection.into(),
343        })
344    }
345}
346
347fn try_operator_from_df(value: &DFOperator) -> DFResult<Operator> {
348    match value {
349        DFOperator::Eq => Ok(Operator::Eq),
350        DFOperator::NotEq => Ok(Operator::NotEq),
351        DFOperator::Lt => Ok(Operator::Lt),
352        DFOperator::LtEq => Ok(Operator::Lte),
353        DFOperator::Gt => Ok(Operator::Gt),
354        DFOperator::GtEq => Ok(Operator::Gte),
355        DFOperator::And => Ok(Operator::And),
356        DFOperator::Or => Ok(Operator::Or),
357        DFOperator::Plus => Ok(Operator::Add),
358        DFOperator::Minus => Ok(Operator::Sub),
359        DFOperator::Multiply => Ok(Operator::Mul),
360        DFOperator::Divide => Ok(Operator::Div),
361        DFOperator::IsDistinctFrom
362        | DFOperator::IsNotDistinctFrom
363        | DFOperator::RegexMatch
364        | DFOperator::RegexIMatch
365        | DFOperator::RegexNotMatch
366        | DFOperator::RegexNotIMatch
367        | DFOperator::LikeMatch
368        | DFOperator::ILikeMatch
369        | DFOperator::NotLikeMatch
370        | DFOperator::NotILikeMatch
371        | DFOperator::BitwiseAnd
372        | DFOperator::BitwiseOr
373        | DFOperator::BitwiseXor
374        | DFOperator::BitwiseShiftRight
375        | DFOperator::BitwiseShiftLeft
376        | DFOperator::StringConcat
377        | DFOperator::AtArrow
378        | DFOperator::ArrowAt
379        | DFOperator::Modulo
380        | DFOperator::Arrow
381        | DFOperator::LongArrow
382        | DFOperator::HashArrow
383        | DFOperator::HashLongArrow
384        | DFOperator::AtAt
385        | DFOperator::IntegerDivide
386        | DFOperator::HashMinus
387        | DFOperator::AtQuestion
388        | DFOperator::Question
389        | DFOperator::QuestionAnd
390        | DFOperator::QuestionPipe
391        | DFOperator::Colon => {
392            tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
393            Err(exec_datafusion_err!(
394                "Unsupported datafusion operator {value}"
395            ))
396        }
397    }
398}
399
400fn can_be_pushed_down_impl(expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
401    // We currently do not support pushdown of dynamic expressions in DF.
402    // See issue: https://github.com/vortex-data/vortex/issues/4034
403    if is_dynamic_physical_expr(expr) {
404        return false;
405    }
406
407    if let Some(binary) = expr.downcast_ref::<df_expr::BinaryExpr>() {
408        can_binary_be_pushed_down(binary, schema)
409    } else if let Some(col) = expr.downcast_ref::<df_expr::Column>() {
410        schema
411            .field_with_name(col.name())
412            .ok()
413            .is_some_and(|field| supported_data_types(field.data_type()))
414    } else if let Some(like) = expr.downcast_ref::<df_expr::LikeExpr>() {
415        can_be_pushed_down_impl(like.expr(), schema)
416            && can_be_pushed_down_impl(like.pattern(), schema)
417    } else if let Some(lit) = expr.downcast_ref::<df_expr::Literal>() {
418        supported_data_types(&lit.value().data_type())
419    } else if let Some(cast_expr) = expr.downcast_ref::<df_expr::CastExpr>() {
420        // CastExpr child must be an expression type that convert() can handle
421        is_convertible_expr(cast_expr.expr())
422    } else if let Some(is_null) = expr.downcast_ref::<df_expr::IsNullExpr>() {
423        can_be_pushed_down_impl(is_null.arg(), schema)
424    } else if let Some(is_not_null) = expr.downcast_ref::<df_expr::IsNotNullExpr>() {
425        can_be_pushed_down_impl(is_not_null.arg(), schema)
426    } else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
427        can_be_pushed_down_impl(in_list.expr(), schema)
428            && in_list
429                .list()
430                .iter()
431                .all(|e| can_be_pushed_down_impl(e, schema))
432    } else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
433        can_scalar_fn_be_pushed_down(scalar_fn)
434    } else if let Some(case_expr) = expr.downcast_ref::<df_expr::CaseExpr>() {
435        can_case_be_pushed_down(case_expr, schema)
436    } else {
437        tracing::debug!(%expr, "DataFusion expression can't be pushed down");
438        false
439    }
440}
441
442/// Checks if an expression type is one that convert() can handle.
443/// This is less restrictive than can_be_pushed_down since it only checks
444/// expression types, not data type support.
445fn is_convertible_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
446    // Expression types that convert() handles
447    expr.downcast_ref::<df_expr::BinaryExpr>().is_some()
448        || expr.downcast_ref::<df_expr::Column>().is_some()
449        || expr.downcast_ref::<df_expr::LikeExpr>().is_some()
450        || expr.downcast_ref::<df_expr::Literal>().is_some()
451        || expr
452            .downcast_ref::<df_expr::CastExpr>()
453            .is_some_and(|e| is_convertible_expr(e.expr()))
454        || expr.downcast_ref::<df_expr::IsNullExpr>().is_some()
455        || expr.downcast_ref::<df_expr::IsNotNullExpr>().is_some()
456        || expr.downcast_ref::<df_expr::InListExpr>().is_some()
457        || expr
458            .downcast_ref::<ScalarFunctionExpr>()
459            .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some())
460}
461
462fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
463    let is_op_supported = try_operator_from_df(binary.op()).is_ok();
464    is_op_supported
465        && can_be_pushed_down_impl(binary.left(), schema)
466        && can_be_pushed_down_impl(binary.right(), schema)
467}
468
469fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool {
470    // We only support the "searched CASE" form (CASE WHEN cond THEN result ...)
471    // not the "simple CASE" form (CASE expr WHEN value THEN result ...)
472    if case_expr.expr().is_some() {
473        return false;
474    }
475
476    // Check all when/then pairs
477    for (when_expr, then_expr) in case_expr.when_then_expr() {
478        if !can_be_pushed_down_impl(when_expr, schema)
479            || !can_be_pushed_down_impl(then_expr, schema)
480        {
481            return false;
482        }
483    }
484
485    // Check the optional else clause
486    if let Some(else_expr) = case_expr.else_expr()
487        && !can_be_pushed_down_impl(else_expr, schema)
488    {
489        return false;
490    }
491
492    true
493}
494
495fn supported_data_types(dt: &DataType) -> bool {
496    use DataType::*;
497
498    // For dictionary types, check if the value type is supported.
499    if let Dictionary(_, value_type) = dt {
500        return supported_data_types(value_type.as_ref());
501    }
502
503    let is_supported = dt.is_null()
504        || dt.is_numeric()
505        || matches!(
506            dt,
507            Boolean
508                | Utf8
509                | LargeUtf8
510                | Utf8View
511                | Binary
512                | LargeBinary
513                | BinaryView
514                | Date32
515                | Date64
516                | Timestamp(_, _)
517                | Time32(_)
518                | Time64(_)
519        );
520
521    if !is_supported {
522        tracing::debug!("DataFusion data type {dt:?} is not supported");
523    }
524
525    is_supported
526}
527
528/// Checks if a scalar function can be pushed down.
529/// Currently only GetFieldFunc is supported.
530fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
531    ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
532}
533
534// TODO(adam): Replace with `DataType::is_decimal` once its released.
535fn is_decimal(dt: &DataType) -> bool {
536    matches!(
537        dt,
538        DataType::Decimal32(_, _)
539            | DataType::Decimal64(_, _)
540            | DataType::Decimal128(_, _)
541            | DataType::Decimal256(_, _)
542    )
543}
544
545#[cfg(test)]
546mod tests {
547    use std::sync::Arc;
548
549    use arrow_schema::DataType;
550    use arrow_schema::Field;
551    use arrow_schema::Schema;
552    use arrow_schema::TimeUnit as ArrowTimeUnit;
553    use datafusion::arrow::array::AsArray;
554    use datafusion::arrow::datatypes::Int32Type;
555    use datafusion_common::ScalarValue;
556    use datafusion_expr::Operator as DFOperator;
557    use datafusion_physical_expr::PhysicalExpr;
558    use datafusion_physical_plan::expressions as df_expr;
559    use insta::assert_snapshot;
560    use rstest::rstest;
561
562    use super::*;
563    use crate::common_tests::TestSessionContext;
564
565    #[rstest::fixture]
566    fn test_schema() -> Schema {
567        Schema::new(vec![
568            Field::new("id", DataType::Int32, false),
569            Field::new("name", DataType::Utf8, true),
570            Field::new("score", DataType::Float64, true),
571            Field::new("active", DataType::Boolean, false),
572            Field::new(
573                "created_at",
574                DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
575                true,
576            ),
577            Field::new(
578                "unsupported_list",
579                DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
580                true,
581            ),
582        ])
583    }
584
585    #[test]
586    fn test_make_vortex_predicate_empty() {
587        let expr_convertor = DefaultExpressionConvertor::default();
588        let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
589        assert!(result.is_none());
590    }
591
592    #[test]
593    fn test_make_vortex_predicate_single() {
594        let expr_convertor = DefaultExpressionConvertor::default();
595        let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
596        let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
597        assert!(result.is_some());
598    }
599
600    #[test]
601    fn test_make_vortex_predicate_multiple() {
602        let expr_convertor = DefaultExpressionConvertor::default();
603        let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
604        let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
605        let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
606        assert!(result.is_some());
607        // Result should be an AND expression combining the two columns
608    }
609
610    #[rstest]
611    #[case::eq(DFOperator::Eq, Operator::Eq)]
612    #[case::not_eq(DFOperator::NotEq, Operator::NotEq)]
613    #[case::lt(DFOperator::Lt, Operator::Lt)]
614    #[case::lte(DFOperator::LtEq, Operator::Lte)]
615    #[case::gt(DFOperator::Gt, Operator::Gt)]
616    #[case::gte(DFOperator::GtEq, Operator::Gte)]
617    #[case::and(DFOperator::And, Operator::And)]
618    #[case::or(DFOperator::Or, Operator::Or)]
619    #[case::plus(DFOperator::Plus, Operator::Add)]
620    #[case::plus(DFOperator::Minus, Operator::Sub)]
621    #[case::plus(DFOperator::Multiply, Operator::Mul)]
622    #[case::plus(DFOperator::Divide, Operator::Div)]
623    fn test_operator_conversion_supported(
624        #[case] df_op: DFOperator,
625        #[case] expected_vortex_op: Operator,
626    ) {
627        let result = try_operator_from_df(&df_op).unwrap();
628        assert_eq!(result, expected_vortex_op);
629    }
630
631    #[rstest]
632    #[case::modulo(DFOperator::Modulo)]
633    #[case::bitwise_and(DFOperator::BitwiseAnd)]
634    #[case::regex_match(DFOperator::RegexMatch)]
635    #[case::like_match(DFOperator::LikeMatch)]
636    fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
637        let result = try_operator_from_df(&df_op);
638        assert!(result.is_err());
639        assert!(
640            result
641                .unwrap_err()
642                .to_string()
643                .contains("Unsupported datafusion operator")
644        );
645    }
646
647    #[test]
648    fn test_expr_from_df_column() {
649        let col_expr = df_expr::Column::new("test_column", 0);
650        let result = DefaultExpressionConvertor::default()
651            .convert(&col_expr)
652            .unwrap();
653
654        assert_snapshot!(result.display_tree().to_string(), @r"
655        vortex.get_item(test_column)
656        └── input: vortex.root()
657        ");
658    }
659
660    #[test]
661    fn test_expr_from_df_literal() {
662        let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
663        let result = DefaultExpressionConvertor::default()
664            .convert(&literal_expr)
665            .unwrap();
666
667        assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
668    }
669
670    #[test]
671    fn test_expr_from_df_binary() {
672        let left = Arc::new(df_expr::Column::new("left", 0)) as Arc<dyn PhysicalExpr>;
673        let right =
674            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
675        let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);
676
677        let result = DefaultExpressionConvertor::default()
678            .convert(&binary_expr)
679            .unwrap();
680
681        assert_snapshot!(result.display_tree().to_string(), @r"
682        vortex.binary(=)
683        ├── lhs: vortex.get_item(left)
684        │   └── input: vortex.root()
685        └── rhs: vortex.literal(42i32)
686        ");
687    }
688
689    #[rstest]
690    #[case::like_normal(false, false)]
691    #[case::like_negated(true, false)]
692    #[case::like_case_insensitive(false, true)]
693    #[case::like_negated_case_insensitive(true, true)]
694    fn test_expr_from_df_like(#[case] negated: bool, #[case] case_insensitive: bool) {
695        let expr = Arc::new(df_expr::Column::new("text_col", 0)) as Arc<dyn PhysicalExpr>;
696        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
697            "test%".to_string(),
698        )))) as Arc<dyn PhysicalExpr>;
699        let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);
700
701        let result = DefaultExpressionConvertor::default()
702            .convert(&like_expr)
703            .unwrap();
704        let like_opts = result.as_::<Like>();
705        assert_eq!(
706            like_opts,
707            &LikeOptions {
708                negated,
709                case_insensitive
710            }
711        );
712    }
713
714    #[rstest]
715    // Supported types
716    #[case::null(DataType::Null, true)]
717    #[case::boolean(DataType::Boolean, true)]
718    #[case::int8(DataType::Int8, true)]
719    #[case::int16(DataType::Int16, true)]
720    #[case::int32(DataType::Int32, true)]
721    #[case::int64(DataType::Int64, true)]
722    #[case::uint8(DataType::UInt8, true)]
723    #[case::uint16(DataType::UInt16, true)]
724    #[case::uint32(DataType::UInt32, true)]
725    #[case::uint64(DataType::UInt64, true)]
726    #[case::float32(DataType::Float32, true)]
727    #[case::float64(DataType::Float64, true)]
728    #[case::utf8(DataType::Utf8, true)]
729    #[case::utf8_view(DataType::Utf8View, true)]
730    #[case::binary(DataType::Binary, true)]
731    #[case::binary_view(DataType::BinaryView, true)]
732    #[case::date32(DataType::Date32, true)]
733    #[case::date64(DataType::Date64, true)]
734    #[case::timestamp_ms(DataType::Timestamp(ArrowTimeUnit::Millisecond, None), true)]
735    #[case::timestamp_us(
736        DataType::Timestamp(ArrowTimeUnit::Microsecond, Some(Arc::from("UTC"))),
737        true
738    )]
739    #[case::time32_s(DataType::Time32(ArrowTimeUnit::Second), true)]
740    #[case::time64_ns(DataType::Time64(ArrowTimeUnit::Nanosecond), true)]
741    // Unsupported types
742    #[case::list(
743        DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
744        false
745    )]
746    #[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()
747    ), false)]
748    // Dictionary types - should be supported if value type is supported
749    #[case::dict_utf8(
750        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
751        true
752    )]
753    #[case::dict_int32(
754        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)),
755        true
756    )]
757    #[case::dict_unsupported(
758        DataType::Dictionary(
759            Box::new(DataType::UInt32),
760            Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true))))
761        ),
762        false
763    )]
764    fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) {
765        assert_eq!(supported_data_types(&data_type), expected);
766    }
767
768    #[rstest]
769    fn test_can_be_pushed_down_column_supported(test_schema: Schema) {
770        let col_expr = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
771
772        assert!(can_be_pushed_down_impl(&col_expr, &test_schema));
773    }
774
775    #[rstest]
776    fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
777        let col_expr =
778            Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
779
780        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
781    }
782
783    #[rstest]
784    fn test_can_be_pushed_down_column_not_found(test_schema: Schema) {
785        let col_expr = Arc::new(df_expr::Column::new("nonexistent", 99)) as Arc<dyn PhysicalExpr>;
786
787        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
788    }
789
790    #[rstest]
791    fn test_can_be_pushed_down_literal_supported(test_schema: Schema) {
792        let lit_expr =
793            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
794
795        assert!(can_be_pushed_down_impl(&lit_expr, &test_schema));
796    }
797
798    #[rstest]
799    fn test_can_be_pushed_down_literal_unsupported(test_schema: Schema) {
800        // Use a simpler unsupported type - Duration is not supported
801        let unsupported_literal = ScalarValue::DurationSecond(Some(42));
802        let lit_expr =
803            Arc::new(df_expr::Literal::new(unsupported_literal)) as Arc<dyn PhysicalExpr>;
804
805        assert!(!can_be_pushed_down_impl(&lit_expr, &test_schema));
806    }
807
808    #[rstest]
809    fn test_can_be_pushed_down_binary_supported(test_schema: Schema) {
810        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
811        let right =
812            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
813        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
814            as Arc<dyn PhysicalExpr>;
815
816        assert!(can_be_pushed_down_impl(&binary_expr, &test_schema));
817    }
818
819    #[rstest]
820    fn test_can_be_pushed_down_binary_unsupported_operator(test_schema: Schema) {
821        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
822        let right =
823            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
824        let binary_expr = Arc::new(df_expr::BinaryExpr::new(
825            left,
826            DFOperator::AtQuestion,
827            right,
828        )) as Arc<dyn PhysicalExpr>;
829
830        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
831    }
832
833    #[rstest]
834    fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
835        let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
836        let right =
837            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
838        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
839            as Arc<dyn PhysicalExpr>;
840
841        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
842    }
843
844    #[rstest]
845    fn test_can_be_pushed_down_like_supported(test_schema: Schema) {
846        let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
847        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
848            "test%".to_string(),
849        )))) as Arc<dyn PhysicalExpr>;
850        let like_expr =
851            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
852
853        assert!(can_be_pushed_down_impl(&like_expr, &test_schema));
854    }
855
856    #[rstest]
857    fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
858        let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
859        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
860            "test%".to_string(),
861        )))) as Arc<dyn PhysicalExpr>;
862        let like_expr =
863            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
864
865        assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
866    }
867
868    // https://github.com/vortex-data/vortex/issues/6211
869    #[tokio::test]
870    async fn test_cast_int_to_string() -> anyhow::Result<()> {
871        let ctx = TestSessionContext::default();
872
873        ctx.session
874            .sql(r#"copy (select 1 as id) to 'example.vortex'"#)
875            .await?
876            .show()
877            .await?;
878
879        ctx.session
880            .sql(r#"select cast(id as string) as sid from 'example.vortex' where id > 0"#)
881            .await?
882            .show()
883            .await?;
884
885        ctx.session
886            .sql(r#"select id from 'example.vortex' where cast (id as string) == '1'"#)
887            .await?
888            .show()
889            .await?;
890
891        // This fails as it pushes string cast to the scan
892        ctx.session
893            .sql(r#"select cast(id as string) from 'example.vortex'"#)
894            .await?
895            .collect()
896            .await?;
897
898        Ok(())
899    }
900
901    /// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion
902    /// matches the result of applying the converted Vortex expression.
903    #[test]
904    fn test_case_when_datafusion_vortex_equivalence() {
905        use datafusion::arrow::array::Int32Array;
906        use datafusion::arrow::array::RecordBatch;
907        use datafusion_physical_expr::expressions::CaseExpr;
908        use vortex::VortexSessionDefault;
909        use vortex::array::ArrayRef;
910        use vortex::array::Canonical;
911        use vortex::array::VortexSessionExecute as _;
912        use vortex::array::arrow::FromArrowArray;
913        use vortex::session::VortexSession;
914
915        // Create test data
916        let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20]));
917        let schema = Arc::new(Schema::new(vec![Field::new(
918            "value",
919            DataType::Int32,
920            false,
921        )]));
922        let batch = RecordBatch::try_new(schema, vec![values]).unwrap();
923
924        // Build a DataFusion CASE expression:
925        // CASE WHEN value > 10 THEN 100 WHEN value > 5 THEN 50 ELSE 0 END
926        let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc<dyn PhysicalExpr>;
927        let lit_10 =
928            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
929        let lit_5 =
930            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
931        let lit_100 =
932            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc<dyn PhysicalExpr>;
933        let lit_50 =
934            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc<dyn PhysicalExpr>;
935        let lit_0 =
936            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
937
938        // WHEN value > 10 THEN 100
939        let when1 = Arc::new(df_expr::BinaryExpr::new(
940            Arc::clone(&col_value),
941            DFOperator::Gt,
942            lit_10,
943        )) as Arc<dyn PhysicalExpr>;
944        // WHEN value > 5 THEN 50
945        let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5))
946            as Arc<dyn PhysicalExpr>;
947
948        let case_expr =
949            CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap();
950
951        // Apply DataFusion expression
952        let df_result = case_expr.evaluate(&batch).unwrap();
953        let df_array = df_result.into_array(batch.num_rows()).unwrap();
954
955        // Convert to Vortex expression
956        let expr_convertor = DefaultExpressionConvertor::default();
957        let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap();
958
959        // Convert batch to Vortex array
960        let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap();
961
962        // Apply Vortex expression
963        let session = VortexSession::default();
964        let mut ctx = session.create_execution_ctx();
965        let vortex_result = vortex_array
966            .apply(&vortex_expr)
967            .unwrap()
968            .execute::<Canonical>(&mut ctx)
969            .unwrap();
970
971        // Convert back to Arrow for comparison
972        let vortex_as_arrow = vortex_result.into_primitive().as_slice::<i32>().to_vec();
973
974        // Convert DataFusion result to Vec for comparison
975        let df_as_arrow: Vec<i32> = df_array.as_primitive::<Int32Type>().values().to_vec();
976
977        // Compare results
978        // Expected: [0, 0, 50, 100, 100] for values [1, 5, 10, 15, 20]
979        // value=1: not > 10, not > 5 -> ELSE 0
980        // value=5: not > 10, not > 5 -> ELSE 0
981        // value=10: not > 10, > 5 -> 50
982        // value=15: > 10 -> 100
983        // value=20: > 10 -> 100
984        assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]);
985        assert_eq!(vortex_as_arrow, df_as_arrow);
986    }
987}