1use std::sync::Arc;
5
6use arrow_schema::DataType;
7use arrow_schema::Schema;
8use datafusion_common::Result as DFResult;
9use datafusion_common::exec_datafusion_err;
10use datafusion_common::tree_node::TreeNode;
11use datafusion_common::tree_node::TreeNodeRecursion;
12use datafusion_expr::Operator as DFOperator;
13use datafusion_functions::core::getfield::GetFieldFunc;
14use datafusion_physical_expr::PhysicalExpr;
15use datafusion_physical_expr::ScalarFunctionExpr;
16use datafusion_physical_expr::projection::ProjectionExpr;
17use datafusion_physical_expr::projection::ProjectionExprs;
18use datafusion_physical_expr::utils::collect_columns;
19use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
20use datafusion_physical_plan::expressions as df_expr;
21use itertools::Itertools;
22use vortex::dtype::DType;
23use vortex::dtype::Nullability;
24use vortex::dtype::arrow::FromArrowType;
25use vortex::expr::Expression;
26use vortex::expr::and_collect;
27use vortex::expr::cast;
28use vortex::expr::get_item;
29use vortex::expr::is_null;
30use vortex::expr::list_contains;
31use vortex::expr::lit;
32use vortex::expr::nested_case_when;
33use vortex::expr::not;
34use vortex::expr::pack;
35use vortex::expr::root;
36use vortex::scalar::Scalar;
37use vortex::scalar_fn::ScalarFnVTableExt;
38use vortex::scalar_fn::fns::binary::Binary;
39use vortex::scalar_fn::fns::like::Like;
40use vortex::scalar_fn::fns::like::LikeOptions;
41use vortex::scalar_fn::fns::operators::Operator;
42
43use crate::convert::FromDataFusion;
44
45pub struct ProcessedProjection {
47 pub scan_projection: Expression,
48 pub leftover_projection: ProjectionExprs,
49}
50
51pub(crate) fn make_vortex_predicate(
53 expr_convertor: &dyn ExpressionConvertor,
54 predicate: &[Arc<dyn PhysicalExpr>],
55) -> DFResult<Option<Expression>> {
56 let exprs = predicate
57 .iter()
58 .map(|e| expr_convertor.convert(e.as_ref()))
59 .collect::<DFResult<Vec<_>>>()?;
60
61 Ok(and_collect(exprs))
62}
63
64pub trait ExpressionConvertor: Send + Sync {
66 fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;
68
69 fn convert(&self, expr: &dyn PhysicalExpr) -> DFResult<Expression>;
71
72 fn split_projection(
75 &self,
76 source_projection: ProjectionExprs,
77 input_schema: &Schema,
78 output_schema: &Schema,
79 ) -> DFResult<ProcessedProjection>;
80
81 fn no_pushdown_projection(
84 &self,
85 source_projection: ProjectionExprs,
86 input_schema: &Schema,
87 ) -> DFResult<ProcessedProjection> {
88 let column_indices = source_projection.column_indices();
90
91 let scan_columns: Vec<(String, Expression)> = column_indices
93 .into_iter()
94 .map(|idx| {
95 let field = input_schema.field(idx);
96 let name = field.name().clone();
97 (name.clone(), get_item(name, root()))
98 })
99 .collect();
100
101 Ok(ProcessedProjection {
102 scan_projection: pack(scan_columns, Nullability::NonNullable),
103 leftover_projection: source_projection,
104 })
105 }
106}
107
108#[derive(Default)]
110pub struct DefaultExpressionConvertor {}
111
112impl DefaultExpressionConvertor {
113 fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
115 if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
116 {
117 let (source_expr, field_names) = get_field_fn
122 .args()
123 .split_first()
124 .ok_or_else(|| exec_datafusion_err!("get_field missing source expression"))?;
125
126 let mut result = self.convert(source_expr.as_ref())?;
127 for expr in field_names {
128 let field_name = expr
129 .as_any()
130 .downcast_ref::<df_expr::Literal>()
131 .ok_or_else(|| exec_datafusion_err!("get_field field name must be a literal"))?
132 .value()
133 .try_as_str()
134 .flatten()
135 .ok_or_else(|| {
136 exec_datafusion_err!("get_field field name must be a UTF-8 string")
137 })?;
138 result = get_item(field_name.to_string(), result);
139 }
140 return Ok(result);
141 }
142
143 Err(exec_datafusion_err!(
144 "Unsupported ScalarFunctionExpr: {}",
145 scalar_fn.name()
146 ))
147 }
148
149 fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult<Expression> {
151 if case_expr.expr().is_some() {
158 return Err(exec_datafusion_err!(
159 "CASE expr WHEN form is not yet supported, only searched CASE is supported"
160 ));
161 }
162
163 let when_then_pairs = case_expr.when_then_expr();
164 if when_then_pairs.is_empty() {
165 return Err(exec_datafusion_err!(
166 "CASE expression must have at least one WHEN clause"
167 ));
168 }
169
170 let mut pairs = Vec::with_capacity(when_then_pairs.len());
172 for (when_expr, then_expr) in when_then_pairs {
173 let condition = self.convert(when_expr.as_ref())?;
174 let value = self.convert(then_expr.as_ref())?;
175 pairs.push((condition, value));
176 }
177
178 let else_value = case_expr
180 .else_expr()
181 .map(|e| self.convert(e.as_ref()))
182 .transpose()?;
183
184 Ok(nested_case_when(pairs, else_value))
186 }
187}
188
189impl ExpressionConvertor for DefaultExpressionConvertor {
190 fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
191 can_be_pushed_down_impl(expr, schema)
192 }
193
194 fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
195 if let Some(binary_expr) = df.as_any().downcast_ref::<df_expr::BinaryExpr>() {
198 let left = self.convert(binary_expr.left().as_ref())?;
199 let right = self.convert(binary_expr.right().as_ref())?;
200 let operator = try_operator_from_df(binary_expr.op())?;
201
202 return Ok(Binary.new_expr(operator, [left, right]));
203 }
204
205 if let Some(col_expr) = df.as_any().downcast_ref::<df_expr::Column>() {
206 return Ok(get_item(col_expr.name().to_owned(), root()));
207 }
208
209 if let Some(like) = df.as_any().downcast_ref::<df_expr::LikeExpr>() {
210 let child = self.convert(like.expr().as_ref())?;
211 let pattern = self.convert(like.pattern().as_ref())?;
212 return Ok(Like.new_expr(
213 LikeOptions {
214 negated: like.negated(),
215 case_insensitive: like.case_insensitive(),
216 },
217 [child, pattern],
218 ));
219 }
220
221 if let Some(literal) = df.as_any().downcast_ref::<df_expr::Literal>() {
222 let value = Scalar::from_df(literal.value());
223 return Ok(lit(value));
224 }
225
226 if let Some(cast_expr) = df.as_any().downcast_ref::<df_expr::CastExpr>() {
227 let cast_dtype = DType::from_arrow((cast_expr.cast_type(), Nullability::Nullable));
228 let child = self.convert(cast_expr.expr().as_ref())?;
229 return Ok(cast(child, cast_dtype));
230 }
231
232 if let Some(cast_col_expr) = df.as_any().downcast_ref::<df_expr::CastColumnExpr>() {
233 let target = cast_col_expr.target_field();
234
235 let target_dtype = DType::from_arrow((target.data_type(), target.is_nullable().into()));
236 let child = self.convert(cast_col_expr.expr().as_ref())?;
237 return Ok(cast(child, target_dtype));
238 }
239
240 if let Some(is_null_expr) = df.as_any().downcast_ref::<df_expr::IsNullExpr>() {
241 let arg = self.convert(is_null_expr.arg().as_ref())?;
242 return Ok(is_null(arg));
243 }
244
245 if let Some(is_not_null_expr) = df.as_any().downcast_ref::<df_expr::IsNotNullExpr>() {
246 let arg = self.convert(is_not_null_expr.arg().as_ref())?;
247 return Ok(not(is_null(arg)));
248 }
249
250 if let Some(in_list) = df.as_any().downcast_ref::<df_expr::InListExpr>() {
251 let value = self.convert(in_list.expr().as_ref())?;
252 let list_elements: Vec<_> = in_list
253 .list()
254 .iter()
255 .map(|e| {
256 if let Some(lit) = e.as_any().downcast_ref::<df_expr::Literal>() {
257 Ok(Scalar::from_df(lit.value()))
258 } else {
259 Err(exec_datafusion_err!("Failed to cast sub-expression"))
260 }
261 })
262 .try_collect()?;
263
264 let list = Scalar::list(
265 list_elements[0].dtype().clone(),
266 list_elements,
267 Nullability::Nullable,
268 );
269 let expr = list_contains(lit(list), value);
270
271 return Ok(if in_list.negated() { not(expr) } else { expr });
272 }
273
274 if let Some(scalar_fn) = df.as_any().downcast_ref::<ScalarFunctionExpr>() {
275 return self.try_convert_scalar_function(scalar_fn);
276 }
277
278 if let Some(case_expr) = df.as_any().downcast_ref::<df_expr::CaseExpr>() {
279 return self.try_convert_case_expr(case_expr);
280 }
281
282 Err(exec_datafusion_err!(
283 "Couldn't convert DataFusion physical {df} expression to a vortex expression"
284 ))
285 }
286
287 fn split_projection(
288 &self,
289 source_projection: ProjectionExprs,
290 input_schema: &Schema,
291 output_schema: &Schema,
292 ) -> DFResult<ProcessedProjection> {
293 let mut scan_projection = vec![];
294 let mut leftover_projection: Vec<ProjectionExpr> = vec![];
295
296 for projection_expr in source_projection.iter() {
297 let r = projection_expr.expr.apply(|node| {
298 if let Some(scalar_fn_expr) = node.as_any().downcast_ref::<ScalarFunctionExpr>()
300 && !can_scalar_fn_be_pushed_down(scalar_fn_expr)
301 {
302 scan_projection.extend(
303 collect_columns(node)
304 .into_iter()
305 .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
306 );
307
308 leftover_projection.push(projection_expr.clone());
309 return Ok(TreeNodeRecursion::Stop);
310 }
311
312 if let Some(binary_expr) = node.as_any().downcast_ref::<df_expr::BinaryExpr>()
315 && binary_expr.op().is_numerical_operators()
316 && (is_decimal(&binary_expr.left().data_type(input_schema)?)
317 && is_decimal(&binary_expr.right().data_type(input_schema)?))
318 {
319 scan_projection.extend(
320 collect_columns(node)
321 .into_iter()
322 .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
323 );
324
325 leftover_projection.push(projection_expr.clone());
326 return Ok(TreeNodeRecursion::Stop);
327 }
328
329 Ok(TreeNodeRecursion::Continue)
330 })?;
331
332 if matches!(r, TreeNodeRecursion::Continue) {
334 scan_projection.push((
335 projection_expr.alias.clone(),
336 self.convert(projection_expr.expr.as_ref())?,
337 ));
338 leftover_projection.push(ProjectionExpr {
339 expr: Arc::new(df_expr::Column::new_with_schema(
340 projection_expr.alias.as_str(),
341 output_schema,
342 )?),
343 alias: projection_expr.alias.clone(),
344 });
345 }
346 }
347
348 Ok(ProcessedProjection {
349 scan_projection: pack(scan_projection, Nullability::NonNullable),
350 leftover_projection: leftover_projection.into(),
351 })
352 }
353}
354
355fn try_operator_from_df(value: &DFOperator) -> DFResult<Operator> {
356 match value {
357 DFOperator::Eq => Ok(Operator::Eq),
358 DFOperator::NotEq => Ok(Operator::NotEq),
359 DFOperator::Lt => Ok(Operator::Lt),
360 DFOperator::LtEq => Ok(Operator::Lte),
361 DFOperator::Gt => Ok(Operator::Gt),
362 DFOperator::GtEq => Ok(Operator::Gte),
363 DFOperator::And => Ok(Operator::And),
364 DFOperator::Or => Ok(Operator::Or),
365 DFOperator::Plus => Ok(Operator::Add),
366 DFOperator::Minus => Ok(Operator::Sub),
367 DFOperator::Multiply => Ok(Operator::Mul),
368 DFOperator::Divide => Ok(Operator::Div),
369 DFOperator::IsDistinctFrom
370 | DFOperator::IsNotDistinctFrom
371 | DFOperator::RegexMatch
372 | DFOperator::RegexIMatch
373 | DFOperator::RegexNotMatch
374 | DFOperator::RegexNotIMatch
375 | DFOperator::LikeMatch
376 | DFOperator::ILikeMatch
377 | DFOperator::NotLikeMatch
378 | DFOperator::NotILikeMatch
379 | DFOperator::BitwiseAnd
380 | DFOperator::BitwiseOr
381 | DFOperator::BitwiseXor
382 | DFOperator::BitwiseShiftRight
383 | DFOperator::BitwiseShiftLeft
384 | DFOperator::StringConcat
385 | DFOperator::AtArrow
386 | DFOperator::ArrowAt
387 | DFOperator::Modulo
388 | DFOperator::Arrow
389 | DFOperator::LongArrow
390 | DFOperator::HashArrow
391 | DFOperator::HashLongArrow
392 | DFOperator::AtAt
393 | DFOperator::IntegerDivide
394 | DFOperator::HashMinus
395 | DFOperator::AtQuestion
396 | DFOperator::Question
397 | DFOperator::QuestionAnd
398 | DFOperator::QuestionPipe => {
399 tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
400 Err(exec_datafusion_err!(
401 "Unsupported datafusion operator {value}"
402 ))
403 }
404 }
405}
406
407fn can_be_pushed_down_impl(df_expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
408 if is_dynamic_physical_expr(df_expr) {
411 return false;
412 }
413
414 let expr = df_expr.as_any();
415 if let Some(binary) = expr.downcast_ref::<df_expr::BinaryExpr>() {
416 can_binary_be_pushed_down(binary, schema)
417 } else if let Some(col) = expr.downcast_ref::<df_expr::Column>() {
418 schema
419 .field_with_name(col.name())
420 .ok()
421 .is_some_and(|field| supported_data_types(field.data_type()))
422 } else if let Some(like) = expr.downcast_ref::<df_expr::LikeExpr>() {
423 can_be_pushed_down_impl(like.expr(), schema)
424 && can_be_pushed_down_impl(like.pattern(), schema)
425 } else if let Some(lit) = expr.downcast_ref::<df_expr::Literal>() {
426 supported_data_types(&lit.value().data_type())
427 } else if let Some(cast_expr) = expr.downcast_ref::<df_expr::CastExpr>() {
428 is_convertible_expr(cast_expr.expr())
430 } else if let Some(cast_col_expr) = expr.downcast_ref::<df_expr::CastColumnExpr>() {
431 is_convertible_expr(cast_col_expr.expr())
433 } else if let Some(is_null) = expr.downcast_ref::<df_expr::IsNullExpr>() {
434 can_be_pushed_down_impl(is_null.arg(), schema)
435 } else if let Some(is_not_null) = expr.downcast_ref::<df_expr::IsNotNullExpr>() {
436 can_be_pushed_down_impl(is_not_null.arg(), schema)
437 } else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
438 can_be_pushed_down_impl(in_list.expr(), schema)
439 && in_list
440 .list()
441 .iter()
442 .all(|e| can_be_pushed_down_impl(e, schema))
443 } else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
444 can_scalar_fn_be_pushed_down(scalar_fn)
445 } else if let Some(case_expr) = expr.downcast_ref::<df_expr::CaseExpr>() {
446 can_case_be_pushed_down(case_expr, schema)
447 } else {
448 tracing::debug!(%df_expr, "DataFusion expression can't be pushed down");
449 false
450 }
451}
452
453fn is_convertible_expr(df_expr: &Arc<dyn PhysicalExpr>) -> bool {
457 let expr = df_expr.as_any();
458
459 expr.downcast_ref::<df_expr::BinaryExpr>().is_some()
461 || expr.downcast_ref::<df_expr::Column>().is_some()
462 || expr.downcast_ref::<df_expr::LikeExpr>().is_some()
463 || expr.downcast_ref::<df_expr::Literal>().is_some()
464 || expr
465 .downcast_ref::<df_expr::CastExpr>()
466 .is_some_and(|e| is_convertible_expr(e.expr()))
467 || expr
468 .downcast_ref::<df_expr::CastColumnExpr>()
469 .is_some_and(|e| is_convertible_expr(e.expr()))
470 || expr.downcast_ref::<df_expr::IsNullExpr>().is_some()
471 || expr.downcast_ref::<df_expr::IsNotNullExpr>().is_some()
472 || expr.downcast_ref::<df_expr::InListExpr>().is_some()
473 || expr
474 .downcast_ref::<ScalarFunctionExpr>()
475 .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some())
476}
477
478fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
479 let is_op_supported = try_operator_from_df(binary.op()).is_ok();
480 is_op_supported
481 && can_be_pushed_down_impl(binary.left(), schema)
482 && can_be_pushed_down_impl(binary.right(), schema)
483}
484
485fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool {
486 if case_expr.expr().is_some() {
489 return false;
490 }
491
492 for (when_expr, then_expr) in case_expr.when_then_expr() {
494 if !can_be_pushed_down_impl(when_expr, schema)
495 || !can_be_pushed_down_impl(then_expr, schema)
496 {
497 return false;
498 }
499 }
500
501 if let Some(else_expr) = case_expr.else_expr()
503 && !can_be_pushed_down_impl(else_expr, schema)
504 {
505 return false;
506 }
507
508 true
509}
510
511fn supported_data_types(dt: &DataType) -> bool {
512 use DataType::*;
513
514 if let Dictionary(_, value_type) = dt {
516 return supported_data_types(value_type.as_ref());
517 }
518
519 let is_supported = dt.is_null()
520 || dt.is_numeric()
521 || matches!(
522 dt,
523 Boolean
524 | Utf8
525 | LargeUtf8
526 | Utf8View
527 | Binary
528 | LargeBinary
529 | BinaryView
530 | Date32
531 | Date64
532 | Timestamp(_, _)
533 | Time32(_)
534 | Time64(_)
535 );
536
537 if !is_supported {
538 tracing::debug!("DataFusion data type {dt:?} is not supported");
539 }
540
541 is_supported
542}
543
544fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
547 ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
548}
549
550fn is_decimal(dt: &DataType) -> bool {
552 matches!(
553 dt,
554 DataType::Decimal32(_, _)
555 | DataType::Decimal64(_, _)
556 | DataType::Decimal128(_, _)
557 | DataType::Decimal256(_, _)
558 )
559}
560
561#[cfg(test)]
562mod tests {
563 use std::sync::Arc;
564
565 use arrow_schema::DataType;
566 use arrow_schema::Field;
567 use arrow_schema::Schema;
568 use arrow_schema::TimeUnit as ArrowTimeUnit;
569 use datafusion_common::ScalarValue;
570 use datafusion_expr::Operator as DFOperator;
571 use datafusion_physical_expr::PhysicalExpr;
572 use datafusion_physical_plan::expressions as df_expr;
573 use insta::assert_snapshot;
574 use rstest::rstest;
575
576 use super::*;
577 use crate::common_tests::TestSessionContext;
578
579 #[rstest::fixture]
580 fn test_schema() -> Schema {
581 Schema::new(vec![
582 Field::new("id", DataType::Int32, false),
583 Field::new("name", DataType::Utf8, true),
584 Field::new("score", DataType::Float64, true),
585 Field::new("active", DataType::Boolean, false),
586 Field::new(
587 "created_at",
588 DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
589 true,
590 ),
591 Field::new(
592 "unsupported_list",
593 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
594 true,
595 ),
596 ])
597 }
598
599 #[test]
600 fn test_make_vortex_predicate_empty() {
601 let expr_convertor = DefaultExpressionConvertor::default();
602 let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
603 assert!(result.is_none());
604 }
605
606 #[test]
607 fn test_make_vortex_predicate_single() {
608 let expr_convertor = DefaultExpressionConvertor::default();
609 let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
610 let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
611 assert!(result.is_some());
612 }
613
614 #[test]
615 fn test_make_vortex_predicate_multiple() {
616 let expr_convertor = DefaultExpressionConvertor::default();
617 let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
618 let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
619 let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
620 assert!(result.is_some());
621 }
623
624 #[rstest]
625 #[case::eq(DFOperator::Eq, Operator::Eq)]
626 #[case::not_eq(DFOperator::NotEq, Operator::NotEq)]
627 #[case::lt(DFOperator::Lt, Operator::Lt)]
628 #[case::lte(DFOperator::LtEq, Operator::Lte)]
629 #[case::gt(DFOperator::Gt, Operator::Gt)]
630 #[case::gte(DFOperator::GtEq, Operator::Gte)]
631 #[case::and(DFOperator::And, Operator::And)]
632 #[case::or(DFOperator::Or, Operator::Or)]
633 #[case::plus(DFOperator::Plus, Operator::Add)]
634 #[case::plus(DFOperator::Minus, Operator::Sub)]
635 #[case::plus(DFOperator::Multiply, Operator::Mul)]
636 #[case::plus(DFOperator::Divide, Operator::Div)]
637 fn test_operator_conversion_supported(
638 #[case] df_op: DFOperator,
639 #[case] expected_vortex_op: Operator,
640 ) {
641 let result = try_operator_from_df(&df_op).unwrap();
642 assert_eq!(result, expected_vortex_op);
643 }
644
645 #[rstest]
646 #[case::modulo(DFOperator::Modulo)]
647 #[case::bitwise_and(DFOperator::BitwiseAnd)]
648 #[case::regex_match(DFOperator::RegexMatch)]
649 #[case::like_match(DFOperator::LikeMatch)]
650 fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
651 let result = try_operator_from_df(&df_op);
652 assert!(result.is_err());
653 assert!(
654 result
655 .unwrap_err()
656 .to_string()
657 .contains("Unsupported datafusion operator")
658 );
659 }
660
661 #[test]
662 fn test_expr_from_df_column() {
663 let col_expr = df_expr::Column::new("test_column", 0);
664 let result = DefaultExpressionConvertor::default()
665 .convert(&col_expr)
666 .unwrap();
667
668 assert_snapshot!(result.display_tree().to_string(), @r"
669 vortex.get_item(test_column)
670 └── input: vortex.root()
671 ");
672 }
673
674 #[test]
675 fn test_expr_from_df_literal() {
676 let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
677 let result = DefaultExpressionConvertor::default()
678 .convert(&literal_expr)
679 .unwrap();
680
681 assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
682 }
683
684 #[test]
685 fn test_expr_from_df_binary() {
686 let left = Arc::new(df_expr::Column::new("left", 0)) as Arc<dyn PhysicalExpr>;
687 let right =
688 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
689 let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);
690
691 let result = DefaultExpressionConvertor::default()
692 .convert(&binary_expr)
693 .unwrap();
694
695 assert_snapshot!(result.display_tree().to_string(), @r"
696 vortex.binary(=)
697 ├── lhs: vortex.get_item(left)
698 │ └── input: vortex.root()
699 └── rhs: vortex.literal(42i32)
700 ");
701 }
702
703 #[rstest]
704 #[case::like_normal(false, false)]
705 #[case::like_negated(true, false)]
706 #[case::like_case_insensitive(false, true)]
707 #[case::like_negated_case_insensitive(true, true)]
708 fn test_expr_from_df_like(#[case] negated: bool, #[case] case_insensitive: bool) {
709 let expr = Arc::new(df_expr::Column::new("text_col", 0)) as Arc<dyn PhysicalExpr>;
710 let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
711 "test%".to_string(),
712 )))) as Arc<dyn PhysicalExpr>;
713 let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);
714
715 let result = DefaultExpressionConvertor::default()
716 .convert(&like_expr)
717 .unwrap();
718 let like_opts = result.as_::<Like>();
719 assert_eq!(
720 like_opts,
721 &LikeOptions {
722 negated,
723 case_insensitive
724 }
725 );
726 }
727
728 #[rstest]
729 #[case::null(DataType::Null, true)]
731 #[case::boolean(DataType::Boolean, true)]
732 #[case::int8(DataType::Int8, true)]
733 #[case::int16(DataType::Int16, true)]
734 #[case::int32(DataType::Int32, true)]
735 #[case::int64(DataType::Int64, true)]
736 #[case::uint8(DataType::UInt8, true)]
737 #[case::uint16(DataType::UInt16, true)]
738 #[case::uint32(DataType::UInt32, true)]
739 #[case::uint64(DataType::UInt64, true)]
740 #[case::float32(DataType::Float32, true)]
741 #[case::float64(DataType::Float64, true)]
742 #[case::utf8(DataType::Utf8, true)]
743 #[case::utf8_view(DataType::Utf8View, true)]
744 #[case::binary(DataType::Binary, true)]
745 #[case::binary_view(DataType::BinaryView, true)]
746 #[case::date32(DataType::Date32, true)]
747 #[case::date64(DataType::Date64, true)]
748 #[case::timestamp_ms(DataType::Timestamp(ArrowTimeUnit::Millisecond, None), true)]
749 #[case::timestamp_us(
750 DataType::Timestamp(ArrowTimeUnit::Microsecond, Some(Arc::from("UTC"))),
751 true
752 )]
753 #[case::time32_s(DataType::Time32(ArrowTimeUnit::Second), true)]
754 #[case::time64_ns(DataType::Time64(ArrowTimeUnit::Nanosecond), true)]
755 #[case::list(
757 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
758 false
759 )]
760 #[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()
761 ), false)]
762 #[case::dict_utf8(
764 DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
765 true
766 )]
767 #[case::dict_int32(
768 DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)),
769 true
770 )]
771 #[case::dict_unsupported(
772 DataType::Dictionary(
773 Box::new(DataType::UInt32),
774 Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true))))
775 ),
776 false
777 )]
778 fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) {
779 assert_eq!(supported_data_types(&data_type), expected);
780 }
781
782 #[rstest]
783 fn test_can_be_pushed_down_column_supported(test_schema: Schema) {
784 let col_expr = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
785
786 assert!(can_be_pushed_down_impl(&col_expr, &test_schema));
787 }
788
789 #[rstest]
790 fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
791 let col_expr =
792 Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
793
794 assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
795 }
796
797 #[rstest]
798 fn test_can_be_pushed_down_column_not_found(test_schema: Schema) {
799 let col_expr = Arc::new(df_expr::Column::new("nonexistent", 99)) as Arc<dyn PhysicalExpr>;
800
801 assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
802 }
803
804 #[rstest]
805 fn test_can_be_pushed_down_literal_supported(test_schema: Schema) {
806 let lit_expr =
807 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
808
809 assert!(can_be_pushed_down_impl(&lit_expr, &test_schema));
810 }
811
812 #[rstest]
813 fn test_can_be_pushed_down_literal_unsupported(test_schema: Schema) {
814 let unsupported_literal = ScalarValue::DurationSecond(Some(42));
816 let lit_expr =
817 Arc::new(df_expr::Literal::new(unsupported_literal)) as Arc<dyn PhysicalExpr>;
818
819 assert!(!can_be_pushed_down_impl(&lit_expr, &test_schema));
820 }
821
822 #[rstest]
823 fn test_can_be_pushed_down_binary_supported(test_schema: Schema) {
824 let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
825 let right =
826 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
827 let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
828 as Arc<dyn PhysicalExpr>;
829
830 assert!(can_be_pushed_down_impl(&binary_expr, &test_schema));
831 }
832
833 #[rstest]
834 fn test_can_be_pushed_down_binary_unsupported_operator(test_schema: Schema) {
835 let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
836 let right =
837 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
838 let binary_expr = Arc::new(df_expr::BinaryExpr::new(
839 left,
840 DFOperator::AtQuestion,
841 right,
842 )) as Arc<dyn PhysicalExpr>;
843
844 assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
845 }
846
847 #[rstest]
848 fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
849 let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
850 let right =
851 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
852 let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
853 as Arc<dyn PhysicalExpr>;
854
855 assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
856 }
857
858 #[rstest]
859 fn test_can_be_pushed_down_like_supported(test_schema: Schema) {
860 let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
861 let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
862 "test%".to_string(),
863 )))) as Arc<dyn PhysicalExpr>;
864 let like_expr =
865 Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
866
867 assert!(can_be_pushed_down_impl(&like_expr, &test_schema));
868 }
869
870 #[rstest]
871 fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
872 let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
873 let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
874 "test%".to_string(),
875 )))) as Arc<dyn PhysicalExpr>;
876 let like_expr =
877 Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
878
879 assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
880 }
881
882 #[tokio::test]
884 async fn test_cast_int_to_string() -> anyhow::Result<()> {
885 let ctx = TestSessionContext::default();
886
887 ctx.session
888 .sql(r#"copy (select 1 as id) to 'example.vortex'"#)
889 .await?
890 .show()
891 .await?;
892
893 ctx.session
894 .sql(r#"select cast(id as string) as sid from 'example.vortex' where id > 0"#)
895 .await?
896 .show()
897 .await?;
898
899 ctx.session
900 .sql(r#"select id from 'example.vortex' where cast (id as string) == '1'"#)
901 .await?
902 .show()
903 .await?;
904
905 ctx.session
907 .sql(r#"select cast(id as string) from 'example.vortex'"#)
908 .await?
909 .collect()
910 .await?;
911
912 Ok(())
913 }
914
915 #[test]
918 fn test_case_when_datafusion_vortex_equivalence() {
919 use datafusion::arrow::array::Int32Array;
920 use datafusion::arrow::array::RecordBatch;
921 use datafusion_physical_expr::expressions::CaseExpr;
922 use vortex::VortexSessionDefault;
923 use vortex::array::ArrayRef;
924 use vortex::array::Canonical;
925 use vortex::array::VortexSessionExecute as _;
926 use vortex::array::arrow::FromArrowArray;
927 use vortex::session::VortexSession;
928
929 let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20]));
931 let schema = Arc::new(Schema::new(vec![Field::new(
932 "value",
933 DataType::Int32,
934 false,
935 )]));
936 let batch = RecordBatch::try_new(schema, vec![values]).unwrap();
937
938 let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc<dyn PhysicalExpr>;
941 let lit_10 =
942 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
943 let lit_5 =
944 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
945 let lit_100 =
946 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc<dyn PhysicalExpr>;
947 let lit_50 =
948 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc<dyn PhysicalExpr>;
949 let lit_0 =
950 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
951
952 let when1 = Arc::new(df_expr::BinaryExpr::new(
954 col_value.clone(),
955 DFOperator::Gt,
956 lit_10,
957 )) as Arc<dyn PhysicalExpr>;
958 let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5))
960 as Arc<dyn PhysicalExpr>;
961
962 let case_expr =
963 CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap();
964
965 let df_result = case_expr.evaluate(&batch).unwrap();
967 let df_array = df_result.into_array(batch.num_rows()).unwrap();
968
969 let expr_convertor = DefaultExpressionConvertor::default();
971 let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap();
972
973 let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap();
975
976 let session = VortexSession::default();
978 let mut ctx = session.create_execution_ctx();
979 let vortex_result = vortex_array
980 .apply(&vortex_expr)
981 .unwrap()
982 .execute::<Canonical>(&mut ctx)
983 .unwrap();
984
985 let vortex_as_arrow = vortex_result.into_primitive().as_slice::<i32>().to_vec();
987
988 let df_as_arrow: Vec<i32> = df_array
990 .as_any()
991 .downcast_ref::<Int32Array>()
992 .unwrap()
993 .values()
994 .to_vec();
995
996 assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]);
1004 assert_eq!(vortex_as_arrow, df_as_arrow);
1005 }
1006}