1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub(crate) fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub(crate) fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub(crate) fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub(crate) fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum VariableKind {
224 Node,
226 Edge,
228 EdgeList,
230 Path,
232}
233
234impl VariableKind {
235 pub fn edge_for(is_variable_length: bool) -> Self {
239 if is_variable_length {
240 Self::EdgeList
241 } else {
242 Self::Edge
243 }
244 }
245}
246
247pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
282 match expr {
283 Expr::PatternComprehension { .. } => Err(anyhow!(
284 "Pattern comprehensions require fallback executor (graph traversal)"
285 )),
286 #[expect(deprecated)]
289 Expr::Wildcard => Ok(DfExpr::Wildcard {
290 qualifier: None,
291 options: Default::default(),
292 }),
293
294 Expr::Variable(name) => {
295 if let Some(ctx) = context
300 && ctx.variable_kinds.contains_key(name)
301 {
302 return Ok(DfExpr::Column(Column::from_name(name)));
303 }
304
305 if let Some(ctx) = context
311 && let Some(value) = ctx.outer_values.get(name)
312 {
313 return value_to_scalar(value).map(lit);
314 }
315
316 if let Some(ctx) = context
321 && let Some(value) = ctx.parameters.get(name)
322 {
323 match value {
326 Value::List(values) if name.ends_with("._vid") => {
327 let literals = values
329 .iter()
330 .map(|v| value_to_scalar(v).map(lit))
331 .collect::<Result<Vec<_>>>()?;
332 return Ok(DfExpr::InList(InList {
333 expr: Box::new(DfExpr::Column(Column::from_name(name))),
334 list: literals,
335 negated: false,
336 }));
337 }
338 other_value => return value_to_scalar(other_value).map(lit),
339 }
340 }
341
342 Ok(DfExpr::Column(Column::from_name(name)))
345 }
346
347 Expr::Property(base, prop) => translate_property_access(base, prop, context),
348
349 Expr::ArrayIndex { array, index } => {
350 if let Ok(var_name) = extract_variable_name(array)
353 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
354 {
355 let col_name = format!("{}.{}", var_name, prop_name);
356 return Ok(DfExpr::Column(Column::from_name(col_name)));
357 }
358
359 let array_expr = cypher_expr_to_df(array, context)?;
360 let index_expr = cypher_expr_to_df(index, context)?;
361
362 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
364 }
365
366 Expr::ArraySlice { array, start, end } => {
367 let array_expr = cypher_expr_to_df(array, context)?;
371
372 let start_expr = match start {
373 Some(s) => cypher_expr_to_df(s, context)?,
374 None => lit(0i64),
375 };
376
377 let end_expr = match end {
378 Some(e) => cypher_expr_to_df(e, context)?,
379 None => lit(i64::MAX),
380 };
381
382 Ok(dummy_udf_expr(
385 "_cypher_list_slice",
386 vec![array_expr, start_expr, end_expr],
387 ))
388 }
389
390 Expr::Parameter(name) => {
391 if let Some(ctx) = context
393 && let Some(value) = ctx.parameters.get(name)
394 {
395 return value_to_scalar(value).map(lit);
396 }
397 Err(anyhow!("Unresolved parameter: ${}", name))
398 }
399
400 Expr::Literal(value) => {
401 let scalar = cypher_literal_to_scalar(value)?;
402 Ok(lit(scalar))
403 }
404
405 Expr::List(items) => translate_list_literal(items, context),
406
407 Expr::Map(entries) => {
408 if entries.is_empty() {
409 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
411 Default::default(),
412 ));
413 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
414 }
415 let mut args = Vec::with_capacity(entries.len() * 2);
418 for (key, val_expr) in entries {
419 args.push(lit(key.clone()));
420 args.push(cypher_expr_to_df(val_expr, context)?);
421 }
422 Ok(datafusion::functions::expr_fn::named_struct(args))
423 }
424
425 Expr::IsNull(inner) => translate_null_check(inner, context, true),
426
427 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
428
429 Expr::IsUnique(_) => {
430 Err(anyhow!(
432 "IS UNIQUE can only be used in constraint definitions"
433 ))
434 }
435
436 Expr::FunctionCall {
437 name,
438 args,
439 distinct,
440 window_spec,
441 } => {
442 if window_spec.is_some() {
445 let col_name = expr.to_string_repr();
447 Ok(col(&col_name))
448 } else {
449 translate_function_call(name, args, *distinct, context)
450 }
451 }
452
453 Expr::In { expr, list } => translate_in_expression(expr, list, context),
454
455 Expr::BinaryOp { left, op, right } => {
456 let left_expr = cypher_expr_to_df(left, context)?;
457 let right_expr = cypher_expr_to_df(right, context)?;
458 translate_binary_op(left_expr, op, right_expr)
459 }
460
461 Expr::UnaryOp { op, expr: inner } => {
462 let inner_expr = cypher_expr_to_df(inner, context)?;
463 match op {
464 UnaryOp::Not => Ok(inner_expr.not()),
465 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
466 }
467 }
468
469 Expr::Case {
470 expr,
471 when_then,
472 else_expr,
473 } => translate_case_expression(expr, when_then, else_expr, context),
474
475 Expr::Reduce { .. } => Err(anyhow!(
476 "Reduce expressions not yet supported in DataFusion translation"
477 )),
478
479 Expr::Exists { .. } => Err(anyhow!(
480 "EXISTS subqueries are handled by the physical expression compiler, \
481 not the DataFusion logical expression translator"
482 )),
483
484 Expr::CountSubquery(_) => Err(anyhow!(
485 "Count subqueries not yet supported in DataFusion translation"
486 )),
487
488 Expr::CollectSubquery(_) => Err(anyhow!(
489 "COLLECT subqueries not yet supported in DataFusion translation"
490 )),
491
492 Expr::Quantifier { .. } => {
493 Err(anyhow!(
498 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
499 via CypherPhysicalExprCompiler"
500 ))
501 }
502
503 Expr::ListComprehension { .. } => {
504 Err(anyhow!(
516 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
517 ))
518 }
519
520 Expr::ValidAt { .. } => {
521 Err(anyhow!(
524 "VALID_AT expression should have been transformed to function call in planner"
525 ))
526 }
527
528 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
529
530 Expr::LabelCheck { expr, labels } => {
531 if let Expr::Variable(var) = expr.as_ref() {
532 let is_edge = context
534 .and_then(|ctx| ctx.variable_kinds.get(var))
535 .is_some_and(|k| matches!(k, VariableKind::Edge));
536
537 if is_edge {
538 if labels.len() > 1 {
542 Ok(lit(false))
543 } else {
544 let type_col =
545 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
546 Ok(DfExpr::Case(datafusion::logical_expr::Case {
548 expr: None,
549 when_then_expr: vec![(
550 Box::new(type_col.clone().is_null()),
551 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
552 )],
553 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
554 }))
555 }
556 } else {
557 let labels_col =
559 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
560 let checks = labels
561 .iter()
562 .map(|label| {
563 datafusion::functions_nested::expr_fn::array_has(
564 labels_col.clone(),
565 lit(label.clone()),
566 )
567 })
568 .reduce(|acc, check| acc.and(check));
569 Ok(DfExpr::Case(datafusion::logical_expr::Case {
571 expr: None,
572 when_then_expr: vec![(
573 Box::new(labels_col.is_null()),
574 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
575 )],
576 else_expr: Some(Box::new(checks.unwrap())),
577 }))
578 }
579 } else {
580 Err(anyhow!(
581 "LabelCheck on non-variable expression not yet supported in DataFusion"
582 ))
583 }
584 }
585 }
586}
587
588#[derive(Debug, Clone)]
592pub struct TranslationContext {
593 pub parameters: std::collections::HashMap<String, Value>,
595
596 pub outer_values: std::collections::HashMap<String, Value>,
600
601 pub variable_labels: std::collections::HashMap<String, String>,
603
604 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
606
607 pub node_variable_hints: Vec<String>,
610
611 pub mutation_edge_hints: Vec<String>,
614
615 pub statement_time: chrono::DateTime<chrono::Utc>,
620}
621
622impl Default for TranslationContext {
623 fn default() -> Self {
624 Self {
625 parameters: std::collections::HashMap::new(),
626 outer_values: std::collections::HashMap::new(),
627 variable_labels: std::collections::HashMap::new(),
628 variable_kinds: std::collections::HashMap::new(),
629 node_variable_hints: Vec::new(),
630 mutation_edge_hints: Vec::new(),
631 statement_time: chrono::Utc::now(),
632 }
633 }
634}
635
636impl TranslationContext {
637 pub fn new() -> Self {
639 Self::default()
640 }
641
642 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
644 self.parameters.insert(name.into(), value);
645 self
646 }
647
648 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
650 self.variable_labels.insert(var.into(), label.into());
651 self
652 }
653}
654
655fn extract_variable_name(expr: &Expr) -> Result<String> {
657 match expr {
658 Expr::Variable(name) => Ok(name.clone()),
659 Expr::Property(base, _) => extract_variable_name(base),
660 _ => Err(anyhow!(
661 "Cannot extract variable name from expression: {:?}",
662 expr
663 )),
664 }
665}
666
667fn translate_null_check(
669 inner: &Expr,
670 context: Option<&TranslationContext>,
671 is_null: bool,
672) -> Result<DfExpr> {
673 if let Expr::Variable(var) = inner
674 && let Some(ctx) = context
675 && let Some(kind) = ctx.variable_kinds.get(var)
676 {
677 let col_name = match kind {
678 VariableKind::Node => format!("{}.{}", var, COL_VID),
679 VariableKind::Edge => format!("{}.{}", var, COL_EID),
680 VariableKind::Path | VariableKind::EdgeList => var.clone(),
681 };
682 let col_expr = DfExpr::Column(Column::from_name(col_name));
683 return Ok(if is_null {
684 col_expr.is_null()
685 } else {
686 col_expr.is_not_null()
687 });
688 }
689
690 let inner_expr = cypher_expr_to_df(inner, context)?;
691 Ok(if is_null {
692 inner_expr.is_null()
693 } else {
694 inner_expr.is_not_null()
695 })
696}
697
698fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
703 if crate::query::datetime::is_duration_accessor(prop) {
704 Some(dummy_udf_expr(
705 "_duration_property",
706 vec![base_expr, lit(prop.to_string())],
707 ))
708 } else if crate::query::datetime::is_temporal_accessor(prop) {
709 Some(dummy_udf_expr(
710 "_temporal_property",
711 vec![base_expr, lit(prop.to_string())],
712 ))
713 } else {
714 None
715 }
716}
717
718fn translate_property_access(
720 base: &Expr,
721 prop: &str,
722 context: Option<&TranslationContext>,
723) -> Result<DfExpr> {
724 if let Ok(var_name) = extract_variable_name(base) {
725 let is_graph_entity = context
726 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
727 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
728
729 if !is_graph_entity
730 && let Some(expr) =
731 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
732 {
733 return Ok(expr);
734 }
735
736 let col_name = format!("{}.{}", var_name, prop);
737
738 if let Some(ctx) = context
741 && let Some(value) = ctx.parameters.get(&col_name)
742 {
743 match value {
746 Value::List(values) if col_name.ends_with("._vid") => {
747 let literals = values
748 .iter()
749 .map(|v| value_to_scalar(v).map(lit))
750 .collect::<Result<Vec<_>>>()?;
751 return Ok(DfExpr::InList(InList {
752 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
753 list: literals,
754 negated: false,
755 }));
756 }
757 other_value => return value_to_scalar(other_value).map(lit),
758 }
759 }
760
761 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
764 let base_expr = cypher_expr_to_df(base, context)?;
765 return Ok(dummy_udf_expr(
766 "index",
767 vec![base_expr, lit(prop.to_string())],
768 ));
769 }
770
771 if is_graph_entity {
772 Ok(DfExpr::Column(Column::from_name(col_name)))
773 } else {
774 let base_expr = DfExpr::Column(Column::from_name(var_name));
775 Ok(dummy_udf_expr(
776 "index",
777 vec![base_expr, lit(prop.to_string())],
778 ))
779 }
780 } else {
781 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
783 return Ok(expr);
784 }
785
786 if let Expr::Parameter(param_name) = base {
788 if let Some(ctx) = context
789 && let Some(value) = ctx.parameters.get(param_name)
790 {
791 if let Value::Map(map) = value {
792 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
793 return value_to_scalar(&extracted).map(lit);
794 }
795 return Ok(lit(ScalarValue::Null));
796 }
797 return Err(anyhow!("Unresolved parameter: ${}", param_name));
798 }
799
800 let base_expr = cypher_expr_to_df(base, context)?;
801 Ok(dummy_udf_expr(
802 "index",
803 vec![base_expr, lit(prop.to_string())],
804 ))
805 }
806}
807
808fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
810 let mut has_string = false;
812 let mut has_bool = false;
813 let mut has_list = false;
814 let mut has_map = false;
815 let mut has_numeric = false;
816 let mut has_graph_entity = false;
817 let mut has_temporal = false;
818
819 for item in items {
820 match item {
821 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
822 has_numeric = true
823 }
824 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
825 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
826 Expr::List(_) => has_list = true,
827 Expr::Map(_) => has_map = true,
828 Expr::Variable(name) => {
831 if context
832 .and_then(|ctx| ctx.variable_kinds.get(name))
833 .is_some()
834 {
835 has_graph_entity = true;
836 }
837 }
838 Expr::FunctionCall { name, .. } => {
841 let upper = name.to_uppercase();
842 if matches!(
843 upper.as_str(),
844 "DATE"
845 | "TIME"
846 | "LOCALTIME"
847 | "LOCALDATETIME"
848 | "DATETIME"
849 | "DURATION"
850 | "DATE.TRUNCATE"
851 | "TIME.TRUNCATE"
852 | "DATETIME.TRUNCATE"
853 | "LOCALDATETIME.TRUNCATE"
854 | "LOCALTIME.TRUNCATE"
855 ) {
856 has_temporal = true;
857 }
858 }
859 _ => {}
861 }
862 }
863
864 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
866
867 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
870 if let Some(json_array) = try_items_to_json(items) {
872 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
873 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
874 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
875 }
876 let df_args: Vec<DfExpr> = items
878 .iter()
879 .map(|item| cypher_expr_to_df(item, context))
880 .collect::<Result<_>>()?;
881 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
882 }
883
884 let mut df_args = Vec::with_capacity(items.len());
887 let mut has_float = false;
888 let mut has_int = false;
889 let mut has_other = false;
890
891 for item in items {
892 match item {
893 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
894 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
895 _ => has_other = true,
896 }
897 df_args.push(cypher_expr_to_df(item, context)?);
898 }
899
900 if df_args.is_empty() {
901 let empty_arr =
903 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
904 Ok(lit(ScalarValue::List(empty_arr)))
905 } else if has_float && has_int && !has_other {
906 let promoted_args = df_args
908 .into_iter()
909 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
910 .collect();
911 Ok(datafusion::functions_nested::expr_fn::make_array(
912 promoted_args,
913 ))
914 } else {
915 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
916 }
917}
918
919fn translate_in_expression(
921 expr: &Expr,
922 list: &Expr,
923 context: Option<&TranslationContext>,
924) -> Result<DfExpr> {
925 let left_expr = if let Expr::Variable(var) = expr
930 && let Some(ctx) = context
931 && let Some(kind) = ctx.variable_kinds.get(var)
932 {
933 match kind {
934 VariableKind::Node | VariableKind::Edge => {
935 let id_col = match kind {
936 VariableKind::Node => COL_VID,
937 VariableKind::Edge => COL_EID,
938 _ => unreachable!(),
939 };
940 cast_expr(
941 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
942 datafusion::arrow::datatypes::DataType::Int64,
943 )
944 }
945 _ => cypher_expr_to_df(expr, context)?,
946 }
947 } else {
948 cypher_expr_to_df(expr, context)?
949 };
950
951 if let Expr::List(items) = list {
956 if let Some(json_array) = try_items_to_json(items) {
957 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
959 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
960 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
961 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
962 } else {
963 let expanded: Vec<DfExpr> = items
965 .iter()
966 .map(|item| cypher_expr_to_df(item, context))
967 .collect::<Result<Vec<_>>>()?;
968 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
969 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
970 }
971 } else {
972 let right_expr = cypher_expr_to_df(list, context)?;
973
974 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
979 return Ok(lit(ScalarValue::Boolean(None)));
980 }
981
982 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
983 }
984}
985
986fn translate_case_expression(
988 operand: &Option<Box<Expr>>,
989 when_then: &[(Expr, Expr)],
990 else_expr: &Option<Box<Expr>>,
991 context: Option<&TranslationContext>,
992) -> Result<DfExpr> {
993 let mut case_builder = if let Some(match_expr) = operand {
994 let match_df = cypher_expr_to_df(match_expr, context)?;
995 datafusion::logical_expr::case(match_df)
996 } else {
997 datafusion::logical_expr::when(
998 cypher_expr_to_df(&when_then[0].0, context)?,
999 cypher_expr_to_df(&when_then[0].1, context)?,
1000 )
1001 };
1002
1003 let start_idx = if operand.is_some() { 0 } else { 1 };
1004 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1005 let when_df = cypher_expr_to_df(when_expr, context)?;
1006 let then_df = cypher_expr_to_df(then_expr, context)?;
1007 case_builder = case_builder.when(when_df, then_df);
1008 }
1009
1010 if let Some(else_e) = else_expr {
1011 let else_df = cypher_expr_to_df(else_e, context)?;
1012 Ok(case_builder.otherwise(else_df)?)
1013 } else {
1014 Ok(case_builder.end()?)
1015 }
1016}
1017
1018fn translate_map_projection(
1020 base: &Expr,
1021 items: &[MapProjectionItem],
1022 context: Option<&TranslationContext>,
1023) -> Result<DfExpr> {
1024 let mut args = Vec::new();
1025 for item in items {
1026 match item {
1027 MapProjectionItem::Property(prop) => {
1028 args.push(lit(prop.clone()));
1029 let prop_expr = cypher_expr_to_df(
1030 &Expr::Property(Box::new(base.clone()), prop.clone()),
1031 context,
1032 )?;
1033 args.push(prop_expr);
1034 }
1035 MapProjectionItem::LiteralEntry(key, expr) => {
1036 args.push(lit(key.clone()));
1037 args.push(cypher_expr_to_df(expr, context)?);
1038 }
1039 MapProjectionItem::Variable(var) => {
1040 args.push(lit(var.clone()));
1041 args.push(DfExpr::Column(Column::from_name(var)));
1042 }
1043 MapProjectionItem::AllProperties => {
1044 args.push(lit("__all__"));
1045 args.push(cypher_expr_to_df(base, context)?);
1046 }
1047 }
1048 }
1049 Ok(dummy_udf_expr("_map_project", args))
1050}
1051
1052fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1055 match expr {
1056 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1057 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1058 Expr::Literal(CypherLiteral::Integer(i)) => {
1059 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1060 }
1061 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1062 .map(serde_json::Value::Number)
1063 .or(Some(serde_json::Value::Null)),
1064 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1065 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1066 Expr::Map(entries) => {
1067 let mut map = serde_json::Map::new();
1068 for (k, v) in entries {
1069 map.insert(k.clone(), try_expr_to_json(v)?);
1070 }
1071 Some(serde_json::Value::Object(map))
1072 }
1073 _ => None,
1074 }
1075}
1076
1077fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1079 items.iter().map(try_expr_to_json).collect()
1080}
1081
1082fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1084 match lit {
1085 CypherLiteral::Null => Ok(ScalarValue::Null),
1086 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1087 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1088 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1089 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1090 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1091 }
1092}
1093
1094fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1096 match value {
1097 Value::Null => Ok(ScalarValue::Null),
1098 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1099 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1100 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1101 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1102 Value::List(items) => {
1103 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1105 let scalars = scalars?;
1106
1107 let data_type = infer_common_scalar_type(&scalars);
1109
1110 let typed_scalars: Vec<ScalarValue> = scalars
1112 .into_iter()
1113 .map(|s| {
1114 if matches!(s, ScalarValue::Null) {
1115 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1116 }
1117
1118 match (s, &data_type) {
1119 (
1120 ScalarValue::Int64(Some(v)),
1121 datafusion::arrow::datatypes::DataType::Float64,
1122 ) => ScalarValue::Float64(Some(v as f64)),
1123 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1124 let s_str = s.to_string();
1126 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1127 }
1128 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1129 if matches!(s, ScalarValue::Utf8(_)) {
1131 s
1132 } else {
1133 ScalarValue::Utf8(Some(s.to_string()))
1134 }
1135 }
1136 (s, _) => s,
1137 }
1138 })
1139 .collect();
1140
1141 if typed_scalars.is_empty() {
1143 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1144 &[],
1145 &data_type,
1146 )))
1147 } else {
1148 Ok(ScalarValue::List(ScalarValue::new_list(
1149 &typed_scalars,
1150 &data_type,
1151 true,
1152 )))
1153 }
1154 }
1155 Value::Map(map) => {
1156 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1159 entries.sort_by_key(|(k, _)| *k);
1160
1161 if entries.is_empty() {
1162 return Ok(ScalarValue::Struct(Arc::new(
1163 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1164 )));
1165 }
1166
1167 let mut fields_arrays = Vec::with_capacity(entries.len());
1168
1169 for (k, v) in entries {
1170 let scalar = value_to_scalar(v)?;
1171 let dt = scalar.data_type();
1172 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1173 let array = scalar.to_array()?;
1174 fields_arrays.push((field, array));
1175 }
1176
1177 Ok(ScalarValue::Struct(Arc::new(
1178 datafusion::arrow::array::StructArray::from(fields_arrays),
1179 )))
1180 }
1181 Value::Temporal(tv) => {
1182 use uni_common::TemporalValue;
1183 match tv {
1184 TemporalValue::Date { days_since_epoch } => {
1185 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1186 }
1187 TemporalValue::LocalTime {
1188 nanos_since_midnight,
1189 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1190 TemporalValue::Time {
1191 nanos_since_midnight,
1192 offset_seconds,
1193 } => {
1194 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1196 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1197
1198 let nanos_arr =
1199 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1200 as ArrayRef;
1201 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1202
1203 let fields = Fields::from(vec![
1204 Field::new(
1205 "nanos_since_midnight",
1206 ArrowDataType::Time64(TimeUnit::Nanosecond),
1207 true,
1208 ),
1209 Field::new("offset_seconds", ArrowDataType::Int32, true),
1210 ]);
1211
1212 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1213 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1214 }
1215 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1216 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1217 ),
1218 TemporalValue::DateTime {
1219 nanos_since_epoch,
1220 offset_seconds,
1221 timezone_name,
1222 } => {
1223 use arrow::array::{
1225 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1226 };
1227 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1228
1229 let nanos_arr =
1230 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1231 as ArrayRef;
1232 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1233 let tz_arr =
1234 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1235
1236 let fields = Fields::from(vec![
1237 Field::new(
1238 "nanos_since_epoch",
1239 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1240 true,
1241 ),
1242 Field::new("offset_seconds", ArrowDataType::Int32, true),
1243 Field::new("timezone_name", ArrowDataType::Utf8, true),
1244 ]);
1245
1246 let struct_arr =
1247 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1248 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1249 }
1250 TemporalValue::Duration {
1251 months,
1252 days,
1253 nanos,
1254 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1255 arrow::datatypes::IntervalMonthDayNano {
1256 months: *months as i32,
1257 days: *days as i32,
1258 nanoseconds: *nanos,
1259 },
1260 ))),
1261 }
1262 }
1263 Value::Vector(v) => {
1264 let cv_bytes = uni_common::cypher_value_codec::encode(&Value::Vector(v.clone()));
1266 Ok(ScalarValue::LargeBinary(Some(cv_bytes)))
1267 }
1268 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1269 other => {
1271 let json_val: serde_json::Value = other.clone().into();
1272 let json_str = serde_json::to_string(&json_val)
1273 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1274 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1275 }
1276 }
1277}
1278
1279fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1281 match op {
1282 BinaryOp::Eq => Ok(left.eq(right)),
1286 BinaryOp::NotEq => Ok(left.not_eq(right)),
1287 BinaryOp::Lt => Ok(left.lt(right)),
1288 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1289 BinaryOp::Gt => Ok(left.gt(right)),
1290 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1291
1292 BinaryOp::And => Ok(left.and(right)),
1294 BinaryOp::Or => Ok(left.or(right)),
1295 BinaryOp::Xor => {
1296 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1298 }
1299
1300 BinaryOp::Add => {
1302 if is_list_expr(&left) || is_list_expr(&right) {
1303 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1304 } else {
1305 Ok(left + right)
1306 }
1307 }
1308 BinaryOp::Sub => Ok(left - right),
1309 BinaryOp::Mul => Ok(left * right),
1310 BinaryOp::Div => Ok(left / right),
1311 BinaryOp::Mod => Ok(left % right),
1312 BinaryOp::Pow => {
1313 let left_f = datafusion::logical_expr::cast(
1316 left,
1317 datafusion::arrow::datatypes::DataType::Float64,
1318 );
1319 let right_f = datafusion::logical_expr::cast(
1320 right,
1321 datafusion::arrow::datatypes::DataType::Float64,
1322 );
1323 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1324 }
1325
1326 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1328 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1329 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1330
1331 BinaryOp::Regex => {
1332 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1333 }
1334
1335 BinaryOp::ApproxEq => Err(anyhow!(
1336 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1337 )),
1338 }
1339}
1340
1341macro_rules! check_args {
1346 (1, $df_args:expr, $name:expr) => {
1347 if let Err(e) = require_arg($df_args, $name) {
1348 return Some(Err(e));
1349 }
1350 };
1351 ($n:expr, $df_args:expr, $name:expr) => {
1352 if let Err(e) = require_args($df_args, $n, $name) {
1353 return Some(Err(e));
1354 }
1355 };
1356}
1357
1358fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1361 if df_args.len() < count {
1362 let noun = if count == 1 { "argument" } else { "arguments" };
1363 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1364 }
1365 Ok(())
1366}
1367
1368fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1370 require_args(df_args, 1, func_name)
1371}
1372
1373fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1375 df_args[0].clone()
1376}
1377
1378pub(crate) fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1380 DfExpr::Cast(datafusion::logical_expr::Cast {
1381 expr: Box::new(expr),
1382 data_type,
1383 })
1384}
1385
1386pub(crate) fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1392 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1393 Arc::new(crate::query::df_udfs::create_cypher_list_to_cv_udf()),
1394 vec![expr],
1395 ))
1396}
1397
1398pub(crate) fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1402 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1403 Arc::new(crate::query::df_udfs::create_cypher_scalar_to_cv_udf()),
1404 vec![expr],
1405 ))
1406}
1407
1408fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1410 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1411 Box::new(left),
1412 op,
1413 Box::new(right),
1414 ))
1415}
1416
1417pub(crate) fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1422 use datafusion::logical_expr::Operator;
1423 match op {
1424 Operator::Eq => Some("_cypher_equal"),
1425 Operator::NotEq => Some("_cypher_not_equal"),
1426 Operator::Lt => Some("_cypher_lt"),
1427 Operator::LtEq => Some("_cypher_lt_eq"),
1428 Operator::Gt => Some("_cypher_gt"),
1429 Operator::GtEq => Some("_cypher_gt_eq"),
1430 _ => None,
1431 }
1432}
1433
1434fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1436 use datafusion::logical_expr::Operator;
1437 match op {
1438 Operator::Plus => Some("_cypher_add"),
1439 Operator::Minus => Some("_cypher_sub"),
1440 Operator::Multiply => Some("_cypher_mul"),
1441 Operator::Divide => Some("_cypher_div"),
1442 Operator::Modulo => Some("_cypher_mod"),
1443 _ => None,
1444 }
1445}
1446
1447fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1452where
1453 F: FnOnce(DfExpr) -> DfExpr,
1454{
1455 require_arg(df_args, func_name)?;
1456 Ok(math_fn(cast_expr(
1457 first_arg(df_args),
1458 datafusion::arrow::datatypes::DataType::Float64,
1459 )))
1460}
1461
1462fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1464 if distinct {
1465 expr.distinct()
1466 .build()
1467 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1468 } else {
1469 Ok(expr)
1470 }
1471}
1472
1473fn translate_aggregate_function(
1475 name_upper: &str,
1476 df_args: &[DfExpr],
1477 distinct: bool,
1478) -> Option<Result<DfExpr>> {
1479 match name_upper {
1480 "COUNT" => {
1481 let expr = if df_args.is_empty() {
1482 datafusion::functions_aggregate::count::count(lit(1i64))
1483 } else {
1484 datafusion::functions_aggregate::count::count(first_arg(df_args))
1485 };
1486 Some(maybe_distinct(expr, distinct, "COUNT"))
1487 }
1488 "SUM" => {
1489 check_args!(1, df_args, "SUM");
1490 let udaf = Arc::new(crate::query::df_udfs::create_cypher_sum_udaf());
1491 Some(maybe_distinct(
1492 udaf.call(vec![first_arg(df_args)]),
1493 distinct,
1494 "SUM",
1495 ))
1496 }
1497 "AVG" => {
1498 check_args!(1, df_args, "AVG");
1499 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1500 let expr = datafusion::functions_aggregate::average::avg(coerced);
1501 Some(maybe_distinct(expr, distinct, "AVG"))
1502 }
1503 "MIN" => {
1504 check_args!(1, df_args, "MIN");
1505 let udaf = Arc::new(crate::query::df_udfs::create_cypher_min_udaf());
1506 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1507 }
1508 "MAX" => {
1509 check_args!(1, df_args, "MAX");
1510 let udaf = Arc::new(crate::query::df_udfs::create_cypher_max_udaf());
1511 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1512 }
1513 "PERCENTILEDISC" => {
1514 if df_args.len() != 2 {
1515 return Some(Err(anyhow!(
1516 "percentileDisc() requires exactly 2 arguments"
1517 )));
1518 }
1519 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1520 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_disc_udaf());
1521 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1522 }
1523 "PERCENTILECONT" => {
1524 if df_args.len() != 2 {
1525 return Some(Err(anyhow!(
1526 "percentileCont() requires exactly 2 arguments"
1527 )));
1528 }
1529 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1530 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_cont_udaf());
1531 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1532 }
1533 "COLLECT" => {
1534 check_args!(1, df_args, "COLLECT");
1535 Some(Ok(crate::query::df_udfs::create_cypher_collect_expr(
1536 first_arg(df_args),
1537 distinct,
1538 )))
1539 }
1540 _ => None,
1541 }
1542}
1543
1544fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1547 match name_upper {
1548 "TOSTRING" => {
1549 check_args!(1, df_args, "toString");
1550 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1551 }
1552 "TOINTEGER" | "TOINT" => {
1553 check_args!(1, df_args, "toInteger");
1554 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1555 }
1556 "TOFLOAT" => {
1557 check_args!(1, df_args, "toFloat");
1558 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1559 }
1560 "TOBOOLEAN" | "TOBOOL" => {
1561 check_args!(1, df_args, "toBoolean");
1562 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1563 }
1564 "UPPER" | "TOUPPER" => {
1565 check_args!(1, df_args, "upper");
1566 Some(Ok(datafusion::functions::string::expr_fn::upper(
1567 first_arg(df_args),
1568 )))
1569 }
1570 "LOWER" | "TOLOWER" => {
1571 check_args!(1, df_args, "lower");
1572 Some(Ok(datafusion::functions::string::expr_fn::lower(
1573 first_arg(df_args),
1574 )))
1575 }
1576 "SUBSTRING" => {
1577 check_args!(2, df_args, "substring");
1578 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1579 }
1580 "TRIM" => {
1581 check_args!(1, df_args, "TRIM");
1582 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1583 first_arg(df_args),
1584 ])))
1585 }
1586 "LTRIM" => {
1587 check_args!(1, df_args, "LTRIM");
1588 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1589 first_arg(df_args),
1590 ])))
1591 }
1592 "RTRIM" => {
1593 check_args!(1, df_args, "RTRIM");
1594 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1595 first_arg(df_args),
1596 ])))
1597 }
1598 "LEFT" => {
1599 check_args!(2, df_args, "left");
1600 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1601 df_args[0].clone(),
1602 df_args[1].clone(),
1603 )))
1604 }
1605 "RIGHT" => {
1606 check_args!(2, df_args, "right");
1607 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1608 df_args[0].clone(),
1609 df_args[1].clone(),
1610 )))
1611 }
1612 "REPLACE" => {
1613 check_args!(3, df_args, "replace");
1614 Some(Ok(datafusion::functions::string::expr_fn::replace(
1615 df_args[0].clone(),
1616 df_args[1].clone(),
1617 df_args[2].clone(),
1618 )))
1619 }
1620 "REVERSE" => {
1621 check_args!(1, df_args, "reverse");
1622 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1623 }
1624 "SPLIT" => {
1625 check_args!(2, df_args, "split");
1626 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1627 }
1628 "SIZE" | "LENGTH" => {
1629 check_args!(1, df_args, name_upper);
1630 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1631 }
1632 _ => None,
1633 }
1634}
1635
1636fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1639 use datafusion::functions::math::expr_fn;
1640
1641 let unary_f64 =
1643 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1644
1645 match name_upper {
1646 "ABS" => {
1647 check_args!(1, df_args, "abs");
1648 Some(Ok(crate::query::df_udfs::cypher_abs_expr(first_arg(
1652 df_args,
1653 ))))
1654 }
1655 "CEIL" | "CEILING" => {
1656 check_args!(1, df_args, "ceil");
1657 Some(Ok(expr_fn::ceil(first_arg(df_args))))
1658 }
1659 "FLOOR" => {
1660 check_args!(1, df_args, "floor");
1661 Some(Ok(expr_fn::floor(first_arg(df_args))))
1662 }
1663 "ROUND" => {
1664 check_args!(1, df_args, "round");
1665 let args = if df_args.len() == 1 {
1666 vec![first_arg(df_args)]
1667 } else {
1668 vec![df_args[0].clone(), df_args[1].clone()]
1669 };
1670 Some(Ok(expr_fn::round(args)))
1671 }
1672 "SIGN" => {
1673 check_args!(1, df_args, "sign");
1674 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1675 Some(Ok(expr_fn::signum(coerced)))
1676 }
1677 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1678 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1679 "LOG10" => unary_f64("log10", expr_fn::log10),
1680 "EXP" => unary_f64("exp", expr_fn::exp),
1681 "SIN" => unary_f64("sin", expr_fn::sin),
1682 "COS" => unary_f64("cos", expr_fn::cos),
1683 "TAN" => unary_f64("tan", expr_fn::tan),
1684 "ASIN" => unary_f64("asin", expr_fn::asin),
1685 "ACOS" => unary_f64("acos", expr_fn::acos),
1686 "ATAN" => unary_f64("atan", expr_fn::atan),
1687 "ATAN2" => {
1688 check_args!(2, df_args, "atan2");
1689 let cast_f64 =
1690 |e: DfExpr| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64);
1691 Some(Ok(expr_fn::atan2(
1692 cast_f64(df_args[0].clone()),
1693 cast_f64(df_args[1].clone()),
1694 )))
1695 }
1696 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1697 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1698 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1699 _ => None,
1700 }
1701}
1702
1703fn translate_temporal_function(
1706 name_upper: &str,
1707 name: &str,
1708 df_args: &[DfExpr],
1709 context: Option<&TranslationContext>,
1710) -> Option<Result<DfExpr>> {
1711 match name_upper {
1712 "DATE"
1713 | "TIME"
1714 | "LOCALTIME"
1715 | "LOCALDATETIME"
1716 | "DATETIME"
1717 | "DURATION"
1718 | "YEAR"
1719 | "MONTH"
1720 | "DAY"
1721 | "HOUR"
1722 | "MINUTE"
1723 | "SECOND"
1724 | "DURATION.BETWEEN"
1725 | "DURATION.INMONTHS"
1726 | "DURATION.INDAYS"
1727 | "DURATION.INSECONDS"
1728 | "DATETIME.FROMEPOCH"
1729 | "DATETIME.FROMEPOCHMILLIS"
1730 | "DATE.TRUNCATE"
1731 | "TIME.TRUNCATE"
1732 | "DATETIME.TRUNCATE"
1733 | "LOCALDATETIME.TRUNCATE"
1734 | "LOCALTIME.TRUNCATE"
1735 | "DATETIME.TRANSACTION"
1736 | "DATETIME.STATEMENT"
1737 | "DATETIME.REALTIME"
1738 | "DATE.TRANSACTION"
1739 | "DATE.STATEMENT"
1740 | "DATE.REALTIME"
1741 | "TIME.TRANSACTION"
1742 | "TIME.STATEMENT"
1743 | "TIME.REALTIME"
1744 | "LOCALTIME.TRANSACTION"
1745 | "LOCALTIME.STATEMENT"
1746 | "LOCALTIME.REALTIME"
1747 | "LOCALDATETIME.TRANSACTION"
1748 | "LOCALDATETIME.STATEMENT"
1749 | "LOCALDATETIME.REALTIME" => {
1750 let stmt_time = context.map(|c| c.statement_time);
1754 if can_constant_fold(name_upper, df_args)
1755 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1756 {
1757 return Some(Ok(folded));
1758 }
1759 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1760 }
1761 _ => None,
1762 }
1763}
1764
1765fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1767 if name.contains("REALTIME") {
1769 return false;
1770 }
1771 if args.is_empty() {
1779 return matches!(
1780 name,
1781 "DATE"
1782 | "TIME"
1783 | "LOCALTIME"
1784 | "LOCALDATETIME"
1785 | "DATETIME"
1786 | "DATE.STATEMENT"
1787 | "TIME.STATEMENT"
1788 | "LOCALTIME.STATEMENT"
1789 | "LOCALDATETIME.STATEMENT"
1790 | "DATETIME.STATEMENT"
1791 | "DATE.TRANSACTION"
1792 | "TIME.TRANSACTION"
1793 | "LOCALTIME.TRANSACTION"
1794 | "LOCALDATETIME.TRANSACTION"
1795 | "DATETIME.TRANSACTION"
1796 );
1797 }
1798 args.iter().all(is_constant_expr)
1800}
1801
1802fn is_constant_expr(expr: &DfExpr) -> bool {
1804 match expr {
1805 DfExpr::Literal(_, _) => true,
1806 DfExpr::ScalarFunction(func) => {
1807 func.args.iter().all(is_constant_expr)
1809 }
1810 _ => false,
1811 }
1812}
1813
1814fn try_constant_fold_temporal(
1820 name: &str,
1821 args: &[DfExpr],
1822 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1823) -> Result<DfExpr> {
1824 let val_args: Vec<Value> = args
1826 .iter()
1827 .map(extract_constant_value)
1828 .collect::<Result<_>>()?;
1829
1830 let result = if val_args.is_empty() {
1832 if let Some(frozen) = stmt_time {
1833 crate::query::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1834 } else {
1835 crate::query::datetime::eval_datetime_function(name, &val_args)?
1836 }
1837 } else {
1838 crate::query::datetime::eval_datetime_function(name, &val_args)?
1839 };
1840
1841 let scalar = value_to_scalar(&result)?;
1843 Ok(DfExpr::Literal(scalar, None))
1844}
1845
1846fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1848 use crate::query::df_udfs::scalar_to_value;
1849 match expr {
1850 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1851 DfExpr::ScalarFunction(func) => {
1852 let mut map = std::collections::HashMap::new();
1855 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1856 for chunk in pairs.chunks(2) {
1857 if let [key_expr, val_expr] = chunk {
1858 let key = match key_expr {
1860 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1861 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1862 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1863 };
1864 let val = extract_constant_value(val_expr)?;
1865 map.insert(key, val);
1866 } else {
1867 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1868 }
1869 }
1870 Ok(Value::Map(map))
1871 }
1872 _ => Err(anyhow::anyhow!(
1873 "Cannot extract constant value from expression"
1874 )),
1875 }
1876}
1877
1878fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1881 match name_upper {
1882 "HEAD" => {
1883 check_args!(1, df_args, "head");
1884 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1885 }
1886 "LAST" => {
1887 check_args!(1, df_args, "last");
1888 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1889 }
1890 "TAIL" => {
1891 check_args!(1, df_args, "tail");
1892 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
1893 }
1894 "RANGE" => {
1895 check_args!(2, df_args, "range");
1896 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
1897 }
1898 _ => None,
1899 }
1900}
1901
1902fn translate_graph_function(
1905 name_upper: &str,
1906 name: &str,
1907 df_args: &[DfExpr],
1908 args: &[Expr],
1909 context: Option<&TranslationContext>,
1910) -> Option<Result<DfExpr>> {
1911 match name_upper {
1912 "ID" => {
1913 if let Some(Expr::Variable(var)) = args.first() {
1916 let is_edge = context.is_some_and(|ctx| {
1917 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
1918 || ctx.mutation_edge_hints.iter().any(|h| h == var)
1919 });
1920 let id_suffix = if is_edge { COL_EID } else { COL_VID };
1921 Some(Ok(DfExpr::Column(Column::from_name(format!(
1922 "{}.{}",
1923 var, id_suffix
1924 )))))
1925 } else {
1926 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
1927 }
1928 }
1929 "LABELS" | "KEYS" => {
1930 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1935 }
1936 "TYPE" => {
1937 if let Some(Expr::Variable(var)) = args.first()
1941 && let Some(ctx) = context
1942 && let Some(label) = ctx.variable_labels.get(var)
1943 {
1944 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
1947 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
1948 expr: None,
1949 when_then_expr: vec![(
1950 Box::new(eid_col.is_not_null()),
1951 Box::new(lit(label.clone())),
1952 )],
1953 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
1954 })));
1955 }
1956 if let Some(Expr::Variable(var)) = args.first()
1960 && context
1961 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
1962 {
1963 return Some(Ok(DfExpr::Column(Column::from_name(format!(
1964 "{}.{}",
1965 var, COL_TYPE
1966 )))));
1967 }
1968 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
1969 }
1970 "PROPERTIES" => {
1971 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1974 }
1975 "UNI.TEMPORAL.VALIDAT" => {
1976 if let (
1979 Some(Expr::Variable(var)),
1980 Some(Expr::Literal(CypherLiteral::String(start_prop))),
1981 Some(Expr::Literal(CypherLiteral::String(end_prop))),
1982 Some(ts_expr),
1983 ) = (args.first(), args.get(1), args.get(2), args.get(3))
1984 {
1985 let start_col =
1986 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
1987 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
1988 let ts = match cypher_expr_to_df(ts_expr, context) {
1989 Ok(ts) => ts,
1990 Err(e) => return Some(Err(e)),
1991 };
1992
1993 let start_check = start_col.lt_eq(ts.clone());
1995 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
1997 let end_after = end_col.gt(ts);
1998 let end_check = end_null.or(end_after);
1999
2000 Some(Ok(start_check.and(end_check)))
2001 } else {
2002 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2004 }
2005 }
2006 "STARTNODE" | "ENDNODE" => {
2007 let mut udf_args = df_args.to_vec();
2010 let mut seen = std::collections::HashSet::new();
2011 if let Some(ctx) = context {
2012 for (var, kind) in &ctx.variable_kinds {
2014 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2015 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2016 }
2017 }
2018 for var in &ctx.node_variable_hints {
2021 if seen.insert(var.clone()) {
2022 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2023 }
2024 }
2025 }
2026 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2027 }
2028 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2029 "HASLABEL" => {
2030 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2031 return Some(Err(e));
2032 }
2033 if let Some(Expr::Variable(var)) = args.first() {
2035 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2036 let labels_col =
2038 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2039 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2040 labels_col,
2041 lit(label.clone()),
2042 )))
2043 } else {
2044 Some(Err(anyhow::anyhow!(
2046 "hasLabel requires string literal as second argument for DataFusion translation"
2047 )))
2048 }
2049 } else {
2050 Some(Err(anyhow::anyhow!(
2052 "hasLabel requires variable as first argument for DataFusion translation"
2053 )))
2054 }
2055 }
2056 _ => None,
2057 }
2058}
2059
2060fn translate_function_call(
2062 name: &str,
2063 args: &[Expr],
2064 distinct: bool,
2065 context: Option<&TranslationContext>,
2066) -> Result<DfExpr> {
2067 let df_args: Vec<DfExpr> = args
2068 .iter()
2069 .map(|arg| cypher_expr_to_df(arg, context))
2070 .collect::<Result<Vec<_>>>()?;
2071
2072 let name_upper = name.to_uppercase();
2073
2074 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2078 return result;
2079 }
2080
2081 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2082 return result;
2083 }
2084
2085 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2086 return result;
2087 }
2088
2089 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2090 return result;
2091 }
2092
2093 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2094 return result;
2095 }
2096
2097 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2098 return result;
2099 }
2100
2101 match name_upper.as_str() {
2103 "COALESCE" => {
2104 require_arg(&df_args, "coalesce")?;
2105 return Ok(datafusion::functions::expr_fn::coalesce(df_args));
2106 }
2107 "NULLIF" => {
2108 require_args(&df_args, 2, "nullif")?;
2109 return Ok(datafusion::functions::expr_fn::nullif(
2110 df_args[0].clone(),
2111 df_args[1].clone(),
2112 ));
2113 }
2114 _ => {}
2115 }
2116
2117 match name_upper.as_str() {
2119 "SIMILAR_TO" | "VECTOR_SIMILARITY" => {
2120 return Ok(dummy_udf_expr(&name_upper.to_lowercase(), df_args));
2121 }
2122 _ => {}
2123 }
2124
2125 Ok(dummy_udf_expr(name, df_args))
2127}
2128
2129#[derive(Debug)]
2134struct DummyUdf {
2135 name: String,
2136 signature: datafusion::logical_expr::Signature,
2137 ret_type: datafusion::arrow::datatypes::DataType,
2138}
2139
2140impl DummyUdf {
2141 fn new(name: String) -> Self {
2142 let ret_type = dummy_udf_return_type(&name);
2143 Self {
2144 name,
2145 signature: datafusion::logical_expr::Signature::variadic_any(
2146 datafusion::logical_expr::Volatility::Immutable,
2147 ),
2148 ret_type,
2149 }
2150 }
2151}
2152
2153fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2166 use datafusion::arrow::datatypes::DataType;
2167 match name {
2168 "_cypher_add"
2172 | "_cypher_sub"
2173 | "_cypher_mul"
2174 | "_cypher_div"
2175 | "_cypher_mod"
2176 | "_cypher_list_concat"
2177 | "_cypher_list_append"
2178 | "_make_cypher_list"
2179 | "_map_project"
2180 | "_cypher_list_to_cv"
2181 | "_cypher_tail" => DataType::LargeBinary,
2182 _ => DataType::Null,
2186 }
2187}
2188
2189impl PartialEq for DummyUdf {
2190 fn eq(&self, other: &Self) -> bool {
2191 self.name == other.name
2192 }
2193}
2194
2195impl Eq for DummyUdf {}
2196
2197impl Hash for DummyUdf {
2198 fn hash<H: Hasher>(&self, state: &mut H) {
2199 self.name.hash(state);
2200 }
2201}
2202
2203pub(crate) fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2205 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2206 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2207 DummyUdf::new(name.to_lowercase()),
2208 )),
2209 args,
2210 })
2211}
2212
2213impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2214 fn as_any(&self) -> &dyn std::any::Any {
2215 self
2216 }
2217
2218 fn name(&self) -> &str {
2219 &self.name
2220 }
2221
2222 fn signature(&self) -> &datafusion::logical_expr::Signature {
2223 &self.signature
2224 }
2225
2226 fn return_type(
2227 &self,
2228 _arg_types: &[datafusion::arrow::datatypes::DataType],
2229 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2230 Ok(self.ret_type.clone())
2233 }
2234
2235 fn invoke_with_args(
2236 &self,
2237 _args: ScalarFunctionArgs,
2238 ) -> datafusion::error::Result<ColumnarValue> {
2239 Err(datafusion::error::DataFusionError::Plan(format!(
2240 "UDF '{}' is not registered. Register it via SessionContext.",
2241 self.name
2242 )))
2243 }
2244}
2245
2246pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2250 let mut properties = Vec::new();
2251 collect_properties_recursive(expr, &mut properties);
2252 properties.sort();
2253 properties.dedup();
2254 properties
2255}
2256
2257fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2258 match expr {
2259 Expr::PatternComprehension { .. } => {}
2260 Expr::Property(base, prop) => {
2261 if let Ok(var_name) = extract_variable_name(base) {
2262 properties.push((var_name, prop.clone()));
2263 }
2264 collect_properties_recursive(base, properties);
2265 }
2266 Expr::ArrayIndex { array, index } => {
2267 if let Ok(var_name) = extract_variable_name(array)
2268 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2269 {
2270 properties.push((var_name, prop_name.clone()));
2271 }
2272 collect_properties_recursive(array, properties);
2273 collect_properties_recursive(index, properties);
2274 }
2275 Expr::ArraySlice { array, start, end } => {
2276 collect_properties_recursive(array, properties);
2277 if let Some(s) = start {
2278 collect_properties_recursive(s, properties);
2279 }
2280 if let Some(e) = end {
2281 collect_properties_recursive(e, properties);
2282 }
2283 }
2284 Expr::List(items) => {
2285 for item in items {
2286 collect_properties_recursive(item, properties);
2287 }
2288 }
2289 Expr::Map(entries) => {
2290 for (_, value) in entries {
2291 collect_properties_recursive(value, properties);
2292 }
2293 }
2294 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2295 collect_properties_recursive(inner, properties);
2296 }
2297 Expr::FunctionCall { args, .. } => {
2298 for arg in args {
2299 collect_properties_recursive(arg, properties);
2300 }
2301 }
2302 Expr::BinaryOp { left, right, .. } => {
2303 collect_properties_recursive(left, properties);
2304 collect_properties_recursive(right, properties);
2305 }
2306 Expr::UnaryOp { expr, .. } => {
2307 collect_properties_recursive(expr, properties);
2308 }
2309 Expr::Case {
2310 expr,
2311 when_then,
2312 else_expr,
2313 } => {
2314 if let Some(e) = expr {
2315 collect_properties_recursive(e, properties);
2316 }
2317 for (when_e, then_e) in when_then {
2318 collect_properties_recursive(when_e, properties);
2319 collect_properties_recursive(then_e, properties);
2320 }
2321 if let Some(e) = else_expr {
2322 collect_properties_recursive(e, properties);
2323 }
2324 }
2325 Expr::Reduce {
2326 init, list, expr, ..
2327 } => {
2328 collect_properties_recursive(init, properties);
2329 collect_properties_recursive(list, properties);
2330 collect_properties_recursive(expr, properties);
2331 }
2332 Expr::Quantifier {
2333 list, predicate, ..
2334 } => {
2335 collect_properties_recursive(list, properties);
2336 collect_properties_recursive(predicate, properties);
2337 }
2338 Expr::ListComprehension {
2339 list,
2340 where_clause,
2341 map_expr,
2342 ..
2343 } => {
2344 collect_properties_recursive(list, properties);
2345 if let Some(filter) = where_clause {
2346 collect_properties_recursive(filter, properties);
2347 }
2348 collect_properties_recursive(map_expr, properties);
2349 }
2350 Expr::In { expr, list } => {
2351 collect_properties_recursive(expr, properties);
2352 collect_properties_recursive(list, properties);
2353 }
2354 Expr::ValidAt {
2355 entity, timestamp, ..
2356 } => {
2357 collect_properties_recursive(entity, properties);
2358 collect_properties_recursive(timestamp, properties);
2359 }
2360 Expr::MapProjection { base, items } => {
2361 collect_properties_recursive(base, properties);
2362 for item in items {
2363 match item {
2364 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2365 if let Ok(var_name) = extract_variable_name(base) {
2366 properties.push((var_name, prop.clone()));
2367 }
2368 }
2369 uni_cypher::ast::MapProjectionItem::AllProperties => {
2370 if let Ok(var_name) = extract_variable_name(base) {
2371 properties.push((var_name, "*".to_string()));
2372 }
2373 }
2374 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2375 collect_properties_recursive(expr, properties);
2376 }
2377 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2378 }
2379 }
2380 }
2381 Expr::LabelCheck { expr, .. } => {
2382 collect_properties_recursive(expr, properties);
2383 }
2384 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2386 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2387 }
2388}
2389
2390pub fn wider_numeric_type(
2397 a: &datafusion::arrow::datatypes::DataType,
2398 b: &datafusion::arrow::datatypes::DataType,
2399) -> datafusion::arrow::datatypes::DataType {
2400 use datafusion::arrow::datatypes::DataType;
2401
2402 fn numeric_rank(dt: &DataType) -> u8 {
2403 match dt {
2404 DataType::Int8 | DataType::UInt8 => 1,
2405 DataType::Int16 | DataType::UInt16 => 2,
2406 DataType::Int32 | DataType::UInt32 => 3,
2407 DataType::Int64 | DataType::UInt64 => 4,
2408 DataType::Float16 => 5,
2409 DataType::Float32 => 6,
2410 DataType::Float64 => 7,
2411 _ => 0,
2412 }
2413 }
2414
2415 if numeric_rank(a) >= numeric_rank(b) {
2416 a.clone()
2417 } else {
2418 b.clone()
2419 }
2420}
2421
2422fn resolve_column_type_fallback(
2428 expr: &DfExpr,
2429 schema: &datafusion::common::DFSchema,
2430) -> Option<datafusion::arrow::datatypes::DataType> {
2431 if let DfExpr::Column(col) = expr {
2432 let col_name = &col.name;
2433 for (_, field) in schema.iter() {
2435 if field.name() == col_name {
2436 return Some(field.data_type().clone());
2437 }
2438 }
2439 }
2440 None
2441}
2442
2443fn contains_division(expr: &DfExpr) -> bool {
2446 match expr {
2447 DfExpr::BinaryExpr(b) => {
2448 b.op == datafusion::logical_expr::Operator::Divide
2449 || contains_division(&b.left)
2450 || contains_division(&b.right)
2451 }
2452 DfExpr::Cast(c) => contains_division(&c.expr),
2453 DfExpr::TryCast(c) => contains_division(&c.expr),
2454 _ => false,
2455 }
2456}
2457
2458pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2464 use datafusion::arrow::datatypes::DataType;
2465 use datafusion::logical_expr::ExprSchemable;
2466
2467 match expr {
2468 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2469 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2470 DfExpr::Case(case) => coerce_case_expr(case, schema),
2471 DfExpr::InList(in_list) => {
2472 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2473 let coerced_list = in_list
2474 .list
2475 .iter()
2476 .map(|e| apply_type_coercion(e, schema))
2477 .collect::<Result<Vec<_>>>()?;
2478 let expr_type = coerced_expr
2479 .get_type(schema)
2480 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2481 crate::query::cypher_type_coerce::build_cypher_in_list(
2482 coerced_expr,
2483 &expr_type,
2484 coerced_list,
2485 in_list.negated,
2486 schema,
2487 )
2488 }
2489 DfExpr::Not(inner) => {
2490 let coerced_inner = apply_type_coercion(inner, schema)?;
2491 let inner_type = coerced_inner.get_type(schema).ok();
2492 let final_inner = if inner_type
2493 .as_ref()
2494 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2495 {
2496 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2497 } else if inner_type
2498 .as_ref()
2499 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2500 {
2501 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2502 } else {
2503 coerced_inner
2504 };
2505 Ok(DfExpr::Not(Box::new(final_inner)))
2506 }
2507 DfExpr::IsNull(inner) => {
2508 let coerced_inner = apply_type_coercion(inner, schema)?;
2509 Ok(coerced_inner.is_null())
2510 }
2511 DfExpr::IsNotNull(inner) => {
2512 let coerced_inner = apply_type_coercion(inner, schema)?;
2513 Ok(coerced_inner.is_not_null())
2514 }
2515 DfExpr::Negative(inner) => {
2516 let coerced_inner = apply_type_coercion(inner, schema)?;
2517 let inner_type = coerced_inner.get_type(schema).ok();
2518 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2519 Ok(dummy_udf_expr(
2520 "_cypher_mul",
2521 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2522 ))
2523 } else {
2524 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2525 }
2526 }
2527 DfExpr::Cast(cast) => {
2528 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2529 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2530 Box::new(coerced_inner),
2531 cast.data_type.clone(),
2532 )))
2533 }
2534 DfExpr::TryCast(cast) => {
2535 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2536 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2537 Box::new(coerced_inner),
2538 cast.data_type.clone(),
2539 )))
2540 }
2541 DfExpr::Alias(alias) => {
2542 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2543 Ok(coerced_inner.alias(alias.name.clone()))
2544 }
2545 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2546 _ => Ok(expr.clone()),
2547 }
2548}
2549
2550fn coerce_logical_operands(
2552 left: DfExpr,
2553 right: DfExpr,
2554 op: datafusion::logical_expr::Operator,
2555 schema: &datafusion::common::DFSchema,
2556) -> Option<DfExpr> {
2557 use datafusion::arrow::datatypes::DataType;
2558 use datafusion::logical_expr::ExprSchemable;
2559
2560 if !matches!(
2561 op,
2562 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2563 ) {
2564 return None;
2565 }
2566 let left_type = left.get_type(schema).ok();
2567 let right_type = right.get_type(schema).ok();
2568 let left_needs_cast = left_type
2569 .as_ref()
2570 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2571 let right_needs_cast = right_type
2572 .as_ref()
2573 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2574 let left_is_lb = left_type
2575 .as_ref()
2576 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2577 let right_is_lb = right_type
2578 .as_ref()
2579 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2580 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2581 return None;
2582 }
2583 let coerced_left = if left_is_lb {
2584 dummy_udf_expr("_cv_to_bool", vec![left])
2585 } else if left_needs_cast {
2586 datafusion::logical_expr::cast(left, DataType::Boolean)
2587 } else {
2588 left
2589 };
2590 let coerced_right = if right_is_lb {
2591 dummy_udf_expr("_cv_to_bool", vec![right])
2592 } else if right_needs_cast {
2593 datafusion::logical_expr::cast(right, DataType::Boolean)
2594 } else {
2595 right
2596 };
2597 Some(binary_expr(coerced_left, op, coerced_right))
2598}
2599
2600#[expect(
2603 clippy::too_many_arguments,
2604 reason = "Binary coercion needs all context"
2605)]
2606fn coerce_large_binary_ops(
2607 left: &DfExpr,
2608 right: &DfExpr,
2609 left_type: &datafusion::arrow::datatypes::DataType,
2610 right_type: &datafusion::arrow::datatypes::DataType,
2611 left_is_null: bool,
2612 op: datafusion::logical_expr::Operator,
2613 is_comparison: bool,
2614 is_arithmetic: bool,
2615) -> Option<Result<DfExpr>> {
2616 use datafusion::arrow::datatypes::DataType;
2617 use datafusion::logical_expr::Operator;
2618
2619 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2620 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2621
2622 if op == Operator::Plus {
2623 if left_is_lb && right_is_lb {
2624 return Some(Ok(dummy_udf_expr(
2625 "_cypher_add",
2626 vec![left.clone(), right.clone()],
2627 )));
2628 }
2629 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2630 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2631 if left_is_native_list && right_is_native_list {
2632 return Some(Ok(dummy_udf_expr(
2633 "_cypher_list_concat",
2634 vec![left.clone(), right.clone()],
2635 )));
2636 }
2637 if left_is_native_list || right_is_native_list {
2638 return Some(Ok(dummy_udf_expr(
2639 "_cypher_list_append",
2640 vec![left.clone(), right.clone()],
2641 )));
2642 }
2643 }
2644
2645 if (left_is_lb || right_is_lb) && is_comparison {
2646 if let Some(udf_name) = comparison_udf_name(op) {
2647 return Some(Ok(dummy_udf_expr(
2648 udf_name,
2649 vec![left.clone(), right.clone()],
2650 )));
2651 }
2652 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2653 }
2654
2655 if (left_is_lb || right_is_lb) && is_arithmetic {
2656 let udf_name =
2657 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2658 return Some(Ok(dummy_udf_expr(
2659 udf_name,
2660 vec![left.clone(), right.clone()],
2661 )));
2662 }
2663
2664 None
2665}
2666
2667fn coerce_temporal_comparisons(
2669 left: DfExpr,
2670 right: DfExpr,
2671 left_type: &datafusion::arrow::datatypes::DataType,
2672 right_type: &datafusion::arrow::datatypes::DataType,
2673 op: datafusion::logical_expr::Operator,
2674 is_comparison: bool,
2675) -> Option<DfExpr> {
2676 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2677 use datafusion::logical_expr::Operator;
2678
2679 if !is_comparison {
2680 return None;
2681 }
2682
2683 if uni_common::core::schema::is_datetime_struct(left_type)
2685 && uni_common::core::schema::is_datetime_struct(right_type)
2686 {
2687 return Some(binary_expr(
2688 extract_datetime_nanos(left),
2689 op,
2690 extract_datetime_nanos(right),
2691 ));
2692 }
2693
2694 if uni_common::core::schema::is_time_struct(left_type)
2696 && uni_common::core::schema::is_time_struct(right_type)
2697 {
2698 return Some(binary_expr(
2699 extract_time_nanos(left),
2700 op,
2701 extract_time_nanos(right),
2702 ));
2703 }
2704
2705 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2707 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2708
2709 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2710 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2711 {
2712 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2713 extract_datetime_nanos(left)
2714 } else {
2715 left
2716 };
2717 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2718 extract_datetime_nanos(right)
2719 } else {
2720 right
2721 };
2722 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2723 return Some(binary_expr(
2724 cast_expr(left_nanos, ts_type.clone()),
2725 op,
2726 cast_expr(right_nanos, ts_type),
2727 ));
2728 }
2729
2730 let left_is_duration = matches!(left_type, DataType::Interval(_));
2734 let right_is_duration = matches!(right_type, DataType::Interval(_));
2735 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2736 || uni_common::core::schema::is_time_struct(left_type)
2737 || matches!(
2738 left_type,
2739 DataType::Timestamp(_, _)
2740 | DataType::Date32
2741 | DataType::Date64
2742 | DataType::Time32(_)
2743 | DataType::Time64(_)
2744 );
2745 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2746 || uni_common::core::schema::is_time_struct(right_type)
2747 || matches!(
2748 right_type,
2749 DataType::Timestamp(_, _)
2750 | DataType::Date32
2751 | DataType::Date64
2752 | DataType::Time32(_)
2753 | DataType::Time64(_)
2754 );
2755
2756 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2757 {
2758 return Some(match op {
2759 Operator::Eq => lit(false),
2760 Operator::NotEq => lit(true),
2761 _ => lit(ScalarValue::Boolean(None)),
2762 });
2763 }
2764
2765 None
2766}
2767
2768fn coerce_mismatched_types(
2771 left: DfExpr,
2772 right: DfExpr,
2773 left_type: &datafusion::arrow::datatypes::DataType,
2774 right_type: &datafusion::arrow::datatypes::DataType,
2775 op: datafusion::logical_expr::Operator,
2776 is_comparison: bool,
2777) -> Option<Result<DfExpr>> {
2778 use datafusion::arrow::datatypes::DataType;
2779 use datafusion::logical_expr::Operator;
2780
2781 if left_type == right_type {
2782 return None;
2783 }
2784
2785 if left_type.is_numeric() && right_type.is_numeric() {
2787 if left_type == &DataType::Int64
2788 && right_type == &DataType::UInt64
2789 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2790 {
2791 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2792 return Some(Ok(binary_expr(coerced_left, op, right)));
2793 }
2794 if left_type == &DataType::UInt64
2795 && right_type == &DataType::Int64
2796 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2797 {
2798 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2799 return Some(Ok(binary_expr(left, op, coerced_right)));
2800 }
2801 let target = wider_numeric_type(left_type, right_type);
2802 let coerced_left = if *left_type != target {
2803 datafusion::logical_expr::cast(left, target.clone())
2804 } else {
2805 left
2806 };
2807 let coerced_right = if *right_type != target {
2808 datafusion::logical_expr::cast(right, target)
2809 } else {
2810 right
2811 };
2812 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2813 }
2814
2815 if is_comparison {
2817 match (left_type, right_type) {
2818 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2819 let right = normalize_datetime_literal(right);
2820 return Some(Ok(binary_expr(
2821 left,
2822 op,
2823 datafusion::logical_expr::cast(right, ts.clone()),
2824 )));
2825 }
2826 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2827 let left = normalize_datetime_literal(left);
2828 return Some(Ok(binary_expr(
2829 datafusion::logical_expr::cast(left, ts.clone()),
2830 op,
2831 right,
2832 )));
2833 }
2834 _ => {}
2835 }
2836 }
2837
2838 if is_comparison
2840 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2841 {
2842 let l_inner = l_field.data_type();
2843 let r_inner = r_field.data_type();
2844 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
2845 let target_inner = wider_numeric_type(l_inner, r_inner);
2846 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
2847 "item",
2848 target_inner,
2849 true,
2850 )));
2851 return Some(Ok(binary_expr(
2852 datafusion::logical_expr::cast(left, target_type.clone()),
2853 op,
2854 datafusion::logical_expr::cast(right, target_type),
2855 )));
2856 }
2857 }
2858
2859 if is_primitive_type(left_type) && is_primitive_type(right_type) {
2861 if op == Operator::Plus {
2862 return Some(crate::query::cypher_type_coerce::build_cypher_plus(
2863 left, left_type, right, right_type,
2864 ));
2865 }
2866 if is_comparison {
2867 return Some(Ok(
2868 crate::query::cypher_type_coerce::build_cypher_comparison(
2869 left, left_type, right, right_type, op,
2870 ),
2871 ));
2872 }
2873 }
2874
2875 None
2876}
2877
2878fn coerce_list_comparisons(
2880 left: DfExpr,
2881 right: DfExpr,
2882 left_type: &datafusion::arrow::datatypes::DataType,
2883 right_type: &datafusion::arrow::datatypes::DataType,
2884 op: datafusion::logical_expr::Operator,
2885 is_comparison: bool,
2886) -> Option<DfExpr> {
2887 use datafusion::arrow::datatypes::DataType;
2888 use datafusion::logical_expr::Operator;
2889
2890 if !is_comparison {
2891 return None;
2892 }
2893
2894 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2895 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2896
2897 if left_is_list
2899 && right_is_list
2900 && matches!(
2901 op,
2902 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
2903 )
2904 {
2905 let op_str = match op {
2906 Operator::Lt => "lt",
2907 Operator::LtEq => "lteq",
2908 Operator::Gt => "gt",
2909 Operator::GtEq => "gteq",
2910 _ => unreachable!(),
2911 };
2912 return Some(dummy_udf_expr(
2913 "_cypher_list_compare",
2914 vec![left, right, lit(op_str)],
2915 ));
2916 }
2917
2918 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
2920 let udf_name =
2921 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
2922 return Some(dummy_udf_expr(udf_name, vec![left, right]));
2923 }
2924
2925 if (left_is_list != right_is_list)
2927 && !matches!(left_type, DataType::Null)
2928 && !matches!(right_type, DataType::Null)
2929 {
2930 return Some(match op {
2931 Operator::Eq => lit(false),
2932 Operator::NotEq => lit(true),
2933 _ => lit(ScalarValue::Boolean(None)),
2934 });
2935 }
2936
2937 None
2938}
2939
2940fn coerce_binary_expr(
2942 binary: &datafusion::logical_expr::expr::BinaryExpr,
2943 schema: &datafusion::common::DFSchema,
2944) -> Result<DfExpr> {
2945 use datafusion::arrow::datatypes::DataType;
2946 use datafusion::logical_expr::ExprSchemable;
2947 use datafusion::logical_expr::Operator;
2948
2949 let left = apply_type_coercion(&binary.left, schema)?;
2950 let right = apply_type_coercion(&binary.right, schema)?;
2951
2952 let is_comparison = matches!(
2953 binary.op,
2954 Operator::Eq
2955 | Operator::NotEq
2956 | Operator::Lt
2957 | Operator::LtEq
2958 | Operator::Gt
2959 | Operator::GtEq
2960 );
2961 let is_arithmetic = matches!(
2962 binary.op,
2963 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
2964 );
2965
2966 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
2968 return Ok(result);
2969 }
2970
2971 if is_comparison || is_arithmetic {
2972 let left_type = match left.get_type(schema) {
2973 Ok(t) => t,
2974 Err(e) => {
2975 if let Some(t) = resolve_column_type_fallback(&left, schema) {
2976 t
2977 } else {
2978 log::warn!("Failed to get left type in binary expr: {}", e);
2979 return Ok(binary_expr(left, binary.op, right));
2980 }
2981 }
2982 };
2983 let right_type = match right.get_type(schema) {
2984 Ok(t) => t,
2985 Err(e) => {
2986 if let Some(t) = resolve_column_type_fallback(&right, schema) {
2987 t
2988 } else {
2989 log::warn!("Failed to get right type in binary expr: {}", e);
2990 return Ok(binary_expr(left, binary.op, right));
2991 }
2992 }
2993 };
2994
2995 let left_is_null = left_type.is_null();
2997 let right_is_null = right_type.is_null();
2998 if left_is_null && right_is_null {
2999 return Ok(lit(ScalarValue::Boolean(None)));
3000 }
3001 if left_is_null || right_is_null {
3002 let target = if left_is_null {
3003 &right_type
3004 } else {
3005 &left_type
3006 };
3007 if !matches!(target, DataType::LargeBinary) {
3008 let coerced_left = if left_is_null {
3009 datafusion::logical_expr::cast(left, target.clone())
3010 } else {
3011 left
3012 };
3013 let coerced_right = if right_is_null {
3014 datafusion::logical_expr::cast(right, target.clone())
3015 } else {
3016 right
3017 };
3018 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3019 }
3020 }
3021
3022 if let Some(result) = coerce_large_binary_ops(
3024 &left,
3025 &right,
3026 &left_type,
3027 &right_type,
3028 left_is_null,
3029 binary.op,
3030 is_comparison,
3031 is_arithmetic,
3032 ) {
3033 return result;
3034 }
3035
3036 if let Some(result) = coerce_temporal_comparisons(
3038 left.clone(),
3039 right.clone(),
3040 &left_type,
3041 &right_type,
3042 binary.op,
3043 is_comparison,
3044 ) {
3045 return Ok(result);
3046 }
3047
3048 let either_struct =
3050 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3051 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3052 || matches!(left_type, DataType::Struct(_)))
3053 && (matches!(right_type, DataType::LargeBinary)
3054 || matches!(right_type, DataType::Struct(_)));
3055 if is_comparison && either_struct && either_lb_or_struct {
3056 if let Some(udf_name) = comparison_udf_name(binary.op) {
3057 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3058 }
3059 return Ok(lit(ScalarValue::Boolean(None)));
3060 }
3061
3062 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3064 let udf_name = comparison_udf_name(binary.op)
3065 .expect("is_comparison guarantees a valid comparison operator");
3066 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3067 }
3068
3069 if binary.op == Operator::Plus
3071 && (crate::query::cypher_type_coerce::is_string_type(&left_type)
3072 || crate::query::cypher_type_coerce::is_string_type(&right_type))
3073 && is_primitive_type(&left_type)
3074 && is_primitive_type(&right_type)
3075 {
3076 return crate::query::cypher_type_coerce::build_cypher_plus(
3077 left,
3078 &left_type,
3079 right,
3080 &right_type,
3081 );
3082 }
3083
3084 if let Some(result) = coerce_mismatched_types(
3086 left.clone(),
3087 right.clone(),
3088 &left_type,
3089 &right_type,
3090 binary.op,
3091 is_comparison,
3092 ) {
3093 return result;
3094 }
3095
3096 if let Some(result) = coerce_list_comparisons(
3098 left.clone(),
3099 right.clone(),
3100 &left_type,
3101 &right_type,
3102 binary.op,
3103 is_comparison,
3104 ) {
3105 return Ok(result);
3106 }
3107 }
3108
3109 Ok(binary_expr(left, binary.op, right))
3110}
3111
3112fn coerce_scalar_function(
3114 func: &datafusion::logical_expr::expr::ScalarFunction,
3115 schema: &datafusion::common::DFSchema,
3116) -> Result<DfExpr> {
3117 use datafusion::arrow::datatypes::DataType;
3118 use datafusion::logical_expr::ExprSchemable;
3119
3120 let coerced_args: Vec<DfExpr> = func
3121 .args
3122 .iter()
3123 .map(|a| apply_type_coercion(a, schema))
3124 .collect::<Result<Vec<_>>>()?;
3125
3126 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3127 let types: Vec<_> = coerced_args
3128 .iter()
3129 .filter_map(|a| a.get_type(schema).ok())
3130 .collect();
3131 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3132 if has_mixed_types {
3133 let has_large_binary = types.iter().any(|t| matches!(t, DataType::LargeBinary));
3134
3135 if has_large_binary {
3136 let unified_args: Vec<DfExpr> = coerced_args
3137 .into_iter()
3138 .zip(types.iter())
3139 .map(|(arg, t)| match t {
3140 DataType::LargeBinary | DataType::Null => arg,
3141 DataType::List(_) | DataType::LargeList(_) => {
3142 list_to_large_binary_expr(arg)
3143 }
3144 _ => scalar_to_large_binary_expr(arg),
3145 })
3146 .collect();
3147 return Ok(DfExpr::ScalarFunction(
3148 datafusion::logical_expr::expr::ScalarFunction {
3149 func: func.func.clone(),
3150 args: unified_args,
3151 },
3152 ));
3153 }
3154
3155 let all_list_or_lb = types.iter().all(|t| {
3156 matches!(
3157 t,
3158 DataType::Null
3159 | DataType::LargeBinary
3160 | DataType::List(_)
3161 | DataType::LargeList(_)
3162 )
3163 });
3164 if all_list_or_lb {
3165 let unified_args: Vec<DfExpr> = coerced_args
3166 .into_iter()
3167 .zip(types.iter())
3168 .map(|(arg, t)| {
3169 if matches!(t, DataType::List(_) | DataType::LargeList(_)) {
3170 list_to_large_binary_expr(arg)
3171 } else {
3172 arg
3173 }
3174 })
3175 .collect();
3176 return Ok(DfExpr::ScalarFunction(
3177 datafusion::logical_expr::expr::ScalarFunction {
3178 func: func.func.clone(),
3179 args: unified_args,
3180 },
3181 ));
3182 } else {
3183 let unified_args = coerced_args
3184 .into_iter()
3185 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3186 .collect();
3187 return Ok(DfExpr::ScalarFunction(
3188 datafusion::logical_expr::expr::ScalarFunction {
3189 func: func.func.clone(),
3190 args: unified_args,
3191 },
3192 ));
3193 }
3194 }
3195 }
3196
3197 Ok(DfExpr::ScalarFunction(
3198 datafusion::logical_expr::expr::ScalarFunction {
3199 func: func.func.clone(),
3200 args: coerced_args,
3201 },
3202 ))
3203}
3204
3205fn coerce_case_expr(
3208 case: &datafusion::logical_expr::expr::Case,
3209 schema: &datafusion::common::DFSchema,
3210) -> Result<DfExpr> {
3211 use datafusion::arrow::datatypes::DataType;
3212 use datafusion::logical_expr::ExprSchemable;
3213
3214 let coerced_operand = case
3215 .expr
3216 .as_ref()
3217 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3218 .transpose()?;
3219 let coerced_when_then = case
3220 .when_then_expr
3221 .iter()
3222 .map(|(w, t)| {
3223 let cw = apply_type_coercion(w, schema)?;
3224 let cw = match cw.get_type(schema).ok() {
3225 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3226 _ => cw,
3227 };
3228 let ct = apply_type_coercion(t, schema)?;
3229 Ok((Box::new(cw), Box::new(ct)))
3230 })
3231 .collect::<Result<Vec<_>>>()?;
3232 let coerced_else = case
3233 .else_expr
3234 .as_ref()
3235 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3236 .transpose()?;
3237
3238 let mut result_case = if let Some(operand) = coerced_operand {
3239 crate::query::cypher_type_coerce::rewrite_simple_case_to_generic(
3240 *operand,
3241 coerced_when_then,
3242 coerced_else,
3243 schema,
3244 )?
3245 } else {
3246 datafusion::logical_expr::expr::Case {
3247 expr: None,
3248 when_then_expr: coerced_when_then,
3249 else_expr: coerced_else,
3250 }
3251 };
3252
3253 crate::query::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3254
3255 Ok(DfExpr::Case(result_case))
3256}
3257
3258fn coerce_aggregate_function(
3260 agg: &datafusion::logical_expr::expr::AggregateFunction,
3261 schema: &datafusion::common::DFSchema,
3262) -> Result<DfExpr> {
3263 let coerced_args: Vec<DfExpr> = agg
3264 .params
3265 .args
3266 .iter()
3267 .map(|a| apply_type_coercion(a, schema))
3268 .collect::<Result<Vec<_>>>()?;
3269 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3270 .params
3271 .order_by
3272 .iter()
3273 .map(|s| {
3274 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3275 Ok(datafusion::logical_expr::SortExpr {
3276 expr: coerced_expr,
3277 asc: s.asc,
3278 nulls_first: s.nulls_first,
3279 })
3280 })
3281 .collect::<Result<Vec<_>>>()?;
3282 let coerced_filter = agg
3283 .params
3284 .filter
3285 .as_ref()
3286 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3287 .transpose()?;
3288 Ok(DfExpr::AggregateFunction(
3289 datafusion::logical_expr::expr::AggregateFunction {
3290 func: agg.func.clone(),
3291 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3292 args: coerced_args,
3293 distinct: agg.params.distinct,
3294 filter: coerced_filter,
3295 order_by: coerced_order_by,
3296 null_treatment: agg.params.null_treatment,
3297 },
3298 },
3299 ))
3300}
3301
3302#[cfg(test)]
3303mod tests {
3304 use super::*;
3305 use arrow_array::{
3306 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3307 };
3308 use uni_common::TemporalValue;
3309 #[test]
3310 fn test_literal_translation() {
3311 let expr = Expr::Literal(CypherLiteral::Integer(42));
3312 let result = cypher_expr_to_df(&expr, None).unwrap();
3313 let s = format!("{:?}", result);
3314 assert!(s.contains("Literal"));
3316 assert!(s.contains("Int64(42)"));
3317 }
3318
3319 #[test]
3320 fn test_property_access_no_context_uses_index() {
3321 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3323 let result = cypher_expr_to_df(&expr, None).unwrap();
3324 let s = format!("{}", result);
3325 assert!(
3326 s.contains("index"),
3327 "expected index UDF for non-graph variable, got: {s}"
3328 );
3329 }
3330
3331 #[test]
3332 fn test_comparison_operator() {
3333 let expr = Expr::BinaryOp {
3334 left: Box::new(Expr::Property(
3335 Box::new(Expr::Variable("n".to_string())),
3336 "age".to_string(),
3337 )),
3338 op: BinaryOp::Gt,
3339 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3340 };
3341 let result = cypher_expr_to_df(&expr, None).unwrap();
3342 let s = format!("{:?}", result);
3344 assert!(s.contains("age"));
3345 assert!(s.contains("30"));
3346 }
3347
3348 #[test]
3349 fn test_boolean_operators() {
3350 let expr = Expr::BinaryOp {
3351 left: Box::new(Expr::BinaryOp {
3352 left: Box::new(Expr::Property(
3353 Box::new(Expr::Variable("n".to_string())),
3354 "age".to_string(),
3355 )),
3356 op: BinaryOp::Gt,
3357 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3358 }),
3359 op: BinaryOp::And,
3360 right: Box::new(Expr::BinaryOp {
3361 left: Box::new(Expr::Property(
3362 Box::new(Expr::Variable("n".to_string())),
3363 "active".to_string(),
3364 )),
3365 op: BinaryOp::Eq,
3366 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3367 }),
3368 };
3369 let result = cypher_expr_to_df(&expr, None).unwrap();
3370 let s = format!("{:?}", result);
3371 assert!(s.contains("And"));
3372 }
3373
3374 #[test]
3375 fn test_is_null() {
3376 let expr = Expr::IsNull(Box::new(Expr::Property(
3377 Box::new(Expr::Variable("n".to_string())),
3378 "email".to_string(),
3379 )));
3380 let result = cypher_expr_to_df(&expr, None).unwrap();
3381 let s = format!("{:?}", result);
3382 assert!(s.contains("IsNull"));
3383 }
3384
3385 #[test]
3386 fn test_collect_properties() {
3387 let expr = Expr::BinaryOp {
3388 left: Box::new(Expr::Property(
3389 Box::new(Expr::Variable("n".to_string())),
3390 "name".to_string(),
3391 )),
3392 op: BinaryOp::Eq,
3393 right: Box::new(Expr::Property(
3394 Box::new(Expr::Variable("m".to_string())),
3395 "name".to_string(),
3396 )),
3397 };
3398
3399 let props = collect_properties(&expr);
3400 assert_eq!(props.len(), 2);
3401 assert!(props.contains(&("m".to_string(), "name".to_string())));
3402 assert!(props.contains(&("n".to_string(), "name".to_string())));
3403 }
3404
3405 #[test]
3406 fn test_function_call() {
3407 let expr = Expr::FunctionCall {
3408 name: "count".to_string(),
3409 args: vec![Expr::Wildcard],
3410 distinct: false,
3411 window_spec: None,
3412 };
3413 let result = cypher_expr_to_df(&expr, None).unwrap();
3414 let s = format!("{:?}", result);
3415 assert!(s.to_lowercase().contains("count"));
3416 }
3417
3418 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3423 use datafusion::logical_expr::Operator;
3424
3425 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3427 let fields: Vec<_> = cols
3428 .iter()
3429 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3430 .collect();
3431 let schema = Schema::new(fields);
3432 datafusion::common::DFSchema::try_from(schema).unwrap()
3433 }
3434
3435 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3437 let s = format!("{}", expr);
3438 s.contains(name)
3439 }
3440
3441 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3443 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3444 }
3445
3446 #[test]
3447 fn test_coercion_lb_eq_int64() {
3448 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3449 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3450 Box::new(col("lb")),
3451 Operator::Eq,
3452 Box::new(col("i")),
3453 ));
3454 let result = apply_type_coercion(&expr, &schema).unwrap();
3455 assert!(
3457 contains_udf(&result, "_cypher_equal"),
3458 "expected _cypher_equal, got: {result}"
3459 );
3460 }
3461
3462 #[test]
3463 fn test_coercion_lb_noteq_int64() {
3464 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3465 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3466 Box::new(col("lb")),
3467 Operator::NotEq,
3468 Box::new(col("i")),
3469 ));
3470 let result = apply_type_coercion(&expr, &schema).unwrap();
3471 assert!(contains_udf(&result, "_cypher_not_equal"));
3473 }
3474
3475 #[test]
3476 fn test_coercion_lb_lt_int64() {
3477 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3478 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3479 Box::new(col("lb")),
3480 Operator::Lt,
3481 Box::new(col("i")),
3482 ));
3483 let result = apply_type_coercion(&expr, &schema).unwrap();
3484 assert!(contains_udf(&result, "_cypher_lt"));
3486 }
3487
3488 #[test]
3489 fn test_coercion_lb_eq_float64() {
3490 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3491 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3492 Box::new(col("lb")),
3493 Operator::Eq,
3494 Box::new(col("f")),
3495 ));
3496 let result = apply_type_coercion(&expr, &schema).unwrap();
3497 assert!(contains_udf(&result, "_cypher_equal"));
3499 }
3500
3501 #[test]
3502 fn test_coercion_lb_eq_utf8() {
3503 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3504 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3505 Box::new(col("lb")),
3506 Operator::Eq,
3507 Box::new(col("s")),
3508 ));
3509 let result = apply_type_coercion(&expr, &schema).unwrap();
3510 assert!(contains_udf(&result, "_cypher_equal"));
3512 }
3513
3514 #[test]
3515 fn test_coercion_lb_eq_bool() {
3516 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3517 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3518 Box::new(col("lb")),
3519 Operator::Eq,
3520 Box::new(col("b")),
3521 ));
3522 let result = apply_type_coercion(&expr, &schema).unwrap();
3523 assert!(contains_udf(&result, "_cypher_equal"));
3525 }
3526
3527 #[test]
3528 fn test_coercion_int64_eq_lb() {
3529 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3531 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3532 Box::new(col("i")),
3533 Operator::Eq,
3534 Box::new(col("lb")),
3535 ));
3536 let result = apply_type_coercion(&expr, &schema).unwrap();
3537 assert!(contains_udf(&result, "_cypher_equal"));
3539 }
3540
3541 #[test]
3542 fn test_coercion_float64_gt_lb() {
3543 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3544 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3545 Box::new(col("f")),
3546 Operator::Gt,
3547 Box::new(col("lb")),
3548 ));
3549 let result = apply_type_coercion(&expr, &schema).unwrap();
3550 assert!(contains_udf(&result, "_cypher_gt"));
3552 }
3553
3554 #[test]
3555 fn test_coercion_both_lb_eq() {
3556 let schema = make_schema(&[
3557 ("lb1", DataType::LargeBinary),
3558 ("lb2", DataType::LargeBinary),
3559 ]);
3560 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3561 Box::new(col("lb1")),
3562 Operator::Eq,
3563 Box::new(col("lb2")),
3564 ));
3565 let result = apply_type_coercion(&expr, &schema).unwrap();
3566 assert!(contains_udf(&result, "_cypher_equal"));
3567 }
3568
3569 #[test]
3570 fn test_coercion_both_lb_lt() {
3571 let schema = make_schema(&[
3572 ("lb1", DataType::LargeBinary),
3573 ("lb2", DataType::LargeBinary),
3574 ]);
3575 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3576 Box::new(col("lb1")),
3577 Operator::Lt,
3578 Box::new(col("lb2")),
3579 ));
3580 let result = apply_type_coercion(&expr, &schema).unwrap();
3581 assert!(contains_udf(&result, "_cypher_lt"));
3582 }
3583
3584 #[test]
3585 fn test_coercion_both_lb_noteq() {
3586 let schema = make_schema(&[
3587 ("lb1", DataType::LargeBinary),
3588 ("lb2", DataType::LargeBinary),
3589 ]);
3590 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3591 Box::new(col("lb1")),
3592 Operator::NotEq,
3593 Box::new(col("lb2")),
3594 ));
3595 let result = apply_type_coercion(&expr, &schema).unwrap();
3596 assert!(contains_udf(&result, "_cypher_not_equal"));
3597 }
3598
3599 #[test]
3600 fn test_coercion_lb_plus_int64() {
3601 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3602 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3603 Box::new(col("lb")),
3604 Operator::Plus,
3605 Box::new(col("i")),
3606 ));
3607 let result = apply_type_coercion(&expr, &schema).unwrap();
3608 assert!(contains_udf(&result, "_cypher_add"));
3609 }
3610
3611 #[test]
3612 fn test_coercion_lb_minus_int64() {
3613 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3614 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3615 Box::new(col("lb")),
3616 Operator::Minus,
3617 Box::new(col("i")),
3618 ));
3619 let result = apply_type_coercion(&expr, &schema).unwrap();
3620 assert!(contains_udf(&result, "_cypher_sub"));
3621 }
3622
3623 #[test]
3624 fn test_coercion_lb_multiply_float64() {
3625 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3626 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3627 Box::new(col("lb")),
3628 Operator::Multiply,
3629 Box::new(col("f")),
3630 ));
3631 let result = apply_type_coercion(&expr, &schema).unwrap();
3632 assert!(contains_udf(&result, "_cypher_mul"));
3633 }
3634
3635 #[test]
3636 fn test_coercion_int64_plus_lb() {
3637 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3638 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3639 Box::new(col("i")),
3640 Operator::Plus,
3641 Box::new(col("lb")),
3642 ));
3643 let result = apply_type_coercion(&expr, &schema).unwrap();
3644 assert!(contains_udf(&result, "_cypher_add"));
3645 }
3646
3647 #[test]
3648 fn test_coercion_lb_plus_utf8() {
3649 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3651 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3652 Box::new(col("lb")),
3653 Operator::Plus,
3654 Box::new(col("s")),
3655 ));
3656 let result = apply_type_coercion(&expr, &schema).unwrap();
3657 assert!(contains_udf(&result, "_cypher_add"));
3659 }
3660
3661 #[test]
3662 fn test_coercion_and_null_bool() {
3663 let schema = make_schema(&[("b", DataType::Boolean)]);
3664 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3666 Box::new(lit(ScalarValue::Null)),
3667 Operator::And,
3668 Box::new(col("b")),
3669 ));
3670 let result = apply_type_coercion(&expr, &schema).unwrap();
3671 let s = format!("{}", result);
3672 assert!(
3674 s.contains("CAST") || s.contains("Boolean"),
3675 "expected cast to Boolean, got: {s}"
3676 );
3677 assert!(is_binary_op(&result, Operator::And));
3678 }
3679
3680 #[test]
3681 fn test_coercion_bool_and_null() {
3682 let schema = make_schema(&[("b", DataType::Boolean)]);
3683 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3684 Box::new(col("b")),
3685 Operator::And,
3686 Box::new(lit(ScalarValue::Null)),
3687 ));
3688 let result = apply_type_coercion(&expr, &schema).unwrap();
3689 assert!(is_binary_op(&result, Operator::And));
3690 }
3691
3692 #[test]
3693 fn test_coercion_or_null_bool() {
3694 let schema = make_schema(&[("b", DataType::Boolean)]);
3695 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3696 Box::new(lit(ScalarValue::Null)),
3697 Operator::Or,
3698 Box::new(col("b")),
3699 ));
3700 let result = apply_type_coercion(&expr, &schema).unwrap();
3701 assert!(is_binary_op(&result, Operator::Or));
3702 }
3703
3704 #[test]
3705 fn test_coercion_null_and_null() {
3706 let schema = make_schema(&[]);
3707 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3708 Box::new(lit(ScalarValue::Null)),
3709 Operator::And,
3710 Box::new(lit(ScalarValue::Null)),
3711 ));
3712 let result = apply_type_coercion(&expr, &schema).unwrap();
3713 assert!(is_binary_op(&result, Operator::And));
3714 }
3715
3716 #[test]
3717 fn test_coercion_bool_and_bool_noop() {
3718 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3719 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3720 Box::new(col("a")),
3721 Operator::And,
3722 Box::new(col("b")),
3723 ));
3724 let result = apply_type_coercion(&expr, &schema).unwrap();
3725 assert!(is_binary_op(&result, Operator::And));
3727 let s = format!("{}", result);
3728 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3729 }
3730
3731 #[test]
3732 fn test_coercion_case_when_lb() {
3733 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3735 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3736 Box::new(col("lb")),
3737 Operator::Eq,
3738 Box::new(lit(42_i64)),
3739 ));
3740 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3741 expr: None,
3742 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3743 else_expr: Some(Box::new(lit("b"))),
3744 });
3745 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3746 let s = format!("{}", result);
3747 assert!(
3749 s.contains("_cypher_equal"),
3750 "CASE WHEN should have _cypher_equal, got: {s}"
3751 );
3752 }
3753
3754 #[test]
3755 fn test_coercion_case_then_lb() {
3756 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3758 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3759 Box::new(col("lb")),
3760 Operator::Plus,
3761 Box::new(lit(1_i64)),
3762 ));
3763 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3764 expr: None,
3765 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3766 else_expr: Some(Box::new(lit(0_i64))),
3767 });
3768 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3769 let s = format!("{}", result);
3770 assert!(
3771 s.contains("_cypher_add"),
3772 "CASE THEN should have _cypher_add, got: {s}"
3773 );
3774 }
3775
3776 #[test]
3777 fn test_coercion_case_else_lb() {
3778 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3780 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3781 Box::new(col("lb")),
3782 Operator::Plus,
3783 Box::new(lit(2_i64)),
3784 ));
3785 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3786 expr: None,
3787 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3788 else_expr: Some(Box::new(else_expr)),
3789 });
3790 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3791 let s = format!("{}", result);
3792 assert!(
3793 s.contains("_cypher_add"),
3794 "CASE ELSE should have _cypher_add, got: {s}"
3795 );
3796 }
3797
3798 #[test]
3799 fn test_coercion_int64_eq_int64_noop() {
3800 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3801 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3802 Box::new(col("a")),
3803 Operator::Eq,
3804 Box::new(col("b")),
3805 ));
3806 let result = apply_type_coercion(&expr, &schema).unwrap();
3807 assert!(is_binary_op(&result, Operator::Eq));
3808 let s = format!("{}", result);
3809 assert!(
3810 !s.contains("_cypher_value"),
3811 "should not contain cypher_value decode: {s}"
3812 );
3813 }
3814
3815 #[test]
3816 fn test_coercion_both_lb_plus() {
3817 let schema = make_schema(&[
3819 ("lb1", DataType::LargeBinary),
3820 ("lb2", DataType::LargeBinary),
3821 ]);
3822 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3823 Box::new(col("lb1")),
3824 Operator::Plus,
3825 Box::new(col("lb2")),
3826 ));
3827 let result = apply_type_coercion(&expr, &schema).unwrap();
3828 assert!(
3829 contains_udf(&result, "_cypher_add"),
3830 "expected _cypher_add, got: {result}"
3831 );
3832 }
3833
3834 #[test]
3835 fn test_coercion_native_list_plus_scalar() {
3836 let schema = make_schema(&[
3838 (
3839 "lst",
3840 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3841 ),
3842 ("i", DataType::Int32),
3843 ]);
3844 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3845 Box::new(col("lst")),
3846 Operator::Plus,
3847 Box::new(col("i")),
3848 ));
3849 let result = apply_type_coercion(&expr, &schema).unwrap();
3850 assert!(
3851 contains_udf(&result, "_cypher_list_append"),
3852 "expected _cypher_list_append, got: {result}"
3853 );
3854 }
3855
3856 #[test]
3857 fn test_coercion_lb_plus_int64_unchanged() {
3858 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3860 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3861 Box::new(col("lb")),
3862 Operator::Plus,
3863 Box::new(col("i")),
3864 ));
3865 let result = apply_type_coercion(&expr, &schema).unwrap();
3866 assert!(
3867 contains_udf(&result, "_cypher_add"),
3868 "expected _cypher_add, got: {result}"
3869 );
3870 }
3871
3872 #[test]
3877 fn test_mixed_list_with_variables_compiles() {
3878 let expr = Expr::List(vec![
3880 Expr::Variable("n".to_string()),
3881 Expr::Literal(CypherLiteral::Integer(1)),
3882 Expr::Literal(CypherLiteral::String("hello".to_string())),
3883 ]);
3884 let result = cypher_expr_to_df(&expr, None).unwrap();
3885 let s = format!("{}", result);
3886 assert!(
3887 s.contains("_make_cypher_list"),
3888 "expected _make_cypher_list UDF call, got: {s}"
3889 );
3890 }
3891
3892 #[test]
3893 fn test_literal_only_mixed_list_uses_cv_fastpath() {
3894 let expr = Expr::List(vec![
3896 Expr::Literal(CypherLiteral::Integer(1)),
3897 Expr::Literal(CypherLiteral::String("hi".to_string())),
3898 Expr::Literal(CypherLiteral::Bool(true)),
3899 ]);
3900 let result = cypher_expr_to_df(&expr, None).unwrap();
3901 assert!(
3902 matches!(result, DfExpr::Literal(..)),
3903 "expected Literal (CypherValue fast path), got: {result}"
3904 );
3905 }
3906
3907 #[test]
3912 fn test_in_mixed_literal_list_uses_cypher_in() {
3913 let expr = Expr::In {
3915 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3916 list: Box::new(Expr::List(vec![
3917 Expr::Literal(CypherLiteral::String("1".to_string())),
3918 Expr::Literal(CypherLiteral::Integer(2)),
3919 ])),
3920 };
3921 let result = cypher_expr_to_df(&expr, None).unwrap();
3922 let s = format!("{}", result);
3923 assert!(
3924 s.contains("_cypher_in"),
3925 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
3926 );
3927 }
3928
3929 #[test]
3930 fn test_in_homogeneous_literal_list_uses_cypher_in() {
3931 let expr = Expr::In {
3933 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3934 list: Box::new(Expr::List(vec![
3935 Expr::Literal(CypherLiteral::Integer(2)),
3936 Expr::Literal(CypherLiteral::Integer(3)),
3937 ])),
3938 };
3939 let result = cypher_expr_to_df(&expr, None).unwrap();
3940 let s = format!("{}", result);
3941 assert!(
3942 s.contains("_cypher_in"),
3943 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
3944 );
3945 }
3946
3947 #[test]
3948 fn test_in_list_with_variables_uses_make_cypher_list() {
3949 let expr = Expr::In {
3951 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3952 list: Box::new(Expr::List(vec![
3953 Expr::Variable("x".to_string()),
3954 Expr::Literal(CypherLiteral::Integer(2)),
3955 ])),
3956 };
3957 let result = cypher_expr_to_df(&expr, None).unwrap();
3958 let s = format!("{}", result);
3959 assert!(
3960 s.contains("_cypher_in"),
3961 "expected _cypher_in UDF, got: {s}"
3962 );
3963 assert!(
3964 s.contains("_make_cypher_list"),
3965 "expected _make_cypher_list for variable-containing list, got: {s}"
3966 );
3967 }
3968
3969 #[test]
3974 fn test_property_on_graph_entity_uses_column() {
3975 let mut ctx = TranslationContext::new();
3977 ctx.variable_kinds
3978 .insert("n".to_string(), VariableKind::Node);
3979
3980 let expr = Expr::Property(
3981 Box::new(Expr::Variable("n".to_string())),
3982 "name".to_string(),
3983 );
3984 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
3985 let s = format!("{:?}", result);
3986 assert!(
3987 s.contains("Column") && s.contains("n.name"),
3988 "expected flat column 'n.name' for graph entity, got: {s}"
3989 );
3990 }
3991
3992 #[test]
3993 fn test_property_on_non_graph_var_uses_index() {
3994 let ctx = TranslationContext::new();
3996
3997 let expr = Expr::Property(
3998 Box::new(Expr::Variable("map".to_string())),
3999 "name".to_string(),
4000 );
4001 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4002 let s = format!("{}", result);
4003 assert!(
4004 s.contains("index"),
4005 "expected index UDF for non-graph variable, got: {s}"
4006 );
4007 }
4008
4009 #[test]
4010 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4011 let mut map = std::collections::HashMap::new();
4012 map.insert("k".to_string(), Value::Int(1));
4013 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4014 assert!(
4015 matches!(scalar, ScalarValue::Struct(_)),
4016 "expected Struct scalar for map input"
4017 );
4018 }
4019
4020 #[test]
4021 fn test_value_to_scalar_empty_map_becomes_struct() {
4022 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4023 assert!(
4024 matches!(scalar, ScalarValue::Struct(_)),
4025 "empty map should produce an empty Struct scalar"
4026 );
4027 }
4028
4029 #[test]
4030 fn test_value_to_scalar_null_is_untyped_null() {
4031 let scalar = value_to_scalar(&Value::Null).unwrap();
4032 assert!(
4033 matches!(scalar, ScalarValue::Null),
4034 "expected untyped Null scalar for Value::Null"
4035 );
4036 }
4037
4038 #[test]
4039 fn test_value_to_scalar_datetime_produces_struct() {
4040 let datetime = Value::Temporal(TemporalValue::DateTime {
4042 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4045 });
4046
4047 let scalar = value_to_scalar(&datetime).unwrap();
4048
4049 if let ScalarValue::Struct(struct_arr) = scalar {
4051 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4052 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4053
4054 let fields = struct_arr.fields();
4056 assert_eq!(fields[0].name(), "nanos_since_epoch");
4057 assert_eq!(fields[1].name(), "offset_seconds");
4058 assert_eq!(fields[2].name(), "timezone_name");
4059
4060 let nanos_col = struct_arr.column(0);
4062 let offset_col = struct_arr.column(1);
4063 let tz_col = struct_arr.column(2);
4064
4065 if let Some(nanos_arr) = nanos_col
4066 .as_any()
4067 .downcast_ref::<TimestampNanosecondArray>()
4068 {
4069 assert_eq!(nanos_arr.value(0), 441763200000000000);
4070 } else {
4071 panic!("Expected TimestampNanosecondArray for nanos field");
4072 }
4073
4074 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4075 assert_eq!(offset_arr.value(0), 3600);
4076 } else {
4077 panic!("Expected Int32Array for offset field");
4078 }
4079
4080 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4081 assert_eq!(tz_arr.value(0), "Europe/Paris");
4082 } else {
4083 panic!("Expected StringArray for timezone_name field");
4084 }
4085 } else {
4086 panic!(
4087 "Expected ScalarValue::Struct for DateTime, got {:?}",
4088 scalar
4089 );
4090 }
4091 }
4092
4093 #[test]
4094 fn test_value_to_scalar_datetime_with_null_timezone() {
4095 let datetime = Value::Temporal(TemporalValue::DateTime {
4097 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4100 });
4101
4102 let scalar = value_to_scalar(&datetime).unwrap();
4103
4104 if let ScalarValue::Struct(struct_arr) = scalar {
4105 assert_eq!(struct_arr.num_columns(), 3);
4106
4107 let tz_col = struct_arr.column(2);
4109 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4110 assert!(tz_arr.is_null(0), "expected null timezone_name");
4111 } else {
4112 panic!("Expected StringArray for timezone_name field");
4113 }
4114 } else {
4115 panic!("Expected ScalarValue::Struct for DateTime");
4116 }
4117 }
4118
4119 #[test]
4120 fn test_value_to_scalar_time_produces_struct() {
4121 let time = Value::Temporal(TemporalValue::Time {
4123 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4126
4127 let scalar = value_to_scalar(&time).unwrap();
4128
4129 if let ScalarValue::Struct(struct_arr) = scalar {
4131 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4132 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4133
4134 let fields = struct_arr.fields();
4136 assert_eq!(fields[0].name(), "nanos_since_midnight");
4137 assert_eq!(fields[1].name(), "offset_seconds");
4138
4139 let nanos_col = struct_arr.column(0);
4141 let offset_col = struct_arr.column(1);
4142
4143 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4144 assert_eq!(nanos_arr.value(0), 37845000000000);
4145 } else {
4146 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4147 }
4148
4149 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4150 assert_eq!(offset_arr.value(0), 3600);
4151 } else {
4152 panic!("Expected Int32Array for offset field");
4153 }
4154 } else {
4155 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4156 }
4157 }
4158
4159 #[test]
4160 fn test_value_to_scalar_time_boundary_values() {
4161 let midnight = Value::Temporal(TemporalValue::Time {
4163 nanos_since_midnight: 0,
4164 offset_seconds: 0,
4165 });
4166
4167 let scalar = value_to_scalar(&midnight).unwrap();
4168
4169 if let ScalarValue::Struct(struct_arr) = scalar {
4170 let nanos_col = struct_arr.column(0);
4171 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4172 assert_eq!(nanos_arr.value(0), 0);
4173 } else {
4174 panic!("Expected Time64NanosecondArray");
4175 }
4176 } else {
4177 panic!("Expected ScalarValue::Struct for Time");
4178 }
4179 }
4180}