1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub(crate) fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub(crate) fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub(crate) fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub(crate) fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum VariableKind {
224 Node,
226 Edge,
228 EdgeList,
230 Path,
232}
233
234impl VariableKind {
235 pub fn edge_for(is_variable_length: bool) -> Self {
239 if is_variable_length {
240 Self::EdgeList
241 } else {
242 Self::Edge
243 }
244 }
245}
246
247pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
282 match expr {
283 Expr::PatternComprehension { .. } => Err(anyhow!(
284 "Pattern comprehensions require fallback executor (graph traversal)"
285 )),
286 #[expect(deprecated)]
289 Expr::Wildcard => Ok(DfExpr::Wildcard {
290 qualifier: None,
291 options: Default::default(),
292 }),
293
294 Expr::Variable(name) => {
295 if let Some(ctx) = context
300 && ctx.variable_kinds.contains_key(name)
301 {
302 return Ok(DfExpr::Column(Column::from_name(name)));
303 }
304
305 if let Some(ctx) = context
311 && let Some(value) = ctx.outer_values.get(name)
312 {
313 return value_to_scalar(value).map(lit);
314 }
315
316 if let Some(ctx) = context
321 && let Some(value) = ctx.parameters.get(name)
322 {
323 match value {
326 Value::List(values) if name.ends_with("._vid") => {
327 let literals = values
329 .iter()
330 .map(|v| value_to_scalar(v).map(lit))
331 .collect::<Result<Vec<_>>>()?;
332 return Ok(DfExpr::InList(InList {
333 expr: Box::new(DfExpr::Column(Column::from_name(name))),
334 list: literals,
335 negated: false,
336 }));
337 }
338 other_value => return value_to_scalar(other_value).map(lit),
339 }
340 }
341
342 Ok(DfExpr::Column(Column::from_name(name)))
345 }
346
347 Expr::Property(base, prop) => translate_property_access(base, prop, context),
348
349 Expr::ArrayIndex { array, index } => {
350 if let Ok(var_name) = extract_variable_name(array)
353 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
354 {
355 let col_name = format!("{}.{}", var_name, prop_name);
356 return Ok(DfExpr::Column(Column::from_name(col_name)));
357 }
358
359 let array_expr = cypher_expr_to_df(array, context)?;
360 let index_expr = cypher_expr_to_df(index, context)?;
361
362 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
364 }
365
366 Expr::ArraySlice { array, start, end } => {
367 let array_expr = cypher_expr_to_df(array, context)?;
371
372 let start_expr = match start {
373 Some(s) => cypher_expr_to_df(s, context)?,
374 None => lit(0i64),
375 };
376
377 let end_expr = match end {
378 Some(e) => cypher_expr_to_df(e, context)?,
379 None => lit(i64::MAX),
380 };
381
382 Ok(dummy_udf_expr(
385 "_cypher_list_slice",
386 vec![array_expr, start_expr, end_expr],
387 ))
388 }
389
390 Expr::Parameter(name) => {
391 if let Some(ctx) = context
393 && let Some(value) = ctx.parameters.get(name)
394 {
395 return value_to_scalar(value).map(lit);
396 }
397 Err(anyhow!("Unresolved parameter: ${}", name))
398 }
399
400 Expr::Literal(value) => {
401 let scalar = cypher_literal_to_scalar(value)?;
402 Ok(lit(scalar))
403 }
404
405 Expr::List(items) => translate_list_literal(items, context),
406
407 Expr::Map(entries) => {
408 if entries.is_empty() {
409 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
411 Default::default(),
412 ));
413 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
414 }
415 let mut args = Vec::with_capacity(entries.len() * 2);
418 for (key, val_expr) in entries {
419 args.push(lit(key.clone()));
420 args.push(cypher_expr_to_df(val_expr, context)?);
421 }
422 Ok(datafusion::functions::expr_fn::named_struct(args))
423 }
424
425 Expr::IsNull(inner) => translate_null_check(inner, context, true),
426
427 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
428
429 Expr::IsUnique(_) => {
430 Err(anyhow!(
432 "IS UNIQUE can only be used in constraint definitions"
433 ))
434 }
435
436 Expr::FunctionCall {
437 name,
438 args,
439 distinct,
440 window_spec,
441 } => {
442 if window_spec.is_some() {
445 let col_name = expr.to_string_repr();
447 Ok(col(&col_name))
448 } else {
449 translate_function_call(name, args, *distinct, context)
450 }
451 }
452
453 Expr::In { expr, list } => translate_in_expression(expr, list, context),
454
455 Expr::BinaryOp { left, op, right } => {
456 let left_expr = cypher_expr_to_df(left, context)?;
457 let right_expr = cypher_expr_to_df(right, context)?;
458 translate_binary_op(left_expr, op, right_expr)
459 }
460
461 Expr::UnaryOp { op, expr: inner } => {
462 let inner_expr = cypher_expr_to_df(inner, context)?;
463 match op {
464 UnaryOp::Not => Ok(inner_expr.not()),
465 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
466 }
467 }
468
469 Expr::Case {
470 expr,
471 when_then,
472 else_expr,
473 } => translate_case_expression(expr, when_then, else_expr, context),
474
475 Expr::Reduce { .. } => Err(anyhow!(
476 "Reduce expressions not yet supported in DataFusion translation"
477 )),
478
479 Expr::Exists { .. } => Err(anyhow!(
480 "EXISTS subqueries are handled by the physical expression compiler, \
481 not the DataFusion logical expression translator"
482 )),
483
484 Expr::CountSubquery(_) => Err(anyhow!(
485 "Count subqueries not yet supported in DataFusion translation"
486 )),
487
488 Expr::CollectSubquery(_) => Err(anyhow!(
489 "COLLECT subqueries not yet supported in DataFusion translation"
490 )),
491
492 Expr::Quantifier { .. } => {
493 Err(anyhow!(
498 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
499 via CypherPhysicalExprCompiler"
500 ))
501 }
502
503 Expr::ListComprehension { .. } => {
504 Err(anyhow!(
516 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
517 ))
518 }
519
520 Expr::ValidAt { .. } => {
521 Err(anyhow!(
524 "VALID_AT expression should have been transformed to function call in planner"
525 ))
526 }
527
528 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
529
530 Expr::LabelCheck { expr, labels } => {
531 if let Expr::Variable(var) = expr.as_ref() {
532 let is_edge = context
534 .and_then(|ctx| ctx.variable_kinds.get(var))
535 .is_some_and(|k| matches!(k, VariableKind::Edge));
536
537 if is_edge {
538 if labels.len() > 1 {
542 Ok(lit(false))
543 } else {
544 let type_col =
545 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
546 Ok(DfExpr::Case(datafusion::logical_expr::Case {
548 expr: None,
549 when_then_expr: vec![(
550 Box::new(type_col.clone().is_null()),
551 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
552 )],
553 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
554 }))
555 }
556 } else {
557 let labels_col =
559 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
560 let mut checks: Option<DfExpr> = None;
561 for label in labels {
562 let check = datafusion::functions_nested::expr_fn::array_has(
563 labels_col.clone(),
564 lit(label.clone()),
565 );
566 checks = Some(match checks {
567 Some(prev) => prev.and(check),
568 None => check,
569 });
570 }
571 Ok(DfExpr::Case(datafusion::logical_expr::Case {
573 expr: None,
574 when_then_expr: vec![(
575 Box::new(labels_col.is_null()),
576 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
577 )],
578 else_expr: Some(Box::new(checks.unwrap())),
579 }))
580 }
581 } else {
582 Err(anyhow!(
583 "LabelCheck on non-variable expression not yet supported in DataFusion"
584 ))
585 }
586 }
587 }
588}
589
590#[derive(Debug, Clone)]
594pub struct TranslationContext {
595 pub parameters: std::collections::HashMap<String, Value>,
597
598 pub outer_values: std::collections::HashMap<String, Value>,
602
603 pub variable_labels: std::collections::HashMap<String, String>,
605
606 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
608
609 pub node_variable_hints: Vec<String>,
612
613 pub mutation_edge_hints: Vec<String>,
616
617 pub statement_time: chrono::DateTime<chrono::Utc>,
622}
623
624impl Default for TranslationContext {
625 fn default() -> Self {
626 Self {
627 parameters: std::collections::HashMap::new(),
628 outer_values: std::collections::HashMap::new(),
629 variable_labels: std::collections::HashMap::new(),
630 variable_kinds: std::collections::HashMap::new(),
631 node_variable_hints: Vec::new(),
632 mutation_edge_hints: Vec::new(),
633 statement_time: chrono::Utc::now(),
634 }
635 }
636}
637
638impl TranslationContext {
639 pub fn new() -> Self {
641 Self::default()
642 }
643
644 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
646 self.parameters.insert(name.into(), value);
647 self
648 }
649
650 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
652 self.variable_labels.insert(var.into(), label.into());
653 self
654 }
655}
656
657fn extract_variable_name(expr: &Expr) -> Result<String> {
659 match expr {
660 Expr::Variable(name) => Ok(name.clone()),
661 Expr::Property(base, _) => extract_variable_name(base),
662 _ => Err(anyhow!(
663 "Cannot extract variable name from expression: {:?}",
664 expr
665 )),
666 }
667}
668
669fn translate_null_check(
671 inner: &Expr,
672 context: Option<&TranslationContext>,
673 is_null: bool,
674) -> Result<DfExpr> {
675 if let Expr::Variable(var) = inner
676 && let Some(ctx) = context
677 && let Some(kind) = ctx.variable_kinds.get(var)
678 {
679 let col_name = match kind {
680 VariableKind::Node => format!("{}.{}", var, COL_VID),
681 VariableKind::Edge => format!("{}.{}", var, COL_EID),
682 VariableKind::Path | VariableKind::EdgeList => var.clone(),
683 };
684 let col_expr = DfExpr::Column(Column::from_name(col_name));
685 return Ok(if is_null {
686 col_expr.is_null()
687 } else {
688 col_expr.is_not_null()
689 });
690 }
691
692 let inner_expr = cypher_expr_to_df(inner, context)?;
693 Ok(if is_null {
694 inner_expr.is_null()
695 } else {
696 inner_expr.is_not_null()
697 })
698}
699
700fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
705 if crate::query::datetime::is_duration_accessor(prop) {
706 Some(dummy_udf_expr(
707 "_duration_property",
708 vec![base_expr, lit(prop.to_string())],
709 ))
710 } else if crate::query::datetime::is_temporal_accessor(prop) {
711 Some(dummy_udf_expr(
712 "_temporal_property",
713 vec![base_expr, lit(prop.to_string())],
714 ))
715 } else {
716 None
717 }
718}
719
720fn translate_property_access(
722 base: &Expr,
723 prop: &str,
724 context: Option<&TranslationContext>,
725) -> Result<DfExpr> {
726 if let Ok(var_name) = extract_variable_name(base) {
727 let is_graph_entity = context
728 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
729 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
730
731 if !is_graph_entity
732 && let Some(expr) =
733 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
734 {
735 return Ok(expr);
736 }
737
738 let col_name = format!("{}.{}", var_name, prop);
739
740 if let Some(ctx) = context
743 && let Some(value) = ctx.parameters.get(&col_name)
744 {
745 match value {
748 Value::List(values) if col_name.ends_with("._vid") => {
749 let literals = values
750 .iter()
751 .map(|v| value_to_scalar(v).map(lit))
752 .collect::<Result<Vec<_>>>()?;
753 return Ok(DfExpr::InList(InList {
754 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
755 list: literals,
756 negated: false,
757 }));
758 }
759 other_value => return value_to_scalar(other_value).map(lit),
760 }
761 }
762
763 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
766 let base_expr = cypher_expr_to_df(base, context)?;
767 return Ok(dummy_udf_expr(
768 "index",
769 vec![base_expr, lit(prop.to_string())],
770 ));
771 }
772
773 if is_graph_entity {
774 Ok(DfExpr::Column(Column::from_name(col_name)))
775 } else {
776 let base_expr = DfExpr::Column(Column::from_name(var_name));
777 Ok(dummy_udf_expr(
778 "index",
779 vec![base_expr, lit(prop.to_string())],
780 ))
781 }
782 } else {
783 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
785 return Ok(expr);
786 }
787
788 if let Expr::Parameter(param_name) = base {
790 if let Some(ctx) = context
791 && let Some(value) = ctx.parameters.get(param_name)
792 {
793 if let Value::Map(map) = value {
794 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
795 return value_to_scalar(&extracted).map(lit);
796 }
797 return Ok(lit(ScalarValue::Null));
798 }
799 return Err(anyhow!("Unresolved parameter: ${}", param_name));
800 }
801
802 let base_expr = cypher_expr_to_df(base, context)?;
803 Ok(dummy_udf_expr(
804 "index",
805 vec![base_expr, lit(prop.to_string())],
806 ))
807 }
808}
809
810fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
812 let mut has_string = false;
814 let mut has_bool = false;
815 let mut has_list = false;
816 let mut has_map = false;
817 let mut has_numeric = false;
818 let mut has_graph_entity = false;
819 let mut has_temporal = false;
820
821 for item in items {
822 match item {
823 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
824 has_numeric = true
825 }
826 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
827 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
828 Expr::List(_) => has_list = true,
829 Expr::Map(_) => has_map = true,
830 Expr::Variable(name) => {
833 if context
834 .and_then(|ctx| ctx.variable_kinds.get(name))
835 .is_some()
836 {
837 has_graph_entity = true;
838 }
839 }
840 Expr::FunctionCall { name, .. } => {
843 let upper = name.to_uppercase();
844 if matches!(
845 upper.as_str(),
846 "DATE"
847 | "TIME"
848 | "LOCALTIME"
849 | "LOCALDATETIME"
850 | "DATETIME"
851 | "DURATION"
852 | "DATE.TRUNCATE"
853 | "TIME.TRUNCATE"
854 | "DATETIME.TRUNCATE"
855 | "LOCALDATETIME.TRUNCATE"
856 | "LOCALTIME.TRUNCATE"
857 ) {
858 has_temporal = true;
859 }
860 }
861 _ => {}
863 }
864 }
865
866 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
868
869 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
872 if let Some(json_array) = try_items_to_json(items) {
874 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
875 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
876 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
877 }
878 let df_args: Vec<DfExpr> = items
880 .iter()
881 .map(|item| cypher_expr_to_df(item, context))
882 .collect::<Result<_>>()?;
883 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
884 }
885
886 let mut df_args = Vec::with_capacity(items.len());
889 let mut has_float = false;
890 let mut has_int = false;
891 let mut has_other = false;
892
893 for item in items {
894 match item {
895 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
896 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
897 _ => has_other = true,
898 }
899 df_args.push(cypher_expr_to_df(item, context)?);
900 }
901
902 if df_args.is_empty() {
903 let empty_arr =
905 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
906 Ok(lit(ScalarValue::List(empty_arr)))
907 } else if has_float && has_int && !has_other {
908 let promoted_args = df_args
910 .into_iter()
911 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
912 .collect();
913 Ok(datafusion::functions_nested::expr_fn::make_array(
914 promoted_args,
915 ))
916 } else {
917 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
918 }
919}
920
921fn translate_in_expression(
923 expr: &Expr,
924 list: &Expr,
925 context: Option<&TranslationContext>,
926) -> Result<DfExpr> {
927 let left_expr = if let Expr::Variable(var) = expr
932 && let Some(ctx) = context
933 && let Some(kind) = ctx.variable_kinds.get(var)
934 {
935 match kind {
936 VariableKind::Node | VariableKind::Edge => {
937 let id_col = match kind {
938 VariableKind::Node => COL_VID,
939 VariableKind::Edge => COL_EID,
940 _ => unreachable!(),
941 };
942 cast_expr(
943 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
944 datafusion::arrow::datatypes::DataType::Int64,
945 )
946 }
947 _ => cypher_expr_to_df(expr, context)?,
948 }
949 } else {
950 cypher_expr_to_df(expr, context)?
951 };
952
953 if let Expr::List(items) = list {
958 if let Some(json_array) = try_items_to_json(items) {
959 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
961 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
962 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
963 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
964 } else {
965 let expanded: Vec<DfExpr> = items
967 .iter()
968 .map(|item| cypher_expr_to_df(item, context))
969 .collect::<Result<Vec<_>>>()?;
970 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
971 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
972 }
973 } else {
974 let right_expr = cypher_expr_to_df(list, context)?;
975
976 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
981 return Ok(lit(ScalarValue::Boolean(None)));
982 }
983
984 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
985 }
986}
987
988fn translate_case_expression(
990 operand: &Option<Box<Expr>>,
991 when_then: &[(Expr, Expr)],
992 else_expr: &Option<Box<Expr>>,
993 context: Option<&TranslationContext>,
994) -> Result<DfExpr> {
995 let mut case_builder = if let Some(match_expr) = operand {
996 let match_df = cypher_expr_to_df(match_expr, context)?;
997 datafusion::logical_expr::case(match_df)
998 } else {
999 datafusion::logical_expr::when(
1000 cypher_expr_to_df(&when_then[0].0, context)?,
1001 cypher_expr_to_df(&when_then[0].1, context)?,
1002 )
1003 };
1004
1005 let start_idx = if operand.is_some() { 0 } else { 1 };
1006 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1007 let when_df = cypher_expr_to_df(when_expr, context)?;
1008 let then_df = cypher_expr_to_df(then_expr, context)?;
1009 case_builder = case_builder.when(when_df, then_df);
1010 }
1011
1012 if let Some(else_e) = else_expr {
1013 let else_df = cypher_expr_to_df(else_e, context)?;
1014 Ok(case_builder.otherwise(else_df)?)
1015 } else {
1016 Ok(case_builder.end()?)
1017 }
1018}
1019
1020fn translate_map_projection(
1022 base: &Expr,
1023 items: &[MapProjectionItem],
1024 context: Option<&TranslationContext>,
1025) -> Result<DfExpr> {
1026 let mut args = Vec::new();
1027 for item in items {
1028 match item {
1029 MapProjectionItem::Property(prop) => {
1030 args.push(lit(prop.clone()));
1031 let prop_expr = cypher_expr_to_df(
1032 &Expr::Property(Box::new(base.clone()), prop.clone()),
1033 context,
1034 )?;
1035 args.push(prop_expr);
1036 }
1037 MapProjectionItem::LiteralEntry(key, expr) => {
1038 args.push(lit(key.clone()));
1039 args.push(cypher_expr_to_df(expr, context)?);
1040 }
1041 MapProjectionItem::Variable(var) => {
1042 args.push(lit(var.clone()));
1043 args.push(DfExpr::Column(Column::from_name(var)));
1044 }
1045 MapProjectionItem::AllProperties => {
1046 args.push(lit("__all__"));
1047 args.push(cypher_expr_to_df(base, context)?);
1048 }
1049 }
1050 }
1051 Ok(dummy_udf_expr("_map_project", args))
1052}
1053
1054fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1057 match expr {
1058 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1059 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1060 Expr::Literal(CypherLiteral::Integer(i)) => {
1061 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1062 }
1063 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1064 .map(serde_json::Value::Number)
1065 .or(Some(serde_json::Value::Null)),
1066 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1067 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1068 Expr::Map(entries) => {
1069 let mut map = serde_json::Map::new();
1070 for (k, v) in entries {
1071 map.insert(k.clone(), try_expr_to_json(v)?);
1072 }
1073 Some(serde_json::Value::Object(map))
1074 }
1075 _ => None,
1076 }
1077}
1078
1079fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1081 items.iter().map(try_expr_to_json).collect()
1082}
1083
1084fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1086 match lit {
1087 CypherLiteral::Null => Ok(ScalarValue::Null),
1088 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1089 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1090 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1091 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1092 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1093 }
1094}
1095
1096fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1098 match value {
1099 Value::Null => Ok(ScalarValue::Null),
1100 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1101 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1102 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1103 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1104 Value::List(items) => {
1105 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1107 let scalars = scalars?;
1108
1109 let data_type = infer_common_scalar_type(&scalars);
1111
1112 let typed_scalars: Vec<ScalarValue> = scalars
1114 .into_iter()
1115 .map(|s| {
1116 if matches!(s, ScalarValue::Null) {
1117 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1118 }
1119
1120 match (s, &data_type) {
1121 (
1122 ScalarValue::Int64(Some(v)),
1123 datafusion::arrow::datatypes::DataType::Float64,
1124 ) => ScalarValue::Float64(Some(v as f64)),
1125 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1126 let s_str = s.to_string();
1128 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1129 }
1130 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1131 if matches!(s, ScalarValue::Utf8(_)) {
1133 s
1134 } else {
1135 ScalarValue::Utf8(Some(s.to_string()))
1136 }
1137 }
1138 (s, _) => s,
1139 }
1140 })
1141 .collect();
1142
1143 if typed_scalars.is_empty() {
1145 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1146 &[],
1147 &data_type,
1148 )))
1149 } else {
1150 Ok(ScalarValue::List(ScalarValue::new_list(
1151 &typed_scalars,
1152 &data_type,
1153 true,
1154 )))
1155 }
1156 }
1157 Value::Map(map) => {
1158 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1161 entries.sort_by_key(|(k, _)| *k);
1162
1163 if entries.is_empty() {
1164 return Ok(ScalarValue::Struct(Arc::new(
1165 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1166 )));
1167 }
1168
1169 let mut fields_arrays = Vec::with_capacity(entries.len());
1170
1171 for (k, v) in entries {
1172 let scalar = value_to_scalar(v)?;
1173 let dt = scalar.data_type();
1174 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1175 let array = scalar.to_array()?;
1176 fields_arrays.push((field, array));
1177 }
1178
1179 Ok(ScalarValue::Struct(Arc::new(
1180 datafusion::arrow::array::StructArray::from(fields_arrays),
1181 )))
1182 }
1183 Value::Temporal(tv) => {
1184 use uni_common::TemporalValue;
1185 match tv {
1186 TemporalValue::Date { days_since_epoch } => {
1187 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1188 }
1189 TemporalValue::LocalTime {
1190 nanos_since_midnight,
1191 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1192 TemporalValue::Time {
1193 nanos_since_midnight,
1194 offset_seconds,
1195 } => {
1196 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1198 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1199
1200 let nanos_arr =
1201 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1202 as ArrayRef;
1203 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1204
1205 let fields = Fields::from(vec![
1206 Field::new(
1207 "nanos_since_midnight",
1208 ArrowDataType::Time64(TimeUnit::Nanosecond),
1209 true,
1210 ),
1211 Field::new("offset_seconds", ArrowDataType::Int32, true),
1212 ]);
1213
1214 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1215 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1216 }
1217 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1218 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1219 ),
1220 TemporalValue::DateTime {
1221 nanos_since_epoch,
1222 offset_seconds,
1223 timezone_name,
1224 } => {
1225 use arrow::array::{
1227 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1228 };
1229 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1230
1231 let nanos_arr =
1232 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1233 as ArrayRef;
1234 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1235 let tz_arr =
1236 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1237
1238 let fields = Fields::from(vec![
1239 Field::new(
1240 "nanos_since_epoch",
1241 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1242 true,
1243 ),
1244 Field::new("offset_seconds", ArrowDataType::Int32, true),
1245 Field::new("timezone_name", ArrowDataType::Utf8, true),
1246 ]);
1247
1248 let struct_arr =
1249 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1250 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1251 }
1252 TemporalValue::Duration {
1253 months,
1254 days,
1255 nanos,
1256 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1257 arrow::datatypes::IntervalMonthDayNano {
1258 months: *months as i32,
1259 days: *days as i32,
1260 nanoseconds: *nanos,
1261 },
1262 ))),
1263 }
1264 }
1265 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1266 other => {
1268 let json_val: serde_json::Value = other.clone().into();
1269 let json_str = serde_json::to_string(&json_val)
1270 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1271 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1272 }
1273 }
1274}
1275
1276fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1278 match op {
1279 BinaryOp::Eq => Ok(left.eq(right)),
1283 BinaryOp::NotEq => Ok(left.not_eq(right)),
1284 BinaryOp::Lt => Ok(left.lt(right)),
1285 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1286 BinaryOp::Gt => Ok(left.gt(right)),
1287 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1288
1289 BinaryOp::And => Ok(left.and(right)),
1291 BinaryOp::Or => Ok(left.or(right)),
1292 BinaryOp::Xor => {
1293 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1295 }
1296
1297 BinaryOp::Add => {
1299 if is_list_expr(&left) || is_list_expr(&right) {
1300 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1301 } else {
1302 Ok(left + right)
1303 }
1304 }
1305 BinaryOp::Sub => Ok(left - right),
1306 BinaryOp::Mul => Ok(left * right),
1307 BinaryOp::Div => Ok(left / right),
1308 BinaryOp::Mod => Ok(left % right),
1309 BinaryOp::Pow => {
1310 let left_f = datafusion::logical_expr::cast(
1313 left,
1314 datafusion::arrow::datatypes::DataType::Float64,
1315 );
1316 let right_f = datafusion::logical_expr::cast(
1317 right,
1318 datafusion::arrow::datatypes::DataType::Float64,
1319 );
1320 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1321 }
1322
1323 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1325 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1326 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1327
1328 BinaryOp::Regex => {
1329 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1330 }
1331
1332 BinaryOp::ApproxEq => Err(anyhow!(
1333 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1334 )),
1335 }
1336}
1337
1338macro_rules! check_args {
1343 (1, $df_args:expr, $name:expr) => {
1344 if let Err(e) = require_arg($df_args, $name) {
1345 return Some(Err(e));
1346 }
1347 };
1348 ($n:expr, $df_args:expr, $name:expr) => {
1349 if let Err(e) = require_args($df_args, $n, $name) {
1350 return Some(Err(e));
1351 }
1352 };
1353}
1354
1355fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1358 if df_args.len() < count {
1359 let noun = if count == 1 { "argument" } else { "arguments" };
1360 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1361 }
1362 Ok(())
1363}
1364
1365fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1367 require_args(df_args, 1, func_name)
1368}
1369
1370fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1372 df_args[0].clone()
1373}
1374
1375pub(crate) fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1377 DfExpr::Cast(datafusion::logical_expr::Cast {
1378 expr: Box::new(expr),
1379 data_type,
1380 })
1381}
1382
1383pub(crate) fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1389 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1390 Arc::new(crate::query::df_udfs::create_cypher_list_to_cv_udf()),
1391 vec![expr],
1392 ))
1393}
1394
1395pub(crate) fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1399 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1400 Arc::new(crate::query::df_udfs::create_cypher_scalar_to_cv_udf()),
1401 vec![expr],
1402 ))
1403}
1404
1405fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1407 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1408 Box::new(left),
1409 op,
1410 Box::new(right),
1411 ))
1412}
1413
1414pub(crate) fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1419 use datafusion::logical_expr::Operator;
1420 match op {
1421 Operator::Eq => Some("_cypher_equal"),
1422 Operator::NotEq => Some("_cypher_not_equal"),
1423 Operator::Lt => Some("_cypher_lt"),
1424 Operator::LtEq => Some("_cypher_lt_eq"),
1425 Operator::Gt => Some("_cypher_gt"),
1426 Operator::GtEq => Some("_cypher_gt_eq"),
1427 _ => None,
1428 }
1429}
1430
1431fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1433 use datafusion::logical_expr::Operator;
1434 match op {
1435 Operator::Plus => Some("_cypher_add"),
1436 Operator::Minus => Some("_cypher_sub"),
1437 Operator::Multiply => Some("_cypher_mul"),
1438 Operator::Divide => Some("_cypher_div"),
1439 Operator::Modulo => Some("_cypher_mod"),
1440 _ => None,
1441 }
1442}
1443
1444fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1449where
1450 F: FnOnce(DfExpr) -> DfExpr,
1451{
1452 require_arg(df_args, func_name)?;
1453 Ok(math_fn(cast_expr(
1454 first_arg(df_args),
1455 datafusion::arrow::datatypes::DataType::Float64,
1456 )))
1457}
1458
1459fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1461 if distinct {
1462 expr.distinct()
1463 .build()
1464 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1465 } else {
1466 Ok(expr)
1467 }
1468}
1469
1470fn translate_aggregate_function(
1472 name_upper: &str,
1473 df_args: &[DfExpr],
1474 distinct: bool,
1475) -> Option<Result<DfExpr>> {
1476 match name_upper {
1477 "COUNT" => {
1478 let expr = if df_args.is_empty() {
1479 datafusion::functions_aggregate::count::count(lit(1i64))
1480 } else {
1481 datafusion::functions_aggregate::count::count(first_arg(df_args))
1482 };
1483 Some(maybe_distinct(expr, distinct, "COUNT"))
1484 }
1485 "SUM" => {
1486 check_args!(1, df_args, "SUM");
1487 let udaf = Arc::new(crate::query::df_udfs::create_cypher_sum_udaf());
1488 Some(maybe_distinct(
1489 udaf.call(vec![first_arg(df_args)]),
1490 distinct,
1491 "SUM",
1492 ))
1493 }
1494 "AVG" => {
1495 check_args!(1, df_args, "AVG");
1496 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1497 let expr = datafusion::functions_aggregate::average::avg(coerced);
1498 Some(maybe_distinct(expr, distinct, "AVG"))
1499 }
1500 "MIN" => {
1501 check_args!(1, df_args, "MIN");
1502 let udaf = Arc::new(crate::query::df_udfs::create_cypher_min_udaf());
1503 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1504 }
1505 "MAX" => {
1506 check_args!(1, df_args, "MAX");
1507 let udaf = Arc::new(crate::query::df_udfs::create_cypher_max_udaf());
1508 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1509 }
1510 "PERCENTILEDISC" => {
1511 if df_args.len() != 2 {
1512 return Some(Err(anyhow!(
1513 "percentileDisc() requires exactly 2 arguments"
1514 )));
1515 }
1516 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1517 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_disc_udaf());
1518 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1519 }
1520 "PERCENTILECONT" => {
1521 if df_args.len() != 2 {
1522 return Some(Err(anyhow!(
1523 "percentileCont() requires exactly 2 arguments"
1524 )));
1525 }
1526 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1527 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_cont_udaf());
1528 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1529 }
1530 "COLLECT" => {
1531 check_args!(1, df_args, "COLLECT");
1532 Some(Ok(crate::query::df_udfs::create_cypher_collect_expr(
1533 first_arg(df_args),
1534 distinct,
1535 )))
1536 }
1537 _ => None,
1538 }
1539}
1540
1541fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1544 match name_upper {
1545 "TOSTRING" => {
1546 check_args!(1, df_args, "toString");
1547 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1548 }
1549 "TOINTEGER" | "TOINT" => {
1550 check_args!(1, df_args, "toInteger");
1551 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1552 }
1553 "TOFLOAT" => {
1554 check_args!(1, df_args, "toFloat");
1555 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1556 }
1557 "TOBOOLEAN" | "TOBOOL" => {
1558 check_args!(1, df_args, "toBoolean");
1559 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1560 }
1561 "UPPER" | "TOUPPER" => {
1562 check_args!(1, df_args, "upper");
1563 Some(Ok(datafusion::functions::string::expr_fn::upper(
1564 first_arg(df_args),
1565 )))
1566 }
1567 "LOWER" | "TOLOWER" => {
1568 check_args!(1, df_args, "lower");
1569 Some(Ok(datafusion::functions::string::expr_fn::lower(
1570 first_arg(df_args),
1571 )))
1572 }
1573 "SUBSTRING" => {
1574 check_args!(2, df_args, "substring");
1575 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1576 }
1577 "TRIM" => {
1578 check_args!(1, df_args, "TRIM");
1579 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1580 first_arg(df_args),
1581 ])))
1582 }
1583 "LTRIM" => {
1584 check_args!(1, df_args, "LTRIM");
1585 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1586 first_arg(df_args),
1587 ])))
1588 }
1589 "RTRIM" => {
1590 check_args!(1, df_args, "RTRIM");
1591 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1592 first_arg(df_args),
1593 ])))
1594 }
1595 "LEFT" => {
1596 check_args!(2, df_args, "left");
1597 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1598 df_args[0].clone(),
1599 df_args[1].clone(),
1600 )))
1601 }
1602 "RIGHT" => {
1603 check_args!(2, df_args, "right");
1604 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1605 df_args[0].clone(),
1606 df_args[1].clone(),
1607 )))
1608 }
1609 "REPLACE" => {
1610 check_args!(3, df_args, "replace");
1611 Some(Ok(datafusion::functions::string::expr_fn::replace(
1612 df_args[0].clone(),
1613 df_args[1].clone(),
1614 df_args[2].clone(),
1615 )))
1616 }
1617 "REVERSE" => {
1618 check_args!(1, df_args, "reverse");
1619 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1620 }
1621 "SPLIT" => {
1622 check_args!(2, df_args, "split");
1623 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1624 }
1625 "SIZE" | "LENGTH" => {
1626 check_args!(1, df_args, name_upper);
1627 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1628 }
1629 _ => None,
1630 }
1631}
1632
1633fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1636 use datafusion::functions::math::expr_fn;
1637
1638 let unary_f64 =
1640 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1641
1642 match name_upper {
1643 "ABS" => {
1644 check_args!(1, df_args, "abs");
1645 Some(Ok(crate::query::df_udfs::cypher_abs_expr(first_arg(
1649 df_args,
1650 ))))
1651 }
1652 "CEIL" | "CEILING" => {
1653 check_args!(1, df_args, "ceil");
1654 Some(Ok(expr_fn::ceil(first_arg(df_args))))
1655 }
1656 "FLOOR" => {
1657 check_args!(1, df_args, "floor");
1658 Some(Ok(expr_fn::floor(first_arg(df_args))))
1659 }
1660 "ROUND" => {
1661 check_args!(1, df_args, "round");
1662 let args = if df_args.len() == 1 {
1663 vec![first_arg(df_args)]
1664 } else {
1665 vec![df_args[0].clone(), df_args[1].clone()]
1666 };
1667 Some(Ok(expr_fn::round(args)))
1668 }
1669 "SIGN" => {
1670 check_args!(1, df_args, "sign");
1671 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1672 Some(Ok(expr_fn::signum(coerced)))
1673 }
1674 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1675 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1676 "LOG10" => unary_f64("log10", expr_fn::log10),
1677 "EXP" => unary_f64("exp", expr_fn::exp),
1678 "SIN" => unary_f64("sin", expr_fn::sin),
1679 "COS" => unary_f64("cos", expr_fn::cos),
1680 "TAN" => unary_f64("tan", expr_fn::tan),
1681 "ASIN" => unary_f64("asin", expr_fn::asin),
1682 "ACOS" => unary_f64("acos", expr_fn::acos),
1683 "ATAN" => unary_f64("atan", expr_fn::atan),
1684 "ATAN2" => {
1685 check_args!(2, df_args, "atan2");
1686 let cast_f64 =
1687 |e: DfExpr| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64);
1688 Some(Ok(expr_fn::atan2(
1689 cast_f64(df_args[0].clone()),
1690 cast_f64(df_args[1].clone()),
1691 )))
1692 }
1693 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1694 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1695 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1696 _ => None,
1697 }
1698}
1699
1700fn translate_temporal_function(
1703 name_upper: &str,
1704 name: &str,
1705 df_args: &[DfExpr],
1706 context: Option<&TranslationContext>,
1707) -> Option<Result<DfExpr>> {
1708 match name_upper {
1709 "DATE"
1710 | "TIME"
1711 | "LOCALTIME"
1712 | "LOCALDATETIME"
1713 | "DATETIME"
1714 | "DURATION"
1715 | "YEAR"
1716 | "MONTH"
1717 | "DAY"
1718 | "HOUR"
1719 | "MINUTE"
1720 | "SECOND"
1721 | "DURATION.BETWEEN"
1722 | "DURATION.INMONTHS"
1723 | "DURATION.INDAYS"
1724 | "DURATION.INSECONDS"
1725 | "DATETIME.FROMEPOCH"
1726 | "DATETIME.FROMEPOCHMILLIS"
1727 | "DATE.TRUNCATE"
1728 | "TIME.TRUNCATE"
1729 | "DATETIME.TRUNCATE"
1730 | "LOCALDATETIME.TRUNCATE"
1731 | "LOCALTIME.TRUNCATE"
1732 | "DATETIME.TRANSACTION"
1733 | "DATETIME.STATEMENT"
1734 | "DATETIME.REALTIME"
1735 | "DATE.TRANSACTION"
1736 | "DATE.STATEMENT"
1737 | "DATE.REALTIME"
1738 | "TIME.TRANSACTION"
1739 | "TIME.STATEMENT"
1740 | "TIME.REALTIME"
1741 | "LOCALTIME.TRANSACTION"
1742 | "LOCALTIME.STATEMENT"
1743 | "LOCALTIME.REALTIME"
1744 | "LOCALDATETIME.TRANSACTION"
1745 | "LOCALDATETIME.STATEMENT"
1746 | "LOCALDATETIME.REALTIME" => {
1747 let stmt_time = context.map(|c| c.statement_time);
1751 if can_constant_fold(name_upper, df_args)
1752 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1753 {
1754 return Some(Ok(folded));
1755 }
1756 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1757 }
1758 _ => None,
1759 }
1760}
1761
1762fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1764 if name.contains("REALTIME") {
1766 return false;
1767 }
1768 if args.is_empty() {
1776 return matches!(
1777 name,
1778 "DATE"
1779 | "TIME"
1780 | "LOCALTIME"
1781 | "LOCALDATETIME"
1782 | "DATETIME"
1783 | "DATE.STATEMENT"
1784 | "TIME.STATEMENT"
1785 | "LOCALTIME.STATEMENT"
1786 | "LOCALDATETIME.STATEMENT"
1787 | "DATETIME.STATEMENT"
1788 | "DATE.TRANSACTION"
1789 | "TIME.TRANSACTION"
1790 | "LOCALTIME.TRANSACTION"
1791 | "LOCALDATETIME.TRANSACTION"
1792 | "DATETIME.TRANSACTION"
1793 );
1794 }
1795 args.iter().all(is_constant_expr)
1797}
1798
1799fn is_constant_expr(expr: &DfExpr) -> bool {
1801 match expr {
1802 DfExpr::Literal(_, _) => true,
1803 DfExpr::ScalarFunction(func) => {
1804 func.args.iter().all(is_constant_expr)
1806 }
1807 _ => false,
1808 }
1809}
1810
1811fn try_constant_fold_temporal(
1817 name: &str,
1818 args: &[DfExpr],
1819 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1820) -> Result<DfExpr> {
1821 let val_args: Vec<Value> = args
1823 .iter()
1824 .map(extract_constant_value)
1825 .collect::<Result<_>>()?;
1826
1827 let result = if val_args.is_empty() {
1829 if let Some(frozen) = stmt_time {
1830 crate::query::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1831 } else {
1832 crate::query::datetime::eval_datetime_function(name, &val_args)?
1833 }
1834 } else {
1835 crate::query::datetime::eval_datetime_function(name, &val_args)?
1836 };
1837
1838 let scalar = value_to_scalar(&result)?;
1840 Ok(DfExpr::Literal(scalar, None))
1841}
1842
1843fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1845 use crate::query::df_udfs::scalar_to_value;
1846 match expr {
1847 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1848 DfExpr::ScalarFunction(func) => {
1849 let mut map = std::collections::HashMap::new();
1852 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1853 for chunk in pairs.chunks(2) {
1854 if let [key_expr, val_expr] = chunk {
1855 let key = match key_expr {
1857 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1858 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1859 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1860 };
1861 let val = extract_constant_value(val_expr)?;
1862 map.insert(key, val);
1863 } else {
1864 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1865 }
1866 }
1867 Ok(Value::Map(map))
1868 }
1869 _ => Err(anyhow::anyhow!(
1870 "Cannot extract constant value from expression"
1871 )),
1872 }
1873}
1874
1875fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1878 match name_upper {
1879 "HEAD" => {
1880 check_args!(1, df_args, "head");
1881 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1882 }
1883 "LAST" => {
1884 check_args!(1, df_args, "last");
1885 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1886 }
1887 "TAIL" => {
1888 check_args!(1, df_args, "tail");
1889 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
1890 }
1891 "RANGE" => {
1892 check_args!(2, df_args, "range");
1893 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
1894 }
1895 _ => None,
1896 }
1897}
1898
1899fn translate_graph_function(
1902 name_upper: &str,
1903 name: &str,
1904 df_args: &[DfExpr],
1905 args: &[Expr],
1906 context: Option<&TranslationContext>,
1907) -> Option<Result<DfExpr>> {
1908 match name_upper {
1909 "ID" => {
1910 if let Some(Expr::Variable(var)) = args.first() {
1913 let is_edge = context.is_some_and(|ctx| {
1914 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
1915 || ctx.mutation_edge_hints.iter().any(|h| h == var)
1916 });
1917 let id_suffix = if is_edge { COL_EID } else { COL_VID };
1918 Some(Ok(DfExpr::Column(Column::from_name(format!(
1919 "{}.{}",
1920 var, id_suffix
1921 )))))
1922 } else {
1923 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
1924 }
1925 }
1926 "LABELS" | "KEYS" => {
1927 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1932 }
1933 "TYPE" => {
1934 if let Some(Expr::Variable(var)) = args.first()
1938 && let Some(ctx) = context
1939 && let Some(label) = ctx.variable_labels.get(var)
1940 {
1941 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
1944 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
1945 expr: None,
1946 when_then_expr: vec![(
1947 Box::new(eid_col.is_not_null()),
1948 Box::new(lit(label.clone())),
1949 )],
1950 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
1951 })));
1952 }
1953 if let Some(Expr::Variable(var)) = args.first()
1957 && context
1958 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
1959 {
1960 return Some(Ok(DfExpr::Column(Column::from_name(format!(
1961 "{}.{}",
1962 var, COL_TYPE
1963 )))));
1964 }
1965 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
1966 }
1967 "PROPERTIES" => {
1968 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1971 }
1972 "UNI.TEMPORAL.VALIDAT" => {
1973 if let (
1976 Some(Expr::Variable(var)),
1977 Some(Expr::Literal(CypherLiteral::String(start_prop))),
1978 Some(Expr::Literal(CypherLiteral::String(end_prop))),
1979 Some(ts_expr),
1980 ) = (args.first(), args.get(1), args.get(2), args.get(3))
1981 {
1982 let start_col =
1983 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
1984 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
1985 let ts = match cypher_expr_to_df(ts_expr, context) {
1986 Ok(ts) => ts,
1987 Err(e) => return Some(Err(e)),
1988 };
1989
1990 let start_check = start_col.lt_eq(ts.clone());
1992 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
1994 let end_after = end_col.gt(ts);
1995 let end_check = end_null.or(end_after);
1996
1997 Some(Ok(start_check.and(end_check)))
1998 } else {
1999 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2001 }
2002 }
2003 "STARTNODE" | "ENDNODE" => {
2004 let mut udf_args = df_args.to_vec();
2007 let mut seen = std::collections::HashSet::new();
2008 if let Some(ctx) = context {
2009 for (var, kind) in &ctx.variable_kinds {
2011 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2012 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2013 }
2014 }
2015 for var in &ctx.node_variable_hints {
2018 if seen.insert(var.clone()) {
2019 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2020 }
2021 }
2022 }
2023 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2024 }
2025 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2026 "HASLABEL" => {
2027 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2028 return Some(Err(e));
2029 }
2030 if let Some(Expr::Variable(var)) = args.first() {
2032 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2033 let labels_col =
2035 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2036 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2037 labels_col,
2038 lit(label.clone()),
2039 )))
2040 } else {
2041 Some(Err(anyhow::anyhow!(
2043 "hasLabel requires string literal as second argument for DataFusion translation"
2044 )))
2045 }
2046 } else {
2047 Some(Err(anyhow::anyhow!(
2049 "hasLabel requires variable as first argument for DataFusion translation"
2050 )))
2051 }
2052 }
2053 _ => None,
2054 }
2055}
2056
2057fn translate_function_call(
2059 name: &str,
2060 args: &[Expr],
2061 distinct: bool,
2062 context: Option<&TranslationContext>,
2063) -> Result<DfExpr> {
2064 let df_args: Vec<DfExpr> = args
2065 .iter()
2066 .map(|arg| cypher_expr_to_df(arg, context))
2067 .collect::<Result<Vec<_>>>()?;
2068
2069 let name_upper = name.to_uppercase();
2070
2071 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2075 return result;
2076 }
2077
2078 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2079 return result;
2080 }
2081
2082 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2083 return result;
2084 }
2085
2086 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2087 return result;
2088 }
2089
2090 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2091 return result;
2092 }
2093
2094 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2095 return result;
2096 }
2097
2098 match name_upper.as_str() {
2100 "COALESCE" => {
2101 require_arg(&df_args, "coalesce")?;
2102 return Ok(datafusion::functions::expr_fn::coalesce(df_args));
2103 }
2104 "NULLIF" => {
2105 require_args(&df_args, 2, "nullif")?;
2106 return Ok(datafusion::functions::expr_fn::nullif(
2107 df_args[0].clone(),
2108 df_args[1].clone(),
2109 ));
2110 }
2111 _ => {}
2112 }
2113
2114 Ok(dummy_udf_expr(name, df_args))
2116}
2117
2118#[derive(Debug)]
2123struct DummyUdf {
2124 name: String,
2125 signature: datafusion::logical_expr::Signature,
2126 ret_type: datafusion::arrow::datatypes::DataType,
2127}
2128
2129impl DummyUdf {
2130 fn new(name: String) -> Self {
2131 let ret_type = dummy_udf_return_type(&name);
2132 Self {
2133 name,
2134 signature: datafusion::logical_expr::Signature::variadic_any(
2135 datafusion::logical_expr::Volatility::Immutable,
2136 ),
2137 ret_type,
2138 }
2139 }
2140}
2141
2142fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2155 use datafusion::arrow::datatypes::DataType;
2156 match name {
2157 "_cypher_add"
2161 | "_cypher_sub"
2162 | "_cypher_mul"
2163 | "_cypher_div"
2164 | "_cypher_mod"
2165 | "_cypher_list_concat"
2166 | "_cypher_list_append"
2167 | "_make_cypher_list"
2168 | "_map_project"
2169 | "_cypher_list_to_cv"
2170 | "_cypher_tail" => DataType::LargeBinary,
2171 _ => DataType::Null,
2175 }
2176}
2177
2178impl PartialEq for DummyUdf {
2179 fn eq(&self, other: &Self) -> bool {
2180 self.name == other.name
2181 }
2182}
2183
2184impl Eq for DummyUdf {}
2185
2186impl Hash for DummyUdf {
2187 fn hash<H: Hasher>(&self, state: &mut H) {
2188 self.name.hash(state);
2189 }
2190}
2191
2192pub(crate) fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2194 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2195 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2196 DummyUdf::new(name.to_lowercase()),
2197 )),
2198 args,
2199 })
2200}
2201
2202impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2203 fn as_any(&self) -> &dyn std::any::Any {
2204 self
2205 }
2206
2207 fn name(&self) -> &str {
2208 &self.name
2209 }
2210
2211 fn signature(&self) -> &datafusion::logical_expr::Signature {
2212 &self.signature
2213 }
2214
2215 fn return_type(
2216 &self,
2217 _arg_types: &[datafusion::arrow::datatypes::DataType],
2218 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2219 Ok(self.ret_type.clone())
2222 }
2223
2224 fn invoke_with_args(
2225 &self,
2226 _args: ScalarFunctionArgs,
2227 ) -> datafusion::error::Result<ColumnarValue> {
2228 Err(datafusion::error::DataFusionError::Plan(format!(
2229 "UDF '{}' is not registered. Register it via SessionContext.",
2230 self.name
2231 )))
2232 }
2233}
2234
2235pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2239 let mut properties = Vec::new();
2240 collect_properties_recursive(expr, &mut properties);
2241 properties.sort();
2242 properties.dedup();
2243 properties
2244}
2245
2246fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2247 match expr {
2248 Expr::PatternComprehension { .. } => {}
2249 Expr::Property(base, prop) => {
2250 if let Ok(var_name) = extract_variable_name(base) {
2251 properties.push((var_name, prop.clone()));
2252 }
2253 collect_properties_recursive(base, properties);
2254 }
2255 Expr::ArrayIndex { array, index } => {
2256 if let Ok(var_name) = extract_variable_name(array)
2257 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2258 {
2259 properties.push((var_name, prop_name.clone()));
2260 }
2261 collect_properties_recursive(array, properties);
2262 collect_properties_recursive(index, properties);
2263 }
2264 Expr::ArraySlice { array, start, end } => {
2265 collect_properties_recursive(array, properties);
2266 if let Some(s) = start {
2267 collect_properties_recursive(s, properties);
2268 }
2269 if let Some(e) = end {
2270 collect_properties_recursive(e, properties);
2271 }
2272 }
2273 Expr::List(items) => {
2274 for item in items {
2275 collect_properties_recursive(item, properties);
2276 }
2277 }
2278 Expr::Map(entries) => {
2279 for (_, value) in entries {
2280 collect_properties_recursive(value, properties);
2281 }
2282 }
2283 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2284 collect_properties_recursive(inner, properties);
2285 }
2286 Expr::FunctionCall { args, .. } => {
2287 for arg in args {
2288 collect_properties_recursive(arg, properties);
2289 }
2290 }
2291 Expr::BinaryOp { left, right, .. } => {
2292 collect_properties_recursive(left, properties);
2293 collect_properties_recursive(right, properties);
2294 }
2295 Expr::UnaryOp { expr, .. } => {
2296 collect_properties_recursive(expr, properties);
2297 }
2298 Expr::Case {
2299 expr,
2300 when_then,
2301 else_expr,
2302 } => {
2303 if let Some(e) = expr {
2304 collect_properties_recursive(e, properties);
2305 }
2306 for (when_e, then_e) in when_then {
2307 collect_properties_recursive(when_e, properties);
2308 collect_properties_recursive(then_e, properties);
2309 }
2310 if let Some(e) = else_expr {
2311 collect_properties_recursive(e, properties);
2312 }
2313 }
2314 Expr::Reduce {
2315 init, list, expr, ..
2316 } => {
2317 collect_properties_recursive(init, properties);
2318 collect_properties_recursive(list, properties);
2319 collect_properties_recursive(expr, properties);
2320 }
2321 Expr::Quantifier {
2322 list, predicate, ..
2323 } => {
2324 collect_properties_recursive(list, properties);
2325 collect_properties_recursive(predicate, properties);
2326 }
2327 Expr::ListComprehension {
2328 list,
2329 where_clause,
2330 map_expr,
2331 ..
2332 } => {
2333 collect_properties_recursive(list, properties);
2334 if let Some(filter) = where_clause {
2335 collect_properties_recursive(filter, properties);
2336 }
2337 collect_properties_recursive(map_expr, properties);
2338 }
2339 Expr::In { expr, list } => {
2340 collect_properties_recursive(expr, properties);
2341 collect_properties_recursive(list, properties);
2342 }
2343 Expr::ValidAt {
2344 entity, timestamp, ..
2345 } => {
2346 collect_properties_recursive(entity, properties);
2347 collect_properties_recursive(timestamp, properties);
2348 }
2349 Expr::MapProjection { base, items } => {
2350 collect_properties_recursive(base, properties);
2351 for item in items {
2352 match item {
2353 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2354 if let Ok(var_name) = extract_variable_name(base) {
2355 properties.push((var_name, prop.clone()));
2356 }
2357 }
2358 uni_cypher::ast::MapProjectionItem::AllProperties => {
2359 if let Ok(var_name) = extract_variable_name(base) {
2360 properties.push((var_name, "*".to_string()));
2361 }
2362 }
2363 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2364 collect_properties_recursive(expr, properties);
2365 }
2366 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2367 }
2368 }
2369 }
2370 Expr::LabelCheck { expr, .. } => {
2371 collect_properties_recursive(expr, properties);
2372 }
2373 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2375 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2376 }
2377}
2378
2379pub fn wider_numeric_type(
2386 a: &datafusion::arrow::datatypes::DataType,
2387 b: &datafusion::arrow::datatypes::DataType,
2388) -> datafusion::arrow::datatypes::DataType {
2389 use datafusion::arrow::datatypes::DataType;
2390
2391 fn numeric_rank(dt: &DataType) -> u8 {
2392 match dt {
2393 DataType::Int8 | DataType::UInt8 => 1,
2394 DataType::Int16 | DataType::UInt16 => 2,
2395 DataType::Int32 | DataType::UInt32 => 3,
2396 DataType::Int64 | DataType::UInt64 => 4,
2397 DataType::Float16 => 5,
2398 DataType::Float32 => 6,
2399 DataType::Float64 => 7,
2400 _ => 0,
2401 }
2402 }
2403
2404 if numeric_rank(a) >= numeric_rank(b) {
2405 a.clone()
2406 } else {
2407 b.clone()
2408 }
2409}
2410
2411fn resolve_column_type_fallback(
2417 expr: &DfExpr,
2418 schema: &datafusion::common::DFSchema,
2419) -> Option<datafusion::arrow::datatypes::DataType> {
2420 if let DfExpr::Column(col) = expr {
2421 let col_name = &col.name;
2422 for (_, field) in schema.iter() {
2424 if field.name() == col_name {
2425 return Some(field.data_type().clone());
2426 }
2427 }
2428 }
2429 None
2430}
2431
2432fn contains_division(expr: &DfExpr) -> bool {
2435 match expr {
2436 DfExpr::BinaryExpr(b) => {
2437 b.op == datafusion::logical_expr::Operator::Divide
2438 || contains_division(&b.left)
2439 || contains_division(&b.right)
2440 }
2441 DfExpr::Cast(c) => contains_division(&c.expr),
2442 DfExpr::TryCast(c) => contains_division(&c.expr),
2443 _ => false,
2444 }
2445}
2446
2447pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2453 use datafusion::arrow::datatypes::DataType;
2454 use datafusion::logical_expr::ExprSchemable;
2455
2456 match expr {
2457 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2458 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2459 DfExpr::Case(case) => coerce_case_expr(case, schema),
2460 DfExpr::InList(in_list) => {
2461 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2462 let coerced_list = in_list
2463 .list
2464 .iter()
2465 .map(|e| apply_type_coercion(e, schema))
2466 .collect::<Result<Vec<_>>>()?;
2467 let expr_type = coerced_expr
2468 .get_type(schema)
2469 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2470 crate::query::cypher_type_coerce::build_cypher_in_list(
2471 coerced_expr,
2472 &expr_type,
2473 coerced_list,
2474 in_list.negated,
2475 schema,
2476 )
2477 }
2478 DfExpr::Not(inner) => {
2479 let coerced_inner = apply_type_coercion(inner, schema)?;
2480 let inner_type = coerced_inner.get_type(schema).ok();
2481 let final_inner = if inner_type
2482 .as_ref()
2483 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2484 {
2485 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2486 } else if inner_type
2487 .as_ref()
2488 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2489 {
2490 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2491 } else {
2492 coerced_inner
2493 };
2494 Ok(DfExpr::Not(Box::new(final_inner)))
2495 }
2496 DfExpr::IsNull(inner) => {
2497 let coerced_inner = apply_type_coercion(inner, schema)?;
2498 Ok(coerced_inner.is_null())
2499 }
2500 DfExpr::IsNotNull(inner) => {
2501 let coerced_inner = apply_type_coercion(inner, schema)?;
2502 Ok(coerced_inner.is_not_null())
2503 }
2504 DfExpr::Negative(inner) => {
2505 let coerced_inner = apply_type_coercion(inner, schema)?;
2506 let inner_type = coerced_inner.get_type(schema).ok();
2507 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2508 Ok(dummy_udf_expr(
2509 "_cypher_mul",
2510 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2511 ))
2512 } else {
2513 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2514 }
2515 }
2516 DfExpr::Cast(cast) => {
2517 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2518 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2519 Box::new(coerced_inner),
2520 cast.data_type.clone(),
2521 )))
2522 }
2523 DfExpr::TryCast(cast) => {
2524 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2525 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2526 Box::new(coerced_inner),
2527 cast.data_type.clone(),
2528 )))
2529 }
2530 DfExpr::Alias(alias) => {
2531 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2532 Ok(coerced_inner.alias(alias.name.clone()))
2533 }
2534 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2535 _ => Ok(expr.clone()),
2536 }
2537}
2538
2539fn coerce_logical_operands(
2541 left: DfExpr,
2542 right: DfExpr,
2543 op: datafusion::logical_expr::Operator,
2544 schema: &datafusion::common::DFSchema,
2545) -> Option<DfExpr> {
2546 use datafusion::arrow::datatypes::DataType;
2547 use datafusion::logical_expr::ExprSchemable;
2548
2549 if !matches!(
2550 op,
2551 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2552 ) {
2553 return None;
2554 }
2555 let left_type = left.get_type(schema).ok();
2556 let right_type = right.get_type(schema).ok();
2557 let left_needs_cast = left_type
2558 .as_ref()
2559 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2560 let right_needs_cast = right_type
2561 .as_ref()
2562 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2563 let left_is_lb = left_type
2564 .as_ref()
2565 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2566 let right_is_lb = right_type
2567 .as_ref()
2568 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2569 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2570 return None;
2571 }
2572 let coerced_left = if left_is_lb {
2573 dummy_udf_expr("_cv_to_bool", vec![left])
2574 } else if left_needs_cast {
2575 datafusion::logical_expr::cast(left, DataType::Boolean)
2576 } else {
2577 left
2578 };
2579 let coerced_right = if right_is_lb {
2580 dummy_udf_expr("_cv_to_bool", vec![right])
2581 } else if right_needs_cast {
2582 datafusion::logical_expr::cast(right, DataType::Boolean)
2583 } else {
2584 right
2585 };
2586 Some(binary_expr(coerced_left, op, coerced_right))
2587}
2588
2589#[expect(
2592 clippy::too_many_arguments,
2593 reason = "Binary coercion needs all context"
2594)]
2595fn coerce_large_binary_ops(
2596 left: &DfExpr,
2597 right: &DfExpr,
2598 left_type: &datafusion::arrow::datatypes::DataType,
2599 right_type: &datafusion::arrow::datatypes::DataType,
2600 left_is_null: bool,
2601 op: datafusion::logical_expr::Operator,
2602 is_comparison: bool,
2603 is_arithmetic: bool,
2604) -> Option<Result<DfExpr>> {
2605 use datafusion::arrow::datatypes::DataType;
2606 use datafusion::logical_expr::Operator;
2607
2608 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2609 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2610
2611 if op == Operator::Plus {
2612 if left_is_lb && right_is_lb {
2613 return Some(Ok(dummy_udf_expr(
2614 "_cypher_add",
2615 vec![left.clone(), right.clone()],
2616 )));
2617 }
2618 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2619 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2620 if left_is_native_list && right_is_native_list {
2621 return Some(Ok(dummy_udf_expr(
2622 "_cypher_list_concat",
2623 vec![left.clone(), right.clone()],
2624 )));
2625 }
2626 if left_is_native_list || right_is_native_list {
2627 return Some(Ok(dummy_udf_expr(
2628 "_cypher_list_append",
2629 vec![left.clone(), right.clone()],
2630 )));
2631 }
2632 }
2633
2634 if (left_is_lb || right_is_lb) && is_comparison {
2635 if let Some(udf_name) = comparison_udf_name(op) {
2636 return Some(Ok(dummy_udf_expr(
2637 udf_name,
2638 vec![left.clone(), right.clone()],
2639 )));
2640 }
2641 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2642 }
2643
2644 if (left_is_lb || right_is_lb) && is_arithmetic {
2645 let udf_name =
2646 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2647 return Some(Ok(dummy_udf_expr(
2648 udf_name,
2649 vec![left.clone(), right.clone()],
2650 )));
2651 }
2652
2653 None
2654}
2655
2656fn coerce_temporal_comparisons(
2658 left: DfExpr,
2659 right: DfExpr,
2660 left_type: &datafusion::arrow::datatypes::DataType,
2661 right_type: &datafusion::arrow::datatypes::DataType,
2662 op: datafusion::logical_expr::Operator,
2663 is_comparison: bool,
2664) -> Option<DfExpr> {
2665 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2666 use datafusion::logical_expr::Operator;
2667
2668 if !is_comparison {
2669 return None;
2670 }
2671
2672 if uni_common::core::schema::is_datetime_struct(left_type)
2674 && uni_common::core::schema::is_datetime_struct(right_type)
2675 {
2676 return Some(binary_expr(
2677 extract_datetime_nanos(left),
2678 op,
2679 extract_datetime_nanos(right),
2680 ));
2681 }
2682
2683 if uni_common::core::schema::is_time_struct(left_type)
2685 && uni_common::core::schema::is_time_struct(right_type)
2686 {
2687 return Some(binary_expr(
2688 extract_time_nanos(left),
2689 op,
2690 extract_time_nanos(right),
2691 ));
2692 }
2693
2694 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2696 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2697
2698 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2699 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2700 {
2701 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2702 extract_datetime_nanos(left)
2703 } else {
2704 left
2705 };
2706 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2707 extract_datetime_nanos(right)
2708 } else {
2709 right
2710 };
2711 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2712 return Some(binary_expr(
2713 cast_expr(left_nanos, ts_type.clone()),
2714 op,
2715 cast_expr(right_nanos, ts_type),
2716 ));
2717 }
2718
2719 let left_is_duration = matches!(left_type, DataType::Interval(_));
2723 let right_is_duration = matches!(right_type, DataType::Interval(_));
2724 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2725 || uni_common::core::schema::is_time_struct(left_type)
2726 || matches!(
2727 left_type,
2728 DataType::Timestamp(_, _)
2729 | DataType::Date32
2730 | DataType::Date64
2731 | DataType::Time32(_)
2732 | DataType::Time64(_)
2733 );
2734 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2735 || uni_common::core::schema::is_time_struct(right_type)
2736 || matches!(
2737 right_type,
2738 DataType::Timestamp(_, _)
2739 | DataType::Date32
2740 | DataType::Date64
2741 | DataType::Time32(_)
2742 | DataType::Time64(_)
2743 );
2744
2745 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2746 {
2747 return Some(match op {
2748 Operator::Eq => lit(false),
2749 Operator::NotEq => lit(true),
2750 _ => lit(ScalarValue::Boolean(None)),
2751 });
2752 }
2753
2754 None
2755}
2756
2757fn coerce_mismatched_types(
2760 left: DfExpr,
2761 right: DfExpr,
2762 left_type: &datafusion::arrow::datatypes::DataType,
2763 right_type: &datafusion::arrow::datatypes::DataType,
2764 op: datafusion::logical_expr::Operator,
2765 is_comparison: bool,
2766) -> Option<Result<DfExpr>> {
2767 use datafusion::arrow::datatypes::DataType;
2768 use datafusion::logical_expr::Operator;
2769
2770 if left_type == right_type {
2771 return None;
2772 }
2773
2774 if left_type.is_numeric() && right_type.is_numeric() {
2776 if left_type == &DataType::Int64
2777 && right_type == &DataType::UInt64
2778 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2779 {
2780 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2781 return Some(Ok(binary_expr(coerced_left, op, right)));
2782 }
2783 if left_type == &DataType::UInt64
2784 && right_type == &DataType::Int64
2785 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2786 {
2787 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2788 return Some(Ok(binary_expr(left, op, coerced_right)));
2789 }
2790 let target = wider_numeric_type(left_type, right_type);
2791 let coerced_left = if *left_type != target {
2792 datafusion::logical_expr::cast(left, target.clone())
2793 } else {
2794 left
2795 };
2796 let coerced_right = if *right_type != target {
2797 datafusion::logical_expr::cast(right, target)
2798 } else {
2799 right
2800 };
2801 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2802 }
2803
2804 if is_comparison {
2806 match (left_type, right_type) {
2807 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2808 let right = normalize_datetime_literal(right);
2809 return Some(Ok(binary_expr(
2810 left,
2811 op,
2812 datafusion::logical_expr::cast(right, ts.clone()),
2813 )));
2814 }
2815 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2816 let left = normalize_datetime_literal(left);
2817 return Some(Ok(binary_expr(
2818 datafusion::logical_expr::cast(left, ts.clone()),
2819 op,
2820 right,
2821 )));
2822 }
2823 _ => {}
2824 }
2825 }
2826
2827 if is_comparison
2829 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2830 {
2831 let l_inner = l_field.data_type();
2832 let r_inner = r_field.data_type();
2833 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
2834 let target_inner = wider_numeric_type(l_inner, r_inner);
2835 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
2836 "item",
2837 target_inner,
2838 true,
2839 )));
2840 return Some(Ok(binary_expr(
2841 datafusion::logical_expr::cast(left, target_type.clone()),
2842 op,
2843 datafusion::logical_expr::cast(right, target_type),
2844 )));
2845 }
2846 }
2847
2848 if is_primitive_type(left_type) && is_primitive_type(right_type) {
2850 if op == Operator::Plus {
2851 return Some(crate::query::cypher_type_coerce::build_cypher_plus(
2852 left, left_type, right, right_type,
2853 ));
2854 }
2855 if is_comparison {
2856 return Some(Ok(
2857 crate::query::cypher_type_coerce::build_cypher_comparison(
2858 left, left_type, right, right_type, op,
2859 ),
2860 ));
2861 }
2862 }
2863
2864 None
2865}
2866
2867fn coerce_list_comparisons(
2869 left: DfExpr,
2870 right: DfExpr,
2871 left_type: &datafusion::arrow::datatypes::DataType,
2872 right_type: &datafusion::arrow::datatypes::DataType,
2873 op: datafusion::logical_expr::Operator,
2874 is_comparison: bool,
2875) -> Option<DfExpr> {
2876 use datafusion::arrow::datatypes::DataType;
2877 use datafusion::logical_expr::Operator;
2878
2879 if !is_comparison {
2880 return None;
2881 }
2882
2883 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2884 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2885
2886 if left_is_list
2888 && right_is_list
2889 && matches!(
2890 op,
2891 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
2892 )
2893 {
2894 let op_str = match op {
2895 Operator::Lt => "lt",
2896 Operator::LtEq => "lteq",
2897 Operator::Gt => "gt",
2898 Operator::GtEq => "gteq",
2899 _ => unreachable!(),
2900 };
2901 return Some(dummy_udf_expr(
2902 "_cypher_list_compare",
2903 vec![left, right, lit(op_str)],
2904 ));
2905 }
2906
2907 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
2909 let udf_name =
2910 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
2911 return Some(dummy_udf_expr(udf_name, vec![left, right]));
2912 }
2913
2914 if (left_is_list != right_is_list)
2916 && !matches!(left_type, DataType::Null)
2917 && !matches!(right_type, DataType::Null)
2918 {
2919 return Some(match op {
2920 Operator::Eq => lit(false),
2921 Operator::NotEq => lit(true),
2922 _ => lit(ScalarValue::Boolean(None)),
2923 });
2924 }
2925
2926 None
2927}
2928
2929fn coerce_binary_expr(
2931 binary: &datafusion::logical_expr::expr::BinaryExpr,
2932 schema: &datafusion::common::DFSchema,
2933) -> Result<DfExpr> {
2934 use datafusion::arrow::datatypes::DataType;
2935 use datafusion::logical_expr::ExprSchemable;
2936 use datafusion::logical_expr::Operator;
2937
2938 let left = apply_type_coercion(&binary.left, schema)?;
2939 let right = apply_type_coercion(&binary.right, schema)?;
2940
2941 let is_comparison = matches!(
2942 binary.op,
2943 Operator::Eq
2944 | Operator::NotEq
2945 | Operator::Lt
2946 | Operator::LtEq
2947 | Operator::Gt
2948 | Operator::GtEq
2949 );
2950 let is_arithmetic = matches!(
2951 binary.op,
2952 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
2953 );
2954
2955 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
2957 return Ok(result);
2958 }
2959
2960 if is_comparison || is_arithmetic {
2961 let left_type = match left.get_type(schema) {
2962 Ok(t) => t,
2963 Err(e) => {
2964 if let Some(t) = resolve_column_type_fallback(&left, schema) {
2965 t
2966 } else {
2967 log::warn!("Failed to get left type in binary expr: {}", e);
2968 return Ok(binary_expr(left, binary.op, right));
2969 }
2970 }
2971 };
2972 let right_type = match right.get_type(schema) {
2973 Ok(t) => t,
2974 Err(e) => {
2975 if let Some(t) = resolve_column_type_fallback(&right, schema) {
2976 t
2977 } else {
2978 log::warn!("Failed to get right type in binary expr: {}", e);
2979 return Ok(binary_expr(left, binary.op, right));
2980 }
2981 }
2982 };
2983
2984 let left_is_null = left_type.is_null();
2986 let right_is_null = right_type.is_null();
2987 if left_is_null && right_is_null {
2988 return Ok(lit(ScalarValue::Boolean(None)));
2989 }
2990 if left_is_null || right_is_null {
2991 let target = if left_is_null {
2992 &right_type
2993 } else {
2994 &left_type
2995 };
2996 if !matches!(target, DataType::LargeBinary) {
2997 let coerced_left = if left_is_null {
2998 datafusion::logical_expr::cast(left, target.clone())
2999 } else {
3000 left
3001 };
3002 let coerced_right = if right_is_null {
3003 datafusion::logical_expr::cast(right, target.clone())
3004 } else {
3005 right
3006 };
3007 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3008 }
3009 }
3010
3011 if let Some(result) = coerce_large_binary_ops(
3013 &left,
3014 &right,
3015 &left_type,
3016 &right_type,
3017 left_is_null,
3018 binary.op,
3019 is_comparison,
3020 is_arithmetic,
3021 ) {
3022 return result;
3023 }
3024
3025 if let Some(result) = coerce_temporal_comparisons(
3027 left.clone(),
3028 right.clone(),
3029 &left_type,
3030 &right_type,
3031 binary.op,
3032 is_comparison,
3033 ) {
3034 return Ok(result);
3035 }
3036
3037 let either_struct =
3039 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3040 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3041 || matches!(left_type, DataType::Struct(_)))
3042 && (matches!(right_type, DataType::LargeBinary)
3043 || matches!(right_type, DataType::Struct(_)));
3044 if is_comparison && either_struct && either_lb_or_struct {
3045 if let Some(udf_name) = comparison_udf_name(binary.op) {
3046 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3047 }
3048 return Ok(lit(ScalarValue::Boolean(None)));
3049 }
3050
3051 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3053 let udf_name = comparison_udf_name(binary.op)
3054 .expect("is_comparison guarantees a valid comparison operator");
3055 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3056 }
3057
3058 if binary.op == Operator::Plus
3060 && (crate::query::cypher_type_coerce::is_string_type(&left_type)
3061 || crate::query::cypher_type_coerce::is_string_type(&right_type))
3062 && is_primitive_type(&left_type)
3063 && is_primitive_type(&right_type)
3064 {
3065 return crate::query::cypher_type_coerce::build_cypher_plus(
3066 left,
3067 &left_type,
3068 right,
3069 &right_type,
3070 );
3071 }
3072
3073 if let Some(result) = coerce_mismatched_types(
3075 left.clone(),
3076 right.clone(),
3077 &left_type,
3078 &right_type,
3079 binary.op,
3080 is_comparison,
3081 ) {
3082 return result;
3083 }
3084
3085 if let Some(result) = coerce_list_comparisons(
3087 left.clone(),
3088 right.clone(),
3089 &left_type,
3090 &right_type,
3091 binary.op,
3092 is_comparison,
3093 ) {
3094 return Ok(result);
3095 }
3096 }
3097
3098 Ok(binary_expr(left, binary.op, right))
3099}
3100
3101fn coerce_scalar_function(
3103 func: &datafusion::logical_expr::expr::ScalarFunction,
3104 schema: &datafusion::common::DFSchema,
3105) -> Result<DfExpr> {
3106 use datafusion::arrow::datatypes::DataType;
3107 use datafusion::logical_expr::ExprSchemable;
3108
3109 let coerced_args: Vec<DfExpr> = func
3110 .args
3111 .iter()
3112 .map(|a| apply_type_coercion(a, schema))
3113 .collect::<Result<Vec<_>>>()?;
3114
3115 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3116 let types: Vec<_> = coerced_args
3117 .iter()
3118 .filter_map(|a| a.get_type(schema).ok())
3119 .collect();
3120 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3121 if has_mixed_types {
3122 let has_large_binary = types.iter().any(|t| matches!(t, DataType::LargeBinary));
3123
3124 if has_large_binary {
3125 let unified_args: Vec<DfExpr> = coerced_args
3126 .into_iter()
3127 .zip(types.iter())
3128 .map(|(arg, t)| match t {
3129 DataType::LargeBinary | DataType::Null => arg,
3130 DataType::List(_) | DataType::LargeList(_) => {
3131 list_to_large_binary_expr(arg)
3132 }
3133 _ => scalar_to_large_binary_expr(arg),
3134 })
3135 .collect();
3136 return Ok(DfExpr::ScalarFunction(
3137 datafusion::logical_expr::expr::ScalarFunction {
3138 func: func.func.clone(),
3139 args: unified_args,
3140 },
3141 ));
3142 }
3143
3144 let all_list_or_lb = types.iter().all(|t| {
3145 matches!(
3146 t,
3147 DataType::Null
3148 | DataType::LargeBinary
3149 | DataType::List(_)
3150 | DataType::LargeList(_)
3151 )
3152 });
3153 if all_list_or_lb {
3154 let unified_args: Vec<DfExpr> = coerced_args
3155 .into_iter()
3156 .zip(types.iter())
3157 .map(|(arg, t)| {
3158 if matches!(t, DataType::List(_) | DataType::LargeList(_)) {
3159 list_to_large_binary_expr(arg)
3160 } else {
3161 arg
3162 }
3163 })
3164 .collect();
3165 return Ok(DfExpr::ScalarFunction(
3166 datafusion::logical_expr::expr::ScalarFunction {
3167 func: func.func.clone(),
3168 args: unified_args,
3169 },
3170 ));
3171 } else {
3172 let unified_args = coerced_args
3173 .into_iter()
3174 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3175 .collect();
3176 return Ok(DfExpr::ScalarFunction(
3177 datafusion::logical_expr::expr::ScalarFunction {
3178 func: func.func.clone(),
3179 args: unified_args,
3180 },
3181 ));
3182 }
3183 }
3184 }
3185
3186 Ok(DfExpr::ScalarFunction(
3187 datafusion::logical_expr::expr::ScalarFunction {
3188 func: func.func.clone(),
3189 args: coerced_args,
3190 },
3191 ))
3192}
3193
3194fn coerce_case_expr(
3197 case: &datafusion::logical_expr::expr::Case,
3198 schema: &datafusion::common::DFSchema,
3199) -> Result<DfExpr> {
3200 use datafusion::arrow::datatypes::DataType;
3201 use datafusion::logical_expr::ExprSchemable;
3202
3203 let coerced_operand = case
3204 .expr
3205 .as_ref()
3206 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3207 .transpose()?;
3208 let coerced_when_then = case
3209 .when_then_expr
3210 .iter()
3211 .map(|(w, t)| {
3212 let cw = apply_type_coercion(w, schema)?;
3213 let cw = match cw.get_type(schema).ok() {
3214 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3215 _ => cw,
3216 };
3217 let ct = apply_type_coercion(t, schema)?;
3218 Ok((Box::new(cw), Box::new(ct)))
3219 })
3220 .collect::<Result<Vec<_>>>()?;
3221 let coerced_else = case
3222 .else_expr
3223 .as_ref()
3224 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3225 .transpose()?;
3226
3227 let mut result_case = if let Some(operand) = coerced_operand {
3228 crate::query::cypher_type_coerce::rewrite_simple_case_to_generic(
3229 *operand,
3230 coerced_when_then,
3231 coerced_else,
3232 schema,
3233 )?
3234 } else {
3235 datafusion::logical_expr::expr::Case {
3236 expr: None,
3237 when_then_expr: coerced_when_then,
3238 else_expr: coerced_else,
3239 }
3240 };
3241
3242 crate::query::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3243
3244 Ok(DfExpr::Case(result_case))
3245}
3246
3247fn coerce_aggregate_function(
3249 agg: &datafusion::logical_expr::expr::AggregateFunction,
3250 schema: &datafusion::common::DFSchema,
3251) -> Result<DfExpr> {
3252 let coerced_args: Vec<DfExpr> = agg
3253 .params
3254 .args
3255 .iter()
3256 .map(|a| apply_type_coercion(a, schema))
3257 .collect::<Result<Vec<_>>>()?;
3258 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3259 .params
3260 .order_by
3261 .iter()
3262 .map(|s| {
3263 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3264 Ok(datafusion::logical_expr::SortExpr {
3265 expr: coerced_expr,
3266 asc: s.asc,
3267 nulls_first: s.nulls_first,
3268 })
3269 })
3270 .collect::<Result<Vec<_>>>()?;
3271 let coerced_filter = agg
3272 .params
3273 .filter
3274 .as_ref()
3275 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3276 .transpose()?;
3277 Ok(DfExpr::AggregateFunction(
3278 datafusion::logical_expr::expr::AggregateFunction {
3279 func: agg.func.clone(),
3280 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3281 args: coerced_args,
3282 distinct: agg.params.distinct,
3283 filter: coerced_filter,
3284 order_by: coerced_order_by,
3285 null_treatment: agg.params.null_treatment,
3286 },
3287 },
3288 ))
3289}
3290
3291#[cfg(test)]
3292mod tests {
3293 use super::*;
3294 use arrow_array::{
3295 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3296 };
3297 use uni_common::TemporalValue;
3298 #[test]
3299 fn test_literal_translation() {
3300 let expr = Expr::Literal(CypherLiteral::Integer(42));
3301 let result = cypher_expr_to_df(&expr, None).unwrap();
3302 let s = format!("{:?}", result);
3303 assert!(s.contains("Literal"));
3305 assert!(s.contains("Int64(42)"));
3306 }
3307
3308 #[test]
3309 fn test_property_access_no_context_uses_index() {
3310 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3312 let result = cypher_expr_to_df(&expr, None).unwrap();
3313 let s = format!("{}", result);
3314 assert!(
3315 s.contains("index"),
3316 "expected index UDF for non-graph variable, got: {s}"
3317 );
3318 }
3319
3320 #[test]
3321 fn test_comparison_operator() {
3322 let expr = Expr::BinaryOp {
3323 left: Box::new(Expr::Property(
3324 Box::new(Expr::Variable("n".to_string())),
3325 "age".to_string(),
3326 )),
3327 op: BinaryOp::Gt,
3328 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3329 };
3330 let result = cypher_expr_to_df(&expr, None).unwrap();
3331 let s = format!("{:?}", result);
3333 assert!(s.contains("age"));
3334 assert!(s.contains("30"));
3335 }
3336
3337 #[test]
3338 fn test_boolean_operators() {
3339 let expr = Expr::BinaryOp {
3340 left: Box::new(Expr::BinaryOp {
3341 left: Box::new(Expr::Property(
3342 Box::new(Expr::Variable("n".to_string())),
3343 "age".to_string(),
3344 )),
3345 op: BinaryOp::Gt,
3346 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3347 }),
3348 op: BinaryOp::And,
3349 right: Box::new(Expr::BinaryOp {
3350 left: Box::new(Expr::Property(
3351 Box::new(Expr::Variable("n".to_string())),
3352 "active".to_string(),
3353 )),
3354 op: BinaryOp::Eq,
3355 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3356 }),
3357 };
3358 let result = cypher_expr_to_df(&expr, None).unwrap();
3359 let s = format!("{:?}", result);
3360 assert!(s.contains("And"));
3361 }
3362
3363 #[test]
3364 fn test_is_null() {
3365 let expr = Expr::IsNull(Box::new(Expr::Property(
3366 Box::new(Expr::Variable("n".to_string())),
3367 "email".to_string(),
3368 )));
3369 let result = cypher_expr_to_df(&expr, None).unwrap();
3370 let s = format!("{:?}", result);
3371 assert!(s.contains("IsNull"));
3372 }
3373
3374 #[test]
3375 fn test_collect_properties() {
3376 let expr = Expr::BinaryOp {
3377 left: Box::new(Expr::Property(
3378 Box::new(Expr::Variable("n".to_string())),
3379 "name".to_string(),
3380 )),
3381 op: BinaryOp::Eq,
3382 right: Box::new(Expr::Property(
3383 Box::new(Expr::Variable("m".to_string())),
3384 "name".to_string(),
3385 )),
3386 };
3387
3388 let props = collect_properties(&expr);
3389 assert_eq!(props.len(), 2);
3390 assert!(props.contains(&("m".to_string(), "name".to_string())));
3391 assert!(props.contains(&("n".to_string(), "name".to_string())));
3392 }
3393
3394 #[test]
3395 fn test_function_call() {
3396 let expr = Expr::FunctionCall {
3397 name: "count".to_string(),
3398 args: vec![Expr::Wildcard],
3399 distinct: false,
3400 window_spec: None,
3401 };
3402 let result = cypher_expr_to_df(&expr, None).unwrap();
3403 let s = format!("{:?}", result);
3404 assert!(s.to_lowercase().contains("count"));
3405 }
3406
3407 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3412 use datafusion::logical_expr::Operator;
3413
3414 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3416 let fields: Vec<_> = cols
3417 .iter()
3418 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3419 .collect();
3420 let schema = Schema::new(fields);
3421 datafusion::common::DFSchema::try_from(schema).unwrap()
3422 }
3423
3424 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3426 let s = format!("{}", expr);
3427 s.contains(name)
3428 }
3429
3430 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3432 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3433 }
3434
3435 #[test]
3436 fn test_coercion_lb_eq_int64() {
3437 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3438 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3439 Box::new(col("lb")),
3440 Operator::Eq,
3441 Box::new(col("i")),
3442 ));
3443 let result = apply_type_coercion(&expr, &schema).unwrap();
3444 assert!(
3446 contains_udf(&result, "_cypher_equal"),
3447 "expected _cypher_equal, got: {result}"
3448 );
3449 }
3450
3451 #[test]
3452 fn test_coercion_lb_noteq_int64() {
3453 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3454 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3455 Box::new(col("lb")),
3456 Operator::NotEq,
3457 Box::new(col("i")),
3458 ));
3459 let result = apply_type_coercion(&expr, &schema).unwrap();
3460 assert!(contains_udf(&result, "_cypher_not_equal"));
3462 }
3463
3464 #[test]
3465 fn test_coercion_lb_lt_int64() {
3466 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3467 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3468 Box::new(col("lb")),
3469 Operator::Lt,
3470 Box::new(col("i")),
3471 ));
3472 let result = apply_type_coercion(&expr, &schema).unwrap();
3473 assert!(contains_udf(&result, "_cypher_lt"));
3475 }
3476
3477 #[test]
3478 fn test_coercion_lb_eq_float64() {
3479 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3480 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3481 Box::new(col("lb")),
3482 Operator::Eq,
3483 Box::new(col("f")),
3484 ));
3485 let result = apply_type_coercion(&expr, &schema).unwrap();
3486 assert!(contains_udf(&result, "_cypher_equal"));
3488 }
3489
3490 #[test]
3491 fn test_coercion_lb_eq_utf8() {
3492 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3493 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3494 Box::new(col("lb")),
3495 Operator::Eq,
3496 Box::new(col("s")),
3497 ));
3498 let result = apply_type_coercion(&expr, &schema).unwrap();
3499 assert!(contains_udf(&result, "_cypher_equal"));
3501 }
3502
3503 #[test]
3504 fn test_coercion_lb_eq_bool() {
3505 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3506 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3507 Box::new(col("lb")),
3508 Operator::Eq,
3509 Box::new(col("b")),
3510 ));
3511 let result = apply_type_coercion(&expr, &schema).unwrap();
3512 assert!(contains_udf(&result, "_cypher_equal"));
3514 }
3515
3516 #[test]
3517 fn test_coercion_int64_eq_lb() {
3518 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3520 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3521 Box::new(col("i")),
3522 Operator::Eq,
3523 Box::new(col("lb")),
3524 ));
3525 let result = apply_type_coercion(&expr, &schema).unwrap();
3526 assert!(contains_udf(&result, "_cypher_equal"));
3528 }
3529
3530 #[test]
3531 fn test_coercion_float64_gt_lb() {
3532 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3533 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3534 Box::new(col("f")),
3535 Operator::Gt,
3536 Box::new(col("lb")),
3537 ));
3538 let result = apply_type_coercion(&expr, &schema).unwrap();
3539 assert!(contains_udf(&result, "_cypher_gt"));
3541 }
3542
3543 #[test]
3544 fn test_coercion_both_lb_eq() {
3545 let schema = make_schema(&[
3546 ("lb1", DataType::LargeBinary),
3547 ("lb2", DataType::LargeBinary),
3548 ]);
3549 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3550 Box::new(col("lb1")),
3551 Operator::Eq,
3552 Box::new(col("lb2")),
3553 ));
3554 let result = apply_type_coercion(&expr, &schema).unwrap();
3555 assert!(contains_udf(&result, "_cypher_equal"));
3556 }
3557
3558 #[test]
3559 fn test_coercion_both_lb_lt() {
3560 let schema = make_schema(&[
3561 ("lb1", DataType::LargeBinary),
3562 ("lb2", DataType::LargeBinary),
3563 ]);
3564 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3565 Box::new(col("lb1")),
3566 Operator::Lt,
3567 Box::new(col("lb2")),
3568 ));
3569 let result = apply_type_coercion(&expr, &schema).unwrap();
3570 assert!(contains_udf(&result, "_cypher_lt"));
3571 }
3572
3573 #[test]
3574 fn test_coercion_both_lb_noteq() {
3575 let schema = make_schema(&[
3576 ("lb1", DataType::LargeBinary),
3577 ("lb2", DataType::LargeBinary),
3578 ]);
3579 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3580 Box::new(col("lb1")),
3581 Operator::NotEq,
3582 Box::new(col("lb2")),
3583 ));
3584 let result = apply_type_coercion(&expr, &schema).unwrap();
3585 assert!(contains_udf(&result, "_cypher_not_equal"));
3586 }
3587
3588 #[test]
3589 fn test_coercion_lb_plus_int64() {
3590 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3591 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3592 Box::new(col("lb")),
3593 Operator::Plus,
3594 Box::new(col("i")),
3595 ));
3596 let result = apply_type_coercion(&expr, &schema).unwrap();
3597 assert!(contains_udf(&result, "_cypher_add"));
3598 }
3599
3600 #[test]
3601 fn test_coercion_lb_minus_int64() {
3602 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3603 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3604 Box::new(col("lb")),
3605 Operator::Minus,
3606 Box::new(col("i")),
3607 ));
3608 let result = apply_type_coercion(&expr, &schema).unwrap();
3609 assert!(contains_udf(&result, "_cypher_sub"));
3610 }
3611
3612 #[test]
3613 fn test_coercion_lb_multiply_float64() {
3614 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3615 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3616 Box::new(col("lb")),
3617 Operator::Multiply,
3618 Box::new(col("f")),
3619 ));
3620 let result = apply_type_coercion(&expr, &schema).unwrap();
3621 assert!(contains_udf(&result, "_cypher_mul"));
3622 }
3623
3624 #[test]
3625 fn test_coercion_int64_plus_lb() {
3626 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3627 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3628 Box::new(col("i")),
3629 Operator::Plus,
3630 Box::new(col("lb")),
3631 ));
3632 let result = apply_type_coercion(&expr, &schema).unwrap();
3633 assert!(contains_udf(&result, "_cypher_add"));
3634 }
3635
3636 #[test]
3637 fn test_coercion_lb_plus_utf8() {
3638 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3640 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3641 Box::new(col("lb")),
3642 Operator::Plus,
3643 Box::new(col("s")),
3644 ));
3645 let result = apply_type_coercion(&expr, &schema).unwrap();
3646 assert!(contains_udf(&result, "_cypher_add"));
3648 }
3649
3650 #[test]
3651 fn test_coercion_and_null_bool() {
3652 let schema = make_schema(&[("b", DataType::Boolean)]);
3653 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3655 Box::new(lit(ScalarValue::Null)),
3656 Operator::And,
3657 Box::new(col("b")),
3658 ));
3659 let result = apply_type_coercion(&expr, &schema).unwrap();
3660 let s = format!("{}", result);
3661 assert!(
3663 s.contains("CAST") || s.contains("Boolean"),
3664 "expected cast to Boolean, got: {s}"
3665 );
3666 assert!(is_binary_op(&result, Operator::And));
3667 }
3668
3669 #[test]
3670 fn test_coercion_bool_and_null() {
3671 let schema = make_schema(&[("b", DataType::Boolean)]);
3672 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3673 Box::new(col("b")),
3674 Operator::And,
3675 Box::new(lit(ScalarValue::Null)),
3676 ));
3677 let result = apply_type_coercion(&expr, &schema).unwrap();
3678 assert!(is_binary_op(&result, Operator::And));
3679 }
3680
3681 #[test]
3682 fn test_coercion_or_null_bool() {
3683 let schema = make_schema(&[("b", DataType::Boolean)]);
3684 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3685 Box::new(lit(ScalarValue::Null)),
3686 Operator::Or,
3687 Box::new(col("b")),
3688 ));
3689 let result = apply_type_coercion(&expr, &schema).unwrap();
3690 assert!(is_binary_op(&result, Operator::Or));
3691 }
3692
3693 #[test]
3694 fn test_coercion_null_and_null() {
3695 let schema = make_schema(&[]);
3696 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3697 Box::new(lit(ScalarValue::Null)),
3698 Operator::And,
3699 Box::new(lit(ScalarValue::Null)),
3700 ));
3701 let result = apply_type_coercion(&expr, &schema).unwrap();
3702 assert!(is_binary_op(&result, Operator::And));
3703 }
3704
3705 #[test]
3706 fn test_coercion_bool_and_bool_noop() {
3707 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3708 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3709 Box::new(col("a")),
3710 Operator::And,
3711 Box::new(col("b")),
3712 ));
3713 let result = apply_type_coercion(&expr, &schema).unwrap();
3714 assert!(is_binary_op(&result, Operator::And));
3716 let s = format!("{}", result);
3717 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3718 }
3719
3720 #[test]
3721 fn test_coercion_case_when_lb() {
3722 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3724 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3725 Box::new(col("lb")),
3726 Operator::Eq,
3727 Box::new(lit(42_i64)),
3728 ));
3729 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3730 expr: None,
3731 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3732 else_expr: Some(Box::new(lit("b"))),
3733 });
3734 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3735 let s = format!("{}", result);
3736 assert!(
3738 s.contains("_cypher_equal"),
3739 "CASE WHEN should have _cypher_equal, got: {s}"
3740 );
3741 }
3742
3743 #[test]
3744 fn test_coercion_case_then_lb() {
3745 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3747 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3748 Box::new(col("lb")),
3749 Operator::Plus,
3750 Box::new(lit(1_i64)),
3751 ));
3752 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3753 expr: None,
3754 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3755 else_expr: Some(Box::new(lit(0_i64))),
3756 });
3757 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3758 let s = format!("{}", result);
3759 assert!(
3760 s.contains("_cypher_add"),
3761 "CASE THEN should have _cypher_add, got: {s}"
3762 );
3763 }
3764
3765 #[test]
3766 fn test_coercion_case_else_lb() {
3767 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3769 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3770 Box::new(col("lb")),
3771 Operator::Plus,
3772 Box::new(lit(2_i64)),
3773 ));
3774 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3775 expr: None,
3776 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3777 else_expr: Some(Box::new(else_expr)),
3778 });
3779 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3780 let s = format!("{}", result);
3781 assert!(
3782 s.contains("_cypher_add"),
3783 "CASE ELSE should have _cypher_add, got: {s}"
3784 );
3785 }
3786
3787 #[test]
3788 fn test_coercion_int64_eq_int64_noop() {
3789 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3790 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3791 Box::new(col("a")),
3792 Operator::Eq,
3793 Box::new(col("b")),
3794 ));
3795 let result = apply_type_coercion(&expr, &schema).unwrap();
3796 assert!(is_binary_op(&result, Operator::Eq));
3797 let s = format!("{}", result);
3798 assert!(
3799 !s.contains("_cypher_value"),
3800 "should not contain cypher_value decode: {s}"
3801 );
3802 }
3803
3804 #[test]
3805 fn test_coercion_both_lb_plus() {
3806 let schema = make_schema(&[
3808 ("lb1", DataType::LargeBinary),
3809 ("lb2", DataType::LargeBinary),
3810 ]);
3811 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3812 Box::new(col("lb1")),
3813 Operator::Plus,
3814 Box::new(col("lb2")),
3815 ));
3816 let result = apply_type_coercion(&expr, &schema).unwrap();
3817 assert!(
3818 contains_udf(&result, "_cypher_add"),
3819 "expected _cypher_add, got: {result}"
3820 );
3821 }
3822
3823 #[test]
3824 fn test_coercion_native_list_plus_scalar() {
3825 let schema = make_schema(&[
3827 (
3828 "lst",
3829 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3830 ),
3831 ("i", DataType::Int32),
3832 ]);
3833 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3834 Box::new(col("lst")),
3835 Operator::Plus,
3836 Box::new(col("i")),
3837 ));
3838 let result = apply_type_coercion(&expr, &schema).unwrap();
3839 assert!(
3840 contains_udf(&result, "_cypher_list_append"),
3841 "expected _cypher_list_append, got: {result}"
3842 );
3843 }
3844
3845 #[test]
3846 fn test_coercion_lb_plus_int64_unchanged() {
3847 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3849 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3850 Box::new(col("lb")),
3851 Operator::Plus,
3852 Box::new(col("i")),
3853 ));
3854 let result = apply_type_coercion(&expr, &schema).unwrap();
3855 assert!(
3856 contains_udf(&result, "_cypher_add"),
3857 "expected _cypher_add, got: {result}"
3858 );
3859 }
3860
3861 #[test]
3866 fn test_mixed_list_with_variables_compiles() {
3867 let expr = Expr::List(vec![
3869 Expr::Variable("n".to_string()),
3870 Expr::Literal(CypherLiteral::Integer(1)),
3871 Expr::Literal(CypherLiteral::String("hello".to_string())),
3872 ]);
3873 let result = cypher_expr_to_df(&expr, None).unwrap();
3874 let s = format!("{}", result);
3875 assert!(
3876 s.contains("_make_cypher_list"),
3877 "expected _make_cypher_list UDF call, got: {s}"
3878 );
3879 }
3880
3881 #[test]
3882 fn test_literal_only_mixed_list_uses_cv_fastpath() {
3883 let expr = Expr::List(vec![
3885 Expr::Literal(CypherLiteral::Integer(1)),
3886 Expr::Literal(CypherLiteral::String("hi".to_string())),
3887 Expr::Literal(CypherLiteral::Bool(true)),
3888 ]);
3889 let result = cypher_expr_to_df(&expr, None).unwrap();
3890 assert!(
3891 matches!(result, DfExpr::Literal(..)),
3892 "expected Literal (CypherValue fast path), got: {result}"
3893 );
3894 }
3895
3896 #[test]
3901 fn test_in_mixed_literal_list_uses_cypher_in() {
3902 let expr = Expr::In {
3904 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3905 list: Box::new(Expr::List(vec![
3906 Expr::Literal(CypherLiteral::String("1".to_string())),
3907 Expr::Literal(CypherLiteral::Integer(2)),
3908 ])),
3909 };
3910 let result = cypher_expr_to_df(&expr, None).unwrap();
3911 let s = format!("{}", result);
3912 assert!(
3913 s.contains("_cypher_in"),
3914 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
3915 );
3916 }
3917
3918 #[test]
3919 fn test_in_homogeneous_literal_list_uses_cypher_in() {
3920 let expr = Expr::In {
3922 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3923 list: Box::new(Expr::List(vec![
3924 Expr::Literal(CypherLiteral::Integer(2)),
3925 Expr::Literal(CypherLiteral::Integer(3)),
3926 ])),
3927 };
3928 let result = cypher_expr_to_df(&expr, None).unwrap();
3929 let s = format!("{}", result);
3930 assert!(
3931 s.contains("_cypher_in"),
3932 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
3933 );
3934 }
3935
3936 #[test]
3937 fn test_in_list_with_variables_uses_make_cypher_list() {
3938 let expr = Expr::In {
3940 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3941 list: Box::new(Expr::List(vec![
3942 Expr::Variable("x".to_string()),
3943 Expr::Literal(CypherLiteral::Integer(2)),
3944 ])),
3945 };
3946 let result = cypher_expr_to_df(&expr, None).unwrap();
3947 let s = format!("{}", result);
3948 assert!(
3949 s.contains("_cypher_in"),
3950 "expected _cypher_in UDF, got: {s}"
3951 );
3952 assert!(
3953 s.contains("_make_cypher_list"),
3954 "expected _make_cypher_list for variable-containing list, got: {s}"
3955 );
3956 }
3957
3958 #[test]
3963 fn test_property_on_graph_entity_uses_column() {
3964 let mut ctx = TranslationContext::new();
3966 ctx.variable_kinds
3967 .insert("n".to_string(), VariableKind::Node);
3968
3969 let expr = Expr::Property(
3970 Box::new(Expr::Variable("n".to_string())),
3971 "name".to_string(),
3972 );
3973 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
3974 let s = format!("{:?}", result);
3975 assert!(
3976 s.contains("Column") && s.contains("n.name"),
3977 "expected flat column 'n.name' for graph entity, got: {s}"
3978 );
3979 }
3980
3981 #[test]
3982 fn test_property_on_non_graph_var_uses_index() {
3983 let ctx = TranslationContext::new();
3985
3986 let expr = Expr::Property(
3987 Box::new(Expr::Variable("map".to_string())),
3988 "name".to_string(),
3989 );
3990 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
3991 let s = format!("{}", result);
3992 assert!(
3993 s.contains("index"),
3994 "expected index UDF for non-graph variable, got: {s}"
3995 );
3996 }
3997
3998 #[test]
3999 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4000 let mut map = std::collections::HashMap::new();
4001 map.insert("k".to_string(), Value::Int(1));
4002 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4003 assert!(
4004 matches!(scalar, ScalarValue::Struct(_)),
4005 "expected Struct scalar for map input"
4006 );
4007 }
4008
4009 #[test]
4010 fn test_value_to_scalar_empty_map_becomes_struct() {
4011 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4012 assert!(
4013 matches!(scalar, ScalarValue::Struct(_)),
4014 "empty map should produce an empty Struct scalar"
4015 );
4016 }
4017
4018 #[test]
4019 fn test_value_to_scalar_null_is_untyped_null() {
4020 let scalar = value_to_scalar(&Value::Null).unwrap();
4021 assert!(
4022 matches!(scalar, ScalarValue::Null),
4023 "expected untyped Null scalar for Value::Null"
4024 );
4025 }
4026
4027 #[test]
4028 fn test_value_to_scalar_datetime_produces_struct() {
4029 let datetime = Value::Temporal(TemporalValue::DateTime {
4031 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4034 });
4035
4036 let scalar = value_to_scalar(&datetime).unwrap();
4037
4038 if let ScalarValue::Struct(struct_arr) = scalar {
4040 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4041 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4042
4043 let fields = struct_arr.fields();
4045 assert_eq!(fields[0].name(), "nanos_since_epoch");
4046 assert_eq!(fields[1].name(), "offset_seconds");
4047 assert_eq!(fields[2].name(), "timezone_name");
4048
4049 let nanos_col = struct_arr.column(0);
4051 let offset_col = struct_arr.column(1);
4052 let tz_col = struct_arr.column(2);
4053
4054 if let Some(nanos_arr) = nanos_col
4055 .as_any()
4056 .downcast_ref::<TimestampNanosecondArray>()
4057 {
4058 assert_eq!(nanos_arr.value(0), 441763200000000000);
4059 } else {
4060 panic!("Expected TimestampNanosecondArray for nanos field");
4061 }
4062
4063 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4064 assert_eq!(offset_arr.value(0), 3600);
4065 } else {
4066 panic!("Expected Int32Array for offset field");
4067 }
4068
4069 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4070 assert_eq!(tz_arr.value(0), "Europe/Paris");
4071 } else {
4072 panic!("Expected StringArray for timezone_name field");
4073 }
4074 } else {
4075 panic!(
4076 "Expected ScalarValue::Struct for DateTime, got {:?}",
4077 scalar
4078 );
4079 }
4080 }
4081
4082 #[test]
4083 fn test_value_to_scalar_datetime_with_null_timezone() {
4084 let datetime = Value::Temporal(TemporalValue::DateTime {
4086 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4089 });
4090
4091 let scalar = value_to_scalar(&datetime).unwrap();
4092
4093 if let ScalarValue::Struct(struct_arr) = scalar {
4094 assert_eq!(struct_arr.num_columns(), 3);
4095
4096 let tz_col = struct_arr.column(2);
4098 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4099 assert!(tz_arr.is_null(0), "expected null timezone_name");
4100 } else {
4101 panic!("Expected StringArray for timezone_name field");
4102 }
4103 } else {
4104 panic!("Expected ScalarValue::Struct for DateTime");
4105 }
4106 }
4107
4108 #[test]
4109 fn test_value_to_scalar_time_produces_struct() {
4110 let time = Value::Temporal(TemporalValue::Time {
4112 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4115
4116 let scalar = value_to_scalar(&time).unwrap();
4117
4118 if let ScalarValue::Struct(struct_arr) = scalar {
4120 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4121 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4122
4123 let fields = struct_arr.fields();
4125 assert_eq!(fields[0].name(), "nanos_since_midnight");
4126 assert_eq!(fields[1].name(), "offset_seconds");
4127
4128 let nanos_col = struct_arr.column(0);
4130 let offset_col = struct_arr.column(1);
4131
4132 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4133 assert_eq!(nanos_arr.value(0), 37845000000000);
4134 } else {
4135 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4136 }
4137
4138 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4139 assert_eq!(offset_arr.value(0), 3600);
4140 } else {
4141 panic!("Expected Int32Array for offset field");
4142 }
4143 } else {
4144 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4145 }
4146 }
4147
4148 #[test]
4149 fn test_value_to_scalar_time_boundary_values() {
4150 let midnight = Value::Temporal(TemporalValue::Time {
4152 nanos_since_midnight: 0,
4153 offset_seconds: 0,
4154 });
4155
4156 let scalar = value_to_scalar(&midnight).unwrap();
4157
4158 if let ScalarValue::Struct(struct_arr) = scalar {
4159 let nanos_col = struct_arr.column(0);
4160 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4161 assert_eq!(nanos_arr.value(0), 0);
4162 } else {
4163 panic!("Expected Time64NanosecondArray");
4164 }
4165 } else {
4166 panic!("Expected ScalarValue::Struct for Time");
4167 }
4168 }
4169}