1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub(crate) fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub(crate) fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub(crate) fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub(crate) fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum VariableKind {
224 Node,
226 Edge,
228 EdgeList,
230 Path,
232}
233
234impl VariableKind {
235 pub fn edge_for(is_variable_length: bool) -> Self {
239 if is_variable_length {
240 Self::EdgeList
241 } else {
242 Self::Edge
243 }
244 }
245}
246
247pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
282 match expr {
283 Expr::PatternComprehension { .. } => Err(anyhow!(
284 "Pattern comprehensions require fallback executor (graph traversal)"
285 )),
286 Expr::Wildcard => Ok(DfExpr::Literal(
290 datafusion::common::ScalarValue::Int32(Some(1)),
291 None,
292 )),
293
294 Expr::Variable(name) => {
295 if let Some(ctx) = context
300 && ctx.variable_kinds.contains_key(name)
301 {
302 return Ok(DfExpr::Column(Column::from_name(name)));
303 }
304
305 if let Some(ctx) = context
311 && let Some(value) = ctx.outer_values.get(name)
312 {
313 return value_to_scalar(value).map(lit);
314 }
315
316 if let Some(ctx) = context
321 && let Some(value) = ctx.parameters.get(name)
322 {
323 match value {
326 Value::List(values) if name.ends_with("._vid") => {
327 let literals = values
329 .iter()
330 .map(|v| value_to_scalar(v).map(lit))
331 .collect::<Result<Vec<_>>>()?;
332 return Ok(DfExpr::InList(InList {
333 expr: Box::new(DfExpr::Column(Column::from_name(name))),
334 list: literals,
335 negated: false,
336 }));
337 }
338 other_value => return value_to_scalar(other_value).map(lit),
339 }
340 }
341
342 Ok(DfExpr::Column(Column::from_name(name)))
345 }
346
347 Expr::Property(base, prop) => translate_property_access(base, prop, context),
348
349 Expr::ArrayIndex { array, index } => {
350 if let Ok(var_name) = extract_variable_name(array)
353 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
354 {
355 let col_name = format!("{}.{}", var_name, prop_name);
356 return Ok(DfExpr::Column(Column::from_name(col_name)));
357 }
358
359 let array_expr = cypher_expr_to_df(array, context)?;
360 let index_expr = cypher_expr_to_df(index, context)?;
361
362 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
364 }
365
366 Expr::ArraySlice { array, start, end } => {
367 let array_expr = cypher_expr_to_df(array, context)?;
371
372 let start_expr = match start {
373 Some(s) => cypher_expr_to_df(s, context)?,
374 None => lit(0i64),
375 };
376
377 let end_expr = match end {
378 Some(e) => cypher_expr_to_df(e, context)?,
379 None => lit(i64::MAX),
380 };
381
382 Ok(dummy_udf_expr(
385 "_cypher_list_slice",
386 vec![array_expr, start_expr, end_expr],
387 ))
388 }
389
390 Expr::Parameter(name) => {
391 if let Some(ctx) = context
393 && let Some(value) = ctx.parameters.get(name)
394 {
395 return value_to_scalar(value).map(lit);
396 }
397 Err(anyhow!("Unresolved parameter: ${}", name))
398 }
399
400 Expr::Literal(value) => {
401 let scalar = cypher_literal_to_scalar(value)?;
402 Ok(lit(scalar))
403 }
404
405 Expr::List(items) => translate_list_literal(items, context),
406
407 Expr::Map(entries) => {
408 if entries.is_empty() {
409 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
411 Default::default(),
412 ));
413 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
414 }
415 let mut args = Vec::with_capacity(entries.len() * 2);
418 for (key, val_expr) in entries {
419 args.push(lit(key.clone()));
420 args.push(cypher_expr_to_df(val_expr, context)?);
421 }
422 Ok(datafusion::functions::expr_fn::named_struct(args))
423 }
424
425 Expr::IsNull(inner) => translate_null_check(inner, context, true),
426
427 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
428
429 Expr::IsUnique(_) => {
430 Err(anyhow!(
432 "IS UNIQUE can only be used in constraint definitions"
433 ))
434 }
435
436 Expr::FunctionCall {
437 name,
438 args,
439 distinct,
440 window_spec,
441 } => {
442 if window_spec.is_some() {
445 let col_name = expr.to_string_repr();
447 Ok(col(&col_name))
448 } else {
449 translate_function_call(name, args, *distinct, context)
450 }
451 }
452
453 Expr::In { expr, list } => translate_in_expression(expr, list, context),
454
455 Expr::BinaryOp { left, op, right } => {
456 let left_expr = cypher_expr_to_df(left, context)?;
457 let right_expr = cypher_expr_to_df(right, context)?;
458 translate_binary_op(left_expr, op, right_expr)
459 }
460
461 Expr::UnaryOp { op, expr: inner } => {
462 let inner_expr = cypher_expr_to_df(inner, context)?;
463 match op {
464 UnaryOp::Not => Ok(inner_expr.not()),
465 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
466 }
467 }
468
469 Expr::Case {
470 expr,
471 when_then,
472 else_expr,
473 } => translate_case_expression(expr, when_then, else_expr, context),
474
475 Expr::Reduce { .. } => Err(anyhow!(
476 "Reduce expressions not yet supported in DataFusion translation"
477 )),
478
479 Expr::Exists { .. } => Err(anyhow!(
480 "EXISTS subqueries are handled by the physical expression compiler, \
481 not the DataFusion logical expression translator"
482 )),
483
484 Expr::CountSubquery(_) => Err(anyhow!(
485 "Count subqueries not yet supported in DataFusion translation"
486 )),
487
488 Expr::CollectSubquery(_) => Err(anyhow!(
489 "COLLECT subqueries not yet supported in DataFusion translation"
490 )),
491
492 Expr::Quantifier { .. } => {
493 Err(anyhow!(
498 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
499 via CypherPhysicalExprCompiler"
500 ))
501 }
502
503 Expr::ListComprehension { .. } => {
504 Err(anyhow!(
516 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
517 ))
518 }
519
520 Expr::ValidAt { .. } => {
521 Err(anyhow!(
524 "VALID_AT expression should have been transformed to function call in planner"
525 ))
526 }
527
528 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
529
530 Expr::LabelCheck { expr, labels } => {
531 if let Expr::Variable(var) = expr.as_ref() {
532 let is_edge = context
534 .and_then(|ctx| ctx.variable_kinds.get(var))
535 .is_some_and(|k| matches!(k, VariableKind::Edge));
536
537 if is_edge {
538 if labels.len() > 1 {
542 Ok(lit(false))
543 } else {
544 let type_col =
545 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
546 Ok(DfExpr::Case(datafusion::logical_expr::Case {
548 expr: None,
549 when_then_expr: vec![(
550 Box::new(type_col.clone().is_null()),
551 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
552 )],
553 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
554 }))
555 }
556 } else {
557 let labels_col =
559 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
560 let checks = labels
561 .iter()
562 .map(|label| {
563 datafusion::functions_nested::expr_fn::array_has(
564 labels_col.clone(),
565 lit(label.clone()),
566 )
567 })
568 .reduce(|acc, check| acc.and(check));
569 Ok(DfExpr::Case(datafusion::logical_expr::Case {
571 expr: None,
572 when_then_expr: vec![(
573 Box::new(labels_col.is_null()),
574 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
575 )],
576 else_expr: Some(Box::new(checks.unwrap())),
577 }))
578 }
579 } else {
580 Err(anyhow!(
581 "LabelCheck on non-variable expression not yet supported in DataFusion"
582 ))
583 }
584 }
585 }
586}
587
588#[derive(Debug, Clone)]
592pub struct TranslationContext {
593 pub parameters: std::collections::HashMap<String, Value>,
595
596 pub outer_values: std::collections::HashMap<String, Value>,
600
601 pub variable_labels: std::collections::HashMap<String, String>,
603
604 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
606
607 pub node_variable_hints: Vec<String>,
610
611 pub mutation_edge_hints: Vec<String>,
614
615 pub statement_time: chrono::DateTime<chrono::Utc>,
620}
621
622impl Default for TranslationContext {
623 fn default() -> Self {
624 Self {
625 parameters: std::collections::HashMap::new(),
626 outer_values: std::collections::HashMap::new(),
627 variable_labels: std::collections::HashMap::new(),
628 variable_kinds: std::collections::HashMap::new(),
629 node_variable_hints: Vec::new(),
630 mutation_edge_hints: Vec::new(),
631 statement_time: chrono::Utc::now(),
632 }
633 }
634}
635
636impl TranslationContext {
637 pub fn new() -> Self {
639 Self::default()
640 }
641
642 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
644 self.parameters.insert(name.into(), value);
645 self
646 }
647
648 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
650 self.variable_labels.insert(var.into(), label.into());
651 self
652 }
653}
654
655fn extract_variable_name(expr: &Expr) -> Result<String> {
657 match expr {
658 Expr::Variable(name) => Ok(name.clone()),
659 Expr::Property(base, _) => extract_variable_name(base),
660 _ => Err(anyhow!(
661 "Cannot extract variable name from expression: {:?}",
662 expr
663 )),
664 }
665}
666
667fn translate_null_check(
669 inner: &Expr,
670 context: Option<&TranslationContext>,
671 is_null: bool,
672) -> Result<DfExpr> {
673 if let Expr::Variable(var) = inner
674 && let Some(ctx) = context
675 && let Some(kind) = ctx.variable_kinds.get(var)
676 {
677 let col_name = match kind {
678 VariableKind::Node => format!("{}.{}", var, COL_VID),
679 VariableKind::Edge => format!("{}.{}", var, COL_EID),
680 VariableKind::Path | VariableKind::EdgeList => var.clone(),
681 };
682 let col_expr = DfExpr::Column(Column::from_name(col_name));
683 return Ok(if is_null {
684 col_expr.is_null()
685 } else {
686 col_expr.is_not_null()
687 });
688 }
689
690 let inner_expr = cypher_expr_to_df(inner, context)?;
691 Ok(if is_null {
692 inner_expr.is_null()
693 } else {
694 inner_expr.is_not_null()
695 })
696}
697
698fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
703 if crate::query::datetime::is_duration_accessor(prop) {
704 Some(dummy_udf_expr(
705 "_duration_property",
706 vec![base_expr, lit(prop.to_string())],
707 ))
708 } else if crate::query::datetime::is_temporal_accessor(prop) {
709 Some(dummy_udf_expr(
710 "_temporal_property",
711 vec![base_expr, lit(prop.to_string())],
712 ))
713 } else {
714 None
715 }
716}
717
718fn translate_property_access(
720 base: &Expr,
721 prop: &str,
722 context: Option<&TranslationContext>,
723) -> Result<DfExpr> {
724 if let Ok(var_name) = extract_variable_name(base) {
725 let is_graph_entity = context
726 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
727 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
728
729 if !is_graph_entity
730 && let Some(expr) =
731 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
732 {
733 return Ok(expr);
734 }
735
736 let col_name = format!("{}.{}", var_name, prop);
737
738 if let Some(ctx) = context
741 && let Some(value) = ctx.parameters.get(&col_name)
742 {
743 match value {
746 Value::List(values) if col_name.ends_with("._vid") => {
747 let literals = values
748 .iter()
749 .map(|v| value_to_scalar(v).map(lit))
750 .collect::<Result<Vec<_>>>()?;
751 return Ok(DfExpr::InList(InList {
752 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
753 list: literals,
754 negated: false,
755 }));
756 }
757 other_value => return value_to_scalar(other_value).map(lit),
758 }
759 }
760
761 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
764 let base_expr = cypher_expr_to_df(base, context)?;
765 return Ok(dummy_udf_expr(
766 "index",
767 vec![base_expr, lit(prop.to_string())],
768 ));
769 }
770
771 if is_graph_entity {
772 Ok(DfExpr::Column(Column::from_name(col_name)))
773 } else {
774 let base_expr = DfExpr::Column(Column::from_name(var_name));
775 Ok(dummy_udf_expr(
776 "index",
777 vec![base_expr, lit(prop.to_string())],
778 ))
779 }
780 } else {
781 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
783 return Ok(expr);
784 }
785
786 if let Expr::Parameter(param_name) = base {
788 if let Some(ctx) = context
789 && let Some(value) = ctx.parameters.get(param_name)
790 {
791 if let Value::Map(map) = value {
792 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
793 return value_to_scalar(&extracted).map(lit);
794 }
795 return Ok(lit(ScalarValue::Null));
796 }
797 return Err(anyhow!("Unresolved parameter: ${}", param_name));
798 }
799
800 let base_expr = cypher_expr_to_df(base, context)?;
801 Ok(dummy_udf_expr(
802 "index",
803 vec![base_expr, lit(prop.to_string())],
804 ))
805 }
806}
807
808fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
810 let mut has_string = false;
812 let mut has_bool = false;
813 let mut has_list = false;
814 let mut has_map = false;
815 let mut has_numeric = false;
816 let mut has_graph_entity = false;
817 let mut has_temporal = false;
818
819 for item in items {
820 match item {
821 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
822 has_numeric = true
823 }
824 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
825 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
826 Expr::List(_) => has_list = true,
827 Expr::Map(_) => has_map = true,
828 Expr::Variable(name) => {
831 if context
832 .and_then(|ctx| ctx.variable_kinds.get(name))
833 .is_some()
834 {
835 has_graph_entity = true;
836 }
837 }
838 Expr::FunctionCall { name, .. } => {
841 let upper = name.to_uppercase();
842 if matches!(
843 upper.as_str(),
844 "DATE"
845 | "TIME"
846 | "LOCALTIME"
847 | "LOCALDATETIME"
848 | "DATETIME"
849 | "DURATION"
850 | "DATE.TRUNCATE"
851 | "TIME.TRUNCATE"
852 | "DATETIME.TRUNCATE"
853 | "LOCALDATETIME.TRUNCATE"
854 | "LOCALTIME.TRUNCATE"
855 ) {
856 has_temporal = true;
857 }
858 }
859 _ => {}
861 }
862 }
863
864 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
866
867 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
870 if let Some(json_array) = try_items_to_json(items) {
872 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
873 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
874 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
875 }
876 let df_args: Vec<DfExpr> = items
878 .iter()
879 .map(|item| cypher_expr_to_df(item, context))
880 .collect::<Result<_>>()?;
881 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
882 }
883
884 let mut df_args = Vec::with_capacity(items.len());
887 let mut has_float = false;
888 let mut has_int = false;
889 let mut has_other = false;
890
891 for item in items {
892 match item {
893 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
894 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
895 _ => has_other = true,
896 }
897 df_args.push(cypher_expr_to_df(item, context)?);
898 }
899
900 if df_args.is_empty() {
901 let empty_arr =
903 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
904 Ok(lit(ScalarValue::List(empty_arr)))
905 } else if has_float && has_int && !has_other {
906 let promoted_args = df_args
908 .into_iter()
909 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
910 .collect();
911 Ok(datafusion::functions_nested::expr_fn::make_array(
912 promoted_args,
913 ))
914 } else {
915 let non_null_type = df_args.iter().find_map(|e| {
919 if let DfExpr::Literal(sv, _) = e {
920 let dt = sv.data_type();
921 if dt != datafusion::arrow::datatypes::DataType::Null {
922 return Some(dt);
923 }
924 }
925 None
926 });
927 if let Some(ref target_type) = non_null_type {
928 let coerced = df_args
929 .into_iter()
930 .map(|e| {
931 if matches!(&e, DfExpr::Literal(sv, _) if sv.data_type() == datafusion::arrow::datatypes::DataType::Null)
932 {
933 cast_expr(e, target_type.clone())
934 } else {
935 e
936 }
937 })
938 .collect();
939 Ok(datafusion::functions_nested::expr_fn::make_array(coerced))
940 } else {
941 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
942 }
943 }
944}
945
946fn translate_in_expression(
948 expr: &Expr,
949 list: &Expr,
950 context: Option<&TranslationContext>,
951) -> Result<DfExpr> {
952 let left_expr = if let Expr::Variable(var) = expr
957 && let Some(ctx) = context
958 && let Some(kind) = ctx.variable_kinds.get(var)
959 {
960 match kind {
961 VariableKind::Node | VariableKind::Edge => {
962 let id_col = match kind {
963 VariableKind::Node => COL_VID,
964 VariableKind::Edge => COL_EID,
965 _ => unreachable!(),
966 };
967 cast_expr(
968 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
969 datafusion::arrow::datatypes::DataType::Int64,
970 )
971 }
972 _ => cypher_expr_to_df(expr, context)?,
973 }
974 } else {
975 cypher_expr_to_df(expr, context)?
976 };
977
978 if let Expr::List(items) = list {
983 if let Some(json_array) = try_items_to_json(items) {
984 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
986 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
987 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
988 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
989 } else {
990 let expanded: Vec<DfExpr> = items
992 .iter()
993 .map(|item| cypher_expr_to_df(item, context))
994 .collect::<Result<Vec<_>>>()?;
995 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
996 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
997 }
998 } else {
999 let right_expr = cypher_expr_to_df(list, context)?;
1000
1001 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
1006 return Ok(lit(ScalarValue::Boolean(None)));
1007 }
1008
1009 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
1010 }
1011}
1012
1013fn translate_case_expression(
1015 operand: &Option<Box<Expr>>,
1016 when_then: &[(Expr, Expr)],
1017 else_expr: &Option<Box<Expr>>,
1018 context: Option<&TranslationContext>,
1019) -> Result<DfExpr> {
1020 let mut case_builder = if let Some(match_expr) = operand {
1021 let match_df = cypher_expr_to_df(match_expr, context)?;
1022 datafusion::logical_expr::case(match_df)
1023 } else {
1024 datafusion::logical_expr::when(
1025 cypher_expr_to_df(&when_then[0].0, context)?,
1026 cypher_expr_to_df(&when_then[0].1, context)?,
1027 )
1028 };
1029
1030 let start_idx = if operand.is_some() { 0 } else { 1 };
1031 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1032 let when_df = cypher_expr_to_df(when_expr, context)?;
1033 let then_df = cypher_expr_to_df(then_expr, context)?;
1034 case_builder = case_builder.when(when_df, then_df);
1035 }
1036
1037 if let Some(else_e) = else_expr {
1038 let else_df = cypher_expr_to_df(else_e, context)?;
1039 Ok(case_builder.otherwise(else_df)?)
1040 } else {
1041 Ok(case_builder.end()?)
1042 }
1043}
1044
1045fn translate_map_projection(
1047 base: &Expr,
1048 items: &[MapProjectionItem],
1049 context: Option<&TranslationContext>,
1050) -> Result<DfExpr> {
1051 let mut args = Vec::new();
1052 for item in items {
1053 match item {
1054 MapProjectionItem::Property(prop) => {
1055 args.push(lit(prop.clone()));
1056 let prop_expr = cypher_expr_to_df(
1057 &Expr::Property(Box::new(base.clone()), prop.clone()),
1058 context,
1059 )?;
1060 args.push(prop_expr);
1061 }
1062 MapProjectionItem::LiteralEntry(key, expr) => {
1063 args.push(lit(key.clone()));
1064 args.push(cypher_expr_to_df(expr, context)?);
1065 }
1066 MapProjectionItem::Variable(var) => {
1067 args.push(lit(var.clone()));
1068 args.push(DfExpr::Column(Column::from_name(var)));
1069 }
1070 MapProjectionItem::AllProperties => {
1071 args.push(lit("__all__"));
1072 args.push(cypher_expr_to_df(base, context)?);
1073 }
1074 }
1075 }
1076 Ok(dummy_udf_expr("_map_project", args))
1077}
1078
1079fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1082 match expr {
1083 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1084 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1085 Expr::Literal(CypherLiteral::Integer(i)) => {
1086 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1087 }
1088 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1089 .map(serde_json::Value::Number)
1090 .or(Some(serde_json::Value::Null)),
1091 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1092 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1093 Expr::Map(entries) => {
1094 let mut map = serde_json::Map::new();
1095 for (k, v) in entries {
1096 map.insert(k.clone(), try_expr_to_json(v)?);
1097 }
1098 Some(serde_json::Value::Object(map))
1099 }
1100 _ => None,
1101 }
1102}
1103
1104fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1106 items.iter().map(try_expr_to_json).collect()
1107}
1108
1109fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1111 match lit {
1112 CypherLiteral::Null => Ok(ScalarValue::Null),
1113 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1114 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1115 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1116 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1117 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1118 }
1119}
1120
1121fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1123 match value {
1124 Value::Null => Ok(ScalarValue::Null),
1125 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1126 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1127 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1128 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1129 Value::List(items) => {
1130 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1132 let scalars = scalars?;
1133
1134 let data_type = infer_common_scalar_type(&scalars);
1136
1137 let typed_scalars: Vec<ScalarValue> = scalars
1139 .into_iter()
1140 .map(|s| {
1141 if matches!(s, ScalarValue::Null) {
1142 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1143 }
1144
1145 match (s, &data_type) {
1146 (
1147 ScalarValue::Int64(Some(v)),
1148 datafusion::arrow::datatypes::DataType::Float64,
1149 ) => ScalarValue::Float64(Some(v as f64)),
1150 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1151 let s_str = s.to_string();
1153 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1154 }
1155 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1156 if matches!(s, ScalarValue::Utf8(_)) {
1158 s
1159 } else {
1160 ScalarValue::Utf8(Some(s.to_string()))
1161 }
1162 }
1163 (s, _) => s,
1164 }
1165 })
1166 .collect();
1167
1168 if typed_scalars.is_empty() {
1170 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1171 &[],
1172 &data_type,
1173 )))
1174 } else {
1175 Ok(ScalarValue::List(ScalarValue::new_list(
1176 &typed_scalars,
1177 &data_type,
1178 true,
1179 )))
1180 }
1181 }
1182 Value::Map(map) => {
1183 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1186 entries.sort_by_key(|(k, _)| *k);
1187
1188 if entries.is_empty() {
1189 return Ok(ScalarValue::Struct(Arc::new(
1190 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1191 )));
1192 }
1193
1194 let mut fields_arrays = Vec::with_capacity(entries.len());
1195
1196 for (k, v) in entries {
1197 let scalar = value_to_scalar(v)?;
1198 let dt = scalar.data_type();
1199 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1200 let array = scalar.to_array()?;
1201 fields_arrays.push((field, array));
1202 }
1203
1204 Ok(ScalarValue::Struct(Arc::new(
1205 datafusion::arrow::array::StructArray::from(fields_arrays),
1206 )))
1207 }
1208 Value::Temporal(tv) => {
1209 use uni_common::TemporalValue;
1210 match tv {
1211 TemporalValue::Date { days_since_epoch } => {
1212 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1213 }
1214 TemporalValue::LocalTime {
1215 nanos_since_midnight,
1216 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1217 TemporalValue::Time {
1218 nanos_since_midnight,
1219 offset_seconds,
1220 } => {
1221 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1223 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1224
1225 let nanos_arr =
1226 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1227 as ArrayRef;
1228 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1229
1230 let fields = Fields::from(vec![
1231 Field::new(
1232 "nanos_since_midnight",
1233 ArrowDataType::Time64(TimeUnit::Nanosecond),
1234 true,
1235 ),
1236 Field::new("offset_seconds", ArrowDataType::Int32, true),
1237 ]);
1238
1239 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1240 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1241 }
1242 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1243 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1244 ),
1245 TemporalValue::DateTime {
1246 nanos_since_epoch,
1247 offset_seconds,
1248 timezone_name,
1249 } => {
1250 use arrow::array::{
1252 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1253 };
1254 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1255
1256 let nanos_arr =
1257 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1258 as ArrayRef;
1259 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1260 let tz_arr =
1261 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1262
1263 let fields = Fields::from(vec![
1264 Field::new(
1265 "nanos_since_epoch",
1266 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1267 true,
1268 ),
1269 Field::new("offset_seconds", ArrowDataType::Int32, true),
1270 Field::new("timezone_name", ArrowDataType::Utf8, true),
1271 ]);
1272
1273 let struct_arr =
1274 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1275 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1276 }
1277 TemporalValue::Duration {
1278 months,
1279 days,
1280 nanos,
1281 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1282 arrow::datatypes::IntervalMonthDayNano {
1283 months: *months as i32,
1284 days: *days as i32,
1285 nanoseconds: *nanos,
1286 },
1287 ))),
1288 }
1289 }
1290 Value::Vector(v) => {
1291 let cv_bytes = uni_common::cypher_value_codec::encode(&Value::Vector(v.clone()));
1293 Ok(ScalarValue::LargeBinary(Some(cv_bytes)))
1294 }
1295 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1296 other => {
1298 let json_val: serde_json::Value = other.clone().into();
1299 let json_str = serde_json::to_string(&json_val)
1300 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1301 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1302 }
1303 }
1304}
1305
1306fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1308 match op {
1309 BinaryOp::Eq => Ok(left.eq(right)),
1313 BinaryOp::NotEq => Ok(left.not_eq(right)),
1314 BinaryOp::Lt => Ok(left.lt(right)),
1315 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1316 BinaryOp::Gt => Ok(left.gt(right)),
1317 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1318
1319 BinaryOp::And => Ok(left.and(right)),
1321 BinaryOp::Or => Ok(left.or(right)),
1322 BinaryOp::Xor => {
1323 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1325 }
1326
1327 BinaryOp::Add => {
1329 if is_list_expr(&left) || is_list_expr(&right) {
1330 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1331 } else {
1332 Ok(left + right)
1333 }
1334 }
1335 BinaryOp::Sub => Ok(left - right),
1336 BinaryOp::Mul => Ok(left * right),
1337 BinaryOp::Div => Ok(left / right),
1338 BinaryOp::Mod => Ok(left % right),
1339 BinaryOp::Pow => {
1340 let left_f = datafusion::logical_expr::cast(
1343 left,
1344 datafusion::arrow::datatypes::DataType::Float64,
1345 );
1346 let right_f = datafusion::logical_expr::cast(
1347 right,
1348 datafusion::arrow::datatypes::DataType::Float64,
1349 );
1350 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1351 }
1352
1353 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1355 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1356 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1357
1358 BinaryOp::Regex => {
1359 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1360 }
1361
1362 BinaryOp::ApproxEq => Err(anyhow!(
1363 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1364 )),
1365 }
1366}
1367
1368macro_rules! check_args {
1373 (1, $df_args:expr, $name:expr) => {
1374 if let Err(e) = require_arg($df_args, $name) {
1375 return Some(Err(e));
1376 }
1377 };
1378 ($n:expr, $df_args:expr, $name:expr) => {
1379 if let Err(e) = require_args($df_args, $n, $name) {
1380 return Some(Err(e));
1381 }
1382 };
1383}
1384
1385fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1388 if df_args.len() < count {
1389 let noun = if count == 1 { "argument" } else { "arguments" };
1390 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1391 }
1392 Ok(())
1393}
1394
1395fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1397 require_args(df_args, 1, func_name)
1398}
1399
1400fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1402 df_args[0].clone()
1403}
1404
1405pub(crate) fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1407 DfExpr::Cast(datafusion::logical_expr::Cast {
1408 expr: Box::new(expr),
1409 data_type,
1410 })
1411}
1412
1413pub(crate) fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1419 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1420 Arc::new(crate::query::df_udfs::create_cypher_list_to_cv_udf()),
1421 vec![expr],
1422 ))
1423}
1424
1425pub(crate) fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1429 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1430 Arc::new(crate::query::df_udfs::create_cypher_scalar_to_cv_udf()),
1431 vec![expr],
1432 ))
1433}
1434
1435fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1437 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1438 Box::new(left),
1439 op,
1440 Box::new(right),
1441 ))
1442}
1443
1444pub(crate) fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1449 use datafusion::logical_expr::Operator;
1450 match op {
1451 Operator::Eq => Some("_cypher_equal"),
1452 Operator::NotEq => Some("_cypher_not_equal"),
1453 Operator::Lt => Some("_cypher_lt"),
1454 Operator::LtEq => Some("_cypher_lt_eq"),
1455 Operator::Gt => Some("_cypher_gt"),
1456 Operator::GtEq => Some("_cypher_gt_eq"),
1457 _ => None,
1458 }
1459}
1460
1461fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1463 use datafusion::logical_expr::Operator;
1464 match op {
1465 Operator::Plus => Some("_cypher_add"),
1466 Operator::Minus => Some("_cypher_sub"),
1467 Operator::Multiply => Some("_cypher_mul"),
1468 Operator::Divide => Some("_cypher_div"),
1469 Operator::Modulo => Some("_cypher_mod"),
1470 _ => None,
1471 }
1472}
1473
1474fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1479where
1480 F: FnOnce(DfExpr) -> DfExpr,
1481{
1482 require_arg(df_args, func_name)?;
1483 Ok(math_fn(cast_expr(
1484 first_arg(df_args),
1485 datafusion::arrow::datatypes::DataType::Float64,
1486 )))
1487}
1488
1489fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1491 if distinct {
1492 expr.distinct()
1493 .build()
1494 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1495 } else {
1496 Ok(expr)
1497 }
1498}
1499
1500fn translate_aggregate_function(
1502 name_upper: &str,
1503 df_args: &[DfExpr],
1504 distinct: bool,
1505) -> Option<Result<DfExpr>> {
1506 match name_upper {
1507 "COUNT" => {
1508 let expr = if df_args.is_empty() {
1509 datafusion::functions_aggregate::count::count(lit(1i64))
1510 } else {
1511 datafusion::functions_aggregate::count::count(first_arg(df_args))
1512 };
1513 Some(maybe_distinct(expr, distinct, "COUNT"))
1514 }
1515 "SUM" => {
1516 check_args!(1, df_args, "SUM");
1517 let udaf = Arc::new(crate::query::df_udfs::create_cypher_sum_udaf());
1518 Some(maybe_distinct(
1519 udaf.call(vec![first_arg(df_args)]),
1520 distinct,
1521 "SUM",
1522 ))
1523 }
1524 "AVG" => {
1525 check_args!(1, df_args, "AVG");
1526 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1527 let expr = datafusion::functions_aggregate::average::avg(coerced);
1528 Some(maybe_distinct(expr, distinct, "AVG"))
1529 }
1530 "MIN" => {
1531 check_args!(1, df_args, "MIN");
1532 let udaf = Arc::new(crate::query::df_udfs::create_cypher_min_udaf());
1533 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1534 }
1535 "MAX" => {
1536 check_args!(1, df_args, "MAX");
1537 let udaf = Arc::new(crate::query::df_udfs::create_cypher_max_udaf());
1538 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1539 }
1540 "PERCENTILEDISC" => {
1541 if df_args.len() != 2 {
1542 return Some(Err(anyhow!(
1543 "percentileDisc() requires exactly 2 arguments"
1544 )));
1545 }
1546 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1547 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_disc_udaf());
1548 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1549 }
1550 "PERCENTILECONT" => {
1551 if df_args.len() != 2 {
1552 return Some(Err(anyhow!(
1553 "percentileCont() requires exactly 2 arguments"
1554 )));
1555 }
1556 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1557 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_cont_udaf());
1558 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1559 }
1560 "COLLECT" => {
1561 check_args!(1, df_args, "COLLECT");
1562 Some(Ok(crate::query::df_udfs::create_cypher_collect_expr(
1563 first_arg(df_args),
1564 distinct,
1565 )))
1566 }
1567 _ => None,
1568 }
1569}
1570
1571fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1574 match name_upper {
1575 "TOSTRING" => {
1576 check_args!(1, df_args, "toString");
1577 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1578 }
1579 "TOINTEGER" | "TOINT" => {
1580 check_args!(1, df_args, "toInteger");
1581 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1582 }
1583 "TOFLOAT" => {
1584 check_args!(1, df_args, "toFloat");
1585 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1586 }
1587 "TOBOOLEAN" | "TOBOOL" => {
1588 check_args!(1, df_args, "toBoolean");
1589 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1590 }
1591 "UPPER" | "TOUPPER" => {
1592 check_args!(1, df_args, "upper");
1593 Some(Ok(datafusion::functions::string::expr_fn::upper(
1594 first_arg(df_args),
1595 )))
1596 }
1597 "LOWER" | "TOLOWER" => {
1598 check_args!(1, df_args, "lower");
1599 Some(Ok(datafusion::functions::string::expr_fn::lower(
1600 first_arg(df_args),
1601 )))
1602 }
1603 "SUBSTRING" => {
1604 check_args!(2, df_args, "substring");
1605 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1606 }
1607 "TRIM" => {
1608 check_args!(1, df_args, "TRIM");
1609 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1610 first_arg(df_args),
1611 ])))
1612 }
1613 "LTRIM" => {
1614 check_args!(1, df_args, "LTRIM");
1615 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1616 first_arg(df_args),
1617 ])))
1618 }
1619 "RTRIM" => {
1620 check_args!(1, df_args, "RTRIM");
1621 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1622 first_arg(df_args),
1623 ])))
1624 }
1625 "LEFT" => {
1626 check_args!(2, df_args, "left");
1627 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1628 df_args[0].clone(),
1629 df_args[1].clone(),
1630 )))
1631 }
1632 "RIGHT" => {
1633 check_args!(2, df_args, "right");
1634 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1635 df_args[0].clone(),
1636 df_args[1].clone(),
1637 )))
1638 }
1639 "REPLACE" => {
1640 check_args!(3, df_args, "replace");
1641 Some(Ok(datafusion::functions::string::expr_fn::replace(
1642 df_args[0].clone(),
1643 df_args[1].clone(),
1644 df_args[2].clone(),
1645 )))
1646 }
1647 "REVERSE" => {
1648 check_args!(1, df_args, "reverse");
1649 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1650 }
1651 "SPLIT" => {
1652 check_args!(2, df_args, "split");
1653 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1654 }
1655 "SIZE" | "LENGTH" => {
1656 check_args!(1, df_args, name_upper);
1657 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1658 }
1659 _ => None,
1660 }
1661}
1662
1663fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1666 use datafusion::functions::math::expr_fn;
1667
1668 let unary_f64 =
1670 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1671
1672 match name_upper {
1673 "ABS" => {
1674 check_args!(1, df_args, "abs");
1675 Some(Ok(crate::query::df_udfs::cypher_abs_expr(first_arg(
1679 df_args,
1680 ))))
1681 }
1682 "CEIL" | "CEILING" => {
1683 check_args!(1, df_args, "ceil");
1684 Some(Ok(expr_fn::ceil(first_arg(df_args))))
1685 }
1686 "FLOOR" => {
1687 check_args!(1, df_args, "floor");
1688 Some(Ok(expr_fn::floor(first_arg(df_args))))
1689 }
1690 "ROUND" => {
1691 check_args!(1, df_args, "round");
1692 let args = if df_args.len() == 1 {
1693 vec![first_arg(df_args)]
1694 } else {
1695 vec![df_args[0].clone(), df_args[1].clone()]
1696 };
1697 Some(Ok(expr_fn::round(args)))
1698 }
1699 "SIGN" => {
1700 check_args!(1, df_args, "sign");
1701 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1702 Some(Ok(expr_fn::signum(coerced)))
1703 }
1704 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1705 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1706 "LOG10" => unary_f64("log10", expr_fn::log10),
1707 "EXP" => unary_f64("exp", expr_fn::exp),
1708 "SIN" => unary_f64("sin", expr_fn::sin),
1709 "COS" => unary_f64("cos", expr_fn::cos),
1710 "TAN" => unary_f64("tan", expr_fn::tan),
1711 "ASIN" => unary_f64("asin", expr_fn::asin),
1712 "ACOS" => unary_f64("acos", expr_fn::acos),
1713 "ATAN" => unary_f64("atan", expr_fn::atan),
1714 "ATAN2" => {
1715 check_args!(2, df_args, "atan2");
1716 let cast_f64 =
1717 |e: DfExpr| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64);
1718 Some(Ok(expr_fn::atan2(
1719 cast_f64(df_args[0].clone()),
1720 cast_f64(df_args[1].clone()),
1721 )))
1722 }
1723 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1724 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1725 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1726 _ => None,
1727 }
1728}
1729
1730fn translate_temporal_function(
1733 name_upper: &str,
1734 name: &str,
1735 df_args: &[DfExpr],
1736 context: Option<&TranslationContext>,
1737) -> Option<Result<DfExpr>> {
1738 match name_upper {
1739 "DATE"
1740 | "TIME"
1741 | "LOCALTIME"
1742 | "LOCALDATETIME"
1743 | "DATETIME"
1744 | "DURATION"
1745 | "YEAR"
1746 | "MONTH"
1747 | "DAY"
1748 | "HOUR"
1749 | "MINUTE"
1750 | "SECOND"
1751 | "DURATION.BETWEEN"
1752 | "DURATION.INMONTHS"
1753 | "DURATION.INDAYS"
1754 | "DURATION.INSECONDS"
1755 | "DATETIME.FROMEPOCH"
1756 | "DATETIME.FROMEPOCHMILLIS"
1757 | "DATE.TRUNCATE"
1758 | "TIME.TRUNCATE"
1759 | "DATETIME.TRUNCATE"
1760 | "LOCALDATETIME.TRUNCATE"
1761 | "LOCALTIME.TRUNCATE"
1762 | "DATETIME.TRANSACTION"
1763 | "DATETIME.STATEMENT"
1764 | "DATETIME.REALTIME"
1765 | "DATE.TRANSACTION"
1766 | "DATE.STATEMENT"
1767 | "DATE.REALTIME"
1768 | "TIME.TRANSACTION"
1769 | "TIME.STATEMENT"
1770 | "TIME.REALTIME"
1771 | "LOCALTIME.TRANSACTION"
1772 | "LOCALTIME.STATEMENT"
1773 | "LOCALTIME.REALTIME"
1774 | "LOCALDATETIME.TRANSACTION"
1775 | "LOCALDATETIME.STATEMENT"
1776 | "LOCALDATETIME.REALTIME" => {
1777 let stmt_time = context.map(|c| c.statement_time);
1781 if can_constant_fold(name_upper, df_args)
1782 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1783 {
1784 return Some(Ok(folded));
1785 }
1786 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1787 }
1788 _ => None,
1789 }
1790}
1791
1792fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1794 if name.contains("REALTIME") {
1796 return false;
1797 }
1798 if args.is_empty() {
1806 return matches!(
1807 name,
1808 "DATE"
1809 | "TIME"
1810 | "LOCALTIME"
1811 | "LOCALDATETIME"
1812 | "DATETIME"
1813 | "DATE.STATEMENT"
1814 | "TIME.STATEMENT"
1815 | "LOCALTIME.STATEMENT"
1816 | "LOCALDATETIME.STATEMENT"
1817 | "DATETIME.STATEMENT"
1818 | "DATE.TRANSACTION"
1819 | "TIME.TRANSACTION"
1820 | "LOCALTIME.TRANSACTION"
1821 | "LOCALDATETIME.TRANSACTION"
1822 | "DATETIME.TRANSACTION"
1823 );
1824 }
1825 args.iter().all(is_constant_expr)
1827}
1828
1829fn is_constant_expr(expr: &DfExpr) -> bool {
1831 match expr {
1832 DfExpr::Literal(_, _) => true,
1833 DfExpr::ScalarFunction(func) => {
1834 func.args.iter().all(is_constant_expr)
1836 }
1837 _ => false,
1838 }
1839}
1840
1841fn try_constant_fold_temporal(
1847 name: &str,
1848 args: &[DfExpr],
1849 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1850) -> Result<DfExpr> {
1851 let val_args: Vec<Value> = args
1853 .iter()
1854 .map(extract_constant_value)
1855 .collect::<Result<_>>()?;
1856
1857 let result = if val_args.is_empty() {
1859 if let Some(frozen) = stmt_time {
1860 crate::query::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1861 } else {
1862 crate::query::datetime::eval_datetime_function(name, &val_args)?
1863 }
1864 } else {
1865 crate::query::datetime::eval_datetime_function(name, &val_args)?
1866 };
1867
1868 let scalar = value_to_scalar(&result)?;
1870 Ok(DfExpr::Literal(scalar, None))
1871}
1872
1873fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1875 use crate::query::df_udfs::scalar_to_value;
1876 match expr {
1877 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1878 DfExpr::ScalarFunction(func) => {
1879 let mut map = std::collections::HashMap::new();
1882 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1883 for chunk in pairs.chunks(2) {
1884 if let [key_expr, val_expr] = chunk {
1885 let key = match key_expr {
1887 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1888 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1889 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1890 };
1891 let val = extract_constant_value(val_expr)?;
1892 map.insert(key, val);
1893 } else {
1894 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1895 }
1896 }
1897 Ok(Value::Map(map))
1898 }
1899 _ => Err(anyhow::anyhow!(
1900 "Cannot extract constant value from expression"
1901 )),
1902 }
1903}
1904
1905fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1908 match name_upper {
1909 "HEAD" => {
1910 check_args!(1, df_args, "head");
1911 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1912 }
1913 "LAST" => {
1914 check_args!(1, df_args, "last");
1915 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1916 }
1917 "TAIL" => {
1918 check_args!(1, df_args, "tail");
1919 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
1920 }
1921 "RANGE" => {
1922 check_args!(2, df_args, "range");
1923 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
1924 }
1925 _ => None,
1926 }
1927}
1928
1929fn translate_graph_function(
1932 name_upper: &str,
1933 name: &str,
1934 df_args: &[DfExpr],
1935 args: &[Expr],
1936 context: Option<&TranslationContext>,
1937) -> Option<Result<DfExpr>> {
1938 match name_upper {
1939 "ID" => {
1940 if let Some(Expr::Variable(var)) = args.first() {
1943 let is_edge = context.is_some_and(|ctx| {
1944 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
1945 || ctx.mutation_edge_hints.iter().any(|h| h == var)
1946 });
1947 let id_suffix = if is_edge { COL_EID } else { COL_VID };
1948 Some(Ok(DfExpr::Column(Column::from_name(format!(
1949 "{}.{}",
1950 var, id_suffix
1951 )))))
1952 } else {
1953 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
1954 }
1955 }
1956 "LABELS" | "KEYS" => {
1957 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1962 }
1963 "TYPE" => {
1964 if let Some(Expr::Variable(var)) = args.first()
1968 && let Some(ctx) = context
1969 && let Some(label) = ctx.variable_labels.get(var)
1970 {
1971 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
1974 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
1975 expr: None,
1976 when_then_expr: vec![(
1977 Box::new(eid_col.is_not_null()),
1978 Box::new(lit(label.clone())),
1979 )],
1980 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
1981 })));
1982 }
1983 if let Some(Expr::Variable(var)) = args.first()
1987 && context
1988 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
1989 {
1990 return Some(Ok(DfExpr::Column(Column::from_name(format!(
1991 "{}.{}",
1992 var, COL_TYPE
1993 )))));
1994 }
1995 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
1996 }
1997 "PROPERTIES" => {
1998 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2001 }
2002 "UNI.TEMPORAL.VALIDAT" => {
2003 if let (
2006 Some(Expr::Variable(var)),
2007 Some(Expr::Literal(CypherLiteral::String(start_prop))),
2008 Some(Expr::Literal(CypherLiteral::String(end_prop))),
2009 Some(ts_expr),
2010 ) = (args.first(), args.get(1), args.get(2), args.get(3))
2011 {
2012 let start_col =
2013 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
2014 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
2015 let ts = match cypher_expr_to_df(ts_expr, context) {
2016 Ok(ts) => ts,
2017 Err(e) => return Some(Err(e)),
2018 };
2019
2020 let start_check = start_col.lt_eq(ts.clone());
2022 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
2024 let end_after = end_col.gt(ts);
2025 let end_check = end_null.or(end_after);
2026
2027 Some(Ok(start_check.and(end_check)))
2028 } else {
2029 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2031 }
2032 }
2033 "STARTNODE" | "ENDNODE" => {
2034 let mut udf_args = df_args.to_vec();
2037 let mut seen = std::collections::HashSet::new();
2038 if let Some(ctx) = context {
2039 for (var, kind) in &ctx.variable_kinds {
2041 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2042 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2043 }
2044 }
2045 for var in &ctx.node_variable_hints {
2048 if seen.insert(var.clone()) {
2049 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2050 }
2051 }
2052 }
2053 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2054 }
2055 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2056 "HASLABEL" => {
2057 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2058 return Some(Err(e));
2059 }
2060 if let Some(Expr::Variable(var)) = args.first() {
2062 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2063 let labels_col =
2065 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2066 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2067 labels_col,
2068 lit(label.clone()),
2069 )))
2070 } else {
2071 Some(Err(anyhow::anyhow!(
2073 "hasLabel requires string literal as second argument for DataFusion translation"
2074 )))
2075 }
2076 } else {
2077 Some(Err(anyhow::anyhow!(
2079 "hasLabel requires variable as first argument for DataFusion translation"
2080 )))
2081 }
2082 }
2083 _ => None,
2084 }
2085}
2086
2087fn translate_function_call(
2089 name: &str,
2090 args: &[Expr],
2091 distinct: bool,
2092 context: Option<&TranslationContext>,
2093) -> Result<DfExpr> {
2094 let df_args: Vec<DfExpr> = args
2095 .iter()
2096 .map(|arg| cypher_expr_to_df(arg, context))
2097 .collect::<Result<Vec<_>>>()?;
2098
2099 let name_upper = name.to_uppercase();
2100
2101 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2105 return result;
2106 }
2107
2108 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2109 return result;
2110 }
2111
2112 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2113 return result;
2114 }
2115
2116 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2117 return result;
2118 }
2119
2120 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2121 return result;
2122 }
2123
2124 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2125 return result;
2126 }
2127
2128 match name_upper.as_str() {
2130 "COALESCE" => {
2131 require_arg(&df_args, "coalesce")?;
2132 if df_args.len() == 1 {
2137 return Ok(df_args.into_iter().next().unwrap());
2138 }
2139 let n = df_args.len();
2140 let (init, last) = df_args.split_at(n - 1);
2141 let mut builder = datafusion::logical_expr::conditional_expressions::CaseBuilder::new(
2142 None,
2143 vec![],
2144 vec![],
2145 None,
2146 );
2147 for arg in init {
2148 builder.when(arg.clone().is_not_null(), arg.clone());
2149 }
2150 return Ok(builder.otherwise(last[0].clone())?);
2151 }
2152 "NULLIF" => {
2153 require_args(&df_args, 2, "nullif")?;
2154 return Ok(datafusion::functions::expr_fn::nullif(
2155 df_args[0].clone(),
2156 df_args[1].clone(),
2157 ));
2158 }
2159 _ => {}
2160 }
2161
2162 match name_upper.as_str() {
2164 "SIMILAR_TO" | "VECTOR_SIMILARITY" => {
2165 return Ok(dummy_udf_expr(&name_upper.to_lowercase(), df_args));
2166 }
2167 _ => {}
2168 }
2169
2170 Ok(dummy_udf_expr(name, df_args))
2172}
2173
2174#[derive(Debug)]
2179struct DummyUdf {
2180 name: String,
2181 signature: datafusion::logical_expr::Signature,
2182 ret_type: datafusion::arrow::datatypes::DataType,
2183}
2184
2185impl DummyUdf {
2186 fn new(name: String) -> Self {
2187 let ret_type = dummy_udf_return_type(&name);
2188 Self {
2189 name,
2190 signature: datafusion::logical_expr::Signature::variadic_any(
2191 datafusion::logical_expr::Volatility::Immutable,
2192 ),
2193 ret_type,
2194 }
2195 }
2196}
2197
2198fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2211 use datafusion::arrow::datatypes::DataType;
2212 match name {
2213 "_cypher_add"
2217 | "_cypher_sub"
2218 | "_cypher_mul"
2219 | "_cypher_div"
2220 | "_cypher_mod"
2221 | "_cypher_list_concat"
2222 | "_cypher_list_append"
2223 | "_make_cypher_list"
2224 | "_map_project"
2225 | "_cypher_list_to_cv"
2226 | "_cypher_tail" => DataType::LargeBinary,
2227 _ => DataType::Null,
2231 }
2232}
2233
2234impl PartialEq for DummyUdf {
2235 fn eq(&self, other: &Self) -> bool {
2236 self.name == other.name
2237 }
2238}
2239
2240impl Eq for DummyUdf {}
2241
2242impl Hash for DummyUdf {
2243 fn hash<H: Hasher>(&self, state: &mut H) {
2244 self.name.hash(state);
2245 }
2246}
2247
2248pub(crate) fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2250 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2251 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2252 DummyUdf::new(name.to_lowercase()),
2253 )),
2254 args,
2255 })
2256}
2257
2258impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2259 fn as_any(&self) -> &dyn std::any::Any {
2260 self
2261 }
2262
2263 fn name(&self) -> &str {
2264 &self.name
2265 }
2266
2267 fn signature(&self) -> &datafusion::logical_expr::Signature {
2268 &self.signature
2269 }
2270
2271 fn return_type(
2272 &self,
2273 _arg_types: &[datafusion::arrow::datatypes::DataType],
2274 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2275 Ok(self.ret_type.clone())
2278 }
2279
2280 fn invoke_with_args(
2281 &self,
2282 _args: ScalarFunctionArgs,
2283 ) -> datafusion::error::Result<ColumnarValue> {
2284 Err(datafusion::error::DataFusionError::Plan(format!(
2285 "UDF '{}' is not registered. Register it via SessionContext.",
2286 self.name
2287 )))
2288 }
2289}
2290
2291pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2295 let mut properties = Vec::new();
2296 collect_properties_recursive(expr, &mut properties);
2297 properties.sort();
2298 properties.dedup();
2299 properties
2300}
2301
2302fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2303 match expr {
2304 Expr::PatternComprehension { .. } => {}
2305 Expr::Property(base, prop) => {
2306 if let Ok(var_name) = extract_variable_name(base) {
2307 properties.push((var_name, prop.clone()));
2308 }
2309 collect_properties_recursive(base, properties);
2310 }
2311 Expr::ArrayIndex { array, index } => {
2312 if let Ok(var_name) = extract_variable_name(array)
2313 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2314 {
2315 properties.push((var_name, prop_name.clone()));
2316 }
2317 collect_properties_recursive(array, properties);
2318 collect_properties_recursive(index, properties);
2319 }
2320 Expr::ArraySlice { array, start, end } => {
2321 collect_properties_recursive(array, properties);
2322 if let Some(s) = start {
2323 collect_properties_recursive(s, properties);
2324 }
2325 if let Some(e) = end {
2326 collect_properties_recursive(e, properties);
2327 }
2328 }
2329 Expr::List(items) => {
2330 for item in items {
2331 collect_properties_recursive(item, properties);
2332 }
2333 }
2334 Expr::Map(entries) => {
2335 for (_, value) in entries {
2336 collect_properties_recursive(value, properties);
2337 }
2338 }
2339 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2340 collect_properties_recursive(inner, properties);
2341 }
2342 Expr::FunctionCall { args, .. } => {
2343 for arg in args {
2344 collect_properties_recursive(arg, properties);
2345 }
2346 }
2347 Expr::BinaryOp { left, right, .. } => {
2348 collect_properties_recursive(left, properties);
2349 collect_properties_recursive(right, properties);
2350 }
2351 Expr::UnaryOp { expr, .. } => {
2352 collect_properties_recursive(expr, properties);
2353 }
2354 Expr::Case {
2355 expr,
2356 when_then,
2357 else_expr,
2358 } => {
2359 if let Some(e) = expr {
2360 collect_properties_recursive(e, properties);
2361 }
2362 for (when_e, then_e) in when_then {
2363 collect_properties_recursive(when_e, properties);
2364 collect_properties_recursive(then_e, properties);
2365 }
2366 if let Some(e) = else_expr {
2367 collect_properties_recursive(e, properties);
2368 }
2369 }
2370 Expr::Reduce {
2371 init, list, expr, ..
2372 } => {
2373 collect_properties_recursive(init, properties);
2374 collect_properties_recursive(list, properties);
2375 collect_properties_recursive(expr, properties);
2376 }
2377 Expr::Quantifier {
2378 list, predicate, ..
2379 } => {
2380 collect_properties_recursive(list, properties);
2381 collect_properties_recursive(predicate, properties);
2382 }
2383 Expr::ListComprehension {
2384 list,
2385 where_clause,
2386 map_expr,
2387 ..
2388 } => {
2389 collect_properties_recursive(list, properties);
2390 if let Some(filter) = where_clause {
2391 collect_properties_recursive(filter, properties);
2392 }
2393 collect_properties_recursive(map_expr, properties);
2394 }
2395 Expr::In { expr, list } => {
2396 collect_properties_recursive(expr, properties);
2397 collect_properties_recursive(list, properties);
2398 }
2399 Expr::ValidAt {
2400 entity, timestamp, ..
2401 } => {
2402 collect_properties_recursive(entity, properties);
2403 collect_properties_recursive(timestamp, properties);
2404 }
2405 Expr::MapProjection { base, items } => {
2406 collect_properties_recursive(base, properties);
2407 for item in items {
2408 match item {
2409 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2410 if let Ok(var_name) = extract_variable_name(base) {
2411 properties.push((var_name, prop.clone()));
2412 }
2413 }
2414 uni_cypher::ast::MapProjectionItem::AllProperties => {
2415 if let Ok(var_name) = extract_variable_name(base) {
2416 properties.push((var_name, "*".to_string()));
2417 }
2418 }
2419 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2420 collect_properties_recursive(expr, properties);
2421 }
2422 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2423 }
2424 }
2425 }
2426 Expr::LabelCheck { expr, .. } => {
2427 collect_properties_recursive(expr, properties);
2428 }
2429 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2431 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2432 }
2433}
2434
2435pub fn wider_numeric_type(
2442 a: &datafusion::arrow::datatypes::DataType,
2443 b: &datafusion::arrow::datatypes::DataType,
2444) -> datafusion::arrow::datatypes::DataType {
2445 use datafusion::arrow::datatypes::DataType;
2446
2447 fn numeric_rank(dt: &DataType) -> u8 {
2448 match dt {
2449 DataType::Int8 | DataType::UInt8 => 1,
2450 DataType::Int16 | DataType::UInt16 => 2,
2451 DataType::Int32 | DataType::UInt32 => 3,
2452 DataType::Int64 | DataType::UInt64 => 4,
2453 DataType::Float16 => 5,
2454 DataType::Float32 => 6,
2455 DataType::Float64 => 7,
2456 _ => 0,
2457 }
2458 }
2459
2460 if numeric_rank(a) >= numeric_rank(b) {
2461 a.clone()
2462 } else {
2463 b.clone()
2464 }
2465}
2466
2467fn resolve_column_type_fallback(
2473 expr: &DfExpr,
2474 schema: &datafusion::common::DFSchema,
2475) -> Option<datafusion::arrow::datatypes::DataType> {
2476 if let DfExpr::Column(col) = expr {
2477 let col_name = &col.name;
2478 for (_, field) in schema.iter() {
2480 if field.name() == col_name {
2481 return Some(field.data_type().clone());
2482 }
2483 }
2484 }
2485 None
2486}
2487
2488fn contains_division(expr: &DfExpr) -> bool {
2491 match expr {
2492 DfExpr::BinaryExpr(b) => {
2493 b.op == datafusion::logical_expr::Operator::Divide
2494 || contains_division(&b.left)
2495 || contains_division(&b.right)
2496 }
2497 DfExpr::Cast(c) => contains_division(&c.expr),
2498 DfExpr::TryCast(c) => contains_division(&c.expr),
2499 _ => false,
2500 }
2501}
2502
2503pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2509 use datafusion::arrow::datatypes::DataType;
2510 use datafusion::logical_expr::ExprSchemable;
2511
2512 match expr {
2513 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2514 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2515 DfExpr::Case(case) => coerce_case_expr(case, schema),
2516 DfExpr::InList(in_list) => {
2517 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2518 let coerced_list = in_list
2519 .list
2520 .iter()
2521 .map(|e| apply_type_coercion(e, schema))
2522 .collect::<Result<Vec<_>>>()?;
2523 let expr_type = coerced_expr
2524 .get_type(schema)
2525 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2526 crate::query::cypher_type_coerce::build_cypher_in_list(
2527 coerced_expr,
2528 &expr_type,
2529 coerced_list,
2530 in_list.negated,
2531 schema,
2532 )
2533 }
2534 DfExpr::Not(inner) => {
2535 let coerced_inner = apply_type_coercion(inner, schema)?;
2536 let inner_type = coerced_inner.get_type(schema).ok();
2537 let final_inner = if inner_type
2538 .as_ref()
2539 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2540 {
2541 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2542 } else if inner_type
2543 .as_ref()
2544 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2545 {
2546 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2547 } else {
2548 coerced_inner
2549 };
2550 Ok(DfExpr::Not(Box::new(final_inner)))
2551 }
2552 DfExpr::IsNull(inner) => {
2553 let coerced_inner = apply_type_coercion(inner, schema)?;
2554 Ok(coerced_inner.is_null())
2555 }
2556 DfExpr::IsNotNull(inner) => {
2557 let coerced_inner = apply_type_coercion(inner, schema)?;
2558 Ok(coerced_inner.is_not_null())
2559 }
2560 DfExpr::Negative(inner) => {
2561 let coerced_inner = apply_type_coercion(inner, schema)?;
2562 let inner_type = coerced_inner.get_type(schema).ok();
2563 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2564 Ok(dummy_udf_expr(
2565 "_cypher_mul",
2566 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2567 ))
2568 } else {
2569 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2570 }
2571 }
2572 DfExpr::Cast(cast) => {
2573 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2574 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2575 Box::new(coerced_inner),
2576 cast.data_type.clone(),
2577 )))
2578 }
2579 DfExpr::TryCast(cast) => {
2580 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2581 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2582 Box::new(coerced_inner),
2583 cast.data_type.clone(),
2584 )))
2585 }
2586 DfExpr::Alias(alias) => {
2587 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2588 Ok(coerced_inner.alias(alias.name.clone()))
2589 }
2590 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2591 _ => Ok(expr.clone()),
2592 }
2593}
2594
2595fn coerce_logical_operands(
2597 left: DfExpr,
2598 right: DfExpr,
2599 op: datafusion::logical_expr::Operator,
2600 schema: &datafusion::common::DFSchema,
2601) -> Option<DfExpr> {
2602 use datafusion::arrow::datatypes::DataType;
2603 use datafusion::logical_expr::ExprSchemable;
2604
2605 if !matches!(
2606 op,
2607 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2608 ) {
2609 return None;
2610 }
2611 let left_type = left.get_type(schema).ok();
2612 let right_type = right.get_type(schema).ok();
2613 let left_needs_cast = left_type
2614 .as_ref()
2615 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2616 let right_needs_cast = right_type
2617 .as_ref()
2618 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2619 let left_is_lb = left_type
2620 .as_ref()
2621 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2622 let right_is_lb = right_type
2623 .as_ref()
2624 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2625 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2626 return None;
2627 }
2628 let coerced_left = if left_is_lb {
2629 dummy_udf_expr("_cv_to_bool", vec![left])
2630 } else if left_needs_cast {
2631 datafusion::logical_expr::cast(left, DataType::Boolean)
2632 } else {
2633 left
2634 };
2635 let coerced_right = if right_is_lb {
2636 dummy_udf_expr("_cv_to_bool", vec![right])
2637 } else if right_needs_cast {
2638 datafusion::logical_expr::cast(right, DataType::Boolean)
2639 } else {
2640 right
2641 };
2642 Some(binary_expr(coerced_left, op, coerced_right))
2643}
2644
2645#[expect(
2648 clippy::too_many_arguments,
2649 reason = "Binary coercion needs all context"
2650)]
2651fn coerce_large_binary_ops(
2652 left: &DfExpr,
2653 right: &DfExpr,
2654 left_type: &datafusion::arrow::datatypes::DataType,
2655 right_type: &datafusion::arrow::datatypes::DataType,
2656 left_is_null: bool,
2657 op: datafusion::logical_expr::Operator,
2658 is_comparison: bool,
2659 is_arithmetic: bool,
2660) -> Option<Result<DfExpr>> {
2661 use datafusion::arrow::datatypes::DataType;
2662 use datafusion::logical_expr::Operator;
2663
2664 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2665 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2666
2667 if op == Operator::Plus {
2668 if left_is_lb && right_is_lb {
2669 return Some(Ok(dummy_udf_expr(
2670 "_cypher_add",
2671 vec![left.clone(), right.clone()],
2672 )));
2673 }
2674 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2675 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2676 if left_is_native_list && right_is_native_list {
2677 return Some(Ok(dummy_udf_expr(
2678 "_cypher_list_concat",
2679 vec![left.clone(), right.clone()],
2680 )));
2681 }
2682 if left_is_native_list || right_is_native_list {
2683 return Some(Ok(dummy_udf_expr(
2684 "_cypher_list_append",
2685 vec![left.clone(), right.clone()],
2686 )));
2687 }
2688 }
2689
2690 if (left_is_lb || right_is_lb) && is_comparison {
2691 if let Some(udf_name) = comparison_udf_name(op) {
2692 return Some(Ok(dummy_udf_expr(
2693 udf_name,
2694 vec![left.clone(), right.clone()],
2695 )));
2696 }
2697 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2698 }
2699
2700 if (left_is_lb || right_is_lb) && is_arithmetic {
2701 let udf_name =
2702 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2703 return Some(Ok(dummy_udf_expr(
2704 udf_name,
2705 vec![left.clone(), right.clone()],
2706 )));
2707 }
2708
2709 None
2710}
2711
2712fn coerce_temporal_comparisons(
2714 left: DfExpr,
2715 right: DfExpr,
2716 left_type: &datafusion::arrow::datatypes::DataType,
2717 right_type: &datafusion::arrow::datatypes::DataType,
2718 op: datafusion::logical_expr::Operator,
2719 is_comparison: bool,
2720) -> Option<DfExpr> {
2721 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2722 use datafusion::logical_expr::Operator;
2723
2724 if !is_comparison {
2725 return None;
2726 }
2727
2728 if uni_common::core::schema::is_datetime_struct(left_type)
2730 && uni_common::core::schema::is_datetime_struct(right_type)
2731 {
2732 return Some(binary_expr(
2733 extract_datetime_nanos(left),
2734 op,
2735 extract_datetime_nanos(right),
2736 ));
2737 }
2738
2739 if uni_common::core::schema::is_time_struct(left_type)
2741 && uni_common::core::schema::is_time_struct(right_type)
2742 {
2743 return Some(binary_expr(
2744 extract_time_nanos(left),
2745 op,
2746 extract_time_nanos(right),
2747 ));
2748 }
2749
2750 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2752 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2753
2754 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2755 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2756 {
2757 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2758 extract_datetime_nanos(left)
2759 } else {
2760 left
2761 };
2762 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2763 extract_datetime_nanos(right)
2764 } else {
2765 right
2766 };
2767 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2768 return Some(binary_expr(
2769 cast_expr(left_nanos, ts_type.clone()),
2770 op,
2771 cast_expr(right_nanos, ts_type),
2772 ));
2773 }
2774
2775 let left_is_duration = matches!(left_type, DataType::Interval(_));
2779 let right_is_duration = matches!(right_type, DataType::Interval(_));
2780 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2781 || uni_common::core::schema::is_time_struct(left_type)
2782 || matches!(
2783 left_type,
2784 DataType::Timestamp(_, _)
2785 | DataType::Date32
2786 | DataType::Date64
2787 | DataType::Time32(_)
2788 | DataType::Time64(_)
2789 );
2790 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2791 || uni_common::core::schema::is_time_struct(right_type)
2792 || matches!(
2793 right_type,
2794 DataType::Timestamp(_, _)
2795 | DataType::Date32
2796 | DataType::Date64
2797 | DataType::Time32(_)
2798 | DataType::Time64(_)
2799 );
2800
2801 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2802 {
2803 return Some(match op {
2804 Operator::Eq => lit(false),
2805 Operator::NotEq => lit(true),
2806 _ => lit(ScalarValue::Boolean(None)),
2807 });
2808 }
2809
2810 None
2811}
2812
2813fn coerce_mismatched_types(
2816 left: DfExpr,
2817 right: DfExpr,
2818 left_type: &datafusion::arrow::datatypes::DataType,
2819 right_type: &datafusion::arrow::datatypes::DataType,
2820 op: datafusion::logical_expr::Operator,
2821 is_comparison: bool,
2822) -> Option<Result<DfExpr>> {
2823 use datafusion::arrow::datatypes::DataType;
2824 use datafusion::logical_expr::Operator;
2825
2826 if left_type == right_type {
2827 return None;
2828 }
2829
2830 if left_type.is_numeric() && right_type.is_numeric() {
2832 if left_type == &DataType::Int64
2833 && right_type == &DataType::UInt64
2834 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2835 {
2836 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2837 return Some(Ok(binary_expr(coerced_left, op, right)));
2838 }
2839 if left_type == &DataType::UInt64
2840 && right_type == &DataType::Int64
2841 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2842 {
2843 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2844 return Some(Ok(binary_expr(left, op, coerced_right)));
2845 }
2846 let target = wider_numeric_type(left_type, right_type);
2847 let coerced_left = if *left_type != target {
2848 datafusion::logical_expr::cast(left, target.clone())
2849 } else {
2850 left
2851 };
2852 let coerced_right = if *right_type != target {
2853 datafusion::logical_expr::cast(right, target)
2854 } else {
2855 right
2856 };
2857 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2858 }
2859
2860 if is_comparison {
2862 match (left_type, right_type) {
2863 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2864 let right = normalize_datetime_literal(right);
2865 return Some(Ok(binary_expr(
2866 left,
2867 op,
2868 datafusion::logical_expr::cast(right, ts.clone()),
2869 )));
2870 }
2871 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2872 let left = normalize_datetime_literal(left);
2873 return Some(Ok(binary_expr(
2874 datafusion::logical_expr::cast(left, ts.clone()),
2875 op,
2876 right,
2877 )));
2878 }
2879 _ => {}
2880 }
2881 }
2882
2883 if is_comparison
2885 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2886 {
2887 let l_inner = l_field.data_type();
2888 let r_inner = r_field.data_type();
2889 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
2890 let target_inner = wider_numeric_type(l_inner, r_inner);
2891 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
2892 "item",
2893 target_inner,
2894 true,
2895 )));
2896 return Some(Ok(binary_expr(
2897 datafusion::logical_expr::cast(left, target_type.clone()),
2898 op,
2899 datafusion::logical_expr::cast(right, target_type),
2900 )));
2901 }
2902 }
2903
2904 if is_primitive_type(left_type) && is_primitive_type(right_type) {
2906 if op == Operator::Plus {
2907 return Some(crate::query::cypher_type_coerce::build_cypher_plus(
2908 left, left_type, right, right_type,
2909 ));
2910 }
2911 if is_comparison {
2912 return Some(Ok(
2913 crate::query::cypher_type_coerce::build_cypher_comparison(
2914 left, left_type, right, right_type, op,
2915 ),
2916 ));
2917 }
2918 }
2919
2920 None
2921}
2922
2923fn coerce_list_comparisons(
2925 left: DfExpr,
2926 right: DfExpr,
2927 left_type: &datafusion::arrow::datatypes::DataType,
2928 right_type: &datafusion::arrow::datatypes::DataType,
2929 op: datafusion::logical_expr::Operator,
2930 is_comparison: bool,
2931) -> Option<DfExpr> {
2932 use datafusion::arrow::datatypes::DataType;
2933 use datafusion::logical_expr::Operator;
2934
2935 if !is_comparison {
2936 return None;
2937 }
2938
2939 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2940 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2941
2942 if left_is_list
2944 && right_is_list
2945 && matches!(
2946 op,
2947 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
2948 )
2949 {
2950 let op_str = match op {
2951 Operator::Lt => "lt",
2952 Operator::LtEq => "lteq",
2953 Operator::Gt => "gt",
2954 Operator::GtEq => "gteq",
2955 _ => unreachable!(),
2956 };
2957 return Some(dummy_udf_expr(
2958 "_cypher_list_compare",
2959 vec![left, right, lit(op_str)],
2960 ));
2961 }
2962
2963 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
2965 let udf_name =
2966 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
2967 return Some(dummy_udf_expr(udf_name, vec![left, right]));
2968 }
2969
2970 if (left_is_list != right_is_list)
2972 && !matches!(left_type, DataType::Null)
2973 && !matches!(right_type, DataType::Null)
2974 {
2975 return Some(match op {
2976 Operator::Eq => lit(false),
2977 Operator::NotEq => lit(true),
2978 _ => lit(ScalarValue::Boolean(None)),
2979 });
2980 }
2981
2982 None
2983}
2984
2985fn coerce_binary_expr(
2987 binary: &datafusion::logical_expr::expr::BinaryExpr,
2988 schema: &datafusion::common::DFSchema,
2989) -> Result<DfExpr> {
2990 use datafusion::arrow::datatypes::DataType;
2991 use datafusion::logical_expr::ExprSchemable;
2992 use datafusion::logical_expr::Operator;
2993
2994 let left = apply_type_coercion(&binary.left, schema)?;
2995 let right = apply_type_coercion(&binary.right, schema)?;
2996
2997 let is_comparison = matches!(
2998 binary.op,
2999 Operator::Eq
3000 | Operator::NotEq
3001 | Operator::Lt
3002 | Operator::LtEq
3003 | Operator::Gt
3004 | Operator::GtEq
3005 );
3006 let is_arithmetic = matches!(
3007 binary.op,
3008 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
3009 );
3010
3011 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
3013 return Ok(result);
3014 }
3015
3016 if is_comparison || is_arithmetic {
3017 let left_type = match left.get_type(schema) {
3018 Ok(t) => t,
3019 Err(e) => {
3020 if let Some(t) = resolve_column_type_fallback(&left, schema) {
3021 t
3022 } else {
3023 log::warn!("Failed to get left type in binary expr: {}", e);
3024 return Ok(binary_expr(left, binary.op, right));
3025 }
3026 }
3027 };
3028 let right_type = match right.get_type(schema) {
3029 Ok(t) => t,
3030 Err(e) => {
3031 if let Some(t) = resolve_column_type_fallback(&right, schema) {
3032 t
3033 } else {
3034 log::warn!("Failed to get right type in binary expr: {}", e);
3035 return Ok(binary_expr(left, binary.op, right));
3036 }
3037 }
3038 };
3039
3040 let left_is_null = left_type.is_null();
3042 let right_is_null = right_type.is_null();
3043 if left_is_null && right_is_null {
3044 return Ok(lit(ScalarValue::Boolean(None)));
3045 }
3046 if left_is_null || right_is_null {
3047 let target = if left_is_null {
3048 &right_type
3049 } else {
3050 &left_type
3051 };
3052 if !matches!(target, DataType::LargeBinary) {
3053 let coerced_left = if left_is_null {
3054 datafusion::logical_expr::cast(left, target.clone())
3055 } else {
3056 left
3057 };
3058 let coerced_right = if right_is_null {
3059 datafusion::logical_expr::cast(right, target.clone())
3060 } else {
3061 right
3062 };
3063 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3064 }
3065 }
3066
3067 if let Some(result) = coerce_large_binary_ops(
3069 &left,
3070 &right,
3071 &left_type,
3072 &right_type,
3073 left_is_null,
3074 binary.op,
3075 is_comparison,
3076 is_arithmetic,
3077 ) {
3078 return result;
3079 }
3080
3081 if let Some(result) = coerce_temporal_comparisons(
3083 left.clone(),
3084 right.clone(),
3085 &left_type,
3086 &right_type,
3087 binary.op,
3088 is_comparison,
3089 ) {
3090 return Ok(result);
3091 }
3092
3093 let either_struct =
3095 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3096 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3097 || matches!(left_type, DataType::Struct(_)))
3098 && (matches!(right_type, DataType::LargeBinary)
3099 || matches!(right_type, DataType::Struct(_)));
3100 if is_comparison && either_struct && either_lb_or_struct {
3101 if let Some(udf_name) = comparison_udf_name(binary.op) {
3102 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3103 }
3104 return Ok(lit(ScalarValue::Boolean(None)));
3105 }
3106
3107 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3109 let udf_name = comparison_udf_name(binary.op)
3110 .expect("is_comparison guarantees a valid comparison operator");
3111 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3112 }
3113
3114 if binary.op == Operator::Plus
3116 && (crate::query::cypher_type_coerce::is_string_type(&left_type)
3117 || crate::query::cypher_type_coerce::is_string_type(&right_type))
3118 && is_primitive_type(&left_type)
3119 && is_primitive_type(&right_type)
3120 {
3121 return crate::query::cypher_type_coerce::build_cypher_plus(
3122 left,
3123 &left_type,
3124 right,
3125 &right_type,
3126 );
3127 }
3128
3129 if let Some(result) = coerce_mismatched_types(
3131 left.clone(),
3132 right.clone(),
3133 &left_type,
3134 &right_type,
3135 binary.op,
3136 is_comparison,
3137 ) {
3138 return result;
3139 }
3140
3141 if let Some(result) = coerce_list_comparisons(
3143 left.clone(),
3144 right.clone(),
3145 &left_type,
3146 &right_type,
3147 binary.op,
3148 is_comparison,
3149 ) {
3150 return Ok(result);
3151 }
3152 }
3153
3154 Ok(binary_expr(left, binary.op, right))
3155}
3156
3157fn coerce_scalar_function(
3159 func: &datafusion::logical_expr::expr::ScalarFunction,
3160 schema: &datafusion::common::DFSchema,
3161) -> Result<DfExpr> {
3162 use datafusion::arrow::datatypes::DataType;
3163 use datafusion::logical_expr::ExprSchemable;
3164
3165 let coerced_args: Vec<DfExpr> = func
3166 .args
3167 .iter()
3168 .map(|a| apply_type_coercion(a, schema))
3169 .collect::<Result<Vec<_>>>()?;
3170
3171 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3172 let types: Vec<_> = coerced_args
3173 .iter()
3174 .filter_map(|a| a.get_type(schema).ok())
3175 .collect();
3176 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3177 if has_mixed_types {
3178 let all_string_like = types
3182 .iter()
3183 .all(|t| matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Null));
3184 let unified_args: Vec<DfExpr> = if all_string_like {
3185 coerced_args
3186 .into_iter()
3187 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3188 .collect()
3189 } else {
3190 coerced_args
3192 .into_iter()
3193 .zip(types.iter())
3194 .map(|(arg, t)| match t {
3195 DataType::LargeBinary | DataType::Null => arg,
3196 DataType::List(_) | DataType::LargeList(_) => {
3197 list_to_large_binary_expr(arg)
3198 }
3199 _ => scalar_to_large_binary_expr(arg),
3200 })
3201 .collect()
3202 };
3203 return Ok(DfExpr::ScalarFunction(
3204 datafusion::logical_expr::expr::ScalarFunction {
3205 func: func.func.clone(),
3206 args: unified_args,
3207 },
3208 ));
3209 }
3210 }
3211
3212 Ok(DfExpr::ScalarFunction(
3213 datafusion::logical_expr::expr::ScalarFunction {
3214 func: func.func.clone(),
3215 args: coerced_args,
3216 },
3217 ))
3218}
3219
3220fn coerce_case_expr(
3223 case: &datafusion::logical_expr::expr::Case,
3224 schema: &datafusion::common::DFSchema,
3225) -> Result<DfExpr> {
3226 use datafusion::arrow::datatypes::DataType;
3227 use datafusion::logical_expr::ExprSchemable;
3228
3229 let coerced_operand = case
3230 .expr
3231 .as_ref()
3232 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3233 .transpose()?;
3234 let coerced_when_then = case
3235 .when_then_expr
3236 .iter()
3237 .map(|(w, t)| {
3238 let cw = apply_type_coercion(w, schema)?;
3239 let cw = match cw.get_type(schema).ok() {
3240 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3241 _ => cw,
3242 };
3243 let ct = apply_type_coercion(t, schema)?;
3244 Ok((Box::new(cw), Box::new(ct)))
3245 })
3246 .collect::<Result<Vec<_>>>()?;
3247 let coerced_else = case
3248 .else_expr
3249 .as_ref()
3250 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3251 .transpose()?;
3252
3253 let mut result_case = if let Some(operand) = coerced_operand {
3254 crate::query::cypher_type_coerce::rewrite_simple_case_to_generic(
3255 *operand,
3256 coerced_when_then,
3257 coerced_else,
3258 schema,
3259 )?
3260 } else {
3261 datafusion::logical_expr::expr::Case {
3262 expr: None,
3263 when_then_expr: coerced_when_then,
3264 else_expr: coerced_else,
3265 }
3266 };
3267
3268 crate::query::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3269
3270 Ok(DfExpr::Case(result_case))
3271}
3272
3273fn coerce_aggregate_function(
3275 agg: &datafusion::logical_expr::expr::AggregateFunction,
3276 schema: &datafusion::common::DFSchema,
3277) -> Result<DfExpr> {
3278 let coerced_args: Vec<DfExpr> = agg
3279 .params
3280 .args
3281 .iter()
3282 .map(|a| apply_type_coercion(a, schema))
3283 .collect::<Result<Vec<_>>>()?;
3284 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3285 .params
3286 .order_by
3287 .iter()
3288 .map(|s| {
3289 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3290 Ok(datafusion::logical_expr::SortExpr {
3291 expr: coerced_expr,
3292 asc: s.asc,
3293 nulls_first: s.nulls_first,
3294 })
3295 })
3296 .collect::<Result<Vec<_>>>()?;
3297 let coerced_filter = agg
3298 .params
3299 .filter
3300 .as_ref()
3301 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3302 .transpose()?;
3303 Ok(DfExpr::AggregateFunction(
3304 datafusion::logical_expr::expr::AggregateFunction {
3305 func: agg.func.clone(),
3306 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3307 args: coerced_args,
3308 distinct: agg.params.distinct,
3309 filter: coerced_filter,
3310 order_by: coerced_order_by,
3311 null_treatment: agg.params.null_treatment,
3312 },
3313 },
3314 ))
3315}
3316
3317#[cfg(test)]
3318mod tests {
3319 use super::*;
3320 use arrow_array::{
3321 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3322 };
3323 use uni_common::TemporalValue;
3324 #[test]
3325 fn test_literal_translation() {
3326 let expr = Expr::Literal(CypherLiteral::Integer(42));
3327 let result = cypher_expr_to_df(&expr, None).unwrap();
3328 let s = format!("{:?}", result);
3329 assert!(s.contains("Literal"));
3331 assert!(s.contains("Int64(42)"));
3332 }
3333
3334 #[test]
3335 fn test_property_access_no_context_uses_index() {
3336 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3338 let result = cypher_expr_to_df(&expr, None).unwrap();
3339 let s = format!("{}", result);
3340 assert!(
3341 s.contains("index"),
3342 "expected index UDF for non-graph variable, got: {s}"
3343 );
3344 }
3345
3346 #[test]
3347 fn test_comparison_operator() {
3348 let expr = Expr::BinaryOp {
3349 left: Box::new(Expr::Property(
3350 Box::new(Expr::Variable("n".to_string())),
3351 "age".to_string(),
3352 )),
3353 op: BinaryOp::Gt,
3354 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3355 };
3356 let result = cypher_expr_to_df(&expr, None).unwrap();
3357 let s = format!("{:?}", result);
3359 assert!(s.contains("age"));
3360 assert!(s.contains("30"));
3361 }
3362
3363 #[test]
3364 fn test_boolean_operators() {
3365 let expr = Expr::BinaryOp {
3366 left: Box::new(Expr::BinaryOp {
3367 left: Box::new(Expr::Property(
3368 Box::new(Expr::Variable("n".to_string())),
3369 "age".to_string(),
3370 )),
3371 op: BinaryOp::Gt,
3372 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3373 }),
3374 op: BinaryOp::And,
3375 right: Box::new(Expr::BinaryOp {
3376 left: Box::new(Expr::Property(
3377 Box::new(Expr::Variable("n".to_string())),
3378 "active".to_string(),
3379 )),
3380 op: BinaryOp::Eq,
3381 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3382 }),
3383 };
3384 let result = cypher_expr_to_df(&expr, None).unwrap();
3385 let s = format!("{:?}", result);
3386 assert!(s.contains("And"));
3387 }
3388
3389 #[test]
3390 fn test_is_null() {
3391 let expr = Expr::IsNull(Box::new(Expr::Property(
3392 Box::new(Expr::Variable("n".to_string())),
3393 "email".to_string(),
3394 )));
3395 let result = cypher_expr_to_df(&expr, None).unwrap();
3396 let s = format!("{:?}", result);
3397 assert!(s.contains("IsNull"));
3398 }
3399
3400 #[test]
3401 fn test_collect_properties() {
3402 let expr = Expr::BinaryOp {
3403 left: Box::new(Expr::Property(
3404 Box::new(Expr::Variable("n".to_string())),
3405 "name".to_string(),
3406 )),
3407 op: BinaryOp::Eq,
3408 right: Box::new(Expr::Property(
3409 Box::new(Expr::Variable("m".to_string())),
3410 "name".to_string(),
3411 )),
3412 };
3413
3414 let props = collect_properties(&expr);
3415 assert_eq!(props.len(), 2);
3416 assert!(props.contains(&("m".to_string(), "name".to_string())));
3417 assert!(props.contains(&("n".to_string(), "name".to_string())));
3418 }
3419
3420 #[test]
3421 fn test_function_call() {
3422 let expr = Expr::FunctionCall {
3423 name: "count".to_string(),
3424 args: vec![Expr::Wildcard],
3425 distinct: false,
3426 window_spec: None,
3427 };
3428 let result = cypher_expr_to_df(&expr, None).unwrap();
3429 let s = format!("{:?}", result);
3430 assert!(s.to_lowercase().contains("count"));
3431 }
3432
3433 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3438 use datafusion::logical_expr::Operator;
3439
3440 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3442 let fields: Vec<_> = cols
3443 .iter()
3444 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3445 .collect();
3446 let schema = Schema::new(fields);
3447 datafusion::common::DFSchema::try_from(schema).unwrap()
3448 }
3449
3450 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3452 let s = format!("{}", expr);
3453 s.contains(name)
3454 }
3455
3456 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3458 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3459 }
3460
3461 #[test]
3462 fn test_coercion_lb_eq_int64() {
3463 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3464 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3465 Box::new(col("lb")),
3466 Operator::Eq,
3467 Box::new(col("i")),
3468 ));
3469 let result = apply_type_coercion(&expr, &schema).unwrap();
3470 assert!(
3472 contains_udf(&result, "_cypher_equal"),
3473 "expected _cypher_equal, got: {result}"
3474 );
3475 }
3476
3477 #[test]
3478 fn test_coercion_lb_noteq_int64() {
3479 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3480 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3481 Box::new(col("lb")),
3482 Operator::NotEq,
3483 Box::new(col("i")),
3484 ));
3485 let result = apply_type_coercion(&expr, &schema).unwrap();
3486 assert!(contains_udf(&result, "_cypher_not_equal"));
3488 }
3489
3490 #[test]
3491 fn test_coercion_lb_lt_int64() {
3492 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3493 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3494 Box::new(col("lb")),
3495 Operator::Lt,
3496 Box::new(col("i")),
3497 ));
3498 let result = apply_type_coercion(&expr, &schema).unwrap();
3499 assert!(contains_udf(&result, "_cypher_lt"));
3501 }
3502
3503 #[test]
3504 fn test_coercion_lb_eq_float64() {
3505 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3506 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3507 Box::new(col("lb")),
3508 Operator::Eq,
3509 Box::new(col("f")),
3510 ));
3511 let result = apply_type_coercion(&expr, &schema).unwrap();
3512 assert!(contains_udf(&result, "_cypher_equal"));
3514 }
3515
3516 #[test]
3517 fn test_coercion_lb_eq_utf8() {
3518 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3519 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3520 Box::new(col("lb")),
3521 Operator::Eq,
3522 Box::new(col("s")),
3523 ));
3524 let result = apply_type_coercion(&expr, &schema).unwrap();
3525 assert!(contains_udf(&result, "_cypher_equal"));
3527 }
3528
3529 #[test]
3530 fn test_coercion_lb_eq_bool() {
3531 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3532 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3533 Box::new(col("lb")),
3534 Operator::Eq,
3535 Box::new(col("b")),
3536 ));
3537 let result = apply_type_coercion(&expr, &schema).unwrap();
3538 assert!(contains_udf(&result, "_cypher_equal"));
3540 }
3541
3542 #[test]
3543 fn test_coercion_int64_eq_lb() {
3544 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3546 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3547 Box::new(col("i")),
3548 Operator::Eq,
3549 Box::new(col("lb")),
3550 ));
3551 let result = apply_type_coercion(&expr, &schema).unwrap();
3552 assert!(contains_udf(&result, "_cypher_equal"));
3554 }
3555
3556 #[test]
3557 fn test_coercion_float64_gt_lb() {
3558 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3559 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3560 Box::new(col("f")),
3561 Operator::Gt,
3562 Box::new(col("lb")),
3563 ));
3564 let result = apply_type_coercion(&expr, &schema).unwrap();
3565 assert!(contains_udf(&result, "_cypher_gt"));
3567 }
3568
3569 #[test]
3570 fn test_coercion_both_lb_eq() {
3571 let schema = make_schema(&[
3572 ("lb1", DataType::LargeBinary),
3573 ("lb2", DataType::LargeBinary),
3574 ]);
3575 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3576 Box::new(col("lb1")),
3577 Operator::Eq,
3578 Box::new(col("lb2")),
3579 ));
3580 let result = apply_type_coercion(&expr, &schema).unwrap();
3581 assert!(contains_udf(&result, "_cypher_equal"));
3582 }
3583
3584 #[test]
3585 fn test_coercion_both_lb_lt() {
3586 let schema = make_schema(&[
3587 ("lb1", DataType::LargeBinary),
3588 ("lb2", DataType::LargeBinary),
3589 ]);
3590 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3591 Box::new(col("lb1")),
3592 Operator::Lt,
3593 Box::new(col("lb2")),
3594 ));
3595 let result = apply_type_coercion(&expr, &schema).unwrap();
3596 assert!(contains_udf(&result, "_cypher_lt"));
3597 }
3598
3599 #[test]
3600 fn test_coercion_both_lb_noteq() {
3601 let schema = make_schema(&[
3602 ("lb1", DataType::LargeBinary),
3603 ("lb2", DataType::LargeBinary),
3604 ]);
3605 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3606 Box::new(col("lb1")),
3607 Operator::NotEq,
3608 Box::new(col("lb2")),
3609 ));
3610 let result = apply_type_coercion(&expr, &schema).unwrap();
3611 assert!(contains_udf(&result, "_cypher_not_equal"));
3612 }
3613
3614 #[test]
3615 fn test_coercion_lb_plus_int64() {
3616 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3617 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3618 Box::new(col("lb")),
3619 Operator::Plus,
3620 Box::new(col("i")),
3621 ));
3622 let result = apply_type_coercion(&expr, &schema).unwrap();
3623 assert!(contains_udf(&result, "_cypher_add"));
3624 }
3625
3626 #[test]
3627 fn test_coercion_lb_minus_int64() {
3628 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3629 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3630 Box::new(col("lb")),
3631 Operator::Minus,
3632 Box::new(col("i")),
3633 ));
3634 let result = apply_type_coercion(&expr, &schema).unwrap();
3635 assert!(contains_udf(&result, "_cypher_sub"));
3636 }
3637
3638 #[test]
3639 fn test_coercion_lb_multiply_float64() {
3640 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3641 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3642 Box::new(col("lb")),
3643 Operator::Multiply,
3644 Box::new(col("f")),
3645 ));
3646 let result = apply_type_coercion(&expr, &schema).unwrap();
3647 assert!(contains_udf(&result, "_cypher_mul"));
3648 }
3649
3650 #[test]
3651 fn test_coercion_int64_plus_lb() {
3652 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3653 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3654 Box::new(col("i")),
3655 Operator::Plus,
3656 Box::new(col("lb")),
3657 ));
3658 let result = apply_type_coercion(&expr, &schema).unwrap();
3659 assert!(contains_udf(&result, "_cypher_add"));
3660 }
3661
3662 #[test]
3663 fn test_coercion_lb_plus_utf8() {
3664 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3666 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3667 Box::new(col("lb")),
3668 Operator::Plus,
3669 Box::new(col("s")),
3670 ));
3671 let result = apply_type_coercion(&expr, &schema).unwrap();
3672 assert!(contains_udf(&result, "_cypher_add"));
3674 }
3675
3676 #[test]
3677 fn test_coercion_and_null_bool() {
3678 let schema = make_schema(&[("b", DataType::Boolean)]);
3679 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3681 Box::new(lit(ScalarValue::Null)),
3682 Operator::And,
3683 Box::new(col("b")),
3684 ));
3685 let result = apply_type_coercion(&expr, &schema).unwrap();
3686 let s = format!("{}", result);
3687 assert!(
3689 s.contains("CAST") || s.contains("Boolean"),
3690 "expected cast to Boolean, got: {s}"
3691 );
3692 assert!(is_binary_op(&result, Operator::And));
3693 }
3694
3695 #[test]
3696 fn test_coercion_bool_and_null() {
3697 let schema = make_schema(&[("b", DataType::Boolean)]);
3698 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3699 Box::new(col("b")),
3700 Operator::And,
3701 Box::new(lit(ScalarValue::Null)),
3702 ));
3703 let result = apply_type_coercion(&expr, &schema).unwrap();
3704 assert!(is_binary_op(&result, Operator::And));
3705 }
3706
3707 #[test]
3708 fn test_coercion_or_null_bool() {
3709 let schema = make_schema(&[("b", DataType::Boolean)]);
3710 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3711 Box::new(lit(ScalarValue::Null)),
3712 Operator::Or,
3713 Box::new(col("b")),
3714 ));
3715 let result = apply_type_coercion(&expr, &schema).unwrap();
3716 assert!(is_binary_op(&result, Operator::Or));
3717 }
3718
3719 #[test]
3720 fn test_coercion_null_and_null() {
3721 let schema = make_schema(&[]);
3722 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3723 Box::new(lit(ScalarValue::Null)),
3724 Operator::And,
3725 Box::new(lit(ScalarValue::Null)),
3726 ));
3727 let result = apply_type_coercion(&expr, &schema).unwrap();
3728 assert!(is_binary_op(&result, Operator::And));
3729 }
3730
3731 #[test]
3732 fn test_coercion_bool_and_bool_noop() {
3733 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3734 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3735 Box::new(col("a")),
3736 Operator::And,
3737 Box::new(col("b")),
3738 ));
3739 let result = apply_type_coercion(&expr, &schema).unwrap();
3740 assert!(is_binary_op(&result, Operator::And));
3742 let s = format!("{}", result);
3743 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3744 }
3745
3746 #[test]
3747 fn test_coercion_case_when_lb() {
3748 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3750 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3751 Box::new(col("lb")),
3752 Operator::Eq,
3753 Box::new(lit(42_i64)),
3754 ));
3755 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3756 expr: None,
3757 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3758 else_expr: Some(Box::new(lit("b"))),
3759 });
3760 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3761 let s = format!("{}", result);
3762 assert!(
3764 s.contains("_cypher_equal"),
3765 "CASE WHEN should have _cypher_equal, got: {s}"
3766 );
3767 }
3768
3769 #[test]
3770 fn test_coercion_case_then_lb() {
3771 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3773 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3774 Box::new(col("lb")),
3775 Operator::Plus,
3776 Box::new(lit(1_i64)),
3777 ));
3778 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3779 expr: None,
3780 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3781 else_expr: Some(Box::new(lit(0_i64))),
3782 });
3783 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3784 let s = format!("{}", result);
3785 assert!(
3786 s.contains("_cypher_add"),
3787 "CASE THEN should have _cypher_add, got: {s}"
3788 );
3789 }
3790
3791 #[test]
3792 fn test_coercion_case_else_lb() {
3793 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3795 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3796 Box::new(col("lb")),
3797 Operator::Plus,
3798 Box::new(lit(2_i64)),
3799 ));
3800 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3801 expr: None,
3802 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3803 else_expr: Some(Box::new(else_expr)),
3804 });
3805 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3806 let s = format!("{}", result);
3807 assert!(
3808 s.contains("_cypher_add"),
3809 "CASE ELSE should have _cypher_add, got: {s}"
3810 );
3811 }
3812
3813 #[test]
3814 fn test_coercion_int64_eq_int64_noop() {
3815 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3816 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3817 Box::new(col("a")),
3818 Operator::Eq,
3819 Box::new(col("b")),
3820 ));
3821 let result = apply_type_coercion(&expr, &schema).unwrap();
3822 assert!(is_binary_op(&result, Operator::Eq));
3823 let s = format!("{}", result);
3824 assert!(
3825 !s.contains("_cypher_value"),
3826 "should not contain cypher_value decode: {s}"
3827 );
3828 }
3829
3830 #[test]
3831 fn test_coercion_both_lb_plus() {
3832 let schema = make_schema(&[
3834 ("lb1", DataType::LargeBinary),
3835 ("lb2", DataType::LargeBinary),
3836 ]);
3837 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3838 Box::new(col("lb1")),
3839 Operator::Plus,
3840 Box::new(col("lb2")),
3841 ));
3842 let result = apply_type_coercion(&expr, &schema).unwrap();
3843 assert!(
3844 contains_udf(&result, "_cypher_add"),
3845 "expected _cypher_add, got: {result}"
3846 );
3847 }
3848
3849 #[test]
3850 fn test_coercion_native_list_plus_scalar() {
3851 let schema = make_schema(&[
3853 (
3854 "lst",
3855 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3856 ),
3857 ("i", DataType::Int32),
3858 ]);
3859 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3860 Box::new(col("lst")),
3861 Operator::Plus,
3862 Box::new(col("i")),
3863 ));
3864 let result = apply_type_coercion(&expr, &schema).unwrap();
3865 assert!(
3866 contains_udf(&result, "_cypher_list_append"),
3867 "expected _cypher_list_append, got: {result}"
3868 );
3869 }
3870
3871 #[test]
3872 fn test_coercion_lb_plus_int64_unchanged() {
3873 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3875 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3876 Box::new(col("lb")),
3877 Operator::Plus,
3878 Box::new(col("i")),
3879 ));
3880 let result = apply_type_coercion(&expr, &schema).unwrap();
3881 assert!(
3882 contains_udf(&result, "_cypher_add"),
3883 "expected _cypher_add, got: {result}"
3884 );
3885 }
3886
3887 #[test]
3892 fn test_mixed_list_with_variables_compiles() {
3893 let expr = Expr::List(vec![
3895 Expr::Variable("n".to_string()),
3896 Expr::Literal(CypherLiteral::Integer(1)),
3897 Expr::Literal(CypherLiteral::String("hello".to_string())),
3898 ]);
3899 let result = cypher_expr_to_df(&expr, None).unwrap();
3900 let s = format!("{}", result);
3901 assert!(
3902 s.contains("_make_cypher_list"),
3903 "expected _make_cypher_list UDF call, got: {s}"
3904 );
3905 }
3906
3907 #[test]
3908 fn test_literal_only_mixed_list_uses_cv_fastpath() {
3909 let expr = Expr::List(vec![
3911 Expr::Literal(CypherLiteral::Integer(1)),
3912 Expr::Literal(CypherLiteral::String("hi".to_string())),
3913 Expr::Literal(CypherLiteral::Bool(true)),
3914 ]);
3915 let result = cypher_expr_to_df(&expr, None).unwrap();
3916 assert!(
3917 matches!(result, DfExpr::Literal(..)),
3918 "expected Literal (CypherValue fast path), got: {result}"
3919 );
3920 }
3921
3922 #[test]
3927 fn test_in_mixed_literal_list_uses_cypher_in() {
3928 let expr = Expr::In {
3930 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3931 list: Box::new(Expr::List(vec![
3932 Expr::Literal(CypherLiteral::String("1".to_string())),
3933 Expr::Literal(CypherLiteral::Integer(2)),
3934 ])),
3935 };
3936 let result = cypher_expr_to_df(&expr, None).unwrap();
3937 let s = format!("{}", result);
3938 assert!(
3939 s.contains("_cypher_in"),
3940 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
3941 );
3942 }
3943
3944 #[test]
3945 fn test_in_homogeneous_literal_list_uses_cypher_in() {
3946 let expr = Expr::In {
3948 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3949 list: Box::new(Expr::List(vec![
3950 Expr::Literal(CypherLiteral::Integer(2)),
3951 Expr::Literal(CypherLiteral::Integer(3)),
3952 ])),
3953 };
3954 let result = cypher_expr_to_df(&expr, None).unwrap();
3955 let s = format!("{}", result);
3956 assert!(
3957 s.contains("_cypher_in"),
3958 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
3959 );
3960 }
3961
3962 #[test]
3963 fn test_in_list_with_variables_uses_make_cypher_list() {
3964 let expr = Expr::In {
3966 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3967 list: Box::new(Expr::List(vec![
3968 Expr::Variable("x".to_string()),
3969 Expr::Literal(CypherLiteral::Integer(2)),
3970 ])),
3971 };
3972 let result = cypher_expr_to_df(&expr, None).unwrap();
3973 let s = format!("{}", result);
3974 assert!(
3975 s.contains("_cypher_in"),
3976 "expected _cypher_in UDF, got: {s}"
3977 );
3978 assert!(
3979 s.contains("_make_cypher_list"),
3980 "expected _make_cypher_list for variable-containing list, got: {s}"
3981 );
3982 }
3983
3984 #[test]
3989 fn test_property_on_graph_entity_uses_column() {
3990 let mut ctx = TranslationContext::new();
3992 ctx.variable_kinds
3993 .insert("n".to_string(), VariableKind::Node);
3994
3995 let expr = Expr::Property(
3996 Box::new(Expr::Variable("n".to_string())),
3997 "name".to_string(),
3998 );
3999 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4000 let s = format!("{:?}", result);
4001 assert!(
4002 s.contains("Column") && s.contains("n.name"),
4003 "expected flat column 'n.name' for graph entity, got: {s}"
4004 );
4005 }
4006
4007 #[test]
4008 fn test_property_on_non_graph_var_uses_index() {
4009 let ctx = TranslationContext::new();
4011
4012 let expr = Expr::Property(
4013 Box::new(Expr::Variable("map".to_string())),
4014 "name".to_string(),
4015 );
4016 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4017 let s = format!("{}", result);
4018 assert!(
4019 s.contains("index"),
4020 "expected index UDF for non-graph variable, got: {s}"
4021 );
4022 }
4023
4024 #[test]
4025 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4026 let mut map = std::collections::HashMap::new();
4027 map.insert("k".to_string(), Value::Int(1));
4028 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4029 assert!(
4030 matches!(scalar, ScalarValue::Struct(_)),
4031 "expected Struct scalar for map input"
4032 );
4033 }
4034
4035 #[test]
4036 fn test_value_to_scalar_empty_map_becomes_struct() {
4037 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4038 assert!(
4039 matches!(scalar, ScalarValue::Struct(_)),
4040 "empty map should produce an empty Struct scalar"
4041 );
4042 }
4043
4044 #[test]
4045 fn test_value_to_scalar_null_is_untyped_null() {
4046 let scalar = value_to_scalar(&Value::Null).unwrap();
4047 assert!(
4048 matches!(scalar, ScalarValue::Null),
4049 "expected untyped Null scalar for Value::Null"
4050 );
4051 }
4052
4053 #[test]
4054 fn test_value_to_scalar_datetime_produces_struct() {
4055 let datetime = Value::Temporal(TemporalValue::DateTime {
4057 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4060 });
4061
4062 let scalar = value_to_scalar(&datetime).unwrap();
4063
4064 if let ScalarValue::Struct(struct_arr) = scalar {
4066 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4067 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4068
4069 let fields = struct_arr.fields();
4071 assert_eq!(fields[0].name(), "nanos_since_epoch");
4072 assert_eq!(fields[1].name(), "offset_seconds");
4073 assert_eq!(fields[2].name(), "timezone_name");
4074
4075 let nanos_col = struct_arr.column(0);
4077 let offset_col = struct_arr.column(1);
4078 let tz_col = struct_arr.column(2);
4079
4080 if let Some(nanos_arr) = nanos_col
4081 .as_any()
4082 .downcast_ref::<TimestampNanosecondArray>()
4083 {
4084 assert_eq!(nanos_arr.value(0), 441763200000000000);
4085 } else {
4086 panic!("Expected TimestampNanosecondArray for nanos field");
4087 }
4088
4089 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4090 assert_eq!(offset_arr.value(0), 3600);
4091 } else {
4092 panic!("Expected Int32Array for offset field");
4093 }
4094
4095 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4096 assert_eq!(tz_arr.value(0), "Europe/Paris");
4097 } else {
4098 panic!("Expected StringArray for timezone_name field");
4099 }
4100 } else {
4101 panic!(
4102 "Expected ScalarValue::Struct for DateTime, got {:?}",
4103 scalar
4104 );
4105 }
4106 }
4107
4108 #[test]
4109 fn test_value_to_scalar_datetime_with_null_timezone() {
4110 let datetime = Value::Temporal(TemporalValue::DateTime {
4112 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4115 });
4116
4117 let scalar = value_to_scalar(&datetime).unwrap();
4118
4119 if let ScalarValue::Struct(struct_arr) = scalar {
4120 assert_eq!(struct_arr.num_columns(), 3);
4121
4122 let tz_col = struct_arr.column(2);
4124 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4125 assert!(tz_arr.is_null(0), "expected null timezone_name");
4126 } else {
4127 panic!("Expected StringArray for timezone_name field");
4128 }
4129 } else {
4130 panic!("Expected ScalarValue::Struct for DateTime");
4131 }
4132 }
4133
4134 #[test]
4135 fn test_value_to_scalar_time_produces_struct() {
4136 let time = Value::Temporal(TemporalValue::Time {
4138 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4141
4142 let scalar = value_to_scalar(&time).unwrap();
4143
4144 if let ScalarValue::Struct(struct_arr) = scalar {
4146 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4147 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4148
4149 let fields = struct_arr.fields();
4151 assert_eq!(fields[0].name(), "nanos_since_midnight");
4152 assert_eq!(fields[1].name(), "offset_seconds");
4153
4154 let nanos_col = struct_arr.column(0);
4156 let offset_col = struct_arr.column(1);
4157
4158 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4159 assert_eq!(nanos_arr.value(0), 37845000000000);
4160 } else {
4161 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4162 }
4163
4164 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4165 assert_eq!(offset_arr.value(0), 3600);
4166 } else {
4167 panic!("Expected Int32Array for offset field");
4168 }
4169 } else {
4170 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4171 }
4172 }
4173
4174 #[test]
4175 fn test_value_to_scalar_time_boundary_values() {
4176 let midnight = Value::Temporal(TemporalValue::Time {
4178 nanos_since_midnight: 0,
4179 offset_seconds: 0,
4180 });
4181
4182 let scalar = value_to_scalar(&midnight).unwrap();
4183
4184 if let ScalarValue::Struct(struct_arr) = scalar {
4185 let nanos_col = struct_arr.column(0);
4186 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4187 assert_eq!(nanos_arr.value(0), 0);
4188 } else {
4189 panic!("Expected Time64NanosecondArray");
4190 }
4191 } else {
4192 panic!("Expected ScalarValue::Struct for Time");
4193 }
4194 }
4195}