1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum VariableKind {
225 Node,
227 Edge,
229 EdgeList,
231 Path,
233}
234
235impl VariableKind {
236 pub fn edge_for(is_variable_length: bool) -> Self {
240 if is_variable_length {
241 Self::EdgeList
242 } else {
243 Self::Edge
244 }
245 }
246}
247
248pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
283 match expr {
284 Expr::PatternComprehension { .. } => Err(anyhow!(
285 "Pattern comprehensions require fallback executor (graph traversal)"
286 )),
287 Expr::Wildcard => Ok(DfExpr::Literal(
291 datafusion::common::ScalarValue::Int32(Some(1)),
292 None,
293 )),
294
295 Expr::Variable(name) => {
296 if let Some(ctx) = context
301 && ctx.variable_kinds.contains_key(name)
302 {
303 return Ok(DfExpr::Column(Column::from_name(name)));
304 }
305
306 if let Some(ctx) = context
312 && let Some(value) = ctx.outer_values.get(name)
313 {
314 return value_to_scalar(value).map(lit);
315 }
316
317 if let Some(ctx) = context
322 && let Some(value) = ctx.parameters.get(name)
323 {
324 match value {
327 Value::List(values) if name.ends_with("._vid") => {
328 let literals = values
330 .iter()
331 .map(|v| value_to_scalar(v).map(lit))
332 .collect::<Result<Vec<_>>>()?;
333 return Ok(DfExpr::InList(InList {
334 expr: Box::new(DfExpr::Column(Column::from_name(name))),
335 list: literals,
336 negated: false,
337 }));
338 }
339 other_value => return value_to_scalar(other_value).map(lit),
340 }
341 }
342
343 Ok(DfExpr::Column(Column::from_name(name)))
346 }
347
348 Expr::Property(base, prop) => translate_property_access(base, prop, context),
349
350 Expr::ArrayIndex { array, index } => {
351 if let Ok(var_name) = extract_variable_name(array)
354 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
355 {
356 let col_name = format!("{}.{}", var_name, prop_name);
357 return Ok(DfExpr::Column(Column::from_name(col_name)));
358 }
359
360 let array_expr = cypher_expr_to_df(array, context)?;
361 let index_expr = cypher_expr_to_df(index, context)?;
362
363 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
365 }
366
367 Expr::ArraySlice { array, start, end } => {
368 let array_expr = cypher_expr_to_df(array, context)?;
372
373 let start_expr = match start {
374 Some(s) => cypher_expr_to_df(s, context)?,
375 None => lit(0i64),
376 };
377
378 let end_expr = match end {
379 Some(e) => cypher_expr_to_df(e, context)?,
380 None => lit(i64::MAX),
381 };
382
383 Ok(dummy_udf_expr(
386 "_cypher_list_slice",
387 vec![array_expr, start_expr, end_expr],
388 ))
389 }
390
391 Expr::Parameter(name) => {
392 if let Some(ctx) = context
394 && let Some(value) = ctx.parameters.get(name)
395 {
396 return value_to_scalar(value).map(lit);
397 }
398 Err(anyhow!("Unresolved parameter: ${}", name))
399 }
400
401 Expr::Literal(value) => {
402 let scalar = cypher_literal_to_scalar(value)?;
403 Ok(lit(scalar))
404 }
405
406 Expr::List(items) => translate_list_literal(items, context),
407
408 Expr::Map(entries) => {
409 if entries.is_empty() {
410 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
412 Default::default(),
413 ));
414 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
415 }
416 let mut args = Vec::with_capacity(entries.len() * 2);
419 for (key, val_expr) in entries {
420 args.push(lit(key.clone()));
421 args.push(cypher_expr_to_df(val_expr, context)?);
422 }
423 Ok(datafusion::functions::expr_fn::named_struct(args))
424 }
425
426 Expr::IsNull(inner) => translate_null_check(inner, context, true),
427
428 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
429
430 Expr::IsUnique(_) => {
431 Err(anyhow!(
433 "IS UNIQUE can only be used in constraint definitions"
434 ))
435 }
436
437 Expr::FunctionCall {
438 name,
439 args,
440 distinct,
441 window_spec,
442 } => {
443 if window_spec.is_some() {
446 let col_name = expr.to_string_repr();
448 Ok(col(&col_name))
449 } else {
450 translate_function_call(name, args, *distinct, context)
451 }
452 }
453
454 Expr::In { expr, list } => translate_in_expression(expr, list, context),
455
456 Expr::BinaryOp { left, op, right } => {
457 let left_expr = cypher_expr_to_df(left, context)?;
458 let right_expr = cypher_expr_to_df(right, context)?;
459 translate_binary_op(left_expr, op, right_expr)
460 }
461
462 Expr::UnaryOp { op, expr: inner } => {
463 let inner_expr = cypher_expr_to_df(inner, context)?;
464 match op {
465 UnaryOp::Not => Ok(inner_expr.not()),
466 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
467 }
468 }
469
470 Expr::Case {
471 expr,
472 when_then,
473 else_expr,
474 } => translate_case_expression(expr, when_then, else_expr, context),
475
476 Expr::Reduce { .. } => Err(anyhow!(
477 "Reduce expressions not yet supported in DataFusion translation"
478 )),
479
480 Expr::Exists { .. } => Err(anyhow!(
481 "EXISTS subqueries are handled by the physical expression compiler, \
482 not the DataFusion logical expression translator"
483 )),
484
485 Expr::CountSubquery(_) => Err(anyhow!(
486 "Count subqueries not yet supported in DataFusion translation"
487 )),
488
489 Expr::CollectSubquery(_) => Err(anyhow!(
490 "COLLECT subqueries not yet supported in DataFusion translation"
491 )),
492
493 Expr::Quantifier { .. } => {
494 Err(anyhow!(
499 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
500 via CypherPhysicalExprCompiler"
501 ))
502 }
503
504 Expr::ListComprehension { .. } => {
505 Err(anyhow!(
517 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
518 ))
519 }
520
521 Expr::ValidAt { .. } => {
522 Err(anyhow!(
525 "VALID_AT expression should have been transformed to function call in planner"
526 ))
527 }
528
529 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
530
531 Expr::LabelCheck { expr, labels } => {
532 if let Expr::Variable(var) = expr.as_ref() {
533 let is_edge = context
535 .and_then(|ctx| ctx.variable_kinds.get(var))
536 .is_some_and(|k| matches!(k, VariableKind::Edge));
537
538 if is_edge {
539 if labels.len() > 1 {
543 Ok(lit(false))
544 } else {
545 let type_col =
546 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
547 Ok(DfExpr::Case(datafusion::logical_expr::Case {
549 expr: None,
550 when_then_expr: vec![(
551 Box::new(type_col.clone().is_null()),
552 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
553 )],
554 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
555 }))
556 }
557 } else {
558 let labels_col =
560 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
561 let checks = labels
562 .iter()
563 .map(|label| {
564 datafusion::functions_nested::expr_fn::array_has(
565 labels_col.clone(),
566 lit(label.clone()),
567 )
568 })
569 .reduce(|acc, check| acc.and(check));
570 Ok(DfExpr::Case(datafusion::logical_expr::Case {
572 expr: None,
573 when_then_expr: vec![(
574 Box::new(labels_col.is_null()),
575 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
576 )],
577 else_expr: Some(Box::new(checks.unwrap())),
578 }))
579 }
580 } else {
581 Err(anyhow!(
582 "LabelCheck on non-variable expression not yet supported in DataFusion"
583 ))
584 }
585 }
586 }
587}
588
589#[derive(Debug, Clone)]
593pub struct TranslationContext {
594 pub parameters: std::collections::HashMap<String, Value>,
596
597 pub outer_values: std::collections::HashMap<String, Value>,
601
602 pub variable_labels: std::collections::HashMap<String, String>,
604
605 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
607
608 pub node_variable_hints: Vec<String>,
611
612 pub mutation_edge_hints: Vec<String>,
615
616 pub statement_time: chrono::DateTime<chrono::Utc>,
621}
622
623impl Default for TranslationContext {
624 fn default() -> Self {
625 Self {
626 parameters: std::collections::HashMap::new(),
627 outer_values: std::collections::HashMap::new(),
628 variable_labels: std::collections::HashMap::new(),
629 variable_kinds: std::collections::HashMap::new(),
630 node_variable_hints: Vec::new(),
631 mutation_edge_hints: Vec::new(),
632 statement_time: chrono::Utc::now(),
633 }
634 }
635}
636
637impl TranslationContext {
638 pub fn new() -> Self {
640 Self::default()
641 }
642
643 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
645 self.parameters.insert(name.into(), value);
646 self
647 }
648
649 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
651 self.variable_labels.insert(var.into(), label.into());
652 self
653 }
654}
655
656fn extract_variable_name(expr: &Expr) -> Result<String> {
658 match expr {
659 Expr::Variable(name) => Ok(name.clone()),
660 Expr::Property(base, _) => extract_variable_name(base),
661 _ => Err(anyhow!(
662 "Cannot extract variable name from expression: {:?}",
663 expr
664 )),
665 }
666}
667
668fn translate_null_check(
670 inner: &Expr,
671 context: Option<&TranslationContext>,
672 is_null: bool,
673) -> Result<DfExpr> {
674 if let Expr::Variable(var) = inner
675 && let Some(ctx) = context
676 && let Some(kind) = ctx.variable_kinds.get(var)
677 {
678 let col_name = match kind {
679 VariableKind::Node => format!("{}.{}", var, COL_VID),
680 VariableKind::Edge => format!("{}.{}", var, COL_EID),
681 VariableKind::Path | VariableKind::EdgeList => var.clone(),
682 };
683 let col_expr = DfExpr::Column(Column::from_name(col_name));
684 return Ok(if is_null {
685 col_expr.is_null()
686 } else {
687 col_expr.is_not_null()
688 });
689 }
690
691 let inner_expr = cypher_expr_to_df(inner, context)?;
692 Ok(if is_null {
693 inner_expr.is_null()
694 } else {
695 inner_expr.is_not_null()
696 })
697}
698
699fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
704 if crate::datetime::is_duration_accessor(prop) {
705 Some(dummy_udf_expr(
706 "_duration_property",
707 vec![base_expr, lit(prop.to_string())],
708 ))
709 } else if crate::datetime::is_temporal_accessor(prop) {
710 Some(dummy_udf_expr(
711 "_temporal_property",
712 vec![base_expr, lit(prop.to_string())],
713 ))
714 } else {
715 None
716 }
717}
718
719fn translate_property_access(
721 base: &Expr,
722 prop: &str,
723 context: Option<&TranslationContext>,
724) -> Result<DfExpr> {
725 if let Ok(var_name) = extract_variable_name(base) {
726 let is_graph_entity = context
727 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
728 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
729
730 if !is_graph_entity
731 && let Some(expr) =
732 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
733 {
734 return Ok(expr);
735 }
736
737 let col_name = format!("{}.{}", var_name, prop);
738
739 if let Some(ctx) = context
742 && let Some(value) = ctx.parameters.get(&col_name)
743 {
744 match value {
747 Value::List(values) if col_name.ends_with("._vid") => {
748 let literals = values
749 .iter()
750 .map(|v| value_to_scalar(v).map(lit))
751 .collect::<Result<Vec<_>>>()?;
752 return Ok(DfExpr::InList(InList {
753 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
754 list: literals,
755 negated: false,
756 }));
757 }
758 other_value => return value_to_scalar(other_value).map(lit),
759 }
760 }
761
762 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
765 let base_expr = cypher_expr_to_df(base, context)?;
766 return Ok(dummy_udf_expr(
767 "index",
768 vec![base_expr, lit(prop.to_string())],
769 ));
770 }
771
772 if is_graph_entity {
773 Ok(DfExpr::Column(Column::from_name(col_name)))
774 } else {
775 let base_expr = DfExpr::Column(Column::from_name(var_name));
776 Ok(dummy_udf_expr(
777 "index",
778 vec![base_expr, lit(prop.to_string())],
779 ))
780 }
781 } else {
782 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
784 return Ok(expr);
785 }
786
787 if let Expr::Parameter(param_name) = base {
789 if let Some(ctx) = context
790 && let Some(value) = ctx.parameters.get(param_name)
791 {
792 if let Value::Map(map) = value {
793 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
794 return value_to_scalar(&extracted).map(lit);
795 }
796 return Ok(lit(ScalarValue::Null));
797 }
798 return Err(anyhow!("Unresolved parameter: ${}", param_name));
799 }
800
801 let base_expr = cypher_expr_to_df(base, context)?;
802 Ok(dummy_udf_expr(
803 "index",
804 vec![base_expr, lit(prop.to_string())],
805 ))
806 }
807}
808
809fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
811 let mut has_string = false;
813 let mut has_bool = false;
814 let mut has_list = false;
815 let mut has_map = false;
816 let mut has_numeric = false;
817 let mut has_graph_entity = false;
818 let mut has_temporal = false;
819
820 for item in items {
821 match item {
822 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
823 has_numeric = true
824 }
825 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
826 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
827 Expr::List(_) => has_list = true,
828 Expr::Map(_) => has_map = true,
829 Expr::Variable(name)
832 if context
833 .and_then(|ctx| ctx.variable_kinds.get(name))
834 .is_some() =>
835 {
836 has_graph_entity = true;
837 }
838 Expr::FunctionCall { name, .. } => {
841 let upper = name.to_uppercase();
842 if matches!(
843 upper.as_str(),
844 "DATE"
845 | "TIME"
846 | "LOCALTIME"
847 | "LOCALDATETIME"
848 | "DATETIME"
849 | "DURATION"
850 | "DATE.TRUNCATE"
851 | "TIME.TRUNCATE"
852 | "DATETIME.TRUNCATE"
853 | "LOCALDATETIME.TRUNCATE"
854 | "LOCALTIME.TRUNCATE"
855 ) {
856 has_temporal = true;
857 }
858 }
859 _ => {}
861 }
862 }
863
864 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
866
867 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
870 if let Some(json_array) = try_items_to_json(items) {
872 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
873 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
874 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
875 }
876 let df_args: Vec<DfExpr> = items
878 .iter()
879 .map(|item| cypher_expr_to_df(item, context))
880 .collect::<Result<_>>()?;
881 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
882 }
883
884 let mut df_args = Vec::with_capacity(items.len());
887 let mut has_float = false;
888 let mut has_int = false;
889 let mut has_other = false;
890
891 for item in items {
892 match item {
893 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
894 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
895 _ => has_other = true,
896 }
897 df_args.push(cypher_expr_to_df(item, context)?);
898 }
899
900 if df_args.is_empty() {
901 let empty_arr =
903 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
904 Ok(lit(ScalarValue::List(empty_arr)))
905 } else if has_float && has_int && !has_other {
906 let promoted_args = df_args
908 .into_iter()
909 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
910 .collect();
911 Ok(datafusion::functions_nested::expr_fn::make_array(
912 promoted_args,
913 ))
914 } else {
915 let non_null_type = df_args.iter().find_map(|e| {
919 if let DfExpr::Literal(sv, _) = e {
920 let dt = sv.data_type();
921 if dt != datafusion::arrow::datatypes::DataType::Null {
922 return Some(dt);
923 }
924 }
925 None
926 });
927 if let Some(ref target_type) = non_null_type {
928 let coerced = df_args
929 .into_iter()
930 .map(|e| {
931 if matches!(&e, DfExpr::Literal(sv, _) if sv.data_type() == datafusion::arrow::datatypes::DataType::Null)
932 {
933 cast_expr(e, target_type.clone())
934 } else {
935 e
936 }
937 })
938 .collect();
939 Ok(datafusion::functions_nested::expr_fn::make_array(coerced))
940 } else {
941 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
942 }
943 }
944}
945
946fn translate_in_expression(
948 expr: &Expr,
949 list: &Expr,
950 context: Option<&TranslationContext>,
951) -> Result<DfExpr> {
952 let left_expr = if let Expr::Variable(var) = expr
957 && let Some(ctx) = context
958 && let Some(kind) = ctx.variable_kinds.get(var)
959 {
960 match kind {
961 VariableKind::Node | VariableKind::Edge => {
962 let id_col = match kind {
963 VariableKind::Node => COL_VID,
964 VariableKind::Edge => COL_EID,
965 _ => unreachable!(),
966 };
967 cast_expr(
968 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
969 datafusion::arrow::datatypes::DataType::Int64,
970 )
971 }
972 _ => cypher_expr_to_df(expr, context)?,
973 }
974 } else {
975 cypher_expr_to_df(expr, context)?
976 };
977
978 if let Expr::List(items) = list {
983 if let Some(json_array) = try_items_to_json(items) {
984 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
986 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
987 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
988 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
989 } else {
990 let expanded: Vec<DfExpr> = items
992 .iter()
993 .map(|item| cypher_expr_to_df(item, context))
994 .collect::<Result<Vec<_>>>()?;
995 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
996 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
997 }
998 } else {
999 let right_expr = cypher_expr_to_df(list, context)?;
1000
1001 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
1006 return Ok(lit(ScalarValue::Boolean(None)));
1007 }
1008
1009 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
1010 }
1011}
1012
1013fn translate_case_expression(
1015 operand: &Option<Box<Expr>>,
1016 when_then: &[(Expr, Expr)],
1017 else_expr: &Option<Box<Expr>>,
1018 context: Option<&TranslationContext>,
1019) -> Result<DfExpr> {
1020 let mut case_builder = if let Some(match_expr) = operand {
1021 let match_df = cypher_expr_to_df(match_expr, context)?;
1022 datafusion::logical_expr::case(match_df)
1023 } else {
1024 datafusion::logical_expr::when(
1025 cypher_expr_to_df(&when_then[0].0, context)?,
1026 cypher_expr_to_df(&when_then[0].1, context)?,
1027 )
1028 };
1029
1030 let start_idx = if operand.is_some() { 0 } else { 1 };
1031 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1032 let when_df = cypher_expr_to_df(when_expr, context)?;
1033 let then_df = cypher_expr_to_df(then_expr, context)?;
1034 case_builder = case_builder.when(when_df, then_df);
1035 }
1036
1037 if let Some(else_e) = else_expr {
1038 let else_df = cypher_expr_to_df(else_e, context)?;
1039 Ok(case_builder.otherwise(else_df)?)
1040 } else {
1041 Ok(case_builder.end()?)
1042 }
1043}
1044
1045fn translate_map_projection(
1047 base: &Expr,
1048 items: &[MapProjectionItem],
1049 context: Option<&TranslationContext>,
1050) -> Result<DfExpr> {
1051 let mut args = Vec::new();
1052 for item in items {
1053 match item {
1054 MapProjectionItem::Property(prop) => {
1055 args.push(lit(prop.clone()));
1056 let prop_expr = cypher_expr_to_df(
1057 &Expr::Property(Box::new(base.clone()), prop.clone()),
1058 context,
1059 )?;
1060 args.push(prop_expr);
1061 }
1062 MapProjectionItem::LiteralEntry(key, expr) => {
1063 args.push(lit(key.clone()));
1064 args.push(cypher_expr_to_df(expr, context)?);
1065 }
1066 MapProjectionItem::Variable(var) => {
1067 args.push(lit(var.clone()));
1068 args.push(DfExpr::Column(Column::from_name(var)));
1069 }
1070 MapProjectionItem::AllProperties => {
1071 args.push(lit("__all__"));
1072 args.push(cypher_expr_to_df(base, context)?);
1073 }
1074 }
1075 }
1076 Ok(dummy_udf_expr("_map_project", args))
1077}
1078
1079fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1082 match expr {
1083 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1084 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1085 Expr::Literal(CypherLiteral::Integer(i)) => {
1086 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1087 }
1088 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1089 .map(serde_json::Value::Number)
1090 .or(Some(serde_json::Value::Null)),
1091 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1092 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1093 Expr::Map(entries) => {
1094 let mut map = serde_json::Map::new();
1095 for (k, v) in entries {
1096 map.insert(k.clone(), try_expr_to_json(v)?);
1097 }
1098 Some(serde_json::Value::Object(map))
1099 }
1100 _ => None,
1101 }
1102}
1103
1104fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1106 items.iter().map(try_expr_to_json).collect()
1107}
1108
1109fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1111 match lit {
1112 CypherLiteral::Null => Ok(ScalarValue::Null),
1113 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1114 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1115 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1116 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1117 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1118 }
1119}
1120
1121fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1123 match value {
1124 Value::Null => Ok(ScalarValue::Null),
1125 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1126 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1127 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1128 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1129 Value::List(items) => {
1130 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1132 let scalars = scalars?;
1133
1134 let data_type = infer_common_scalar_type(&scalars);
1136
1137 let typed_scalars: Vec<ScalarValue> = scalars
1139 .into_iter()
1140 .map(|s| {
1141 if matches!(s, ScalarValue::Null) {
1142 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1143 }
1144
1145 match (s, &data_type) {
1146 (
1147 ScalarValue::Int64(Some(v)),
1148 datafusion::arrow::datatypes::DataType::Float64,
1149 ) => ScalarValue::Float64(Some(v as f64)),
1150 (
1154 s @ ScalarValue::LargeBinary(_),
1155 datafusion::arrow::datatypes::DataType::LargeBinary,
1156 ) => s,
1157 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1158 let s_str = s.to_string();
1160 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1161 }
1162 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1163 if matches!(s, ScalarValue::Utf8(_)) {
1165 s
1166 } else {
1167 ScalarValue::Utf8(Some(s.to_string()))
1168 }
1169 }
1170 (s, _) => s,
1171 }
1172 })
1173 .collect();
1174
1175 if typed_scalars.is_empty() {
1177 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1178 &[],
1179 &data_type,
1180 )))
1181 } else {
1182 Ok(ScalarValue::List(ScalarValue::new_list(
1183 &typed_scalars,
1184 &data_type,
1185 true,
1186 )))
1187 }
1188 }
1189 Value::Map(map) => {
1190 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1193 entries.sort_by_key(|(k, _)| *k);
1194
1195 if entries.is_empty() {
1196 return Ok(ScalarValue::Struct(Arc::new(
1197 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1198 )));
1199 }
1200
1201 let mut fields_arrays = Vec::with_capacity(entries.len());
1202
1203 for (k, v) in entries {
1204 let scalar = value_to_scalar(v)?;
1205 let dt = scalar.data_type();
1206 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1207 let array = scalar.to_array()?;
1208 fields_arrays.push((field, array));
1209 }
1210
1211 Ok(ScalarValue::Struct(Arc::new(
1212 datafusion::arrow::array::StructArray::from(fields_arrays),
1213 )))
1214 }
1215 Value::Temporal(tv) => {
1216 use uni_common::TemporalValue;
1217 match tv {
1218 TemporalValue::Date { days_since_epoch } => {
1219 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1220 }
1221 TemporalValue::LocalTime {
1222 nanos_since_midnight,
1223 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1224 TemporalValue::Time {
1225 nanos_since_midnight,
1226 offset_seconds,
1227 } => {
1228 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1230 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1231
1232 let nanos_arr =
1233 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1234 as ArrayRef;
1235 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1236
1237 let fields = Fields::from(vec![
1238 Field::new(
1239 "nanos_since_midnight",
1240 ArrowDataType::Time64(TimeUnit::Nanosecond),
1241 true,
1242 ),
1243 Field::new("offset_seconds", ArrowDataType::Int32, true),
1244 ]);
1245
1246 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1247 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1248 }
1249 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1250 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1251 ),
1252 TemporalValue::DateTime {
1253 nanos_since_epoch,
1254 offset_seconds,
1255 timezone_name,
1256 } => {
1257 use arrow::array::{
1259 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1260 };
1261 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1262
1263 let nanos_arr =
1264 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1265 as ArrayRef;
1266 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1267 let tz_arr =
1268 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1269
1270 let fields = Fields::from(vec![
1271 Field::new(
1272 "nanos_since_epoch",
1273 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1274 true,
1275 ),
1276 Field::new("offset_seconds", ArrowDataType::Int32, true),
1277 Field::new("timezone_name", ArrowDataType::Utf8, true),
1278 ]);
1279
1280 let struct_arr =
1281 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1282 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1283 }
1284 TemporalValue::Duration {
1285 months,
1286 days,
1287 nanos,
1288 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1289 arrow::datatypes::IntervalMonthDayNano {
1290 months: *months as i32,
1291 days: *days as i32,
1292 nanoseconds: *nanos,
1293 },
1294 ))),
1295 TemporalValue::Btic { lo, hi, meta } => {
1296 let btic = uni_btic::Btic::new(*lo, *hi, *meta)
1297 .map_err(|e| anyhow::anyhow!("invalid BTIC value: {}", e))?;
1298 let packed = uni_btic::encode::encode(&btic);
1299 Ok(ScalarValue::FixedSizeBinary(24, Some(packed.to_vec())))
1300 }
1301 }
1302 }
1303 Value::Vector(v) => {
1304 let cv_bytes = uni_common::cypher_value_codec::encode(&Value::Vector(v.clone()));
1306 Ok(ScalarValue::LargeBinary(Some(cv_bytes)))
1307 }
1308 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1309 other => {
1311 let json_val: serde_json::Value = other.clone().into();
1312 let json_str = serde_json::to_string(&json_val)
1313 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1314 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1315 }
1316 }
1317}
1318
1319fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1321 match op {
1322 BinaryOp::Eq => Ok(left.eq(right)),
1326 BinaryOp::NotEq => Ok(left.not_eq(right)),
1327 BinaryOp::Lt => Ok(left.lt(right)),
1328 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1329 BinaryOp::Gt => Ok(left.gt(right)),
1330 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1331
1332 BinaryOp::And => Ok(left.and(right)),
1334 BinaryOp::Or => Ok(left.or(right)),
1335 BinaryOp::Xor => {
1336 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1338 }
1339
1340 BinaryOp::Add => {
1346 if is_list_expr(&left) || is_list_expr(&right) {
1347 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1348 } else {
1349 Ok(left + right)
1350 }
1351 }
1352 BinaryOp::Sub => Ok(left - right),
1353 BinaryOp::Mul => Ok(left * right),
1354 BinaryOp::Div => Ok(left / right),
1355 BinaryOp::Mod => Ok(left % right),
1356 BinaryOp::Pow => {
1357 let left_f = crate::df_udfs::cypher_to_float64_expr(left);
1364 let right_f = crate::df_udfs::cypher_to_float64_expr(right);
1365 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1366 }
1367
1368 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1370 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1371 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1372
1373 BinaryOp::Regex => {
1374 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1375 }
1376
1377 BinaryOp::ApproxEq => Err(anyhow!(
1378 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1379 )),
1380 }
1381}
1382
1383macro_rules! check_args {
1388 (1, $df_args:expr, $name:expr) => {
1389 if let Err(e) = require_arg($df_args, $name) {
1390 return Some(Err(e));
1391 }
1392 };
1393 ($n:expr, $df_args:expr, $name:expr) => {
1394 if let Err(e) = require_args($df_args, $n, $name) {
1395 return Some(Err(e));
1396 }
1397 };
1398}
1399
1400fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1403 if df_args.len() < count {
1404 let noun = if count == 1 { "argument" } else { "arguments" };
1405 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1406 }
1407 Ok(())
1408}
1409
1410fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1412 require_args(df_args, 1, func_name)
1413}
1414
1415fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1417 df_args[0].clone()
1418}
1419
1420pub fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1422 DfExpr::Cast(datafusion::logical_expr::Cast {
1423 expr: Box::new(expr),
1424 data_type,
1425 })
1426}
1427
1428pub fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1434 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1435 Arc::new(crate::df_udfs::create_cypher_list_to_cv_udf()),
1436 vec![expr],
1437 ))
1438}
1439
1440pub fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1444 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1445 Arc::new(crate::df_udfs::create_cypher_scalar_to_cv_udf()),
1446 vec![expr],
1447 ))
1448}
1449
1450fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1452 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1453 Box::new(left),
1454 op,
1455 Box::new(right),
1456 ))
1457}
1458
1459pub fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1464 use datafusion::logical_expr::Operator;
1465 match op {
1466 Operator::Eq => Some("_cypher_equal"),
1467 Operator::NotEq => Some("_cypher_not_equal"),
1468 Operator::Lt => Some("_cypher_lt"),
1469 Operator::LtEq => Some("_cypher_lt_eq"),
1470 Operator::Gt => Some("_cypher_gt"),
1471 Operator::GtEq => Some("_cypher_gt_eq"),
1472 _ => None,
1473 }
1474}
1475
1476fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1478 use datafusion::logical_expr::Operator;
1479 match op {
1480 Operator::Plus => Some("_cypher_add"),
1481 Operator::Minus => Some("_cypher_sub"),
1482 Operator::Multiply => Some("_cypher_mul"),
1483 Operator::Divide => Some("_cypher_div"),
1484 Operator::Modulo => Some("_cypher_mod"),
1485 _ => None,
1486 }
1487}
1488
1489fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1494where
1495 F: FnOnce(DfExpr) -> DfExpr,
1496{
1497 require_arg(df_args, func_name)?;
1498 Ok(math_fn(crate::df_udfs::cypher_to_float64_expr(first_arg(
1506 df_args,
1507 ))))
1508}
1509
1510fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1512 if distinct {
1513 expr.distinct()
1514 .build()
1515 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1516 } else {
1517 Ok(expr)
1518 }
1519}
1520
1521fn translate_aggregate_function(
1523 name_upper: &str,
1524 df_args: &[DfExpr],
1525 distinct: bool,
1526) -> Option<Result<DfExpr>> {
1527 match name_upper {
1528 "COUNT" => {
1529 let expr = if df_args.is_empty() {
1530 datafusion::functions_aggregate::count::count(lit(1i64))
1531 } else {
1532 datafusion::functions_aggregate::count::count(first_arg(df_args))
1533 };
1534 Some(maybe_distinct(expr, distinct, "COUNT"))
1535 }
1536 "SUM" => {
1537 check_args!(1, df_args, "SUM");
1538 let udaf = Arc::new(crate::df_udfs::create_cypher_sum_udaf());
1539 Some(maybe_distinct(
1540 udaf.call(vec![first_arg(df_args)]),
1541 distinct,
1542 "SUM",
1543 ))
1544 }
1545 "AVG" => {
1546 check_args!(1, df_args, "AVG");
1547 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1548 let expr = datafusion::functions_aggregate::average::avg(coerced);
1549 Some(maybe_distinct(expr, distinct, "AVG"))
1550 }
1551 "MIN" => {
1552 check_args!(1, df_args, "MIN");
1553 let udaf = Arc::new(crate::df_udfs::create_cypher_min_udaf());
1554 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1555 }
1556 "MAX" => {
1557 check_args!(1, df_args, "MAX");
1558 let udaf = Arc::new(crate::df_udfs::create_cypher_max_udaf());
1559 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1560 }
1561 "PERCENTILEDISC" => {
1562 if df_args.len() != 2 {
1563 return Some(Err(anyhow!(
1564 "percentileDisc() requires exactly 2 arguments"
1565 )));
1566 }
1567 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1568 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_disc_udaf());
1569 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1570 }
1571 "PERCENTILECONT" => {
1572 if df_args.len() != 2 {
1573 return Some(Err(anyhow!(
1574 "percentileCont() requires exactly 2 arguments"
1575 )));
1576 }
1577 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1578 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_cont_udaf());
1579 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1580 }
1581 "COLLECT" => {
1582 check_args!(1, df_args, "COLLECT");
1583 Some(Ok(crate::df_udfs::create_cypher_collect_expr(
1584 first_arg(df_args),
1585 distinct,
1586 )))
1587 }
1588 "BTIC_MIN" => {
1590 check_args!(1, df_args, "btic_min");
1591 let udaf = Arc::new(crate::df_udfs::create_btic_min_udaf());
1592 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1593 }
1594 "BTIC_MAX" => {
1595 check_args!(1, df_args, "btic_max");
1596 let udaf = Arc::new(crate::df_udfs::create_btic_max_udaf());
1597 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1598 }
1599 "BTIC_SPAN_AGG" => {
1600 check_args!(1, df_args, "btic_span_agg");
1601 let udaf = Arc::new(crate::df_udfs::create_btic_span_agg_udaf());
1602 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1603 }
1604 "BTIC_COUNT_AT" => {
1605 if df_args.len() != 2 {
1606 return Some(Err(anyhow!("btic_count_at requires 2 arguments")));
1607 }
1608 let udaf = Arc::new(crate::df_udfs::create_btic_count_at_udaf());
1609 Some(Ok(udaf.call(df_args.to_vec())))
1610 }
1611 _ => None,
1612 }
1613}
1614
1615fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1618 match name_upper {
1619 "TOSTRING" => {
1620 check_args!(1, df_args, "toString");
1621 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1622 }
1623 "TOINTEGER" | "TOINT" => {
1624 check_args!(1, df_args, "toInteger");
1625 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1626 }
1627 "TOFLOAT" => {
1628 check_args!(1, df_args, "toFloat");
1629 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1630 }
1631 "TOBOOLEAN" | "TOBOOL" => {
1632 check_args!(1, df_args, "toBoolean");
1633 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1634 }
1635 "UPPER" | "TOUPPER" => {
1636 check_args!(1, df_args, "upper");
1637 Some(Ok(datafusion::functions::string::expr_fn::upper(
1638 first_arg(df_args),
1639 )))
1640 }
1641 "LOWER" | "TOLOWER" => {
1642 check_args!(1, df_args, "lower");
1643 Some(Ok(datafusion::functions::string::expr_fn::lower(
1644 first_arg(df_args),
1645 )))
1646 }
1647 "SUBSTRING" => {
1648 check_args!(2, df_args, "substring");
1649 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1650 }
1651 "TRIM" => {
1652 check_args!(1, df_args, "TRIM");
1653 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1654 first_arg(df_args),
1655 ])))
1656 }
1657 "LTRIM" => {
1658 check_args!(1, df_args, "LTRIM");
1659 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1660 first_arg(df_args),
1661 ])))
1662 }
1663 "RTRIM" => {
1664 check_args!(1, df_args, "RTRIM");
1665 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1666 first_arg(df_args),
1667 ])))
1668 }
1669 "LEFT" => {
1670 check_args!(2, df_args, "left");
1671 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1672 df_args[0].clone(),
1673 df_args[1].clone(),
1674 )))
1675 }
1676 "RIGHT" => {
1677 check_args!(2, df_args, "right");
1678 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1679 df_args[0].clone(),
1680 df_args[1].clone(),
1681 )))
1682 }
1683 "REPLACE" => {
1684 check_args!(3, df_args, "replace");
1685 Some(Ok(datafusion::functions::string::expr_fn::replace(
1686 df_args[0].clone(),
1687 df_args[1].clone(),
1688 df_args[2].clone(),
1689 )))
1690 }
1691 "REVERSE" => {
1692 check_args!(1, df_args, "reverse");
1693 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1694 }
1695 "SPLIT" => {
1696 check_args!(2, df_args, "split");
1697 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1698 }
1699 "SIZE" | "LENGTH" => {
1700 check_args!(1, df_args, name_upper);
1701 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1702 }
1703 _ => None,
1704 }
1705}
1706
1707fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1710 use datafusion::functions::math::expr_fn;
1711
1712 let unary_f64 =
1714 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1715
1716 match name_upper {
1717 "ABS" => {
1718 check_args!(1, df_args, "abs");
1719 Some(Ok(crate::df_udfs::cypher_abs_expr(first_arg(df_args))))
1723 }
1724 "CEIL" | "CEILING" => {
1725 check_args!(1, df_args, "ceil");
1726 Some(Ok(expr_fn::ceil(crate::df_udfs::cypher_to_float64_expr(
1730 first_arg(df_args),
1731 ))))
1732 }
1733 "FLOOR" => {
1734 check_args!(1, df_args, "floor");
1735 Some(Ok(expr_fn::floor(crate::df_udfs::cypher_to_float64_expr(
1736 first_arg(df_args),
1737 ))))
1738 }
1739 "ROUND" => {
1740 check_args!(1, df_args, "round");
1741 let args = if df_args.len() == 1 {
1742 vec![crate::df_udfs::cypher_to_float64_expr(first_arg(df_args))]
1743 } else {
1744 vec![
1745 crate::df_udfs::cypher_to_float64_expr(df_args[0].clone()),
1746 df_args[1].clone(),
1747 ]
1748 };
1749 Some(Ok(expr_fn::round(args)))
1750 }
1751 "SIGN" => {
1752 check_args!(1, df_args, "sign");
1753 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1754 Some(Ok(expr_fn::signum(coerced)))
1755 }
1756 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1757 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1758 "LOG10" => unary_f64("log10", expr_fn::log10),
1759 "EXP" => unary_f64("exp", expr_fn::exp),
1760 "SIN" => unary_f64("sin", expr_fn::sin),
1761 "COS" => unary_f64("cos", expr_fn::cos),
1762 "TAN" => unary_f64("tan", expr_fn::tan),
1763 "ASIN" => unary_f64("asin", expr_fn::asin),
1764 "ACOS" => unary_f64("acos", expr_fn::acos),
1765 "ATAN" => unary_f64("atan", expr_fn::atan),
1766 "ATAN2" => {
1767 check_args!(2, df_args, "atan2");
1768 Some(Ok(expr_fn::atan2(
1771 crate::df_udfs::cypher_to_float64_expr(df_args[0].clone()),
1772 crate::df_udfs::cypher_to_float64_expr(df_args[1].clone()),
1773 )))
1774 }
1775 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1776 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1777 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1778 _ => None,
1779 }
1780}
1781
1782fn translate_temporal_function(
1785 name_upper: &str,
1786 name: &str,
1787 df_args: &[DfExpr],
1788 context: Option<&TranslationContext>,
1789) -> Option<Result<DfExpr>> {
1790 match name_upper {
1791 "DATE"
1792 | "TIME"
1793 | "LOCALTIME"
1794 | "LOCALDATETIME"
1795 | "DATETIME"
1796 | "DURATION"
1797 | "YEAR"
1798 | "MONTH"
1799 | "DAY"
1800 | "HOUR"
1801 | "MINUTE"
1802 | "SECOND"
1803 | "DURATION.BETWEEN"
1804 | "DURATION.INMONTHS"
1805 | "DURATION.INDAYS"
1806 | "DURATION.INSECONDS"
1807 | "DATETIME.FROMEPOCH"
1808 | "DATETIME.FROMEPOCHMILLIS"
1809 | "DATE.TRUNCATE"
1810 | "TIME.TRUNCATE"
1811 | "DATETIME.TRUNCATE"
1812 | "LOCALDATETIME.TRUNCATE"
1813 | "LOCALTIME.TRUNCATE"
1814 | "DATETIME.TRANSACTION"
1815 | "DATETIME.STATEMENT"
1816 | "DATETIME.REALTIME"
1817 | "DATE.TRANSACTION"
1818 | "DATE.STATEMENT"
1819 | "DATE.REALTIME"
1820 | "TIME.TRANSACTION"
1821 | "TIME.STATEMENT"
1822 | "TIME.REALTIME"
1823 | "LOCALTIME.TRANSACTION"
1824 | "LOCALTIME.STATEMENT"
1825 | "LOCALTIME.REALTIME"
1826 | "LOCALDATETIME.TRANSACTION"
1827 | "LOCALDATETIME.STATEMENT"
1828 | "LOCALDATETIME.REALTIME" => {
1829 let stmt_time = context.map(|c| c.statement_time);
1833 if can_constant_fold(name_upper, df_args)
1834 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1835 {
1836 return Some(Ok(folded));
1837 }
1838 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1839 }
1840 _ => None,
1841 }
1842}
1843
1844fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1846 if name.contains("REALTIME") {
1848 return false;
1849 }
1850 if args.is_empty() {
1858 return matches!(
1859 name,
1860 "DATE"
1861 | "TIME"
1862 | "LOCALTIME"
1863 | "LOCALDATETIME"
1864 | "DATETIME"
1865 | "DATE.STATEMENT"
1866 | "TIME.STATEMENT"
1867 | "LOCALTIME.STATEMENT"
1868 | "LOCALDATETIME.STATEMENT"
1869 | "DATETIME.STATEMENT"
1870 | "DATE.TRANSACTION"
1871 | "TIME.TRANSACTION"
1872 | "LOCALTIME.TRANSACTION"
1873 | "LOCALDATETIME.TRANSACTION"
1874 | "DATETIME.TRANSACTION"
1875 );
1876 }
1877 args.iter().all(is_constant_expr)
1879}
1880
1881fn is_constant_expr(expr: &DfExpr) -> bool {
1883 match expr {
1884 DfExpr::Literal(_, _) => true,
1885 DfExpr::ScalarFunction(func) => {
1886 func.args.iter().all(is_constant_expr)
1888 }
1889 _ => false,
1890 }
1891}
1892
1893fn try_constant_fold_temporal(
1899 name: &str,
1900 args: &[DfExpr],
1901 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1902) -> Result<DfExpr> {
1903 let val_args: Vec<Value> = args
1905 .iter()
1906 .map(extract_constant_value)
1907 .collect::<Result<_>>()?;
1908
1909 let result = if val_args.is_empty() {
1911 if let Some(frozen) = stmt_time {
1912 crate::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1913 } else {
1914 crate::datetime::eval_datetime_function(name, &val_args)?
1915 }
1916 } else {
1917 crate::datetime::eval_datetime_function(name, &val_args)?
1918 };
1919
1920 let scalar = value_to_scalar(&result)?;
1922 Ok(DfExpr::Literal(scalar, None))
1923}
1924
1925fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1927 use crate::df_udfs::scalar_to_value;
1928 match expr {
1929 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1930 DfExpr::ScalarFunction(func) => {
1931 let mut map = std::collections::HashMap::new();
1934 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1935 for chunk in pairs.chunks(2) {
1936 if let [key_expr, val_expr] = chunk {
1937 let key = match key_expr {
1939 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1940 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1941 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1942 };
1943 let val = extract_constant_value(val_expr)?;
1944 map.insert(key, val);
1945 } else {
1946 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1947 }
1948 }
1949 Ok(Value::Map(map))
1950 }
1951 _ => Err(anyhow::anyhow!(
1952 "Cannot extract constant value from expression"
1953 )),
1954 }
1955}
1956
1957fn translate_btic_function(
1960 name_upper: &str,
1961 name: &str,
1962 df_args: &[DfExpr],
1963) -> Option<Result<DfExpr>> {
1964 let canonical_upper = name_upper
1970 .strip_prefix("BTIC.")
1971 .map(|rest| format!("BTIC_{rest}"));
1972 let lookup_upper = canonical_upper.as_deref().unwrap_or(name_upper);
1973 if crate::expr_eval::is_btic_function(lookup_upper) {
1974 let canonical_name = name
1976 .strip_prefix("btic.")
1977 .or_else(|| name.strip_prefix("BTIC."))
1978 .map(|rest| format!("btic_{}", rest.to_lowercase()));
1979 let udf_name = canonical_name.as_deref().unwrap_or(name);
1980 Some(Ok(dummy_udf_expr(udf_name, df_args.to_vec())))
1981 } else {
1982 None
1983 }
1984}
1985
1986fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1989 match name_upper {
1990 "HEAD" => {
1991 check_args!(1, df_args, "head");
1992 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1993 }
1994 "LAST" => {
1995 check_args!(1, df_args, "last");
1996 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1997 }
1998 "TAIL" => {
1999 check_args!(1, df_args, "tail");
2000 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
2001 }
2002 "RANGE" => {
2003 check_args!(2, df_args, "range");
2004 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
2005 }
2006 _ => None,
2007 }
2008}
2009
2010fn translate_graph_function(
2013 name_upper: &str,
2014 name: &str,
2015 df_args: &[DfExpr],
2016 args: &[Expr],
2017 context: Option<&TranslationContext>,
2018) -> Option<Result<DfExpr>> {
2019 match name_upper {
2020 "ID" => {
2021 if let Some(Expr::Variable(var)) = args.first() {
2024 let is_edge = context.is_some_and(|ctx| {
2025 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
2026 || ctx.mutation_edge_hints.iter().any(|h| h == var)
2027 });
2028 let id_suffix = if is_edge { COL_EID } else { COL_VID };
2029 Some(Ok(DfExpr::Column(Column::from_name(format!(
2030 "{}.{}",
2031 var, id_suffix
2032 )))))
2033 } else {
2034 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
2035 }
2036 }
2037 "CREATED_AT" | "UPDATED_AT" => {
2038 if let Some(Expr::Variable(var)) = args.first() {
2042 let suffix = if name_upper == "CREATED_AT" {
2043 "_created_at"
2044 } else {
2045 "_updated_at"
2046 };
2047 Some(Ok(DfExpr::Column(Column::from_name(format!(
2048 "{}.{}",
2049 var, suffix
2050 )))))
2051 } else {
2052 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2053 }
2054 }
2055 "LABELS" | "KEYS" => {
2056 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2061 }
2062 "TYPE" => {
2063 if let Some(Expr::Variable(var)) = args.first()
2067 && let Some(ctx) = context
2068 && let Some(label) = ctx.variable_labels.get(var)
2069 {
2070 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
2073 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
2074 expr: None,
2075 when_then_expr: vec![(
2076 Box::new(eid_col.is_not_null()),
2077 Box::new(lit(label.clone())),
2078 )],
2079 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
2080 })));
2081 }
2082 if let Some(Expr::Variable(var)) = args.first()
2086 && context
2087 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
2088 {
2089 return Some(Ok(DfExpr::Column(Column::from_name(format!(
2090 "{}.{}",
2091 var, COL_TYPE
2092 )))));
2093 }
2094 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
2095 }
2096 "PROPERTIES" => {
2097 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2100 }
2101 "UNI.TEMPORAL.VALIDAT" => {
2102 if let (
2105 Some(Expr::Variable(var)),
2106 Some(Expr::Literal(CypherLiteral::String(start_prop))),
2107 Some(Expr::Literal(CypherLiteral::String(end_prop))),
2108 Some(ts_expr),
2109 ) = (args.first(), args.get(1), args.get(2), args.get(3))
2110 {
2111 let start_col =
2112 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
2113 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
2114 let ts = match cypher_expr_to_df(ts_expr, context) {
2115 Ok(ts) => ts,
2116 Err(e) => return Some(Err(e)),
2117 };
2118
2119 let start_check = start_col.lt_eq(ts.clone());
2121 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
2123 let end_after = end_col.gt(ts);
2124 let end_check = end_null.or(end_after);
2125
2126 Some(Ok(start_check.and(end_check)))
2127 } else {
2128 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2130 }
2131 }
2132 "STARTNODE" | "ENDNODE" => {
2133 let mut udf_args = df_args.to_vec();
2136 let mut seen = std::collections::HashSet::new();
2137 if let Some(ctx) = context {
2138 for (var, kind) in &ctx.variable_kinds {
2140 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2141 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2142 }
2143 }
2144 for var in &ctx.node_variable_hints {
2147 if seen.insert(var.clone()) {
2148 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2149 }
2150 }
2151 }
2152 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2153 }
2154 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2155 "HASLABEL" => {
2156 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2157 return Some(Err(e));
2158 }
2159 if let Some(Expr::Variable(var)) = args.first() {
2161 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2162 let labels_col =
2164 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2165 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2166 labels_col,
2167 lit(label.clone()),
2168 )))
2169 } else {
2170 Some(Err(anyhow::anyhow!(
2172 "hasLabel requires string literal as second argument for DataFusion translation"
2173 )))
2174 }
2175 } else {
2176 Some(Err(anyhow::anyhow!(
2178 "hasLabel requires variable as first argument for DataFusion translation"
2179 )))
2180 }
2181 }
2182 _ => None,
2183 }
2184}
2185
2186fn translate_function_call(
2188 name: &str,
2189 args: &[Expr],
2190 distinct: bool,
2191 context: Option<&TranslationContext>,
2192) -> Result<DfExpr> {
2193 let df_args: Vec<DfExpr> = args
2194 .iter()
2195 .map(|arg| cypher_expr_to_df(arg, context))
2196 .collect::<Result<Vec<_>>>()?;
2197
2198 let name_upper = name.to_uppercase();
2199
2200 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2204 return result;
2205 }
2206
2207 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2208 return result;
2209 }
2210
2211 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2212 return result;
2213 }
2214
2215 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2216 return result;
2217 }
2218
2219 if let Some(result) = translate_btic_function(&name_upper, name, &df_args) {
2220 return result;
2221 }
2222
2223 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2224 return result;
2225 }
2226
2227 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2228 return result;
2229 }
2230
2231 match name_upper.as_str() {
2233 "COALESCE" => {
2234 require_arg(&df_args, "coalesce")?;
2235 if df_args.len() == 1 {
2240 return Ok(df_args.into_iter().next().unwrap());
2241 }
2242 let n = df_args.len();
2243 let (init, last) = df_args.split_at(n - 1);
2244 let mut builder = datafusion::logical_expr::conditional_expressions::CaseBuilder::new(
2245 None,
2246 vec![],
2247 vec![],
2248 None,
2249 );
2250 for arg in init {
2251 builder.when(arg.clone().is_not_null(), arg.clone());
2252 }
2253 return Ok(builder.otherwise(last[0].clone())?);
2254 }
2255 "NULLIF" => {
2256 require_args(&df_args, 2, "nullif")?;
2257 return Ok(datafusion::functions::expr_fn::nullif(
2258 df_args[0].clone(),
2259 df_args[1].clone(),
2260 ));
2261 }
2262 _ => {}
2263 }
2264
2265 match name_upper.as_str() {
2267 "SIMILAR_TO" | "VECTOR_SIMILARITY" | "SPARSE_SIMILAR_TO" => {
2268 return Ok(dummy_udf_expr(&name_upper.to_lowercase(), df_args));
2269 }
2270 _ => {}
2271 }
2272
2273 Ok(dummy_udf_expr(name, df_args))
2275}
2276
2277#[derive(Debug)]
2282struct DummyUdf {
2283 name: String,
2284 signature: datafusion::logical_expr::Signature,
2285 ret_type: datafusion::arrow::datatypes::DataType,
2286}
2287
2288impl DummyUdf {
2289 fn new(name: String) -> Self {
2290 let ret_type = dummy_udf_return_type(&name);
2291 Self {
2292 name,
2293 signature: datafusion::logical_expr::Signature::variadic_any(
2294 datafusion::logical_expr::Volatility::Immutable,
2295 ),
2296 ret_type,
2297 }
2298 }
2299}
2300
2301fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2314 use datafusion::arrow::datatypes::DataType;
2315 match name {
2316 "_cypher_add"
2320 | "_cypher_sub"
2321 | "_cypher_mul"
2322 | "_cypher_div"
2323 | "_cypher_mod"
2324 | "_cypher_list_concat"
2325 | "_cypher_list_append"
2326 | "_make_cypher_list"
2327 | "_map_project"
2328 | "_cypher_list_to_cv"
2329 | "_cypher_tail" => DataType::LargeBinary,
2330 _ => DataType::Null,
2334 }
2335}
2336
2337impl PartialEq for DummyUdf {
2338 fn eq(&self, other: &Self) -> bool {
2339 self.name == other.name
2340 }
2341}
2342
2343impl Eq for DummyUdf {}
2344
2345impl Hash for DummyUdf {
2346 fn hash<H: Hasher>(&self, state: &mut H) {
2347 self.name.hash(state);
2348 }
2349}
2350
2351pub fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2353 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2354 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2355 DummyUdf::new(name.to_lowercase()),
2356 )),
2357 args,
2358 })
2359}
2360
2361impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2362 fn as_any(&self) -> &dyn std::any::Any {
2363 self
2364 }
2365
2366 fn name(&self) -> &str {
2367 &self.name
2368 }
2369
2370 fn signature(&self) -> &datafusion::logical_expr::Signature {
2371 &self.signature
2372 }
2373
2374 fn return_type(
2375 &self,
2376 arg_types: &[datafusion::arrow::datatypes::DataType],
2377 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2378 match self.name.as_str() {
2383 "_cypher_add" | "_cypher_sub" | "_cypher_mul" | "_cypher_div" | "_cypher_mod" => {
2384 Ok(crate::df_udfs::cypher_arith_return_type(arg_types))
2385 }
2386 _ => Ok(self.ret_type.clone()),
2390 }
2391 }
2392
2393 fn invoke_with_args(
2394 &self,
2395 _args: ScalarFunctionArgs,
2396 ) -> datafusion::error::Result<ColumnarValue> {
2397 Err(datafusion::error::DataFusionError::Plan(format!(
2398 "UDF '{}' is not registered. Register it via SessionContext.",
2399 self.name
2400 )))
2401 }
2402}
2403
2404pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2408 let mut properties = Vec::new();
2409 collect_properties_recursive(expr, &mut properties);
2410 properties.sort();
2411 properties.dedup();
2412 properties
2413}
2414
2415fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2416 match expr {
2417 Expr::PatternComprehension { .. } => {}
2418 Expr::Property(base, prop) => {
2419 if let Ok(var_name) = extract_variable_name(base) {
2420 properties.push((var_name, prop.clone()));
2421 }
2422 collect_properties_recursive(base, properties);
2423 }
2424 Expr::ArrayIndex { array, index } => {
2425 if let Ok(var_name) = extract_variable_name(array)
2426 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2427 {
2428 properties.push((var_name, prop_name.clone()));
2429 }
2430 collect_properties_recursive(array, properties);
2431 collect_properties_recursive(index, properties);
2432 }
2433 Expr::ArraySlice { array, start, end } => {
2434 collect_properties_recursive(array, properties);
2435 if let Some(s) = start {
2436 collect_properties_recursive(s, properties);
2437 }
2438 if let Some(e) = end {
2439 collect_properties_recursive(e, properties);
2440 }
2441 }
2442 Expr::List(items) => {
2443 for item in items {
2444 collect_properties_recursive(item, properties);
2445 }
2446 }
2447 Expr::Map(entries) => {
2448 for (_, value) in entries {
2449 collect_properties_recursive(value, properties);
2450 }
2451 }
2452 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2453 collect_properties_recursive(inner, properties);
2454 }
2455 Expr::FunctionCall { args, .. } => {
2456 for arg in args {
2457 collect_properties_recursive(arg, properties);
2458 }
2459 }
2460 Expr::BinaryOp { left, right, .. } => {
2461 collect_properties_recursive(left, properties);
2462 collect_properties_recursive(right, properties);
2463 }
2464 Expr::UnaryOp { expr, .. } => {
2465 collect_properties_recursive(expr, properties);
2466 }
2467 Expr::Case {
2468 expr,
2469 when_then,
2470 else_expr,
2471 } => {
2472 if let Some(e) = expr {
2473 collect_properties_recursive(e, properties);
2474 }
2475 for (when_e, then_e) in when_then {
2476 collect_properties_recursive(when_e, properties);
2477 collect_properties_recursive(then_e, properties);
2478 }
2479 if let Some(e) = else_expr {
2480 collect_properties_recursive(e, properties);
2481 }
2482 }
2483 Expr::Reduce {
2484 init, list, expr, ..
2485 } => {
2486 collect_properties_recursive(init, properties);
2487 collect_properties_recursive(list, properties);
2488 collect_properties_recursive(expr, properties);
2489 }
2490 Expr::Quantifier {
2491 list, predicate, ..
2492 } => {
2493 collect_properties_recursive(list, properties);
2494 collect_properties_recursive(predicate, properties);
2495 }
2496 Expr::ListComprehension {
2497 list,
2498 where_clause,
2499 map_expr,
2500 ..
2501 } => {
2502 collect_properties_recursive(list, properties);
2503 if let Some(filter) = where_clause {
2504 collect_properties_recursive(filter, properties);
2505 }
2506 collect_properties_recursive(map_expr, properties);
2507 }
2508 Expr::In { expr, list } => {
2509 collect_properties_recursive(expr, properties);
2510 collect_properties_recursive(list, properties);
2511 }
2512 Expr::ValidAt {
2513 entity, timestamp, ..
2514 } => {
2515 collect_properties_recursive(entity, properties);
2516 collect_properties_recursive(timestamp, properties);
2517 }
2518 Expr::MapProjection { base, items } => {
2519 collect_properties_recursive(base, properties);
2520 for item in items {
2521 match item {
2522 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2523 if let Ok(var_name) = extract_variable_name(base) {
2524 properties.push((var_name, prop.clone()));
2525 }
2526 }
2527 uni_cypher::ast::MapProjectionItem::AllProperties => {
2528 if let Ok(var_name) = extract_variable_name(base) {
2529 properties.push((var_name, "*".to_string()));
2530 }
2531 }
2532 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2533 collect_properties_recursive(expr, properties);
2534 }
2535 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2536 }
2537 }
2538 }
2539 Expr::LabelCheck { expr, .. } => {
2540 collect_properties_recursive(expr, properties);
2541 }
2542 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2544 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2545 }
2546}
2547
2548pub fn wider_numeric_type(
2555 a: &datafusion::arrow::datatypes::DataType,
2556 b: &datafusion::arrow::datatypes::DataType,
2557) -> datafusion::arrow::datatypes::DataType {
2558 use datafusion::arrow::datatypes::DataType;
2559
2560 fn numeric_rank(dt: &DataType) -> u8 {
2561 match dt {
2562 DataType::Int8 | DataType::UInt8 => 1,
2563 DataType::Int16 | DataType::UInt16 => 2,
2564 DataType::Int32 | DataType::UInt32 => 3,
2565 DataType::Int64 | DataType::UInt64 => 4,
2566 DataType::Float16 => 5,
2567 DataType::Float32 => 6,
2568 DataType::Float64 => 7,
2569 _ => 0,
2570 }
2571 }
2572
2573 if numeric_rank(a) >= numeric_rank(b) {
2574 a.clone()
2575 } else {
2576 b.clone()
2577 }
2578}
2579
2580fn resolve_column_type_fallback(
2586 expr: &DfExpr,
2587 schema: &datafusion::common::DFSchema,
2588) -> Option<datafusion::arrow::datatypes::DataType> {
2589 if let DfExpr::Column(col) = expr {
2590 let col_name = &col.name;
2591 for (_, field) in schema.iter() {
2593 if field.name() == col_name {
2594 return Some(field.data_type().clone());
2595 }
2596 }
2597 }
2598 None
2599}
2600
2601fn contains_division(expr: &DfExpr) -> bool {
2604 match expr {
2605 DfExpr::BinaryExpr(b) => {
2606 b.op == datafusion::logical_expr::Operator::Divide
2607 || contains_division(&b.left)
2608 || contains_division(&b.right)
2609 }
2610 DfExpr::Cast(c) => contains_division(&c.expr),
2611 DfExpr::TryCast(c) => contains_division(&c.expr),
2612 _ => false,
2613 }
2614}
2615
2616pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2622 use datafusion::arrow::datatypes::DataType;
2623 use datafusion::logical_expr::ExprSchemable;
2624
2625 match expr {
2626 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2627 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2628 DfExpr::Case(case) => coerce_case_expr(case, schema),
2629 DfExpr::InList(in_list) => {
2630 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2631 let coerced_list = in_list
2632 .list
2633 .iter()
2634 .map(|e| apply_type_coercion(e, schema))
2635 .collect::<Result<Vec<_>>>()?;
2636 let expr_type = coerced_expr
2637 .get_type(schema)
2638 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2639 crate::cypher_type_coerce::build_cypher_in_list(
2640 coerced_expr,
2641 &expr_type,
2642 coerced_list,
2643 in_list.negated,
2644 schema,
2645 )
2646 }
2647 DfExpr::Not(inner) => {
2648 let coerced_inner = apply_type_coercion(inner, schema)?;
2649 let inner_type = coerced_inner.get_type(schema).ok();
2650 let final_inner = if inner_type
2651 .as_ref()
2652 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2653 {
2654 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2655 } else if inner_type
2656 .as_ref()
2657 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2658 {
2659 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2660 } else {
2661 coerced_inner
2662 };
2663 Ok(DfExpr::Not(Box::new(final_inner)))
2664 }
2665 DfExpr::IsNull(inner) => {
2666 let coerced_inner = apply_type_coercion(inner, schema)?;
2667 Ok(coerced_inner.is_null())
2668 }
2669 DfExpr::IsNotNull(inner) => {
2670 let coerced_inner = apply_type_coercion(inner, schema)?;
2671 Ok(coerced_inner.is_not_null())
2672 }
2673 DfExpr::Negative(inner) => {
2674 let coerced_inner = apply_type_coercion(inner, schema)?;
2675 let inner_type = coerced_inner.get_type(schema).ok();
2676 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2677 Ok(dummy_udf_expr(
2678 "_cypher_mul",
2679 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2680 ))
2681 } else {
2682 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2683 }
2684 }
2685 DfExpr::Cast(cast) => {
2686 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2687 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2688 Box::new(coerced_inner),
2689 cast.data_type.clone(),
2690 )))
2691 }
2692 DfExpr::TryCast(cast) => {
2693 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2694 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2695 Box::new(coerced_inner),
2696 cast.data_type.clone(),
2697 )))
2698 }
2699 DfExpr::Alias(alias) => {
2700 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2701 Ok(coerced_inner.alias(alias.name.clone()))
2702 }
2703 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2704 _ => Ok(expr.clone()),
2705 }
2706}
2707
2708fn coerce_logical_operands(
2710 left: DfExpr,
2711 right: DfExpr,
2712 op: datafusion::logical_expr::Operator,
2713 schema: &datafusion::common::DFSchema,
2714) -> Option<DfExpr> {
2715 use datafusion::arrow::datatypes::DataType;
2716 use datafusion::logical_expr::ExprSchemable;
2717
2718 if !matches!(
2719 op,
2720 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2721 ) {
2722 return None;
2723 }
2724 let left_type = left.get_type(schema).ok();
2725 let right_type = right.get_type(schema).ok();
2726 let left_needs_cast = left_type
2727 .as_ref()
2728 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2729 let right_needs_cast = right_type
2730 .as_ref()
2731 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2732 let left_is_lb = left_type
2733 .as_ref()
2734 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2735 let right_is_lb = right_type
2736 .as_ref()
2737 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2738 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2739 return None;
2740 }
2741 let coerced_left = if left_is_lb {
2742 dummy_udf_expr("_cv_to_bool", vec![left])
2743 } else if left_needs_cast {
2744 datafusion::logical_expr::cast(left, DataType::Boolean)
2745 } else {
2746 left
2747 };
2748 let coerced_right = if right_is_lb {
2749 dummy_udf_expr("_cv_to_bool", vec![right])
2750 } else if right_needs_cast {
2751 datafusion::logical_expr::cast(right, DataType::Boolean)
2752 } else {
2753 right
2754 };
2755 Some(binary_expr(coerced_left, op, coerced_right))
2756}
2757
2758#[expect(
2761 clippy::too_many_arguments,
2762 reason = "Binary coercion needs all context"
2763)]
2764fn coerce_large_binary_ops(
2765 left: &DfExpr,
2766 right: &DfExpr,
2767 left_type: &datafusion::arrow::datatypes::DataType,
2768 right_type: &datafusion::arrow::datatypes::DataType,
2769 left_is_null: bool,
2770 op: datafusion::logical_expr::Operator,
2771 is_comparison: bool,
2772 is_arithmetic: bool,
2773) -> Option<Result<DfExpr>> {
2774 use datafusion::arrow::datatypes::DataType;
2775 use datafusion::logical_expr::Operator;
2776
2777 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2778 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2779
2780 if op == Operator::Plus {
2781 if left_is_lb && right_is_lb {
2782 return Some(Ok(dummy_udf_expr(
2783 "_cypher_add",
2784 vec![left.clone(), right.clone()],
2785 )));
2786 }
2787 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2788 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2789 if left_is_native_list && right_is_native_list {
2790 return Some(Ok(dummy_udf_expr(
2791 "_cypher_list_concat",
2792 vec![left.clone(), right.clone()],
2793 )));
2794 }
2795 if left_is_native_list || right_is_native_list {
2796 return Some(Ok(dummy_udf_expr(
2797 "_cypher_list_append",
2798 vec![left.clone(), right.clone()],
2799 )));
2800 }
2801 }
2802
2803 if (left_is_lb || right_is_lb) && is_comparison {
2804 if let Some(udf_name) = comparison_udf_name(op) {
2805 return Some(Ok(dummy_udf_expr(
2806 udf_name,
2807 vec![left.clone(), right.clone()],
2808 )));
2809 }
2810 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2811 }
2812
2813 if (left_is_lb || right_is_lb) && is_arithmetic {
2814 let udf_name =
2815 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2816 return Some(Ok(dummy_udf_expr(
2817 udf_name,
2818 vec![left.clone(), right.clone()],
2819 )));
2820 }
2821
2822 None
2823}
2824
2825fn coerce_temporal_comparisons(
2827 left: DfExpr,
2828 right: DfExpr,
2829 left_type: &datafusion::arrow::datatypes::DataType,
2830 right_type: &datafusion::arrow::datatypes::DataType,
2831 op: datafusion::logical_expr::Operator,
2832 is_comparison: bool,
2833) -> Option<DfExpr> {
2834 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2835 use datafusion::logical_expr::Operator;
2836
2837 if !is_comparison {
2838 return None;
2839 }
2840
2841 if uni_common::core::schema::is_datetime_struct(left_type)
2843 && uni_common::core::schema::is_datetime_struct(right_type)
2844 {
2845 return Some(binary_expr(
2846 extract_datetime_nanos(left),
2847 op,
2848 extract_datetime_nanos(right),
2849 ));
2850 }
2851
2852 if uni_common::core::schema::is_time_struct(left_type)
2854 && uni_common::core::schema::is_time_struct(right_type)
2855 {
2856 return Some(binary_expr(
2857 extract_time_nanos(left),
2858 op,
2859 extract_time_nanos(right),
2860 ));
2861 }
2862
2863 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2865 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2866
2867 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2868 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2869 {
2870 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2871 extract_datetime_nanos(left)
2872 } else {
2873 left
2874 };
2875 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2876 extract_datetime_nanos(right)
2877 } else {
2878 right
2879 };
2880 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2881 return Some(binary_expr(
2882 cast_expr(left_nanos, ts_type.clone()),
2883 op,
2884 cast_expr(right_nanos, ts_type),
2885 ));
2886 }
2887
2888 let left_is_duration = matches!(left_type, DataType::Interval(_));
2892 let right_is_duration = matches!(right_type, DataType::Interval(_));
2893 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2894 || uni_common::core::schema::is_time_struct(left_type)
2895 || matches!(
2896 left_type,
2897 DataType::Timestamp(_, _)
2898 | DataType::Date32
2899 | DataType::Date64
2900 | DataType::Time32(_)
2901 | DataType::Time64(_)
2902 );
2903 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2904 || uni_common::core::schema::is_time_struct(right_type)
2905 || matches!(
2906 right_type,
2907 DataType::Timestamp(_, _)
2908 | DataType::Date32
2909 | DataType::Date64
2910 | DataType::Time32(_)
2911 | DataType::Time64(_)
2912 );
2913
2914 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2915 {
2916 return Some(match op {
2917 Operator::Eq => lit(false),
2918 Operator::NotEq => lit(true),
2919 _ => lit(ScalarValue::Boolean(None)),
2920 });
2921 }
2922
2923 None
2924}
2925
2926fn coerce_mismatched_types(
2929 left: DfExpr,
2930 right: DfExpr,
2931 left_type: &datafusion::arrow::datatypes::DataType,
2932 right_type: &datafusion::arrow::datatypes::DataType,
2933 op: datafusion::logical_expr::Operator,
2934 is_comparison: bool,
2935) -> Option<Result<DfExpr>> {
2936 use datafusion::arrow::datatypes::DataType;
2937 use datafusion::logical_expr::Operator;
2938
2939 if left_type == right_type {
2940 return None;
2941 }
2942
2943 if left_type.is_numeric() && right_type.is_numeric() {
2945 if left_type == &DataType::Int64
2946 && right_type == &DataType::UInt64
2947 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2948 {
2949 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2950 return Some(Ok(binary_expr(coerced_left, op, right)));
2951 }
2952 if left_type == &DataType::UInt64
2953 && right_type == &DataType::Int64
2954 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2955 {
2956 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2957 return Some(Ok(binary_expr(left, op, coerced_right)));
2958 }
2959 let target = wider_numeric_type(left_type, right_type);
2960 let coerced_left = if *left_type != target {
2961 datafusion::logical_expr::cast(left, target.clone())
2962 } else {
2963 left
2964 };
2965 let coerced_right = if *right_type != target {
2966 datafusion::logical_expr::cast(right, target)
2967 } else {
2968 right
2969 };
2970 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2971 }
2972
2973 if is_comparison {
2975 match (left_type, right_type) {
2976 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2977 let right = normalize_datetime_literal(right);
2978 return Some(Ok(binary_expr(
2979 left,
2980 op,
2981 datafusion::logical_expr::cast(right, ts.clone()),
2982 )));
2983 }
2984 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2985 let left = normalize_datetime_literal(left);
2986 return Some(Ok(binary_expr(
2987 datafusion::logical_expr::cast(left, ts.clone()),
2988 op,
2989 right,
2990 )));
2991 }
2992 _ => {}
2993 }
2994 }
2995
2996 if is_comparison
2998 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2999 {
3000 let l_inner = l_field.data_type();
3001 let r_inner = r_field.data_type();
3002 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
3003 let target_inner = wider_numeric_type(l_inner, r_inner);
3004 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
3005 "item",
3006 target_inner,
3007 true,
3008 )));
3009 return Some(Ok(binary_expr(
3010 datafusion::logical_expr::cast(left, target_type.clone()),
3011 op,
3012 datafusion::logical_expr::cast(right, target_type),
3013 )));
3014 }
3015 }
3016
3017 if is_primitive_type(left_type) && is_primitive_type(right_type) {
3019 if op == Operator::Plus {
3020 return Some(crate::cypher_type_coerce::build_cypher_plus(
3021 left, left_type, right, right_type,
3022 ));
3023 }
3024 if is_comparison {
3025 return Some(Ok(crate::cypher_type_coerce::build_cypher_comparison(
3026 left, left_type, right, right_type, op,
3027 )));
3028 }
3029 }
3030
3031 None
3032}
3033
3034fn coerce_list_comparisons(
3036 left: DfExpr,
3037 right: DfExpr,
3038 left_type: &datafusion::arrow::datatypes::DataType,
3039 right_type: &datafusion::arrow::datatypes::DataType,
3040 op: datafusion::logical_expr::Operator,
3041 is_comparison: bool,
3042) -> Option<DfExpr> {
3043 use datafusion::arrow::datatypes::DataType;
3044 use datafusion::logical_expr::Operator;
3045
3046 if !is_comparison {
3047 return None;
3048 }
3049
3050 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
3051 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
3052
3053 if left_is_list
3055 && right_is_list
3056 && matches!(
3057 op,
3058 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
3059 )
3060 {
3061 let op_str = match op {
3062 Operator::Lt => "lt",
3063 Operator::LtEq => "lteq",
3064 Operator::Gt => "gt",
3065 Operator::GtEq => "gteq",
3066 _ => unreachable!(),
3067 };
3068 return Some(dummy_udf_expr(
3069 "_cypher_list_compare",
3070 vec![left, right, lit(op_str)],
3071 ));
3072 }
3073
3074 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
3076 let udf_name =
3077 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
3078 return Some(dummy_udf_expr(udf_name, vec![left, right]));
3079 }
3080
3081 if (left_is_list != right_is_list)
3083 && !matches!(left_type, DataType::Null)
3084 && !matches!(right_type, DataType::Null)
3085 {
3086 return Some(match op {
3087 Operator::Eq => lit(false),
3088 Operator::NotEq => lit(true),
3089 _ => lit(ScalarValue::Boolean(None)),
3090 });
3091 }
3092
3093 None
3094}
3095
3096fn coerce_binary_expr(
3098 binary: &datafusion::logical_expr::expr::BinaryExpr,
3099 schema: &datafusion::common::DFSchema,
3100) -> Result<DfExpr> {
3101 use datafusion::arrow::datatypes::DataType;
3102 use datafusion::logical_expr::ExprSchemable;
3103 use datafusion::logical_expr::Operator;
3104
3105 let left = apply_type_coercion(&binary.left, schema)?;
3106 let right = apply_type_coercion(&binary.right, schema)?;
3107
3108 let is_comparison = matches!(
3109 binary.op,
3110 Operator::Eq
3111 | Operator::NotEq
3112 | Operator::Lt
3113 | Operator::LtEq
3114 | Operator::Gt
3115 | Operator::GtEq
3116 );
3117 let is_arithmetic = matches!(
3118 binary.op,
3119 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
3120 );
3121
3122 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
3124 return Ok(result);
3125 }
3126
3127 if is_comparison || is_arithmetic {
3128 let left_type = match left.get_type(schema) {
3129 Ok(t) => t,
3130 Err(e) => {
3131 if let Some(t) = resolve_column_type_fallback(&left, schema) {
3132 t
3133 } else {
3134 log::warn!("Failed to get left type in binary expr: {}", e);
3135 return Ok(binary_expr(left, binary.op, right));
3136 }
3137 }
3138 };
3139 let right_type = match right.get_type(schema) {
3140 Ok(t) => t,
3141 Err(e) => {
3142 if let Some(t) = resolve_column_type_fallback(&right, schema) {
3143 t
3144 } else {
3145 log::warn!("Failed to get right type in binary expr: {}", e);
3146 return Ok(binary_expr(left, binary.op, right));
3147 }
3148 }
3149 };
3150
3151 let left_is_null = left_type.is_null();
3153 let right_is_null = right_type.is_null();
3154 if left_is_null && right_is_null {
3155 return Ok(lit(ScalarValue::Boolean(None)));
3156 }
3157 if left_is_null || right_is_null {
3158 let target = if left_is_null {
3159 &right_type
3160 } else {
3161 &left_type
3162 };
3163 if !matches!(target, DataType::LargeBinary) {
3164 let coerced_left = if left_is_null {
3165 datafusion::logical_expr::cast(left, target.clone())
3166 } else {
3167 left
3168 };
3169 let coerced_right = if right_is_null {
3170 datafusion::logical_expr::cast(right, target.clone())
3171 } else {
3172 right
3173 };
3174 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3175 }
3176 }
3177
3178 if let Some(result) = coerce_large_binary_ops(
3180 &left,
3181 &right,
3182 &left_type,
3183 &right_type,
3184 left_is_null,
3185 binary.op,
3186 is_comparison,
3187 is_arithmetic,
3188 ) {
3189 return result;
3190 }
3191
3192 if let Some(result) = coerce_temporal_comparisons(
3194 left.clone(),
3195 right.clone(),
3196 &left_type,
3197 &right_type,
3198 binary.op,
3199 is_comparison,
3200 ) {
3201 return Ok(result);
3202 }
3203
3204 let either_struct =
3206 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3207 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3208 || matches!(left_type, DataType::Struct(_)))
3209 && (matches!(right_type, DataType::LargeBinary)
3210 || matches!(right_type, DataType::Struct(_)));
3211 if is_comparison && either_struct && either_lb_or_struct {
3212 if let Some(udf_name) = comparison_udf_name(binary.op) {
3213 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3214 }
3215 return Ok(lit(ScalarValue::Boolean(None)));
3216 }
3217
3218 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3220 let udf_name = comparison_udf_name(binary.op)
3221 .expect("is_comparison guarantees a valid comparison operator");
3222 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3223 }
3224
3225 if binary.op == Operator::Plus
3227 && (crate::cypher_type_coerce::is_string_type(&left_type)
3228 || crate::cypher_type_coerce::is_string_type(&right_type))
3229 && is_primitive_type(&left_type)
3230 && is_primitive_type(&right_type)
3231 {
3232 return crate::cypher_type_coerce::build_cypher_plus(
3233 left,
3234 &left_type,
3235 right,
3236 &right_type,
3237 );
3238 }
3239
3240 if let Some(result) = coerce_mismatched_types(
3242 left.clone(),
3243 right.clone(),
3244 &left_type,
3245 &right_type,
3246 binary.op,
3247 is_comparison,
3248 ) {
3249 return result;
3250 }
3251
3252 if let Some(result) = coerce_list_comparisons(
3254 left.clone(),
3255 right.clone(),
3256 &left_type,
3257 &right_type,
3258 binary.op,
3259 is_comparison,
3260 ) {
3261 return Ok(result);
3262 }
3263
3264 if let Some(name) = arithmetic_udf_name(binary.op)
3273 && left_type == DataType::Int64
3274 && right_type == DataType::Int64
3275 && !is_list_expr(&left)
3276 && !is_list_expr(&right)
3277 {
3278 return Ok(dummy_udf_expr(name, vec![left, right]));
3279 }
3280 }
3281
3282 Ok(binary_expr(left, binary.op, right))
3283}
3284
3285fn coerce_scalar_function(
3287 func: &datafusion::logical_expr::expr::ScalarFunction,
3288 schema: &datafusion::common::DFSchema,
3289) -> Result<DfExpr> {
3290 use datafusion::arrow::datatypes::DataType;
3291 use datafusion::logical_expr::ExprSchemable;
3292
3293 let coerced_args: Vec<DfExpr> = func
3294 .args
3295 .iter()
3296 .map(|a| apply_type_coercion(a, schema))
3297 .collect::<Result<Vec<_>>>()?;
3298
3299 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3300 let types: Vec<_> = coerced_args
3301 .iter()
3302 .filter_map(|a| a.get_type(schema).ok())
3303 .collect();
3304 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3305 if has_mixed_types {
3306 let all_string_like = types
3310 .iter()
3311 .all(|t| matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Null));
3312 let unified_args: Vec<DfExpr> = if all_string_like {
3313 coerced_args
3314 .into_iter()
3315 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3316 .collect()
3317 } else {
3318 coerced_args
3320 .into_iter()
3321 .zip(types.iter())
3322 .map(|(arg, t)| match t {
3323 DataType::LargeBinary | DataType::Null => arg,
3324 DataType::List(_) | DataType::LargeList(_) => {
3325 list_to_large_binary_expr(arg)
3326 }
3327 _ => scalar_to_large_binary_expr(arg),
3328 })
3329 .collect()
3330 };
3331 return Ok(DfExpr::ScalarFunction(
3332 datafusion::logical_expr::expr::ScalarFunction {
3333 func: func.func.clone(),
3334 args: unified_args,
3335 },
3336 ));
3337 }
3338 }
3339
3340 Ok(DfExpr::ScalarFunction(
3341 datafusion::logical_expr::expr::ScalarFunction {
3342 func: func.func.clone(),
3343 args: coerced_args,
3344 },
3345 ))
3346}
3347
3348fn coerce_case_expr(
3351 case: &datafusion::logical_expr::expr::Case,
3352 schema: &datafusion::common::DFSchema,
3353) -> Result<DfExpr> {
3354 use datafusion::arrow::datatypes::DataType;
3355 use datafusion::logical_expr::ExprSchemable;
3356
3357 let coerced_operand = case
3358 .expr
3359 .as_ref()
3360 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3361 .transpose()?;
3362 let coerced_when_then = case
3363 .when_then_expr
3364 .iter()
3365 .map(|(w, t)| {
3366 let cw = apply_type_coercion(w, schema)?;
3367 let cw = match cw.get_type(schema).ok() {
3368 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3369 _ => cw,
3370 };
3371 let ct = apply_type_coercion(t, schema)?;
3372 Ok((Box::new(cw), Box::new(ct)))
3373 })
3374 .collect::<Result<Vec<_>>>()?;
3375 let coerced_else = case
3376 .else_expr
3377 .as_ref()
3378 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3379 .transpose()?;
3380
3381 let mut result_case = if let Some(operand) = coerced_operand {
3382 crate::cypher_type_coerce::rewrite_simple_case_to_generic(
3383 *operand,
3384 coerced_when_then,
3385 coerced_else,
3386 schema,
3387 )?
3388 } else {
3389 datafusion::logical_expr::expr::Case {
3390 expr: None,
3391 when_then_expr: coerced_when_then,
3392 else_expr: coerced_else,
3393 }
3394 };
3395
3396 crate::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3397
3398 Ok(DfExpr::Case(result_case))
3399}
3400
3401fn coerce_aggregate_function(
3403 agg: &datafusion::logical_expr::expr::AggregateFunction,
3404 schema: &datafusion::common::DFSchema,
3405) -> Result<DfExpr> {
3406 let coerced_args: Vec<DfExpr> = agg
3407 .params
3408 .args
3409 .iter()
3410 .map(|a| apply_type_coercion(a, schema))
3411 .collect::<Result<Vec<_>>>()?;
3412 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3413 .params
3414 .order_by
3415 .iter()
3416 .map(|s| {
3417 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3418 Ok(datafusion::logical_expr::SortExpr {
3419 expr: coerced_expr,
3420 asc: s.asc,
3421 nulls_first: s.nulls_first,
3422 })
3423 })
3424 .collect::<Result<Vec<_>>>()?;
3425 let coerced_filter = agg
3426 .params
3427 .filter
3428 .as_ref()
3429 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3430 .transpose()?;
3431 Ok(DfExpr::AggregateFunction(
3432 datafusion::logical_expr::expr::AggregateFunction {
3433 func: agg.func.clone(),
3434 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3435 args: coerced_args,
3436 distinct: agg.params.distinct,
3437 filter: coerced_filter,
3438 order_by: coerced_order_by,
3439 null_treatment: agg.params.null_treatment,
3440 },
3441 },
3442 ))
3443}
3444
3445#[cfg(test)]
3446mod tests {
3447 use super::*;
3448 use arrow_array::{
3449 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3450 };
3451 use uni_common::TemporalValue;
3452 #[test]
3453 fn test_literal_translation() {
3454 let expr = Expr::Literal(CypherLiteral::Integer(42));
3455 let result = cypher_expr_to_df(&expr, None).unwrap();
3456 let s = format!("{:?}", result);
3457 assert!(s.contains("Literal"));
3459 assert!(s.contains("Int64(42)"));
3460 }
3461
3462 #[test]
3463 fn test_property_access_no_context_uses_index() {
3464 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3466 let result = cypher_expr_to_df(&expr, None).unwrap();
3467 let s = format!("{}", result);
3468 assert!(
3469 s.contains("index"),
3470 "expected index UDF for non-graph variable, got: {s}"
3471 );
3472 }
3473
3474 #[test]
3475 fn test_comparison_operator() {
3476 let expr = Expr::BinaryOp {
3477 left: Box::new(Expr::Property(
3478 Box::new(Expr::Variable("n".to_string())),
3479 "age".to_string(),
3480 )),
3481 op: BinaryOp::Gt,
3482 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3483 };
3484 let result = cypher_expr_to_df(&expr, None).unwrap();
3485 let s = format!("{:?}", result);
3487 assert!(s.contains("age"));
3488 assert!(s.contains("30"));
3489 }
3490
3491 #[test]
3492 fn test_boolean_operators() {
3493 let expr = Expr::BinaryOp {
3494 left: Box::new(Expr::BinaryOp {
3495 left: Box::new(Expr::Property(
3496 Box::new(Expr::Variable("n".to_string())),
3497 "age".to_string(),
3498 )),
3499 op: BinaryOp::Gt,
3500 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3501 }),
3502 op: BinaryOp::And,
3503 right: Box::new(Expr::BinaryOp {
3504 left: Box::new(Expr::Property(
3505 Box::new(Expr::Variable("n".to_string())),
3506 "active".to_string(),
3507 )),
3508 op: BinaryOp::Eq,
3509 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3510 }),
3511 };
3512 let result = cypher_expr_to_df(&expr, None).unwrap();
3513 let s = format!("{:?}", result);
3514 assert!(s.contains("And"));
3515 }
3516
3517 #[test]
3518 fn test_is_null() {
3519 let expr = Expr::IsNull(Box::new(Expr::Property(
3520 Box::new(Expr::Variable("n".to_string())),
3521 "email".to_string(),
3522 )));
3523 let result = cypher_expr_to_df(&expr, None).unwrap();
3524 let s = format!("{:?}", result);
3525 assert!(s.contains("IsNull"));
3526 }
3527
3528 #[test]
3529 fn test_collect_properties() {
3530 let expr = Expr::BinaryOp {
3531 left: Box::new(Expr::Property(
3532 Box::new(Expr::Variable("n".to_string())),
3533 "name".to_string(),
3534 )),
3535 op: BinaryOp::Eq,
3536 right: Box::new(Expr::Property(
3537 Box::new(Expr::Variable("m".to_string())),
3538 "name".to_string(),
3539 )),
3540 };
3541
3542 let props = collect_properties(&expr);
3543 assert_eq!(props.len(), 2);
3544 assert!(props.contains(&("m".to_string(), "name".to_string())));
3545 assert!(props.contains(&("n".to_string(), "name".to_string())));
3546 }
3547
3548 #[test]
3549 fn test_function_call() {
3550 let expr = Expr::FunctionCall {
3551 name: "count".to_string(),
3552 args: vec![Expr::Wildcard],
3553 distinct: false,
3554 window_spec: None,
3555 };
3556 let result = cypher_expr_to_df(&expr, None).unwrap();
3557 let s = format!("{:?}", result);
3558 assert!(s.to_lowercase().contains("count"));
3559 }
3560
3561 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3566 use datafusion::logical_expr::Operator;
3567
3568 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3570 let fields: Vec<_> = cols
3571 .iter()
3572 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3573 .collect();
3574 let schema = Schema::new(fields);
3575 datafusion::common::DFSchema::try_from(schema).unwrap()
3576 }
3577
3578 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3580 let s = format!("{}", expr);
3581 s.contains(name)
3582 }
3583
3584 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3586 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3587 }
3588
3589 #[test]
3590 fn test_coercion_lb_eq_int64() {
3591 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3592 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3593 Box::new(col("lb")),
3594 Operator::Eq,
3595 Box::new(col("i")),
3596 ));
3597 let result = apply_type_coercion(&expr, &schema).unwrap();
3598 assert!(
3600 contains_udf(&result, "_cypher_equal"),
3601 "expected _cypher_equal, got: {result}"
3602 );
3603 }
3604
3605 #[test]
3606 fn test_coercion_lb_noteq_int64() {
3607 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3608 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3609 Box::new(col("lb")),
3610 Operator::NotEq,
3611 Box::new(col("i")),
3612 ));
3613 let result = apply_type_coercion(&expr, &schema).unwrap();
3614 assert!(contains_udf(&result, "_cypher_not_equal"));
3616 }
3617
3618 #[test]
3619 fn test_coercion_lb_lt_int64() {
3620 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3621 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3622 Box::new(col("lb")),
3623 Operator::Lt,
3624 Box::new(col("i")),
3625 ));
3626 let result = apply_type_coercion(&expr, &schema).unwrap();
3627 assert!(contains_udf(&result, "_cypher_lt"));
3629 }
3630
3631 #[test]
3632 fn test_coercion_lb_eq_float64() {
3633 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3634 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3635 Box::new(col("lb")),
3636 Operator::Eq,
3637 Box::new(col("f")),
3638 ));
3639 let result = apply_type_coercion(&expr, &schema).unwrap();
3640 assert!(contains_udf(&result, "_cypher_equal"));
3642 }
3643
3644 #[test]
3645 fn test_coercion_lb_eq_utf8() {
3646 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3647 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3648 Box::new(col("lb")),
3649 Operator::Eq,
3650 Box::new(col("s")),
3651 ));
3652 let result = apply_type_coercion(&expr, &schema).unwrap();
3653 assert!(contains_udf(&result, "_cypher_equal"));
3655 }
3656
3657 #[test]
3658 fn test_coercion_lb_eq_bool() {
3659 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3660 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3661 Box::new(col("lb")),
3662 Operator::Eq,
3663 Box::new(col("b")),
3664 ));
3665 let result = apply_type_coercion(&expr, &schema).unwrap();
3666 assert!(contains_udf(&result, "_cypher_equal"));
3668 }
3669
3670 #[test]
3671 fn test_coercion_int64_eq_lb() {
3672 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3674 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3675 Box::new(col("i")),
3676 Operator::Eq,
3677 Box::new(col("lb")),
3678 ));
3679 let result = apply_type_coercion(&expr, &schema).unwrap();
3680 assert!(contains_udf(&result, "_cypher_equal"));
3682 }
3683
3684 #[test]
3685 fn test_coercion_float64_gt_lb() {
3686 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3687 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3688 Box::new(col("f")),
3689 Operator::Gt,
3690 Box::new(col("lb")),
3691 ));
3692 let result = apply_type_coercion(&expr, &schema).unwrap();
3693 assert!(contains_udf(&result, "_cypher_gt"));
3695 }
3696
3697 #[test]
3698 fn test_coercion_both_lb_eq() {
3699 let schema = make_schema(&[
3700 ("lb1", DataType::LargeBinary),
3701 ("lb2", DataType::LargeBinary),
3702 ]);
3703 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3704 Box::new(col("lb1")),
3705 Operator::Eq,
3706 Box::new(col("lb2")),
3707 ));
3708 let result = apply_type_coercion(&expr, &schema).unwrap();
3709 assert!(contains_udf(&result, "_cypher_equal"));
3710 }
3711
3712 #[test]
3713 fn test_coercion_both_lb_lt() {
3714 let schema = make_schema(&[
3715 ("lb1", DataType::LargeBinary),
3716 ("lb2", DataType::LargeBinary),
3717 ]);
3718 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3719 Box::new(col("lb1")),
3720 Operator::Lt,
3721 Box::new(col("lb2")),
3722 ));
3723 let result = apply_type_coercion(&expr, &schema).unwrap();
3724 assert!(contains_udf(&result, "_cypher_lt"));
3725 }
3726
3727 #[test]
3728 fn test_coercion_both_lb_noteq() {
3729 let schema = make_schema(&[
3730 ("lb1", DataType::LargeBinary),
3731 ("lb2", DataType::LargeBinary),
3732 ]);
3733 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3734 Box::new(col("lb1")),
3735 Operator::NotEq,
3736 Box::new(col("lb2")),
3737 ));
3738 let result = apply_type_coercion(&expr, &schema).unwrap();
3739 assert!(contains_udf(&result, "_cypher_not_equal"));
3740 }
3741
3742 #[test]
3743 fn test_coercion_lb_plus_int64() {
3744 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3745 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3746 Box::new(col("lb")),
3747 Operator::Plus,
3748 Box::new(col("i")),
3749 ));
3750 let result = apply_type_coercion(&expr, &schema).unwrap();
3751 assert!(contains_udf(&result, "_cypher_add"));
3752 }
3753
3754 #[test]
3755 fn test_coercion_lb_minus_int64() {
3756 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3757 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3758 Box::new(col("lb")),
3759 Operator::Minus,
3760 Box::new(col("i")),
3761 ));
3762 let result = apply_type_coercion(&expr, &schema).unwrap();
3763 assert!(contains_udf(&result, "_cypher_sub"));
3764 }
3765
3766 #[test]
3767 fn test_coercion_lb_multiply_float64() {
3768 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3769 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3770 Box::new(col("lb")),
3771 Operator::Multiply,
3772 Box::new(col("f")),
3773 ));
3774 let result = apply_type_coercion(&expr, &schema).unwrap();
3775 assert!(contains_udf(&result, "_cypher_mul"));
3776 }
3777
3778 #[test]
3779 fn test_coercion_int64_plus_lb() {
3780 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3781 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3782 Box::new(col("i")),
3783 Operator::Plus,
3784 Box::new(col("lb")),
3785 ));
3786 let result = apply_type_coercion(&expr, &schema).unwrap();
3787 assert!(contains_udf(&result, "_cypher_add"));
3788 }
3789
3790 #[test]
3791 fn test_coercion_lb_plus_utf8() {
3792 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3794 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3795 Box::new(col("lb")),
3796 Operator::Plus,
3797 Box::new(col("s")),
3798 ));
3799 let result = apply_type_coercion(&expr, &schema).unwrap();
3800 assert!(contains_udf(&result, "_cypher_add"));
3802 }
3803
3804 #[test]
3805 fn test_coercion_and_null_bool() {
3806 let schema = make_schema(&[("b", DataType::Boolean)]);
3807 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3809 Box::new(lit(ScalarValue::Null)),
3810 Operator::And,
3811 Box::new(col("b")),
3812 ));
3813 let result = apply_type_coercion(&expr, &schema).unwrap();
3814 let s = format!("{}", result);
3815 assert!(
3817 s.contains("CAST") || s.contains("Boolean"),
3818 "expected cast to Boolean, got: {s}"
3819 );
3820 assert!(is_binary_op(&result, Operator::And));
3821 }
3822
3823 #[test]
3824 fn test_coercion_bool_and_null() {
3825 let schema = make_schema(&[("b", DataType::Boolean)]);
3826 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3827 Box::new(col("b")),
3828 Operator::And,
3829 Box::new(lit(ScalarValue::Null)),
3830 ));
3831 let result = apply_type_coercion(&expr, &schema).unwrap();
3832 assert!(is_binary_op(&result, Operator::And));
3833 }
3834
3835 #[test]
3836 fn test_coercion_or_null_bool() {
3837 let schema = make_schema(&[("b", DataType::Boolean)]);
3838 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3839 Box::new(lit(ScalarValue::Null)),
3840 Operator::Or,
3841 Box::new(col("b")),
3842 ));
3843 let result = apply_type_coercion(&expr, &schema).unwrap();
3844 assert!(is_binary_op(&result, Operator::Or));
3845 }
3846
3847 #[test]
3848 fn test_coercion_null_and_null() {
3849 let schema = make_schema(&[]);
3850 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3851 Box::new(lit(ScalarValue::Null)),
3852 Operator::And,
3853 Box::new(lit(ScalarValue::Null)),
3854 ));
3855 let result = apply_type_coercion(&expr, &schema).unwrap();
3856 assert!(is_binary_op(&result, Operator::And));
3857 }
3858
3859 #[test]
3860 fn test_coercion_bool_and_bool_noop() {
3861 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3862 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3863 Box::new(col("a")),
3864 Operator::And,
3865 Box::new(col("b")),
3866 ));
3867 let result = apply_type_coercion(&expr, &schema).unwrap();
3868 assert!(is_binary_op(&result, Operator::And));
3870 let s = format!("{}", result);
3871 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3872 }
3873
3874 #[test]
3875 fn test_coercion_case_when_lb() {
3876 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3878 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3879 Box::new(col("lb")),
3880 Operator::Eq,
3881 Box::new(lit(42_i64)),
3882 ));
3883 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3884 expr: None,
3885 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3886 else_expr: Some(Box::new(lit("b"))),
3887 });
3888 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3889 let s = format!("{}", result);
3890 assert!(
3892 s.contains("_cypher_equal"),
3893 "CASE WHEN should have _cypher_equal, got: {s}"
3894 );
3895 }
3896
3897 #[test]
3898 fn test_coercion_case_then_lb() {
3899 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3901 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3902 Box::new(col("lb")),
3903 Operator::Plus,
3904 Box::new(lit(1_i64)),
3905 ));
3906 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3907 expr: None,
3908 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3909 else_expr: Some(Box::new(lit(0_i64))),
3910 });
3911 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3912 let s = format!("{}", result);
3913 assert!(
3914 s.contains("_cypher_add"),
3915 "CASE THEN should have _cypher_add, got: {s}"
3916 );
3917 }
3918
3919 #[test]
3920 fn test_coercion_case_else_lb() {
3921 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3923 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3924 Box::new(col("lb")),
3925 Operator::Plus,
3926 Box::new(lit(2_i64)),
3927 ));
3928 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3929 expr: None,
3930 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3931 else_expr: Some(Box::new(else_expr)),
3932 });
3933 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3934 let s = format!("{}", result);
3935 assert!(
3936 s.contains("_cypher_add"),
3937 "CASE ELSE should have _cypher_add, got: {s}"
3938 );
3939 }
3940
3941 #[test]
3942 fn test_coercion_int64_eq_int64_noop() {
3943 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3944 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3945 Box::new(col("a")),
3946 Operator::Eq,
3947 Box::new(col("b")),
3948 ));
3949 let result = apply_type_coercion(&expr, &schema).unwrap();
3950 assert!(is_binary_op(&result, Operator::Eq));
3951 let s = format!("{}", result);
3952 assert!(
3953 !s.contains("_cypher_value"),
3954 "should not contain cypher_value decode: {s}"
3955 );
3956 }
3957
3958 #[test]
3959 fn test_coercion_both_lb_plus() {
3960 let schema = make_schema(&[
3962 ("lb1", DataType::LargeBinary),
3963 ("lb2", DataType::LargeBinary),
3964 ]);
3965 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3966 Box::new(col("lb1")),
3967 Operator::Plus,
3968 Box::new(col("lb2")),
3969 ));
3970 let result = apply_type_coercion(&expr, &schema).unwrap();
3971 assert!(
3972 contains_udf(&result, "_cypher_add"),
3973 "expected _cypher_add, got: {result}"
3974 );
3975 }
3976
3977 #[test]
3978 fn test_coercion_native_list_plus_scalar() {
3979 let schema = make_schema(&[
3981 (
3982 "lst",
3983 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3984 ),
3985 ("i", DataType::Int32),
3986 ]);
3987 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3988 Box::new(col("lst")),
3989 Operator::Plus,
3990 Box::new(col("i")),
3991 ));
3992 let result = apply_type_coercion(&expr, &schema).unwrap();
3993 assert!(
3994 contains_udf(&result, "_cypher_list_append"),
3995 "expected _cypher_list_append, got: {result}"
3996 );
3997 }
3998
3999 #[test]
4000 fn test_coercion_lb_plus_int64_unchanged() {
4001 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
4003 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
4004 Box::new(col("lb")),
4005 Operator::Plus,
4006 Box::new(col("i")),
4007 ));
4008 let result = apply_type_coercion(&expr, &schema).unwrap();
4009 assert!(
4010 contains_udf(&result, "_cypher_add"),
4011 "expected _cypher_add, got: {result}"
4012 );
4013 }
4014
4015 #[test]
4020 fn test_mixed_list_with_variables_compiles() {
4021 let expr = Expr::List(vec![
4023 Expr::Variable("n".to_string()),
4024 Expr::Literal(CypherLiteral::Integer(1)),
4025 Expr::Literal(CypherLiteral::String("hello".to_string())),
4026 ]);
4027 let result = cypher_expr_to_df(&expr, None).unwrap();
4028 let s = format!("{}", result);
4029 assert!(
4030 s.contains("_make_cypher_list"),
4031 "expected _make_cypher_list UDF call, got: {s}"
4032 );
4033 }
4034
4035 #[test]
4036 fn test_literal_only_mixed_list_uses_cv_fastpath() {
4037 let expr = Expr::List(vec![
4039 Expr::Literal(CypherLiteral::Integer(1)),
4040 Expr::Literal(CypherLiteral::String("hi".to_string())),
4041 Expr::Literal(CypherLiteral::Bool(true)),
4042 ]);
4043 let result = cypher_expr_to_df(&expr, None).unwrap();
4044 assert!(
4045 matches!(result, DfExpr::Literal(..)),
4046 "expected Literal (CypherValue fast path), got: {result}"
4047 );
4048 }
4049
4050 #[test]
4055 fn test_in_mixed_literal_list_uses_cypher_in() {
4056 let expr = Expr::In {
4058 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4059 list: Box::new(Expr::List(vec![
4060 Expr::Literal(CypherLiteral::String("1".to_string())),
4061 Expr::Literal(CypherLiteral::Integer(2)),
4062 ])),
4063 };
4064 let result = cypher_expr_to_df(&expr, None).unwrap();
4065 let s = format!("{}", result);
4066 assert!(
4067 s.contains("_cypher_in"),
4068 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
4069 );
4070 }
4071
4072 #[test]
4073 fn test_in_homogeneous_literal_list_uses_cypher_in() {
4074 let expr = Expr::In {
4076 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4077 list: Box::new(Expr::List(vec![
4078 Expr::Literal(CypherLiteral::Integer(2)),
4079 Expr::Literal(CypherLiteral::Integer(3)),
4080 ])),
4081 };
4082 let result = cypher_expr_to_df(&expr, None).unwrap();
4083 let s = format!("{}", result);
4084 assert!(
4085 s.contains("_cypher_in"),
4086 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
4087 );
4088 }
4089
4090 #[test]
4091 fn test_in_list_with_variables_uses_make_cypher_list() {
4092 let expr = Expr::In {
4094 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4095 list: Box::new(Expr::List(vec![
4096 Expr::Variable("x".to_string()),
4097 Expr::Literal(CypherLiteral::Integer(2)),
4098 ])),
4099 };
4100 let result = cypher_expr_to_df(&expr, None).unwrap();
4101 let s = format!("{}", result);
4102 assert!(
4103 s.contains("_cypher_in"),
4104 "expected _cypher_in UDF, got: {s}"
4105 );
4106 assert!(
4107 s.contains("_make_cypher_list"),
4108 "expected _make_cypher_list for variable-containing list, got: {s}"
4109 );
4110 }
4111
4112 #[test]
4117 fn test_property_on_graph_entity_uses_column() {
4118 let mut ctx = TranslationContext::new();
4120 ctx.variable_kinds
4121 .insert("n".to_string(), VariableKind::Node);
4122
4123 let expr = Expr::Property(
4124 Box::new(Expr::Variable("n".to_string())),
4125 "name".to_string(),
4126 );
4127 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4128 let s = format!("{:?}", result);
4129 assert!(
4130 s.contains("Column") && s.contains("n.name"),
4131 "expected flat column 'n.name' for graph entity, got: {s}"
4132 );
4133 }
4134
4135 #[test]
4136 fn test_property_on_non_graph_var_uses_index() {
4137 let ctx = TranslationContext::new();
4139
4140 let expr = Expr::Property(
4141 Box::new(Expr::Variable("map".to_string())),
4142 "name".to_string(),
4143 );
4144 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4145 let s = format!("{}", result);
4146 assert!(
4147 s.contains("index"),
4148 "expected index UDF for non-graph variable, got: {s}"
4149 );
4150 }
4151
4152 #[test]
4153 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4154 let mut map = std::collections::HashMap::new();
4155 map.insert("k".to_string(), Value::Int(1));
4156 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4157 assert!(
4158 matches!(scalar, ScalarValue::Struct(_)),
4159 "expected Struct scalar for map input"
4160 );
4161 }
4162
4163 #[test]
4164 fn test_value_to_scalar_empty_map_becomes_struct() {
4165 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4166 assert!(
4167 matches!(scalar, ScalarValue::Struct(_)),
4168 "empty map should produce an empty Struct scalar"
4169 );
4170 }
4171
4172 #[test]
4173 fn test_value_to_scalar_null_is_untyped_null() {
4174 let scalar = value_to_scalar(&Value::Null).unwrap();
4175 assert!(
4176 matches!(scalar, ScalarValue::Null),
4177 "expected untyped Null scalar for Value::Null"
4178 );
4179 }
4180
4181 #[test]
4182 fn test_value_to_scalar_datetime_produces_struct() {
4183 let datetime = Value::Temporal(TemporalValue::DateTime {
4185 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4188 });
4189
4190 let scalar = value_to_scalar(&datetime).unwrap();
4191
4192 if let ScalarValue::Struct(struct_arr) = scalar {
4194 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4195 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4196
4197 let fields = struct_arr.fields();
4199 assert_eq!(fields[0].name(), "nanos_since_epoch");
4200 assert_eq!(fields[1].name(), "offset_seconds");
4201 assert_eq!(fields[2].name(), "timezone_name");
4202
4203 let nanos_col = struct_arr.column(0);
4205 let offset_col = struct_arr.column(1);
4206 let tz_col = struct_arr.column(2);
4207
4208 if let Some(nanos_arr) = nanos_col
4209 .as_any()
4210 .downcast_ref::<TimestampNanosecondArray>()
4211 {
4212 assert_eq!(nanos_arr.value(0), 441763200000000000);
4213 } else {
4214 panic!("Expected TimestampNanosecondArray for nanos field");
4215 }
4216
4217 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4218 assert_eq!(offset_arr.value(0), 3600);
4219 } else {
4220 panic!("Expected Int32Array for offset field");
4221 }
4222
4223 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4224 assert_eq!(tz_arr.value(0), "Europe/Paris");
4225 } else {
4226 panic!("Expected StringArray for timezone_name field");
4227 }
4228 } else {
4229 panic!(
4230 "Expected ScalarValue::Struct for DateTime, got {:?}",
4231 scalar
4232 );
4233 }
4234 }
4235
4236 #[test]
4237 fn test_value_to_scalar_datetime_with_null_timezone() {
4238 let datetime = Value::Temporal(TemporalValue::DateTime {
4240 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4243 });
4244
4245 let scalar = value_to_scalar(&datetime).unwrap();
4246
4247 if let ScalarValue::Struct(struct_arr) = scalar {
4248 assert_eq!(struct_arr.num_columns(), 3);
4249
4250 let tz_col = struct_arr.column(2);
4252 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4253 assert!(tz_arr.is_null(0), "expected null timezone_name");
4254 } else {
4255 panic!("Expected StringArray for timezone_name field");
4256 }
4257 } else {
4258 panic!("Expected ScalarValue::Struct for DateTime");
4259 }
4260 }
4261
4262 #[test]
4263 fn test_value_to_scalar_time_produces_struct() {
4264 let time = Value::Temporal(TemporalValue::Time {
4266 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4269
4270 let scalar = value_to_scalar(&time).unwrap();
4271
4272 if let ScalarValue::Struct(struct_arr) = scalar {
4274 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4275 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4276
4277 let fields = struct_arr.fields();
4279 assert_eq!(fields[0].name(), "nanos_since_midnight");
4280 assert_eq!(fields[1].name(), "offset_seconds");
4281
4282 let nanos_col = struct_arr.column(0);
4284 let offset_col = struct_arr.column(1);
4285
4286 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4287 assert_eq!(nanos_arr.value(0), 37845000000000);
4288 } else {
4289 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4290 }
4291
4292 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4293 assert_eq!(offset_arr.value(0), 3600);
4294 } else {
4295 panic!("Expected Int32Array for offset field");
4296 }
4297 } else {
4298 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4299 }
4300 }
4301
4302 #[test]
4303 fn test_value_to_scalar_time_boundary_values() {
4304 let midnight = Value::Temporal(TemporalValue::Time {
4306 nanos_since_midnight: 0,
4307 offset_seconds: 0,
4308 });
4309
4310 let scalar = value_to_scalar(&midnight).unwrap();
4311
4312 if let ScalarValue::Struct(struct_arr) = scalar {
4313 let nanos_col = struct_arr.column(0);
4314 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4315 assert_eq!(nanos_arr.value(0), 0);
4316 } else {
4317 panic!("Expected Time64NanosecondArray");
4318 }
4319 } else {
4320 panic!("Expected ScalarValue::Struct for Time");
4321 }
4322 }
4323}