1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub(crate) fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub(crate) fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub(crate) fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub(crate) fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum VariableKind {
224 Node,
226 Edge,
228 EdgeList,
230 Path,
232}
233
234impl VariableKind {
235 pub fn edge_for(is_variable_length: bool) -> Self {
239 if is_variable_length {
240 Self::EdgeList
241 } else {
242 Self::Edge
243 }
244 }
245}
246
247pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
282 match expr {
283 Expr::PatternComprehension { .. } => Err(anyhow!(
284 "Pattern comprehensions require fallback executor (graph traversal)"
285 )),
286 Expr::Wildcard => Ok(DfExpr::Literal(
290 datafusion::common::ScalarValue::Int32(Some(1)),
291 None,
292 )),
293
294 Expr::Variable(name) => {
295 if let Some(ctx) = context
300 && ctx.variable_kinds.contains_key(name)
301 {
302 return Ok(DfExpr::Column(Column::from_name(name)));
303 }
304
305 if let Some(ctx) = context
311 && let Some(value) = ctx.outer_values.get(name)
312 {
313 return value_to_scalar(value).map(lit);
314 }
315
316 if let Some(ctx) = context
321 && let Some(value) = ctx.parameters.get(name)
322 {
323 match value {
326 Value::List(values) if name.ends_with("._vid") => {
327 let literals = values
329 .iter()
330 .map(|v| value_to_scalar(v).map(lit))
331 .collect::<Result<Vec<_>>>()?;
332 return Ok(DfExpr::InList(InList {
333 expr: Box::new(DfExpr::Column(Column::from_name(name))),
334 list: literals,
335 negated: false,
336 }));
337 }
338 other_value => return value_to_scalar(other_value).map(lit),
339 }
340 }
341
342 Ok(DfExpr::Column(Column::from_name(name)))
345 }
346
347 Expr::Property(base, prop) => translate_property_access(base, prop, context),
348
349 Expr::ArrayIndex { array, index } => {
350 if let Ok(var_name) = extract_variable_name(array)
353 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
354 {
355 let col_name = format!("{}.{}", var_name, prop_name);
356 return Ok(DfExpr::Column(Column::from_name(col_name)));
357 }
358
359 let array_expr = cypher_expr_to_df(array, context)?;
360 let index_expr = cypher_expr_to_df(index, context)?;
361
362 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
364 }
365
366 Expr::ArraySlice { array, start, end } => {
367 let array_expr = cypher_expr_to_df(array, context)?;
371
372 let start_expr = match start {
373 Some(s) => cypher_expr_to_df(s, context)?,
374 None => lit(0i64),
375 };
376
377 let end_expr = match end {
378 Some(e) => cypher_expr_to_df(e, context)?,
379 None => lit(i64::MAX),
380 };
381
382 Ok(dummy_udf_expr(
385 "_cypher_list_slice",
386 vec![array_expr, start_expr, end_expr],
387 ))
388 }
389
390 Expr::Parameter(name) => {
391 if let Some(ctx) = context
393 && let Some(value) = ctx.parameters.get(name)
394 {
395 return value_to_scalar(value).map(lit);
396 }
397 Err(anyhow!("Unresolved parameter: ${}", name))
398 }
399
400 Expr::Literal(value) => {
401 let scalar = cypher_literal_to_scalar(value)?;
402 Ok(lit(scalar))
403 }
404
405 Expr::List(items) => translate_list_literal(items, context),
406
407 Expr::Map(entries) => {
408 if entries.is_empty() {
409 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
411 Default::default(),
412 ));
413 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
414 }
415 let mut args = Vec::with_capacity(entries.len() * 2);
418 for (key, val_expr) in entries {
419 args.push(lit(key.clone()));
420 args.push(cypher_expr_to_df(val_expr, context)?);
421 }
422 Ok(datafusion::functions::expr_fn::named_struct(args))
423 }
424
425 Expr::IsNull(inner) => translate_null_check(inner, context, true),
426
427 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
428
429 Expr::IsUnique(_) => {
430 Err(anyhow!(
432 "IS UNIQUE can only be used in constraint definitions"
433 ))
434 }
435
436 Expr::FunctionCall {
437 name,
438 args,
439 distinct,
440 window_spec,
441 } => {
442 if window_spec.is_some() {
445 let col_name = expr.to_string_repr();
447 Ok(col(&col_name))
448 } else {
449 translate_function_call(name, args, *distinct, context)
450 }
451 }
452
453 Expr::In { expr, list } => translate_in_expression(expr, list, context),
454
455 Expr::BinaryOp { left, op, right } => {
456 let left_expr = cypher_expr_to_df(left, context)?;
457 let right_expr = cypher_expr_to_df(right, context)?;
458 translate_binary_op(left_expr, op, right_expr)
459 }
460
461 Expr::UnaryOp { op, expr: inner } => {
462 let inner_expr = cypher_expr_to_df(inner, context)?;
463 match op {
464 UnaryOp::Not => Ok(inner_expr.not()),
465 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
466 }
467 }
468
469 Expr::Case {
470 expr,
471 when_then,
472 else_expr,
473 } => translate_case_expression(expr, when_then, else_expr, context),
474
475 Expr::Reduce { .. } => Err(anyhow!(
476 "Reduce expressions not yet supported in DataFusion translation"
477 )),
478
479 Expr::Exists { .. } => Err(anyhow!(
480 "EXISTS subqueries are handled by the physical expression compiler, \
481 not the DataFusion logical expression translator"
482 )),
483
484 Expr::CountSubquery(_) => Err(anyhow!(
485 "Count subqueries not yet supported in DataFusion translation"
486 )),
487
488 Expr::CollectSubquery(_) => Err(anyhow!(
489 "COLLECT subqueries not yet supported in DataFusion translation"
490 )),
491
492 Expr::Quantifier { .. } => {
493 Err(anyhow!(
498 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
499 via CypherPhysicalExprCompiler"
500 ))
501 }
502
503 Expr::ListComprehension { .. } => {
504 Err(anyhow!(
516 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
517 ))
518 }
519
520 Expr::ValidAt { .. } => {
521 Err(anyhow!(
524 "VALID_AT expression should have been transformed to function call in planner"
525 ))
526 }
527
528 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
529
530 Expr::LabelCheck { expr, labels } => {
531 if let Expr::Variable(var) = expr.as_ref() {
532 let is_edge = context
534 .and_then(|ctx| ctx.variable_kinds.get(var))
535 .is_some_and(|k| matches!(k, VariableKind::Edge));
536
537 if is_edge {
538 if labels.len() > 1 {
542 Ok(lit(false))
543 } else {
544 let type_col =
545 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
546 Ok(DfExpr::Case(datafusion::logical_expr::Case {
548 expr: None,
549 when_then_expr: vec![(
550 Box::new(type_col.clone().is_null()),
551 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
552 )],
553 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
554 }))
555 }
556 } else {
557 let labels_col =
559 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
560 let checks = labels
561 .iter()
562 .map(|label| {
563 datafusion::functions_nested::expr_fn::array_has(
564 labels_col.clone(),
565 lit(label.clone()),
566 )
567 })
568 .reduce(|acc, check| acc.and(check));
569 Ok(DfExpr::Case(datafusion::logical_expr::Case {
571 expr: None,
572 when_then_expr: vec![(
573 Box::new(labels_col.is_null()),
574 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
575 )],
576 else_expr: Some(Box::new(checks.unwrap())),
577 }))
578 }
579 } else {
580 Err(anyhow!(
581 "LabelCheck on non-variable expression not yet supported in DataFusion"
582 ))
583 }
584 }
585 }
586}
587
588#[derive(Debug, Clone)]
592pub struct TranslationContext {
593 pub parameters: std::collections::HashMap<String, Value>,
595
596 pub outer_values: std::collections::HashMap<String, Value>,
600
601 pub variable_labels: std::collections::HashMap<String, String>,
603
604 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
606
607 pub node_variable_hints: Vec<String>,
610
611 pub mutation_edge_hints: Vec<String>,
614
615 pub statement_time: chrono::DateTime<chrono::Utc>,
620}
621
622impl Default for TranslationContext {
623 fn default() -> Self {
624 Self {
625 parameters: std::collections::HashMap::new(),
626 outer_values: std::collections::HashMap::new(),
627 variable_labels: std::collections::HashMap::new(),
628 variable_kinds: std::collections::HashMap::new(),
629 node_variable_hints: Vec::new(),
630 mutation_edge_hints: Vec::new(),
631 statement_time: chrono::Utc::now(),
632 }
633 }
634}
635
636impl TranslationContext {
637 pub fn new() -> Self {
639 Self::default()
640 }
641
642 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
644 self.parameters.insert(name.into(), value);
645 self
646 }
647
648 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
650 self.variable_labels.insert(var.into(), label.into());
651 self
652 }
653}
654
655fn extract_variable_name(expr: &Expr) -> Result<String> {
657 match expr {
658 Expr::Variable(name) => Ok(name.clone()),
659 Expr::Property(base, _) => extract_variable_name(base),
660 _ => Err(anyhow!(
661 "Cannot extract variable name from expression: {:?}",
662 expr
663 )),
664 }
665}
666
667fn translate_null_check(
669 inner: &Expr,
670 context: Option<&TranslationContext>,
671 is_null: bool,
672) -> Result<DfExpr> {
673 if let Expr::Variable(var) = inner
674 && let Some(ctx) = context
675 && let Some(kind) = ctx.variable_kinds.get(var)
676 {
677 let col_name = match kind {
678 VariableKind::Node => format!("{}.{}", var, COL_VID),
679 VariableKind::Edge => format!("{}.{}", var, COL_EID),
680 VariableKind::Path | VariableKind::EdgeList => var.clone(),
681 };
682 let col_expr = DfExpr::Column(Column::from_name(col_name));
683 return Ok(if is_null {
684 col_expr.is_null()
685 } else {
686 col_expr.is_not_null()
687 });
688 }
689
690 let inner_expr = cypher_expr_to_df(inner, context)?;
691 Ok(if is_null {
692 inner_expr.is_null()
693 } else {
694 inner_expr.is_not_null()
695 })
696}
697
698fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
703 if crate::query::datetime::is_duration_accessor(prop) {
704 Some(dummy_udf_expr(
705 "_duration_property",
706 vec![base_expr, lit(prop.to_string())],
707 ))
708 } else if crate::query::datetime::is_temporal_accessor(prop) {
709 Some(dummy_udf_expr(
710 "_temporal_property",
711 vec![base_expr, lit(prop.to_string())],
712 ))
713 } else {
714 None
715 }
716}
717
718fn translate_property_access(
720 base: &Expr,
721 prop: &str,
722 context: Option<&TranslationContext>,
723) -> Result<DfExpr> {
724 if let Ok(var_name) = extract_variable_name(base) {
725 let is_graph_entity = context
726 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
727 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
728
729 if !is_graph_entity
730 && let Some(expr) =
731 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
732 {
733 return Ok(expr);
734 }
735
736 let col_name = format!("{}.{}", var_name, prop);
737
738 if let Some(ctx) = context
741 && let Some(value) = ctx.parameters.get(&col_name)
742 {
743 match value {
746 Value::List(values) if col_name.ends_with("._vid") => {
747 let literals = values
748 .iter()
749 .map(|v| value_to_scalar(v).map(lit))
750 .collect::<Result<Vec<_>>>()?;
751 return Ok(DfExpr::InList(InList {
752 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
753 list: literals,
754 negated: false,
755 }));
756 }
757 other_value => return value_to_scalar(other_value).map(lit),
758 }
759 }
760
761 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
764 let base_expr = cypher_expr_to_df(base, context)?;
765 return Ok(dummy_udf_expr(
766 "index",
767 vec![base_expr, lit(prop.to_string())],
768 ));
769 }
770
771 if is_graph_entity {
772 Ok(DfExpr::Column(Column::from_name(col_name)))
773 } else {
774 let base_expr = DfExpr::Column(Column::from_name(var_name));
775 Ok(dummy_udf_expr(
776 "index",
777 vec![base_expr, lit(prop.to_string())],
778 ))
779 }
780 } else {
781 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
783 return Ok(expr);
784 }
785
786 if let Expr::Parameter(param_name) = base {
788 if let Some(ctx) = context
789 && let Some(value) = ctx.parameters.get(param_name)
790 {
791 if let Value::Map(map) = value {
792 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
793 return value_to_scalar(&extracted).map(lit);
794 }
795 return Ok(lit(ScalarValue::Null));
796 }
797 return Err(anyhow!("Unresolved parameter: ${}", param_name));
798 }
799
800 let base_expr = cypher_expr_to_df(base, context)?;
801 Ok(dummy_udf_expr(
802 "index",
803 vec![base_expr, lit(prop.to_string())],
804 ))
805 }
806}
807
808fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
810 let mut has_string = false;
812 let mut has_bool = false;
813 let mut has_list = false;
814 let mut has_map = false;
815 let mut has_numeric = false;
816 let mut has_graph_entity = false;
817 let mut has_temporal = false;
818
819 for item in items {
820 match item {
821 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
822 has_numeric = true
823 }
824 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
825 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
826 Expr::List(_) => has_list = true,
827 Expr::Map(_) => has_map = true,
828 Expr::Variable(name) => {
831 if context
832 .and_then(|ctx| ctx.variable_kinds.get(name))
833 .is_some()
834 {
835 has_graph_entity = true;
836 }
837 }
838 Expr::FunctionCall { name, .. } => {
841 let upper = name.to_uppercase();
842 if matches!(
843 upper.as_str(),
844 "DATE"
845 | "TIME"
846 | "LOCALTIME"
847 | "LOCALDATETIME"
848 | "DATETIME"
849 | "DURATION"
850 | "DATE.TRUNCATE"
851 | "TIME.TRUNCATE"
852 | "DATETIME.TRUNCATE"
853 | "LOCALDATETIME.TRUNCATE"
854 | "LOCALTIME.TRUNCATE"
855 ) {
856 has_temporal = true;
857 }
858 }
859 _ => {}
861 }
862 }
863
864 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
866
867 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
870 if let Some(json_array) = try_items_to_json(items) {
872 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
873 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
874 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
875 }
876 let df_args: Vec<DfExpr> = items
878 .iter()
879 .map(|item| cypher_expr_to_df(item, context))
880 .collect::<Result<_>>()?;
881 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
882 }
883
884 let mut df_args = Vec::with_capacity(items.len());
887 let mut has_float = false;
888 let mut has_int = false;
889 let mut has_other = false;
890
891 for item in items {
892 match item {
893 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
894 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
895 _ => has_other = true,
896 }
897 df_args.push(cypher_expr_to_df(item, context)?);
898 }
899
900 if df_args.is_empty() {
901 let empty_arr =
903 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
904 Ok(lit(ScalarValue::List(empty_arr)))
905 } else if has_float && has_int && !has_other {
906 let promoted_args = df_args
908 .into_iter()
909 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
910 .collect();
911 Ok(datafusion::functions_nested::expr_fn::make_array(
912 promoted_args,
913 ))
914 } else {
915 let non_null_type = df_args.iter().find_map(|e| {
919 if let DfExpr::Literal(sv, _) = e {
920 let dt = sv.data_type();
921 if dt != datafusion::arrow::datatypes::DataType::Null {
922 return Some(dt);
923 }
924 }
925 None
926 });
927 if let Some(ref target_type) = non_null_type {
928 let coerced = df_args
929 .into_iter()
930 .map(|e| {
931 if matches!(&e, DfExpr::Literal(sv, _) if sv.data_type() == datafusion::arrow::datatypes::DataType::Null)
932 {
933 cast_expr(e, target_type.clone())
934 } else {
935 e
936 }
937 })
938 .collect();
939 Ok(datafusion::functions_nested::expr_fn::make_array(coerced))
940 } else {
941 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
942 }
943 }
944}
945
946fn translate_in_expression(
948 expr: &Expr,
949 list: &Expr,
950 context: Option<&TranslationContext>,
951) -> Result<DfExpr> {
952 let left_expr = if let Expr::Variable(var) = expr
957 && let Some(ctx) = context
958 && let Some(kind) = ctx.variable_kinds.get(var)
959 {
960 match kind {
961 VariableKind::Node | VariableKind::Edge => {
962 let id_col = match kind {
963 VariableKind::Node => COL_VID,
964 VariableKind::Edge => COL_EID,
965 _ => unreachable!(),
966 };
967 cast_expr(
968 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
969 datafusion::arrow::datatypes::DataType::Int64,
970 )
971 }
972 _ => cypher_expr_to_df(expr, context)?,
973 }
974 } else {
975 cypher_expr_to_df(expr, context)?
976 };
977
978 if let Expr::List(items) = list {
983 if let Some(json_array) = try_items_to_json(items) {
984 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
986 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
987 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
988 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
989 } else {
990 let expanded: Vec<DfExpr> = items
992 .iter()
993 .map(|item| cypher_expr_to_df(item, context))
994 .collect::<Result<Vec<_>>>()?;
995 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
996 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
997 }
998 } else {
999 let right_expr = cypher_expr_to_df(list, context)?;
1000
1001 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
1006 return Ok(lit(ScalarValue::Boolean(None)));
1007 }
1008
1009 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
1010 }
1011}
1012
1013fn translate_case_expression(
1015 operand: &Option<Box<Expr>>,
1016 when_then: &[(Expr, Expr)],
1017 else_expr: &Option<Box<Expr>>,
1018 context: Option<&TranslationContext>,
1019) -> Result<DfExpr> {
1020 let mut case_builder = if let Some(match_expr) = operand {
1021 let match_df = cypher_expr_to_df(match_expr, context)?;
1022 datafusion::logical_expr::case(match_df)
1023 } else {
1024 datafusion::logical_expr::when(
1025 cypher_expr_to_df(&when_then[0].0, context)?,
1026 cypher_expr_to_df(&when_then[0].1, context)?,
1027 )
1028 };
1029
1030 let start_idx = if operand.is_some() { 0 } else { 1 };
1031 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1032 let when_df = cypher_expr_to_df(when_expr, context)?;
1033 let then_df = cypher_expr_to_df(then_expr, context)?;
1034 case_builder = case_builder.when(when_df, then_df);
1035 }
1036
1037 if let Some(else_e) = else_expr {
1038 let else_df = cypher_expr_to_df(else_e, context)?;
1039 Ok(case_builder.otherwise(else_df)?)
1040 } else {
1041 Ok(case_builder.end()?)
1042 }
1043}
1044
1045fn translate_map_projection(
1047 base: &Expr,
1048 items: &[MapProjectionItem],
1049 context: Option<&TranslationContext>,
1050) -> Result<DfExpr> {
1051 let mut args = Vec::new();
1052 for item in items {
1053 match item {
1054 MapProjectionItem::Property(prop) => {
1055 args.push(lit(prop.clone()));
1056 let prop_expr = cypher_expr_to_df(
1057 &Expr::Property(Box::new(base.clone()), prop.clone()),
1058 context,
1059 )?;
1060 args.push(prop_expr);
1061 }
1062 MapProjectionItem::LiteralEntry(key, expr) => {
1063 args.push(lit(key.clone()));
1064 args.push(cypher_expr_to_df(expr, context)?);
1065 }
1066 MapProjectionItem::Variable(var) => {
1067 args.push(lit(var.clone()));
1068 args.push(DfExpr::Column(Column::from_name(var)));
1069 }
1070 MapProjectionItem::AllProperties => {
1071 args.push(lit("__all__"));
1072 args.push(cypher_expr_to_df(base, context)?);
1073 }
1074 }
1075 }
1076 Ok(dummy_udf_expr("_map_project", args))
1077}
1078
1079fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1082 match expr {
1083 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1084 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1085 Expr::Literal(CypherLiteral::Integer(i)) => {
1086 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1087 }
1088 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1089 .map(serde_json::Value::Number)
1090 .or(Some(serde_json::Value::Null)),
1091 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1092 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1093 Expr::Map(entries) => {
1094 let mut map = serde_json::Map::new();
1095 for (k, v) in entries {
1096 map.insert(k.clone(), try_expr_to_json(v)?);
1097 }
1098 Some(serde_json::Value::Object(map))
1099 }
1100 _ => None,
1101 }
1102}
1103
1104fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1106 items.iter().map(try_expr_to_json).collect()
1107}
1108
1109fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1111 match lit {
1112 CypherLiteral::Null => Ok(ScalarValue::Null),
1113 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1114 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1115 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1116 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1117 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1118 }
1119}
1120
1121fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1123 match value {
1124 Value::Null => Ok(ScalarValue::Null),
1125 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1126 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1127 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1128 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1129 Value::List(items) => {
1130 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1132 let scalars = scalars?;
1133
1134 let data_type = infer_common_scalar_type(&scalars);
1136
1137 let typed_scalars: Vec<ScalarValue> = scalars
1139 .into_iter()
1140 .map(|s| {
1141 if matches!(s, ScalarValue::Null) {
1142 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1143 }
1144
1145 match (s, &data_type) {
1146 (
1147 ScalarValue::Int64(Some(v)),
1148 datafusion::arrow::datatypes::DataType::Float64,
1149 ) => ScalarValue::Float64(Some(v as f64)),
1150 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1151 let s_str = s.to_string();
1153 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1154 }
1155 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1156 if matches!(s, ScalarValue::Utf8(_)) {
1158 s
1159 } else {
1160 ScalarValue::Utf8(Some(s.to_string()))
1161 }
1162 }
1163 (s, _) => s,
1164 }
1165 })
1166 .collect();
1167
1168 if typed_scalars.is_empty() {
1170 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1171 &[],
1172 &data_type,
1173 )))
1174 } else {
1175 Ok(ScalarValue::List(ScalarValue::new_list(
1176 &typed_scalars,
1177 &data_type,
1178 true,
1179 )))
1180 }
1181 }
1182 Value::Map(map) => {
1183 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1186 entries.sort_by_key(|(k, _)| *k);
1187
1188 if entries.is_empty() {
1189 return Ok(ScalarValue::Struct(Arc::new(
1190 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1191 )));
1192 }
1193
1194 let mut fields_arrays = Vec::with_capacity(entries.len());
1195
1196 for (k, v) in entries {
1197 let scalar = value_to_scalar(v)?;
1198 let dt = scalar.data_type();
1199 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1200 let array = scalar.to_array()?;
1201 fields_arrays.push((field, array));
1202 }
1203
1204 Ok(ScalarValue::Struct(Arc::new(
1205 datafusion::arrow::array::StructArray::from(fields_arrays),
1206 )))
1207 }
1208 Value::Temporal(tv) => {
1209 use uni_common::TemporalValue;
1210 match tv {
1211 TemporalValue::Date { days_since_epoch } => {
1212 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1213 }
1214 TemporalValue::LocalTime {
1215 nanos_since_midnight,
1216 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1217 TemporalValue::Time {
1218 nanos_since_midnight,
1219 offset_seconds,
1220 } => {
1221 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1223 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1224
1225 let nanos_arr =
1226 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1227 as ArrayRef;
1228 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1229
1230 let fields = Fields::from(vec![
1231 Field::new(
1232 "nanos_since_midnight",
1233 ArrowDataType::Time64(TimeUnit::Nanosecond),
1234 true,
1235 ),
1236 Field::new("offset_seconds", ArrowDataType::Int32, true),
1237 ]);
1238
1239 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1240 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1241 }
1242 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1243 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1244 ),
1245 TemporalValue::DateTime {
1246 nanos_since_epoch,
1247 offset_seconds,
1248 timezone_name,
1249 } => {
1250 use arrow::array::{
1252 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1253 };
1254 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1255
1256 let nanos_arr =
1257 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1258 as ArrayRef;
1259 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1260 let tz_arr =
1261 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1262
1263 let fields = Fields::from(vec![
1264 Field::new(
1265 "nanos_since_epoch",
1266 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1267 true,
1268 ),
1269 Field::new("offset_seconds", ArrowDataType::Int32, true),
1270 Field::new("timezone_name", ArrowDataType::Utf8, true),
1271 ]);
1272
1273 let struct_arr =
1274 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1275 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1276 }
1277 TemporalValue::Duration {
1278 months,
1279 days,
1280 nanos,
1281 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1282 arrow::datatypes::IntervalMonthDayNano {
1283 months: *months as i32,
1284 days: *days as i32,
1285 nanoseconds: *nanos,
1286 },
1287 ))),
1288 TemporalValue::Btic { lo, hi, meta } => {
1289 let btic = uni_btic::Btic::new(*lo, *hi, *meta)
1290 .map_err(|e| anyhow::anyhow!("invalid BTIC value: {}", e))?;
1291 let packed = uni_btic::encode::encode(&btic);
1292 Ok(ScalarValue::FixedSizeBinary(24, Some(packed.to_vec())))
1293 }
1294 }
1295 }
1296 Value::Vector(v) => {
1297 let cv_bytes = uni_common::cypher_value_codec::encode(&Value::Vector(v.clone()));
1299 Ok(ScalarValue::LargeBinary(Some(cv_bytes)))
1300 }
1301 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1302 other => {
1304 let json_val: serde_json::Value = other.clone().into();
1305 let json_str = serde_json::to_string(&json_val)
1306 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1307 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1308 }
1309 }
1310}
1311
1312fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1314 match op {
1315 BinaryOp::Eq => Ok(left.eq(right)),
1319 BinaryOp::NotEq => Ok(left.not_eq(right)),
1320 BinaryOp::Lt => Ok(left.lt(right)),
1321 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1322 BinaryOp::Gt => Ok(left.gt(right)),
1323 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1324
1325 BinaryOp::And => Ok(left.and(right)),
1327 BinaryOp::Or => Ok(left.or(right)),
1328 BinaryOp::Xor => {
1329 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1331 }
1332
1333 BinaryOp::Add => {
1335 if is_list_expr(&left) || is_list_expr(&right) {
1336 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1337 } else {
1338 Ok(left + right)
1339 }
1340 }
1341 BinaryOp::Sub => Ok(left - right),
1342 BinaryOp::Mul => Ok(left * right),
1343 BinaryOp::Div => Ok(left / right),
1344 BinaryOp::Mod => Ok(left % right),
1345 BinaryOp::Pow => {
1346 let left_f = datafusion::logical_expr::cast(
1349 left,
1350 datafusion::arrow::datatypes::DataType::Float64,
1351 );
1352 let right_f = datafusion::logical_expr::cast(
1353 right,
1354 datafusion::arrow::datatypes::DataType::Float64,
1355 );
1356 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1357 }
1358
1359 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1361 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1362 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1363
1364 BinaryOp::Regex => {
1365 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1366 }
1367
1368 BinaryOp::ApproxEq => Err(anyhow!(
1369 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1370 )),
1371 }
1372}
1373
1374macro_rules! check_args {
1379 (1, $df_args:expr, $name:expr) => {
1380 if let Err(e) = require_arg($df_args, $name) {
1381 return Some(Err(e));
1382 }
1383 };
1384 ($n:expr, $df_args:expr, $name:expr) => {
1385 if let Err(e) = require_args($df_args, $n, $name) {
1386 return Some(Err(e));
1387 }
1388 };
1389}
1390
1391fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1394 if df_args.len() < count {
1395 let noun = if count == 1 { "argument" } else { "arguments" };
1396 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1397 }
1398 Ok(())
1399}
1400
1401fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1403 require_args(df_args, 1, func_name)
1404}
1405
1406fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1408 df_args[0].clone()
1409}
1410
1411pub(crate) fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1413 DfExpr::Cast(datafusion::logical_expr::Cast {
1414 expr: Box::new(expr),
1415 data_type,
1416 })
1417}
1418
1419pub(crate) fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1425 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1426 Arc::new(crate::query::df_udfs::create_cypher_list_to_cv_udf()),
1427 vec![expr],
1428 ))
1429}
1430
1431pub(crate) fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1435 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1436 Arc::new(crate::query::df_udfs::create_cypher_scalar_to_cv_udf()),
1437 vec![expr],
1438 ))
1439}
1440
1441fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1443 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1444 Box::new(left),
1445 op,
1446 Box::new(right),
1447 ))
1448}
1449
1450pub(crate) fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1455 use datafusion::logical_expr::Operator;
1456 match op {
1457 Operator::Eq => Some("_cypher_equal"),
1458 Operator::NotEq => Some("_cypher_not_equal"),
1459 Operator::Lt => Some("_cypher_lt"),
1460 Operator::LtEq => Some("_cypher_lt_eq"),
1461 Operator::Gt => Some("_cypher_gt"),
1462 Operator::GtEq => Some("_cypher_gt_eq"),
1463 _ => None,
1464 }
1465}
1466
1467fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1469 use datafusion::logical_expr::Operator;
1470 match op {
1471 Operator::Plus => Some("_cypher_add"),
1472 Operator::Minus => Some("_cypher_sub"),
1473 Operator::Multiply => Some("_cypher_mul"),
1474 Operator::Divide => Some("_cypher_div"),
1475 Operator::Modulo => Some("_cypher_mod"),
1476 _ => None,
1477 }
1478}
1479
1480fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1485where
1486 F: FnOnce(DfExpr) -> DfExpr,
1487{
1488 require_arg(df_args, func_name)?;
1489 Ok(math_fn(cast_expr(
1490 first_arg(df_args),
1491 datafusion::arrow::datatypes::DataType::Float64,
1492 )))
1493}
1494
1495fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1497 if distinct {
1498 expr.distinct()
1499 .build()
1500 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1501 } else {
1502 Ok(expr)
1503 }
1504}
1505
1506fn translate_aggregate_function(
1508 name_upper: &str,
1509 df_args: &[DfExpr],
1510 distinct: bool,
1511) -> Option<Result<DfExpr>> {
1512 match name_upper {
1513 "COUNT" => {
1514 let expr = if df_args.is_empty() {
1515 datafusion::functions_aggregate::count::count(lit(1i64))
1516 } else {
1517 datafusion::functions_aggregate::count::count(first_arg(df_args))
1518 };
1519 Some(maybe_distinct(expr, distinct, "COUNT"))
1520 }
1521 "SUM" => {
1522 check_args!(1, df_args, "SUM");
1523 let udaf = Arc::new(crate::query::df_udfs::create_cypher_sum_udaf());
1524 Some(maybe_distinct(
1525 udaf.call(vec![first_arg(df_args)]),
1526 distinct,
1527 "SUM",
1528 ))
1529 }
1530 "AVG" => {
1531 check_args!(1, df_args, "AVG");
1532 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1533 let expr = datafusion::functions_aggregate::average::avg(coerced);
1534 Some(maybe_distinct(expr, distinct, "AVG"))
1535 }
1536 "MIN" => {
1537 check_args!(1, df_args, "MIN");
1538 let udaf = Arc::new(crate::query::df_udfs::create_cypher_min_udaf());
1539 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1540 }
1541 "MAX" => {
1542 check_args!(1, df_args, "MAX");
1543 let udaf = Arc::new(crate::query::df_udfs::create_cypher_max_udaf());
1544 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1545 }
1546 "PERCENTILEDISC" => {
1547 if df_args.len() != 2 {
1548 return Some(Err(anyhow!(
1549 "percentileDisc() requires exactly 2 arguments"
1550 )));
1551 }
1552 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1553 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_disc_udaf());
1554 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1555 }
1556 "PERCENTILECONT" => {
1557 if df_args.len() != 2 {
1558 return Some(Err(anyhow!(
1559 "percentileCont() requires exactly 2 arguments"
1560 )));
1561 }
1562 let coerced = crate::query::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1563 let udaf = Arc::new(crate::query::df_udfs::create_cypher_percentile_cont_udaf());
1564 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1565 }
1566 "COLLECT" => {
1567 check_args!(1, df_args, "COLLECT");
1568 Some(Ok(crate::query::df_udfs::create_cypher_collect_expr(
1569 first_arg(df_args),
1570 distinct,
1571 )))
1572 }
1573 "BTIC_MIN" => {
1575 check_args!(1, df_args, "btic_min");
1576 let udaf = Arc::new(crate::query::df_udfs::create_btic_min_udaf());
1577 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1578 }
1579 "BTIC_MAX" => {
1580 check_args!(1, df_args, "btic_max");
1581 let udaf = Arc::new(crate::query::df_udfs::create_btic_max_udaf());
1582 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1583 }
1584 "BTIC_SPAN_AGG" => {
1585 check_args!(1, df_args, "btic_span_agg");
1586 let udaf = Arc::new(crate::query::df_udfs::create_btic_span_agg_udaf());
1587 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1588 }
1589 "BTIC_COUNT_AT" => {
1590 if df_args.len() != 2 {
1591 return Some(Err(anyhow!("btic_count_at requires 2 arguments")));
1592 }
1593 let udaf = Arc::new(crate::query::df_udfs::create_btic_count_at_udaf());
1594 Some(Ok(udaf.call(df_args.to_vec())))
1595 }
1596 _ => None,
1597 }
1598}
1599
1600fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1603 match name_upper {
1604 "TOSTRING" => {
1605 check_args!(1, df_args, "toString");
1606 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1607 }
1608 "TOINTEGER" | "TOINT" => {
1609 check_args!(1, df_args, "toInteger");
1610 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1611 }
1612 "TOFLOAT" => {
1613 check_args!(1, df_args, "toFloat");
1614 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1615 }
1616 "TOBOOLEAN" | "TOBOOL" => {
1617 check_args!(1, df_args, "toBoolean");
1618 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1619 }
1620 "UPPER" | "TOUPPER" => {
1621 check_args!(1, df_args, "upper");
1622 Some(Ok(datafusion::functions::string::expr_fn::upper(
1623 first_arg(df_args),
1624 )))
1625 }
1626 "LOWER" | "TOLOWER" => {
1627 check_args!(1, df_args, "lower");
1628 Some(Ok(datafusion::functions::string::expr_fn::lower(
1629 first_arg(df_args),
1630 )))
1631 }
1632 "SUBSTRING" => {
1633 check_args!(2, df_args, "substring");
1634 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1635 }
1636 "TRIM" => {
1637 check_args!(1, df_args, "TRIM");
1638 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1639 first_arg(df_args),
1640 ])))
1641 }
1642 "LTRIM" => {
1643 check_args!(1, df_args, "LTRIM");
1644 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1645 first_arg(df_args),
1646 ])))
1647 }
1648 "RTRIM" => {
1649 check_args!(1, df_args, "RTRIM");
1650 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1651 first_arg(df_args),
1652 ])))
1653 }
1654 "LEFT" => {
1655 check_args!(2, df_args, "left");
1656 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1657 df_args[0].clone(),
1658 df_args[1].clone(),
1659 )))
1660 }
1661 "RIGHT" => {
1662 check_args!(2, df_args, "right");
1663 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1664 df_args[0].clone(),
1665 df_args[1].clone(),
1666 )))
1667 }
1668 "REPLACE" => {
1669 check_args!(3, df_args, "replace");
1670 Some(Ok(datafusion::functions::string::expr_fn::replace(
1671 df_args[0].clone(),
1672 df_args[1].clone(),
1673 df_args[2].clone(),
1674 )))
1675 }
1676 "REVERSE" => {
1677 check_args!(1, df_args, "reverse");
1678 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1679 }
1680 "SPLIT" => {
1681 check_args!(2, df_args, "split");
1682 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1683 }
1684 "SIZE" | "LENGTH" => {
1685 check_args!(1, df_args, name_upper);
1686 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1687 }
1688 _ => None,
1689 }
1690}
1691
1692fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1695 use datafusion::functions::math::expr_fn;
1696
1697 let unary_f64 =
1699 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1700
1701 match name_upper {
1702 "ABS" => {
1703 check_args!(1, df_args, "abs");
1704 Some(Ok(crate::query::df_udfs::cypher_abs_expr(first_arg(
1708 df_args,
1709 ))))
1710 }
1711 "CEIL" | "CEILING" => {
1712 check_args!(1, df_args, "ceil");
1713 Some(Ok(expr_fn::ceil(first_arg(df_args))))
1714 }
1715 "FLOOR" => {
1716 check_args!(1, df_args, "floor");
1717 Some(Ok(expr_fn::floor(first_arg(df_args))))
1718 }
1719 "ROUND" => {
1720 check_args!(1, df_args, "round");
1721 let args = if df_args.len() == 1 {
1722 vec![first_arg(df_args)]
1723 } else {
1724 vec![df_args[0].clone(), df_args[1].clone()]
1725 };
1726 Some(Ok(expr_fn::round(args)))
1727 }
1728 "SIGN" => {
1729 check_args!(1, df_args, "sign");
1730 let coerced = crate::query::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1731 Some(Ok(expr_fn::signum(coerced)))
1732 }
1733 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1734 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1735 "LOG10" => unary_f64("log10", expr_fn::log10),
1736 "EXP" => unary_f64("exp", expr_fn::exp),
1737 "SIN" => unary_f64("sin", expr_fn::sin),
1738 "COS" => unary_f64("cos", expr_fn::cos),
1739 "TAN" => unary_f64("tan", expr_fn::tan),
1740 "ASIN" => unary_f64("asin", expr_fn::asin),
1741 "ACOS" => unary_f64("acos", expr_fn::acos),
1742 "ATAN" => unary_f64("atan", expr_fn::atan),
1743 "ATAN2" => {
1744 check_args!(2, df_args, "atan2");
1745 let cast_f64 =
1746 |e: DfExpr| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64);
1747 Some(Ok(expr_fn::atan2(
1748 cast_f64(df_args[0].clone()),
1749 cast_f64(df_args[1].clone()),
1750 )))
1751 }
1752 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1753 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1754 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1755 _ => None,
1756 }
1757}
1758
1759fn translate_temporal_function(
1762 name_upper: &str,
1763 name: &str,
1764 df_args: &[DfExpr],
1765 context: Option<&TranslationContext>,
1766) -> Option<Result<DfExpr>> {
1767 match name_upper {
1768 "DATE"
1769 | "TIME"
1770 | "LOCALTIME"
1771 | "LOCALDATETIME"
1772 | "DATETIME"
1773 | "DURATION"
1774 | "YEAR"
1775 | "MONTH"
1776 | "DAY"
1777 | "HOUR"
1778 | "MINUTE"
1779 | "SECOND"
1780 | "DURATION.BETWEEN"
1781 | "DURATION.INMONTHS"
1782 | "DURATION.INDAYS"
1783 | "DURATION.INSECONDS"
1784 | "DATETIME.FROMEPOCH"
1785 | "DATETIME.FROMEPOCHMILLIS"
1786 | "DATE.TRUNCATE"
1787 | "TIME.TRUNCATE"
1788 | "DATETIME.TRUNCATE"
1789 | "LOCALDATETIME.TRUNCATE"
1790 | "LOCALTIME.TRUNCATE"
1791 | "DATETIME.TRANSACTION"
1792 | "DATETIME.STATEMENT"
1793 | "DATETIME.REALTIME"
1794 | "DATE.TRANSACTION"
1795 | "DATE.STATEMENT"
1796 | "DATE.REALTIME"
1797 | "TIME.TRANSACTION"
1798 | "TIME.STATEMENT"
1799 | "TIME.REALTIME"
1800 | "LOCALTIME.TRANSACTION"
1801 | "LOCALTIME.STATEMENT"
1802 | "LOCALTIME.REALTIME"
1803 | "LOCALDATETIME.TRANSACTION"
1804 | "LOCALDATETIME.STATEMENT"
1805 | "LOCALDATETIME.REALTIME" => {
1806 let stmt_time = context.map(|c| c.statement_time);
1810 if can_constant_fold(name_upper, df_args)
1811 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1812 {
1813 return Some(Ok(folded));
1814 }
1815 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1816 }
1817 _ => None,
1818 }
1819}
1820
1821fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1823 if name.contains("REALTIME") {
1825 return false;
1826 }
1827 if args.is_empty() {
1835 return matches!(
1836 name,
1837 "DATE"
1838 | "TIME"
1839 | "LOCALTIME"
1840 | "LOCALDATETIME"
1841 | "DATETIME"
1842 | "DATE.STATEMENT"
1843 | "TIME.STATEMENT"
1844 | "LOCALTIME.STATEMENT"
1845 | "LOCALDATETIME.STATEMENT"
1846 | "DATETIME.STATEMENT"
1847 | "DATE.TRANSACTION"
1848 | "TIME.TRANSACTION"
1849 | "LOCALTIME.TRANSACTION"
1850 | "LOCALDATETIME.TRANSACTION"
1851 | "DATETIME.TRANSACTION"
1852 );
1853 }
1854 args.iter().all(is_constant_expr)
1856}
1857
1858fn is_constant_expr(expr: &DfExpr) -> bool {
1860 match expr {
1861 DfExpr::Literal(_, _) => true,
1862 DfExpr::ScalarFunction(func) => {
1863 func.args.iter().all(is_constant_expr)
1865 }
1866 _ => false,
1867 }
1868}
1869
1870fn try_constant_fold_temporal(
1876 name: &str,
1877 args: &[DfExpr],
1878 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1879) -> Result<DfExpr> {
1880 let val_args: Vec<Value> = args
1882 .iter()
1883 .map(extract_constant_value)
1884 .collect::<Result<_>>()?;
1885
1886 let result = if val_args.is_empty() {
1888 if let Some(frozen) = stmt_time {
1889 crate::query::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1890 } else {
1891 crate::query::datetime::eval_datetime_function(name, &val_args)?
1892 }
1893 } else {
1894 crate::query::datetime::eval_datetime_function(name, &val_args)?
1895 };
1896
1897 let scalar = value_to_scalar(&result)?;
1899 Ok(DfExpr::Literal(scalar, None))
1900}
1901
1902fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1904 use crate::query::df_udfs::scalar_to_value;
1905 match expr {
1906 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1907 DfExpr::ScalarFunction(func) => {
1908 let mut map = std::collections::HashMap::new();
1911 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1912 for chunk in pairs.chunks(2) {
1913 if let [key_expr, val_expr] = chunk {
1914 let key = match key_expr {
1916 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1917 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1918 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1919 };
1920 let val = extract_constant_value(val_expr)?;
1921 map.insert(key, val);
1922 } else {
1923 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1924 }
1925 }
1926 Ok(Value::Map(map))
1927 }
1928 _ => Err(anyhow::anyhow!(
1929 "Cannot extract constant value from expression"
1930 )),
1931 }
1932}
1933
1934fn translate_btic_function(
1937 name_upper: &str,
1938 name: &str,
1939 df_args: &[DfExpr],
1940) -> Option<Result<DfExpr>> {
1941 if crate::query::expr_eval::is_btic_function(name_upper) {
1942 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1943 } else {
1944 None
1945 }
1946}
1947
1948fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1951 match name_upper {
1952 "HEAD" => {
1953 check_args!(1, df_args, "head");
1954 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1955 }
1956 "LAST" => {
1957 check_args!(1, df_args, "last");
1958 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1959 }
1960 "TAIL" => {
1961 check_args!(1, df_args, "tail");
1962 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
1963 }
1964 "RANGE" => {
1965 check_args!(2, df_args, "range");
1966 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
1967 }
1968 _ => None,
1969 }
1970}
1971
1972fn translate_graph_function(
1975 name_upper: &str,
1976 name: &str,
1977 df_args: &[DfExpr],
1978 args: &[Expr],
1979 context: Option<&TranslationContext>,
1980) -> Option<Result<DfExpr>> {
1981 match name_upper {
1982 "ID" => {
1983 if let Some(Expr::Variable(var)) = args.first() {
1986 let is_edge = context.is_some_and(|ctx| {
1987 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
1988 || ctx.mutation_edge_hints.iter().any(|h| h == var)
1989 });
1990 let id_suffix = if is_edge { COL_EID } else { COL_VID };
1991 Some(Ok(DfExpr::Column(Column::from_name(format!(
1992 "{}.{}",
1993 var, id_suffix
1994 )))))
1995 } else {
1996 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
1997 }
1998 }
1999 "LABELS" | "KEYS" => {
2000 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2005 }
2006 "TYPE" => {
2007 if let Some(Expr::Variable(var)) = args.first()
2011 && let Some(ctx) = context
2012 && let Some(label) = ctx.variable_labels.get(var)
2013 {
2014 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
2017 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
2018 expr: None,
2019 when_then_expr: vec![(
2020 Box::new(eid_col.is_not_null()),
2021 Box::new(lit(label.clone())),
2022 )],
2023 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
2024 })));
2025 }
2026 if let Some(Expr::Variable(var)) = args.first()
2030 && context
2031 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
2032 {
2033 return Some(Ok(DfExpr::Column(Column::from_name(format!(
2034 "{}.{}",
2035 var, COL_TYPE
2036 )))));
2037 }
2038 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
2039 }
2040 "PROPERTIES" => {
2041 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2044 }
2045 "UNI.TEMPORAL.VALIDAT" => {
2046 if let (
2049 Some(Expr::Variable(var)),
2050 Some(Expr::Literal(CypherLiteral::String(start_prop))),
2051 Some(Expr::Literal(CypherLiteral::String(end_prop))),
2052 Some(ts_expr),
2053 ) = (args.first(), args.get(1), args.get(2), args.get(3))
2054 {
2055 let start_col =
2056 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
2057 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
2058 let ts = match cypher_expr_to_df(ts_expr, context) {
2059 Ok(ts) => ts,
2060 Err(e) => return Some(Err(e)),
2061 };
2062
2063 let start_check = start_col.lt_eq(ts.clone());
2065 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
2067 let end_after = end_col.gt(ts);
2068 let end_check = end_null.or(end_after);
2069
2070 Some(Ok(start_check.and(end_check)))
2071 } else {
2072 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2074 }
2075 }
2076 "STARTNODE" | "ENDNODE" => {
2077 let mut udf_args = df_args.to_vec();
2080 let mut seen = std::collections::HashSet::new();
2081 if let Some(ctx) = context {
2082 for (var, kind) in &ctx.variable_kinds {
2084 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2085 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2086 }
2087 }
2088 for var in &ctx.node_variable_hints {
2091 if seen.insert(var.clone()) {
2092 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2093 }
2094 }
2095 }
2096 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2097 }
2098 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2099 "HASLABEL" => {
2100 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2101 return Some(Err(e));
2102 }
2103 if let Some(Expr::Variable(var)) = args.first() {
2105 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2106 let labels_col =
2108 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2109 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2110 labels_col,
2111 lit(label.clone()),
2112 )))
2113 } else {
2114 Some(Err(anyhow::anyhow!(
2116 "hasLabel requires string literal as second argument for DataFusion translation"
2117 )))
2118 }
2119 } else {
2120 Some(Err(anyhow::anyhow!(
2122 "hasLabel requires variable as first argument for DataFusion translation"
2123 )))
2124 }
2125 }
2126 _ => None,
2127 }
2128}
2129
2130fn translate_function_call(
2132 name: &str,
2133 args: &[Expr],
2134 distinct: bool,
2135 context: Option<&TranslationContext>,
2136) -> Result<DfExpr> {
2137 let df_args: Vec<DfExpr> = args
2138 .iter()
2139 .map(|arg| cypher_expr_to_df(arg, context))
2140 .collect::<Result<Vec<_>>>()?;
2141
2142 let name_upper = name.to_uppercase();
2143
2144 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2148 return result;
2149 }
2150
2151 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2152 return result;
2153 }
2154
2155 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2156 return result;
2157 }
2158
2159 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2160 return result;
2161 }
2162
2163 if let Some(result) = translate_btic_function(&name_upper, name, &df_args) {
2164 return result;
2165 }
2166
2167 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2168 return result;
2169 }
2170
2171 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2172 return result;
2173 }
2174
2175 match name_upper.as_str() {
2177 "COALESCE" => {
2178 require_arg(&df_args, "coalesce")?;
2179 if df_args.len() == 1 {
2184 return Ok(df_args.into_iter().next().unwrap());
2185 }
2186 let n = df_args.len();
2187 let (init, last) = df_args.split_at(n - 1);
2188 let mut builder = datafusion::logical_expr::conditional_expressions::CaseBuilder::new(
2189 None,
2190 vec![],
2191 vec![],
2192 None,
2193 );
2194 for arg in init {
2195 builder.when(arg.clone().is_not_null(), arg.clone());
2196 }
2197 return Ok(builder.otherwise(last[0].clone())?);
2198 }
2199 "NULLIF" => {
2200 require_args(&df_args, 2, "nullif")?;
2201 return Ok(datafusion::functions::expr_fn::nullif(
2202 df_args[0].clone(),
2203 df_args[1].clone(),
2204 ));
2205 }
2206 _ => {}
2207 }
2208
2209 match name_upper.as_str() {
2211 "SIMILAR_TO" | "VECTOR_SIMILARITY" => {
2212 return Ok(dummy_udf_expr(&name_upper.to_lowercase(), df_args));
2213 }
2214 _ => {}
2215 }
2216
2217 Ok(dummy_udf_expr(name, df_args))
2219}
2220
2221#[derive(Debug)]
2226struct DummyUdf {
2227 name: String,
2228 signature: datafusion::logical_expr::Signature,
2229 ret_type: datafusion::arrow::datatypes::DataType,
2230}
2231
2232impl DummyUdf {
2233 fn new(name: String) -> Self {
2234 let ret_type = dummy_udf_return_type(&name);
2235 Self {
2236 name,
2237 signature: datafusion::logical_expr::Signature::variadic_any(
2238 datafusion::logical_expr::Volatility::Immutable,
2239 ),
2240 ret_type,
2241 }
2242 }
2243}
2244
2245fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2258 use datafusion::arrow::datatypes::DataType;
2259 match name {
2260 "_cypher_add"
2264 | "_cypher_sub"
2265 | "_cypher_mul"
2266 | "_cypher_div"
2267 | "_cypher_mod"
2268 | "_cypher_list_concat"
2269 | "_cypher_list_append"
2270 | "_make_cypher_list"
2271 | "_map_project"
2272 | "_cypher_list_to_cv"
2273 | "_cypher_tail" => DataType::LargeBinary,
2274 _ => DataType::Null,
2278 }
2279}
2280
2281impl PartialEq for DummyUdf {
2282 fn eq(&self, other: &Self) -> bool {
2283 self.name == other.name
2284 }
2285}
2286
2287impl Eq for DummyUdf {}
2288
2289impl Hash for DummyUdf {
2290 fn hash<H: Hasher>(&self, state: &mut H) {
2291 self.name.hash(state);
2292 }
2293}
2294
2295pub(crate) fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2297 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2298 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2299 DummyUdf::new(name.to_lowercase()),
2300 )),
2301 args,
2302 })
2303}
2304
2305impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2306 fn as_any(&self) -> &dyn std::any::Any {
2307 self
2308 }
2309
2310 fn name(&self) -> &str {
2311 &self.name
2312 }
2313
2314 fn signature(&self) -> &datafusion::logical_expr::Signature {
2315 &self.signature
2316 }
2317
2318 fn return_type(
2319 &self,
2320 _arg_types: &[datafusion::arrow::datatypes::DataType],
2321 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2322 Ok(self.ret_type.clone())
2325 }
2326
2327 fn invoke_with_args(
2328 &self,
2329 _args: ScalarFunctionArgs,
2330 ) -> datafusion::error::Result<ColumnarValue> {
2331 Err(datafusion::error::DataFusionError::Plan(format!(
2332 "UDF '{}' is not registered. Register it via SessionContext.",
2333 self.name
2334 )))
2335 }
2336}
2337
2338pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2342 let mut properties = Vec::new();
2343 collect_properties_recursive(expr, &mut properties);
2344 properties.sort();
2345 properties.dedup();
2346 properties
2347}
2348
2349fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2350 match expr {
2351 Expr::PatternComprehension { .. } => {}
2352 Expr::Property(base, prop) => {
2353 if let Ok(var_name) = extract_variable_name(base) {
2354 properties.push((var_name, prop.clone()));
2355 }
2356 collect_properties_recursive(base, properties);
2357 }
2358 Expr::ArrayIndex { array, index } => {
2359 if let Ok(var_name) = extract_variable_name(array)
2360 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2361 {
2362 properties.push((var_name, prop_name.clone()));
2363 }
2364 collect_properties_recursive(array, properties);
2365 collect_properties_recursive(index, properties);
2366 }
2367 Expr::ArraySlice { array, start, end } => {
2368 collect_properties_recursive(array, properties);
2369 if let Some(s) = start {
2370 collect_properties_recursive(s, properties);
2371 }
2372 if let Some(e) = end {
2373 collect_properties_recursive(e, properties);
2374 }
2375 }
2376 Expr::List(items) => {
2377 for item in items {
2378 collect_properties_recursive(item, properties);
2379 }
2380 }
2381 Expr::Map(entries) => {
2382 for (_, value) in entries {
2383 collect_properties_recursive(value, properties);
2384 }
2385 }
2386 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2387 collect_properties_recursive(inner, properties);
2388 }
2389 Expr::FunctionCall { args, .. } => {
2390 for arg in args {
2391 collect_properties_recursive(arg, properties);
2392 }
2393 }
2394 Expr::BinaryOp { left, right, .. } => {
2395 collect_properties_recursive(left, properties);
2396 collect_properties_recursive(right, properties);
2397 }
2398 Expr::UnaryOp { expr, .. } => {
2399 collect_properties_recursive(expr, properties);
2400 }
2401 Expr::Case {
2402 expr,
2403 when_then,
2404 else_expr,
2405 } => {
2406 if let Some(e) = expr {
2407 collect_properties_recursive(e, properties);
2408 }
2409 for (when_e, then_e) in when_then {
2410 collect_properties_recursive(when_e, properties);
2411 collect_properties_recursive(then_e, properties);
2412 }
2413 if let Some(e) = else_expr {
2414 collect_properties_recursive(e, properties);
2415 }
2416 }
2417 Expr::Reduce {
2418 init, list, expr, ..
2419 } => {
2420 collect_properties_recursive(init, properties);
2421 collect_properties_recursive(list, properties);
2422 collect_properties_recursive(expr, properties);
2423 }
2424 Expr::Quantifier {
2425 list, predicate, ..
2426 } => {
2427 collect_properties_recursive(list, properties);
2428 collect_properties_recursive(predicate, properties);
2429 }
2430 Expr::ListComprehension {
2431 list,
2432 where_clause,
2433 map_expr,
2434 ..
2435 } => {
2436 collect_properties_recursive(list, properties);
2437 if let Some(filter) = where_clause {
2438 collect_properties_recursive(filter, properties);
2439 }
2440 collect_properties_recursive(map_expr, properties);
2441 }
2442 Expr::In { expr, list } => {
2443 collect_properties_recursive(expr, properties);
2444 collect_properties_recursive(list, properties);
2445 }
2446 Expr::ValidAt {
2447 entity, timestamp, ..
2448 } => {
2449 collect_properties_recursive(entity, properties);
2450 collect_properties_recursive(timestamp, properties);
2451 }
2452 Expr::MapProjection { base, items } => {
2453 collect_properties_recursive(base, properties);
2454 for item in items {
2455 match item {
2456 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2457 if let Ok(var_name) = extract_variable_name(base) {
2458 properties.push((var_name, prop.clone()));
2459 }
2460 }
2461 uni_cypher::ast::MapProjectionItem::AllProperties => {
2462 if let Ok(var_name) = extract_variable_name(base) {
2463 properties.push((var_name, "*".to_string()));
2464 }
2465 }
2466 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2467 collect_properties_recursive(expr, properties);
2468 }
2469 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2470 }
2471 }
2472 }
2473 Expr::LabelCheck { expr, .. } => {
2474 collect_properties_recursive(expr, properties);
2475 }
2476 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2478 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2479 }
2480}
2481
2482pub fn wider_numeric_type(
2489 a: &datafusion::arrow::datatypes::DataType,
2490 b: &datafusion::arrow::datatypes::DataType,
2491) -> datafusion::arrow::datatypes::DataType {
2492 use datafusion::arrow::datatypes::DataType;
2493
2494 fn numeric_rank(dt: &DataType) -> u8 {
2495 match dt {
2496 DataType::Int8 | DataType::UInt8 => 1,
2497 DataType::Int16 | DataType::UInt16 => 2,
2498 DataType::Int32 | DataType::UInt32 => 3,
2499 DataType::Int64 | DataType::UInt64 => 4,
2500 DataType::Float16 => 5,
2501 DataType::Float32 => 6,
2502 DataType::Float64 => 7,
2503 _ => 0,
2504 }
2505 }
2506
2507 if numeric_rank(a) >= numeric_rank(b) {
2508 a.clone()
2509 } else {
2510 b.clone()
2511 }
2512}
2513
2514fn resolve_column_type_fallback(
2520 expr: &DfExpr,
2521 schema: &datafusion::common::DFSchema,
2522) -> Option<datafusion::arrow::datatypes::DataType> {
2523 if let DfExpr::Column(col) = expr {
2524 let col_name = &col.name;
2525 for (_, field) in schema.iter() {
2527 if field.name() == col_name {
2528 return Some(field.data_type().clone());
2529 }
2530 }
2531 }
2532 None
2533}
2534
2535fn contains_division(expr: &DfExpr) -> bool {
2538 match expr {
2539 DfExpr::BinaryExpr(b) => {
2540 b.op == datafusion::logical_expr::Operator::Divide
2541 || contains_division(&b.left)
2542 || contains_division(&b.right)
2543 }
2544 DfExpr::Cast(c) => contains_division(&c.expr),
2545 DfExpr::TryCast(c) => contains_division(&c.expr),
2546 _ => false,
2547 }
2548}
2549
2550pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2556 use datafusion::arrow::datatypes::DataType;
2557 use datafusion::logical_expr::ExprSchemable;
2558
2559 match expr {
2560 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2561 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2562 DfExpr::Case(case) => coerce_case_expr(case, schema),
2563 DfExpr::InList(in_list) => {
2564 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2565 let coerced_list = in_list
2566 .list
2567 .iter()
2568 .map(|e| apply_type_coercion(e, schema))
2569 .collect::<Result<Vec<_>>>()?;
2570 let expr_type = coerced_expr
2571 .get_type(schema)
2572 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2573 crate::query::cypher_type_coerce::build_cypher_in_list(
2574 coerced_expr,
2575 &expr_type,
2576 coerced_list,
2577 in_list.negated,
2578 schema,
2579 )
2580 }
2581 DfExpr::Not(inner) => {
2582 let coerced_inner = apply_type_coercion(inner, schema)?;
2583 let inner_type = coerced_inner.get_type(schema).ok();
2584 let final_inner = if inner_type
2585 .as_ref()
2586 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2587 {
2588 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2589 } else if inner_type
2590 .as_ref()
2591 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2592 {
2593 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2594 } else {
2595 coerced_inner
2596 };
2597 Ok(DfExpr::Not(Box::new(final_inner)))
2598 }
2599 DfExpr::IsNull(inner) => {
2600 let coerced_inner = apply_type_coercion(inner, schema)?;
2601 Ok(coerced_inner.is_null())
2602 }
2603 DfExpr::IsNotNull(inner) => {
2604 let coerced_inner = apply_type_coercion(inner, schema)?;
2605 Ok(coerced_inner.is_not_null())
2606 }
2607 DfExpr::Negative(inner) => {
2608 let coerced_inner = apply_type_coercion(inner, schema)?;
2609 let inner_type = coerced_inner.get_type(schema).ok();
2610 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2611 Ok(dummy_udf_expr(
2612 "_cypher_mul",
2613 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2614 ))
2615 } else {
2616 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2617 }
2618 }
2619 DfExpr::Cast(cast) => {
2620 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2621 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2622 Box::new(coerced_inner),
2623 cast.data_type.clone(),
2624 )))
2625 }
2626 DfExpr::TryCast(cast) => {
2627 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2628 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2629 Box::new(coerced_inner),
2630 cast.data_type.clone(),
2631 )))
2632 }
2633 DfExpr::Alias(alias) => {
2634 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2635 Ok(coerced_inner.alias(alias.name.clone()))
2636 }
2637 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2638 _ => Ok(expr.clone()),
2639 }
2640}
2641
2642fn coerce_logical_operands(
2644 left: DfExpr,
2645 right: DfExpr,
2646 op: datafusion::logical_expr::Operator,
2647 schema: &datafusion::common::DFSchema,
2648) -> Option<DfExpr> {
2649 use datafusion::arrow::datatypes::DataType;
2650 use datafusion::logical_expr::ExprSchemable;
2651
2652 if !matches!(
2653 op,
2654 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2655 ) {
2656 return None;
2657 }
2658 let left_type = left.get_type(schema).ok();
2659 let right_type = right.get_type(schema).ok();
2660 let left_needs_cast = left_type
2661 .as_ref()
2662 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2663 let right_needs_cast = right_type
2664 .as_ref()
2665 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2666 let left_is_lb = left_type
2667 .as_ref()
2668 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2669 let right_is_lb = right_type
2670 .as_ref()
2671 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2672 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2673 return None;
2674 }
2675 let coerced_left = if left_is_lb {
2676 dummy_udf_expr("_cv_to_bool", vec![left])
2677 } else if left_needs_cast {
2678 datafusion::logical_expr::cast(left, DataType::Boolean)
2679 } else {
2680 left
2681 };
2682 let coerced_right = if right_is_lb {
2683 dummy_udf_expr("_cv_to_bool", vec![right])
2684 } else if right_needs_cast {
2685 datafusion::logical_expr::cast(right, DataType::Boolean)
2686 } else {
2687 right
2688 };
2689 Some(binary_expr(coerced_left, op, coerced_right))
2690}
2691
2692#[expect(
2695 clippy::too_many_arguments,
2696 reason = "Binary coercion needs all context"
2697)]
2698fn coerce_large_binary_ops(
2699 left: &DfExpr,
2700 right: &DfExpr,
2701 left_type: &datafusion::arrow::datatypes::DataType,
2702 right_type: &datafusion::arrow::datatypes::DataType,
2703 left_is_null: bool,
2704 op: datafusion::logical_expr::Operator,
2705 is_comparison: bool,
2706 is_arithmetic: bool,
2707) -> Option<Result<DfExpr>> {
2708 use datafusion::arrow::datatypes::DataType;
2709 use datafusion::logical_expr::Operator;
2710
2711 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2712 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2713
2714 if op == Operator::Plus {
2715 if left_is_lb && right_is_lb {
2716 return Some(Ok(dummy_udf_expr(
2717 "_cypher_add",
2718 vec![left.clone(), right.clone()],
2719 )));
2720 }
2721 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2722 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2723 if left_is_native_list && right_is_native_list {
2724 return Some(Ok(dummy_udf_expr(
2725 "_cypher_list_concat",
2726 vec![left.clone(), right.clone()],
2727 )));
2728 }
2729 if left_is_native_list || right_is_native_list {
2730 return Some(Ok(dummy_udf_expr(
2731 "_cypher_list_append",
2732 vec![left.clone(), right.clone()],
2733 )));
2734 }
2735 }
2736
2737 if (left_is_lb || right_is_lb) && is_comparison {
2738 if let Some(udf_name) = comparison_udf_name(op) {
2739 return Some(Ok(dummy_udf_expr(
2740 udf_name,
2741 vec![left.clone(), right.clone()],
2742 )));
2743 }
2744 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2745 }
2746
2747 if (left_is_lb || right_is_lb) && is_arithmetic {
2748 let udf_name =
2749 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2750 return Some(Ok(dummy_udf_expr(
2751 udf_name,
2752 vec![left.clone(), right.clone()],
2753 )));
2754 }
2755
2756 None
2757}
2758
2759fn coerce_temporal_comparisons(
2761 left: DfExpr,
2762 right: DfExpr,
2763 left_type: &datafusion::arrow::datatypes::DataType,
2764 right_type: &datafusion::arrow::datatypes::DataType,
2765 op: datafusion::logical_expr::Operator,
2766 is_comparison: bool,
2767) -> Option<DfExpr> {
2768 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2769 use datafusion::logical_expr::Operator;
2770
2771 if !is_comparison {
2772 return None;
2773 }
2774
2775 if uni_common::core::schema::is_datetime_struct(left_type)
2777 && uni_common::core::schema::is_datetime_struct(right_type)
2778 {
2779 return Some(binary_expr(
2780 extract_datetime_nanos(left),
2781 op,
2782 extract_datetime_nanos(right),
2783 ));
2784 }
2785
2786 if uni_common::core::schema::is_time_struct(left_type)
2788 && uni_common::core::schema::is_time_struct(right_type)
2789 {
2790 return Some(binary_expr(
2791 extract_time_nanos(left),
2792 op,
2793 extract_time_nanos(right),
2794 ));
2795 }
2796
2797 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2799 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2800
2801 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2802 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2803 {
2804 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2805 extract_datetime_nanos(left)
2806 } else {
2807 left
2808 };
2809 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2810 extract_datetime_nanos(right)
2811 } else {
2812 right
2813 };
2814 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2815 return Some(binary_expr(
2816 cast_expr(left_nanos, ts_type.clone()),
2817 op,
2818 cast_expr(right_nanos, ts_type),
2819 ));
2820 }
2821
2822 let left_is_duration = matches!(left_type, DataType::Interval(_));
2826 let right_is_duration = matches!(right_type, DataType::Interval(_));
2827 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2828 || uni_common::core::schema::is_time_struct(left_type)
2829 || matches!(
2830 left_type,
2831 DataType::Timestamp(_, _)
2832 | DataType::Date32
2833 | DataType::Date64
2834 | DataType::Time32(_)
2835 | DataType::Time64(_)
2836 );
2837 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2838 || uni_common::core::schema::is_time_struct(right_type)
2839 || matches!(
2840 right_type,
2841 DataType::Timestamp(_, _)
2842 | DataType::Date32
2843 | DataType::Date64
2844 | DataType::Time32(_)
2845 | DataType::Time64(_)
2846 );
2847
2848 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2849 {
2850 return Some(match op {
2851 Operator::Eq => lit(false),
2852 Operator::NotEq => lit(true),
2853 _ => lit(ScalarValue::Boolean(None)),
2854 });
2855 }
2856
2857 None
2858}
2859
2860fn coerce_mismatched_types(
2863 left: DfExpr,
2864 right: DfExpr,
2865 left_type: &datafusion::arrow::datatypes::DataType,
2866 right_type: &datafusion::arrow::datatypes::DataType,
2867 op: datafusion::logical_expr::Operator,
2868 is_comparison: bool,
2869) -> Option<Result<DfExpr>> {
2870 use datafusion::arrow::datatypes::DataType;
2871 use datafusion::logical_expr::Operator;
2872
2873 if left_type == right_type {
2874 return None;
2875 }
2876
2877 if left_type.is_numeric() && right_type.is_numeric() {
2879 if left_type == &DataType::Int64
2880 && right_type == &DataType::UInt64
2881 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2882 {
2883 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2884 return Some(Ok(binary_expr(coerced_left, op, right)));
2885 }
2886 if left_type == &DataType::UInt64
2887 && right_type == &DataType::Int64
2888 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2889 {
2890 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2891 return Some(Ok(binary_expr(left, op, coerced_right)));
2892 }
2893 let target = wider_numeric_type(left_type, right_type);
2894 let coerced_left = if *left_type != target {
2895 datafusion::logical_expr::cast(left, target.clone())
2896 } else {
2897 left
2898 };
2899 let coerced_right = if *right_type != target {
2900 datafusion::logical_expr::cast(right, target)
2901 } else {
2902 right
2903 };
2904 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2905 }
2906
2907 if is_comparison {
2909 match (left_type, right_type) {
2910 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2911 let right = normalize_datetime_literal(right);
2912 return Some(Ok(binary_expr(
2913 left,
2914 op,
2915 datafusion::logical_expr::cast(right, ts.clone()),
2916 )));
2917 }
2918 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2919 let left = normalize_datetime_literal(left);
2920 return Some(Ok(binary_expr(
2921 datafusion::logical_expr::cast(left, ts.clone()),
2922 op,
2923 right,
2924 )));
2925 }
2926 _ => {}
2927 }
2928 }
2929
2930 if is_comparison
2932 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2933 {
2934 let l_inner = l_field.data_type();
2935 let r_inner = r_field.data_type();
2936 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
2937 let target_inner = wider_numeric_type(l_inner, r_inner);
2938 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
2939 "item",
2940 target_inner,
2941 true,
2942 )));
2943 return Some(Ok(binary_expr(
2944 datafusion::logical_expr::cast(left, target_type.clone()),
2945 op,
2946 datafusion::logical_expr::cast(right, target_type),
2947 )));
2948 }
2949 }
2950
2951 if is_primitive_type(left_type) && is_primitive_type(right_type) {
2953 if op == Operator::Plus {
2954 return Some(crate::query::cypher_type_coerce::build_cypher_plus(
2955 left, left_type, right, right_type,
2956 ));
2957 }
2958 if is_comparison {
2959 return Some(Ok(
2960 crate::query::cypher_type_coerce::build_cypher_comparison(
2961 left, left_type, right, right_type, op,
2962 ),
2963 ));
2964 }
2965 }
2966
2967 None
2968}
2969
2970fn coerce_list_comparisons(
2972 left: DfExpr,
2973 right: DfExpr,
2974 left_type: &datafusion::arrow::datatypes::DataType,
2975 right_type: &datafusion::arrow::datatypes::DataType,
2976 op: datafusion::logical_expr::Operator,
2977 is_comparison: bool,
2978) -> Option<DfExpr> {
2979 use datafusion::arrow::datatypes::DataType;
2980 use datafusion::logical_expr::Operator;
2981
2982 if !is_comparison {
2983 return None;
2984 }
2985
2986 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2987 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2988
2989 if left_is_list
2991 && right_is_list
2992 && matches!(
2993 op,
2994 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
2995 )
2996 {
2997 let op_str = match op {
2998 Operator::Lt => "lt",
2999 Operator::LtEq => "lteq",
3000 Operator::Gt => "gt",
3001 Operator::GtEq => "gteq",
3002 _ => unreachable!(),
3003 };
3004 return Some(dummy_udf_expr(
3005 "_cypher_list_compare",
3006 vec![left, right, lit(op_str)],
3007 ));
3008 }
3009
3010 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
3012 let udf_name =
3013 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
3014 return Some(dummy_udf_expr(udf_name, vec![left, right]));
3015 }
3016
3017 if (left_is_list != right_is_list)
3019 && !matches!(left_type, DataType::Null)
3020 && !matches!(right_type, DataType::Null)
3021 {
3022 return Some(match op {
3023 Operator::Eq => lit(false),
3024 Operator::NotEq => lit(true),
3025 _ => lit(ScalarValue::Boolean(None)),
3026 });
3027 }
3028
3029 None
3030}
3031
3032fn coerce_binary_expr(
3034 binary: &datafusion::logical_expr::expr::BinaryExpr,
3035 schema: &datafusion::common::DFSchema,
3036) -> Result<DfExpr> {
3037 use datafusion::arrow::datatypes::DataType;
3038 use datafusion::logical_expr::ExprSchemable;
3039 use datafusion::logical_expr::Operator;
3040
3041 let left = apply_type_coercion(&binary.left, schema)?;
3042 let right = apply_type_coercion(&binary.right, schema)?;
3043
3044 let is_comparison = matches!(
3045 binary.op,
3046 Operator::Eq
3047 | Operator::NotEq
3048 | Operator::Lt
3049 | Operator::LtEq
3050 | Operator::Gt
3051 | Operator::GtEq
3052 );
3053 let is_arithmetic = matches!(
3054 binary.op,
3055 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
3056 );
3057
3058 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
3060 return Ok(result);
3061 }
3062
3063 if is_comparison || is_arithmetic {
3064 let left_type = match left.get_type(schema) {
3065 Ok(t) => t,
3066 Err(e) => {
3067 if let Some(t) = resolve_column_type_fallback(&left, schema) {
3068 t
3069 } else {
3070 log::warn!("Failed to get left type in binary expr: {}", e);
3071 return Ok(binary_expr(left, binary.op, right));
3072 }
3073 }
3074 };
3075 let right_type = match right.get_type(schema) {
3076 Ok(t) => t,
3077 Err(e) => {
3078 if let Some(t) = resolve_column_type_fallback(&right, schema) {
3079 t
3080 } else {
3081 log::warn!("Failed to get right type in binary expr: {}", e);
3082 return Ok(binary_expr(left, binary.op, right));
3083 }
3084 }
3085 };
3086
3087 let left_is_null = left_type.is_null();
3089 let right_is_null = right_type.is_null();
3090 if left_is_null && right_is_null {
3091 return Ok(lit(ScalarValue::Boolean(None)));
3092 }
3093 if left_is_null || right_is_null {
3094 let target = if left_is_null {
3095 &right_type
3096 } else {
3097 &left_type
3098 };
3099 if !matches!(target, DataType::LargeBinary) {
3100 let coerced_left = if left_is_null {
3101 datafusion::logical_expr::cast(left, target.clone())
3102 } else {
3103 left
3104 };
3105 let coerced_right = if right_is_null {
3106 datafusion::logical_expr::cast(right, target.clone())
3107 } else {
3108 right
3109 };
3110 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3111 }
3112 }
3113
3114 if let Some(result) = coerce_large_binary_ops(
3116 &left,
3117 &right,
3118 &left_type,
3119 &right_type,
3120 left_is_null,
3121 binary.op,
3122 is_comparison,
3123 is_arithmetic,
3124 ) {
3125 return result;
3126 }
3127
3128 if let Some(result) = coerce_temporal_comparisons(
3130 left.clone(),
3131 right.clone(),
3132 &left_type,
3133 &right_type,
3134 binary.op,
3135 is_comparison,
3136 ) {
3137 return Ok(result);
3138 }
3139
3140 let either_struct =
3142 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3143 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3144 || matches!(left_type, DataType::Struct(_)))
3145 && (matches!(right_type, DataType::LargeBinary)
3146 || matches!(right_type, DataType::Struct(_)));
3147 if is_comparison && either_struct && either_lb_or_struct {
3148 if let Some(udf_name) = comparison_udf_name(binary.op) {
3149 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3150 }
3151 return Ok(lit(ScalarValue::Boolean(None)));
3152 }
3153
3154 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3156 let udf_name = comparison_udf_name(binary.op)
3157 .expect("is_comparison guarantees a valid comparison operator");
3158 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3159 }
3160
3161 if binary.op == Operator::Plus
3163 && (crate::query::cypher_type_coerce::is_string_type(&left_type)
3164 || crate::query::cypher_type_coerce::is_string_type(&right_type))
3165 && is_primitive_type(&left_type)
3166 && is_primitive_type(&right_type)
3167 {
3168 return crate::query::cypher_type_coerce::build_cypher_plus(
3169 left,
3170 &left_type,
3171 right,
3172 &right_type,
3173 );
3174 }
3175
3176 if let Some(result) = coerce_mismatched_types(
3178 left.clone(),
3179 right.clone(),
3180 &left_type,
3181 &right_type,
3182 binary.op,
3183 is_comparison,
3184 ) {
3185 return result;
3186 }
3187
3188 if let Some(result) = coerce_list_comparisons(
3190 left.clone(),
3191 right.clone(),
3192 &left_type,
3193 &right_type,
3194 binary.op,
3195 is_comparison,
3196 ) {
3197 return Ok(result);
3198 }
3199 }
3200
3201 Ok(binary_expr(left, binary.op, right))
3202}
3203
3204fn coerce_scalar_function(
3206 func: &datafusion::logical_expr::expr::ScalarFunction,
3207 schema: &datafusion::common::DFSchema,
3208) -> Result<DfExpr> {
3209 use datafusion::arrow::datatypes::DataType;
3210 use datafusion::logical_expr::ExprSchemable;
3211
3212 let coerced_args: Vec<DfExpr> = func
3213 .args
3214 .iter()
3215 .map(|a| apply_type_coercion(a, schema))
3216 .collect::<Result<Vec<_>>>()?;
3217
3218 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3219 let types: Vec<_> = coerced_args
3220 .iter()
3221 .filter_map(|a| a.get_type(schema).ok())
3222 .collect();
3223 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3224 if has_mixed_types {
3225 let all_string_like = types
3229 .iter()
3230 .all(|t| matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Null));
3231 let unified_args: Vec<DfExpr> = if all_string_like {
3232 coerced_args
3233 .into_iter()
3234 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3235 .collect()
3236 } else {
3237 coerced_args
3239 .into_iter()
3240 .zip(types.iter())
3241 .map(|(arg, t)| match t {
3242 DataType::LargeBinary | DataType::Null => arg,
3243 DataType::List(_) | DataType::LargeList(_) => {
3244 list_to_large_binary_expr(arg)
3245 }
3246 _ => scalar_to_large_binary_expr(arg),
3247 })
3248 .collect()
3249 };
3250 return Ok(DfExpr::ScalarFunction(
3251 datafusion::logical_expr::expr::ScalarFunction {
3252 func: func.func.clone(),
3253 args: unified_args,
3254 },
3255 ));
3256 }
3257 }
3258
3259 Ok(DfExpr::ScalarFunction(
3260 datafusion::logical_expr::expr::ScalarFunction {
3261 func: func.func.clone(),
3262 args: coerced_args,
3263 },
3264 ))
3265}
3266
3267fn coerce_case_expr(
3270 case: &datafusion::logical_expr::expr::Case,
3271 schema: &datafusion::common::DFSchema,
3272) -> Result<DfExpr> {
3273 use datafusion::arrow::datatypes::DataType;
3274 use datafusion::logical_expr::ExprSchemable;
3275
3276 let coerced_operand = case
3277 .expr
3278 .as_ref()
3279 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3280 .transpose()?;
3281 let coerced_when_then = case
3282 .when_then_expr
3283 .iter()
3284 .map(|(w, t)| {
3285 let cw = apply_type_coercion(w, schema)?;
3286 let cw = match cw.get_type(schema).ok() {
3287 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3288 _ => cw,
3289 };
3290 let ct = apply_type_coercion(t, schema)?;
3291 Ok((Box::new(cw), Box::new(ct)))
3292 })
3293 .collect::<Result<Vec<_>>>()?;
3294 let coerced_else = case
3295 .else_expr
3296 .as_ref()
3297 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3298 .transpose()?;
3299
3300 let mut result_case = if let Some(operand) = coerced_operand {
3301 crate::query::cypher_type_coerce::rewrite_simple_case_to_generic(
3302 *operand,
3303 coerced_when_then,
3304 coerced_else,
3305 schema,
3306 )?
3307 } else {
3308 datafusion::logical_expr::expr::Case {
3309 expr: None,
3310 when_then_expr: coerced_when_then,
3311 else_expr: coerced_else,
3312 }
3313 };
3314
3315 crate::query::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3316
3317 Ok(DfExpr::Case(result_case))
3318}
3319
3320fn coerce_aggregate_function(
3322 agg: &datafusion::logical_expr::expr::AggregateFunction,
3323 schema: &datafusion::common::DFSchema,
3324) -> Result<DfExpr> {
3325 let coerced_args: Vec<DfExpr> = agg
3326 .params
3327 .args
3328 .iter()
3329 .map(|a| apply_type_coercion(a, schema))
3330 .collect::<Result<Vec<_>>>()?;
3331 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3332 .params
3333 .order_by
3334 .iter()
3335 .map(|s| {
3336 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3337 Ok(datafusion::logical_expr::SortExpr {
3338 expr: coerced_expr,
3339 asc: s.asc,
3340 nulls_first: s.nulls_first,
3341 })
3342 })
3343 .collect::<Result<Vec<_>>>()?;
3344 let coerced_filter = agg
3345 .params
3346 .filter
3347 .as_ref()
3348 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3349 .transpose()?;
3350 Ok(DfExpr::AggregateFunction(
3351 datafusion::logical_expr::expr::AggregateFunction {
3352 func: agg.func.clone(),
3353 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3354 args: coerced_args,
3355 distinct: agg.params.distinct,
3356 filter: coerced_filter,
3357 order_by: coerced_order_by,
3358 null_treatment: agg.params.null_treatment,
3359 },
3360 },
3361 ))
3362}
3363
3364#[cfg(test)]
3365mod tests {
3366 use super::*;
3367 use arrow_array::{
3368 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3369 };
3370 use uni_common::TemporalValue;
3371 #[test]
3372 fn test_literal_translation() {
3373 let expr = Expr::Literal(CypherLiteral::Integer(42));
3374 let result = cypher_expr_to_df(&expr, None).unwrap();
3375 let s = format!("{:?}", result);
3376 assert!(s.contains("Literal"));
3378 assert!(s.contains("Int64(42)"));
3379 }
3380
3381 #[test]
3382 fn test_property_access_no_context_uses_index() {
3383 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3385 let result = cypher_expr_to_df(&expr, None).unwrap();
3386 let s = format!("{}", result);
3387 assert!(
3388 s.contains("index"),
3389 "expected index UDF for non-graph variable, got: {s}"
3390 );
3391 }
3392
3393 #[test]
3394 fn test_comparison_operator() {
3395 let expr = Expr::BinaryOp {
3396 left: Box::new(Expr::Property(
3397 Box::new(Expr::Variable("n".to_string())),
3398 "age".to_string(),
3399 )),
3400 op: BinaryOp::Gt,
3401 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3402 };
3403 let result = cypher_expr_to_df(&expr, None).unwrap();
3404 let s = format!("{:?}", result);
3406 assert!(s.contains("age"));
3407 assert!(s.contains("30"));
3408 }
3409
3410 #[test]
3411 fn test_boolean_operators() {
3412 let expr = Expr::BinaryOp {
3413 left: Box::new(Expr::BinaryOp {
3414 left: Box::new(Expr::Property(
3415 Box::new(Expr::Variable("n".to_string())),
3416 "age".to_string(),
3417 )),
3418 op: BinaryOp::Gt,
3419 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3420 }),
3421 op: BinaryOp::And,
3422 right: Box::new(Expr::BinaryOp {
3423 left: Box::new(Expr::Property(
3424 Box::new(Expr::Variable("n".to_string())),
3425 "active".to_string(),
3426 )),
3427 op: BinaryOp::Eq,
3428 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3429 }),
3430 };
3431 let result = cypher_expr_to_df(&expr, None).unwrap();
3432 let s = format!("{:?}", result);
3433 assert!(s.contains("And"));
3434 }
3435
3436 #[test]
3437 fn test_is_null() {
3438 let expr = Expr::IsNull(Box::new(Expr::Property(
3439 Box::new(Expr::Variable("n".to_string())),
3440 "email".to_string(),
3441 )));
3442 let result = cypher_expr_to_df(&expr, None).unwrap();
3443 let s = format!("{:?}", result);
3444 assert!(s.contains("IsNull"));
3445 }
3446
3447 #[test]
3448 fn test_collect_properties() {
3449 let expr = Expr::BinaryOp {
3450 left: Box::new(Expr::Property(
3451 Box::new(Expr::Variable("n".to_string())),
3452 "name".to_string(),
3453 )),
3454 op: BinaryOp::Eq,
3455 right: Box::new(Expr::Property(
3456 Box::new(Expr::Variable("m".to_string())),
3457 "name".to_string(),
3458 )),
3459 };
3460
3461 let props = collect_properties(&expr);
3462 assert_eq!(props.len(), 2);
3463 assert!(props.contains(&("m".to_string(), "name".to_string())));
3464 assert!(props.contains(&("n".to_string(), "name".to_string())));
3465 }
3466
3467 #[test]
3468 fn test_function_call() {
3469 let expr = Expr::FunctionCall {
3470 name: "count".to_string(),
3471 args: vec![Expr::Wildcard],
3472 distinct: false,
3473 window_spec: None,
3474 };
3475 let result = cypher_expr_to_df(&expr, None).unwrap();
3476 let s = format!("{:?}", result);
3477 assert!(s.to_lowercase().contains("count"));
3478 }
3479
3480 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3485 use datafusion::logical_expr::Operator;
3486
3487 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3489 let fields: Vec<_> = cols
3490 .iter()
3491 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3492 .collect();
3493 let schema = Schema::new(fields);
3494 datafusion::common::DFSchema::try_from(schema).unwrap()
3495 }
3496
3497 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3499 let s = format!("{}", expr);
3500 s.contains(name)
3501 }
3502
3503 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3505 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3506 }
3507
3508 #[test]
3509 fn test_coercion_lb_eq_int64() {
3510 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3511 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3512 Box::new(col("lb")),
3513 Operator::Eq,
3514 Box::new(col("i")),
3515 ));
3516 let result = apply_type_coercion(&expr, &schema).unwrap();
3517 assert!(
3519 contains_udf(&result, "_cypher_equal"),
3520 "expected _cypher_equal, got: {result}"
3521 );
3522 }
3523
3524 #[test]
3525 fn test_coercion_lb_noteq_int64() {
3526 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3527 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3528 Box::new(col("lb")),
3529 Operator::NotEq,
3530 Box::new(col("i")),
3531 ));
3532 let result = apply_type_coercion(&expr, &schema).unwrap();
3533 assert!(contains_udf(&result, "_cypher_not_equal"));
3535 }
3536
3537 #[test]
3538 fn test_coercion_lb_lt_int64() {
3539 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3540 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3541 Box::new(col("lb")),
3542 Operator::Lt,
3543 Box::new(col("i")),
3544 ));
3545 let result = apply_type_coercion(&expr, &schema).unwrap();
3546 assert!(contains_udf(&result, "_cypher_lt"));
3548 }
3549
3550 #[test]
3551 fn test_coercion_lb_eq_float64() {
3552 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3553 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3554 Box::new(col("lb")),
3555 Operator::Eq,
3556 Box::new(col("f")),
3557 ));
3558 let result = apply_type_coercion(&expr, &schema).unwrap();
3559 assert!(contains_udf(&result, "_cypher_equal"));
3561 }
3562
3563 #[test]
3564 fn test_coercion_lb_eq_utf8() {
3565 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3566 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3567 Box::new(col("lb")),
3568 Operator::Eq,
3569 Box::new(col("s")),
3570 ));
3571 let result = apply_type_coercion(&expr, &schema).unwrap();
3572 assert!(contains_udf(&result, "_cypher_equal"));
3574 }
3575
3576 #[test]
3577 fn test_coercion_lb_eq_bool() {
3578 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3579 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3580 Box::new(col("lb")),
3581 Operator::Eq,
3582 Box::new(col("b")),
3583 ));
3584 let result = apply_type_coercion(&expr, &schema).unwrap();
3585 assert!(contains_udf(&result, "_cypher_equal"));
3587 }
3588
3589 #[test]
3590 fn test_coercion_int64_eq_lb() {
3591 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3593 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3594 Box::new(col("i")),
3595 Operator::Eq,
3596 Box::new(col("lb")),
3597 ));
3598 let result = apply_type_coercion(&expr, &schema).unwrap();
3599 assert!(contains_udf(&result, "_cypher_equal"));
3601 }
3602
3603 #[test]
3604 fn test_coercion_float64_gt_lb() {
3605 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3606 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3607 Box::new(col("f")),
3608 Operator::Gt,
3609 Box::new(col("lb")),
3610 ));
3611 let result = apply_type_coercion(&expr, &schema).unwrap();
3612 assert!(contains_udf(&result, "_cypher_gt"));
3614 }
3615
3616 #[test]
3617 fn test_coercion_both_lb_eq() {
3618 let schema = make_schema(&[
3619 ("lb1", DataType::LargeBinary),
3620 ("lb2", DataType::LargeBinary),
3621 ]);
3622 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3623 Box::new(col("lb1")),
3624 Operator::Eq,
3625 Box::new(col("lb2")),
3626 ));
3627 let result = apply_type_coercion(&expr, &schema).unwrap();
3628 assert!(contains_udf(&result, "_cypher_equal"));
3629 }
3630
3631 #[test]
3632 fn test_coercion_both_lb_lt() {
3633 let schema = make_schema(&[
3634 ("lb1", DataType::LargeBinary),
3635 ("lb2", DataType::LargeBinary),
3636 ]);
3637 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3638 Box::new(col("lb1")),
3639 Operator::Lt,
3640 Box::new(col("lb2")),
3641 ));
3642 let result = apply_type_coercion(&expr, &schema).unwrap();
3643 assert!(contains_udf(&result, "_cypher_lt"));
3644 }
3645
3646 #[test]
3647 fn test_coercion_both_lb_noteq() {
3648 let schema = make_schema(&[
3649 ("lb1", DataType::LargeBinary),
3650 ("lb2", DataType::LargeBinary),
3651 ]);
3652 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3653 Box::new(col("lb1")),
3654 Operator::NotEq,
3655 Box::new(col("lb2")),
3656 ));
3657 let result = apply_type_coercion(&expr, &schema).unwrap();
3658 assert!(contains_udf(&result, "_cypher_not_equal"));
3659 }
3660
3661 #[test]
3662 fn test_coercion_lb_plus_int64() {
3663 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3664 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3665 Box::new(col("lb")),
3666 Operator::Plus,
3667 Box::new(col("i")),
3668 ));
3669 let result = apply_type_coercion(&expr, &schema).unwrap();
3670 assert!(contains_udf(&result, "_cypher_add"));
3671 }
3672
3673 #[test]
3674 fn test_coercion_lb_minus_int64() {
3675 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3676 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3677 Box::new(col("lb")),
3678 Operator::Minus,
3679 Box::new(col("i")),
3680 ));
3681 let result = apply_type_coercion(&expr, &schema).unwrap();
3682 assert!(contains_udf(&result, "_cypher_sub"));
3683 }
3684
3685 #[test]
3686 fn test_coercion_lb_multiply_float64() {
3687 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3688 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3689 Box::new(col("lb")),
3690 Operator::Multiply,
3691 Box::new(col("f")),
3692 ));
3693 let result = apply_type_coercion(&expr, &schema).unwrap();
3694 assert!(contains_udf(&result, "_cypher_mul"));
3695 }
3696
3697 #[test]
3698 fn test_coercion_int64_plus_lb() {
3699 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3700 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3701 Box::new(col("i")),
3702 Operator::Plus,
3703 Box::new(col("lb")),
3704 ));
3705 let result = apply_type_coercion(&expr, &schema).unwrap();
3706 assert!(contains_udf(&result, "_cypher_add"));
3707 }
3708
3709 #[test]
3710 fn test_coercion_lb_plus_utf8() {
3711 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3713 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3714 Box::new(col("lb")),
3715 Operator::Plus,
3716 Box::new(col("s")),
3717 ));
3718 let result = apply_type_coercion(&expr, &schema).unwrap();
3719 assert!(contains_udf(&result, "_cypher_add"));
3721 }
3722
3723 #[test]
3724 fn test_coercion_and_null_bool() {
3725 let schema = make_schema(&[("b", DataType::Boolean)]);
3726 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3728 Box::new(lit(ScalarValue::Null)),
3729 Operator::And,
3730 Box::new(col("b")),
3731 ));
3732 let result = apply_type_coercion(&expr, &schema).unwrap();
3733 let s = format!("{}", result);
3734 assert!(
3736 s.contains("CAST") || s.contains("Boolean"),
3737 "expected cast to Boolean, got: {s}"
3738 );
3739 assert!(is_binary_op(&result, Operator::And));
3740 }
3741
3742 #[test]
3743 fn test_coercion_bool_and_null() {
3744 let schema = make_schema(&[("b", DataType::Boolean)]);
3745 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3746 Box::new(col("b")),
3747 Operator::And,
3748 Box::new(lit(ScalarValue::Null)),
3749 ));
3750 let result = apply_type_coercion(&expr, &schema).unwrap();
3751 assert!(is_binary_op(&result, Operator::And));
3752 }
3753
3754 #[test]
3755 fn test_coercion_or_null_bool() {
3756 let schema = make_schema(&[("b", DataType::Boolean)]);
3757 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3758 Box::new(lit(ScalarValue::Null)),
3759 Operator::Or,
3760 Box::new(col("b")),
3761 ));
3762 let result = apply_type_coercion(&expr, &schema).unwrap();
3763 assert!(is_binary_op(&result, Operator::Or));
3764 }
3765
3766 #[test]
3767 fn test_coercion_null_and_null() {
3768 let schema = make_schema(&[]);
3769 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3770 Box::new(lit(ScalarValue::Null)),
3771 Operator::And,
3772 Box::new(lit(ScalarValue::Null)),
3773 ));
3774 let result = apply_type_coercion(&expr, &schema).unwrap();
3775 assert!(is_binary_op(&result, Operator::And));
3776 }
3777
3778 #[test]
3779 fn test_coercion_bool_and_bool_noop() {
3780 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3781 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3782 Box::new(col("a")),
3783 Operator::And,
3784 Box::new(col("b")),
3785 ));
3786 let result = apply_type_coercion(&expr, &schema).unwrap();
3787 assert!(is_binary_op(&result, Operator::And));
3789 let s = format!("{}", result);
3790 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3791 }
3792
3793 #[test]
3794 fn test_coercion_case_when_lb() {
3795 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3797 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3798 Box::new(col("lb")),
3799 Operator::Eq,
3800 Box::new(lit(42_i64)),
3801 ));
3802 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3803 expr: None,
3804 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3805 else_expr: Some(Box::new(lit("b"))),
3806 });
3807 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3808 let s = format!("{}", result);
3809 assert!(
3811 s.contains("_cypher_equal"),
3812 "CASE WHEN should have _cypher_equal, got: {s}"
3813 );
3814 }
3815
3816 #[test]
3817 fn test_coercion_case_then_lb() {
3818 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3820 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3821 Box::new(col("lb")),
3822 Operator::Plus,
3823 Box::new(lit(1_i64)),
3824 ));
3825 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3826 expr: None,
3827 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3828 else_expr: Some(Box::new(lit(0_i64))),
3829 });
3830 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3831 let s = format!("{}", result);
3832 assert!(
3833 s.contains("_cypher_add"),
3834 "CASE THEN should have _cypher_add, got: {s}"
3835 );
3836 }
3837
3838 #[test]
3839 fn test_coercion_case_else_lb() {
3840 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3842 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3843 Box::new(col("lb")),
3844 Operator::Plus,
3845 Box::new(lit(2_i64)),
3846 ));
3847 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3848 expr: None,
3849 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3850 else_expr: Some(Box::new(else_expr)),
3851 });
3852 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3853 let s = format!("{}", result);
3854 assert!(
3855 s.contains("_cypher_add"),
3856 "CASE ELSE should have _cypher_add, got: {s}"
3857 );
3858 }
3859
3860 #[test]
3861 fn test_coercion_int64_eq_int64_noop() {
3862 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3863 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3864 Box::new(col("a")),
3865 Operator::Eq,
3866 Box::new(col("b")),
3867 ));
3868 let result = apply_type_coercion(&expr, &schema).unwrap();
3869 assert!(is_binary_op(&result, Operator::Eq));
3870 let s = format!("{}", result);
3871 assert!(
3872 !s.contains("_cypher_value"),
3873 "should not contain cypher_value decode: {s}"
3874 );
3875 }
3876
3877 #[test]
3878 fn test_coercion_both_lb_plus() {
3879 let schema = make_schema(&[
3881 ("lb1", DataType::LargeBinary),
3882 ("lb2", DataType::LargeBinary),
3883 ]);
3884 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3885 Box::new(col("lb1")),
3886 Operator::Plus,
3887 Box::new(col("lb2")),
3888 ));
3889 let result = apply_type_coercion(&expr, &schema).unwrap();
3890 assert!(
3891 contains_udf(&result, "_cypher_add"),
3892 "expected _cypher_add, got: {result}"
3893 );
3894 }
3895
3896 #[test]
3897 fn test_coercion_native_list_plus_scalar() {
3898 let schema = make_schema(&[
3900 (
3901 "lst",
3902 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3903 ),
3904 ("i", DataType::Int32),
3905 ]);
3906 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3907 Box::new(col("lst")),
3908 Operator::Plus,
3909 Box::new(col("i")),
3910 ));
3911 let result = apply_type_coercion(&expr, &schema).unwrap();
3912 assert!(
3913 contains_udf(&result, "_cypher_list_append"),
3914 "expected _cypher_list_append, got: {result}"
3915 );
3916 }
3917
3918 #[test]
3919 fn test_coercion_lb_plus_int64_unchanged() {
3920 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3922 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3923 Box::new(col("lb")),
3924 Operator::Plus,
3925 Box::new(col("i")),
3926 ));
3927 let result = apply_type_coercion(&expr, &schema).unwrap();
3928 assert!(
3929 contains_udf(&result, "_cypher_add"),
3930 "expected _cypher_add, got: {result}"
3931 );
3932 }
3933
3934 #[test]
3939 fn test_mixed_list_with_variables_compiles() {
3940 let expr = Expr::List(vec![
3942 Expr::Variable("n".to_string()),
3943 Expr::Literal(CypherLiteral::Integer(1)),
3944 Expr::Literal(CypherLiteral::String("hello".to_string())),
3945 ]);
3946 let result = cypher_expr_to_df(&expr, None).unwrap();
3947 let s = format!("{}", result);
3948 assert!(
3949 s.contains("_make_cypher_list"),
3950 "expected _make_cypher_list UDF call, got: {s}"
3951 );
3952 }
3953
3954 #[test]
3955 fn test_literal_only_mixed_list_uses_cv_fastpath() {
3956 let expr = Expr::List(vec![
3958 Expr::Literal(CypherLiteral::Integer(1)),
3959 Expr::Literal(CypherLiteral::String("hi".to_string())),
3960 Expr::Literal(CypherLiteral::Bool(true)),
3961 ]);
3962 let result = cypher_expr_to_df(&expr, None).unwrap();
3963 assert!(
3964 matches!(result, DfExpr::Literal(..)),
3965 "expected Literal (CypherValue fast path), got: {result}"
3966 );
3967 }
3968
3969 #[test]
3974 fn test_in_mixed_literal_list_uses_cypher_in() {
3975 let expr = Expr::In {
3977 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3978 list: Box::new(Expr::List(vec![
3979 Expr::Literal(CypherLiteral::String("1".to_string())),
3980 Expr::Literal(CypherLiteral::Integer(2)),
3981 ])),
3982 };
3983 let result = cypher_expr_to_df(&expr, None).unwrap();
3984 let s = format!("{}", result);
3985 assert!(
3986 s.contains("_cypher_in"),
3987 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
3988 );
3989 }
3990
3991 #[test]
3992 fn test_in_homogeneous_literal_list_uses_cypher_in() {
3993 let expr = Expr::In {
3995 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3996 list: Box::new(Expr::List(vec![
3997 Expr::Literal(CypherLiteral::Integer(2)),
3998 Expr::Literal(CypherLiteral::Integer(3)),
3999 ])),
4000 };
4001 let result = cypher_expr_to_df(&expr, None).unwrap();
4002 let s = format!("{}", result);
4003 assert!(
4004 s.contains("_cypher_in"),
4005 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
4006 );
4007 }
4008
4009 #[test]
4010 fn test_in_list_with_variables_uses_make_cypher_list() {
4011 let expr = Expr::In {
4013 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4014 list: Box::new(Expr::List(vec![
4015 Expr::Variable("x".to_string()),
4016 Expr::Literal(CypherLiteral::Integer(2)),
4017 ])),
4018 };
4019 let result = cypher_expr_to_df(&expr, None).unwrap();
4020 let s = format!("{}", result);
4021 assert!(
4022 s.contains("_cypher_in"),
4023 "expected _cypher_in UDF, got: {s}"
4024 );
4025 assert!(
4026 s.contains("_make_cypher_list"),
4027 "expected _make_cypher_list for variable-containing list, got: {s}"
4028 );
4029 }
4030
4031 #[test]
4036 fn test_property_on_graph_entity_uses_column() {
4037 let mut ctx = TranslationContext::new();
4039 ctx.variable_kinds
4040 .insert("n".to_string(), VariableKind::Node);
4041
4042 let expr = Expr::Property(
4043 Box::new(Expr::Variable("n".to_string())),
4044 "name".to_string(),
4045 );
4046 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4047 let s = format!("{:?}", result);
4048 assert!(
4049 s.contains("Column") && s.contains("n.name"),
4050 "expected flat column 'n.name' for graph entity, got: {s}"
4051 );
4052 }
4053
4054 #[test]
4055 fn test_property_on_non_graph_var_uses_index() {
4056 let ctx = TranslationContext::new();
4058
4059 let expr = Expr::Property(
4060 Box::new(Expr::Variable("map".to_string())),
4061 "name".to_string(),
4062 );
4063 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4064 let s = format!("{}", result);
4065 assert!(
4066 s.contains("index"),
4067 "expected index UDF for non-graph variable, got: {s}"
4068 );
4069 }
4070
4071 #[test]
4072 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4073 let mut map = std::collections::HashMap::new();
4074 map.insert("k".to_string(), Value::Int(1));
4075 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4076 assert!(
4077 matches!(scalar, ScalarValue::Struct(_)),
4078 "expected Struct scalar for map input"
4079 );
4080 }
4081
4082 #[test]
4083 fn test_value_to_scalar_empty_map_becomes_struct() {
4084 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4085 assert!(
4086 matches!(scalar, ScalarValue::Struct(_)),
4087 "empty map should produce an empty Struct scalar"
4088 );
4089 }
4090
4091 #[test]
4092 fn test_value_to_scalar_null_is_untyped_null() {
4093 let scalar = value_to_scalar(&Value::Null).unwrap();
4094 assert!(
4095 matches!(scalar, ScalarValue::Null),
4096 "expected untyped Null scalar for Value::Null"
4097 );
4098 }
4099
4100 #[test]
4101 fn test_value_to_scalar_datetime_produces_struct() {
4102 let datetime = Value::Temporal(TemporalValue::DateTime {
4104 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4107 });
4108
4109 let scalar = value_to_scalar(&datetime).unwrap();
4110
4111 if let ScalarValue::Struct(struct_arr) = scalar {
4113 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4114 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4115
4116 let fields = struct_arr.fields();
4118 assert_eq!(fields[0].name(), "nanos_since_epoch");
4119 assert_eq!(fields[1].name(), "offset_seconds");
4120 assert_eq!(fields[2].name(), "timezone_name");
4121
4122 let nanos_col = struct_arr.column(0);
4124 let offset_col = struct_arr.column(1);
4125 let tz_col = struct_arr.column(2);
4126
4127 if let Some(nanos_arr) = nanos_col
4128 .as_any()
4129 .downcast_ref::<TimestampNanosecondArray>()
4130 {
4131 assert_eq!(nanos_arr.value(0), 441763200000000000);
4132 } else {
4133 panic!("Expected TimestampNanosecondArray for nanos field");
4134 }
4135
4136 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4137 assert_eq!(offset_arr.value(0), 3600);
4138 } else {
4139 panic!("Expected Int32Array for offset field");
4140 }
4141
4142 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4143 assert_eq!(tz_arr.value(0), "Europe/Paris");
4144 } else {
4145 panic!("Expected StringArray for timezone_name field");
4146 }
4147 } else {
4148 panic!(
4149 "Expected ScalarValue::Struct for DateTime, got {:?}",
4150 scalar
4151 );
4152 }
4153 }
4154
4155 #[test]
4156 fn test_value_to_scalar_datetime_with_null_timezone() {
4157 let datetime = Value::Temporal(TemporalValue::DateTime {
4159 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4162 });
4163
4164 let scalar = value_to_scalar(&datetime).unwrap();
4165
4166 if let ScalarValue::Struct(struct_arr) = scalar {
4167 assert_eq!(struct_arr.num_columns(), 3);
4168
4169 let tz_col = struct_arr.column(2);
4171 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4172 assert!(tz_arr.is_null(0), "expected null timezone_name");
4173 } else {
4174 panic!("Expected StringArray for timezone_name field");
4175 }
4176 } else {
4177 panic!("Expected ScalarValue::Struct for DateTime");
4178 }
4179 }
4180
4181 #[test]
4182 fn test_value_to_scalar_time_produces_struct() {
4183 let time = Value::Temporal(TemporalValue::Time {
4185 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4188
4189 let scalar = value_to_scalar(&time).unwrap();
4190
4191 if let ScalarValue::Struct(struct_arr) = scalar {
4193 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4194 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4195
4196 let fields = struct_arr.fields();
4198 assert_eq!(fields[0].name(), "nanos_since_midnight");
4199 assert_eq!(fields[1].name(), "offset_seconds");
4200
4201 let nanos_col = struct_arr.column(0);
4203 let offset_col = struct_arr.column(1);
4204
4205 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4206 assert_eq!(nanos_arr.value(0), 37845000000000);
4207 } else {
4208 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4209 }
4210
4211 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4212 assert_eq!(offset_arr.value(0), 3600);
4213 } else {
4214 panic!("Expected Int32Array for offset field");
4215 }
4216 } else {
4217 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4218 }
4219 }
4220
4221 #[test]
4222 fn test_value_to_scalar_time_boundary_values() {
4223 let midnight = Value::Temporal(TemporalValue::Time {
4225 nanos_since_midnight: 0,
4226 offset_seconds: 0,
4227 });
4228
4229 let scalar = value_to_scalar(&midnight).unwrap();
4230
4231 if let ScalarValue::Struct(struct_arr) = scalar {
4232 let nanos_col = struct_arr.column(0);
4233 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4234 assert_eq!(nanos_arr.value(0), 0);
4235 } else {
4236 panic!("Expected Time64NanosecondArray");
4237 }
4238 } else {
4239 panic!("Expected ScalarValue::Struct for Time");
4240 }
4241 }
4242}