1use std::sync::Arc;
5
6use arrow_schema::DataType;
7use arrow_schema::Schema;
8use datafusion_common::Result as DFResult;
9use datafusion_common::exec_datafusion_err;
10use datafusion_common::tree_node::TreeNode;
11use datafusion_common::tree_node::TreeNodeRecursion;
12use datafusion_expr::Operator as DFOperator;
13use datafusion_functions::core::getfield::GetFieldFunc;
14use datafusion_physical_expr::PhysicalExpr;
15use datafusion_physical_expr::ScalarFunctionExpr;
16use datafusion_physical_expr::projection::ProjectionExpr;
17use datafusion_physical_expr::projection::ProjectionExprs;
18use datafusion_physical_expr::utils::collect_columns;
19use datafusion_physical_expr_common::physical_expr::is_dynamic_physical_expr;
20use datafusion_physical_plan::expressions as df_expr;
21use itertools::Itertools;
22use vortex::dtype::DType;
23use vortex::dtype::Nullability;
24use vortex::dtype::arrow::FromArrowType;
25use vortex::expr::Expression;
26use vortex::expr::and_collect;
27use vortex::expr::cast;
28use vortex::expr::get_item;
29use vortex::expr::is_not_null;
30use vortex::expr::is_null;
31use vortex::expr::list_contains;
32use vortex::expr::lit;
33use vortex::expr::nested_case_when;
34use vortex::expr::not;
35use vortex::expr::pack;
36use vortex::expr::root;
37use vortex::scalar::Scalar;
38use vortex::scalar_fn::ScalarFnVTableExt;
39use vortex::scalar_fn::fns::binary::Binary;
40use vortex::scalar_fn::fns::like::Like;
41use vortex::scalar_fn::fns::like::LikeOptions;
42use vortex::scalar_fn::fns::operators::Operator;
43
44use crate::convert::FromDataFusion;
45
46pub struct ProcessedProjection {
48 pub scan_projection: Expression,
49 pub leftover_projection: ProjectionExprs,
50}
51
52pub(crate) fn make_vortex_predicate(
54 expr_convertor: &dyn ExpressionConvertor,
55 predicate: &[Arc<dyn PhysicalExpr>],
56) -> DFResult<Option<Expression>> {
57 let exprs = predicate
58 .iter()
59 .map(|e| expr_convertor.convert(e.as_ref()))
60 .collect::<DFResult<Vec<_>>>()?;
61
62 Ok(and_collect(exprs))
63}
64
65pub trait ExpressionConvertor: Send + Sync {
67 fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;
69
70 fn convert(&self, expr: &dyn PhysicalExpr) -> DFResult<Expression>;
72
73 fn split_projection(
76 &self,
77 source_projection: ProjectionExprs,
78 input_schema: &Schema,
79 output_schema: &Schema,
80 ) -> DFResult<ProcessedProjection>;
81
82 fn no_pushdown_projection(
85 &self,
86 source_projection: ProjectionExprs,
87 input_schema: &Schema,
88 ) -> DFResult<ProcessedProjection> {
89 let column_indices = source_projection.column_indices();
91
92 let scan_columns: Vec<(String, Expression)> = column_indices
94 .into_iter()
95 .map(|idx| {
96 let field = input_schema.field(idx);
97 let name = field.name().clone();
98 (name.clone(), get_item(name, root()))
99 })
100 .collect();
101
102 Ok(ProcessedProjection {
103 scan_projection: pack(scan_columns, Nullability::NonNullable),
104 leftover_projection: source_projection,
105 })
106 }
107}
108
109#[derive(Default)]
111pub struct DefaultExpressionConvertor {}
112
113impl DefaultExpressionConvertor {
114 fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
116 if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
117 {
118 let (source_expr, field_names) = get_field_fn
123 .args()
124 .split_first()
125 .ok_or_else(|| exec_datafusion_err!("get_field missing source expression"))?;
126
127 let mut result = self.convert(source_expr.as_ref())?;
128 for expr in field_names {
129 let field_name = expr
130 .downcast_ref::<df_expr::Literal>()
131 .ok_or_else(|| exec_datafusion_err!("get_field field name must be a literal"))?
132 .value()
133 .try_as_str()
134 .flatten()
135 .ok_or_else(|| {
136 exec_datafusion_err!("get_field field name must be a UTF-8 string")
137 })?;
138 result = get_item(field_name.to_string(), result);
139 }
140 return Ok(result);
141 }
142
143 Err(exec_datafusion_err!(
144 "Unsupported ScalarFunctionExpr: {}",
145 scalar_fn.name()
146 ))
147 }
148
149 fn try_convert_case_expr(&self, case_expr: &df_expr::CaseExpr) -> DFResult<Expression> {
151 if case_expr.expr().is_some() {
158 return Err(exec_datafusion_err!(
159 "CASE expr WHEN form is not yet supported, only searched CASE is supported"
160 ));
161 }
162
163 let when_then_pairs = case_expr.when_then_expr();
164 if when_then_pairs.is_empty() {
165 return Err(exec_datafusion_err!(
166 "CASE expression must have at least one WHEN clause"
167 ));
168 }
169
170 let mut pairs = Vec::with_capacity(when_then_pairs.len());
172 for (when_expr, then_expr) in when_then_pairs {
173 let condition = self.convert(when_expr.as_ref())?;
174 let value = self.convert(then_expr.as_ref())?;
175 pairs.push((condition, value));
176 }
177
178 let else_value = case_expr
180 .else_expr()
181 .map(|e| self.convert(e.as_ref()))
182 .transpose()?;
183
184 Ok(nested_case_when(pairs, else_value))
186 }
187}
188
189impl ExpressionConvertor for DefaultExpressionConvertor {
190 fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
191 can_be_pushed_down_impl(expr, schema)
192 }
193
194 fn convert(&self, df: &dyn PhysicalExpr) -> DFResult<Expression> {
195 if let Some(binary_expr) = df.downcast_ref::<df_expr::BinaryExpr>() {
198 let left = self.convert(binary_expr.left().as_ref())?;
199 let right = self.convert(binary_expr.right().as_ref())?;
200 let operator = try_operator_from_df(binary_expr.op())?;
201
202 return Ok(Binary.new_expr(operator, [left, right]));
203 }
204
205 if let Some(col_expr) = df.downcast_ref::<df_expr::Column>() {
206 return Ok(get_item(col_expr.name().to_owned(), root()));
207 }
208
209 if let Some(like) = df.downcast_ref::<df_expr::LikeExpr>() {
210 let child = self.convert(like.expr().as_ref())?;
211 let pattern = self.convert(like.pattern().as_ref())?;
212 return Ok(Like.new_expr(
213 LikeOptions {
214 negated: like.negated(),
215 case_insensitive: like.case_insensitive(),
216 },
217 [child, pattern],
218 ));
219 }
220
221 if let Some(literal) = df.downcast_ref::<df_expr::Literal>() {
222 let value = Scalar::from_df(literal.value());
223 return Ok(lit(value));
224 }
225
226 if let Some(cast_expr) = df.downcast_ref::<df_expr::CastExpr>() {
227 let cast_dtype = DType::from_arrow(cast_expr.target_field().as_ref());
228 let child = self.convert(cast_expr.expr().as_ref())?;
229 return Ok(cast(child, cast_dtype));
230 }
231
232 if let Some(is_null_expr) = df.downcast_ref::<df_expr::IsNullExpr>() {
233 let arg = self.convert(is_null_expr.arg().as_ref())?;
234 return Ok(is_null(arg));
235 }
236
237 if let Some(is_not_null_expr) = df.downcast_ref::<df_expr::IsNotNullExpr>() {
238 let arg = self.convert(is_not_null_expr.arg().as_ref())?;
239 return Ok(is_not_null(arg));
240 }
241
242 if let Some(in_list) = df.downcast_ref::<df_expr::InListExpr>() {
243 let value = self.convert(in_list.expr().as_ref())?;
244 let list_elements: Vec<_> = in_list
245 .list()
246 .iter()
247 .map(|e| {
248 if let Some(lit) = e.downcast_ref::<df_expr::Literal>() {
249 Ok(Scalar::from_df(lit.value()))
250 } else {
251 Err(exec_datafusion_err!("Failed to cast sub-expression"))
252 }
253 })
254 .try_collect()?;
255
256 let list = Scalar::list(
257 list_elements[0].dtype().clone(),
258 list_elements,
259 Nullability::Nullable,
260 );
261 let expr = list_contains(lit(list), value);
262
263 return Ok(if in_list.negated() { not(expr) } else { expr });
264 }
265
266 if let Some(scalar_fn) = df.downcast_ref::<ScalarFunctionExpr>() {
267 return self.try_convert_scalar_function(scalar_fn);
268 }
269
270 if let Some(case_expr) = df.downcast_ref::<df_expr::CaseExpr>() {
271 return self.try_convert_case_expr(case_expr);
272 }
273
274 Err(exec_datafusion_err!(
275 "Couldn't convert DataFusion physical {df} expression to a vortex expression"
276 ))
277 }
278
279 fn split_projection(
280 &self,
281 source_projection: ProjectionExprs,
282 input_schema: &Schema,
283 output_schema: &Schema,
284 ) -> DFResult<ProcessedProjection> {
285 let mut scan_projection = vec![];
286 let mut leftover_projection: Vec<ProjectionExpr> = vec![];
287
288 for projection_expr in source_projection.iter() {
289 let r = projection_expr.expr.apply(|node| {
290 if let Some(scalar_fn_expr) = node.downcast_ref::<ScalarFunctionExpr>()
292 && !can_scalar_fn_be_pushed_down(scalar_fn_expr)
293 {
294 scan_projection.extend(
295 collect_columns(node)
296 .into_iter()
297 .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
298 );
299
300 leftover_projection.push(projection_expr.clone());
301 return Ok(TreeNodeRecursion::Stop);
302 }
303
304 if let Some(binary_expr) = node.downcast_ref::<df_expr::BinaryExpr>()
307 && binary_expr.op().is_numerical_operators()
308 && (is_decimal(&binary_expr.left().data_type(input_schema)?)
309 && is_decimal(&binary_expr.right().data_type(input_schema)?))
310 {
311 scan_projection.extend(
312 collect_columns(node)
313 .into_iter()
314 .map(|c| (c.name().to_string(), get_item(c.name(), root()))),
315 );
316
317 leftover_projection.push(projection_expr.clone());
318 return Ok(TreeNodeRecursion::Stop);
319 }
320
321 Ok(TreeNodeRecursion::Continue)
322 })?;
323
324 if matches!(r, TreeNodeRecursion::Continue) {
326 scan_projection.push((
327 projection_expr.alias.clone(),
328 self.convert(projection_expr.expr.as_ref())?,
329 ));
330 leftover_projection.push(ProjectionExpr {
331 expr: Arc::new(df_expr::Column::new_with_schema(
332 projection_expr.alias.as_str(),
333 output_schema,
334 )?),
335 alias: projection_expr.alias.clone(),
336 });
337 }
338 }
339
340 Ok(ProcessedProjection {
341 scan_projection: pack(scan_projection, Nullability::NonNullable),
342 leftover_projection: leftover_projection.into(),
343 })
344 }
345}
346
347fn try_operator_from_df(value: &DFOperator) -> DFResult<Operator> {
348 match value {
349 DFOperator::Eq => Ok(Operator::Eq),
350 DFOperator::NotEq => Ok(Operator::NotEq),
351 DFOperator::Lt => Ok(Operator::Lt),
352 DFOperator::LtEq => Ok(Operator::Lte),
353 DFOperator::Gt => Ok(Operator::Gt),
354 DFOperator::GtEq => Ok(Operator::Gte),
355 DFOperator::And => Ok(Operator::And),
356 DFOperator::Or => Ok(Operator::Or),
357 DFOperator::Plus => Ok(Operator::Add),
358 DFOperator::Minus => Ok(Operator::Sub),
359 DFOperator::Multiply => Ok(Operator::Mul),
360 DFOperator::Divide => Ok(Operator::Div),
361 DFOperator::IsDistinctFrom
362 | DFOperator::IsNotDistinctFrom
363 | DFOperator::RegexMatch
364 | DFOperator::RegexIMatch
365 | DFOperator::RegexNotMatch
366 | DFOperator::RegexNotIMatch
367 | DFOperator::LikeMatch
368 | DFOperator::ILikeMatch
369 | DFOperator::NotLikeMatch
370 | DFOperator::NotILikeMatch
371 | DFOperator::BitwiseAnd
372 | DFOperator::BitwiseOr
373 | DFOperator::BitwiseXor
374 | DFOperator::BitwiseShiftRight
375 | DFOperator::BitwiseShiftLeft
376 | DFOperator::StringConcat
377 | DFOperator::AtArrow
378 | DFOperator::ArrowAt
379 | DFOperator::Modulo
380 | DFOperator::Arrow
381 | DFOperator::LongArrow
382 | DFOperator::HashArrow
383 | DFOperator::HashLongArrow
384 | DFOperator::AtAt
385 | DFOperator::IntegerDivide
386 | DFOperator::HashMinus
387 | DFOperator::AtQuestion
388 | DFOperator::Question
389 | DFOperator::QuestionAnd
390 | DFOperator::QuestionPipe
391 | DFOperator::Colon => {
392 tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
393 Err(exec_datafusion_err!(
394 "Unsupported datafusion operator {value}"
395 ))
396 }
397 }
398}
399
400fn can_be_pushed_down_impl(expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
401 if is_dynamic_physical_expr(expr) {
404 return false;
405 }
406
407 if let Some(binary) = expr.downcast_ref::<df_expr::BinaryExpr>() {
408 can_binary_be_pushed_down(binary, schema)
409 } else if let Some(col) = expr.downcast_ref::<df_expr::Column>() {
410 schema
411 .field_with_name(col.name())
412 .ok()
413 .is_some_and(|field| supported_data_types(field.data_type()))
414 } else if let Some(like) = expr.downcast_ref::<df_expr::LikeExpr>() {
415 can_be_pushed_down_impl(like.expr(), schema)
416 && can_be_pushed_down_impl(like.pattern(), schema)
417 } else if let Some(lit) = expr.downcast_ref::<df_expr::Literal>() {
418 supported_data_types(&lit.value().data_type())
419 } else if let Some(cast_expr) = expr.downcast_ref::<df_expr::CastExpr>() {
420 is_convertible_expr(cast_expr.expr())
422 } else if let Some(is_null) = expr.downcast_ref::<df_expr::IsNullExpr>() {
423 can_be_pushed_down_impl(is_null.arg(), schema)
424 } else if let Some(is_not_null) = expr.downcast_ref::<df_expr::IsNotNullExpr>() {
425 can_be_pushed_down_impl(is_not_null.arg(), schema)
426 } else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
427 can_be_pushed_down_impl(in_list.expr(), schema)
428 && in_list
429 .list()
430 .iter()
431 .all(|e| can_be_pushed_down_impl(e, schema))
432 } else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
433 can_scalar_fn_be_pushed_down(scalar_fn)
434 } else if let Some(case_expr) = expr.downcast_ref::<df_expr::CaseExpr>() {
435 can_case_be_pushed_down(case_expr, schema)
436 } else {
437 tracing::debug!(%expr, "DataFusion expression can't be pushed down");
438 false
439 }
440}
441
442fn is_convertible_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
446 expr.downcast_ref::<df_expr::BinaryExpr>().is_some()
448 || expr.downcast_ref::<df_expr::Column>().is_some()
449 || expr.downcast_ref::<df_expr::LikeExpr>().is_some()
450 || expr.downcast_ref::<df_expr::Literal>().is_some()
451 || expr
452 .downcast_ref::<df_expr::CastExpr>()
453 .is_some_and(|e| is_convertible_expr(e.expr()))
454 || expr.downcast_ref::<df_expr::IsNullExpr>().is_some()
455 || expr.downcast_ref::<df_expr::IsNotNullExpr>().is_some()
456 || expr.downcast_ref::<df_expr::InListExpr>().is_some()
457 || expr
458 .downcast_ref::<ScalarFunctionExpr>()
459 .is_some_and(|sf| ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some())
460}
461
462fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
463 let is_op_supported = try_operator_from_df(binary.op()).is_ok();
464 is_op_supported
465 && can_be_pushed_down_impl(binary.left(), schema)
466 && can_be_pushed_down_impl(binary.right(), schema)
467}
468
469fn can_case_be_pushed_down(case_expr: &df_expr::CaseExpr, schema: &Schema) -> bool {
470 if case_expr.expr().is_some() {
473 return false;
474 }
475
476 for (when_expr, then_expr) in case_expr.when_then_expr() {
478 if !can_be_pushed_down_impl(when_expr, schema)
479 || !can_be_pushed_down_impl(then_expr, schema)
480 {
481 return false;
482 }
483 }
484
485 if let Some(else_expr) = case_expr.else_expr()
487 && !can_be_pushed_down_impl(else_expr, schema)
488 {
489 return false;
490 }
491
492 true
493}
494
495fn supported_data_types(dt: &DataType) -> bool {
496 use DataType::*;
497
498 if let Dictionary(_, value_type) = dt {
500 return supported_data_types(value_type.as_ref());
501 }
502
503 let is_supported = dt.is_null()
504 || dt.is_numeric()
505 || matches!(
506 dt,
507 Boolean
508 | Utf8
509 | LargeUtf8
510 | Utf8View
511 | Binary
512 | LargeBinary
513 | BinaryView
514 | Date32
515 | Date64
516 | Timestamp(_, _)
517 | Time32(_)
518 | Time64(_)
519 );
520
521 if !is_supported {
522 tracing::debug!("DataFusion data type {dt:?} is not supported");
523 }
524
525 is_supported
526}
527
528fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
531 ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
532}
533
534fn is_decimal(dt: &DataType) -> bool {
536 matches!(
537 dt,
538 DataType::Decimal32(_, _)
539 | DataType::Decimal64(_, _)
540 | DataType::Decimal128(_, _)
541 | DataType::Decimal256(_, _)
542 )
543}
544
545#[cfg(test)]
546mod tests {
547 use std::sync::Arc;
548
549 use arrow_schema::DataType;
550 use arrow_schema::Field;
551 use arrow_schema::Schema;
552 use arrow_schema::TimeUnit as ArrowTimeUnit;
553 use datafusion::arrow::array::AsArray;
554 use datafusion::arrow::datatypes::Int32Type;
555 use datafusion_common::ScalarValue;
556 use datafusion_expr::Operator as DFOperator;
557 use datafusion_physical_expr::PhysicalExpr;
558 use datafusion_physical_plan::expressions as df_expr;
559 use insta::assert_snapshot;
560 use rstest::rstest;
561
562 use super::*;
563 use crate::common_tests::TestSessionContext;
564
565 #[rstest::fixture]
566 fn test_schema() -> Schema {
567 Schema::new(vec![
568 Field::new("id", DataType::Int32, false),
569 Field::new("name", DataType::Utf8, true),
570 Field::new("score", DataType::Float64, true),
571 Field::new("active", DataType::Boolean, false),
572 Field::new(
573 "created_at",
574 DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
575 true,
576 ),
577 Field::new(
578 "unsupported_list",
579 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
580 true,
581 ),
582 ])
583 }
584
585 #[test]
586 fn test_make_vortex_predicate_empty() {
587 let expr_convertor = DefaultExpressionConvertor::default();
588 let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
589 assert!(result.is_none());
590 }
591
592 #[test]
593 fn test_make_vortex_predicate_single() {
594 let expr_convertor = DefaultExpressionConvertor::default();
595 let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
596 let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
597 assert!(result.is_some());
598 }
599
600 #[test]
601 fn test_make_vortex_predicate_multiple() {
602 let expr_convertor = DefaultExpressionConvertor::default();
603 let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
604 let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
605 let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
606 assert!(result.is_some());
607 }
609
610 #[rstest]
611 #[case::eq(DFOperator::Eq, Operator::Eq)]
612 #[case::not_eq(DFOperator::NotEq, Operator::NotEq)]
613 #[case::lt(DFOperator::Lt, Operator::Lt)]
614 #[case::lte(DFOperator::LtEq, Operator::Lte)]
615 #[case::gt(DFOperator::Gt, Operator::Gt)]
616 #[case::gte(DFOperator::GtEq, Operator::Gte)]
617 #[case::and(DFOperator::And, Operator::And)]
618 #[case::or(DFOperator::Or, Operator::Or)]
619 #[case::plus(DFOperator::Plus, Operator::Add)]
620 #[case::plus(DFOperator::Minus, Operator::Sub)]
621 #[case::plus(DFOperator::Multiply, Operator::Mul)]
622 #[case::plus(DFOperator::Divide, Operator::Div)]
623 fn test_operator_conversion_supported(
624 #[case] df_op: DFOperator,
625 #[case] expected_vortex_op: Operator,
626 ) {
627 let result = try_operator_from_df(&df_op).unwrap();
628 assert_eq!(result, expected_vortex_op);
629 }
630
631 #[rstest]
632 #[case::modulo(DFOperator::Modulo)]
633 #[case::bitwise_and(DFOperator::BitwiseAnd)]
634 #[case::regex_match(DFOperator::RegexMatch)]
635 #[case::like_match(DFOperator::LikeMatch)]
636 fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
637 let result = try_operator_from_df(&df_op);
638 assert!(result.is_err());
639 assert!(
640 result
641 .unwrap_err()
642 .to_string()
643 .contains("Unsupported datafusion operator")
644 );
645 }
646
647 #[test]
648 fn test_expr_from_df_column() {
649 let col_expr = df_expr::Column::new("test_column", 0);
650 let result = DefaultExpressionConvertor::default()
651 .convert(&col_expr)
652 .unwrap();
653
654 assert_snapshot!(result.display_tree().to_string(), @r"
655 vortex.get_item(test_column)
656 └── input: vortex.root()
657 ");
658 }
659
660 #[test]
661 fn test_expr_from_df_literal() {
662 let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
663 let result = DefaultExpressionConvertor::default()
664 .convert(&literal_expr)
665 .unwrap();
666
667 assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
668 }
669
670 #[test]
671 fn test_expr_from_df_binary() {
672 let left = Arc::new(df_expr::Column::new("left", 0)) as Arc<dyn PhysicalExpr>;
673 let right =
674 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
675 let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);
676
677 let result = DefaultExpressionConvertor::default()
678 .convert(&binary_expr)
679 .unwrap();
680
681 assert_snapshot!(result.display_tree().to_string(), @r"
682 vortex.binary(=)
683 ├── lhs: vortex.get_item(left)
684 │ └── input: vortex.root()
685 └── rhs: vortex.literal(42i32)
686 ");
687 }
688
689 #[rstest]
690 #[case::like_normal(false, false)]
691 #[case::like_negated(true, false)]
692 #[case::like_case_insensitive(false, true)]
693 #[case::like_negated_case_insensitive(true, true)]
694 fn test_expr_from_df_like(#[case] negated: bool, #[case] case_insensitive: bool) {
695 let expr = Arc::new(df_expr::Column::new("text_col", 0)) as Arc<dyn PhysicalExpr>;
696 let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
697 "test%".to_string(),
698 )))) as Arc<dyn PhysicalExpr>;
699 let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);
700
701 let result = DefaultExpressionConvertor::default()
702 .convert(&like_expr)
703 .unwrap();
704 let like_opts = result.as_::<Like>();
705 assert_eq!(
706 like_opts,
707 &LikeOptions {
708 negated,
709 case_insensitive
710 }
711 );
712 }
713
714 #[rstest]
715 #[case::null(DataType::Null, true)]
717 #[case::boolean(DataType::Boolean, true)]
718 #[case::int8(DataType::Int8, true)]
719 #[case::int16(DataType::Int16, true)]
720 #[case::int32(DataType::Int32, true)]
721 #[case::int64(DataType::Int64, true)]
722 #[case::uint8(DataType::UInt8, true)]
723 #[case::uint16(DataType::UInt16, true)]
724 #[case::uint32(DataType::UInt32, true)]
725 #[case::uint64(DataType::UInt64, true)]
726 #[case::float32(DataType::Float32, true)]
727 #[case::float64(DataType::Float64, true)]
728 #[case::utf8(DataType::Utf8, true)]
729 #[case::utf8_view(DataType::Utf8View, true)]
730 #[case::binary(DataType::Binary, true)]
731 #[case::binary_view(DataType::BinaryView, true)]
732 #[case::date32(DataType::Date32, true)]
733 #[case::date64(DataType::Date64, true)]
734 #[case::timestamp_ms(DataType::Timestamp(ArrowTimeUnit::Millisecond, None), true)]
735 #[case::timestamp_us(
736 DataType::Timestamp(ArrowTimeUnit::Microsecond, Some(Arc::from("UTC"))),
737 true
738 )]
739 #[case::time32_s(DataType::Time32(ArrowTimeUnit::Second), true)]
740 #[case::time64_ns(DataType::Time64(ArrowTimeUnit::Nanosecond), true)]
741 #[case::list(
743 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
744 false
745 )]
746 #[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()
747 ), false)]
748 #[case::dict_utf8(
750 DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
751 true
752 )]
753 #[case::dict_int32(
754 DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Int32)),
755 true
756 )]
757 #[case::dict_unsupported(
758 DataType::Dictionary(
759 Box::new(DataType::UInt32),
760 Box::new(DataType::List(Arc::new(Field::new("item", DataType::Int32, true))))
761 ),
762 false
763 )]
764 fn test_supported_data_types(#[case] data_type: DataType, #[case] expected: bool) {
765 assert_eq!(supported_data_types(&data_type), expected);
766 }
767
768 #[rstest]
769 fn test_can_be_pushed_down_column_supported(test_schema: Schema) {
770 let col_expr = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
771
772 assert!(can_be_pushed_down_impl(&col_expr, &test_schema));
773 }
774
775 #[rstest]
776 fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
777 let col_expr =
778 Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
779
780 assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
781 }
782
783 #[rstest]
784 fn test_can_be_pushed_down_column_not_found(test_schema: Schema) {
785 let col_expr = Arc::new(df_expr::Column::new("nonexistent", 99)) as Arc<dyn PhysicalExpr>;
786
787 assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
788 }
789
790 #[rstest]
791 fn test_can_be_pushed_down_literal_supported(test_schema: Schema) {
792 let lit_expr =
793 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
794
795 assert!(can_be_pushed_down_impl(&lit_expr, &test_schema));
796 }
797
798 #[rstest]
799 fn test_can_be_pushed_down_literal_unsupported(test_schema: Schema) {
800 let unsupported_literal = ScalarValue::DurationSecond(Some(42));
802 let lit_expr =
803 Arc::new(df_expr::Literal::new(unsupported_literal)) as Arc<dyn PhysicalExpr>;
804
805 assert!(!can_be_pushed_down_impl(&lit_expr, &test_schema));
806 }
807
808 #[rstest]
809 fn test_can_be_pushed_down_binary_supported(test_schema: Schema) {
810 let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
811 let right =
812 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
813 let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
814 as Arc<dyn PhysicalExpr>;
815
816 assert!(can_be_pushed_down_impl(&binary_expr, &test_schema));
817 }
818
819 #[rstest]
820 fn test_can_be_pushed_down_binary_unsupported_operator(test_schema: Schema) {
821 let left = Arc::new(df_expr::Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
822 let right =
823 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
824 let binary_expr = Arc::new(df_expr::BinaryExpr::new(
825 left,
826 DFOperator::AtQuestion,
827 right,
828 )) as Arc<dyn PhysicalExpr>;
829
830 assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
831 }
832
833 #[rstest]
834 fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
835 let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
836 let right =
837 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
838 let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
839 as Arc<dyn PhysicalExpr>;
840
841 assert!(!can_be_pushed_down_impl(&binary_expr, &test_schema));
842 }
843
844 #[rstest]
845 fn test_can_be_pushed_down_like_supported(test_schema: Schema) {
846 let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
847 let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
848 "test%".to_string(),
849 )))) as Arc<dyn PhysicalExpr>;
850 let like_expr =
851 Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
852
853 assert!(can_be_pushed_down_impl(&like_expr, &test_schema));
854 }
855
856 #[rstest]
857 fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
858 let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
859 let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
860 "test%".to_string(),
861 )))) as Arc<dyn PhysicalExpr>;
862 let like_expr =
863 Arc::new(df_expr::LikeExpr::new(false, false, expr, pattern)) as Arc<dyn PhysicalExpr>;
864
865 assert!(!can_be_pushed_down_impl(&like_expr, &test_schema));
866 }
867
868 #[tokio::test]
870 async fn test_cast_int_to_string() -> anyhow::Result<()> {
871 let ctx = TestSessionContext::default();
872
873 ctx.session
874 .sql(r#"copy (select 1 as id) to 'example.vortex'"#)
875 .await?
876 .show()
877 .await?;
878
879 ctx.session
880 .sql(r#"select cast(id as string) as sid from 'example.vortex' where id > 0"#)
881 .await?
882 .show()
883 .await?;
884
885 ctx.session
886 .sql(r#"select id from 'example.vortex' where cast (id as string) == '1'"#)
887 .await?
888 .show()
889 .await?;
890
891 ctx.session
893 .sql(r#"select cast(id as string) from 'example.vortex'"#)
894 .await?
895 .collect()
896 .await?;
897
898 Ok(())
899 }
900
901 #[test]
904 fn test_case_when_datafusion_vortex_equivalence() {
905 use datafusion::arrow::array::Int32Array;
906 use datafusion::arrow::array::RecordBatch;
907 use datafusion_physical_expr::expressions::CaseExpr;
908 use vortex::VortexSessionDefault;
909 use vortex::array::ArrayRef;
910 use vortex::array::Canonical;
911 use vortex::array::VortexSessionExecute as _;
912 use vortex::array::arrow::FromArrowArray;
913 use vortex::session::VortexSession;
914
915 let values = Arc::new(Int32Array::from(vec![1, 5, 10, 15, 20]));
917 let schema = Arc::new(Schema::new(vec![Field::new(
918 "value",
919 DataType::Int32,
920 false,
921 )]));
922 let batch = RecordBatch::try_new(schema, vec![values]).unwrap();
923
924 let col_value = Arc::new(df_expr::Column::new("value", 0)) as Arc<dyn PhysicalExpr>;
927 let lit_10 =
928 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(10)))) as Arc<dyn PhysicalExpr>;
929 let lit_5 =
930 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(5)))) as Arc<dyn PhysicalExpr>;
931 let lit_100 =
932 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(100)))) as Arc<dyn PhysicalExpr>;
933 let lit_50 =
934 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(50)))) as Arc<dyn PhysicalExpr>;
935 let lit_0 =
936 Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(0)))) as Arc<dyn PhysicalExpr>;
937
938 let when1 = Arc::new(df_expr::BinaryExpr::new(
940 Arc::clone(&col_value),
941 DFOperator::Gt,
942 lit_10,
943 )) as Arc<dyn PhysicalExpr>;
944 let when2 = Arc::new(df_expr::BinaryExpr::new(col_value, DFOperator::Gt, lit_5))
946 as Arc<dyn PhysicalExpr>;
947
948 let case_expr =
949 CaseExpr::try_new(None, vec![(when1, lit_100), (when2, lit_50)], Some(lit_0)).unwrap();
950
951 let df_result = case_expr.evaluate(&batch).unwrap();
953 let df_array = df_result.into_array(batch.num_rows()).unwrap();
954
955 let expr_convertor = DefaultExpressionConvertor::default();
957 let vortex_expr = expr_convertor.try_convert_case_expr(&case_expr).unwrap();
958
959 let vortex_array: ArrayRef = ArrayRef::from_arrow(&batch, false).unwrap();
961
962 let session = VortexSession::default();
964 let mut ctx = session.create_execution_ctx();
965 let vortex_result = vortex_array
966 .apply(&vortex_expr)
967 .unwrap()
968 .execute::<Canonical>(&mut ctx)
969 .unwrap();
970
971 let vortex_as_arrow = vortex_result.into_primitive().as_slice::<i32>().to_vec();
973
974 let df_as_arrow: Vec<i32> = df_array.as_primitive::<Int32Type>().values().to_vec();
976
977 assert_eq!(df_as_arrow, vec![0, 0, 50, 100, 100]);
985 assert_eq!(vortex_as_arrow, df_as_arrow);
986 }
987}