1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum VariableKind {
225 Node,
227 Edge,
229 EdgeList,
231 Path,
233}
234
235impl VariableKind {
236 pub fn edge_for(is_variable_length: bool) -> Self {
240 if is_variable_length {
241 Self::EdgeList
242 } else {
243 Self::Edge
244 }
245 }
246}
247
248pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
283 match expr {
284 Expr::PatternComprehension { .. } => Err(anyhow!(
285 "Pattern comprehensions require fallback executor (graph traversal)"
286 )),
287 Expr::Wildcard => Ok(DfExpr::Literal(
291 datafusion::common::ScalarValue::Int32(Some(1)),
292 None,
293 )),
294
295 Expr::Variable(name) => {
296 if let Some(ctx) = context
301 && ctx.variable_kinds.contains_key(name)
302 {
303 return Ok(DfExpr::Column(Column::from_name(name)));
304 }
305
306 if let Some(ctx) = context
312 && let Some(value) = ctx.outer_values.get(name)
313 {
314 return value_to_scalar(value).map(lit);
315 }
316
317 if let Some(ctx) = context
322 && let Some(value) = ctx.parameters.get(name)
323 {
324 match value {
327 Value::List(values) if name.ends_with("._vid") => {
328 let literals = values
330 .iter()
331 .map(|v| value_to_scalar(v).map(lit))
332 .collect::<Result<Vec<_>>>()?;
333 return Ok(DfExpr::InList(InList {
334 expr: Box::new(DfExpr::Column(Column::from_name(name))),
335 list: literals,
336 negated: false,
337 }));
338 }
339 other_value => return value_to_scalar(other_value).map(lit),
340 }
341 }
342
343 Ok(DfExpr::Column(Column::from_name(name)))
346 }
347
348 Expr::Property(base, prop) => translate_property_access(base, prop, context),
349
350 Expr::ArrayIndex { array, index } => {
351 if let Ok(var_name) = extract_variable_name(array)
354 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
355 {
356 let col_name = format!("{}.{}", var_name, prop_name);
357 return Ok(DfExpr::Column(Column::from_name(col_name)));
358 }
359
360 let array_expr = cypher_expr_to_df(array, context)?;
361 let index_expr = cypher_expr_to_df(index, context)?;
362
363 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
365 }
366
367 Expr::ArraySlice { array, start, end } => {
368 let array_expr = cypher_expr_to_df(array, context)?;
372
373 let start_expr = match start {
374 Some(s) => cypher_expr_to_df(s, context)?,
375 None => lit(0i64),
376 };
377
378 let end_expr = match end {
379 Some(e) => cypher_expr_to_df(e, context)?,
380 None => lit(i64::MAX),
381 };
382
383 Ok(dummy_udf_expr(
386 "_cypher_list_slice",
387 vec![array_expr, start_expr, end_expr],
388 ))
389 }
390
391 Expr::Parameter(name) => {
392 if let Some(ctx) = context
394 && let Some(value) = ctx.parameters.get(name)
395 {
396 return value_to_scalar(value).map(lit);
397 }
398 Err(anyhow!("Unresolved parameter: ${}", name))
399 }
400
401 Expr::Literal(value) => {
402 let scalar = cypher_literal_to_scalar(value)?;
403 Ok(lit(scalar))
404 }
405
406 Expr::List(items) => translate_list_literal(items, context),
407
408 Expr::Map(entries) => {
409 if entries.is_empty() {
410 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
412 Default::default(),
413 ));
414 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
415 }
416 let mut args = Vec::with_capacity(entries.len() * 2);
419 for (key, val_expr) in entries {
420 args.push(lit(key.clone()));
421 args.push(cypher_expr_to_df(val_expr, context)?);
422 }
423 Ok(datafusion::functions::expr_fn::named_struct(args))
424 }
425
426 Expr::IsNull(inner) => translate_null_check(inner, context, true),
427
428 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
429
430 Expr::IsUnique(_) => {
431 Err(anyhow!(
433 "IS UNIQUE can only be used in constraint definitions"
434 ))
435 }
436
437 Expr::FunctionCall {
438 name,
439 args,
440 distinct,
441 window_spec,
442 } => {
443 if window_spec.is_some() {
446 let col_name = expr.to_string_repr();
448 Ok(col(&col_name))
449 } else {
450 translate_function_call(name, args, *distinct, context)
451 }
452 }
453
454 Expr::In { expr, list } => translate_in_expression(expr, list, context),
455
456 Expr::BinaryOp { left, op, right } => {
457 let left_expr = cypher_expr_to_df(left, context)?;
458 let right_expr = cypher_expr_to_df(right, context)?;
459 translate_binary_op(left_expr, op, right_expr)
460 }
461
462 Expr::UnaryOp { op, expr: inner } => {
463 let inner_expr = cypher_expr_to_df(inner, context)?;
464 match op {
465 UnaryOp::Not => Ok(inner_expr.not()),
466 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
467 }
468 }
469
470 Expr::Case {
471 expr,
472 when_then,
473 else_expr,
474 } => translate_case_expression(expr, when_then, else_expr, context),
475
476 Expr::Reduce { .. } => Err(anyhow!(
477 "Reduce expressions not yet supported in DataFusion translation"
478 )),
479
480 Expr::Exists { .. } => Err(anyhow!(
481 "EXISTS subqueries are handled by the physical expression compiler, \
482 not the DataFusion logical expression translator"
483 )),
484
485 Expr::CountSubquery(_) => Err(anyhow!(
486 "Count subqueries not yet supported in DataFusion translation"
487 )),
488
489 Expr::CollectSubquery(_) => Err(anyhow!(
490 "COLLECT subqueries not yet supported in DataFusion translation"
491 )),
492
493 Expr::Quantifier { .. } => {
494 Err(anyhow!(
499 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
500 via CypherPhysicalExprCompiler"
501 ))
502 }
503
504 Expr::ListComprehension { .. } => {
505 Err(anyhow!(
517 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
518 ))
519 }
520
521 Expr::ValidAt { .. } => {
522 Err(anyhow!(
525 "VALID_AT expression should have been transformed to function call in planner"
526 ))
527 }
528
529 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
530
531 Expr::LabelCheck { expr, labels } => {
532 if let Expr::Variable(var) = expr.as_ref() {
533 let is_edge = context
535 .and_then(|ctx| ctx.variable_kinds.get(var))
536 .is_some_and(|k| matches!(k, VariableKind::Edge));
537
538 if is_edge {
539 if labels.len() > 1 {
543 Ok(lit(false))
544 } else {
545 let type_col =
546 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
547 Ok(DfExpr::Case(datafusion::logical_expr::Case {
549 expr: None,
550 when_then_expr: vec![(
551 Box::new(type_col.clone().is_null()),
552 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
553 )],
554 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
555 }))
556 }
557 } else {
558 let labels_col =
560 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
561 let checks = labels
562 .iter()
563 .map(|label| {
564 datafusion::functions_nested::expr_fn::array_has(
565 labels_col.clone(),
566 lit(label.clone()),
567 )
568 })
569 .reduce(|acc, check| acc.and(check));
570 Ok(DfExpr::Case(datafusion::logical_expr::Case {
572 expr: None,
573 when_then_expr: vec![(
574 Box::new(labels_col.is_null()),
575 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
576 )],
577 else_expr: Some(Box::new(checks.unwrap())),
578 }))
579 }
580 } else {
581 Err(anyhow!(
582 "LabelCheck on non-variable expression not yet supported in DataFusion"
583 ))
584 }
585 }
586 }
587}
588
589#[derive(Debug, Clone)]
593pub struct TranslationContext {
594 pub parameters: std::collections::HashMap<String, Value>,
596
597 pub outer_values: std::collections::HashMap<String, Value>,
601
602 pub variable_labels: std::collections::HashMap<String, String>,
604
605 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
607
608 pub node_variable_hints: Vec<String>,
611
612 pub mutation_edge_hints: Vec<String>,
615
616 pub statement_time: chrono::DateTime<chrono::Utc>,
621}
622
623impl Default for TranslationContext {
624 fn default() -> Self {
625 Self {
626 parameters: std::collections::HashMap::new(),
627 outer_values: std::collections::HashMap::new(),
628 variable_labels: std::collections::HashMap::new(),
629 variable_kinds: std::collections::HashMap::new(),
630 node_variable_hints: Vec::new(),
631 mutation_edge_hints: Vec::new(),
632 statement_time: chrono::Utc::now(),
633 }
634 }
635}
636
637impl TranslationContext {
638 pub fn new() -> Self {
640 Self::default()
641 }
642
643 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
645 self.parameters.insert(name.into(), value);
646 self
647 }
648
649 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
651 self.variable_labels.insert(var.into(), label.into());
652 self
653 }
654}
655
656fn extract_variable_name(expr: &Expr) -> Result<String> {
658 match expr {
659 Expr::Variable(name) => Ok(name.clone()),
660 Expr::Property(base, _) => extract_variable_name(base),
661 _ => Err(anyhow!(
662 "Cannot extract variable name from expression: {:?}",
663 expr
664 )),
665 }
666}
667
668fn translate_null_check(
670 inner: &Expr,
671 context: Option<&TranslationContext>,
672 is_null: bool,
673) -> Result<DfExpr> {
674 if let Expr::Variable(var) = inner
675 && let Some(ctx) = context
676 && let Some(kind) = ctx.variable_kinds.get(var)
677 {
678 let col_name = match kind {
679 VariableKind::Node => format!("{}.{}", var, COL_VID),
680 VariableKind::Edge => format!("{}.{}", var, COL_EID),
681 VariableKind::Path | VariableKind::EdgeList => var.clone(),
682 };
683 let col_expr = DfExpr::Column(Column::from_name(col_name));
684 return Ok(if is_null {
685 col_expr.is_null()
686 } else {
687 col_expr.is_not_null()
688 });
689 }
690
691 let inner_expr = cypher_expr_to_df(inner, context)?;
692 Ok(if is_null {
693 inner_expr.is_null()
694 } else {
695 inner_expr.is_not_null()
696 })
697}
698
699fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
704 if crate::datetime::is_duration_accessor(prop) {
705 Some(dummy_udf_expr(
706 "_duration_property",
707 vec![base_expr, lit(prop.to_string())],
708 ))
709 } else if crate::datetime::is_temporal_accessor(prop) {
710 Some(dummy_udf_expr(
711 "_temporal_property",
712 vec![base_expr, lit(prop.to_string())],
713 ))
714 } else {
715 None
716 }
717}
718
719fn translate_property_access(
721 base: &Expr,
722 prop: &str,
723 context: Option<&TranslationContext>,
724) -> Result<DfExpr> {
725 if let Ok(var_name) = extract_variable_name(base) {
726 let is_graph_entity = context
727 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
728 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
729
730 if !is_graph_entity
731 && let Some(expr) =
732 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
733 {
734 return Ok(expr);
735 }
736
737 let col_name = format!("{}.{}", var_name, prop);
738
739 if let Some(ctx) = context
742 && let Some(value) = ctx.parameters.get(&col_name)
743 {
744 match value {
747 Value::List(values) if col_name.ends_with("._vid") => {
748 let literals = values
749 .iter()
750 .map(|v| value_to_scalar(v).map(lit))
751 .collect::<Result<Vec<_>>>()?;
752 return Ok(DfExpr::InList(InList {
753 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
754 list: literals,
755 negated: false,
756 }));
757 }
758 other_value => return value_to_scalar(other_value).map(lit),
759 }
760 }
761
762 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
765 let base_expr = cypher_expr_to_df(base, context)?;
766 return Ok(dummy_udf_expr(
767 "index",
768 vec![base_expr, lit(prop.to_string())],
769 ));
770 }
771
772 if is_graph_entity {
773 Ok(DfExpr::Column(Column::from_name(col_name)))
774 } else {
775 let base_expr = DfExpr::Column(Column::from_name(var_name));
776 Ok(dummy_udf_expr(
777 "index",
778 vec![base_expr, lit(prop.to_string())],
779 ))
780 }
781 } else {
782 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
784 return Ok(expr);
785 }
786
787 if let Expr::Parameter(param_name) = base {
789 if let Some(ctx) = context
790 && let Some(value) = ctx.parameters.get(param_name)
791 {
792 if let Value::Map(map) = value {
793 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
794 return value_to_scalar(&extracted).map(lit);
795 }
796 return Ok(lit(ScalarValue::Null));
797 }
798 return Err(anyhow!("Unresolved parameter: ${}", param_name));
799 }
800
801 let base_expr = cypher_expr_to_df(base, context)?;
802 Ok(dummy_udf_expr(
803 "index",
804 vec![base_expr, lit(prop.to_string())],
805 ))
806 }
807}
808
809fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
811 let mut has_string = false;
813 let mut has_bool = false;
814 let mut has_list = false;
815 let mut has_map = false;
816 let mut has_numeric = false;
817 let mut has_graph_entity = false;
818 let mut has_temporal = false;
819
820 for item in items {
821 match item {
822 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
823 has_numeric = true
824 }
825 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
826 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
827 Expr::List(_) => has_list = true,
828 Expr::Map(_) => has_map = true,
829 Expr::Variable(name)
832 if context
833 .and_then(|ctx| ctx.variable_kinds.get(name))
834 .is_some() =>
835 {
836 has_graph_entity = true;
837 }
838 Expr::FunctionCall { name, .. } => {
841 let upper = name.to_uppercase();
842 if matches!(
843 upper.as_str(),
844 "DATE"
845 | "TIME"
846 | "LOCALTIME"
847 | "LOCALDATETIME"
848 | "DATETIME"
849 | "DURATION"
850 | "DATE.TRUNCATE"
851 | "TIME.TRUNCATE"
852 | "DATETIME.TRUNCATE"
853 | "LOCALDATETIME.TRUNCATE"
854 | "LOCALTIME.TRUNCATE"
855 ) {
856 has_temporal = true;
857 }
858 }
859 _ => {}
861 }
862 }
863
864 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
866
867 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
870 if let Some(json_array) = try_items_to_json(items) {
872 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
873 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
874 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
875 }
876 let df_args: Vec<DfExpr> = items
878 .iter()
879 .map(|item| cypher_expr_to_df(item, context))
880 .collect::<Result<_>>()?;
881 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
882 }
883
884 let mut df_args = Vec::with_capacity(items.len());
887 let mut has_float = false;
888 let mut has_int = false;
889 let mut has_other = false;
890
891 for item in items {
892 match item {
893 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
894 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
895 _ => has_other = true,
896 }
897 df_args.push(cypher_expr_to_df(item, context)?);
898 }
899
900 if df_args.is_empty() {
901 let empty_arr =
903 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
904 Ok(lit(ScalarValue::List(empty_arr)))
905 } else if has_float && has_int && !has_other {
906 let promoted_args = df_args
908 .into_iter()
909 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
910 .collect();
911 Ok(datafusion::functions_nested::expr_fn::make_array(
912 promoted_args,
913 ))
914 } else {
915 let non_null_type = df_args.iter().find_map(|e| {
919 if let DfExpr::Literal(sv, _) = e {
920 let dt = sv.data_type();
921 if dt != datafusion::arrow::datatypes::DataType::Null {
922 return Some(dt);
923 }
924 }
925 None
926 });
927 if let Some(ref target_type) = non_null_type {
928 let coerced = df_args
929 .into_iter()
930 .map(|e| {
931 if matches!(&e, DfExpr::Literal(sv, _) if sv.data_type() == datafusion::arrow::datatypes::DataType::Null)
932 {
933 cast_expr(e, target_type.clone())
934 } else {
935 e
936 }
937 })
938 .collect();
939 Ok(datafusion::functions_nested::expr_fn::make_array(coerced))
940 } else {
941 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
942 }
943 }
944}
945
946fn translate_in_expression(
948 expr: &Expr,
949 list: &Expr,
950 context: Option<&TranslationContext>,
951) -> Result<DfExpr> {
952 let left_expr = if let Expr::Variable(var) = expr
957 && let Some(ctx) = context
958 && let Some(kind) = ctx.variable_kinds.get(var)
959 {
960 match kind {
961 VariableKind::Node | VariableKind::Edge => {
962 let id_col = match kind {
963 VariableKind::Node => COL_VID,
964 VariableKind::Edge => COL_EID,
965 _ => unreachable!(),
966 };
967 cast_expr(
968 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
969 datafusion::arrow::datatypes::DataType::Int64,
970 )
971 }
972 _ => cypher_expr_to_df(expr, context)?,
973 }
974 } else {
975 cypher_expr_to_df(expr, context)?
976 };
977
978 if let Expr::List(items) = list {
983 if let Some(json_array) = try_items_to_json(items) {
984 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
986 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
987 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
988 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
989 } else {
990 let expanded: Vec<DfExpr> = items
992 .iter()
993 .map(|item| cypher_expr_to_df(item, context))
994 .collect::<Result<Vec<_>>>()?;
995 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
996 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
997 }
998 } else {
999 let right_expr = cypher_expr_to_df(list, context)?;
1000
1001 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
1006 return Ok(lit(ScalarValue::Boolean(None)));
1007 }
1008
1009 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
1010 }
1011}
1012
1013fn translate_case_expression(
1015 operand: &Option<Box<Expr>>,
1016 when_then: &[(Expr, Expr)],
1017 else_expr: &Option<Box<Expr>>,
1018 context: Option<&TranslationContext>,
1019) -> Result<DfExpr> {
1020 let mut case_builder = if let Some(match_expr) = operand {
1021 let match_df = cypher_expr_to_df(match_expr, context)?;
1022 datafusion::logical_expr::case(match_df)
1023 } else {
1024 datafusion::logical_expr::when(
1025 cypher_expr_to_df(&when_then[0].0, context)?,
1026 cypher_expr_to_df(&when_then[0].1, context)?,
1027 )
1028 };
1029
1030 let start_idx = if operand.is_some() { 0 } else { 1 };
1031 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1032 let when_df = cypher_expr_to_df(when_expr, context)?;
1033 let then_df = cypher_expr_to_df(then_expr, context)?;
1034 case_builder = case_builder.when(when_df, then_df);
1035 }
1036
1037 if let Some(else_e) = else_expr {
1038 let else_df = cypher_expr_to_df(else_e, context)?;
1039 Ok(case_builder.otherwise(else_df)?)
1040 } else {
1041 Ok(case_builder.end()?)
1042 }
1043}
1044
1045fn translate_map_projection(
1047 base: &Expr,
1048 items: &[MapProjectionItem],
1049 context: Option<&TranslationContext>,
1050) -> Result<DfExpr> {
1051 let mut args = Vec::new();
1052 for item in items {
1053 match item {
1054 MapProjectionItem::Property(prop) => {
1055 args.push(lit(prop.clone()));
1056 let prop_expr = cypher_expr_to_df(
1057 &Expr::Property(Box::new(base.clone()), prop.clone()),
1058 context,
1059 )?;
1060 args.push(prop_expr);
1061 }
1062 MapProjectionItem::LiteralEntry(key, expr) => {
1063 args.push(lit(key.clone()));
1064 args.push(cypher_expr_to_df(expr, context)?);
1065 }
1066 MapProjectionItem::Variable(var) => {
1067 args.push(lit(var.clone()));
1068 args.push(DfExpr::Column(Column::from_name(var)));
1069 }
1070 MapProjectionItem::AllProperties => {
1071 args.push(lit("__all__"));
1072 args.push(cypher_expr_to_df(base, context)?);
1073 }
1074 }
1075 }
1076 Ok(dummy_udf_expr("_map_project", args))
1077}
1078
1079fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1082 match expr {
1083 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1084 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1085 Expr::Literal(CypherLiteral::Integer(i)) => {
1086 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1087 }
1088 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1089 .map(serde_json::Value::Number)
1090 .or(Some(serde_json::Value::Null)),
1091 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1092 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1093 Expr::Map(entries) => {
1094 let mut map = serde_json::Map::new();
1095 for (k, v) in entries {
1096 map.insert(k.clone(), try_expr_to_json(v)?);
1097 }
1098 Some(serde_json::Value::Object(map))
1099 }
1100 _ => None,
1101 }
1102}
1103
1104fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1106 items.iter().map(try_expr_to_json).collect()
1107}
1108
1109fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1111 match lit {
1112 CypherLiteral::Null => Ok(ScalarValue::Null),
1113 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1114 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1115 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1116 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1117 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1118 }
1119}
1120
1121fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1123 match value {
1124 Value::Null => Ok(ScalarValue::Null),
1125 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1126 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1127 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1128 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1129 Value::List(items) => {
1130 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1132 let scalars = scalars?;
1133
1134 let data_type = infer_common_scalar_type(&scalars);
1136
1137 let typed_scalars: Vec<ScalarValue> = scalars
1139 .into_iter()
1140 .map(|s| {
1141 if matches!(s, ScalarValue::Null) {
1142 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1143 }
1144
1145 match (s, &data_type) {
1146 (
1147 ScalarValue::Int64(Some(v)),
1148 datafusion::arrow::datatypes::DataType::Float64,
1149 ) => ScalarValue::Float64(Some(v as f64)),
1150 (
1154 s @ ScalarValue::LargeBinary(_),
1155 datafusion::arrow::datatypes::DataType::LargeBinary,
1156 ) => s,
1157 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1158 let s_str = s.to_string();
1160 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1161 }
1162 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1163 if matches!(s, ScalarValue::Utf8(_)) {
1165 s
1166 } else {
1167 ScalarValue::Utf8(Some(s.to_string()))
1168 }
1169 }
1170 (s, _) => s,
1171 }
1172 })
1173 .collect();
1174
1175 if typed_scalars.is_empty() {
1177 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1178 &[],
1179 &data_type,
1180 )))
1181 } else {
1182 Ok(ScalarValue::List(ScalarValue::new_list(
1183 &typed_scalars,
1184 &data_type,
1185 true,
1186 )))
1187 }
1188 }
1189 Value::Map(map) => {
1190 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1193 entries.sort_by_key(|(k, _)| *k);
1194
1195 if entries.is_empty() {
1196 return Ok(ScalarValue::Struct(Arc::new(
1197 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1198 )));
1199 }
1200
1201 let mut fields_arrays = Vec::with_capacity(entries.len());
1202
1203 for (k, v) in entries {
1204 let scalar = value_to_scalar(v)?;
1205 let dt = scalar.data_type();
1206 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1207 let array = scalar.to_array()?;
1208 fields_arrays.push((field, array));
1209 }
1210
1211 Ok(ScalarValue::Struct(Arc::new(
1212 datafusion::arrow::array::StructArray::from(fields_arrays),
1213 )))
1214 }
1215 Value::Temporal(tv) => {
1216 use uni_common::TemporalValue;
1217 match tv {
1218 TemporalValue::Date { days_since_epoch } => {
1219 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1220 }
1221 TemporalValue::LocalTime {
1222 nanos_since_midnight,
1223 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1224 TemporalValue::Time {
1225 nanos_since_midnight,
1226 offset_seconds,
1227 } => {
1228 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1230 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1231
1232 let nanos_arr =
1233 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1234 as ArrayRef;
1235 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1236
1237 let fields = Fields::from(vec![
1238 Field::new(
1239 "nanos_since_midnight",
1240 ArrowDataType::Time64(TimeUnit::Nanosecond),
1241 true,
1242 ),
1243 Field::new("offset_seconds", ArrowDataType::Int32, true),
1244 ]);
1245
1246 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1247 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1248 }
1249 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1250 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1251 ),
1252 TemporalValue::DateTime {
1253 nanos_since_epoch,
1254 offset_seconds,
1255 timezone_name,
1256 } => {
1257 use arrow::array::{
1259 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1260 };
1261 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1262
1263 let nanos_arr =
1264 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1265 as ArrayRef;
1266 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1267 let tz_arr =
1268 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1269
1270 let fields = Fields::from(vec![
1271 Field::new(
1272 "nanos_since_epoch",
1273 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1274 true,
1275 ),
1276 Field::new("offset_seconds", ArrowDataType::Int32, true),
1277 Field::new("timezone_name", ArrowDataType::Utf8, true),
1278 ]);
1279
1280 let struct_arr =
1281 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1282 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1283 }
1284 TemporalValue::Duration {
1285 months,
1286 days,
1287 nanos,
1288 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1289 arrow::datatypes::IntervalMonthDayNano {
1290 months: *months as i32,
1291 days: *days as i32,
1292 nanoseconds: *nanos,
1293 },
1294 ))),
1295 TemporalValue::Btic { lo, hi, meta } => {
1296 let btic = uni_btic::Btic::new(*lo, *hi, *meta)
1297 .map_err(|e| anyhow::anyhow!("invalid BTIC value: {}", e))?;
1298 let packed = uni_btic::encode::encode(&btic);
1299 Ok(ScalarValue::FixedSizeBinary(24, Some(packed.to_vec())))
1300 }
1301 }
1302 }
1303 Value::Vector(v) => {
1304 let cv_bytes = uni_common::cypher_value_codec::encode(&Value::Vector(v.clone()));
1306 Ok(ScalarValue::LargeBinary(Some(cv_bytes)))
1307 }
1308 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1309 other => {
1311 let json_val: serde_json::Value = other.clone().into();
1312 let json_str = serde_json::to_string(&json_val)
1313 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1314 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1315 }
1316 }
1317}
1318
1319fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1321 match op {
1322 BinaryOp::Eq => Ok(left.eq(right)),
1326 BinaryOp::NotEq => Ok(left.not_eq(right)),
1327 BinaryOp::Lt => Ok(left.lt(right)),
1328 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1329 BinaryOp::Gt => Ok(left.gt(right)),
1330 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1331
1332 BinaryOp::And => Ok(left.and(right)),
1334 BinaryOp::Or => Ok(left.or(right)),
1335 BinaryOp::Xor => {
1336 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1338 }
1339
1340 BinaryOp::Add => {
1346 if is_list_expr(&left) || is_list_expr(&right) {
1347 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1348 } else {
1349 Ok(left + right)
1350 }
1351 }
1352 BinaryOp::Sub => Ok(left - right),
1353 BinaryOp::Mul => Ok(left * right),
1354 BinaryOp::Div => Ok(left / right),
1355 BinaryOp::Mod => Ok(left % right),
1356 BinaryOp::Pow => {
1357 let left_f = datafusion::logical_expr::cast(
1360 left,
1361 datafusion::arrow::datatypes::DataType::Float64,
1362 );
1363 let right_f = datafusion::logical_expr::cast(
1364 right,
1365 datafusion::arrow::datatypes::DataType::Float64,
1366 );
1367 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1368 }
1369
1370 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1372 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1373 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1374
1375 BinaryOp::Regex => {
1376 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1377 }
1378
1379 BinaryOp::ApproxEq => Err(anyhow!(
1380 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1381 )),
1382 }
1383}
1384
1385macro_rules! check_args {
1390 (1, $df_args:expr, $name:expr) => {
1391 if let Err(e) = require_arg($df_args, $name) {
1392 return Some(Err(e));
1393 }
1394 };
1395 ($n:expr, $df_args:expr, $name:expr) => {
1396 if let Err(e) = require_args($df_args, $n, $name) {
1397 return Some(Err(e));
1398 }
1399 };
1400}
1401
1402fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1405 if df_args.len() < count {
1406 let noun = if count == 1 { "argument" } else { "arguments" };
1407 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1408 }
1409 Ok(())
1410}
1411
1412fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1414 require_args(df_args, 1, func_name)
1415}
1416
1417fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1419 df_args[0].clone()
1420}
1421
1422pub fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1424 DfExpr::Cast(datafusion::logical_expr::Cast {
1425 expr: Box::new(expr),
1426 data_type,
1427 })
1428}
1429
1430pub fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1436 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1437 Arc::new(crate::df_udfs::create_cypher_list_to_cv_udf()),
1438 vec![expr],
1439 ))
1440}
1441
1442pub fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1446 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1447 Arc::new(crate::df_udfs::create_cypher_scalar_to_cv_udf()),
1448 vec![expr],
1449 ))
1450}
1451
1452fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1454 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1455 Box::new(left),
1456 op,
1457 Box::new(right),
1458 ))
1459}
1460
1461pub fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1466 use datafusion::logical_expr::Operator;
1467 match op {
1468 Operator::Eq => Some("_cypher_equal"),
1469 Operator::NotEq => Some("_cypher_not_equal"),
1470 Operator::Lt => Some("_cypher_lt"),
1471 Operator::LtEq => Some("_cypher_lt_eq"),
1472 Operator::Gt => Some("_cypher_gt"),
1473 Operator::GtEq => Some("_cypher_gt_eq"),
1474 _ => None,
1475 }
1476}
1477
1478fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1480 use datafusion::logical_expr::Operator;
1481 match op {
1482 Operator::Plus => Some("_cypher_add"),
1483 Operator::Minus => Some("_cypher_sub"),
1484 Operator::Multiply => Some("_cypher_mul"),
1485 Operator::Divide => Some("_cypher_div"),
1486 Operator::Modulo => Some("_cypher_mod"),
1487 _ => None,
1488 }
1489}
1490
1491fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1496where
1497 F: FnOnce(DfExpr) -> DfExpr,
1498{
1499 require_arg(df_args, func_name)?;
1500 Ok(math_fn(cast_expr(
1501 first_arg(df_args),
1502 datafusion::arrow::datatypes::DataType::Float64,
1503 )))
1504}
1505
1506fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1508 if distinct {
1509 expr.distinct()
1510 .build()
1511 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1512 } else {
1513 Ok(expr)
1514 }
1515}
1516
1517fn translate_aggregate_function(
1519 name_upper: &str,
1520 df_args: &[DfExpr],
1521 distinct: bool,
1522) -> Option<Result<DfExpr>> {
1523 match name_upper {
1524 "COUNT" => {
1525 let expr = if df_args.is_empty() {
1526 datafusion::functions_aggregate::count::count(lit(1i64))
1527 } else {
1528 datafusion::functions_aggregate::count::count(first_arg(df_args))
1529 };
1530 Some(maybe_distinct(expr, distinct, "COUNT"))
1531 }
1532 "SUM" => {
1533 check_args!(1, df_args, "SUM");
1534 let udaf = Arc::new(crate::df_udfs::create_cypher_sum_udaf());
1535 Some(maybe_distinct(
1536 udaf.call(vec![first_arg(df_args)]),
1537 distinct,
1538 "SUM",
1539 ))
1540 }
1541 "AVG" => {
1542 check_args!(1, df_args, "AVG");
1543 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1544 let expr = datafusion::functions_aggregate::average::avg(coerced);
1545 Some(maybe_distinct(expr, distinct, "AVG"))
1546 }
1547 "MIN" => {
1548 check_args!(1, df_args, "MIN");
1549 let udaf = Arc::new(crate::df_udfs::create_cypher_min_udaf());
1550 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1551 }
1552 "MAX" => {
1553 check_args!(1, df_args, "MAX");
1554 let udaf = Arc::new(crate::df_udfs::create_cypher_max_udaf());
1555 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1556 }
1557 "PERCENTILEDISC" => {
1558 if df_args.len() != 2 {
1559 return Some(Err(anyhow!(
1560 "percentileDisc() requires exactly 2 arguments"
1561 )));
1562 }
1563 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1564 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_disc_udaf());
1565 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1566 }
1567 "PERCENTILECONT" => {
1568 if df_args.len() != 2 {
1569 return Some(Err(anyhow!(
1570 "percentileCont() requires exactly 2 arguments"
1571 )));
1572 }
1573 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1574 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_cont_udaf());
1575 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1576 }
1577 "COLLECT" => {
1578 check_args!(1, df_args, "COLLECT");
1579 Some(Ok(crate::df_udfs::create_cypher_collect_expr(
1580 first_arg(df_args),
1581 distinct,
1582 )))
1583 }
1584 "BTIC_MIN" => {
1586 check_args!(1, df_args, "btic_min");
1587 let udaf = Arc::new(crate::df_udfs::create_btic_min_udaf());
1588 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1589 }
1590 "BTIC_MAX" => {
1591 check_args!(1, df_args, "btic_max");
1592 let udaf = Arc::new(crate::df_udfs::create_btic_max_udaf());
1593 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1594 }
1595 "BTIC_SPAN_AGG" => {
1596 check_args!(1, df_args, "btic_span_agg");
1597 let udaf = Arc::new(crate::df_udfs::create_btic_span_agg_udaf());
1598 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1599 }
1600 "BTIC_COUNT_AT" => {
1601 if df_args.len() != 2 {
1602 return Some(Err(anyhow!("btic_count_at requires 2 arguments")));
1603 }
1604 let udaf = Arc::new(crate::df_udfs::create_btic_count_at_udaf());
1605 Some(Ok(udaf.call(df_args.to_vec())))
1606 }
1607 _ => None,
1608 }
1609}
1610
1611fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1614 match name_upper {
1615 "TOSTRING" => {
1616 check_args!(1, df_args, "toString");
1617 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1618 }
1619 "TOINTEGER" | "TOINT" => {
1620 check_args!(1, df_args, "toInteger");
1621 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1622 }
1623 "TOFLOAT" => {
1624 check_args!(1, df_args, "toFloat");
1625 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1626 }
1627 "TOBOOLEAN" | "TOBOOL" => {
1628 check_args!(1, df_args, "toBoolean");
1629 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1630 }
1631 "UPPER" | "TOUPPER" => {
1632 check_args!(1, df_args, "upper");
1633 Some(Ok(datafusion::functions::string::expr_fn::upper(
1634 first_arg(df_args),
1635 )))
1636 }
1637 "LOWER" | "TOLOWER" => {
1638 check_args!(1, df_args, "lower");
1639 Some(Ok(datafusion::functions::string::expr_fn::lower(
1640 first_arg(df_args),
1641 )))
1642 }
1643 "SUBSTRING" => {
1644 check_args!(2, df_args, "substring");
1645 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1646 }
1647 "TRIM" => {
1648 check_args!(1, df_args, "TRIM");
1649 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1650 first_arg(df_args),
1651 ])))
1652 }
1653 "LTRIM" => {
1654 check_args!(1, df_args, "LTRIM");
1655 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1656 first_arg(df_args),
1657 ])))
1658 }
1659 "RTRIM" => {
1660 check_args!(1, df_args, "RTRIM");
1661 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1662 first_arg(df_args),
1663 ])))
1664 }
1665 "LEFT" => {
1666 check_args!(2, df_args, "left");
1667 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1668 df_args[0].clone(),
1669 df_args[1].clone(),
1670 )))
1671 }
1672 "RIGHT" => {
1673 check_args!(2, df_args, "right");
1674 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1675 df_args[0].clone(),
1676 df_args[1].clone(),
1677 )))
1678 }
1679 "REPLACE" => {
1680 check_args!(3, df_args, "replace");
1681 Some(Ok(datafusion::functions::string::expr_fn::replace(
1682 df_args[0].clone(),
1683 df_args[1].clone(),
1684 df_args[2].clone(),
1685 )))
1686 }
1687 "REVERSE" => {
1688 check_args!(1, df_args, "reverse");
1689 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1690 }
1691 "SPLIT" => {
1692 check_args!(2, df_args, "split");
1693 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1694 }
1695 "SIZE" | "LENGTH" => {
1696 check_args!(1, df_args, name_upper);
1697 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1698 }
1699 _ => None,
1700 }
1701}
1702
1703fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1706 use datafusion::functions::math::expr_fn;
1707
1708 let unary_f64 =
1710 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1711
1712 match name_upper {
1713 "ABS" => {
1714 check_args!(1, df_args, "abs");
1715 Some(Ok(crate::df_udfs::cypher_abs_expr(first_arg(df_args))))
1719 }
1720 "CEIL" | "CEILING" => {
1721 check_args!(1, df_args, "ceil");
1722 Some(Ok(expr_fn::ceil(first_arg(df_args))))
1723 }
1724 "FLOOR" => {
1725 check_args!(1, df_args, "floor");
1726 Some(Ok(expr_fn::floor(first_arg(df_args))))
1727 }
1728 "ROUND" => {
1729 check_args!(1, df_args, "round");
1730 let args = if df_args.len() == 1 {
1731 vec![first_arg(df_args)]
1732 } else {
1733 vec![df_args[0].clone(), df_args[1].clone()]
1734 };
1735 Some(Ok(expr_fn::round(args)))
1736 }
1737 "SIGN" => {
1738 check_args!(1, df_args, "sign");
1739 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1740 Some(Ok(expr_fn::signum(coerced)))
1741 }
1742 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1743 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1744 "LOG10" => unary_f64("log10", expr_fn::log10),
1745 "EXP" => unary_f64("exp", expr_fn::exp),
1746 "SIN" => unary_f64("sin", expr_fn::sin),
1747 "COS" => unary_f64("cos", expr_fn::cos),
1748 "TAN" => unary_f64("tan", expr_fn::tan),
1749 "ASIN" => unary_f64("asin", expr_fn::asin),
1750 "ACOS" => unary_f64("acos", expr_fn::acos),
1751 "ATAN" => unary_f64("atan", expr_fn::atan),
1752 "ATAN2" => {
1753 check_args!(2, df_args, "atan2");
1754 let cast_f64 =
1755 |e: DfExpr| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64);
1756 Some(Ok(expr_fn::atan2(
1757 cast_f64(df_args[0].clone()),
1758 cast_f64(df_args[1].clone()),
1759 )))
1760 }
1761 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1762 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1763 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1764 _ => None,
1765 }
1766}
1767
1768fn translate_temporal_function(
1771 name_upper: &str,
1772 name: &str,
1773 df_args: &[DfExpr],
1774 context: Option<&TranslationContext>,
1775) -> Option<Result<DfExpr>> {
1776 match name_upper {
1777 "DATE"
1778 | "TIME"
1779 | "LOCALTIME"
1780 | "LOCALDATETIME"
1781 | "DATETIME"
1782 | "DURATION"
1783 | "YEAR"
1784 | "MONTH"
1785 | "DAY"
1786 | "HOUR"
1787 | "MINUTE"
1788 | "SECOND"
1789 | "DURATION.BETWEEN"
1790 | "DURATION.INMONTHS"
1791 | "DURATION.INDAYS"
1792 | "DURATION.INSECONDS"
1793 | "DATETIME.FROMEPOCH"
1794 | "DATETIME.FROMEPOCHMILLIS"
1795 | "DATE.TRUNCATE"
1796 | "TIME.TRUNCATE"
1797 | "DATETIME.TRUNCATE"
1798 | "LOCALDATETIME.TRUNCATE"
1799 | "LOCALTIME.TRUNCATE"
1800 | "DATETIME.TRANSACTION"
1801 | "DATETIME.STATEMENT"
1802 | "DATETIME.REALTIME"
1803 | "DATE.TRANSACTION"
1804 | "DATE.STATEMENT"
1805 | "DATE.REALTIME"
1806 | "TIME.TRANSACTION"
1807 | "TIME.STATEMENT"
1808 | "TIME.REALTIME"
1809 | "LOCALTIME.TRANSACTION"
1810 | "LOCALTIME.STATEMENT"
1811 | "LOCALTIME.REALTIME"
1812 | "LOCALDATETIME.TRANSACTION"
1813 | "LOCALDATETIME.STATEMENT"
1814 | "LOCALDATETIME.REALTIME" => {
1815 let stmt_time = context.map(|c| c.statement_time);
1819 if can_constant_fold(name_upper, df_args)
1820 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1821 {
1822 return Some(Ok(folded));
1823 }
1824 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1825 }
1826 _ => None,
1827 }
1828}
1829
1830fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1832 if name.contains("REALTIME") {
1834 return false;
1835 }
1836 if args.is_empty() {
1844 return matches!(
1845 name,
1846 "DATE"
1847 | "TIME"
1848 | "LOCALTIME"
1849 | "LOCALDATETIME"
1850 | "DATETIME"
1851 | "DATE.STATEMENT"
1852 | "TIME.STATEMENT"
1853 | "LOCALTIME.STATEMENT"
1854 | "LOCALDATETIME.STATEMENT"
1855 | "DATETIME.STATEMENT"
1856 | "DATE.TRANSACTION"
1857 | "TIME.TRANSACTION"
1858 | "LOCALTIME.TRANSACTION"
1859 | "LOCALDATETIME.TRANSACTION"
1860 | "DATETIME.TRANSACTION"
1861 );
1862 }
1863 args.iter().all(is_constant_expr)
1865}
1866
1867fn is_constant_expr(expr: &DfExpr) -> bool {
1869 match expr {
1870 DfExpr::Literal(_, _) => true,
1871 DfExpr::ScalarFunction(func) => {
1872 func.args.iter().all(is_constant_expr)
1874 }
1875 _ => false,
1876 }
1877}
1878
1879fn try_constant_fold_temporal(
1885 name: &str,
1886 args: &[DfExpr],
1887 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1888) -> Result<DfExpr> {
1889 let val_args: Vec<Value> = args
1891 .iter()
1892 .map(extract_constant_value)
1893 .collect::<Result<_>>()?;
1894
1895 let result = if val_args.is_empty() {
1897 if let Some(frozen) = stmt_time {
1898 crate::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1899 } else {
1900 crate::datetime::eval_datetime_function(name, &val_args)?
1901 }
1902 } else {
1903 crate::datetime::eval_datetime_function(name, &val_args)?
1904 };
1905
1906 let scalar = value_to_scalar(&result)?;
1908 Ok(DfExpr::Literal(scalar, None))
1909}
1910
1911fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1913 use crate::df_udfs::scalar_to_value;
1914 match expr {
1915 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1916 DfExpr::ScalarFunction(func) => {
1917 let mut map = std::collections::HashMap::new();
1920 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1921 for chunk in pairs.chunks(2) {
1922 if let [key_expr, val_expr] = chunk {
1923 let key = match key_expr {
1925 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1926 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1927 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1928 };
1929 let val = extract_constant_value(val_expr)?;
1930 map.insert(key, val);
1931 } else {
1932 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1933 }
1934 }
1935 Ok(Value::Map(map))
1936 }
1937 _ => Err(anyhow::anyhow!(
1938 "Cannot extract constant value from expression"
1939 )),
1940 }
1941}
1942
1943fn translate_btic_function(
1946 name_upper: &str,
1947 name: &str,
1948 df_args: &[DfExpr],
1949) -> Option<Result<DfExpr>> {
1950 if crate::expr_eval::is_btic_function(name_upper) {
1951 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1952 } else {
1953 None
1954 }
1955}
1956
1957fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1960 match name_upper {
1961 "HEAD" => {
1962 check_args!(1, df_args, "head");
1963 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1964 }
1965 "LAST" => {
1966 check_args!(1, df_args, "last");
1967 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1968 }
1969 "TAIL" => {
1970 check_args!(1, df_args, "tail");
1971 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
1972 }
1973 "RANGE" => {
1974 check_args!(2, df_args, "range");
1975 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
1976 }
1977 _ => None,
1978 }
1979}
1980
1981fn translate_graph_function(
1984 name_upper: &str,
1985 name: &str,
1986 df_args: &[DfExpr],
1987 args: &[Expr],
1988 context: Option<&TranslationContext>,
1989) -> Option<Result<DfExpr>> {
1990 match name_upper {
1991 "ID" => {
1992 if let Some(Expr::Variable(var)) = args.first() {
1995 let is_edge = context.is_some_and(|ctx| {
1996 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
1997 || ctx.mutation_edge_hints.iter().any(|h| h == var)
1998 });
1999 let id_suffix = if is_edge { COL_EID } else { COL_VID };
2000 Some(Ok(DfExpr::Column(Column::from_name(format!(
2001 "{}.{}",
2002 var, id_suffix
2003 )))))
2004 } else {
2005 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
2006 }
2007 }
2008 "CREATED_AT" | "UPDATED_AT" => {
2009 if let Some(Expr::Variable(var)) = args.first() {
2013 let suffix = if name_upper == "CREATED_AT" {
2014 "_created_at"
2015 } else {
2016 "_updated_at"
2017 };
2018 Some(Ok(DfExpr::Column(Column::from_name(format!(
2019 "{}.{}",
2020 var, suffix
2021 )))))
2022 } else {
2023 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2024 }
2025 }
2026 "LABELS" | "KEYS" => {
2027 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2032 }
2033 "TYPE" => {
2034 if let Some(Expr::Variable(var)) = args.first()
2038 && let Some(ctx) = context
2039 && let Some(label) = ctx.variable_labels.get(var)
2040 {
2041 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
2044 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
2045 expr: None,
2046 when_then_expr: vec![(
2047 Box::new(eid_col.is_not_null()),
2048 Box::new(lit(label.clone())),
2049 )],
2050 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
2051 })));
2052 }
2053 if let Some(Expr::Variable(var)) = args.first()
2057 && context
2058 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
2059 {
2060 return Some(Ok(DfExpr::Column(Column::from_name(format!(
2061 "{}.{}",
2062 var, COL_TYPE
2063 )))));
2064 }
2065 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
2066 }
2067 "PROPERTIES" => {
2068 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2071 }
2072 "UNI.TEMPORAL.VALIDAT" => {
2073 if let (
2076 Some(Expr::Variable(var)),
2077 Some(Expr::Literal(CypherLiteral::String(start_prop))),
2078 Some(Expr::Literal(CypherLiteral::String(end_prop))),
2079 Some(ts_expr),
2080 ) = (args.first(), args.get(1), args.get(2), args.get(3))
2081 {
2082 let start_col =
2083 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
2084 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
2085 let ts = match cypher_expr_to_df(ts_expr, context) {
2086 Ok(ts) => ts,
2087 Err(e) => return Some(Err(e)),
2088 };
2089
2090 let start_check = start_col.lt_eq(ts.clone());
2092 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
2094 let end_after = end_col.gt(ts);
2095 let end_check = end_null.or(end_after);
2096
2097 Some(Ok(start_check.and(end_check)))
2098 } else {
2099 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2101 }
2102 }
2103 "STARTNODE" | "ENDNODE" => {
2104 let mut udf_args = df_args.to_vec();
2107 let mut seen = std::collections::HashSet::new();
2108 if let Some(ctx) = context {
2109 for (var, kind) in &ctx.variable_kinds {
2111 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2112 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2113 }
2114 }
2115 for var in &ctx.node_variable_hints {
2118 if seen.insert(var.clone()) {
2119 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2120 }
2121 }
2122 }
2123 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2124 }
2125 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2126 "HASLABEL" => {
2127 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2128 return Some(Err(e));
2129 }
2130 if let Some(Expr::Variable(var)) = args.first() {
2132 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2133 let labels_col =
2135 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2136 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2137 labels_col,
2138 lit(label.clone()),
2139 )))
2140 } else {
2141 Some(Err(anyhow::anyhow!(
2143 "hasLabel requires string literal as second argument for DataFusion translation"
2144 )))
2145 }
2146 } else {
2147 Some(Err(anyhow::anyhow!(
2149 "hasLabel requires variable as first argument for DataFusion translation"
2150 )))
2151 }
2152 }
2153 _ => None,
2154 }
2155}
2156
2157fn translate_function_call(
2159 name: &str,
2160 args: &[Expr],
2161 distinct: bool,
2162 context: Option<&TranslationContext>,
2163) -> Result<DfExpr> {
2164 let df_args: Vec<DfExpr> = args
2165 .iter()
2166 .map(|arg| cypher_expr_to_df(arg, context))
2167 .collect::<Result<Vec<_>>>()?;
2168
2169 let name_upper = name.to_uppercase();
2170
2171 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2175 return result;
2176 }
2177
2178 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2179 return result;
2180 }
2181
2182 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2183 return result;
2184 }
2185
2186 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2187 return result;
2188 }
2189
2190 if let Some(result) = translate_btic_function(&name_upper, name, &df_args) {
2191 return result;
2192 }
2193
2194 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2195 return result;
2196 }
2197
2198 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2199 return result;
2200 }
2201
2202 match name_upper.as_str() {
2204 "COALESCE" => {
2205 require_arg(&df_args, "coalesce")?;
2206 if df_args.len() == 1 {
2211 return Ok(df_args.into_iter().next().unwrap());
2212 }
2213 let n = df_args.len();
2214 let (init, last) = df_args.split_at(n - 1);
2215 let mut builder = datafusion::logical_expr::conditional_expressions::CaseBuilder::new(
2216 None,
2217 vec![],
2218 vec![],
2219 None,
2220 );
2221 for arg in init {
2222 builder.when(arg.clone().is_not_null(), arg.clone());
2223 }
2224 return Ok(builder.otherwise(last[0].clone())?);
2225 }
2226 "NULLIF" => {
2227 require_args(&df_args, 2, "nullif")?;
2228 return Ok(datafusion::functions::expr_fn::nullif(
2229 df_args[0].clone(),
2230 df_args[1].clone(),
2231 ));
2232 }
2233 _ => {}
2234 }
2235
2236 match name_upper.as_str() {
2238 "SIMILAR_TO" | "VECTOR_SIMILARITY" => {
2239 return Ok(dummy_udf_expr(&name_upper.to_lowercase(), df_args));
2240 }
2241 _ => {}
2242 }
2243
2244 Ok(dummy_udf_expr(name, df_args))
2246}
2247
2248#[derive(Debug)]
2253struct DummyUdf {
2254 name: String,
2255 signature: datafusion::logical_expr::Signature,
2256 ret_type: datafusion::arrow::datatypes::DataType,
2257}
2258
2259impl DummyUdf {
2260 fn new(name: String) -> Self {
2261 let ret_type = dummy_udf_return_type(&name);
2262 Self {
2263 name,
2264 signature: datafusion::logical_expr::Signature::variadic_any(
2265 datafusion::logical_expr::Volatility::Immutable,
2266 ),
2267 ret_type,
2268 }
2269 }
2270}
2271
2272fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2285 use datafusion::arrow::datatypes::DataType;
2286 match name {
2287 "_cypher_add"
2291 | "_cypher_sub"
2292 | "_cypher_mul"
2293 | "_cypher_div"
2294 | "_cypher_mod"
2295 | "_cypher_list_concat"
2296 | "_cypher_list_append"
2297 | "_make_cypher_list"
2298 | "_map_project"
2299 | "_cypher_list_to_cv"
2300 | "_cypher_tail" => DataType::LargeBinary,
2301 _ => DataType::Null,
2305 }
2306}
2307
2308impl PartialEq for DummyUdf {
2309 fn eq(&self, other: &Self) -> bool {
2310 self.name == other.name
2311 }
2312}
2313
2314impl Eq for DummyUdf {}
2315
2316impl Hash for DummyUdf {
2317 fn hash<H: Hasher>(&self, state: &mut H) {
2318 self.name.hash(state);
2319 }
2320}
2321
2322pub fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2324 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2325 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2326 DummyUdf::new(name.to_lowercase()),
2327 )),
2328 args,
2329 })
2330}
2331
2332impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2333 fn as_any(&self) -> &dyn std::any::Any {
2334 self
2335 }
2336
2337 fn name(&self) -> &str {
2338 &self.name
2339 }
2340
2341 fn signature(&self) -> &datafusion::logical_expr::Signature {
2342 &self.signature
2343 }
2344
2345 fn return_type(
2346 &self,
2347 arg_types: &[datafusion::arrow::datatypes::DataType],
2348 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2349 match self.name.as_str() {
2354 "_cypher_add" | "_cypher_sub" | "_cypher_mul" | "_cypher_div" | "_cypher_mod" => {
2355 Ok(crate::df_udfs::cypher_arith_return_type(arg_types))
2356 }
2357 _ => Ok(self.ret_type.clone()),
2361 }
2362 }
2363
2364 fn invoke_with_args(
2365 &self,
2366 _args: ScalarFunctionArgs,
2367 ) -> datafusion::error::Result<ColumnarValue> {
2368 Err(datafusion::error::DataFusionError::Plan(format!(
2369 "UDF '{}' is not registered. Register it via SessionContext.",
2370 self.name
2371 )))
2372 }
2373}
2374
2375pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2379 let mut properties = Vec::new();
2380 collect_properties_recursive(expr, &mut properties);
2381 properties.sort();
2382 properties.dedup();
2383 properties
2384}
2385
2386fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2387 match expr {
2388 Expr::PatternComprehension { .. } => {}
2389 Expr::Property(base, prop) => {
2390 if let Ok(var_name) = extract_variable_name(base) {
2391 properties.push((var_name, prop.clone()));
2392 }
2393 collect_properties_recursive(base, properties);
2394 }
2395 Expr::ArrayIndex { array, index } => {
2396 if let Ok(var_name) = extract_variable_name(array)
2397 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2398 {
2399 properties.push((var_name, prop_name.clone()));
2400 }
2401 collect_properties_recursive(array, properties);
2402 collect_properties_recursive(index, properties);
2403 }
2404 Expr::ArraySlice { array, start, end } => {
2405 collect_properties_recursive(array, properties);
2406 if let Some(s) = start {
2407 collect_properties_recursive(s, properties);
2408 }
2409 if let Some(e) = end {
2410 collect_properties_recursive(e, properties);
2411 }
2412 }
2413 Expr::List(items) => {
2414 for item in items {
2415 collect_properties_recursive(item, properties);
2416 }
2417 }
2418 Expr::Map(entries) => {
2419 for (_, value) in entries {
2420 collect_properties_recursive(value, properties);
2421 }
2422 }
2423 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2424 collect_properties_recursive(inner, properties);
2425 }
2426 Expr::FunctionCall { args, .. } => {
2427 for arg in args {
2428 collect_properties_recursive(arg, properties);
2429 }
2430 }
2431 Expr::BinaryOp { left, right, .. } => {
2432 collect_properties_recursive(left, properties);
2433 collect_properties_recursive(right, properties);
2434 }
2435 Expr::UnaryOp { expr, .. } => {
2436 collect_properties_recursive(expr, properties);
2437 }
2438 Expr::Case {
2439 expr,
2440 when_then,
2441 else_expr,
2442 } => {
2443 if let Some(e) = expr {
2444 collect_properties_recursive(e, properties);
2445 }
2446 for (when_e, then_e) in when_then {
2447 collect_properties_recursive(when_e, properties);
2448 collect_properties_recursive(then_e, properties);
2449 }
2450 if let Some(e) = else_expr {
2451 collect_properties_recursive(e, properties);
2452 }
2453 }
2454 Expr::Reduce {
2455 init, list, expr, ..
2456 } => {
2457 collect_properties_recursive(init, properties);
2458 collect_properties_recursive(list, properties);
2459 collect_properties_recursive(expr, properties);
2460 }
2461 Expr::Quantifier {
2462 list, predicate, ..
2463 } => {
2464 collect_properties_recursive(list, properties);
2465 collect_properties_recursive(predicate, properties);
2466 }
2467 Expr::ListComprehension {
2468 list,
2469 where_clause,
2470 map_expr,
2471 ..
2472 } => {
2473 collect_properties_recursive(list, properties);
2474 if let Some(filter) = where_clause {
2475 collect_properties_recursive(filter, properties);
2476 }
2477 collect_properties_recursive(map_expr, properties);
2478 }
2479 Expr::In { expr, list } => {
2480 collect_properties_recursive(expr, properties);
2481 collect_properties_recursive(list, properties);
2482 }
2483 Expr::ValidAt {
2484 entity, timestamp, ..
2485 } => {
2486 collect_properties_recursive(entity, properties);
2487 collect_properties_recursive(timestamp, properties);
2488 }
2489 Expr::MapProjection { base, items } => {
2490 collect_properties_recursive(base, properties);
2491 for item in items {
2492 match item {
2493 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2494 if let Ok(var_name) = extract_variable_name(base) {
2495 properties.push((var_name, prop.clone()));
2496 }
2497 }
2498 uni_cypher::ast::MapProjectionItem::AllProperties => {
2499 if let Ok(var_name) = extract_variable_name(base) {
2500 properties.push((var_name, "*".to_string()));
2501 }
2502 }
2503 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2504 collect_properties_recursive(expr, properties);
2505 }
2506 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2507 }
2508 }
2509 }
2510 Expr::LabelCheck { expr, .. } => {
2511 collect_properties_recursive(expr, properties);
2512 }
2513 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2515 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2516 }
2517}
2518
2519pub fn wider_numeric_type(
2526 a: &datafusion::arrow::datatypes::DataType,
2527 b: &datafusion::arrow::datatypes::DataType,
2528) -> datafusion::arrow::datatypes::DataType {
2529 use datafusion::arrow::datatypes::DataType;
2530
2531 fn numeric_rank(dt: &DataType) -> u8 {
2532 match dt {
2533 DataType::Int8 | DataType::UInt8 => 1,
2534 DataType::Int16 | DataType::UInt16 => 2,
2535 DataType::Int32 | DataType::UInt32 => 3,
2536 DataType::Int64 | DataType::UInt64 => 4,
2537 DataType::Float16 => 5,
2538 DataType::Float32 => 6,
2539 DataType::Float64 => 7,
2540 _ => 0,
2541 }
2542 }
2543
2544 if numeric_rank(a) >= numeric_rank(b) {
2545 a.clone()
2546 } else {
2547 b.clone()
2548 }
2549}
2550
2551fn resolve_column_type_fallback(
2557 expr: &DfExpr,
2558 schema: &datafusion::common::DFSchema,
2559) -> Option<datafusion::arrow::datatypes::DataType> {
2560 if let DfExpr::Column(col) = expr {
2561 let col_name = &col.name;
2562 for (_, field) in schema.iter() {
2564 if field.name() == col_name {
2565 return Some(field.data_type().clone());
2566 }
2567 }
2568 }
2569 None
2570}
2571
2572fn contains_division(expr: &DfExpr) -> bool {
2575 match expr {
2576 DfExpr::BinaryExpr(b) => {
2577 b.op == datafusion::logical_expr::Operator::Divide
2578 || contains_division(&b.left)
2579 || contains_division(&b.right)
2580 }
2581 DfExpr::Cast(c) => contains_division(&c.expr),
2582 DfExpr::TryCast(c) => contains_division(&c.expr),
2583 _ => false,
2584 }
2585}
2586
2587pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2593 use datafusion::arrow::datatypes::DataType;
2594 use datafusion::logical_expr::ExprSchemable;
2595
2596 match expr {
2597 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2598 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2599 DfExpr::Case(case) => coerce_case_expr(case, schema),
2600 DfExpr::InList(in_list) => {
2601 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2602 let coerced_list = in_list
2603 .list
2604 .iter()
2605 .map(|e| apply_type_coercion(e, schema))
2606 .collect::<Result<Vec<_>>>()?;
2607 let expr_type = coerced_expr
2608 .get_type(schema)
2609 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2610 crate::cypher_type_coerce::build_cypher_in_list(
2611 coerced_expr,
2612 &expr_type,
2613 coerced_list,
2614 in_list.negated,
2615 schema,
2616 )
2617 }
2618 DfExpr::Not(inner) => {
2619 let coerced_inner = apply_type_coercion(inner, schema)?;
2620 let inner_type = coerced_inner.get_type(schema).ok();
2621 let final_inner = if inner_type
2622 .as_ref()
2623 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2624 {
2625 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2626 } else if inner_type
2627 .as_ref()
2628 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2629 {
2630 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2631 } else {
2632 coerced_inner
2633 };
2634 Ok(DfExpr::Not(Box::new(final_inner)))
2635 }
2636 DfExpr::IsNull(inner) => {
2637 let coerced_inner = apply_type_coercion(inner, schema)?;
2638 Ok(coerced_inner.is_null())
2639 }
2640 DfExpr::IsNotNull(inner) => {
2641 let coerced_inner = apply_type_coercion(inner, schema)?;
2642 Ok(coerced_inner.is_not_null())
2643 }
2644 DfExpr::Negative(inner) => {
2645 let coerced_inner = apply_type_coercion(inner, schema)?;
2646 let inner_type = coerced_inner.get_type(schema).ok();
2647 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2648 Ok(dummy_udf_expr(
2649 "_cypher_mul",
2650 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2651 ))
2652 } else {
2653 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2654 }
2655 }
2656 DfExpr::Cast(cast) => {
2657 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2658 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2659 Box::new(coerced_inner),
2660 cast.data_type.clone(),
2661 )))
2662 }
2663 DfExpr::TryCast(cast) => {
2664 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2665 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2666 Box::new(coerced_inner),
2667 cast.data_type.clone(),
2668 )))
2669 }
2670 DfExpr::Alias(alias) => {
2671 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2672 Ok(coerced_inner.alias(alias.name.clone()))
2673 }
2674 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2675 _ => Ok(expr.clone()),
2676 }
2677}
2678
2679fn coerce_logical_operands(
2681 left: DfExpr,
2682 right: DfExpr,
2683 op: datafusion::logical_expr::Operator,
2684 schema: &datafusion::common::DFSchema,
2685) -> Option<DfExpr> {
2686 use datafusion::arrow::datatypes::DataType;
2687 use datafusion::logical_expr::ExprSchemable;
2688
2689 if !matches!(
2690 op,
2691 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2692 ) {
2693 return None;
2694 }
2695 let left_type = left.get_type(schema).ok();
2696 let right_type = right.get_type(schema).ok();
2697 let left_needs_cast = left_type
2698 .as_ref()
2699 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2700 let right_needs_cast = right_type
2701 .as_ref()
2702 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2703 let left_is_lb = left_type
2704 .as_ref()
2705 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2706 let right_is_lb = right_type
2707 .as_ref()
2708 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2709 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2710 return None;
2711 }
2712 let coerced_left = if left_is_lb {
2713 dummy_udf_expr("_cv_to_bool", vec![left])
2714 } else if left_needs_cast {
2715 datafusion::logical_expr::cast(left, DataType::Boolean)
2716 } else {
2717 left
2718 };
2719 let coerced_right = if right_is_lb {
2720 dummy_udf_expr("_cv_to_bool", vec![right])
2721 } else if right_needs_cast {
2722 datafusion::logical_expr::cast(right, DataType::Boolean)
2723 } else {
2724 right
2725 };
2726 Some(binary_expr(coerced_left, op, coerced_right))
2727}
2728
2729#[expect(
2732 clippy::too_many_arguments,
2733 reason = "Binary coercion needs all context"
2734)]
2735fn coerce_large_binary_ops(
2736 left: &DfExpr,
2737 right: &DfExpr,
2738 left_type: &datafusion::arrow::datatypes::DataType,
2739 right_type: &datafusion::arrow::datatypes::DataType,
2740 left_is_null: bool,
2741 op: datafusion::logical_expr::Operator,
2742 is_comparison: bool,
2743 is_arithmetic: bool,
2744) -> Option<Result<DfExpr>> {
2745 use datafusion::arrow::datatypes::DataType;
2746 use datafusion::logical_expr::Operator;
2747
2748 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2749 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2750
2751 if op == Operator::Plus {
2752 if left_is_lb && right_is_lb {
2753 return Some(Ok(dummy_udf_expr(
2754 "_cypher_add",
2755 vec![left.clone(), right.clone()],
2756 )));
2757 }
2758 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2759 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2760 if left_is_native_list && right_is_native_list {
2761 return Some(Ok(dummy_udf_expr(
2762 "_cypher_list_concat",
2763 vec![left.clone(), right.clone()],
2764 )));
2765 }
2766 if left_is_native_list || right_is_native_list {
2767 return Some(Ok(dummy_udf_expr(
2768 "_cypher_list_append",
2769 vec![left.clone(), right.clone()],
2770 )));
2771 }
2772 }
2773
2774 if (left_is_lb || right_is_lb) && is_comparison {
2775 if let Some(udf_name) = comparison_udf_name(op) {
2776 return Some(Ok(dummy_udf_expr(
2777 udf_name,
2778 vec![left.clone(), right.clone()],
2779 )));
2780 }
2781 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2782 }
2783
2784 if (left_is_lb || right_is_lb) && is_arithmetic {
2785 let udf_name =
2786 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2787 return Some(Ok(dummy_udf_expr(
2788 udf_name,
2789 vec![left.clone(), right.clone()],
2790 )));
2791 }
2792
2793 None
2794}
2795
2796fn coerce_temporal_comparisons(
2798 left: DfExpr,
2799 right: DfExpr,
2800 left_type: &datafusion::arrow::datatypes::DataType,
2801 right_type: &datafusion::arrow::datatypes::DataType,
2802 op: datafusion::logical_expr::Operator,
2803 is_comparison: bool,
2804) -> Option<DfExpr> {
2805 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2806 use datafusion::logical_expr::Operator;
2807
2808 if !is_comparison {
2809 return None;
2810 }
2811
2812 if uni_common::core::schema::is_datetime_struct(left_type)
2814 && uni_common::core::schema::is_datetime_struct(right_type)
2815 {
2816 return Some(binary_expr(
2817 extract_datetime_nanos(left),
2818 op,
2819 extract_datetime_nanos(right),
2820 ));
2821 }
2822
2823 if uni_common::core::schema::is_time_struct(left_type)
2825 && uni_common::core::schema::is_time_struct(right_type)
2826 {
2827 return Some(binary_expr(
2828 extract_time_nanos(left),
2829 op,
2830 extract_time_nanos(right),
2831 ));
2832 }
2833
2834 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2836 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2837
2838 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2839 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2840 {
2841 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2842 extract_datetime_nanos(left)
2843 } else {
2844 left
2845 };
2846 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2847 extract_datetime_nanos(right)
2848 } else {
2849 right
2850 };
2851 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2852 return Some(binary_expr(
2853 cast_expr(left_nanos, ts_type.clone()),
2854 op,
2855 cast_expr(right_nanos, ts_type),
2856 ));
2857 }
2858
2859 let left_is_duration = matches!(left_type, DataType::Interval(_));
2863 let right_is_duration = matches!(right_type, DataType::Interval(_));
2864 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2865 || uni_common::core::schema::is_time_struct(left_type)
2866 || matches!(
2867 left_type,
2868 DataType::Timestamp(_, _)
2869 | DataType::Date32
2870 | DataType::Date64
2871 | DataType::Time32(_)
2872 | DataType::Time64(_)
2873 );
2874 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2875 || uni_common::core::schema::is_time_struct(right_type)
2876 || matches!(
2877 right_type,
2878 DataType::Timestamp(_, _)
2879 | DataType::Date32
2880 | DataType::Date64
2881 | DataType::Time32(_)
2882 | DataType::Time64(_)
2883 );
2884
2885 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2886 {
2887 return Some(match op {
2888 Operator::Eq => lit(false),
2889 Operator::NotEq => lit(true),
2890 _ => lit(ScalarValue::Boolean(None)),
2891 });
2892 }
2893
2894 None
2895}
2896
2897fn coerce_mismatched_types(
2900 left: DfExpr,
2901 right: DfExpr,
2902 left_type: &datafusion::arrow::datatypes::DataType,
2903 right_type: &datafusion::arrow::datatypes::DataType,
2904 op: datafusion::logical_expr::Operator,
2905 is_comparison: bool,
2906) -> Option<Result<DfExpr>> {
2907 use datafusion::arrow::datatypes::DataType;
2908 use datafusion::logical_expr::Operator;
2909
2910 if left_type == right_type {
2911 return None;
2912 }
2913
2914 if left_type.is_numeric() && right_type.is_numeric() {
2916 if left_type == &DataType::Int64
2917 && right_type == &DataType::UInt64
2918 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2919 {
2920 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2921 return Some(Ok(binary_expr(coerced_left, op, right)));
2922 }
2923 if left_type == &DataType::UInt64
2924 && right_type == &DataType::Int64
2925 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2926 {
2927 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2928 return Some(Ok(binary_expr(left, op, coerced_right)));
2929 }
2930 let target = wider_numeric_type(left_type, right_type);
2931 let coerced_left = if *left_type != target {
2932 datafusion::logical_expr::cast(left, target.clone())
2933 } else {
2934 left
2935 };
2936 let coerced_right = if *right_type != target {
2937 datafusion::logical_expr::cast(right, target)
2938 } else {
2939 right
2940 };
2941 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2942 }
2943
2944 if is_comparison {
2946 match (left_type, right_type) {
2947 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2948 let right = normalize_datetime_literal(right);
2949 return Some(Ok(binary_expr(
2950 left,
2951 op,
2952 datafusion::logical_expr::cast(right, ts.clone()),
2953 )));
2954 }
2955 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2956 let left = normalize_datetime_literal(left);
2957 return Some(Ok(binary_expr(
2958 datafusion::logical_expr::cast(left, ts.clone()),
2959 op,
2960 right,
2961 )));
2962 }
2963 _ => {}
2964 }
2965 }
2966
2967 if is_comparison
2969 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2970 {
2971 let l_inner = l_field.data_type();
2972 let r_inner = r_field.data_type();
2973 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
2974 let target_inner = wider_numeric_type(l_inner, r_inner);
2975 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
2976 "item",
2977 target_inner,
2978 true,
2979 )));
2980 return Some(Ok(binary_expr(
2981 datafusion::logical_expr::cast(left, target_type.clone()),
2982 op,
2983 datafusion::logical_expr::cast(right, target_type),
2984 )));
2985 }
2986 }
2987
2988 if is_primitive_type(left_type) && is_primitive_type(right_type) {
2990 if op == Operator::Plus {
2991 return Some(crate::cypher_type_coerce::build_cypher_plus(
2992 left, left_type, right, right_type,
2993 ));
2994 }
2995 if is_comparison {
2996 return Some(Ok(crate::cypher_type_coerce::build_cypher_comparison(
2997 left, left_type, right, right_type, op,
2998 )));
2999 }
3000 }
3001
3002 None
3003}
3004
3005fn coerce_list_comparisons(
3007 left: DfExpr,
3008 right: DfExpr,
3009 left_type: &datafusion::arrow::datatypes::DataType,
3010 right_type: &datafusion::arrow::datatypes::DataType,
3011 op: datafusion::logical_expr::Operator,
3012 is_comparison: bool,
3013) -> Option<DfExpr> {
3014 use datafusion::arrow::datatypes::DataType;
3015 use datafusion::logical_expr::Operator;
3016
3017 if !is_comparison {
3018 return None;
3019 }
3020
3021 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
3022 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
3023
3024 if left_is_list
3026 && right_is_list
3027 && matches!(
3028 op,
3029 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
3030 )
3031 {
3032 let op_str = match op {
3033 Operator::Lt => "lt",
3034 Operator::LtEq => "lteq",
3035 Operator::Gt => "gt",
3036 Operator::GtEq => "gteq",
3037 _ => unreachable!(),
3038 };
3039 return Some(dummy_udf_expr(
3040 "_cypher_list_compare",
3041 vec![left, right, lit(op_str)],
3042 ));
3043 }
3044
3045 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
3047 let udf_name =
3048 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
3049 return Some(dummy_udf_expr(udf_name, vec![left, right]));
3050 }
3051
3052 if (left_is_list != right_is_list)
3054 && !matches!(left_type, DataType::Null)
3055 && !matches!(right_type, DataType::Null)
3056 {
3057 return Some(match op {
3058 Operator::Eq => lit(false),
3059 Operator::NotEq => lit(true),
3060 _ => lit(ScalarValue::Boolean(None)),
3061 });
3062 }
3063
3064 None
3065}
3066
3067fn coerce_binary_expr(
3069 binary: &datafusion::logical_expr::expr::BinaryExpr,
3070 schema: &datafusion::common::DFSchema,
3071) -> Result<DfExpr> {
3072 use datafusion::arrow::datatypes::DataType;
3073 use datafusion::logical_expr::ExprSchemable;
3074 use datafusion::logical_expr::Operator;
3075
3076 let left = apply_type_coercion(&binary.left, schema)?;
3077 let right = apply_type_coercion(&binary.right, schema)?;
3078
3079 let is_comparison = matches!(
3080 binary.op,
3081 Operator::Eq
3082 | Operator::NotEq
3083 | Operator::Lt
3084 | Operator::LtEq
3085 | Operator::Gt
3086 | Operator::GtEq
3087 );
3088 let is_arithmetic = matches!(
3089 binary.op,
3090 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
3091 );
3092
3093 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
3095 return Ok(result);
3096 }
3097
3098 if is_comparison || is_arithmetic {
3099 let left_type = match left.get_type(schema) {
3100 Ok(t) => t,
3101 Err(e) => {
3102 if let Some(t) = resolve_column_type_fallback(&left, schema) {
3103 t
3104 } else {
3105 log::warn!("Failed to get left type in binary expr: {}", e);
3106 return Ok(binary_expr(left, binary.op, right));
3107 }
3108 }
3109 };
3110 let right_type = match right.get_type(schema) {
3111 Ok(t) => t,
3112 Err(e) => {
3113 if let Some(t) = resolve_column_type_fallback(&right, schema) {
3114 t
3115 } else {
3116 log::warn!("Failed to get right type in binary expr: {}", e);
3117 return Ok(binary_expr(left, binary.op, right));
3118 }
3119 }
3120 };
3121
3122 let left_is_null = left_type.is_null();
3124 let right_is_null = right_type.is_null();
3125 if left_is_null && right_is_null {
3126 return Ok(lit(ScalarValue::Boolean(None)));
3127 }
3128 if left_is_null || right_is_null {
3129 let target = if left_is_null {
3130 &right_type
3131 } else {
3132 &left_type
3133 };
3134 if !matches!(target, DataType::LargeBinary) {
3135 let coerced_left = if left_is_null {
3136 datafusion::logical_expr::cast(left, target.clone())
3137 } else {
3138 left
3139 };
3140 let coerced_right = if right_is_null {
3141 datafusion::logical_expr::cast(right, target.clone())
3142 } else {
3143 right
3144 };
3145 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3146 }
3147 }
3148
3149 if let Some(result) = coerce_large_binary_ops(
3151 &left,
3152 &right,
3153 &left_type,
3154 &right_type,
3155 left_is_null,
3156 binary.op,
3157 is_comparison,
3158 is_arithmetic,
3159 ) {
3160 return result;
3161 }
3162
3163 if let Some(result) = coerce_temporal_comparisons(
3165 left.clone(),
3166 right.clone(),
3167 &left_type,
3168 &right_type,
3169 binary.op,
3170 is_comparison,
3171 ) {
3172 return Ok(result);
3173 }
3174
3175 let either_struct =
3177 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3178 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3179 || matches!(left_type, DataType::Struct(_)))
3180 && (matches!(right_type, DataType::LargeBinary)
3181 || matches!(right_type, DataType::Struct(_)));
3182 if is_comparison && either_struct && either_lb_or_struct {
3183 if let Some(udf_name) = comparison_udf_name(binary.op) {
3184 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3185 }
3186 return Ok(lit(ScalarValue::Boolean(None)));
3187 }
3188
3189 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3191 let udf_name = comparison_udf_name(binary.op)
3192 .expect("is_comparison guarantees a valid comparison operator");
3193 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3194 }
3195
3196 if binary.op == Operator::Plus
3198 && (crate::cypher_type_coerce::is_string_type(&left_type)
3199 || crate::cypher_type_coerce::is_string_type(&right_type))
3200 && is_primitive_type(&left_type)
3201 && is_primitive_type(&right_type)
3202 {
3203 return crate::cypher_type_coerce::build_cypher_plus(
3204 left,
3205 &left_type,
3206 right,
3207 &right_type,
3208 );
3209 }
3210
3211 if let Some(result) = coerce_mismatched_types(
3213 left.clone(),
3214 right.clone(),
3215 &left_type,
3216 &right_type,
3217 binary.op,
3218 is_comparison,
3219 ) {
3220 return result;
3221 }
3222
3223 if let Some(result) = coerce_list_comparisons(
3225 left.clone(),
3226 right.clone(),
3227 &left_type,
3228 &right_type,
3229 binary.op,
3230 is_comparison,
3231 ) {
3232 return Ok(result);
3233 }
3234
3235 if let Some(name) = arithmetic_udf_name(binary.op)
3244 && left_type == DataType::Int64
3245 && right_type == DataType::Int64
3246 && !is_list_expr(&left)
3247 && !is_list_expr(&right)
3248 {
3249 return Ok(dummy_udf_expr(name, vec![left, right]));
3250 }
3251 }
3252
3253 Ok(binary_expr(left, binary.op, right))
3254}
3255
3256fn coerce_scalar_function(
3258 func: &datafusion::logical_expr::expr::ScalarFunction,
3259 schema: &datafusion::common::DFSchema,
3260) -> Result<DfExpr> {
3261 use datafusion::arrow::datatypes::DataType;
3262 use datafusion::logical_expr::ExprSchemable;
3263
3264 let coerced_args: Vec<DfExpr> = func
3265 .args
3266 .iter()
3267 .map(|a| apply_type_coercion(a, schema))
3268 .collect::<Result<Vec<_>>>()?;
3269
3270 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3271 let types: Vec<_> = coerced_args
3272 .iter()
3273 .filter_map(|a| a.get_type(schema).ok())
3274 .collect();
3275 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3276 if has_mixed_types {
3277 let all_string_like = types
3281 .iter()
3282 .all(|t| matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Null));
3283 let unified_args: Vec<DfExpr> = if all_string_like {
3284 coerced_args
3285 .into_iter()
3286 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3287 .collect()
3288 } else {
3289 coerced_args
3291 .into_iter()
3292 .zip(types.iter())
3293 .map(|(arg, t)| match t {
3294 DataType::LargeBinary | DataType::Null => arg,
3295 DataType::List(_) | DataType::LargeList(_) => {
3296 list_to_large_binary_expr(arg)
3297 }
3298 _ => scalar_to_large_binary_expr(arg),
3299 })
3300 .collect()
3301 };
3302 return Ok(DfExpr::ScalarFunction(
3303 datafusion::logical_expr::expr::ScalarFunction {
3304 func: func.func.clone(),
3305 args: unified_args,
3306 },
3307 ));
3308 }
3309 }
3310
3311 Ok(DfExpr::ScalarFunction(
3312 datafusion::logical_expr::expr::ScalarFunction {
3313 func: func.func.clone(),
3314 args: coerced_args,
3315 },
3316 ))
3317}
3318
3319fn coerce_case_expr(
3322 case: &datafusion::logical_expr::expr::Case,
3323 schema: &datafusion::common::DFSchema,
3324) -> Result<DfExpr> {
3325 use datafusion::arrow::datatypes::DataType;
3326 use datafusion::logical_expr::ExprSchemable;
3327
3328 let coerced_operand = case
3329 .expr
3330 .as_ref()
3331 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3332 .transpose()?;
3333 let coerced_when_then = case
3334 .when_then_expr
3335 .iter()
3336 .map(|(w, t)| {
3337 let cw = apply_type_coercion(w, schema)?;
3338 let cw = match cw.get_type(schema).ok() {
3339 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3340 _ => cw,
3341 };
3342 let ct = apply_type_coercion(t, schema)?;
3343 Ok((Box::new(cw), Box::new(ct)))
3344 })
3345 .collect::<Result<Vec<_>>>()?;
3346 let coerced_else = case
3347 .else_expr
3348 .as_ref()
3349 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3350 .transpose()?;
3351
3352 let mut result_case = if let Some(operand) = coerced_operand {
3353 crate::cypher_type_coerce::rewrite_simple_case_to_generic(
3354 *operand,
3355 coerced_when_then,
3356 coerced_else,
3357 schema,
3358 )?
3359 } else {
3360 datafusion::logical_expr::expr::Case {
3361 expr: None,
3362 when_then_expr: coerced_when_then,
3363 else_expr: coerced_else,
3364 }
3365 };
3366
3367 crate::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3368
3369 Ok(DfExpr::Case(result_case))
3370}
3371
3372fn coerce_aggregate_function(
3374 agg: &datafusion::logical_expr::expr::AggregateFunction,
3375 schema: &datafusion::common::DFSchema,
3376) -> Result<DfExpr> {
3377 let coerced_args: Vec<DfExpr> = agg
3378 .params
3379 .args
3380 .iter()
3381 .map(|a| apply_type_coercion(a, schema))
3382 .collect::<Result<Vec<_>>>()?;
3383 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3384 .params
3385 .order_by
3386 .iter()
3387 .map(|s| {
3388 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3389 Ok(datafusion::logical_expr::SortExpr {
3390 expr: coerced_expr,
3391 asc: s.asc,
3392 nulls_first: s.nulls_first,
3393 })
3394 })
3395 .collect::<Result<Vec<_>>>()?;
3396 let coerced_filter = agg
3397 .params
3398 .filter
3399 .as_ref()
3400 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3401 .transpose()?;
3402 Ok(DfExpr::AggregateFunction(
3403 datafusion::logical_expr::expr::AggregateFunction {
3404 func: agg.func.clone(),
3405 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3406 args: coerced_args,
3407 distinct: agg.params.distinct,
3408 filter: coerced_filter,
3409 order_by: coerced_order_by,
3410 null_treatment: agg.params.null_treatment,
3411 },
3412 },
3413 ))
3414}
3415
3416#[cfg(test)]
3417mod tests {
3418 use super::*;
3419 use arrow_array::{
3420 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3421 };
3422 use uni_common::TemporalValue;
3423 #[test]
3424 fn test_literal_translation() {
3425 let expr = Expr::Literal(CypherLiteral::Integer(42));
3426 let result = cypher_expr_to_df(&expr, None).unwrap();
3427 let s = format!("{:?}", result);
3428 assert!(s.contains("Literal"));
3430 assert!(s.contains("Int64(42)"));
3431 }
3432
3433 #[test]
3434 fn test_property_access_no_context_uses_index() {
3435 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3437 let result = cypher_expr_to_df(&expr, None).unwrap();
3438 let s = format!("{}", result);
3439 assert!(
3440 s.contains("index"),
3441 "expected index UDF for non-graph variable, got: {s}"
3442 );
3443 }
3444
3445 #[test]
3446 fn test_comparison_operator() {
3447 let expr = Expr::BinaryOp {
3448 left: Box::new(Expr::Property(
3449 Box::new(Expr::Variable("n".to_string())),
3450 "age".to_string(),
3451 )),
3452 op: BinaryOp::Gt,
3453 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3454 };
3455 let result = cypher_expr_to_df(&expr, None).unwrap();
3456 let s = format!("{:?}", result);
3458 assert!(s.contains("age"));
3459 assert!(s.contains("30"));
3460 }
3461
3462 #[test]
3463 fn test_boolean_operators() {
3464 let expr = Expr::BinaryOp {
3465 left: Box::new(Expr::BinaryOp {
3466 left: Box::new(Expr::Property(
3467 Box::new(Expr::Variable("n".to_string())),
3468 "age".to_string(),
3469 )),
3470 op: BinaryOp::Gt,
3471 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3472 }),
3473 op: BinaryOp::And,
3474 right: Box::new(Expr::BinaryOp {
3475 left: Box::new(Expr::Property(
3476 Box::new(Expr::Variable("n".to_string())),
3477 "active".to_string(),
3478 )),
3479 op: BinaryOp::Eq,
3480 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3481 }),
3482 };
3483 let result = cypher_expr_to_df(&expr, None).unwrap();
3484 let s = format!("{:?}", result);
3485 assert!(s.contains("And"));
3486 }
3487
3488 #[test]
3489 fn test_is_null() {
3490 let expr = Expr::IsNull(Box::new(Expr::Property(
3491 Box::new(Expr::Variable("n".to_string())),
3492 "email".to_string(),
3493 )));
3494 let result = cypher_expr_to_df(&expr, None).unwrap();
3495 let s = format!("{:?}", result);
3496 assert!(s.contains("IsNull"));
3497 }
3498
3499 #[test]
3500 fn test_collect_properties() {
3501 let expr = Expr::BinaryOp {
3502 left: Box::new(Expr::Property(
3503 Box::new(Expr::Variable("n".to_string())),
3504 "name".to_string(),
3505 )),
3506 op: BinaryOp::Eq,
3507 right: Box::new(Expr::Property(
3508 Box::new(Expr::Variable("m".to_string())),
3509 "name".to_string(),
3510 )),
3511 };
3512
3513 let props = collect_properties(&expr);
3514 assert_eq!(props.len(), 2);
3515 assert!(props.contains(&("m".to_string(), "name".to_string())));
3516 assert!(props.contains(&("n".to_string(), "name".to_string())));
3517 }
3518
3519 #[test]
3520 fn test_function_call() {
3521 let expr = Expr::FunctionCall {
3522 name: "count".to_string(),
3523 args: vec![Expr::Wildcard],
3524 distinct: false,
3525 window_spec: None,
3526 };
3527 let result = cypher_expr_to_df(&expr, None).unwrap();
3528 let s = format!("{:?}", result);
3529 assert!(s.to_lowercase().contains("count"));
3530 }
3531
3532 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3537 use datafusion::logical_expr::Operator;
3538
3539 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3541 let fields: Vec<_> = cols
3542 .iter()
3543 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3544 .collect();
3545 let schema = Schema::new(fields);
3546 datafusion::common::DFSchema::try_from(schema).unwrap()
3547 }
3548
3549 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3551 let s = format!("{}", expr);
3552 s.contains(name)
3553 }
3554
3555 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3557 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3558 }
3559
3560 #[test]
3561 fn test_coercion_lb_eq_int64() {
3562 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3563 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3564 Box::new(col("lb")),
3565 Operator::Eq,
3566 Box::new(col("i")),
3567 ));
3568 let result = apply_type_coercion(&expr, &schema).unwrap();
3569 assert!(
3571 contains_udf(&result, "_cypher_equal"),
3572 "expected _cypher_equal, got: {result}"
3573 );
3574 }
3575
3576 #[test]
3577 fn test_coercion_lb_noteq_int64() {
3578 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3579 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3580 Box::new(col("lb")),
3581 Operator::NotEq,
3582 Box::new(col("i")),
3583 ));
3584 let result = apply_type_coercion(&expr, &schema).unwrap();
3585 assert!(contains_udf(&result, "_cypher_not_equal"));
3587 }
3588
3589 #[test]
3590 fn test_coercion_lb_lt_int64() {
3591 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3592 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3593 Box::new(col("lb")),
3594 Operator::Lt,
3595 Box::new(col("i")),
3596 ));
3597 let result = apply_type_coercion(&expr, &schema).unwrap();
3598 assert!(contains_udf(&result, "_cypher_lt"));
3600 }
3601
3602 #[test]
3603 fn test_coercion_lb_eq_float64() {
3604 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3605 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3606 Box::new(col("lb")),
3607 Operator::Eq,
3608 Box::new(col("f")),
3609 ));
3610 let result = apply_type_coercion(&expr, &schema).unwrap();
3611 assert!(contains_udf(&result, "_cypher_equal"));
3613 }
3614
3615 #[test]
3616 fn test_coercion_lb_eq_utf8() {
3617 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3618 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3619 Box::new(col("lb")),
3620 Operator::Eq,
3621 Box::new(col("s")),
3622 ));
3623 let result = apply_type_coercion(&expr, &schema).unwrap();
3624 assert!(contains_udf(&result, "_cypher_equal"));
3626 }
3627
3628 #[test]
3629 fn test_coercion_lb_eq_bool() {
3630 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3631 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3632 Box::new(col("lb")),
3633 Operator::Eq,
3634 Box::new(col("b")),
3635 ));
3636 let result = apply_type_coercion(&expr, &schema).unwrap();
3637 assert!(contains_udf(&result, "_cypher_equal"));
3639 }
3640
3641 #[test]
3642 fn test_coercion_int64_eq_lb() {
3643 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3645 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3646 Box::new(col("i")),
3647 Operator::Eq,
3648 Box::new(col("lb")),
3649 ));
3650 let result = apply_type_coercion(&expr, &schema).unwrap();
3651 assert!(contains_udf(&result, "_cypher_equal"));
3653 }
3654
3655 #[test]
3656 fn test_coercion_float64_gt_lb() {
3657 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3658 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3659 Box::new(col("f")),
3660 Operator::Gt,
3661 Box::new(col("lb")),
3662 ));
3663 let result = apply_type_coercion(&expr, &schema).unwrap();
3664 assert!(contains_udf(&result, "_cypher_gt"));
3666 }
3667
3668 #[test]
3669 fn test_coercion_both_lb_eq() {
3670 let schema = make_schema(&[
3671 ("lb1", DataType::LargeBinary),
3672 ("lb2", DataType::LargeBinary),
3673 ]);
3674 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3675 Box::new(col("lb1")),
3676 Operator::Eq,
3677 Box::new(col("lb2")),
3678 ));
3679 let result = apply_type_coercion(&expr, &schema).unwrap();
3680 assert!(contains_udf(&result, "_cypher_equal"));
3681 }
3682
3683 #[test]
3684 fn test_coercion_both_lb_lt() {
3685 let schema = make_schema(&[
3686 ("lb1", DataType::LargeBinary),
3687 ("lb2", DataType::LargeBinary),
3688 ]);
3689 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3690 Box::new(col("lb1")),
3691 Operator::Lt,
3692 Box::new(col("lb2")),
3693 ));
3694 let result = apply_type_coercion(&expr, &schema).unwrap();
3695 assert!(contains_udf(&result, "_cypher_lt"));
3696 }
3697
3698 #[test]
3699 fn test_coercion_both_lb_noteq() {
3700 let schema = make_schema(&[
3701 ("lb1", DataType::LargeBinary),
3702 ("lb2", DataType::LargeBinary),
3703 ]);
3704 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3705 Box::new(col("lb1")),
3706 Operator::NotEq,
3707 Box::new(col("lb2")),
3708 ));
3709 let result = apply_type_coercion(&expr, &schema).unwrap();
3710 assert!(contains_udf(&result, "_cypher_not_equal"));
3711 }
3712
3713 #[test]
3714 fn test_coercion_lb_plus_int64() {
3715 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3716 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3717 Box::new(col("lb")),
3718 Operator::Plus,
3719 Box::new(col("i")),
3720 ));
3721 let result = apply_type_coercion(&expr, &schema).unwrap();
3722 assert!(contains_udf(&result, "_cypher_add"));
3723 }
3724
3725 #[test]
3726 fn test_coercion_lb_minus_int64() {
3727 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3728 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3729 Box::new(col("lb")),
3730 Operator::Minus,
3731 Box::new(col("i")),
3732 ));
3733 let result = apply_type_coercion(&expr, &schema).unwrap();
3734 assert!(contains_udf(&result, "_cypher_sub"));
3735 }
3736
3737 #[test]
3738 fn test_coercion_lb_multiply_float64() {
3739 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3740 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3741 Box::new(col("lb")),
3742 Operator::Multiply,
3743 Box::new(col("f")),
3744 ));
3745 let result = apply_type_coercion(&expr, &schema).unwrap();
3746 assert!(contains_udf(&result, "_cypher_mul"));
3747 }
3748
3749 #[test]
3750 fn test_coercion_int64_plus_lb() {
3751 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3752 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3753 Box::new(col("i")),
3754 Operator::Plus,
3755 Box::new(col("lb")),
3756 ));
3757 let result = apply_type_coercion(&expr, &schema).unwrap();
3758 assert!(contains_udf(&result, "_cypher_add"));
3759 }
3760
3761 #[test]
3762 fn test_coercion_lb_plus_utf8() {
3763 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3765 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3766 Box::new(col("lb")),
3767 Operator::Plus,
3768 Box::new(col("s")),
3769 ));
3770 let result = apply_type_coercion(&expr, &schema).unwrap();
3771 assert!(contains_udf(&result, "_cypher_add"));
3773 }
3774
3775 #[test]
3776 fn test_coercion_and_null_bool() {
3777 let schema = make_schema(&[("b", DataType::Boolean)]);
3778 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3780 Box::new(lit(ScalarValue::Null)),
3781 Operator::And,
3782 Box::new(col("b")),
3783 ));
3784 let result = apply_type_coercion(&expr, &schema).unwrap();
3785 let s = format!("{}", result);
3786 assert!(
3788 s.contains("CAST") || s.contains("Boolean"),
3789 "expected cast to Boolean, got: {s}"
3790 );
3791 assert!(is_binary_op(&result, Operator::And));
3792 }
3793
3794 #[test]
3795 fn test_coercion_bool_and_null() {
3796 let schema = make_schema(&[("b", DataType::Boolean)]);
3797 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3798 Box::new(col("b")),
3799 Operator::And,
3800 Box::new(lit(ScalarValue::Null)),
3801 ));
3802 let result = apply_type_coercion(&expr, &schema).unwrap();
3803 assert!(is_binary_op(&result, Operator::And));
3804 }
3805
3806 #[test]
3807 fn test_coercion_or_null_bool() {
3808 let schema = make_schema(&[("b", DataType::Boolean)]);
3809 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3810 Box::new(lit(ScalarValue::Null)),
3811 Operator::Or,
3812 Box::new(col("b")),
3813 ));
3814 let result = apply_type_coercion(&expr, &schema).unwrap();
3815 assert!(is_binary_op(&result, Operator::Or));
3816 }
3817
3818 #[test]
3819 fn test_coercion_null_and_null() {
3820 let schema = make_schema(&[]);
3821 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3822 Box::new(lit(ScalarValue::Null)),
3823 Operator::And,
3824 Box::new(lit(ScalarValue::Null)),
3825 ));
3826 let result = apply_type_coercion(&expr, &schema).unwrap();
3827 assert!(is_binary_op(&result, Operator::And));
3828 }
3829
3830 #[test]
3831 fn test_coercion_bool_and_bool_noop() {
3832 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3833 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3834 Box::new(col("a")),
3835 Operator::And,
3836 Box::new(col("b")),
3837 ));
3838 let result = apply_type_coercion(&expr, &schema).unwrap();
3839 assert!(is_binary_op(&result, Operator::And));
3841 let s = format!("{}", result);
3842 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3843 }
3844
3845 #[test]
3846 fn test_coercion_case_when_lb() {
3847 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3849 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3850 Box::new(col("lb")),
3851 Operator::Eq,
3852 Box::new(lit(42_i64)),
3853 ));
3854 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3855 expr: None,
3856 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3857 else_expr: Some(Box::new(lit("b"))),
3858 });
3859 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3860 let s = format!("{}", result);
3861 assert!(
3863 s.contains("_cypher_equal"),
3864 "CASE WHEN should have _cypher_equal, got: {s}"
3865 );
3866 }
3867
3868 #[test]
3869 fn test_coercion_case_then_lb() {
3870 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3872 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3873 Box::new(col("lb")),
3874 Operator::Plus,
3875 Box::new(lit(1_i64)),
3876 ));
3877 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3878 expr: None,
3879 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3880 else_expr: Some(Box::new(lit(0_i64))),
3881 });
3882 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3883 let s = format!("{}", result);
3884 assert!(
3885 s.contains("_cypher_add"),
3886 "CASE THEN should have _cypher_add, got: {s}"
3887 );
3888 }
3889
3890 #[test]
3891 fn test_coercion_case_else_lb() {
3892 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3894 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3895 Box::new(col("lb")),
3896 Operator::Plus,
3897 Box::new(lit(2_i64)),
3898 ));
3899 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3900 expr: None,
3901 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3902 else_expr: Some(Box::new(else_expr)),
3903 });
3904 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3905 let s = format!("{}", result);
3906 assert!(
3907 s.contains("_cypher_add"),
3908 "CASE ELSE should have _cypher_add, got: {s}"
3909 );
3910 }
3911
3912 #[test]
3913 fn test_coercion_int64_eq_int64_noop() {
3914 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3915 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3916 Box::new(col("a")),
3917 Operator::Eq,
3918 Box::new(col("b")),
3919 ));
3920 let result = apply_type_coercion(&expr, &schema).unwrap();
3921 assert!(is_binary_op(&result, Operator::Eq));
3922 let s = format!("{}", result);
3923 assert!(
3924 !s.contains("_cypher_value"),
3925 "should not contain cypher_value decode: {s}"
3926 );
3927 }
3928
3929 #[test]
3930 fn test_coercion_both_lb_plus() {
3931 let schema = make_schema(&[
3933 ("lb1", DataType::LargeBinary),
3934 ("lb2", DataType::LargeBinary),
3935 ]);
3936 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3937 Box::new(col("lb1")),
3938 Operator::Plus,
3939 Box::new(col("lb2")),
3940 ));
3941 let result = apply_type_coercion(&expr, &schema).unwrap();
3942 assert!(
3943 contains_udf(&result, "_cypher_add"),
3944 "expected _cypher_add, got: {result}"
3945 );
3946 }
3947
3948 #[test]
3949 fn test_coercion_native_list_plus_scalar() {
3950 let schema = make_schema(&[
3952 (
3953 "lst",
3954 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3955 ),
3956 ("i", DataType::Int32),
3957 ]);
3958 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3959 Box::new(col("lst")),
3960 Operator::Plus,
3961 Box::new(col("i")),
3962 ));
3963 let result = apply_type_coercion(&expr, &schema).unwrap();
3964 assert!(
3965 contains_udf(&result, "_cypher_list_append"),
3966 "expected _cypher_list_append, got: {result}"
3967 );
3968 }
3969
3970 #[test]
3971 fn test_coercion_lb_plus_int64_unchanged() {
3972 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3974 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3975 Box::new(col("lb")),
3976 Operator::Plus,
3977 Box::new(col("i")),
3978 ));
3979 let result = apply_type_coercion(&expr, &schema).unwrap();
3980 assert!(
3981 contains_udf(&result, "_cypher_add"),
3982 "expected _cypher_add, got: {result}"
3983 );
3984 }
3985
3986 #[test]
3991 fn test_mixed_list_with_variables_compiles() {
3992 let expr = Expr::List(vec![
3994 Expr::Variable("n".to_string()),
3995 Expr::Literal(CypherLiteral::Integer(1)),
3996 Expr::Literal(CypherLiteral::String("hello".to_string())),
3997 ]);
3998 let result = cypher_expr_to_df(&expr, None).unwrap();
3999 let s = format!("{}", result);
4000 assert!(
4001 s.contains("_make_cypher_list"),
4002 "expected _make_cypher_list UDF call, got: {s}"
4003 );
4004 }
4005
4006 #[test]
4007 fn test_literal_only_mixed_list_uses_cv_fastpath() {
4008 let expr = Expr::List(vec![
4010 Expr::Literal(CypherLiteral::Integer(1)),
4011 Expr::Literal(CypherLiteral::String("hi".to_string())),
4012 Expr::Literal(CypherLiteral::Bool(true)),
4013 ]);
4014 let result = cypher_expr_to_df(&expr, None).unwrap();
4015 assert!(
4016 matches!(result, DfExpr::Literal(..)),
4017 "expected Literal (CypherValue fast path), got: {result}"
4018 );
4019 }
4020
4021 #[test]
4026 fn test_in_mixed_literal_list_uses_cypher_in() {
4027 let expr = Expr::In {
4029 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4030 list: Box::new(Expr::List(vec![
4031 Expr::Literal(CypherLiteral::String("1".to_string())),
4032 Expr::Literal(CypherLiteral::Integer(2)),
4033 ])),
4034 };
4035 let result = cypher_expr_to_df(&expr, None).unwrap();
4036 let s = format!("{}", result);
4037 assert!(
4038 s.contains("_cypher_in"),
4039 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
4040 );
4041 }
4042
4043 #[test]
4044 fn test_in_homogeneous_literal_list_uses_cypher_in() {
4045 let expr = Expr::In {
4047 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4048 list: Box::new(Expr::List(vec![
4049 Expr::Literal(CypherLiteral::Integer(2)),
4050 Expr::Literal(CypherLiteral::Integer(3)),
4051 ])),
4052 };
4053 let result = cypher_expr_to_df(&expr, None).unwrap();
4054 let s = format!("{}", result);
4055 assert!(
4056 s.contains("_cypher_in"),
4057 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
4058 );
4059 }
4060
4061 #[test]
4062 fn test_in_list_with_variables_uses_make_cypher_list() {
4063 let expr = Expr::In {
4065 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4066 list: Box::new(Expr::List(vec![
4067 Expr::Variable("x".to_string()),
4068 Expr::Literal(CypherLiteral::Integer(2)),
4069 ])),
4070 };
4071 let result = cypher_expr_to_df(&expr, None).unwrap();
4072 let s = format!("{}", result);
4073 assert!(
4074 s.contains("_cypher_in"),
4075 "expected _cypher_in UDF, got: {s}"
4076 );
4077 assert!(
4078 s.contains("_make_cypher_list"),
4079 "expected _make_cypher_list for variable-containing list, got: {s}"
4080 );
4081 }
4082
4083 #[test]
4088 fn test_property_on_graph_entity_uses_column() {
4089 let mut ctx = TranslationContext::new();
4091 ctx.variable_kinds
4092 .insert("n".to_string(), VariableKind::Node);
4093
4094 let expr = Expr::Property(
4095 Box::new(Expr::Variable("n".to_string())),
4096 "name".to_string(),
4097 );
4098 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4099 let s = format!("{:?}", result);
4100 assert!(
4101 s.contains("Column") && s.contains("n.name"),
4102 "expected flat column 'n.name' for graph entity, got: {s}"
4103 );
4104 }
4105
4106 #[test]
4107 fn test_property_on_non_graph_var_uses_index() {
4108 let ctx = TranslationContext::new();
4110
4111 let expr = Expr::Property(
4112 Box::new(Expr::Variable("map".to_string())),
4113 "name".to_string(),
4114 );
4115 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4116 let s = format!("{}", result);
4117 assert!(
4118 s.contains("index"),
4119 "expected index UDF for non-graph variable, got: {s}"
4120 );
4121 }
4122
4123 #[test]
4124 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4125 let mut map = std::collections::HashMap::new();
4126 map.insert("k".to_string(), Value::Int(1));
4127 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4128 assert!(
4129 matches!(scalar, ScalarValue::Struct(_)),
4130 "expected Struct scalar for map input"
4131 );
4132 }
4133
4134 #[test]
4135 fn test_value_to_scalar_empty_map_becomes_struct() {
4136 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4137 assert!(
4138 matches!(scalar, ScalarValue::Struct(_)),
4139 "empty map should produce an empty Struct scalar"
4140 );
4141 }
4142
4143 #[test]
4144 fn test_value_to_scalar_null_is_untyped_null() {
4145 let scalar = value_to_scalar(&Value::Null).unwrap();
4146 assert!(
4147 matches!(scalar, ScalarValue::Null),
4148 "expected untyped Null scalar for Value::Null"
4149 );
4150 }
4151
4152 #[test]
4153 fn test_value_to_scalar_datetime_produces_struct() {
4154 let datetime = Value::Temporal(TemporalValue::DateTime {
4156 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4159 });
4160
4161 let scalar = value_to_scalar(&datetime).unwrap();
4162
4163 if let ScalarValue::Struct(struct_arr) = scalar {
4165 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4166 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4167
4168 let fields = struct_arr.fields();
4170 assert_eq!(fields[0].name(), "nanos_since_epoch");
4171 assert_eq!(fields[1].name(), "offset_seconds");
4172 assert_eq!(fields[2].name(), "timezone_name");
4173
4174 let nanos_col = struct_arr.column(0);
4176 let offset_col = struct_arr.column(1);
4177 let tz_col = struct_arr.column(2);
4178
4179 if let Some(nanos_arr) = nanos_col
4180 .as_any()
4181 .downcast_ref::<TimestampNanosecondArray>()
4182 {
4183 assert_eq!(nanos_arr.value(0), 441763200000000000);
4184 } else {
4185 panic!("Expected TimestampNanosecondArray for nanos field");
4186 }
4187
4188 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4189 assert_eq!(offset_arr.value(0), 3600);
4190 } else {
4191 panic!("Expected Int32Array for offset field");
4192 }
4193
4194 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4195 assert_eq!(tz_arr.value(0), "Europe/Paris");
4196 } else {
4197 panic!("Expected StringArray for timezone_name field");
4198 }
4199 } else {
4200 panic!(
4201 "Expected ScalarValue::Struct for DateTime, got {:?}",
4202 scalar
4203 );
4204 }
4205 }
4206
4207 #[test]
4208 fn test_value_to_scalar_datetime_with_null_timezone() {
4209 let datetime = Value::Temporal(TemporalValue::DateTime {
4211 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4214 });
4215
4216 let scalar = value_to_scalar(&datetime).unwrap();
4217
4218 if let ScalarValue::Struct(struct_arr) = scalar {
4219 assert_eq!(struct_arr.num_columns(), 3);
4220
4221 let tz_col = struct_arr.column(2);
4223 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4224 assert!(tz_arr.is_null(0), "expected null timezone_name");
4225 } else {
4226 panic!("Expected StringArray for timezone_name field");
4227 }
4228 } else {
4229 panic!("Expected ScalarValue::Struct for DateTime");
4230 }
4231 }
4232
4233 #[test]
4234 fn test_value_to_scalar_time_produces_struct() {
4235 let time = Value::Temporal(TemporalValue::Time {
4237 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4240
4241 let scalar = value_to_scalar(&time).unwrap();
4242
4243 if let ScalarValue::Struct(struct_arr) = scalar {
4245 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4246 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4247
4248 let fields = struct_arr.fields();
4250 assert_eq!(fields[0].name(), "nanos_since_midnight");
4251 assert_eq!(fields[1].name(), "offset_seconds");
4252
4253 let nanos_col = struct_arr.column(0);
4255 let offset_col = struct_arr.column(1);
4256
4257 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4258 assert_eq!(nanos_arr.value(0), 37845000000000);
4259 } else {
4260 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4261 }
4262
4263 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4264 assert_eq!(offset_arr.value(0), 3600);
4265 } else {
4266 panic!("Expected Int32Array for offset field");
4267 }
4268 } else {
4269 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4270 }
4271 }
4272
4273 #[test]
4274 fn test_value_to_scalar_time_boundary_values() {
4275 let midnight = Value::Temporal(TemporalValue::Time {
4277 nanos_since_midnight: 0,
4278 offset_seconds: 0,
4279 });
4280
4281 let scalar = value_to_scalar(&midnight).unwrap();
4282
4283 if let ScalarValue::Struct(struct_arr) = scalar {
4284 let nanos_col = struct_arr.column(0);
4285 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4286 assert_eq!(nanos_arr.value(0), 0);
4287 } else {
4288 panic!("Expected Time64NanosecondArray");
4289 }
4290 } else {
4291 panic!("Expected ScalarValue::Struct for Time");
4292 }
4293 }
4294}