Skip to main content

vortex_datafusion/convert/
exprs.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_schema::DataType;
7use arrow_schema::Schema;
8use datafusion_common::Result as DFResult;
9use datafusion_common::exec_datafusion_err;
10use datafusion_common::tree_node::TreeNode;
11use datafusion_common::tree_node::TreeNodeRecursion;
12use datafusion_expr::Operator as DFOperator;
13use datafusion_functions::core::getfield::GetFieldFunc;
14use datafusion_physical_expr::PhysicalExpr;
15use datafusion_physical_expr::ScalarFunctionExpr;
16use datafusion_physical_expr::projection::ProjectionExpr;
17use datafusion_physical_expr::projection::ProjectionExprs;
18use datafusion_physical_expr::utils::collect_columns;
19use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
20use datafusion_physical_plan::expressions as df_expr;
21use itertools::Itertools;
22use vortex::dtype::DType;
23use vortex::dtype::Nullability;
24use vortex::dtype::arrow::FromArrowType;
25use vortex::expr::Expression;
26use vortex::expr::and_collect;
27use vortex::expr::cast;
28use vortex::expr::get_item;
29use vortex::expr::is_null;
30use vortex::expr::list_contains;
31use vortex::expr::lit;
32use vortex::expr::nested_case_when;
33use vortex::expr::not;
34use vortex::expr::pack;
35use vortex::expr::root;
36use vortex::scalar::Scalar;
37use vortex::scalar_fn::ScalarFnVTableExt;
38use vortex::scalar_fn::fns::binary::Binary;
39use vortex::scalar_fn::fns::like::Like;
40use vortex::scalar_fn::fns::like::LikeOptions;
41use vortex::scalar_fn::fns::operators::Operator;
42
43use crate::convert::FromDataFusion;
44
45/// Result of splitting a projection into Vortex expressions and leftover DataFusion projections.
46pub struct ProcessedProjection {
47    pub scan_projection: Expression,
48    pub leftover_projection: ProjectionExprs,
49}
50
51/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
52pub(crate) fn make_vortex_predicate(
53    expr_convertor: &dyn ExpressionConvertor,
54    predicate: &[Arc<dyn PhysicalExpr>],
55) -> DFResult<Option<Expression>> {
56    let exprs = predicate
57        .iter()
58        .map(|e| expr_convertor.convert(e.as_ref()))
59        .collect::<DFResult<Vec<_>>>()?;
60
61    Ok(and_collect(exprs))
62}
63
64/// Trait for converting DataFusion expressions to Vortex ones.
65pub trait ExpressionConvertor: Send + Sync {
66    /// Can an expression be pushed down given a specific schema
67    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;
68
69    /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
70    fn convert(&self, expr: &dyn PhysicalExpr) -> DFResult<Expression>;
71
72    /// Split a projection into Vortex expressions that can be pushed down and leftover
73    /// DataFusion projections that need to be evaluated after the scan.
74    fn split_projection(
75        &self,
76        source_projection: ProjectionExprs,
77        input_schema: &Schema,
78        output_schema: &Schema,
79    ) -> DFResult<ProcessedProjection>;
80
81    /// Create a projection that reads only the required columns without pushing down
82    /// any expressions. All projection logic is applied after the scan.
83    fn no_pushdown_projection(
84        &self,
85        source_projection: ProjectionExprs,
86        input_schema: &Schema,
87    ) -> DFResult<ProcessedProjection> {
88        // Get all unique column indices referenced by the projection
89        let column_indices = source_projection.column_indices();
90
91        // Create scan projection that reads the required columns
92        let scan_columns: Vec<(String, Expression)> = column_indices
93            .into_iter()
94            .map(|idx| {
95                let field = input_schema.field(idx);
96                let name = field.name().clone();
97                (name.clone(), get_item(name, root()))
98            })
99            .collect();
100
101        Ok(ProcessedProjection {
102            scan_projection: pack(scan_columns, Nullability::NonNullable),
103            leftover_projection: source_projection,
104        })
105    }
106}
107
108/// The default [`ExpressionConvertor`].
109#[derive(Default)]
110pub struct DefaultExpressionConvertor {}
111
112impl DefaultExpressionConvertor {
113    /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
114    fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
115        if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
116        {
117            // DataFusion's GetFieldFunc flattens nested field access into a single call
118            // with multiple field name arguments. For example, `outer.inner.leaf` becomes
119            // get_field(Column("outer"), "inner", "leaf"). We build a chain of get_item
120            // calls for each field name in the path.
121            let (source_expr, field_names) = get_field_fn
122                .args()
123                .split_first()
124                .ok_or_else(|| exec_datafusion_err!("get_field missing source expression"))?;
125
126            let mut result = self.convert(source_expr.as_ref())?;
127            for expr in field_names {
128                let field_name = expr
129                    .as_any()
130                    .downcast_ref::<df_expr::Literal>()
131                    .ok_or_else(|| exec_datafusion_err!("get_field field name must be a literal"))?
132                    .value()
133                    .try_as_str()
134                    .flatten()
135                    .ok_or_else(|| {
136                        exec_datafusion_err!("get_field field name must be a UTF-8 string")
137                    })?;
138                result = get_item(field_name.to_string(), result);
139            }
140            return Ok(result);
141        }
142
143        Err(exec_datafusion_err!(
144            "Unsupported ScalarFunctionExpr: {}",
145            scalar_fn.name()
146        ))
147    }
148
149    /// Attempts to convert a DataFusion CaseExpr to a Vortex expression.
150    fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult<Expression> {
151        // DataFusion CaseExpr has:
152        // - expr(): Optional base expression (for "CASE expr WHEN ..." form)
153        // - when_then_expr(): Vec of (when, then) pairs
154        // - else_expr(): Optional else expression
155
156        // We don't support the "CASE expr WHEN value1 THEN result1" form yet
157        if case_expr.expr().is_some() {
158            return Err(exec_datafusion_err!(
159                "CASE expr WHEN form is not yet supported, only searched CASE is supported"
160            ));
161        }
162
163        let when_then_pairs = case_expr.when_then_expr();
164        if when_then_pairs.is_empty() {
165            return Err(exec_datafusion_err!(
166                "CASE expression must have at least one WHEN clause"
167            ));
168        }
169
170        // Convert all when/then pairs to (condition, value) tuples
171        let mut pairs = Vec::with_capacity(when_then_pairs.len());
172        for (when_expr, then_expr) in when_then_pairs {
173            let condition = self.convert(when_expr.as_ref())?;
174            let value = self.convert(then_expr.as_ref())?;
175            pairs.push((condition, value));
176        }
177
178        // Convert optional else expression
179        let else_value = case_expr
180            .else_expr()
181            .map(|e| self.convert(e.as_ref()))
182            .transpose()?;
183
184        // Build a single n-ary CASE WHEN expression from DataFusion WHEN/THEN pairs
185        Ok(nested_case_when(pairs, else_value))
186    }
187}
188
189impl ExpressionConvertor for DefaultExpressionConvertor {
190    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
191        can_be_pushed_down_impl(expr, schema)
192    }
193
194    fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
195        // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
196        //  for that node, up to any `and` or `or` node.
197        if let Some(binary_expr) = df.as_any().downcast_ref::<df_expr::BinaryExpr>() {
198            let left = self.convert(binary_expr.left().as_ref())?;
199            let right = self.convert(binary_expr.right().as_ref())?;
200            let operator = try_operator_from_df(binary_expr.op())?;
201
202            return Ok(Binary.new_expr(operator, [left, right]));
203        }
204
205        if let Some(col_expr) = df.as_any().downcast_ref::<df_expr::Column>() {
206            return Ok(get_item(col_expr.name().to_owned(), root()));
207        }
208
209        if let Some(like) = df.as_any().downcast_ref::<df_expr::LikeExpr>() {
210            let child = self.convert(like.expr().as_ref())?;
211            let pattern = self.convert(like.pattern().as_ref())?;
212            return Ok(Like.new_expr(
213                LikeOptions {
214                    negated: like.negated(),
215                    case_insensitive: like.case_insensitive(),
216                },
217                [child, pattern],
218            ));
219        }
220
221        if let Some(literal) = df.as_any().downcast_ref::<df_expr::Literal>() {
222            let value = Scalar::from_df(literal.value());
223            return Ok(lit(value));
224        }
225
226        if let Some(cast_expr) = df.as_any().downcast_ref::<df_expr::CastExpr>() {
227            let cast_dtype = DType::from_arrow((cast_expr.cast_type(), Nullability::Nullable));
228            let child = self.convert(cast_expr.expr().as_ref())?;
229            return Ok(cast(child, cast_dtype));
230        }
231
232        if let Some(cast_col_expr) = df.as_any().downcast_ref::<df_expr::CastColumnExpr>() {
233            let target = cast_col_expr.target_field();
234
235            let target_dtype = DType::from_arrow((target.data_type(), target.is_nullable().into()));
236            let child = self.convert(cast_col_expr.expr().as_ref())?;
237            return Ok(cast(child, target_dtype));
238        }
239
240        if let Some(is_null_expr) = df.as_any().downcast_ref::<df_expr::IsNullExpr>() {
241            let arg = self.convert(is_null_expr.arg().as_ref())?;
242            return Ok(is_null(arg));
243        }
244
245        if let Some(is_not_null_expr) = df.as_any().downcast_ref::<df_expr::IsNotNullExpr>() {
246            let arg = self.convert(is_not_null_expr.arg().as_ref())?;
247            return Ok(not(is_null(arg)));
248        }
249
250        if let Some(in_list) = df.as_any().downcast_ref::<df_expr::InListExpr>() {
251            let value = self.convert(in_list.expr().as_ref())?;
252            let list_elements: Vec<_> = in_list
253                .list()
254                .iter()
255                .map(|e| {
256                    if let Some(lit) = e.as_any().downcast_ref::<df_expr::Literal>() {
257                        Ok(Scalar::from_df(lit.value()))
258                    } else {
259                        Err(exec_datafusion_err!("Failed to cast sub-expression"))
260                    }
261                })
262                .try_collect()?;
263
264            let list = Scalar::list(
265                list_elements[0].dtype().clone(),
266                list_elements,
267                Nullability::Nullable,
268            );
269            let expr = list_contains(lit(list), value);
270
271            return Ok(if in_list.negated() { not(expr) } else { expr });
272        }
273
274        if let Some(scalar_fn) = df.as_any().downcast_ref::<ScalarFunctionExpr>() {
275            return self.try_convert_scalar_function(scalar_fn);
276        }
277
278        if let Some(case_expr) = df.as_any().downcast_ref::<df_expr::CaseExpr>() {
279            return self.try_convert_case_expr(case_expr);
280        }
281
282        Err(exec_datafusion_err!(
283            "Couldn't convert DataFusion physical {df} expression to a vortex expression"
284        ))
285    }
286
287    fn split_projection(
288        &self,
289        source_projection: ProjectionExprs,
290        input_schema: &Schema,
291        output_schema: &Schema,
292    ) -> DFResult<ProcessedProjection> {
293        let mut scan_projection = vec![];
294        let mut leftover_projection: Vec<ProjectionExpr> = vec![];
295
296        for projection_expr in source_projection.iter() {
297            let r = projection_expr.expr.apply(|node| {
298                // We only pull column children of scalar functions that we can't push into the scan.
299                if let Some(scalar_fn_expr) = node.as_any().downcast_ref::<ScalarFunctionExpr>()
300                    && !can_scalar_fn_be_pushed_down(scalar_fn_expr)
301                {
302                    scan_projection.extend(
303                        collect_columns(node)
304                            .into_iter()
305                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
306                    );
307
308                    leftover_projection.push(projection_expr.clone());
309                    return Ok(TreeNodeRecursion::Stop);
310                }
311
312                // DataFusion assumes different decimal types can be coerced.
313                // Vortex expects a perfect match so we don't push it down.
314                if let Some(binary_expr) = node.as_any().downcast_ref::<df_expr::BinaryExpr>()
315                    && binary_expr.op().is_numerical_operators()
316                    && (is_decimal(&binary_expr.left().data_type(input_schema)?)
317                        && is_decimal(&binary_expr.right().data_type(input_schema)?))
318                {
319                    scan_projection.extend(
320                        collect_columns(node)
321                            .into_iter()
322                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
323                    );
324
325                    leftover_projection.push(projection_expr.clone());
326                    return Ok(TreeNodeRecursion::Stop);
327                }
328
329                Ok(TreeNodeRecursion::Continue)
330            })?;
331
332            // if we didn't stop early
333            if matches!(r, TreeNodeRecursion::Continue) {
334                scan_projection.push((
335                    projection_expr.alias.clone(),
336                    self.convert(projection_expr.expr.as_ref())?,
337                ));
338                leftover_projection.push(ProjectionExpr {
339                    expr: Arc::new(df_expr::Column::new_with_schema(
340                        projection_expr.alias.as_str(),
341                        output_schema,
342                    )?),
343                    alias: projection_expr.alias.clone(),
344                });
345            }
346        }
347
348        Ok(ProcessedProjection {
349            scan_projection: pack(scan_projection, Nullability::NonNullable),
350            leftover_projection: leftover_projection.into(),
351        })
352    }
353}
354
355fn try_operator_from_df(value: &DFOperator) -> DFResult<Operator> {
356    match value {
357        DFOperator::Eq => Ok(Operator::Eq),
358        DFOperator::NotEq => Ok(Operator::NotEq),
359        DFOperator::Lt => Ok(Operator::Lt),
360        DFOperator::LtEq => Ok(Operator::Lte),
361        DFOperator::Gt => Ok(Operator::Gt),
362        DFOperator::GtEq => Ok(Operator::Gte),
363        DFOperator::And => Ok(Operator::And),
364        DFOperator::Or => Ok(Operator::Or),
365        DFOperator::Plus => Ok(Operator::Add),
366        DFOperator::Minus => Ok(Operator::Sub),
367        DFOperator::Multiply => Ok(Operator::Mul),
368        DFOperator::Divide => Ok(Operator::Div),
369        DFOperator::IsDistinctFrom
370        | DFOperator::IsNotDistinctFrom
371        | DFOperator::RegexMatch
372        | DFOperator::RegexIMatch
373        | DFOperator::RegexNotMatch
374        | DFOperator::RegexNotIMatch
375        | DFOperator::LikeMatch
376        | DFOperator::ILikeMatch
377        | DFOperator::NotLikeMatch
378        | DFOperator::NotILikeMatch
379        | DFOperator::BitwiseAnd
380        | DFOperator::BitwiseOr
381        | DFOperator::BitwiseXor
382        | DFOperator::BitwiseShiftRight
383        | DFOperator::BitwiseShiftLeft
384        | DFOperator::StringConcat
385        | DFOperator::AtArrow
386        | DFOperator::ArrowAt
387        | DFOperator::Modulo
388        | DFOperator::Arrow
389        | DFOperator::LongArrow
390        | DFOperator::HashArrow
391        | DFOperator::HashLongArrow
392        | DFOperator::AtAt
393        | DFOperator::IntegerDivide
394        | DFOperator::HashMinus
395        | DFOperator::AtQuestion
396        | DFOperator::Question
397        | DFOperator::QuestionAnd
398        | DFOperator::QuestionPipe
399        | DFOperator::Colon => {
400            tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
401            Err(exec_datafusion_err!(
402                "Unsupported datafusion operator {value}"
403            ))
404        }
405    }
406}
407
408fn can_be_pushed_down_impl(df_expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
409    // We currently do not support pushdown of dynamic expressions in DF.
410    // See issue: https://github.com/vortex-data/vortex/issues/4034
411    if is_dynamic_physical_expr(df_expr) {
412        return false;
413    }
414
415    let expr = df_expr.as_any();
416    if let Some(binary) = expr.downcast_ref::<df_expr::BinaryExpr>() {
417        can_binary_be_pushed_down(binary, schema)
418    } else if let Some(col) = expr.downcast_ref::<df_expr::Column>() {
419        schema
420            .field_with_name(col.name())
421            .ok()
422            .is_some_and(|field| supported_data_types(field.data_type()))
423    } else if let Some(like) = expr.downcast_ref::<df_expr::LikeExpr>() {
424        can_be_pushed_down_impl(like.expr(), schema)
425            && can_be_pushed_down_impl(like.pattern(), schema)
426    } else if let Some(lit) = expr.downcast_ref::<df_expr::Literal>() {
427        supported_data_types(&lit.value().data_type())
428    } else if let Some(cast_expr) = expr.downcast_ref::<df_expr::CastExpr>() {
429        // CastExpr child must be an expression type that convert() can handle
430        is_convertible_expr(cast_expr.expr())
431    } else if let Some(cast_col_expr) = expr.downcast_ref::<df_expr::CastColumnExpr>() {
432        // CastColumnExpr child must be an expression type that convert() can handle
433        is_convertible_expr(cast_col_expr.expr())
434    } else if let Some(is_null) = expr.downcast_ref::<df_expr::IsNullExpr>() {
435        can_be_pushed_down_impl(is_null.arg(), schema)
436    } else if let Some(is_not_null) = expr.downcast_ref::<df_expr::IsNotNullExpr>() {
437        can_be_pushed_down_impl(is_not_null.arg(), schema)
438    } else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
439        can_be_pushed_down_impl(in_list.expr(), schema)
440            && in_list
441                .list()
442                .iter()
443                .all(|e| can_be_pushed_down_impl(e, schema))
444    } else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
445        can_scalar_fn_be_pushed_down(scalar_fn)
446    } else if let Some(case_expr) = expr.downcast_ref::<df_expr::CaseExpr>() {
447        can_case_be_pushed_down(case_expr, schema)
448    } else {
449        tracing::debug!(%df_expr, "DataFusion expression can't be pushed down");
450        false
451    }
452}
453
454/// Checks if an expression type is one that convert() can handle.
455/// This is less restrictive than can_be_pushed_down since it only checks
456/// expression types, not data type support.
457fn is_convertible_expr(df_expr: &Arc<dyn PhysicalExpr>) -> bool {
458    let expr = df_expr.as_any();
459
460    // Expression types that convert() handles
461    expr.downcast_ref::<df_expr::BinaryExpr>().is_some()
462        || expr.downcast_ref::<df_expr::Column>().is_some()
463        || expr.downcast_ref::<df_expr::LikeExpr>().is_some()
464        || expr.downcast_ref::<df_expr::Literal>().is_some()
465        || expr
466            .downcast_ref::<df_expr::CastExpr>()
467            .is_some_and(|e| is_convertible_expr(e.expr()))
468        || expr
469            .downcast_ref::<df_expr::CastColumnExpr>()
470            .is_some_and(|e| is_convertible_expr(e.expr()))
471        || expr.downcast_ref::<df_expr::IsNullExpr>().is_some()
472        || expr.downcast_ref::<df_expr::IsNotNullExpr>().is_some()
473        || expr.downcast_ref::<df_expr::InListExpr>().is_some()
474        || expr
475            .downcast_ref::<ScalarFunctionExpr>()
476            .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some())
477}
478
479fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
480    let is_op_supported = try_operator_from_df(binary.op()).is_ok();
481    is_op_supported
482        && can_be_pushed_down_impl(binary.left(), schema)
483        && can_be_pushed_down_impl(binary.right(), schema)
484}
485
486fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool {
487    // We only support the "searched CASE" form (CASE WHEN cond THEN result ...)
488    // not the "simple CASE" form (CASE expr WHEN value THEN result ...)
489    if case_expr.expr().is_some() {
490        return false;
491    }
492
493    // Check all when/then pairs
494    for (when_expr, then_expr) in case_expr.when_then_expr() {
495        if !can_be_pushed_down_impl(when_expr, schema)
496            || !can_be_pushed_down_impl(then_expr, schema)
497        {
498            return false;
499        }
500    }
501
502    // Check the optional else clause
503    if let Some(else_expr) = case_expr.else_expr()
504        && !can_be_pushed_down_impl(else_expr, schema)
505    {
506        return false;
507    }
508
509    true
510}
511
512fn supported_data_types(dt: &DataType) -> bool {
513    use DataType::*;
514
515    // For dictionary types, check if the value type is supported.
516    if let Dictionary(_, value_type) = dt {
517        return supported_data_types(value_type.as_ref());
518    }
519
520    let is_supported = dt.is_null()
521        || dt.is_numeric()
522        || matches!(
523            dt,
524            Boolean
525                | Utf8
526                | LargeUtf8
527                | Utf8View
528                | Binary
529                | LargeBinary
530                | BinaryView
531                | Date32
532                | Date64
533                | Timestamp(_, _)
534                | Time32(_)
535                | Time64(_)
536        );
537
538    if !is_supported {
539        tracing::debug!("DataFusion data type {dt:?} is not supported");
540    }
541
542    is_supported
543}
544
545/// Checks if a scalar function can be pushed down.
546/// Currently only GetFieldFunc is supported.
547fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
548    ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
549}
550
551// TODO(adam): Replace with `DataType::is_decimal` once its released.
552fn is_decimal(dt: &DataType) -> bool {
553    matches!(
554        dt,
555        DataType::Decimal32(_, _)
556            | DataType::Decimal64(_, _)
557            | DataType::Decimal128(_, _)
558            | DataType::Decimal256(_, _)
559    )
560}
561
562#[cfg(test)]
563mod tests {
564    use std::sync::Arc;
565
566    use arrow_schema::DataType;
567    use arrow_schema::Field;
568    use arrow_schema::Schema;
569    use arrow_schema::TimeUnit as ArrowTimeUnit;
570    use datafusion_common::ScalarValue;
571    use datafusion_expr::Operator as DFOperator;
572    use datafusion_physical_expr::PhysicalExpr;
573    use datafusion_physical_plan::expressions as df_expr;
574    use insta::assert_snapshot;
575    use rstest::rstest;
576
577    use super::*;
578    use crate::common_tests::TestSessionContext;
579
580    #[rstest::fixture]
581    fn test_schema() -> Schema {
582        Schema::new(vec![
583            Field::new("id", DataType::Int32, false),
584            Field::new("name", DataType::Utf8, true),
585            Field::new("score", DataType::Float64, true),
586            Field::new("active", DataType::Boolean, false),
587            Field::new(
588                "created_at",
589                DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
590                true,
591            ),
592            Field::new(
593                "unsupported_list",
594                DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
595                true,
596            ),
597        ])
598    }
599
600    #[test]
601    fn test_make_vortex_predicate_empty() {
602        let expr_convertor = DefaultExpressionConvertor::default();
603        let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
604        assert!(result.is_none());
605    }
606
607    #[test]
608    fn test_make_vortex_predicate_single() {
609        let expr_convertor = DefaultExpressionConvertor::default();
610        let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
611        let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
612        assert!(result.is_some());
613    }
614
615    #[test]
616    fn test_make_vortex_predicate_multiple() {
617        let expr_convertor = DefaultExpressionConvertor::default();
618        let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
619        let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
620        let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
621        assert!(result.is_some());
622        // Result should be an AND expression combining the two columns
623    }
624
625    #[rstest]
626    #[case::eq(DFOperator::Eq, Operator::Eq)]
627    #[case::not_eq(DFOperator::NotEq, Operator::NotEq)]
628    #[case::lt(DFOperator::Lt, Operator::Lt)]
629    #[case::lte(DFOperator::LtEq, Operator::Lte)]
630    #[case::gt(DFOperator::Gt, Operator::Gt)]
631    #[case::gte(DFOperator::GtEq, Operator::Gte)]
632    #[case::and(DFOperator::And, Operator::And)]
633    #[case::or(DFOperator::Or, Operator::Or)]
634    #[case::plus(DFOperator::Plus, Operator::Add)]
635    #[case::plus(DFOperator::Minus, Operator::Sub)]
636    #[case::plus(DFOperator::Multiply, Operator::Mul)]
637    #[case::plus(DFOperator::Divide, Operator::Div)]
638    fn test_operator_conversion_supported(
639        #[case] df_op: DFOperator,
640        #[case] expected_vortex_op: Operator,
641    ) {
642        let result = try_operator_from_df(&df_op).unwrap();
643        assert_eq!(result, expected_vortex_op);
644    }
645
646    #[rstest]
647    #[case::modulo(DFOperator::Modulo)]
648    #[case::bitwise_and(DFOperator::BitwiseAnd)]
649    #[case::regex_match(DFOperator::RegexMatch)]
650    #[case::like_match(DFOperator::LikeMatch)]
651    fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
652        let result = try_operator_from_df(&df_op);
653        assert!(result.is_err());
654        assert!(
655            result
656                .unwrap_err()
657                .to_string()
658                .contains("Unsupported datafusion operator")
659        );
660    }
661
662    #[test]
663    fn test_expr_from_df_column() {
664        let col_expr = df_expr::Column::new("test_column", 0);
665        let result = DefaultExpressionConvertor::default()
666            .convert(&col_expr)
667            .unwrap();
668
669        assert_snapshot!(result.display_tree().to_string(), @r"
670        vortex.get_item(test_column)
671        └── input: vortex.root()
672        ");
673    }
674
675    #[test]
676    fn test_expr_from_df_literal() {
677        let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
678        let result = DefaultExpressionConvertor::default()
679            .convert(&literal_expr)
680            .unwrap();
681
682        assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
683    }
684
685    #[test]
686    fn test_expr_from_df_binary() {
687        let left = Arc::new(df_expr::Column::new("left", 0)) as Arc<dyn PhysicalExpr>;
688        let right =
689            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
690        let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);
691
692        let result = DefaultExpressionConvertor::default()
693            .convert(&binary_expr)
694            .unwrap();
695
696        assert_snapshot!(result.display_tree().to_string(), @r"
697        vortex.binary(=)
698        ├── lhs: vortex.get_item(left)
699        │   └── input: vortex.root()
700        └── rhs: vortex.literal(42i32)
701        ");
702    }
703
704    #[rstest]
705    #[case::like_normal(false, false)]
706    #[case::like_negated(true, false)]
707    #[case::like_case_insensitive(false, true)]
708    #[case::like_negated_case_insensitive(true, true)]
709    fn test_expr_from_df_like(#[case] negated: bool, #[case] case_insensitive: bool) {
710        let expr = Arc::new(df_expr::Column::new("text_col", 0)) as Arc<dyn PhysicalExpr>;
711        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
712            "test%".to_string(),
713        )))) as Arc<dyn PhysicalExpr>;
714        let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);
715
716        let result = DefaultExpressionConvertor::default()
717            .convert(&like_expr)
718            .unwrap();
719        let like_opts = result.as_::<Like>();
720        assert_eq!(
721            like_opts,
722            &LikeOptions {
723                negated,
724                case_insensitive
725            }
726        );
727    }
728
729    #[rstest]
730    // Supported types
731    #[case::null(DataType::Null, true)]
732    #[case::boolean(DataType::Boolean, true)]
733    #[case::int8(DataType::Int8, true)]
734    #[case::int16(DataType::Int16, true)]
735    #[case::int32(DataType::Int32, true)]
736    #[case::int64(DataType::Int64, true)]
737    #[case::uint8(DataType::UInt8, true)]
738    #[case::uint16(DataType::UInt16, true)]
739    #[case::uint32(DataType::UInt32, true)]
740    #[case::uint64(DataType::UInt64, true)]
741    #[case::float32(DataType::Float32, true)]
742    #[case::float64(DataType::Float64, true)]
743    #[case::utf8(DataType::Utf8, true)]
744    #[case::utf8_view(DataType::Utf8View, true)]
745    #[case::binary(DataType::Binary, true)]
746    #[case::binary_view(DataType::BinaryView, true)]
747    #[case::date32(DataType::Date32, true)]
748    #[case::date64(DataType::Date64, true)]
749    #[case::timestamp_ms(DataType::Timestamp(ArrowTimeUnit::Millisecond, None), true)]
750    #[case::timestamp_us(
751        DataType::Timestamp(ArrowTimeUnit::Microsecond, Some(Arc::from("UTC"))),
752        true
753    )]
754    #[case::time32_s(DataType::Time32(ArrowTimeUnit::Second), true)]
755    #[case::time64_ns(DataType::Time64(ArrowTimeUnit::Nanosecond), true)]
756    // Unsupported types
757    #[case::list(
758        DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
759        false
760    )]
761    #[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()
762    ), false)]
763    // Dictionary types - should be supported if value type is supported
764    #[case::dict_utf8(
765        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
766        true
767    )]
768    #[case::dict_int32(
769        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)),
770        true
771    )]
772    #[case::dict_unsupported(
773        DataType::Dictionary(
774            Box::new(DataType::UInt32),
775            Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true))))
776        ),
777        false
778    )]
779    fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) {
780        assert_eq!(supported_data_types(&data_type), expected);
781    }
782
783    #[rstest]
784    fn test_can_be_pushed_down_column_supported(test_schema: Schema) {
785        let col_expr = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
786
787        assert!(can_be_pushed_down_impl(&col_expr, &test_schema));
788    }
789
790    #[rstest]
791    fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
792        let col_expr =
793            Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
794
795        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
796    }
797
798    #[rstest]
799    fn test_can_be_pushed_down_column_not_found(test_schema: Schema) {
800        let col_expr = Arc::new(df_expr::Column::new("nonexistent", 99)) as Arc<dyn PhysicalExpr>;
801
802        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
803    }
804
805    #[rstest]
806    fn test_can_be_pushed_down_literal_supported(test_schema: Schema) {
807        let lit_expr =
808            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
809
810        assert!(can_be_pushed_down_impl(&lit_expr, &test_schema));
811    }
812
813    #[rstest]
814    fn test_can_be_pushed_down_literal_unsupported(test_schema: Schema) {
815        // Use a simpler unsupported type - Duration is not supported
816        let unsupported_literal = ScalarValue::DurationSecond(Some(42));
817        let lit_expr =
818            Arc::new(df_expr::Literal::new(unsupported_literal)) as Arc<dyn PhysicalExpr>;
819
820        assert!(!can_be_pushed_down_impl(&lit_expr, &test_schema));
821    }
822
823    #[rstest]
824    fn test_can_be_pushed_down_binary_supported(test_schema: Schema) {
825        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
826        let right =
827            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
828        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
829            as Arc<dyn PhysicalExpr>;
830
831        assert!(can_be_pushed_down_impl(&binary_expr, &test_schema));
832    }
833
834    #[rstest]
835    fn test_can_be_pushed_down_binary_unsupported_operator(test_schema: Schema) {
836        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
837        let right =
838            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
839        let binary_expr = Arc::new(df_expr::BinaryExpr::new(
840            left,
841            DFOperator::AtQuestion,
842            right,
843        )) as Arc<dyn PhysicalExpr>;
844
845        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
846    }
847
848    #[rstest]
849    fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
850        let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
851        let right =
852            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
853        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
854            as Arc<dyn PhysicalExpr>;
855
856        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
857    }
858
859    #[rstest]
860    fn test_can_be_pushed_down_like_supported(test_schema: Schema) {
861        let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
862        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
863            "test%".to_string(),
864        )))) as Arc<dyn PhysicalExpr>;
865        let like_expr =
866            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
867
868        assert!(can_be_pushed_down_impl(&like_expr, &test_schema));
869    }
870
871    #[rstest]
872    fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
873        let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
874        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
875            "test%".to_string(),
876        )))) as Arc<dyn PhysicalExpr>;
877        let like_expr =
878            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
879
880        assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
881    }
882
883    // https://github.com/vortex-data/vortex/issues/6211
884    #[tokio::test]
885    async fn test_cast_int_to_string() -> anyhow::Result<()> {
886        let ctx = TestSessionContext::default();
887
888        ctx.session
889            .sql(r#"copy (select 1 as id) to 'example.vortex'"#)
890            .await?
891            .show()
892            .await?;
893
894        ctx.session
895            .sql(r#"select cast(id as string) as sid from 'example.vortex' where id > 0"#)
896            .await?
897            .show()
898            .await?;
899
900        ctx.session
901            .sql(r#"select id from 'example.vortex' where cast (id as string) == '1'"#)
902            .await?
903            .show()
904            .await?;
905
906        // This fails as it pushes string cast to the scan
907        ctx.session
908            .sql(r#"select cast(id as string) from 'example.vortex'"#)
909            .await?
910            .collect()
911            .await?;
912
913        Ok(())
914    }
915
916    /// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion
917    /// matches the result of applying the converted Vortex expression.
918    #[test]
919    fn test_case_when_datafusion_vortex_equivalence() {
920        use datafusion::arrow::array::Int32Array;
921        use datafusion::arrow::array::RecordBatch;
922        use datafusion_physical_expr::expressions::CaseExpr;
923        use vortex::VortexSessionDefault;
924        use vortex::array::ArrayRef;
925        use vortex::array::Canonical;
926        use vortex::array::VortexSessionExecute as _;
927        use vortex::array::arrow::FromArrowArray;
928        use vortex::session::VortexSession;
929
930        // Create test data
931        let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20]));
932        let schema = Arc::new(Schema::new(vec![Field::new(
933            "value",
934            DataType::Int32,
935            false,
936        )]));
937        let batch = RecordBatch::try_new(schema, vec![values]).unwrap();
938
939        // Build a DataFusion CASE expression:
940        // CASE WHEN value > 10 THEN 100 WHEN value > 5 THEN 50 ELSE 0 END
941        let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc<dyn PhysicalExpr>;
942        let lit_10 =
943            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
944        let lit_5 =
945            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
946        let lit_100 =
947            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc<dyn PhysicalExpr>;
948        let lit_50 =
949            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc<dyn PhysicalExpr>;
950        let lit_0 =
951            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
952
953        // WHEN value > 10 THEN 100
954        let when1 = Arc::new(df_expr::BinaryExpr::new(
955            col_value.clone(),
956            DFOperator::Gt,
957            lit_10,
958        )) as Arc<dyn PhysicalExpr>;
959        // WHEN value > 5 THEN 50
960        let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5))
961            as Arc<dyn PhysicalExpr>;
962
963        let case_expr =
964            CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap();
965
966        // Apply DataFusion expression
967        let df_result = case_expr.evaluate(&batch).unwrap();
968        let df_array = df_result.into_array(batch.num_rows()).unwrap();
969
970        // Convert to Vortex expression
971        let expr_convertor = DefaultExpressionConvertor::default();
972        let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap();
973
974        // Convert batch to Vortex array
975        let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap();
976
977        // Apply Vortex expression
978        let session = VortexSession::default();
979        let mut ctx = session.create_execution_ctx();
980        let vortex_result = vortex_array
981            .apply(&vortex_expr)
982            .unwrap()
983            .execute::<Canonical>(&mut ctx)
984            .unwrap();
985
986        // Convert back to Arrow for comparison
987        let vortex_as_arrow = vortex_result.into_primitive().as_slice::<i32>().to_vec();
988
989        // Convert DataFusion result to Vec for comparison
990        let df_as_arrow: Vec<i32> = df_array
991            .as_any()
992            .downcast_ref::<Int32Array>()
993            .unwrap()
994            .values()
995            .to_vec();
996
997        // Compare results
998        // Expected: [0, 0, 50, 100, 100] for values [1, 5, 10, 15, 20]
999        // value=1: not > 10, not > 5 -> ELSE 0
1000        // value=5: not > 10, not > 5 -> ELSE 0
1001        // value=10: not > 10, > 5 -> 50
1002        // value=15: > 10 -> 100
1003        // value=20: > 10 -> 100
1004        assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]);
1005        assert_eq!(vortex_as_arrow, df_as_arrow);
1006    }
1007}