1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum VariableKind {
225 Node,
227 Edge,
229 EdgeList,
231 Path,
233}
234
235impl VariableKind {
236 pub fn edge_for(is_variable_length: bool) -> Self {
240 if is_variable_length {
241 Self::EdgeList
242 } else {
243 Self::Edge
244 }
245 }
246}
247
248pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
283 match expr {
284 Expr::PatternComprehension { .. } => Err(anyhow!(
285 "Pattern comprehensions require fallback executor (graph traversal)"
286 )),
287 Expr::Wildcard => Ok(DfExpr::Literal(
291 datafusion::common::ScalarValue::Int32(Some(1)),
292 None,
293 )),
294
295 Expr::Variable(name) => {
296 if let Some(ctx) = context
301 && ctx.variable_kinds.contains_key(name)
302 {
303 return Ok(DfExpr::Column(Column::from_name(name)));
304 }
305
306 if let Some(ctx) = context
312 && let Some(value) = ctx.outer_values.get(name)
313 {
314 return value_to_scalar(value).map(lit);
315 }
316
317 if let Some(ctx) = context
322 && let Some(value) = ctx.parameters.get(name)
323 {
324 match value {
327 Value::List(values) if name.ends_with("._vid") => {
328 let literals = values
330 .iter()
331 .map(|v| value_to_scalar(v).map(lit))
332 .collect::<Result<Vec<_>>>()?;
333 return Ok(DfExpr::InList(InList {
334 expr: Box::new(DfExpr::Column(Column::from_name(name))),
335 list: literals,
336 negated: false,
337 }));
338 }
339 other_value => return value_to_scalar(other_value).map(lit),
340 }
341 }
342
343 Ok(DfExpr::Column(Column::from_name(name)))
346 }
347
348 Expr::Property(base, prop) => translate_property_access(base, prop, context),
349
350 Expr::ArrayIndex { array, index } => {
351 if let Ok(var_name) = extract_variable_name(array)
354 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
355 {
356 let col_name = format!("{}.{}", var_name, prop_name);
357 return Ok(DfExpr::Column(Column::from_name(col_name)));
358 }
359
360 let array_expr = cypher_expr_to_df(array, context)?;
361 let index_expr = cypher_expr_to_df(index, context)?;
362
363 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
365 }
366
367 Expr::ArraySlice { array, start, end } => {
368 let array_expr = cypher_expr_to_df(array, context)?;
372
373 let start_expr = match start {
374 Some(s) => cypher_expr_to_df(s, context)?,
375 None => lit(0i64),
376 };
377
378 let end_expr = match end {
379 Some(e) => cypher_expr_to_df(e, context)?,
380 None => lit(i64::MAX),
381 };
382
383 Ok(dummy_udf_expr(
386 "_cypher_list_slice",
387 vec![array_expr, start_expr, end_expr],
388 ))
389 }
390
391 Expr::Parameter(name) => {
392 if let Some(ctx) = context
394 && let Some(value) = ctx.parameters.get(name)
395 {
396 return value_to_scalar(value).map(lit);
397 }
398 Err(anyhow!("Unresolved parameter: ${}", name))
399 }
400
401 Expr::Literal(value) => {
402 let scalar = cypher_literal_to_scalar(value)?;
403 Ok(lit(scalar))
404 }
405
406 Expr::List(items) => translate_list_literal(items, context),
407
408 Expr::Map(entries) => {
409 if entries.is_empty() {
410 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
412 Default::default(),
413 ));
414 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
415 }
416 let mut args = Vec::with_capacity(entries.len() * 2);
419 for (key, val_expr) in entries {
420 args.push(lit(key.clone()));
421 args.push(cypher_expr_to_df(val_expr, context)?);
422 }
423 Ok(datafusion::functions::expr_fn::named_struct(args))
424 }
425
426 Expr::IsNull(inner) => translate_null_check(inner, context, true),
427
428 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
429
430 Expr::IsUnique(_) => {
431 Err(anyhow!(
433 "IS UNIQUE can only be used in constraint definitions"
434 ))
435 }
436
437 Expr::FunctionCall {
438 name,
439 args,
440 distinct,
441 window_spec,
442 } => {
443 if window_spec.is_some() {
446 let col_name = expr.to_string_repr();
448 Ok(col(&col_name))
449 } else {
450 translate_function_call(name, args, *distinct, context)
451 }
452 }
453
454 Expr::In { expr, list } => translate_in_expression(expr, list, context),
455
456 Expr::BinaryOp { left, op, right } => {
457 let left_expr = cypher_expr_to_df(left, context)?;
458 let right_expr = cypher_expr_to_df(right, context)?;
459 translate_binary_op(left_expr, op, right_expr)
460 }
461
462 Expr::UnaryOp { op, expr: inner } => {
463 let inner_expr = cypher_expr_to_df(inner, context)?;
464 match op {
465 UnaryOp::Not => Ok(inner_expr.not()),
466 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
467 }
468 }
469
470 Expr::Case {
471 expr,
472 when_then,
473 else_expr,
474 } => translate_case_expression(expr, when_then, else_expr, context),
475
476 Expr::Reduce { .. } => Err(anyhow!(
477 "Reduce expressions not yet supported in DataFusion translation"
478 )),
479
480 Expr::Exists { .. } => Err(anyhow!(
481 "EXISTS subqueries are handled by the physical expression compiler, \
482 not the DataFusion logical expression translator"
483 )),
484
485 Expr::CountSubquery(_) => Err(anyhow!(
486 "Count subqueries not yet supported in DataFusion translation"
487 )),
488
489 Expr::CollectSubquery(_) => Err(anyhow!(
490 "COLLECT subqueries not yet supported in DataFusion translation"
491 )),
492
493 Expr::Quantifier { .. } => {
494 Err(anyhow!(
499 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
500 via CypherPhysicalExprCompiler"
501 ))
502 }
503
504 Expr::ListComprehension { .. } => {
505 Err(anyhow!(
517 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
518 ))
519 }
520
521 Expr::ValidAt { .. } => {
522 Err(anyhow!(
525 "VALID_AT expression should have been transformed to function call in planner"
526 ))
527 }
528
529 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
530
531 Expr::LabelCheck { expr, labels } => {
532 if let Expr::Variable(var) = expr.as_ref() {
533 let is_edge = context
535 .and_then(|ctx| ctx.variable_kinds.get(var))
536 .is_some_and(|k| matches!(k, VariableKind::Edge));
537
538 if is_edge {
539 if labels.len() > 1 {
543 Ok(lit(false))
544 } else {
545 let type_col =
546 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
547 Ok(DfExpr::Case(datafusion::logical_expr::Case {
549 expr: None,
550 when_then_expr: vec![(
551 Box::new(type_col.clone().is_null()),
552 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
553 )],
554 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
555 }))
556 }
557 } else {
558 let labels_col =
560 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
561 let checks = labels
562 .iter()
563 .map(|label| {
564 datafusion::functions_nested::expr_fn::array_has(
565 labels_col.clone(),
566 lit(label.clone()),
567 )
568 })
569 .reduce(|acc, check| acc.and(check));
570 Ok(DfExpr::Case(datafusion::logical_expr::Case {
572 expr: None,
573 when_then_expr: vec![(
574 Box::new(labels_col.is_null()),
575 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
576 )],
577 else_expr: Some(Box::new(checks.unwrap())),
578 }))
579 }
580 } else {
581 Err(anyhow!(
582 "LabelCheck on non-variable expression not yet supported in DataFusion"
583 ))
584 }
585 }
586 }
587}
588
589#[derive(Debug, Clone)]
593pub struct TranslationContext {
594 pub parameters: std::collections::HashMap<String, Value>,
596
597 pub outer_values: std::collections::HashMap<String, Value>,
601
602 pub variable_labels: std::collections::HashMap<String, String>,
604
605 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
607
608 pub node_variable_hints: Vec<String>,
611
612 pub mutation_edge_hints: Vec<String>,
615
616 pub statement_time: chrono::DateTime<chrono::Utc>,
621}
622
623impl Default for TranslationContext {
624 fn default() -> Self {
625 Self {
626 parameters: std::collections::HashMap::new(),
627 outer_values: std::collections::HashMap::new(),
628 variable_labels: std::collections::HashMap::new(),
629 variable_kinds: std::collections::HashMap::new(),
630 node_variable_hints: Vec::new(),
631 mutation_edge_hints: Vec::new(),
632 statement_time: chrono::Utc::now(),
633 }
634 }
635}
636
637impl TranslationContext {
638 pub fn new() -> Self {
640 Self::default()
641 }
642
643 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
645 self.parameters.insert(name.into(), value);
646 self
647 }
648
649 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
651 self.variable_labels.insert(var.into(), label.into());
652 self
653 }
654}
655
656fn extract_variable_name(expr: &Expr) -> Result<String> {
658 match expr {
659 Expr::Variable(name) => Ok(name.clone()),
660 Expr::Property(base, _) => extract_variable_name(base),
661 _ => Err(anyhow!(
662 "Cannot extract variable name from expression: {:?}",
663 expr
664 )),
665 }
666}
667
668fn translate_null_check(
670 inner: &Expr,
671 context: Option<&TranslationContext>,
672 is_null: bool,
673) -> Result<DfExpr> {
674 if let Expr::Variable(var) = inner
675 && let Some(ctx) = context
676 && let Some(kind) = ctx.variable_kinds.get(var)
677 {
678 let col_name = match kind {
679 VariableKind::Node => format!("{}.{}", var, COL_VID),
680 VariableKind::Edge => format!("{}.{}", var, COL_EID),
681 VariableKind::Path | VariableKind::EdgeList => var.clone(),
682 };
683 let col_expr = DfExpr::Column(Column::from_name(col_name));
684 return Ok(if is_null {
685 col_expr.is_null()
686 } else {
687 col_expr.is_not_null()
688 });
689 }
690
691 let inner_expr = cypher_expr_to_df(inner, context)?;
692 Ok(if is_null {
693 inner_expr.is_null()
694 } else {
695 inner_expr.is_not_null()
696 })
697}
698
699fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
704 if crate::datetime::is_duration_accessor(prop) {
705 Some(dummy_udf_expr(
706 "_duration_property",
707 vec![base_expr, lit(prop.to_string())],
708 ))
709 } else if crate::datetime::is_temporal_accessor(prop) {
710 Some(dummy_udf_expr(
711 "_temporal_property",
712 vec![base_expr, lit(prop.to_string())],
713 ))
714 } else {
715 None
716 }
717}
718
719fn translate_property_access(
721 base: &Expr,
722 prop: &str,
723 context: Option<&TranslationContext>,
724) -> Result<DfExpr> {
725 if let Ok(var_name) = extract_variable_name(base) {
726 let is_graph_entity = context
727 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
728 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
729
730 if !is_graph_entity
731 && let Some(expr) =
732 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
733 {
734 return Ok(expr);
735 }
736
737 let col_name = format!("{}.{}", var_name, prop);
738
739 if let Some(ctx) = context
742 && let Some(value) = ctx.parameters.get(&col_name)
743 {
744 match value {
747 Value::List(values) if col_name.ends_with("._vid") => {
748 let literals = values
749 .iter()
750 .map(|v| value_to_scalar(v).map(lit))
751 .collect::<Result<Vec<_>>>()?;
752 return Ok(DfExpr::InList(InList {
753 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
754 list: literals,
755 negated: false,
756 }));
757 }
758 other_value => return value_to_scalar(other_value).map(lit),
759 }
760 }
761
762 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
765 let base_expr = cypher_expr_to_df(base, context)?;
766 return Ok(dummy_udf_expr(
767 "index",
768 vec![base_expr, lit(prop.to_string())],
769 ));
770 }
771
772 if is_graph_entity {
773 Ok(DfExpr::Column(Column::from_name(col_name)))
774 } else {
775 let base_expr = DfExpr::Column(Column::from_name(var_name));
776 Ok(dummy_udf_expr(
777 "index",
778 vec![base_expr, lit(prop.to_string())],
779 ))
780 }
781 } else {
782 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
784 return Ok(expr);
785 }
786
787 if let Expr::Parameter(param_name) = base {
789 if let Some(ctx) = context
790 && let Some(value) = ctx.parameters.get(param_name)
791 {
792 if let Value::Map(map) = value {
793 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
794 return value_to_scalar(&extracted).map(lit);
795 }
796 return Ok(lit(ScalarValue::Null));
797 }
798 return Err(anyhow!("Unresolved parameter: ${}", param_name));
799 }
800
801 let base_expr = cypher_expr_to_df(base, context)?;
802 Ok(dummy_udf_expr(
803 "index",
804 vec![base_expr, lit(prop.to_string())],
805 ))
806 }
807}
808
809fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
811 let mut has_string = false;
813 let mut has_bool = false;
814 let mut has_list = false;
815 let mut has_map = false;
816 let mut has_numeric = false;
817 let mut has_graph_entity = false;
818 let mut has_temporal = false;
819
820 for item in items {
821 match item {
822 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
823 has_numeric = true
824 }
825 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
826 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
827 Expr::List(_) => has_list = true,
828 Expr::Map(_) => has_map = true,
829 Expr::Variable(name)
832 if context
833 .and_then(|ctx| ctx.variable_kinds.get(name))
834 .is_some() =>
835 {
836 has_graph_entity = true;
837 }
838 Expr::FunctionCall { name, .. } => {
841 let upper = name.to_uppercase();
842 if matches!(
843 upper.as_str(),
844 "DATE"
845 | "TIME"
846 | "LOCALTIME"
847 | "LOCALDATETIME"
848 | "DATETIME"
849 | "DURATION"
850 | "DATE.TRUNCATE"
851 | "TIME.TRUNCATE"
852 | "DATETIME.TRUNCATE"
853 | "LOCALDATETIME.TRUNCATE"
854 | "LOCALTIME.TRUNCATE"
855 ) {
856 has_temporal = true;
857 }
858 }
859 _ => {}
861 }
862 }
863
864 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
866
867 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
870 if let Some(json_array) = try_items_to_json(items) {
872 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
873 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
874 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
875 }
876 let df_args: Vec<DfExpr> = items
878 .iter()
879 .map(|item| cypher_expr_to_df(item, context))
880 .collect::<Result<_>>()?;
881 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
882 }
883
884 let mut df_args = Vec::with_capacity(items.len());
887 let mut has_float = false;
888 let mut has_int = false;
889 let mut has_other = false;
890
891 for item in items {
892 match item {
893 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
894 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
895 _ => has_other = true,
896 }
897 df_args.push(cypher_expr_to_df(item, context)?);
898 }
899
900 if df_args.is_empty() {
901 let empty_arr =
903 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
904 Ok(lit(ScalarValue::List(empty_arr)))
905 } else if has_float && has_int && !has_other {
906 let promoted_args = df_args
908 .into_iter()
909 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
910 .collect();
911 Ok(datafusion::functions_nested::expr_fn::make_array(
912 promoted_args,
913 ))
914 } else {
915 let non_null_type = df_args.iter().find_map(|e| {
919 if let DfExpr::Literal(sv, _) = e {
920 let dt = sv.data_type();
921 if dt != datafusion::arrow::datatypes::DataType::Null {
922 return Some(dt);
923 }
924 }
925 None
926 });
927 if let Some(ref target_type) = non_null_type {
928 let coerced = df_args
929 .into_iter()
930 .map(|e| {
931 if matches!(&e, DfExpr::Literal(sv, _) if sv.data_type() == datafusion::arrow::datatypes::DataType::Null)
932 {
933 cast_expr(e, target_type.clone())
934 } else {
935 e
936 }
937 })
938 .collect();
939 Ok(datafusion::functions_nested::expr_fn::make_array(coerced))
940 } else {
941 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
942 }
943 }
944}
945
946fn translate_in_expression(
948 expr: &Expr,
949 list: &Expr,
950 context: Option<&TranslationContext>,
951) -> Result<DfExpr> {
952 let left_expr = if let Expr::Variable(var) = expr
957 && let Some(ctx) = context
958 && let Some(kind) = ctx.variable_kinds.get(var)
959 {
960 match kind {
961 VariableKind::Node | VariableKind::Edge => {
962 let id_col = match kind {
963 VariableKind::Node => COL_VID,
964 VariableKind::Edge => COL_EID,
965 _ => unreachable!(),
966 };
967 cast_expr(
968 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
969 datafusion::arrow::datatypes::DataType::Int64,
970 )
971 }
972 _ => cypher_expr_to_df(expr, context)?,
973 }
974 } else {
975 cypher_expr_to_df(expr, context)?
976 };
977
978 if let Expr::List(items) = list {
983 if let Some(json_array) = try_items_to_json(items) {
984 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
986 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
987 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
988 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
989 } else {
990 let expanded: Vec<DfExpr> = items
992 .iter()
993 .map(|item| cypher_expr_to_df(item, context))
994 .collect::<Result<Vec<_>>>()?;
995 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
996 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
997 }
998 } else {
999 let right_expr = cypher_expr_to_df(list, context)?;
1000
1001 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
1006 return Ok(lit(ScalarValue::Boolean(None)));
1007 }
1008
1009 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
1010 }
1011}
1012
1013fn translate_case_expression(
1015 operand: &Option<Box<Expr>>,
1016 when_then: &[(Expr, Expr)],
1017 else_expr: &Option<Box<Expr>>,
1018 context: Option<&TranslationContext>,
1019) -> Result<DfExpr> {
1020 let mut case_builder = if let Some(match_expr) = operand {
1021 let match_df = cypher_expr_to_df(match_expr, context)?;
1022 datafusion::logical_expr::case(match_df)
1023 } else {
1024 datafusion::logical_expr::when(
1025 cypher_expr_to_df(&when_then[0].0, context)?,
1026 cypher_expr_to_df(&when_then[0].1, context)?,
1027 )
1028 };
1029
1030 let start_idx = if operand.is_some() { 0 } else { 1 };
1031 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1032 let when_df = cypher_expr_to_df(when_expr, context)?;
1033 let then_df = cypher_expr_to_df(then_expr, context)?;
1034 case_builder = case_builder.when(when_df, then_df);
1035 }
1036
1037 if let Some(else_e) = else_expr {
1038 let else_df = cypher_expr_to_df(else_e, context)?;
1039 Ok(case_builder.otherwise(else_df)?)
1040 } else {
1041 Ok(case_builder.end()?)
1042 }
1043}
1044
1045fn translate_map_projection(
1047 base: &Expr,
1048 items: &[MapProjectionItem],
1049 context: Option<&TranslationContext>,
1050) -> Result<DfExpr> {
1051 let mut args = Vec::new();
1052 for item in items {
1053 match item {
1054 MapProjectionItem::Property(prop) => {
1055 args.push(lit(prop.clone()));
1056 let prop_expr = cypher_expr_to_df(
1057 &Expr::Property(Box::new(base.clone()), prop.clone()),
1058 context,
1059 )?;
1060 args.push(prop_expr);
1061 }
1062 MapProjectionItem::LiteralEntry(key, expr) => {
1063 args.push(lit(key.clone()));
1064 args.push(cypher_expr_to_df(expr, context)?);
1065 }
1066 MapProjectionItem::Variable(var) => {
1067 args.push(lit(var.clone()));
1068 args.push(DfExpr::Column(Column::from_name(var)));
1069 }
1070 MapProjectionItem::AllProperties => {
1071 args.push(lit("__all__"));
1072 args.push(cypher_expr_to_df(base, context)?);
1073 }
1074 }
1075 }
1076 Ok(dummy_udf_expr("_map_project", args))
1077}
1078
1079fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1082 match expr {
1083 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1084 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1085 Expr::Literal(CypherLiteral::Integer(i)) => {
1086 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1087 }
1088 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1089 .map(serde_json::Value::Number)
1090 .or(Some(serde_json::Value::Null)),
1091 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1092 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1093 Expr::Map(entries) => {
1094 let mut map = serde_json::Map::new();
1095 for (k, v) in entries {
1096 map.insert(k.clone(), try_expr_to_json(v)?);
1097 }
1098 Some(serde_json::Value::Object(map))
1099 }
1100 _ => None,
1101 }
1102}
1103
1104fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1106 items.iter().map(try_expr_to_json).collect()
1107}
1108
1109fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1111 match lit {
1112 CypherLiteral::Null => Ok(ScalarValue::Null),
1113 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1114 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1115 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1116 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1117 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1118 }
1119}
1120
1121fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1123 match value {
1124 Value::Null => Ok(ScalarValue::Null),
1125 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1126 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1127 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1128 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1129 Value::List(items) => {
1130 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1132 let scalars = scalars?;
1133
1134 let data_type = infer_common_scalar_type(&scalars);
1136
1137 let typed_scalars: Vec<ScalarValue> = scalars
1139 .into_iter()
1140 .map(|s| {
1141 if matches!(s, ScalarValue::Null) {
1142 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1143 }
1144
1145 match (s, &data_type) {
1146 (
1147 ScalarValue::Int64(Some(v)),
1148 datafusion::arrow::datatypes::DataType::Float64,
1149 ) => ScalarValue::Float64(Some(v as f64)),
1150 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1151 let s_str = s.to_string();
1153 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1154 }
1155 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1156 if matches!(s, ScalarValue::Utf8(_)) {
1158 s
1159 } else {
1160 ScalarValue::Utf8(Some(s.to_string()))
1161 }
1162 }
1163 (s, _) => s,
1164 }
1165 })
1166 .collect();
1167
1168 if typed_scalars.is_empty() {
1170 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1171 &[],
1172 &data_type,
1173 )))
1174 } else {
1175 Ok(ScalarValue::List(ScalarValue::new_list(
1176 &typed_scalars,
1177 &data_type,
1178 true,
1179 )))
1180 }
1181 }
1182 Value::Map(map) => {
1183 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1186 entries.sort_by_key(|(k, _)| *k);
1187
1188 if entries.is_empty() {
1189 return Ok(ScalarValue::Struct(Arc::new(
1190 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1191 )));
1192 }
1193
1194 let mut fields_arrays = Vec::with_capacity(entries.len());
1195
1196 for (k, v) in entries {
1197 let scalar = value_to_scalar(v)?;
1198 let dt = scalar.data_type();
1199 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1200 let array = scalar.to_array()?;
1201 fields_arrays.push((field, array));
1202 }
1203
1204 Ok(ScalarValue::Struct(Arc::new(
1205 datafusion::arrow::array::StructArray::from(fields_arrays),
1206 )))
1207 }
1208 Value::Temporal(tv) => {
1209 use uni_common::TemporalValue;
1210 match tv {
1211 TemporalValue::Date { days_since_epoch } => {
1212 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1213 }
1214 TemporalValue::LocalTime {
1215 nanos_since_midnight,
1216 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1217 TemporalValue::Time {
1218 nanos_since_midnight,
1219 offset_seconds,
1220 } => {
1221 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1223 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1224
1225 let nanos_arr =
1226 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1227 as ArrayRef;
1228 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1229
1230 let fields = Fields::from(vec![
1231 Field::new(
1232 "nanos_since_midnight",
1233 ArrowDataType::Time64(TimeUnit::Nanosecond),
1234 true,
1235 ),
1236 Field::new("offset_seconds", ArrowDataType::Int32, true),
1237 ]);
1238
1239 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1240 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1241 }
1242 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1243 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1244 ),
1245 TemporalValue::DateTime {
1246 nanos_since_epoch,
1247 offset_seconds,
1248 timezone_name,
1249 } => {
1250 use arrow::array::{
1252 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1253 };
1254 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1255
1256 let nanos_arr =
1257 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1258 as ArrayRef;
1259 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1260 let tz_arr =
1261 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1262
1263 let fields = Fields::from(vec![
1264 Field::new(
1265 "nanos_since_epoch",
1266 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1267 true,
1268 ),
1269 Field::new("offset_seconds", ArrowDataType::Int32, true),
1270 Field::new("timezone_name", ArrowDataType::Utf8, true),
1271 ]);
1272
1273 let struct_arr =
1274 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1275 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1276 }
1277 TemporalValue::Duration {
1278 months,
1279 days,
1280 nanos,
1281 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1282 arrow::datatypes::IntervalMonthDayNano {
1283 months: *months as i32,
1284 days: *days as i32,
1285 nanoseconds: *nanos,
1286 },
1287 ))),
1288 TemporalValue::Btic { lo, hi, meta } => {
1289 let btic = uni_btic::Btic::new(*lo, *hi, *meta)
1290 .map_err(|e| anyhow::anyhow!("invalid BTIC value: {}", e))?;
1291 let packed = uni_btic::encode::encode(&btic);
1292 Ok(ScalarValue::FixedSizeBinary(24, Some(packed.to_vec())))
1293 }
1294 }
1295 }
1296 Value::Vector(v) => {
1297 let cv_bytes = uni_common::cypher_value_codec::encode(&Value::Vector(v.clone()));
1299 Ok(ScalarValue::LargeBinary(Some(cv_bytes)))
1300 }
1301 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1302 other => {
1304 let json_val: serde_json::Value = other.clone().into();
1305 let json_str = serde_json::to_string(&json_val)
1306 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1307 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1308 }
1309 }
1310}
1311
1312fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1314 match op {
1315 BinaryOp::Eq => Ok(left.eq(right)),
1319 BinaryOp::NotEq => Ok(left.not_eq(right)),
1320 BinaryOp::Lt => Ok(left.lt(right)),
1321 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1322 BinaryOp::Gt => Ok(left.gt(right)),
1323 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1324
1325 BinaryOp::And => Ok(left.and(right)),
1327 BinaryOp::Or => Ok(left.or(right)),
1328 BinaryOp::Xor => {
1329 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1331 }
1332
1333 BinaryOp::Add => {
1339 if is_list_expr(&left) || is_list_expr(&right) {
1340 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1341 } else {
1342 Ok(left + right)
1343 }
1344 }
1345 BinaryOp::Sub => Ok(left - right),
1346 BinaryOp::Mul => Ok(left * right),
1347 BinaryOp::Div => Ok(left / right),
1348 BinaryOp::Mod => Ok(left % right),
1349 BinaryOp::Pow => {
1350 let left_f = datafusion::logical_expr::cast(
1353 left,
1354 datafusion::arrow::datatypes::DataType::Float64,
1355 );
1356 let right_f = datafusion::logical_expr::cast(
1357 right,
1358 datafusion::arrow::datatypes::DataType::Float64,
1359 );
1360 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1361 }
1362
1363 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1365 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1366 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1367
1368 BinaryOp::Regex => {
1369 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1370 }
1371
1372 BinaryOp::ApproxEq => Err(anyhow!(
1373 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1374 )),
1375 }
1376}
1377
1378macro_rules! check_args {
1383 (1, $df_args:expr, $name:expr) => {
1384 if let Err(e) = require_arg($df_args, $name) {
1385 return Some(Err(e));
1386 }
1387 };
1388 ($n:expr, $df_args:expr, $name:expr) => {
1389 if let Err(e) = require_args($df_args, $n, $name) {
1390 return Some(Err(e));
1391 }
1392 };
1393}
1394
1395fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1398 if df_args.len() < count {
1399 let noun = if count == 1 { "argument" } else { "arguments" };
1400 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1401 }
1402 Ok(())
1403}
1404
1405fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1407 require_args(df_args, 1, func_name)
1408}
1409
1410fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1412 df_args[0].clone()
1413}
1414
1415pub fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1417 DfExpr::Cast(datafusion::logical_expr::Cast {
1418 expr: Box::new(expr),
1419 data_type,
1420 })
1421}
1422
1423pub fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1429 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1430 Arc::new(crate::df_udfs::create_cypher_list_to_cv_udf()),
1431 vec![expr],
1432 ))
1433}
1434
1435pub fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1439 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1440 Arc::new(crate::df_udfs::create_cypher_scalar_to_cv_udf()),
1441 vec![expr],
1442 ))
1443}
1444
1445fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1447 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1448 Box::new(left),
1449 op,
1450 Box::new(right),
1451 ))
1452}
1453
1454pub fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1459 use datafusion::logical_expr::Operator;
1460 match op {
1461 Operator::Eq => Some("_cypher_equal"),
1462 Operator::NotEq => Some("_cypher_not_equal"),
1463 Operator::Lt => Some("_cypher_lt"),
1464 Operator::LtEq => Some("_cypher_lt_eq"),
1465 Operator::Gt => Some("_cypher_gt"),
1466 Operator::GtEq => Some("_cypher_gt_eq"),
1467 _ => None,
1468 }
1469}
1470
1471fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1473 use datafusion::logical_expr::Operator;
1474 match op {
1475 Operator::Plus => Some("_cypher_add"),
1476 Operator::Minus => Some("_cypher_sub"),
1477 Operator::Multiply => Some("_cypher_mul"),
1478 Operator::Divide => Some("_cypher_div"),
1479 Operator::Modulo => Some("_cypher_mod"),
1480 _ => None,
1481 }
1482}
1483
1484fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1489where
1490 F: FnOnce(DfExpr) -> DfExpr,
1491{
1492 require_arg(df_args, func_name)?;
1493 Ok(math_fn(cast_expr(
1494 first_arg(df_args),
1495 datafusion::arrow::datatypes::DataType::Float64,
1496 )))
1497}
1498
1499fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1501 if distinct {
1502 expr.distinct()
1503 .build()
1504 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1505 } else {
1506 Ok(expr)
1507 }
1508}
1509
1510fn translate_aggregate_function(
1512 name_upper: &str,
1513 df_args: &[DfExpr],
1514 distinct: bool,
1515) -> Option<Result<DfExpr>> {
1516 match name_upper {
1517 "COUNT" => {
1518 let expr = if df_args.is_empty() {
1519 datafusion::functions_aggregate::count::count(lit(1i64))
1520 } else {
1521 datafusion::functions_aggregate::count::count(first_arg(df_args))
1522 };
1523 Some(maybe_distinct(expr, distinct, "COUNT"))
1524 }
1525 "SUM" => {
1526 check_args!(1, df_args, "SUM");
1527 let udaf = Arc::new(crate::df_udfs::create_cypher_sum_udaf());
1528 Some(maybe_distinct(
1529 udaf.call(vec![first_arg(df_args)]),
1530 distinct,
1531 "SUM",
1532 ))
1533 }
1534 "AVG" => {
1535 check_args!(1, df_args, "AVG");
1536 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1537 let expr = datafusion::functions_aggregate::average::avg(coerced);
1538 Some(maybe_distinct(expr, distinct, "AVG"))
1539 }
1540 "MIN" => {
1541 check_args!(1, df_args, "MIN");
1542 let udaf = Arc::new(crate::df_udfs::create_cypher_min_udaf());
1543 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1544 }
1545 "MAX" => {
1546 check_args!(1, df_args, "MAX");
1547 let udaf = Arc::new(crate::df_udfs::create_cypher_max_udaf());
1548 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1549 }
1550 "PERCENTILEDISC" => {
1551 if df_args.len() != 2 {
1552 return Some(Err(anyhow!(
1553 "percentileDisc() requires exactly 2 arguments"
1554 )));
1555 }
1556 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1557 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_disc_udaf());
1558 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1559 }
1560 "PERCENTILECONT" => {
1561 if df_args.len() != 2 {
1562 return Some(Err(anyhow!(
1563 "percentileCont() requires exactly 2 arguments"
1564 )));
1565 }
1566 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1567 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_cont_udaf());
1568 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1569 }
1570 "COLLECT" => {
1571 check_args!(1, df_args, "COLLECT");
1572 Some(Ok(crate::df_udfs::create_cypher_collect_expr(
1573 first_arg(df_args),
1574 distinct,
1575 )))
1576 }
1577 "BTIC_MIN" => {
1579 check_args!(1, df_args, "btic_min");
1580 let udaf = Arc::new(crate::df_udfs::create_btic_min_udaf());
1581 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1582 }
1583 "BTIC_MAX" => {
1584 check_args!(1, df_args, "btic_max");
1585 let udaf = Arc::new(crate::df_udfs::create_btic_max_udaf());
1586 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1587 }
1588 "BTIC_SPAN_AGG" => {
1589 check_args!(1, df_args, "btic_span_agg");
1590 let udaf = Arc::new(crate::df_udfs::create_btic_span_agg_udaf());
1591 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1592 }
1593 "BTIC_COUNT_AT" => {
1594 if df_args.len() != 2 {
1595 return Some(Err(anyhow!("btic_count_at requires 2 arguments")));
1596 }
1597 let udaf = Arc::new(crate::df_udfs::create_btic_count_at_udaf());
1598 Some(Ok(udaf.call(df_args.to_vec())))
1599 }
1600 _ => None,
1601 }
1602}
1603
1604fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1607 match name_upper {
1608 "TOSTRING" => {
1609 check_args!(1, df_args, "toString");
1610 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1611 }
1612 "TOINTEGER" | "TOINT" => {
1613 check_args!(1, df_args, "toInteger");
1614 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1615 }
1616 "TOFLOAT" => {
1617 check_args!(1, df_args, "toFloat");
1618 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1619 }
1620 "TOBOOLEAN" | "TOBOOL" => {
1621 check_args!(1, df_args, "toBoolean");
1622 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1623 }
1624 "UPPER" | "TOUPPER" => {
1625 check_args!(1, df_args, "upper");
1626 Some(Ok(datafusion::functions::string::expr_fn::upper(
1627 first_arg(df_args),
1628 )))
1629 }
1630 "LOWER" | "TOLOWER" => {
1631 check_args!(1, df_args, "lower");
1632 Some(Ok(datafusion::functions::string::expr_fn::lower(
1633 first_arg(df_args),
1634 )))
1635 }
1636 "SUBSTRING" => {
1637 check_args!(2, df_args, "substring");
1638 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1639 }
1640 "TRIM" => {
1641 check_args!(1, df_args, "TRIM");
1642 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1643 first_arg(df_args),
1644 ])))
1645 }
1646 "LTRIM" => {
1647 check_args!(1, df_args, "LTRIM");
1648 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1649 first_arg(df_args),
1650 ])))
1651 }
1652 "RTRIM" => {
1653 check_args!(1, df_args, "RTRIM");
1654 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1655 first_arg(df_args),
1656 ])))
1657 }
1658 "LEFT" => {
1659 check_args!(2, df_args, "left");
1660 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1661 df_args[0].clone(),
1662 df_args[1].clone(),
1663 )))
1664 }
1665 "RIGHT" => {
1666 check_args!(2, df_args, "right");
1667 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1668 df_args[0].clone(),
1669 df_args[1].clone(),
1670 )))
1671 }
1672 "REPLACE" => {
1673 check_args!(3, df_args, "replace");
1674 Some(Ok(datafusion::functions::string::expr_fn::replace(
1675 df_args[0].clone(),
1676 df_args[1].clone(),
1677 df_args[2].clone(),
1678 )))
1679 }
1680 "REVERSE" => {
1681 check_args!(1, df_args, "reverse");
1682 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1683 }
1684 "SPLIT" => {
1685 check_args!(2, df_args, "split");
1686 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1687 }
1688 "SIZE" | "LENGTH" => {
1689 check_args!(1, df_args, name_upper);
1690 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1691 }
1692 _ => None,
1693 }
1694}
1695
1696fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1699 use datafusion::functions::math::expr_fn;
1700
1701 let unary_f64 =
1703 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1704
1705 match name_upper {
1706 "ABS" => {
1707 check_args!(1, df_args, "abs");
1708 Some(Ok(crate::df_udfs::cypher_abs_expr(first_arg(df_args))))
1712 }
1713 "CEIL" | "CEILING" => {
1714 check_args!(1, df_args, "ceil");
1715 Some(Ok(expr_fn::ceil(first_arg(df_args))))
1716 }
1717 "FLOOR" => {
1718 check_args!(1, df_args, "floor");
1719 Some(Ok(expr_fn::floor(first_arg(df_args))))
1720 }
1721 "ROUND" => {
1722 check_args!(1, df_args, "round");
1723 let args = if df_args.len() == 1 {
1724 vec![first_arg(df_args)]
1725 } else {
1726 vec![df_args[0].clone(), df_args[1].clone()]
1727 };
1728 Some(Ok(expr_fn::round(args)))
1729 }
1730 "SIGN" => {
1731 check_args!(1, df_args, "sign");
1732 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1733 Some(Ok(expr_fn::signum(coerced)))
1734 }
1735 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1736 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1737 "LOG10" => unary_f64("log10", expr_fn::log10),
1738 "EXP" => unary_f64("exp", expr_fn::exp),
1739 "SIN" => unary_f64("sin", expr_fn::sin),
1740 "COS" => unary_f64("cos", expr_fn::cos),
1741 "TAN" => unary_f64("tan", expr_fn::tan),
1742 "ASIN" => unary_f64("asin", expr_fn::asin),
1743 "ACOS" => unary_f64("acos", expr_fn::acos),
1744 "ATAN" => unary_f64("atan", expr_fn::atan),
1745 "ATAN2" => {
1746 check_args!(2, df_args, "atan2");
1747 let cast_f64 =
1748 |e: DfExpr| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64);
1749 Some(Ok(expr_fn::atan2(
1750 cast_f64(df_args[0].clone()),
1751 cast_f64(df_args[1].clone()),
1752 )))
1753 }
1754 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1755 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1756 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1757 _ => None,
1758 }
1759}
1760
1761fn translate_temporal_function(
1764 name_upper: &str,
1765 name: &str,
1766 df_args: &[DfExpr],
1767 context: Option<&TranslationContext>,
1768) -> Option<Result<DfExpr>> {
1769 match name_upper {
1770 "DATE"
1771 | "TIME"
1772 | "LOCALTIME"
1773 | "LOCALDATETIME"
1774 | "DATETIME"
1775 | "DURATION"
1776 | "YEAR"
1777 | "MONTH"
1778 | "DAY"
1779 | "HOUR"
1780 | "MINUTE"
1781 | "SECOND"
1782 | "DURATION.BETWEEN"
1783 | "DURATION.INMONTHS"
1784 | "DURATION.INDAYS"
1785 | "DURATION.INSECONDS"
1786 | "DATETIME.FROMEPOCH"
1787 | "DATETIME.FROMEPOCHMILLIS"
1788 | "DATE.TRUNCATE"
1789 | "TIME.TRUNCATE"
1790 | "DATETIME.TRUNCATE"
1791 | "LOCALDATETIME.TRUNCATE"
1792 | "LOCALTIME.TRUNCATE"
1793 | "DATETIME.TRANSACTION"
1794 | "DATETIME.STATEMENT"
1795 | "DATETIME.REALTIME"
1796 | "DATE.TRANSACTION"
1797 | "DATE.STATEMENT"
1798 | "DATE.REALTIME"
1799 | "TIME.TRANSACTION"
1800 | "TIME.STATEMENT"
1801 | "TIME.REALTIME"
1802 | "LOCALTIME.TRANSACTION"
1803 | "LOCALTIME.STATEMENT"
1804 | "LOCALTIME.REALTIME"
1805 | "LOCALDATETIME.TRANSACTION"
1806 | "LOCALDATETIME.STATEMENT"
1807 | "LOCALDATETIME.REALTIME" => {
1808 let stmt_time = context.map(|c| c.statement_time);
1812 if can_constant_fold(name_upper, df_args)
1813 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1814 {
1815 return Some(Ok(folded));
1816 }
1817 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1818 }
1819 _ => None,
1820 }
1821}
1822
1823fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1825 if name.contains("REALTIME") {
1827 return false;
1828 }
1829 if args.is_empty() {
1837 return matches!(
1838 name,
1839 "DATE"
1840 | "TIME"
1841 | "LOCALTIME"
1842 | "LOCALDATETIME"
1843 | "DATETIME"
1844 | "DATE.STATEMENT"
1845 | "TIME.STATEMENT"
1846 | "LOCALTIME.STATEMENT"
1847 | "LOCALDATETIME.STATEMENT"
1848 | "DATETIME.STATEMENT"
1849 | "DATE.TRANSACTION"
1850 | "TIME.TRANSACTION"
1851 | "LOCALTIME.TRANSACTION"
1852 | "LOCALDATETIME.TRANSACTION"
1853 | "DATETIME.TRANSACTION"
1854 );
1855 }
1856 args.iter().all(is_constant_expr)
1858}
1859
1860fn is_constant_expr(expr: &DfExpr) -> bool {
1862 match expr {
1863 DfExpr::Literal(_, _) => true,
1864 DfExpr::ScalarFunction(func) => {
1865 func.args.iter().all(is_constant_expr)
1867 }
1868 _ => false,
1869 }
1870}
1871
1872fn try_constant_fold_temporal(
1878 name: &str,
1879 args: &[DfExpr],
1880 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1881) -> Result<DfExpr> {
1882 let val_args: Vec<Value> = args
1884 .iter()
1885 .map(extract_constant_value)
1886 .collect::<Result<_>>()?;
1887
1888 let result = if val_args.is_empty() {
1890 if let Some(frozen) = stmt_time {
1891 crate::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1892 } else {
1893 crate::datetime::eval_datetime_function(name, &val_args)?
1894 }
1895 } else {
1896 crate::datetime::eval_datetime_function(name, &val_args)?
1897 };
1898
1899 let scalar = value_to_scalar(&result)?;
1901 Ok(DfExpr::Literal(scalar, None))
1902}
1903
1904fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1906 use crate::df_udfs::scalar_to_value;
1907 match expr {
1908 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1909 DfExpr::ScalarFunction(func) => {
1910 let mut map = std::collections::HashMap::new();
1913 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1914 for chunk in pairs.chunks(2) {
1915 if let [key_expr, val_expr] = chunk {
1916 let key = match key_expr {
1918 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1919 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1920 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1921 };
1922 let val = extract_constant_value(val_expr)?;
1923 map.insert(key, val);
1924 } else {
1925 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1926 }
1927 }
1928 Ok(Value::Map(map))
1929 }
1930 _ => Err(anyhow::anyhow!(
1931 "Cannot extract constant value from expression"
1932 )),
1933 }
1934}
1935
1936fn translate_btic_function(
1939 name_upper: &str,
1940 name: &str,
1941 df_args: &[DfExpr],
1942) -> Option<Result<DfExpr>> {
1943 if crate::expr_eval::is_btic_function(name_upper) {
1944 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1945 } else {
1946 None
1947 }
1948}
1949
1950fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1953 match name_upper {
1954 "HEAD" => {
1955 check_args!(1, df_args, "head");
1956 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1957 }
1958 "LAST" => {
1959 check_args!(1, df_args, "last");
1960 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1961 }
1962 "TAIL" => {
1963 check_args!(1, df_args, "tail");
1964 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
1965 }
1966 "RANGE" => {
1967 check_args!(2, df_args, "range");
1968 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
1969 }
1970 _ => None,
1971 }
1972}
1973
1974fn translate_graph_function(
1977 name_upper: &str,
1978 name: &str,
1979 df_args: &[DfExpr],
1980 args: &[Expr],
1981 context: Option<&TranslationContext>,
1982) -> Option<Result<DfExpr>> {
1983 match name_upper {
1984 "ID" => {
1985 if let Some(Expr::Variable(var)) = args.first() {
1988 let is_edge = context.is_some_and(|ctx| {
1989 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
1990 || ctx.mutation_edge_hints.iter().any(|h| h == var)
1991 });
1992 let id_suffix = if is_edge { COL_EID } else { COL_VID };
1993 Some(Ok(DfExpr::Column(Column::from_name(format!(
1994 "{}.{}",
1995 var, id_suffix
1996 )))))
1997 } else {
1998 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
1999 }
2000 }
2001 "CREATED_AT" | "UPDATED_AT" => {
2002 if let Some(Expr::Variable(var)) = args.first() {
2006 let suffix = if name_upper == "CREATED_AT" {
2007 "_created_at"
2008 } else {
2009 "_updated_at"
2010 };
2011 Some(Ok(DfExpr::Column(Column::from_name(format!(
2012 "{}.{}",
2013 var, suffix
2014 )))))
2015 } else {
2016 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2017 }
2018 }
2019 "LABELS" | "KEYS" => {
2020 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2025 }
2026 "TYPE" => {
2027 if let Some(Expr::Variable(var)) = args.first()
2031 && let Some(ctx) = context
2032 && let Some(label) = ctx.variable_labels.get(var)
2033 {
2034 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
2037 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
2038 expr: None,
2039 when_then_expr: vec![(
2040 Box::new(eid_col.is_not_null()),
2041 Box::new(lit(label.clone())),
2042 )],
2043 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
2044 })));
2045 }
2046 if let Some(Expr::Variable(var)) = args.first()
2050 && context
2051 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
2052 {
2053 return Some(Ok(DfExpr::Column(Column::from_name(format!(
2054 "{}.{}",
2055 var, COL_TYPE
2056 )))));
2057 }
2058 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
2059 }
2060 "PROPERTIES" => {
2061 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2064 }
2065 "UNI.TEMPORAL.VALIDAT" => {
2066 if let (
2069 Some(Expr::Variable(var)),
2070 Some(Expr::Literal(CypherLiteral::String(start_prop))),
2071 Some(Expr::Literal(CypherLiteral::String(end_prop))),
2072 Some(ts_expr),
2073 ) = (args.first(), args.get(1), args.get(2), args.get(3))
2074 {
2075 let start_col =
2076 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
2077 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
2078 let ts = match cypher_expr_to_df(ts_expr, context) {
2079 Ok(ts) => ts,
2080 Err(e) => return Some(Err(e)),
2081 };
2082
2083 let start_check = start_col.lt_eq(ts.clone());
2085 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
2087 let end_after = end_col.gt(ts);
2088 let end_check = end_null.or(end_after);
2089
2090 Some(Ok(start_check.and(end_check)))
2091 } else {
2092 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2094 }
2095 }
2096 "STARTNODE" | "ENDNODE" => {
2097 let mut udf_args = df_args.to_vec();
2100 let mut seen = std::collections::HashSet::new();
2101 if let Some(ctx) = context {
2102 for (var, kind) in &ctx.variable_kinds {
2104 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2105 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2106 }
2107 }
2108 for var in &ctx.node_variable_hints {
2111 if seen.insert(var.clone()) {
2112 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2113 }
2114 }
2115 }
2116 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2117 }
2118 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2119 "HASLABEL" => {
2120 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2121 return Some(Err(e));
2122 }
2123 if let Some(Expr::Variable(var)) = args.first() {
2125 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2126 let labels_col =
2128 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2129 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2130 labels_col,
2131 lit(label.clone()),
2132 )))
2133 } else {
2134 Some(Err(anyhow::anyhow!(
2136 "hasLabel requires string literal as second argument for DataFusion translation"
2137 )))
2138 }
2139 } else {
2140 Some(Err(anyhow::anyhow!(
2142 "hasLabel requires variable as first argument for DataFusion translation"
2143 )))
2144 }
2145 }
2146 _ => None,
2147 }
2148}
2149
2150fn translate_function_call(
2152 name: &str,
2153 args: &[Expr],
2154 distinct: bool,
2155 context: Option<&TranslationContext>,
2156) -> Result<DfExpr> {
2157 let df_args: Vec<DfExpr> = args
2158 .iter()
2159 .map(|arg| cypher_expr_to_df(arg, context))
2160 .collect::<Result<Vec<_>>>()?;
2161
2162 let name_upper = name.to_uppercase();
2163
2164 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2168 return result;
2169 }
2170
2171 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2172 return result;
2173 }
2174
2175 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2176 return result;
2177 }
2178
2179 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2180 return result;
2181 }
2182
2183 if let Some(result) = translate_btic_function(&name_upper, name, &df_args) {
2184 return result;
2185 }
2186
2187 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2188 return result;
2189 }
2190
2191 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2192 return result;
2193 }
2194
2195 match name_upper.as_str() {
2197 "COALESCE" => {
2198 require_arg(&df_args, "coalesce")?;
2199 if df_args.len() == 1 {
2204 return Ok(df_args.into_iter().next().unwrap());
2205 }
2206 let n = df_args.len();
2207 let (init, last) = df_args.split_at(n - 1);
2208 let mut builder = datafusion::logical_expr::conditional_expressions::CaseBuilder::new(
2209 None,
2210 vec![],
2211 vec![],
2212 None,
2213 );
2214 for arg in init {
2215 builder.when(arg.clone().is_not_null(), arg.clone());
2216 }
2217 return Ok(builder.otherwise(last[0].clone())?);
2218 }
2219 "NULLIF" => {
2220 require_args(&df_args, 2, "nullif")?;
2221 return Ok(datafusion::functions::expr_fn::nullif(
2222 df_args[0].clone(),
2223 df_args[1].clone(),
2224 ));
2225 }
2226 _ => {}
2227 }
2228
2229 match name_upper.as_str() {
2231 "SIMILAR_TO" | "VECTOR_SIMILARITY" => {
2232 return Ok(dummy_udf_expr(&name_upper.to_lowercase(), df_args));
2233 }
2234 _ => {}
2235 }
2236
2237 Ok(dummy_udf_expr(name, df_args))
2239}
2240
2241#[derive(Debug)]
2246struct DummyUdf {
2247 name: String,
2248 signature: datafusion::logical_expr::Signature,
2249 ret_type: datafusion::arrow::datatypes::DataType,
2250}
2251
2252impl DummyUdf {
2253 fn new(name: String) -> Self {
2254 let ret_type = dummy_udf_return_type(&name);
2255 Self {
2256 name,
2257 signature: datafusion::logical_expr::Signature::variadic_any(
2258 datafusion::logical_expr::Volatility::Immutable,
2259 ),
2260 ret_type,
2261 }
2262 }
2263}
2264
2265fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2278 use datafusion::arrow::datatypes::DataType;
2279 match name {
2280 "_cypher_add"
2284 | "_cypher_sub"
2285 | "_cypher_mul"
2286 | "_cypher_div"
2287 | "_cypher_mod"
2288 | "_cypher_list_concat"
2289 | "_cypher_list_append"
2290 | "_make_cypher_list"
2291 | "_map_project"
2292 | "_cypher_list_to_cv"
2293 | "_cypher_tail" => DataType::LargeBinary,
2294 _ => DataType::Null,
2298 }
2299}
2300
2301impl PartialEq for DummyUdf {
2302 fn eq(&self, other: &Self) -> bool {
2303 self.name == other.name
2304 }
2305}
2306
2307impl Eq for DummyUdf {}
2308
2309impl Hash for DummyUdf {
2310 fn hash<H: Hasher>(&self, state: &mut H) {
2311 self.name.hash(state);
2312 }
2313}
2314
2315pub fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2317 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2318 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2319 DummyUdf::new(name.to_lowercase()),
2320 )),
2321 args,
2322 })
2323}
2324
2325impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2326 fn as_any(&self) -> &dyn std::any::Any {
2327 self
2328 }
2329
2330 fn name(&self) -> &str {
2331 &self.name
2332 }
2333
2334 fn signature(&self) -> &datafusion::logical_expr::Signature {
2335 &self.signature
2336 }
2337
2338 fn return_type(
2339 &self,
2340 arg_types: &[datafusion::arrow::datatypes::DataType],
2341 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2342 match self.name.as_str() {
2347 "_cypher_add" | "_cypher_sub" | "_cypher_mul" | "_cypher_div" | "_cypher_mod" => {
2348 Ok(crate::df_udfs::cypher_arith_return_type(arg_types))
2349 }
2350 _ => Ok(self.ret_type.clone()),
2354 }
2355 }
2356
2357 fn invoke_with_args(
2358 &self,
2359 _args: ScalarFunctionArgs,
2360 ) -> datafusion::error::Result<ColumnarValue> {
2361 Err(datafusion::error::DataFusionError::Plan(format!(
2362 "UDF '{}' is not registered. Register it via SessionContext.",
2363 self.name
2364 )))
2365 }
2366}
2367
2368pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2372 let mut properties = Vec::new();
2373 collect_properties_recursive(expr, &mut properties);
2374 properties.sort();
2375 properties.dedup();
2376 properties
2377}
2378
2379fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2380 match expr {
2381 Expr::PatternComprehension { .. } => {}
2382 Expr::Property(base, prop) => {
2383 if let Ok(var_name) = extract_variable_name(base) {
2384 properties.push((var_name, prop.clone()));
2385 }
2386 collect_properties_recursive(base, properties);
2387 }
2388 Expr::ArrayIndex { array, index } => {
2389 if let Ok(var_name) = extract_variable_name(array)
2390 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2391 {
2392 properties.push((var_name, prop_name.clone()));
2393 }
2394 collect_properties_recursive(array, properties);
2395 collect_properties_recursive(index, properties);
2396 }
2397 Expr::ArraySlice { array, start, end } => {
2398 collect_properties_recursive(array, properties);
2399 if let Some(s) = start {
2400 collect_properties_recursive(s, properties);
2401 }
2402 if let Some(e) = end {
2403 collect_properties_recursive(e, properties);
2404 }
2405 }
2406 Expr::List(items) => {
2407 for item in items {
2408 collect_properties_recursive(item, properties);
2409 }
2410 }
2411 Expr::Map(entries) => {
2412 for (_, value) in entries {
2413 collect_properties_recursive(value, properties);
2414 }
2415 }
2416 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2417 collect_properties_recursive(inner, properties);
2418 }
2419 Expr::FunctionCall { args, .. } => {
2420 for arg in args {
2421 collect_properties_recursive(arg, properties);
2422 }
2423 }
2424 Expr::BinaryOp { left, right, .. } => {
2425 collect_properties_recursive(left, properties);
2426 collect_properties_recursive(right, properties);
2427 }
2428 Expr::UnaryOp { expr, .. } => {
2429 collect_properties_recursive(expr, properties);
2430 }
2431 Expr::Case {
2432 expr,
2433 when_then,
2434 else_expr,
2435 } => {
2436 if let Some(e) = expr {
2437 collect_properties_recursive(e, properties);
2438 }
2439 for (when_e, then_e) in when_then {
2440 collect_properties_recursive(when_e, properties);
2441 collect_properties_recursive(then_e, properties);
2442 }
2443 if let Some(e) = else_expr {
2444 collect_properties_recursive(e, properties);
2445 }
2446 }
2447 Expr::Reduce {
2448 init, list, expr, ..
2449 } => {
2450 collect_properties_recursive(init, properties);
2451 collect_properties_recursive(list, properties);
2452 collect_properties_recursive(expr, properties);
2453 }
2454 Expr::Quantifier {
2455 list, predicate, ..
2456 } => {
2457 collect_properties_recursive(list, properties);
2458 collect_properties_recursive(predicate, properties);
2459 }
2460 Expr::ListComprehension {
2461 list,
2462 where_clause,
2463 map_expr,
2464 ..
2465 } => {
2466 collect_properties_recursive(list, properties);
2467 if let Some(filter) = where_clause {
2468 collect_properties_recursive(filter, properties);
2469 }
2470 collect_properties_recursive(map_expr, properties);
2471 }
2472 Expr::In { expr, list } => {
2473 collect_properties_recursive(expr, properties);
2474 collect_properties_recursive(list, properties);
2475 }
2476 Expr::ValidAt {
2477 entity, timestamp, ..
2478 } => {
2479 collect_properties_recursive(entity, properties);
2480 collect_properties_recursive(timestamp, properties);
2481 }
2482 Expr::MapProjection { base, items } => {
2483 collect_properties_recursive(base, properties);
2484 for item in items {
2485 match item {
2486 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2487 if let Ok(var_name) = extract_variable_name(base) {
2488 properties.push((var_name, prop.clone()));
2489 }
2490 }
2491 uni_cypher::ast::MapProjectionItem::AllProperties => {
2492 if let Ok(var_name) = extract_variable_name(base) {
2493 properties.push((var_name, "*".to_string()));
2494 }
2495 }
2496 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2497 collect_properties_recursive(expr, properties);
2498 }
2499 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2500 }
2501 }
2502 }
2503 Expr::LabelCheck { expr, .. } => {
2504 collect_properties_recursive(expr, properties);
2505 }
2506 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2508 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2509 }
2510}
2511
2512pub fn wider_numeric_type(
2519 a: &datafusion::arrow::datatypes::DataType,
2520 b: &datafusion::arrow::datatypes::DataType,
2521) -> datafusion::arrow::datatypes::DataType {
2522 use datafusion::arrow::datatypes::DataType;
2523
2524 fn numeric_rank(dt: &DataType) -> u8 {
2525 match dt {
2526 DataType::Int8 | DataType::UInt8 => 1,
2527 DataType::Int16 | DataType::UInt16 => 2,
2528 DataType::Int32 | DataType::UInt32 => 3,
2529 DataType::Int64 | DataType::UInt64 => 4,
2530 DataType::Float16 => 5,
2531 DataType::Float32 => 6,
2532 DataType::Float64 => 7,
2533 _ => 0,
2534 }
2535 }
2536
2537 if numeric_rank(a) >= numeric_rank(b) {
2538 a.clone()
2539 } else {
2540 b.clone()
2541 }
2542}
2543
2544fn resolve_column_type_fallback(
2550 expr: &DfExpr,
2551 schema: &datafusion::common::DFSchema,
2552) -> Option<datafusion::arrow::datatypes::DataType> {
2553 if let DfExpr::Column(col) = expr {
2554 let col_name = &col.name;
2555 for (_, field) in schema.iter() {
2557 if field.name() == col_name {
2558 return Some(field.data_type().clone());
2559 }
2560 }
2561 }
2562 None
2563}
2564
2565fn contains_division(expr: &DfExpr) -> bool {
2568 match expr {
2569 DfExpr::BinaryExpr(b) => {
2570 b.op == datafusion::logical_expr::Operator::Divide
2571 || contains_division(&b.left)
2572 || contains_division(&b.right)
2573 }
2574 DfExpr::Cast(c) => contains_division(&c.expr),
2575 DfExpr::TryCast(c) => contains_division(&c.expr),
2576 _ => false,
2577 }
2578}
2579
2580pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2586 use datafusion::arrow::datatypes::DataType;
2587 use datafusion::logical_expr::ExprSchemable;
2588
2589 match expr {
2590 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2591 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2592 DfExpr::Case(case) => coerce_case_expr(case, schema),
2593 DfExpr::InList(in_list) => {
2594 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2595 let coerced_list = in_list
2596 .list
2597 .iter()
2598 .map(|e| apply_type_coercion(e, schema))
2599 .collect::<Result<Vec<_>>>()?;
2600 let expr_type = coerced_expr
2601 .get_type(schema)
2602 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2603 crate::cypher_type_coerce::build_cypher_in_list(
2604 coerced_expr,
2605 &expr_type,
2606 coerced_list,
2607 in_list.negated,
2608 schema,
2609 )
2610 }
2611 DfExpr::Not(inner) => {
2612 let coerced_inner = apply_type_coercion(inner, schema)?;
2613 let inner_type = coerced_inner.get_type(schema).ok();
2614 let final_inner = if inner_type
2615 .as_ref()
2616 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2617 {
2618 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2619 } else if inner_type
2620 .as_ref()
2621 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2622 {
2623 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2624 } else {
2625 coerced_inner
2626 };
2627 Ok(DfExpr::Not(Box::new(final_inner)))
2628 }
2629 DfExpr::IsNull(inner) => {
2630 let coerced_inner = apply_type_coercion(inner, schema)?;
2631 Ok(coerced_inner.is_null())
2632 }
2633 DfExpr::IsNotNull(inner) => {
2634 let coerced_inner = apply_type_coercion(inner, schema)?;
2635 Ok(coerced_inner.is_not_null())
2636 }
2637 DfExpr::Negative(inner) => {
2638 let coerced_inner = apply_type_coercion(inner, schema)?;
2639 let inner_type = coerced_inner.get_type(schema).ok();
2640 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2641 Ok(dummy_udf_expr(
2642 "_cypher_mul",
2643 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2644 ))
2645 } else {
2646 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2647 }
2648 }
2649 DfExpr::Cast(cast) => {
2650 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2651 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2652 Box::new(coerced_inner),
2653 cast.data_type.clone(),
2654 )))
2655 }
2656 DfExpr::TryCast(cast) => {
2657 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2658 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2659 Box::new(coerced_inner),
2660 cast.data_type.clone(),
2661 )))
2662 }
2663 DfExpr::Alias(alias) => {
2664 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2665 Ok(coerced_inner.alias(alias.name.clone()))
2666 }
2667 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2668 _ => Ok(expr.clone()),
2669 }
2670}
2671
2672fn coerce_logical_operands(
2674 left: DfExpr,
2675 right: DfExpr,
2676 op: datafusion::logical_expr::Operator,
2677 schema: &datafusion::common::DFSchema,
2678) -> Option<DfExpr> {
2679 use datafusion::arrow::datatypes::DataType;
2680 use datafusion::logical_expr::ExprSchemable;
2681
2682 if !matches!(
2683 op,
2684 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2685 ) {
2686 return None;
2687 }
2688 let left_type = left.get_type(schema).ok();
2689 let right_type = right.get_type(schema).ok();
2690 let left_needs_cast = left_type
2691 .as_ref()
2692 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2693 let right_needs_cast = right_type
2694 .as_ref()
2695 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2696 let left_is_lb = left_type
2697 .as_ref()
2698 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2699 let right_is_lb = right_type
2700 .as_ref()
2701 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2702 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2703 return None;
2704 }
2705 let coerced_left = if left_is_lb {
2706 dummy_udf_expr("_cv_to_bool", vec![left])
2707 } else if left_needs_cast {
2708 datafusion::logical_expr::cast(left, DataType::Boolean)
2709 } else {
2710 left
2711 };
2712 let coerced_right = if right_is_lb {
2713 dummy_udf_expr("_cv_to_bool", vec![right])
2714 } else if right_needs_cast {
2715 datafusion::logical_expr::cast(right, DataType::Boolean)
2716 } else {
2717 right
2718 };
2719 Some(binary_expr(coerced_left, op, coerced_right))
2720}
2721
2722#[expect(
2725 clippy::too_many_arguments,
2726 reason = "Binary coercion needs all context"
2727)]
2728fn coerce_large_binary_ops(
2729 left: &DfExpr,
2730 right: &DfExpr,
2731 left_type: &datafusion::arrow::datatypes::DataType,
2732 right_type: &datafusion::arrow::datatypes::DataType,
2733 left_is_null: bool,
2734 op: datafusion::logical_expr::Operator,
2735 is_comparison: bool,
2736 is_arithmetic: bool,
2737) -> Option<Result<DfExpr>> {
2738 use datafusion::arrow::datatypes::DataType;
2739 use datafusion::logical_expr::Operator;
2740
2741 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2742 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2743
2744 if op == Operator::Plus {
2745 if left_is_lb && right_is_lb {
2746 return Some(Ok(dummy_udf_expr(
2747 "_cypher_add",
2748 vec![left.clone(), right.clone()],
2749 )));
2750 }
2751 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2752 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2753 if left_is_native_list && right_is_native_list {
2754 return Some(Ok(dummy_udf_expr(
2755 "_cypher_list_concat",
2756 vec![left.clone(), right.clone()],
2757 )));
2758 }
2759 if left_is_native_list || right_is_native_list {
2760 return Some(Ok(dummy_udf_expr(
2761 "_cypher_list_append",
2762 vec![left.clone(), right.clone()],
2763 )));
2764 }
2765 }
2766
2767 if (left_is_lb || right_is_lb) && is_comparison {
2768 if let Some(udf_name) = comparison_udf_name(op) {
2769 return Some(Ok(dummy_udf_expr(
2770 udf_name,
2771 vec![left.clone(), right.clone()],
2772 )));
2773 }
2774 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2775 }
2776
2777 if (left_is_lb || right_is_lb) && is_arithmetic {
2778 let udf_name =
2779 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2780 return Some(Ok(dummy_udf_expr(
2781 udf_name,
2782 vec![left.clone(), right.clone()],
2783 )));
2784 }
2785
2786 None
2787}
2788
2789fn coerce_temporal_comparisons(
2791 left: DfExpr,
2792 right: DfExpr,
2793 left_type: &datafusion::arrow::datatypes::DataType,
2794 right_type: &datafusion::arrow::datatypes::DataType,
2795 op: datafusion::logical_expr::Operator,
2796 is_comparison: bool,
2797) -> Option<DfExpr> {
2798 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2799 use datafusion::logical_expr::Operator;
2800
2801 if !is_comparison {
2802 return None;
2803 }
2804
2805 if uni_common::core::schema::is_datetime_struct(left_type)
2807 && uni_common::core::schema::is_datetime_struct(right_type)
2808 {
2809 return Some(binary_expr(
2810 extract_datetime_nanos(left),
2811 op,
2812 extract_datetime_nanos(right),
2813 ));
2814 }
2815
2816 if uni_common::core::schema::is_time_struct(left_type)
2818 && uni_common::core::schema::is_time_struct(right_type)
2819 {
2820 return Some(binary_expr(
2821 extract_time_nanos(left),
2822 op,
2823 extract_time_nanos(right),
2824 ));
2825 }
2826
2827 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2829 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2830
2831 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2832 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2833 {
2834 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2835 extract_datetime_nanos(left)
2836 } else {
2837 left
2838 };
2839 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2840 extract_datetime_nanos(right)
2841 } else {
2842 right
2843 };
2844 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2845 return Some(binary_expr(
2846 cast_expr(left_nanos, ts_type.clone()),
2847 op,
2848 cast_expr(right_nanos, ts_type),
2849 ));
2850 }
2851
2852 let left_is_duration = matches!(left_type, DataType::Interval(_));
2856 let right_is_duration = matches!(right_type, DataType::Interval(_));
2857 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2858 || uni_common::core::schema::is_time_struct(left_type)
2859 || matches!(
2860 left_type,
2861 DataType::Timestamp(_, _)
2862 | DataType::Date32
2863 | DataType::Date64
2864 | DataType::Time32(_)
2865 | DataType::Time64(_)
2866 );
2867 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2868 || uni_common::core::schema::is_time_struct(right_type)
2869 || matches!(
2870 right_type,
2871 DataType::Timestamp(_, _)
2872 | DataType::Date32
2873 | DataType::Date64
2874 | DataType::Time32(_)
2875 | DataType::Time64(_)
2876 );
2877
2878 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2879 {
2880 return Some(match op {
2881 Operator::Eq => lit(false),
2882 Operator::NotEq => lit(true),
2883 _ => lit(ScalarValue::Boolean(None)),
2884 });
2885 }
2886
2887 None
2888}
2889
2890fn coerce_mismatched_types(
2893 left: DfExpr,
2894 right: DfExpr,
2895 left_type: &datafusion::arrow::datatypes::DataType,
2896 right_type: &datafusion::arrow::datatypes::DataType,
2897 op: datafusion::logical_expr::Operator,
2898 is_comparison: bool,
2899) -> Option<Result<DfExpr>> {
2900 use datafusion::arrow::datatypes::DataType;
2901 use datafusion::logical_expr::Operator;
2902
2903 if left_type == right_type {
2904 return None;
2905 }
2906
2907 if left_type.is_numeric() && right_type.is_numeric() {
2909 if left_type == &DataType::Int64
2910 && right_type == &DataType::UInt64
2911 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2912 {
2913 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2914 return Some(Ok(binary_expr(coerced_left, op, right)));
2915 }
2916 if left_type == &DataType::UInt64
2917 && right_type == &DataType::Int64
2918 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2919 {
2920 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2921 return Some(Ok(binary_expr(left, op, coerced_right)));
2922 }
2923 let target = wider_numeric_type(left_type, right_type);
2924 let coerced_left = if *left_type != target {
2925 datafusion::logical_expr::cast(left, target.clone())
2926 } else {
2927 left
2928 };
2929 let coerced_right = if *right_type != target {
2930 datafusion::logical_expr::cast(right, target)
2931 } else {
2932 right
2933 };
2934 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2935 }
2936
2937 if is_comparison {
2939 match (left_type, right_type) {
2940 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2941 let right = normalize_datetime_literal(right);
2942 return Some(Ok(binary_expr(
2943 left,
2944 op,
2945 datafusion::logical_expr::cast(right, ts.clone()),
2946 )));
2947 }
2948 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2949 let left = normalize_datetime_literal(left);
2950 return Some(Ok(binary_expr(
2951 datafusion::logical_expr::cast(left, ts.clone()),
2952 op,
2953 right,
2954 )));
2955 }
2956 _ => {}
2957 }
2958 }
2959
2960 if is_comparison
2962 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2963 {
2964 let l_inner = l_field.data_type();
2965 let r_inner = r_field.data_type();
2966 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
2967 let target_inner = wider_numeric_type(l_inner, r_inner);
2968 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
2969 "item",
2970 target_inner,
2971 true,
2972 )));
2973 return Some(Ok(binary_expr(
2974 datafusion::logical_expr::cast(left, target_type.clone()),
2975 op,
2976 datafusion::logical_expr::cast(right, target_type),
2977 )));
2978 }
2979 }
2980
2981 if is_primitive_type(left_type) && is_primitive_type(right_type) {
2983 if op == Operator::Plus {
2984 return Some(crate::cypher_type_coerce::build_cypher_plus(
2985 left, left_type, right, right_type,
2986 ));
2987 }
2988 if is_comparison {
2989 return Some(Ok(crate::cypher_type_coerce::build_cypher_comparison(
2990 left, left_type, right, right_type, op,
2991 )));
2992 }
2993 }
2994
2995 None
2996}
2997
2998fn coerce_list_comparisons(
3000 left: DfExpr,
3001 right: DfExpr,
3002 left_type: &datafusion::arrow::datatypes::DataType,
3003 right_type: &datafusion::arrow::datatypes::DataType,
3004 op: datafusion::logical_expr::Operator,
3005 is_comparison: bool,
3006) -> Option<DfExpr> {
3007 use datafusion::arrow::datatypes::DataType;
3008 use datafusion::logical_expr::Operator;
3009
3010 if !is_comparison {
3011 return None;
3012 }
3013
3014 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
3015 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
3016
3017 if left_is_list
3019 && right_is_list
3020 && matches!(
3021 op,
3022 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
3023 )
3024 {
3025 let op_str = match op {
3026 Operator::Lt => "lt",
3027 Operator::LtEq => "lteq",
3028 Operator::Gt => "gt",
3029 Operator::GtEq => "gteq",
3030 _ => unreachable!(),
3031 };
3032 return Some(dummy_udf_expr(
3033 "_cypher_list_compare",
3034 vec![left, right, lit(op_str)],
3035 ));
3036 }
3037
3038 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
3040 let udf_name =
3041 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
3042 return Some(dummy_udf_expr(udf_name, vec![left, right]));
3043 }
3044
3045 if (left_is_list != right_is_list)
3047 && !matches!(left_type, DataType::Null)
3048 && !matches!(right_type, DataType::Null)
3049 {
3050 return Some(match op {
3051 Operator::Eq => lit(false),
3052 Operator::NotEq => lit(true),
3053 _ => lit(ScalarValue::Boolean(None)),
3054 });
3055 }
3056
3057 None
3058}
3059
3060fn coerce_binary_expr(
3062 binary: &datafusion::logical_expr::expr::BinaryExpr,
3063 schema: &datafusion::common::DFSchema,
3064) -> Result<DfExpr> {
3065 use datafusion::arrow::datatypes::DataType;
3066 use datafusion::logical_expr::ExprSchemable;
3067 use datafusion::logical_expr::Operator;
3068
3069 let left = apply_type_coercion(&binary.left, schema)?;
3070 let right = apply_type_coercion(&binary.right, schema)?;
3071
3072 let is_comparison = matches!(
3073 binary.op,
3074 Operator::Eq
3075 | Operator::NotEq
3076 | Operator::Lt
3077 | Operator::LtEq
3078 | Operator::Gt
3079 | Operator::GtEq
3080 );
3081 let is_arithmetic = matches!(
3082 binary.op,
3083 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
3084 );
3085
3086 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
3088 return Ok(result);
3089 }
3090
3091 if is_comparison || is_arithmetic {
3092 let left_type = match left.get_type(schema) {
3093 Ok(t) => t,
3094 Err(e) => {
3095 if let Some(t) = resolve_column_type_fallback(&left, schema) {
3096 t
3097 } else {
3098 log::warn!("Failed to get left type in binary expr: {}", e);
3099 return Ok(binary_expr(left, binary.op, right));
3100 }
3101 }
3102 };
3103 let right_type = match right.get_type(schema) {
3104 Ok(t) => t,
3105 Err(e) => {
3106 if let Some(t) = resolve_column_type_fallback(&right, schema) {
3107 t
3108 } else {
3109 log::warn!("Failed to get right type in binary expr: {}", e);
3110 return Ok(binary_expr(left, binary.op, right));
3111 }
3112 }
3113 };
3114
3115 let left_is_null = left_type.is_null();
3117 let right_is_null = right_type.is_null();
3118 if left_is_null && right_is_null {
3119 return Ok(lit(ScalarValue::Boolean(None)));
3120 }
3121 if left_is_null || right_is_null {
3122 let target = if left_is_null {
3123 &right_type
3124 } else {
3125 &left_type
3126 };
3127 if !matches!(target, DataType::LargeBinary) {
3128 let coerced_left = if left_is_null {
3129 datafusion::logical_expr::cast(left, target.clone())
3130 } else {
3131 left
3132 };
3133 let coerced_right = if right_is_null {
3134 datafusion::logical_expr::cast(right, target.clone())
3135 } else {
3136 right
3137 };
3138 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3139 }
3140 }
3141
3142 if let Some(result) = coerce_large_binary_ops(
3144 &left,
3145 &right,
3146 &left_type,
3147 &right_type,
3148 left_is_null,
3149 binary.op,
3150 is_comparison,
3151 is_arithmetic,
3152 ) {
3153 return result;
3154 }
3155
3156 if let Some(result) = coerce_temporal_comparisons(
3158 left.clone(),
3159 right.clone(),
3160 &left_type,
3161 &right_type,
3162 binary.op,
3163 is_comparison,
3164 ) {
3165 return Ok(result);
3166 }
3167
3168 let either_struct =
3170 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3171 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3172 || matches!(left_type, DataType::Struct(_)))
3173 && (matches!(right_type, DataType::LargeBinary)
3174 || matches!(right_type, DataType::Struct(_)));
3175 if is_comparison && either_struct && either_lb_or_struct {
3176 if let Some(udf_name) = comparison_udf_name(binary.op) {
3177 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3178 }
3179 return Ok(lit(ScalarValue::Boolean(None)));
3180 }
3181
3182 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3184 let udf_name = comparison_udf_name(binary.op)
3185 .expect("is_comparison guarantees a valid comparison operator");
3186 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3187 }
3188
3189 if binary.op == Operator::Plus
3191 && (crate::cypher_type_coerce::is_string_type(&left_type)
3192 || crate::cypher_type_coerce::is_string_type(&right_type))
3193 && is_primitive_type(&left_type)
3194 && is_primitive_type(&right_type)
3195 {
3196 return crate::cypher_type_coerce::build_cypher_plus(
3197 left,
3198 &left_type,
3199 right,
3200 &right_type,
3201 );
3202 }
3203
3204 if let Some(result) = coerce_mismatched_types(
3206 left.clone(),
3207 right.clone(),
3208 &left_type,
3209 &right_type,
3210 binary.op,
3211 is_comparison,
3212 ) {
3213 return result;
3214 }
3215
3216 if let Some(result) = coerce_list_comparisons(
3218 left.clone(),
3219 right.clone(),
3220 &left_type,
3221 &right_type,
3222 binary.op,
3223 is_comparison,
3224 ) {
3225 return Ok(result);
3226 }
3227
3228 if let Some(name) = arithmetic_udf_name(binary.op)
3237 && left_type == DataType::Int64
3238 && right_type == DataType::Int64
3239 && !is_list_expr(&left)
3240 && !is_list_expr(&right)
3241 {
3242 return Ok(dummy_udf_expr(name, vec![left, right]));
3243 }
3244 }
3245
3246 Ok(binary_expr(left, binary.op, right))
3247}
3248
3249fn coerce_scalar_function(
3251 func: &datafusion::logical_expr::expr::ScalarFunction,
3252 schema: &datafusion::common::DFSchema,
3253) -> Result<DfExpr> {
3254 use datafusion::arrow::datatypes::DataType;
3255 use datafusion::logical_expr::ExprSchemable;
3256
3257 let coerced_args: Vec<DfExpr> = func
3258 .args
3259 .iter()
3260 .map(|a| apply_type_coercion(a, schema))
3261 .collect::<Result<Vec<_>>>()?;
3262
3263 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3264 let types: Vec<_> = coerced_args
3265 .iter()
3266 .filter_map(|a| a.get_type(schema).ok())
3267 .collect();
3268 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3269 if has_mixed_types {
3270 let all_string_like = types
3274 .iter()
3275 .all(|t| matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Null));
3276 let unified_args: Vec<DfExpr> = if all_string_like {
3277 coerced_args
3278 .into_iter()
3279 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3280 .collect()
3281 } else {
3282 coerced_args
3284 .into_iter()
3285 .zip(types.iter())
3286 .map(|(arg, t)| match t {
3287 DataType::LargeBinary | DataType::Null => arg,
3288 DataType::List(_) | DataType::LargeList(_) => {
3289 list_to_large_binary_expr(arg)
3290 }
3291 _ => scalar_to_large_binary_expr(arg),
3292 })
3293 .collect()
3294 };
3295 return Ok(DfExpr::ScalarFunction(
3296 datafusion::logical_expr::expr::ScalarFunction {
3297 func: func.func.clone(),
3298 args: unified_args,
3299 },
3300 ));
3301 }
3302 }
3303
3304 Ok(DfExpr::ScalarFunction(
3305 datafusion::logical_expr::expr::ScalarFunction {
3306 func: func.func.clone(),
3307 args: coerced_args,
3308 },
3309 ))
3310}
3311
3312fn coerce_case_expr(
3315 case: &datafusion::logical_expr::expr::Case,
3316 schema: &datafusion::common::DFSchema,
3317) -> Result<DfExpr> {
3318 use datafusion::arrow::datatypes::DataType;
3319 use datafusion::logical_expr::ExprSchemable;
3320
3321 let coerced_operand = case
3322 .expr
3323 .as_ref()
3324 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3325 .transpose()?;
3326 let coerced_when_then = case
3327 .when_then_expr
3328 .iter()
3329 .map(|(w, t)| {
3330 let cw = apply_type_coercion(w, schema)?;
3331 let cw = match cw.get_type(schema).ok() {
3332 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3333 _ => cw,
3334 };
3335 let ct = apply_type_coercion(t, schema)?;
3336 Ok((Box::new(cw), Box::new(ct)))
3337 })
3338 .collect::<Result<Vec<_>>>()?;
3339 let coerced_else = case
3340 .else_expr
3341 .as_ref()
3342 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3343 .transpose()?;
3344
3345 let mut result_case = if let Some(operand) = coerced_operand {
3346 crate::cypher_type_coerce::rewrite_simple_case_to_generic(
3347 *operand,
3348 coerced_when_then,
3349 coerced_else,
3350 schema,
3351 )?
3352 } else {
3353 datafusion::logical_expr::expr::Case {
3354 expr: None,
3355 when_then_expr: coerced_when_then,
3356 else_expr: coerced_else,
3357 }
3358 };
3359
3360 crate::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3361
3362 Ok(DfExpr::Case(result_case))
3363}
3364
3365fn coerce_aggregate_function(
3367 agg: &datafusion::logical_expr::expr::AggregateFunction,
3368 schema: &datafusion::common::DFSchema,
3369) -> Result<DfExpr> {
3370 let coerced_args: Vec<DfExpr> = agg
3371 .params
3372 .args
3373 .iter()
3374 .map(|a| apply_type_coercion(a, schema))
3375 .collect::<Result<Vec<_>>>()?;
3376 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3377 .params
3378 .order_by
3379 .iter()
3380 .map(|s| {
3381 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3382 Ok(datafusion::logical_expr::SortExpr {
3383 expr: coerced_expr,
3384 asc: s.asc,
3385 nulls_first: s.nulls_first,
3386 })
3387 })
3388 .collect::<Result<Vec<_>>>()?;
3389 let coerced_filter = agg
3390 .params
3391 .filter
3392 .as_ref()
3393 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3394 .transpose()?;
3395 Ok(DfExpr::AggregateFunction(
3396 datafusion::logical_expr::expr::AggregateFunction {
3397 func: agg.func.clone(),
3398 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3399 args: coerced_args,
3400 distinct: agg.params.distinct,
3401 filter: coerced_filter,
3402 order_by: coerced_order_by,
3403 null_treatment: agg.params.null_treatment,
3404 },
3405 },
3406 ))
3407}
3408
3409#[cfg(test)]
3410mod tests {
3411 use super::*;
3412 use arrow_array::{
3413 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3414 };
3415 use uni_common::TemporalValue;
3416 #[test]
3417 fn test_literal_translation() {
3418 let expr = Expr::Literal(CypherLiteral::Integer(42));
3419 let result = cypher_expr_to_df(&expr, None).unwrap();
3420 let s = format!("{:?}", result);
3421 assert!(s.contains("Literal"));
3423 assert!(s.contains("Int64(42)"));
3424 }
3425
3426 #[test]
3427 fn test_property_access_no_context_uses_index() {
3428 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3430 let result = cypher_expr_to_df(&expr, None).unwrap();
3431 let s = format!("{}", result);
3432 assert!(
3433 s.contains("index"),
3434 "expected index UDF for non-graph variable, got: {s}"
3435 );
3436 }
3437
3438 #[test]
3439 fn test_comparison_operator() {
3440 let expr = Expr::BinaryOp {
3441 left: Box::new(Expr::Property(
3442 Box::new(Expr::Variable("n".to_string())),
3443 "age".to_string(),
3444 )),
3445 op: BinaryOp::Gt,
3446 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3447 };
3448 let result = cypher_expr_to_df(&expr, None).unwrap();
3449 let s = format!("{:?}", result);
3451 assert!(s.contains("age"));
3452 assert!(s.contains("30"));
3453 }
3454
3455 #[test]
3456 fn test_boolean_operators() {
3457 let expr = Expr::BinaryOp {
3458 left: Box::new(Expr::BinaryOp {
3459 left: Box::new(Expr::Property(
3460 Box::new(Expr::Variable("n".to_string())),
3461 "age".to_string(),
3462 )),
3463 op: BinaryOp::Gt,
3464 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3465 }),
3466 op: BinaryOp::And,
3467 right: Box::new(Expr::BinaryOp {
3468 left: Box::new(Expr::Property(
3469 Box::new(Expr::Variable("n".to_string())),
3470 "active".to_string(),
3471 )),
3472 op: BinaryOp::Eq,
3473 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3474 }),
3475 };
3476 let result = cypher_expr_to_df(&expr, None).unwrap();
3477 let s = format!("{:?}", result);
3478 assert!(s.contains("And"));
3479 }
3480
3481 #[test]
3482 fn test_is_null() {
3483 let expr = Expr::IsNull(Box::new(Expr::Property(
3484 Box::new(Expr::Variable("n".to_string())),
3485 "email".to_string(),
3486 )));
3487 let result = cypher_expr_to_df(&expr, None).unwrap();
3488 let s = format!("{:?}", result);
3489 assert!(s.contains("IsNull"));
3490 }
3491
3492 #[test]
3493 fn test_collect_properties() {
3494 let expr = Expr::BinaryOp {
3495 left: Box::new(Expr::Property(
3496 Box::new(Expr::Variable("n".to_string())),
3497 "name".to_string(),
3498 )),
3499 op: BinaryOp::Eq,
3500 right: Box::new(Expr::Property(
3501 Box::new(Expr::Variable("m".to_string())),
3502 "name".to_string(),
3503 )),
3504 };
3505
3506 let props = collect_properties(&expr);
3507 assert_eq!(props.len(), 2);
3508 assert!(props.contains(&("m".to_string(), "name".to_string())));
3509 assert!(props.contains(&("n".to_string(), "name".to_string())));
3510 }
3511
3512 #[test]
3513 fn test_function_call() {
3514 let expr = Expr::FunctionCall {
3515 name: "count".to_string(),
3516 args: vec![Expr::Wildcard],
3517 distinct: false,
3518 window_spec: None,
3519 };
3520 let result = cypher_expr_to_df(&expr, None).unwrap();
3521 let s = format!("{:?}", result);
3522 assert!(s.to_lowercase().contains("count"));
3523 }
3524
3525 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3530 use datafusion::logical_expr::Operator;
3531
3532 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3534 let fields: Vec<_> = cols
3535 .iter()
3536 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3537 .collect();
3538 let schema = Schema::new(fields);
3539 datafusion::common::DFSchema::try_from(schema).unwrap()
3540 }
3541
3542 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3544 let s = format!("{}", expr);
3545 s.contains(name)
3546 }
3547
3548 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3550 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3551 }
3552
3553 #[test]
3554 fn test_coercion_lb_eq_int64() {
3555 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3556 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3557 Box::new(col("lb")),
3558 Operator::Eq,
3559 Box::new(col("i")),
3560 ));
3561 let result = apply_type_coercion(&expr, &schema).unwrap();
3562 assert!(
3564 contains_udf(&result, "_cypher_equal"),
3565 "expected _cypher_equal, got: {result}"
3566 );
3567 }
3568
3569 #[test]
3570 fn test_coercion_lb_noteq_int64() {
3571 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3572 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3573 Box::new(col("lb")),
3574 Operator::NotEq,
3575 Box::new(col("i")),
3576 ));
3577 let result = apply_type_coercion(&expr, &schema).unwrap();
3578 assert!(contains_udf(&result, "_cypher_not_equal"));
3580 }
3581
3582 #[test]
3583 fn test_coercion_lb_lt_int64() {
3584 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3585 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3586 Box::new(col("lb")),
3587 Operator::Lt,
3588 Box::new(col("i")),
3589 ));
3590 let result = apply_type_coercion(&expr, &schema).unwrap();
3591 assert!(contains_udf(&result, "_cypher_lt"));
3593 }
3594
3595 #[test]
3596 fn test_coercion_lb_eq_float64() {
3597 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3598 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3599 Box::new(col("lb")),
3600 Operator::Eq,
3601 Box::new(col("f")),
3602 ));
3603 let result = apply_type_coercion(&expr, &schema).unwrap();
3604 assert!(contains_udf(&result, "_cypher_equal"));
3606 }
3607
3608 #[test]
3609 fn test_coercion_lb_eq_utf8() {
3610 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3611 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3612 Box::new(col("lb")),
3613 Operator::Eq,
3614 Box::new(col("s")),
3615 ));
3616 let result = apply_type_coercion(&expr, &schema).unwrap();
3617 assert!(contains_udf(&result, "_cypher_equal"));
3619 }
3620
3621 #[test]
3622 fn test_coercion_lb_eq_bool() {
3623 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3624 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3625 Box::new(col("lb")),
3626 Operator::Eq,
3627 Box::new(col("b")),
3628 ));
3629 let result = apply_type_coercion(&expr, &schema).unwrap();
3630 assert!(contains_udf(&result, "_cypher_equal"));
3632 }
3633
3634 #[test]
3635 fn test_coercion_int64_eq_lb() {
3636 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3638 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3639 Box::new(col("i")),
3640 Operator::Eq,
3641 Box::new(col("lb")),
3642 ));
3643 let result = apply_type_coercion(&expr, &schema).unwrap();
3644 assert!(contains_udf(&result, "_cypher_equal"));
3646 }
3647
3648 #[test]
3649 fn test_coercion_float64_gt_lb() {
3650 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3651 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3652 Box::new(col("f")),
3653 Operator::Gt,
3654 Box::new(col("lb")),
3655 ));
3656 let result = apply_type_coercion(&expr, &schema).unwrap();
3657 assert!(contains_udf(&result, "_cypher_gt"));
3659 }
3660
3661 #[test]
3662 fn test_coercion_both_lb_eq() {
3663 let schema = make_schema(&[
3664 ("lb1", DataType::LargeBinary),
3665 ("lb2", DataType::LargeBinary),
3666 ]);
3667 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3668 Box::new(col("lb1")),
3669 Operator::Eq,
3670 Box::new(col("lb2")),
3671 ));
3672 let result = apply_type_coercion(&expr, &schema).unwrap();
3673 assert!(contains_udf(&result, "_cypher_equal"));
3674 }
3675
3676 #[test]
3677 fn test_coercion_both_lb_lt() {
3678 let schema = make_schema(&[
3679 ("lb1", DataType::LargeBinary),
3680 ("lb2", DataType::LargeBinary),
3681 ]);
3682 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3683 Box::new(col("lb1")),
3684 Operator::Lt,
3685 Box::new(col("lb2")),
3686 ));
3687 let result = apply_type_coercion(&expr, &schema).unwrap();
3688 assert!(contains_udf(&result, "_cypher_lt"));
3689 }
3690
3691 #[test]
3692 fn test_coercion_both_lb_noteq() {
3693 let schema = make_schema(&[
3694 ("lb1", DataType::LargeBinary),
3695 ("lb2", DataType::LargeBinary),
3696 ]);
3697 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3698 Box::new(col("lb1")),
3699 Operator::NotEq,
3700 Box::new(col("lb2")),
3701 ));
3702 let result = apply_type_coercion(&expr, &schema).unwrap();
3703 assert!(contains_udf(&result, "_cypher_not_equal"));
3704 }
3705
3706 #[test]
3707 fn test_coercion_lb_plus_int64() {
3708 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3709 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3710 Box::new(col("lb")),
3711 Operator::Plus,
3712 Box::new(col("i")),
3713 ));
3714 let result = apply_type_coercion(&expr, &schema).unwrap();
3715 assert!(contains_udf(&result, "_cypher_add"));
3716 }
3717
3718 #[test]
3719 fn test_coercion_lb_minus_int64() {
3720 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3721 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3722 Box::new(col("lb")),
3723 Operator::Minus,
3724 Box::new(col("i")),
3725 ));
3726 let result = apply_type_coercion(&expr, &schema).unwrap();
3727 assert!(contains_udf(&result, "_cypher_sub"));
3728 }
3729
3730 #[test]
3731 fn test_coercion_lb_multiply_float64() {
3732 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3733 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3734 Box::new(col("lb")),
3735 Operator::Multiply,
3736 Box::new(col("f")),
3737 ));
3738 let result = apply_type_coercion(&expr, &schema).unwrap();
3739 assert!(contains_udf(&result, "_cypher_mul"));
3740 }
3741
3742 #[test]
3743 fn test_coercion_int64_plus_lb() {
3744 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3745 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3746 Box::new(col("i")),
3747 Operator::Plus,
3748 Box::new(col("lb")),
3749 ));
3750 let result = apply_type_coercion(&expr, &schema).unwrap();
3751 assert!(contains_udf(&result, "_cypher_add"));
3752 }
3753
3754 #[test]
3755 fn test_coercion_lb_plus_utf8() {
3756 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3758 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3759 Box::new(col("lb")),
3760 Operator::Plus,
3761 Box::new(col("s")),
3762 ));
3763 let result = apply_type_coercion(&expr, &schema).unwrap();
3764 assert!(contains_udf(&result, "_cypher_add"));
3766 }
3767
3768 #[test]
3769 fn test_coercion_and_null_bool() {
3770 let schema = make_schema(&[("b", DataType::Boolean)]);
3771 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3773 Box::new(lit(ScalarValue::Null)),
3774 Operator::And,
3775 Box::new(col("b")),
3776 ));
3777 let result = apply_type_coercion(&expr, &schema).unwrap();
3778 let s = format!("{}", result);
3779 assert!(
3781 s.contains("CAST") || s.contains("Boolean"),
3782 "expected cast to Boolean, got: {s}"
3783 );
3784 assert!(is_binary_op(&result, Operator::And));
3785 }
3786
3787 #[test]
3788 fn test_coercion_bool_and_null() {
3789 let schema = make_schema(&[("b", DataType::Boolean)]);
3790 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3791 Box::new(col("b")),
3792 Operator::And,
3793 Box::new(lit(ScalarValue::Null)),
3794 ));
3795 let result = apply_type_coercion(&expr, &schema).unwrap();
3796 assert!(is_binary_op(&result, Operator::And));
3797 }
3798
3799 #[test]
3800 fn test_coercion_or_null_bool() {
3801 let schema = make_schema(&[("b", DataType::Boolean)]);
3802 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3803 Box::new(lit(ScalarValue::Null)),
3804 Operator::Or,
3805 Box::new(col("b")),
3806 ));
3807 let result = apply_type_coercion(&expr, &schema).unwrap();
3808 assert!(is_binary_op(&result, Operator::Or));
3809 }
3810
3811 #[test]
3812 fn test_coercion_null_and_null() {
3813 let schema = make_schema(&[]);
3814 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3815 Box::new(lit(ScalarValue::Null)),
3816 Operator::And,
3817 Box::new(lit(ScalarValue::Null)),
3818 ));
3819 let result = apply_type_coercion(&expr, &schema).unwrap();
3820 assert!(is_binary_op(&result, Operator::And));
3821 }
3822
3823 #[test]
3824 fn test_coercion_bool_and_bool_noop() {
3825 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3826 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3827 Box::new(col("a")),
3828 Operator::And,
3829 Box::new(col("b")),
3830 ));
3831 let result = apply_type_coercion(&expr, &schema).unwrap();
3832 assert!(is_binary_op(&result, Operator::And));
3834 let s = format!("{}", result);
3835 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3836 }
3837
3838 #[test]
3839 fn test_coercion_case_when_lb() {
3840 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3842 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3843 Box::new(col("lb")),
3844 Operator::Eq,
3845 Box::new(lit(42_i64)),
3846 ));
3847 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3848 expr: None,
3849 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3850 else_expr: Some(Box::new(lit("b"))),
3851 });
3852 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3853 let s = format!("{}", result);
3854 assert!(
3856 s.contains("_cypher_equal"),
3857 "CASE WHEN should have _cypher_equal, got: {s}"
3858 );
3859 }
3860
3861 #[test]
3862 fn test_coercion_case_then_lb() {
3863 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3865 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3866 Box::new(col("lb")),
3867 Operator::Plus,
3868 Box::new(lit(1_i64)),
3869 ));
3870 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3871 expr: None,
3872 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3873 else_expr: Some(Box::new(lit(0_i64))),
3874 });
3875 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3876 let s = format!("{}", result);
3877 assert!(
3878 s.contains("_cypher_add"),
3879 "CASE THEN should have _cypher_add, got: {s}"
3880 );
3881 }
3882
3883 #[test]
3884 fn test_coercion_case_else_lb() {
3885 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3887 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3888 Box::new(col("lb")),
3889 Operator::Plus,
3890 Box::new(lit(2_i64)),
3891 ));
3892 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3893 expr: None,
3894 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3895 else_expr: Some(Box::new(else_expr)),
3896 });
3897 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3898 let s = format!("{}", result);
3899 assert!(
3900 s.contains("_cypher_add"),
3901 "CASE ELSE should have _cypher_add, got: {s}"
3902 );
3903 }
3904
3905 #[test]
3906 fn test_coercion_int64_eq_int64_noop() {
3907 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3908 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3909 Box::new(col("a")),
3910 Operator::Eq,
3911 Box::new(col("b")),
3912 ));
3913 let result = apply_type_coercion(&expr, &schema).unwrap();
3914 assert!(is_binary_op(&result, Operator::Eq));
3915 let s = format!("{}", result);
3916 assert!(
3917 !s.contains("_cypher_value"),
3918 "should not contain cypher_value decode: {s}"
3919 );
3920 }
3921
3922 #[test]
3923 fn test_coercion_both_lb_plus() {
3924 let schema = make_schema(&[
3926 ("lb1", DataType::LargeBinary),
3927 ("lb2", DataType::LargeBinary),
3928 ]);
3929 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3930 Box::new(col("lb1")),
3931 Operator::Plus,
3932 Box::new(col("lb2")),
3933 ));
3934 let result = apply_type_coercion(&expr, &schema).unwrap();
3935 assert!(
3936 contains_udf(&result, "_cypher_add"),
3937 "expected _cypher_add, got: {result}"
3938 );
3939 }
3940
3941 #[test]
3942 fn test_coercion_native_list_plus_scalar() {
3943 let schema = make_schema(&[
3945 (
3946 "lst",
3947 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3948 ),
3949 ("i", DataType::Int32),
3950 ]);
3951 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3952 Box::new(col("lst")),
3953 Operator::Plus,
3954 Box::new(col("i")),
3955 ));
3956 let result = apply_type_coercion(&expr, &schema).unwrap();
3957 assert!(
3958 contains_udf(&result, "_cypher_list_append"),
3959 "expected _cypher_list_append, got: {result}"
3960 );
3961 }
3962
3963 #[test]
3964 fn test_coercion_lb_plus_int64_unchanged() {
3965 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3967 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3968 Box::new(col("lb")),
3969 Operator::Plus,
3970 Box::new(col("i")),
3971 ));
3972 let result = apply_type_coercion(&expr, &schema).unwrap();
3973 assert!(
3974 contains_udf(&result, "_cypher_add"),
3975 "expected _cypher_add, got: {result}"
3976 );
3977 }
3978
3979 #[test]
3984 fn test_mixed_list_with_variables_compiles() {
3985 let expr = Expr::List(vec![
3987 Expr::Variable("n".to_string()),
3988 Expr::Literal(CypherLiteral::Integer(1)),
3989 Expr::Literal(CypherLiteral::String("hello".to_string())),
3990 ]);
3991 let result = cypher_expr_to_df(&expr, None).unwrap();
3992 let s = format!("{}", result);
3993 assert!(
3994 s.contains("_make_cypher_list"),
3995 "expected _make_cypher_list UDF call, got: {s}"
3996 );
3997 }
3998
3999 #[test]
4000 fn test_literal_only_mixed_list_uses_cv_fastpath() {
4001 let expr = Expr::List(vec![
4003 Expr::Literal(CypherLiteral::Integer(1)),
4004 Expr::Literal(CypherLiteral::String("hi".to_string())),
4005 Expr::Literal(CypherLiteral::Bool(true)),
4006 ]);
4007 let result = cypher_expr_to_df(&expr, None).unwrap();
4008 assert!(
4009 matches!(result, DfExpr::Literal(..)),
4010 "expected Literal (CypherValue fast path), got: {result}"
4011 );
4012 }
4013
4014 #[test]
4019 fn test_in_mixed_literal_list_uses_cypher_in() {
4020 let expr = Expr::In {
4022 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4023 list: Box::new(Expr::List(vec![
4024 Expr::Literal(CypherLiteral::String("1".to_string())),
4025 Expr::Literal(CypherLiteral::Integer(2)),
4026 ])),
4027 };
4028 let result = cypher_expr_to_df(&expr, None).unwrap();
4029 let s = format!("{}", result);
4030 assert!(
4031 s.contains("_cypher_in"),
4032 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
4033 );
4034 }
4035
4036 #[test]
4037 fn test_in_homogeneous_literal_list_uses_cypher_in() {
4038 let expr = Expr::In {
4040 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4041 list: Box::new(Expr::List(vec![
4042 Expr::Literal(CypherLiteral::Integer(2)),
4043 Expr::Literal(CypherLiteral::Integer(3)),
4044 ])),
4045 };
4046 let result = cypher_expr_to_df(&expr, None).unwrap();
4047 let s = format!("{}", result);
4048 assert!(
4049 s.contains("_cypher_in"),
4050 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
4051 );
4052 }
4053
4054 #[test]
4055 fn test_in_list_with_variables_uses_make_cypher_list() {
4056 let expr = Expr::In {
4058 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4059 list: Box::new(Expr::List(vec![
4060 Expr::Variable("x".to_string()),
4061 Expr::Literal(CypherLiteral::Integer(2)),
4062 ])),
4063 };
4064 let result = cypher_expr_to_df(&expr, None).unwrap();
4065 let s = format!("{}", result);
4066 assert!(
4067 s.contains("_cypher_in"),
4068 "expected _cypher_in UDF, got: {s}"
4069 );
4070 assert!(
4071 s.contains("_make_cypher_list"),
4072 "expected _make_cypher_list for variable-containing list, got: {s}"
4073 );
4074 }
4075
4076 #[test]
4081 fn test_property_on_graph_entity_uses_column() {
4082 let mut ctx = TranslationContext::new();
4084 ctx.variable_kinds
4085 .insert("n".to_string(), VariableKind::Node);
4086
4087 let expr = Expr::Property(
4088 Box::new(Expr::Variable("n".to_string())),
4089 "name".to_string(),
4090 );
4091 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4092 let s = format!("{:?}", result);
4093 assert!(
4094 s.contains("Column") && s.contains("n.name"),
4095 "expected flat column 'n.name' for graph entity, got: {s}"
4096 );
4097 }
4098
4099 #[test]
4100 fn test_property_on_non_graph_var_uses_index() {
4101 let ctx = TranslationContext::new();
4103
4104 let expr = Expr::Property(
4105 Box::new(Expr::Variable("map".to_string())),
4106 "name".to_string(),
4107 );
4108 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4109 let s = format!("{}", result);
4110 assert!(
4111 s.contains("index"),
4112 "expected index UDF for non-graph variable, got: {s}"
4113 );
4114 }
4115
4116 #[test]
4117 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4118 let mut map = std::collections::HashMap::new();
4119 map.insert("k".to_string(), Value::Int(1));
4120 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4121 assert!(
4122 matches!(scalar, ScalarValue::Struct(_)),
4123 "expected Struct scalar for map input"
4124 );
4125 }
4126
4127 #[test]
4128 fn test_value_to_scalar_empty_map_becomes_struct() {
4129 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4130 assert!(
4131 matches!(scalar, ScalarValue::Struct(_)),
4132 "empty map should produce an empty Struct scalar"
4133 );
4134 }
4135
4136 #[test]
4137 fn test_value_to_scalar_null_is_untyped_null() {
4138 let scalar = value_to_scalar(&Value::Null).unwrap();
4139 assert!(
4140 matches!(scalar, ScalarValue::Null),
4141 "expected untyped Null scalar for Value::Null"
4142 );
4143 }
4144
4145 #[test]
4146 fn test_value_to_scalar_datetime_produces_struct() {
4147 let datetime = Value::Temporal(TemporalValue::DateTime {
4149 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4152 });
4153
4154 let scalar = value_to_scalar(&datetime).unwrap();
4155
4156 if let ScalarValue::Struct(struct_arr) = scalar {
4158 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4159 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4160
4161 let fields = struct_arr.fields();
4163 assert_eq!(fields[0].name(), "nanos_since_epoch");
4164 assert_eq!(fields[1].name(), "offset_seconds");
4165 assert_eq!(fields[2].name(), "timezone_name");
4166
4167 let nanos_col = struct_arr.column(0);
4169 let offset_col = struct_arr.column(1);
4170 let tz_col = struct_arr.column(2);
4171
4172 if let Some(nanos_arr) = nanos_col
4173 .as_any()
4174 .downcast_ref::<TimestampNanosecondArray>()
4175 {
4176 assert_eq!(nanos_arr.value(0), 441763200000000000);
4177 } else {
4178 panic!("Expected TimestampNanosecondArray for nanos field");
4179 }
4180
4181 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4182 assert_eq!(offset_arr.value(0), 3600);
4183 } else {
4184 panic!("Expected Int32Array for offset field");
4185 }
4186
4187 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4188 assert_eq!(tz_arr.value(0), "Europe/Paris");
4189 } else {
4190 panic!("Expected StringArray for timezone_name field");
4191 }
4192 } else {
4193 panic!(
4194 "Expected ScalarValue::Struct for DateTime, got {:?}",
4195 scalar
4196 );
4197 }
4198 }
4199
4200 #[test]
4201 fn test_value_to_scalar_datetime_with_null_timezone() {
4202 let datetime = Value::Temporal(TemporalValue::DateTime {
4204 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4207 });
4208
4209 let scalar = value_to_scalar(&datetime).unwrap();
4210
4211 if let ScalarValue::Struct(struct_arr) = scalar {
4212 assert_eq!(struct_arr.num_columns(), 3);
4213
4214 let tz_col = struct_arr.column(2);
4216 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4217 assert!(tz_arr.is_null(0), "expected null timezone_name");
4218 } else {
4219 panic!("Expected StringArray for timezone_name field");
4220 }
4221 } else {
4222 panic!("Expected ScalarValue::Struct for DateTime");
4223 }
4224 }
4225
4226 #[test]
4227 fn test_value_to_scalar_time_produces_struct() {
4228 let time = Value::Temporal(TemporalValue::Time {
4230 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4233
4234 let scalar = value_to_scalar(&time).unwrap();
4235
4236 if let ScalarValue::Struct(struct_arr) = scalar {
4238 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4239 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4240
4241 let fields = struct_arr.fields();
4243 assert_eq!(fields[0].name(), "nanos_since_midnight");
4244 assert_eq!(fields[1].name(), "offset_seconds");
4245
4246 let nanos_col = struct_arr.column(0);
4248 let offset_col = struct_arr.column(1);
4249
4250 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4251 assert_eq!(nanos_arr.value(0), 37845000000000);
4252 } else {
4253 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4254 }
4255
4256 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4257 assert_eq!(offset_arr.value(0), 3600);
4258 } else {
4259 panic!("Expected Int32Array for offset field");
4260 }
4261 } else {
4262 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4263 }
4264 }
4265
4266 #[test]
4267 fn test_value_to_scalar_time_boundary_values() {
4268 let midnight = Value::Temporal(TemporalValue::Time {
4270 nanos_since_midnight: 0,
4271 offset_seconds: 0,
4272 });
4273
4274 let scalar = value_to_scalar(&midnight).unwrap();
4275
4276 if let ScalarValue::Struct(struct_arr) = scalar {
4277 let nanos_col = struct_arr.column(0);
4278 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4279 assert_eq!(nanos_arr.value(0), 0);
4280 } else {
4281 panic!("Expected Time64NanosecondArray");
4282 }
4283 } else {
4284 panic!("Expected ScalarValue::Struct for Time");
4285 }
4286 }
4287}