Skip to main content

vortex_datafusion/convert/
exprs.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_schema::DataType;
7use arrow_schema::Schema;
8use datafusion_common::Result as DFResult;
9use datafusion_common::exec_datafusion_err;
10use datafusion_common::tree_node::TreeNode;
11use datafusion_common::tree_node::TreeNodeRecursion;
12use datafusion_expr::Operator as DFOperator;
13use datafusion_functions::core::getfield::GetFieldFunc;
14use datafusion_physical_expr::PhysicalExpr;
15use datafusion_physical_expr::ScalarFunctionExpr;
16use datafusion_physical_expr::projection::ProjectionExpr;
17use datafusion_physical_expr::projection::ProjectionExprs;
18use datafusion_physical_expr::utils::collect_columns;
19use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
20use datafusion_physical_plan::expressions as df_expr;
21use itertools::Itertools;
22use vortex::dtype::DType;
23use vortex::dtype::Nullability;
24use vortex::dtype::arrow::FromArrowType;
25use vortex::expr::Expression;
26use vortex::expr::and_collect;
27use vortex::expr::cast;
28use vortex::expr::get_item;
29use vortex::expr::is_null;
30use vortex::expr::list_contains;
31use vortex::expr::lit;
32use vortex::expr::nested_case_when;
33use vortex::expr::not;
34use vortex::expr::pack;
35use vortex::expr::root;
36use vortex::scalar::Scalar;
37use vortex::scalar_fn::ScalarFnVTableExt;
38use vortex::scalar_fn::fns::binary::Binary;
39use vortex::scalar_fn::fns::like::Like;
40use vortex::scalar_fn::fns::like::LikeOptions;
41use vortex::scalar_fn::fns::operators::Operator;
42
43use crate::convert::FromDataFusion;
44
45/// Result of splitting a projection into Vortex expressions and leftover DataFusion projections.
46pub struct ProcessedProjection {
47    pub scan_projection: Expression,
48    pub leftover_projection: ProjectionExprs,
49}
50
51/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
52pub(crate) fn make_vortex_predicate(
53    expr_convertor: &dyn ExpressionConvertor,
54    predicate: &[Arc<dyn PhysicalExpr>],
55) -> DFResult<Option<Expression>> {
56    let exprs = predicate
57        .iter()
58        .map(|e| expr_convertor.convert(e.as_ref()))
59        .collect::<DFResult<Vec<_>>>()?;
60
61    Ok(and_collect(exprs))
62}
63
64/// Trait for converting DataFusion expressions to Vortex ones.
65pub trait ExpressionConvertor: Send + Sync {
66    /// Can an expression be pushed down given a specific schema
67    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;
68
69    /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
70    fn convert(&self, expr: &dyn PhysicalExpr) -> DFResult<Expression>;
71
72    /// Split a projection into Vortex expressions that can be pushed down and leftover
73    /// DataFusion projections that need to be evaluated after the scan.
74    fn split_projection(
75        &self,
76        source_projection: ProjectionExprs,
77        input_schema: &Schema,
78        output_schema: &Schema,
79    ) -> DFResult<ProcessedProjection>;
80
81    /// Create a projection that reads only the required columns without pushing down
82    /// any expressions. All projection logic is applied after the scan.
83    fn no_pushdown_projection(
84        &self,
85        source_projection: ProjectionExprs,
86        input_schema: &Schema,
87    ) -> DFResult<ProcessedProjection> {
88        // Get all unique column indices referenced by the projection
89        let column_indices = source_projection.column_indices();
90
91        // Create scan projection that reads the required columns
92        let scan_columns: Vec<(String, Expression)> = column_indices
93            .into_iter()
94            .map(|idx| {
95                let field = input_schema.field(idx);
96                let name = field.name().clone();
97                (name.clone(), get_item(name, root()))
98            })
99            .collect();
100
101        Ok(ProcessedProjection {
102            scan_projection: pack(scan_columns, Nullability::NonNullable),
103            leftover_projection: source_projection,
104        })
105    }
106}
107
108/// The default [`ExpressionConvertor`].
109#[derive(Default)]
110pub struct DefaultExpressionConvertor {}
111
112impl DefaultExpressionConvertor {
113    /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
114    fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
115        if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
116        {
117            // DataFusion's GetFieldFunc flattens nested field access into a single call
118            // with multiple field name arguments. For example, `outer.inner.leaf` becomes
119            // get_field(Column("outer"), "inner", "leaf"). We build a chain of get_item
120            // calls for each field name in the path.
121            let (source_expr, field_names) = get_field_fn
122                .args()
123                .split_first()
124                .ok_or_else(|| exec_datafusion_err!("get_field missing source expression"))?;
125
126            let mut result = self.convert(source_expr.as_ref())?;
127            for expr in field_names {
128                let field_name = expr
129                    .as_any()
130                    .downcast_ref::<df_expr::Literal>()
131                    .ok_or_else(|| exec_datafusion_err!("get_field field name must be a literal"))?
132                    .value()
133                    .try_as_str()
134                    .flatten()
135                    .ok_or_else(|| {
136                        exec_datafusion_err!("get_field field name must be a UTF-8 string")
137                    })?;
138                result = get_item(field_name.to_string(), result);
139            }
140            return Ok(result);
141        }
142
143        Err(exec_datafusion_err!(
144            "Unsupported ScalarFunctionExpr: {}",
145            scalar_fn.name()
146        ))
147    }
148
149    /// Attempts to convert a DataFusion CaseExpr to a Vortex expression.
150    fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult<Expression> {
151        // DataFusion CaseExpr has:
152        // - expr(): Optional base expression (for "CASE expr WHEN ..." form)
153        // - when_then_expr(): Vec of (when, then) pairs
154        // - else_expr(): Optional else expression
155
156        // We don't support the "CASE expr WHEN value1 THEN result1" form yet
157        if case_expr.expr().is_some() {
158            return Err(exec_datafusion_err!(
159                "CASE expr WHEN form is not yet supported, only searched CASE is supported"
160            ));
161        }
162
163        let when_then_pairs = case_expr.when_then_expr();
164        if when_then_pairs.is_empty() {
165            return Err(exec_datafusion_err!(
166                "CASE expression must have at least one WHEN clause"
167            ));
168        }
169
170        // Convert all when/then pairs to (condition, value) tuples
171        let mut pairs = Vec::with_capacity(when_then_pairs.len());
172        for (when_expr, then_expr) in when_then_pairs {
173            let condition = self.convert(when_expr.as_ref())?;
174            let value = self.convert(then_expr.as_ref())?;
175            pairs.push((condition, value));
176        }
177
178        // Convert optional else expression
179        let else_value = case_expr
180            .else_expr()
181            .map(|e| self.convert(e.as_ref()))
182            .transpose()?;
183
184        // Build a single n-ary CASE WHEN expression from DataFusion WHEN/THEN pairs
185        Ok(nested_case_when(pairs, else_value))
186    }
187}
188
189impl ExpressionConvertor for DefaultExpressionConvertor {
190    fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
191        can_be_pushed_down_impl(expr, schema)
192    }
193
194    fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
195        // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
196        //  for that node, up to any `and` or `or` node.
197        if let Some(binary_expr) = df.as_any().downcast_ref::<df_expr::BinaryExpr>() {
198            let left = self.convert(binary_expr.left().as_ref())?;
199            let right = self.convert(binary_expr.right().as_ref())?;
200            let operator = try_operator_from_df(binary_expr.op())?;
201
202            return Ok(Binary.new_expr(operator, [left, right]));
203        }
204
205        if let Some(col_expr) = df.as_any().downcast_ref::<df_expr::Column>() {
206            return Ok(get_item(col_expr.name().to_owned(), root()));
207        }
208
209        if let Some(like) = df.as_any().downcast_ref::<df_expr::LikeExpr>() {
210            let child = self.convert(like.expr().as_ref())?;
211            let pattern = self.convert(like.pattern().as_ref())?;
212            return Ok(Like.new_expr(
213                LikeOptions {
214                    negated: like.negated(),
215                    case_insensitive: like.case_insensitive(),
216                },
217                [child, pattern],
218            ));
219        }
220
221        if let Some(literal) = df.as_any().downcast_ref::<df_expr::Literal>() {
222            let value = Scalar::from_df(literal.value());
223            return Ok(lit(value));
224        }
225
226        if let Some(cast_expr) = df.as_any().downcast_ref::<df_expr::CastExpr>() {
227            let cast_dtype = DType::from_arrow((cast_expr.cast_type(), Nullability::Nullable));
228            let child = self.convert(cast_expr.expr().as_ref())?;
229            return Ok(cast(child, cast_dtype));
230        }
231
232        if let Some(cast_col_expr) = df.as_any().downcast_ref::<df_expr::CastColumnExpr>() {
233            let target = cast_col_expr.target_field();
234
235            let target_dtype = DType::from_arrow((target.data_type(), target.is_nullable().into()));
236            let child = self.convert(cast_col_expr.expr().as_ref())?;
237            return Ok(cast(child, target_dtype));
238        }
239
240        if let Some(is_null_expr) = df.as_any().downcast_ref::<df_expr::IsNullExpr>() {
241            let arg = self.convert(is_null_expr.arg().as_ref())?;
242            return Ok(is_null(arg));
243        }
244
245        if let Some(is_not_null_expr) = df.as_any().downcast_ref::<df_expr::IsNotNullExpr>() {
246            let arg = self.convert(is_not_null_expr.arg().as_ref())?;
247            return Ok(not(is_null(arg)));
248        }
249
250        if let Some(in_list) = df.as_any().downcast_ref::<df_expr::InListExpr>() {
251            let value = self.convert(in_list.expr().as_ref())?;
252            let list_elements: Vec<_> = in_list
253                .list()
254                .iter()
255                .map(|e| {
256                    if let Some(lit) = e.as_any().downcast_ref::<df_expr::Literal>() {
257                        Ok(Scalar::from_df(lit.value()))
258                    } else {
259                        Err(exec_datafusion_err!("Failed to cast sub-expression"))
260                    }
261                })
262                .try_collect()?;
263
264            let list = Scalar::list(
265                list_elements[0].dtype().clone(),
266                list_elements,
267                Nullability::Nullable,
268            );
269            let expr = list_contains(lit(list), value);
270
271            return Ok(if in_list.negated() { not(expr) } else { expr });
272        }
273
274        if let Some(scalar_fn) = df.as_any().downcast_ref::<ScalarFunctionExpr>() {
275            return self.try_convert_scalar_function(scalar_fn);
276        }
277
278        if let Some(case_expr) = df.as_any().downcast_ref::<df_expr::CaseExpr>() {
279            return self.try_convert_case_expr(case_expr);
280        }
281
282        Err(exec_datafusion_err!(
283            "Couldn't convert DataFusion physical {df} expression to a vortex expression"
284        ))
285    }
286
287    fn split_projection(
288        &self,
289        source_projection: ProjectionExprs,
290        input_schema: &Schema,
291        output_schema: &Schema,
292    ) -> DFResult<ProcessedProjection> {
293        let mut scan_projection = vec![];
294        let mut leftover_projection: Vec<ProjectionExpr> = vec![];
295
296        for projection_expr in source_projection.iter() {
297            let r = projection_expr.expr.apply(|node| {
298                // We only pull column children of scalar functions that we can't push into the scan.
299                if let Some(scalar_fn_expr) = node.as_any().downcast_ref::<ScalarFunctionExpr>()
300                    && !can_scalar_fn_be_pushed_down(scalar_fn_expr)
301                {
302                    scan_projection.extend(
303                        collect_columns(node)
304                            .into_iter()
305                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
306                    );
307
308                    leftover_projection.push(projection_expr.clone());
309                    return Ok(TreeNodeRecursion::Stop);
310                }
311
312                // DataFusion assumes different decimal types can be coerced.
313                // Vortex expects a perfect match so we don't push it down.
314                if let Some(binary_expr) = node.as_any().downcast_ref::<df_expr::BinaryExpr>()
315                    && binary_expr.op().is_numerical_operators()
316                    && (is_decimal(&binary_expr.left().data_type(input_schema)?)
317                        && is_decimal(&binary_expr.right().data_type(input_schema)?))
318                {
319                    scan_projection.extend(
320                        collect_columns(node)
321                            .into_iter()
322                            .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
323                    );
324
325                    leftover_projection.push(projection_expr.clone());
326                    return Ok(TreeNodeRecursion::Stop);
327                }
328
329                Ok(TreeNodeRecursion::Continue)
330            })?;
331
332            // if we didn't stop early
333            if matches!(r, TreeNodeRecursion::Continue) {
334                scan_projection.push((
335                    projection_expr.alias.clone(),
336                    self.convert(projection_expr.expr.as_ref())?,
337                ));
338                leftover_projection.push(ProjectionExpr {
339                    expr: Arc::new(df_expr::Column::new_with_schema(
340                        projection_expr.alias.as_str(),
341                        output_schema,
342                    )?),
343                    alias: projection_expr.alias.clone(),
344                });
345            }
346        }
347
348        Ok(ProcessedProjection {
349            scan_projection: pack(scan_projection, Nullability::NonNullable),
350            leftover_projection: leftover_projection.into(),
351        })
352    }
353}
354
355fn try_operator_from_df(value: &DFOperator) -> DFResult<Operator> {
356    match value {
357        DFOperator::Eq => Ok(Operator::Eq),
358        DFOperator::NotEq => Ok(Operator::NotEq),
359        DFOperator::Lt => Ok(Operator::Lt),
360        DFOperator::LtEq => Ok(Operator::Lte),
361        DFOperator::Gt => Ok(Operator::Gt),
362        DFOperator::GtEq => Ok(Operator::Gte),
363        DFOperator::And => Ok(Operator::And),
364        DFOperator::Or => Ok(Operator::Or),
365        DFOperator::Plus => Ok(Operator::Add),
366        DFOperator::Minus => Ok(Operator::Sub),
367        DFOperator::Multiply => Ok(Operator::Mul),
368        DFOperator::Divide => Ok(Operator::Div),
369        DFOperator::IsDistinctFrom
370        | DFOperator::IsNotDistinctFrom
371        | DFOperator::RegexMatch
372        | DFOperator::RegexIMatch
373        | DFOperator::RegexNotMatch
374        | DFOperator::RegexNotIMatch
375        | DFOperator::LikeMatch
376        | DFOperator::ILikeMatch
377        | DFOperator::NotLikeMatch
378        | DFOperator::NotILikeMatch
379        | DFOperator::BitwiseAnd
380        | DFOperator::BitwiseOr
381        | DFOperator::BitwiseXor
382        | DFOperator::BitwiseShiftRight
383        | DFOperator::BitwiseShiftLeft
384        | DFOperator::StringConcat
385        | DFOperator::AtArrow
386        | DFOperator::ArrowAt
387        | DFOperator::Modulo
388        | DFOperator::Arrow
389        | DFOperator::LongArrow
390        | DFOperator::HashArrow
391        | DFOperator::HashLongArrow
392        | DFOperator::AtAt
393        | DFOperator::IntegerDivide
394        | DFOperator::HashMinus
395        | DFOperator::AtQuestion
396        | DFOperator::Question
397        | DFOperator::QuestionAnd
398        | DFOperator::QuestionPipe => {
399            tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
400            Err(exec_datafusion_err!(
401                "Unsupported datafusion operator {value}"
402            ))
403        }
404    }
405}
406
407fn can_be_pushed_down_impl(df_expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
408    // We currently do not support pushdown of dynamic expressions in DF.
409    // See issue: https://github.com/vortex-data/vortex/issues/4034
410    if is_dynamic_physical_expr(df_expr) {
411        return false;
412    }
413
414    let expr = df_expr.as_any();
415    if let Some(binary) = expr.downcast_ref::<df_expr::BinaryExpr>() {
416        can_binary_be_pushed_down(binary, schema)
417    } else if let Some(col) = expr.downcast_ref::<df_expr::Column>() {
418        schema
419            .field_with_name(col.name())
420            .ok()
421            .is_some_and(|field| supported_data_types(field.data_type()))
422    } else if let Some(like) = expr.downcast_ref::<df_expr::LikeExpr>() {
423        can_be_pushed_down_impl(like.expr(), schema)
424            && can_be_pushed_down_impl(like.pattern(), schema)
425    } else if let Some(lit) = expr.downcast_ref::<df_expr::Literal>() {
426        supported_data_types(&lit.value().data_type())
427    } else if let Some(cast_expr) = expr.downcast_ref::<df_expr::CastExpr>() {
428        // CastExpr child must be an expression type that convert() can handle
429        is_convertible_expr(cast_expr.expr())
430    } else if let Some(cast_col_expr) = expr.downcast_ref::<df_expr::CastColumnExpr>() {
431        // CastColumnExpr child must be an expression type that convert() can handle
432        is_convertible_expr(cast_col_expr.expr())
433    } else if let Some(is_null) = expr.downcast_ref::<df_expr::IsNullExpr>() {
434        can_be_pushed_down_impl(is_null.arg(), schema)
435    } else if let Some(is_not_null) = expr.downcast_ref::<df_expr::IsNotNullExpr>() {
436        can_be_pushed_down_impl(is_not_null.arg(), schema)
437    } else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
438        can_be_pushed_down_impl(in_list.expr(), schema)
439            && in_list
440                .list()
441                .iter()
442                .all(|e| can_be_pushed_down_impl(e, schema))
443    } else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
444        can_scalar_fn_be_pushed_down(scalar_fn)
445    } else if let Some(case_expr) = expr.downcast_ref::<df_expr::CaseExpr>() {
446        can_case_be_pushed_down(case_expr, schema)
447    } else {
448        tracing::debug!(%df_expr, "DataFusion expression can't be pushed down");
449        false
450    }
451}
452
453/// Checks if an expression type is one that convert() can handle.
454/// This is less restrictive than can_be_pushed_down since it only checks
455/// expression types, not data type support.
456fn is_convertible_expr(df_expr: &Arc<dyn PhysicalExpr>) -> bool {
457    let expr = df_expr.as_any();
458
459    // Expression types that convert() handles
460    expr.downcast_ref::<df_expr::BinaryExpr>().is_some()
461        || expr.downcast_ref::<df_expr::Column>().is_some()
462        || expr.downcast_ref::<df_expr::LikeExpr>().is_some()
463        || expr.downcast_ref::<df_expr::Literal>().is_some()
464        || expr
465            .downcast_ref::<df_expr::CastExpr>()
466            .is_some_and(|e| is_convertible_expr(e.expr()))
467        || expr
468            .downcast_ref::<df_expr::CastColumnExpr>()
469            .is_some_and(|e| is_convertible_expr(e.expr()))
470        || expr.downcast_ref::<df_expr::IsNullExpr>().is_some()
471        || expr.downcast_ref::<df_expr::IsNotNullExpr>().is_some()
472        || expr.downcast_ref::<df_expr::InListExpr>().is_some()
473        || expr
474            .downcast_ref::<ScalarFunctionExpr>()
475            .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some())
476}
477
478fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
479    let is_op_supported = try_operator_from_df(binary.op()).is_ok();
480    is_op_supported
481        && can_be_pushed_down_impl(binary.left(), schema)
482        && can_be_pushed_down_impl(binary.right(), schema)
483}
484
485fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool {
486    // We only support the "searched CASE" form (CASE WHEN cond THEN result ...)
487    // not the "simple CASE" form (CASE expr WHEN value THEN result ...)
488    if case_expr.expr().is_some() {
489        return false;
490    }
491
492    // Check all when/then pairs
493    for (when_expr, then_expr) in case_expr.when_then_expr() {
494        if !can_be_pushed_down_impl(when_expr, schema)
495            || !can_be_pushed_down_impl(then_expr, schema)
496        {
497            return false;
498        }
499    }
500
501    // Check the optional else clause
502    if let Some(else_expr) = case_expr.else_expr()
503        && !can_be_pushed_down_impl(else_expr, schema)
504    {
505        return false;
506    }
507
508    true
509}
510
511fn supported_data_types(dt: &DataType) -> bool {
512    use DataType::*;
513
514    // For dictionary types, check if the value type is supported.
515    if let Dictionary(_, value_type) = dt {
516        return supported_data_types(value_type.as_ref());
517    }
518
519    let is_supported = dt.is_null()
520        || dt.is_numeric()
521        || matches!(
522            dt,
523            Boolean
524                | Utf8
525                | LargeUtf8
526                | Utf8View
527                | Binary
528                | LargeBinary
529                | BinaryView
530                | Date32
531                | Date64
532                | Timestamp(_, _)
533                | Time32(_)
534                | Time64(_)
535        );
536
537    if !is_supported {
538        tracing::debug!("DataFusion data type {dt:?} is not supported");
539    }
540
541    is_supported
542}
543
544/// Checks if a scalar function can be pushed down.
545/// Currently only GetFieldFunc is supported.
546fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
547    ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
548}
549
550// TODO(adam): Replace with `DataType::is_decimal` once its released.
551fn is_decimal(dt: &DataType) -> bool {
552    matches!(
553        dt,
554        DataType::Decimal32(_, _)
555            | DataType::Decimal64(_, _)
556            | DataType::Decimal128(_, _)
557            | DataType::Decimal256(_, _)
558    )
559}
560
561#[cfg(test)]
562mod tests {
563    use std::sync::Arc;
564
565    use arrow_schema::DataType;
566    use arrow_schema::Field;
567    use arrow_schema::Schema;
568    use arrow_schema::TimeUnit as ArrowTimeUnit;
569    use datafusion_common::ScalarValue;
570    use datafusion_expr::Operator as DFOperator;
571    use datafusion_physical_expr::PhysicalExpr;
572    use datafusion_physical_plan::expressions as df_expr;
573    use insta::assert_snapshot;
574    use rstest::rstest;
575
576    use super::*;
577    use crate::common_tests::TestSessionContext;
578
579    #[rstest::fixture]
580    fn test_schema() -> Schema {
581        Schema::new(vec![
582            Field::new("id", DataType::Int32, false),
583            Field::new("name", DataType::Utf8, true),
584            Field::new("score", DataType::Float64, true),
585            Field::new("active", DataType::Boolean, false),
586            Field::new(
587                "created_at",
588                DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
589                true,
590            ),
591            Field::new(
592                "unsupported_list",
593                DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
594                true,
595            ),
596        ])
597    }
598
599    #[test]
600    fn test_make_vortex_predicate_empty() {
601        let expr_convertor = DefaultExpressionConvertor::default();
602        let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
603        assert!(result.is_none());
604    }
605
606    #[test]
607    fn test_make_vortex_predicate_single() {
608        let expr_convertor = DefaultExpressionConvertor::default();
609        let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
610        let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
611        assert!(result.is_some());
612    }
613
614    #[test]
615    fn test_make_vortex_predicate_multiple() {
616        let expr_convertor = DefaultExpressionConvertor::default();
617        let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
618        let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
619        let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
620        assert!(result.is_some());
621        // Result should be an AND expression combining the two columns
622    }
623
624    #[rstest]
625    #[case::eq(DFOperator::Eq, Operator::Eq)]
626    #[case::not_eq(DFOperator::NotEq, Operator::NotEq)]
627    #[case::lt(DFOperator::Lt, Operator::Lt)]
628    #[case::lte(DFOperator::LtEq, Operator::Lte)]
629    #[case::gt(DFOperator::Gt, Operator::Gt)]
630    #[case::gte(DFOperator::GtEq, Operator::Gte)]
631    #[case::and(DFOperator::And, Operator::And)]
632    #[case::or(DFOperator::Or, Operator::Or)]
633    #[case::plus(DFOperator::Plus, Operator::Add)]
634    #[case::plus(DFOperator::Minus, Operator::Sub)]
635    #[case::plus(DFOperator::Multiply, Operator::Mul)]
636    #[case::plus(DFOperator::Divide, Operator::Div)]
637    fn test_operator_conversion_supported(
638        #[case] df_op: DFOperator,
639        #[case] expected_vortex_op: Operator,
640    ) {
641        let result = try_operator_from_df(&df_op).unwrap();
642        assert_eq!(result, expected_vortex_op);
643    }
644
645    #[rstest]
646    #[case::modulo(DFOperator::Modulo)]
647    #[case::bitwise_and(DFOperator::BitwiseAnd)]
648    #[case::regex_match(DFOperator::RegexMatch)]
649    #[case::like_match(DFOperator::LikeMatch)]
650    fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
651        let result = try_operator_from_df(&df_op);
652        assert!(result.is_err());
653        assert!(
654            result
655                .unwrap_err()
656                .to_string()
657                .contains("Unsupported datafusion operator")
658        );
659    }
660
661    #[test]
662    fn test_expr_from_df_column() {
663        let col_expr = df_expr::Column::new("test_column", 0);
664        let result = DefaultExpressionConvertor::default()
665            .convert(&col_expr)
666            .unwrap();
667
668        assert_snapshot!(result.display_tree().to_string(), @r"
669        vortex.get_item(test_column)
670        └── input: vortex.root()
671        ");
672    }
673
674    #[test]
675    fn test_expr_from_df_literal() {
676        let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
677        let result = DefaultExpressionConvertor::default()
678            .convert(&literal_expr)
679            .unwrap();
680
681        assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
682    }
683
684    #[test]
685    fn test_expr_from_df_binary() {
686        let left = Arc::new(df_expr::Column::new("left", 0)) as Arc<dyn PhysicalExpr>;
687        let right =
688            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
689        let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);
690
691        let result = DefaultExpressionConvertor::default()
692            .convert(&binary_expr)
693            .unwrap();
694
695        assert_snapshot!(result.display_tree().to_string(), @r"
696        vortex.binary(=)
697        ├── lhs: vortex.get_item(left)
698        │   └── input: vortex.root()
699        └── rhs: vortex.literal(42i32)
700        ");
701    }
702
703    #[rstest]
704    #[case::like_normal(false, false)]
705    #[case::like_negated(true, false)]
706    #[case::like_case_insensitive(false, true)]
707    #[case::like_negated_case_insensitive(true, true)]
708    fn test_expr_from_df_like(#[case] negated: bool, #[case] case_insensitive: bool) {
709        let expr = Arc::new(df_expr::Column::new("text_col", 0)) as Arc<dyn PhysicalExpr>;
710        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
711            "test%".to_string(),
712        )))) as Arc<dyn PhysicalExpr>;
713        let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);
714
715        let result = DefaultExpressionConvertor::default()
716            .convert(&like_expr)
717            .unwrap();
718        let like_opts = result.as_::<Like>();
719        assert_eq!(
720            like_opts,
721            &LikeOptions {
722                negated,
723                case_insensitive
724            }
725        );
726    }
727
728    #[rstest]
729    // Supported types
730    #[case::null(DataType::Null, true)]
731    #[case::boolean(DataType::Boolean, true)]
732    #[case::int8(DataType::Int8, true)]
733    #[case::int16(DataType::Int16, true)]
734    #[case::int32(DataType::Int32, true)]
735    #[case::int64(DataType::Int64, true)]
736    #[case::uint8(DataType::UInt8, true)]
737    #[case::uint16(DataType::UInt16, true)]
738    #[case::uint32(DataType::UInt32, true)]
739    #[case::uint64(DataType::UInt64, true)]
740    #[case::float32(DataType::Float32, true)]
741    #[case::float64(DataType::Float64, true)]
742    #[case::utf8(DataType::Utf8, true)]
743    #[case::utf8_view(DataType::Utf8View, true)]
744    #[case::binary(DataType::Binary, true)]
745    #[case::binary_view(DataType::BinaryView, true)]
746    #[case::date32(DataType::Date32, true)]
747    #[case::date64(DataType::Date64, true)]
748    #[case::timestamp_ms(DataType::Timestamp(ArrowTimeUnit::Millisecond, None), true)]
749    #[case::timestamp_us(
750        DataType::Timestamp(ArrowTimeUnit::Microsecond, Some(Arc::from("UTC"))),
751        true
752    )]
753    #[case::time32_s(DataType::Time32(ArrowTimeUnit::Second), true)]
754    #[case::time64_ns(DataType::Time64(ArrowTimeUnit::Nanosecond), true)]
755    // Unsupported types
756    #[case::list(
757        DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
758        false
759    )]
760    #[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()
761    ), false)]
762    // Dictionary types - should be supported if value type is supported
763    #[case::dict_utf8(
764        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
765        true
766    )]
767    #[case::dict_int32(
768        DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)),
769        true
770    )]
771    #[case::dict_unsupported(
772        DataType::Dictionary(
773            Box::new(DataType::UInt32),
774            Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true))))
775        ),
776        false
777    )]
778    fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) {
779        assert_eq!(supported_data_types(&data_type), expected);
780    }
781
782    #[rstest]
783    fn test_can_be_pushed_down_column_supported(test_schema: Schema) {
784        let col_expr = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
785
786        assert!(can_be_pushed_down_impl(&col_expr, &test_schema));
787    }
788
789    #[rstest]
790    fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
791        let col_expr =
792            Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
793
794        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
795    }
796
797    #[rstest]
798    fn test_can_be_pushed_down_column_not_found(test_schema: Schema) {
799        let col_expr = Arc::new(df_expr::Column::new("nonexistent", 99)) as Arc<dyn PhysicalExpr>;
800
801        assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
802    }
803
804    #[rstest]
805    fn test_can_be_pushed_down_literal_supported(test_schema: Schema) {
806        let lit_expr =
807            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
808
809        assert!(can_be_pushed_down_impl(&lit_expr, &test_schema));
810    }
811
812    #[rstest]
813    fn test_can_be_pushed_down_literal_unsupported(test_schema: Schema) {
814        // Use a simpler unsupported type - Duration is not supported
815        let unsupported_literal = ScalarValue::DurationSecond(Some(42));
816        let lit_expr =
817            Arc::new(df_expr::Literal::new(unsupported_literal)) as Arc<dyn PhysicalExpr>;
818
819        assert!(!can_be_pushed_down_impl(&lit_expr, &test_schema));
820    }
821
822    #[rstest]
823    fn test_can_be_pushed_down_binary_supported(test_schema: Schema) {
824        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
825        let right =
826            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
827        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
828            as Arc<dyn PhysicalExpr>;
829
830        assert!(can_be_pushed_down_impl(&binary_expr, &test_schema));
831    }
832
833    #[rstest]
834    fn test_can_be_pushed_down_binary_unsupported_operator(test_schema: Schema) {
835        let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
836        let right =
837            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
838        let binary_expr = Arc::new(df_expr::BinaryExpr::new(
839            left,
840            DFOperator::AtQuestion,
841            right,
842        )) as Arc<dyn PhysicalExpr>;
843
844        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
845    }
846
847    #[rstest]
848    fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
849        let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
850        let right =
851            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
852        let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
853            as Arc<dyn PhysicalExpr>;
854
855        assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
856    }
857
858    #[rstest]
859    fn test_can_be_pushed_down_like_supported(test_schema: Schema) {
860        let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
861        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
862            "test%".to_string(),
863        )))) as Arc<dyn PhysicalExpr>;
864        let like_expr =
865            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
866
867        assert!(can_be_pushed_down_impl(&like_expr, &test_schema));
868    }
869
870    #[rstest]
871    fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
872        let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
873        let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
874            "test%".to_string(),
875        )))) as Arc<dyn PhysicalExpr>;
876        let like_expr =
877            Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
878
879        assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
880    }
881
882    // https://github.com/vortex-data/vortex/issues/6211
883    #[tokio::test]
884    async fn test_cast_int_to_string() -> anyhow::Result<()> {
885        let ctx = TestSessionContext::default();
886
887        ctx.session
888            .sql(r#"copy (select 1 as id) to 'example.vortex'"#)
889            .await?
890            .show()
891            .await?;
892
893        ctx.session
894            .sql(r#"select cast(id as string) as sid from 'example.vortex' where id > 0"#)
895            .await?
896            .show()
897            .await?;
898
899        ctx.session
900            .sql(r#"select id from 'example.vortex' where cast (id as string) == '1'"#)
901            .await?
902            .show()
903            .await?;
904
905        // This fails as it pushes string cast to the scan
906        ctx.session
907            .sql(r#"select cast(id as string) from 'example.vortex'"#)
908            .await?
909            .collect()
910            .await?;
911
912        Ok(())
913    }
914
915    /// Test that applying a CASE expression to an Arrow RecordBatch using DataFusion
916    /// matches the result of applying the converted Vortex expression.
917    #[test]
918    fn test_case_when_datafusion_vortex_equivalence() {
919        use datafusion::arrow::array::Int32Array;
920        use datafusion::arrow::array::RecordBatch;
921        use datafusion_physical_expr::expressions::CaseExpr;
922        use vortex::VortexSessionDefault;
923        use vortex::array::ArrayRef;
924        use vortex::array::Canonical;
925        use vortex::array::VortexSessionExecute as _;
926        use vortex::array::arrow::FromArrowArray;
927        use vortex::session::VortexSession;
928
929        // Create test data
930        let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20]));
931        let schema = Arc::new(Schema::new(vec![Field::new(
932            "value",
933            DataType::Int32,
934            false,
935        )]));
936        let batch = RecordBatch::try_new(schema, vec![values]).unwrap();
937
938        // Build a DataFusion CASE expression:
939        // CASE WHEN value > 10 THEN 100 WHEN value > 5 THEN 50 ELSE 0 END
940        let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc<dyn PhysicalExpr>;
941        let lit_10 =
942            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
943        let lit_5 =
944            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
945        let lit_100 =
946            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc<dyn PhysicalExpr>;
947        let lit_50 =
948            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc<dyn PhysicalExpr>;
949        let lit_0 =
950            Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
951
952        // WHEN value > 10 THEN 100
953        let when1 = Arc::new(df_expr::BinaryExpr::new(
954            col_value.clone(),
955            DFOperator::Gt,
956            lit_10,
957        )) as Arc<dyn PhysicalExpr>;
958        // WHEN value > 5 THEN 50
959        let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5))
960            as Arc<dyn PhysicalExpr>;
961
962        let case_expr =
963            CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap();
964
965        // Apply DataFusion expression
966        let df_result = case_expr.evaluate(&batch).unwrap();
967        let df_array = df_result.into_array(batch.num_rows()).unwrap();
968
969        // Convert to Vortex expression
970        let expr_convertor = DefaultExpressionConvertor::default();
971        let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap();
972
973        // Convert batch to Vortex array
974        let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap();
975
976        // Apply Vortex expression
977        let session = VortexSession::default();
978        let mut ctx = session.create_execution_ctx();
979        let vortex_result = vortex_array
980            .apply(&vortex_expr)
981            .unwrap()
982            .execute::<Canonical>(&mut ctx)
983            .unwrap();
984
985        // Convert back to Arrow for comparison
986        let vortex_as_arrow = vortex_result.into_primitive().as_slice::<i32>().to_vec();
987
988        // Convert DataFusion result to Vec for comparison
989        let df_as_arrow: Vec<i32> = df_array
990            .as_any()
991            .downcast_ref::<Int32Array>()
992            .unwrap()
993            .values()
994            .to_vec();
995
996        // Compare results
997        // Expected: [0, 0, 50, 100, 100] for values [1, 5, 10, 15, 20]
998        // value=1: not > 10, not > 5 -> ELSE 0
999        // value=5: not > 10, not > 5 -> ELSE 0
1000        // value=10: not > 10, > 5 -> 50
1001        // value=15: > 10 -> 100
1002        // value=20: > 10 -> 100
1003        assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]);
1004        assert_eq!(vortex_as_arrow, df_as_arrow);
1005    }
1006}