1use anyhow::{Result, anyhow};
34use datafusion::common::{Column, ScalarValue};
35use datafusion::logical_expr::{
36 ColumnarValue, Expr as DfExpr, ScalarFunctionArgs, col, expr::InList, lit,
37};
38use datafusion::prelude::ExprFunctionExt;
39use std::hash::{Hash, Hasher};
40use std::ops::Not;
41use std::sync::Arc;
42use uni_common::Value;
43use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, MapProjectionItem, UnaryOp};
44
45const COL_VID: &str = "_vid";
47const COL_EID: &str = "_eid";
48const COL_LABELS: &str = "_labels";
49const COL_TYPE: &str = "_type";
50
51fn is_primitive_type(dt: &datafusion::arrow::datatypes::DataType) -> bool {
56 !matches!(
57 dt,
58 datafusion::arrow::datatypes::DataType::LargeBinary
59 | datafusion::arrow::datatypes::DataType::Struct(_)
60 | datafusion::arrow::datatypes::DataType::List(_)
61 | datafusion::arrow::datatypes::DataType::LargeList(_)
62 )
63}
64
65pub fn struct_getfield(expr: DfExpr, field_name: &str) -> DfExpr {
67 use datafusion::logical_expr::ScalarUDF;
68 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
69 Arc::new(ScalarUDF::from(
70 datafusion::functions::core::getfield::GetFieldFunc::new(),
71 )),
72 vec![expr, lit(field_name)],
73 ))
74}
75
76pub fn extract_datetime_nanos(expr: DfExpr) -> DfExpr {
78 struct_getfield(expr, "nanos_since_epoch")
79}
80
81pub fn extract_time_nanos(expr: DfExpr) -> DfExpr {
89 use datafusion::logical_expr::Operator;
90
91 let nanos_local = struct_getfield(expr.clone(), "nanos_since_midnight");
92 let offset_seconds = struct_getfield(expr, "offset_seconds");
93
94 let nanos_local_i64 = cast_expr(nanos_local, datafusion::arrow::datatypes::DataType::Int64);
98 let offset_nanos = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
99 Box::new(cast_expr(
100 offset_seconds,
101 datafusion::arrow::datatypes::DataType::Int64,
102 )),
103 Operator::Multiply,
104 Box::new(lit(1_000_000_000_i64)),
105 ));
106
107 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
108 Box::new(nanos_local_i64),
109 Operator::Minus,
110 Box::new(offset_nanos),
111 ))
112}
113
114fn normalize_datetime_literal(expr: DfExpr) -> DfExpr {
121 if let DfExpr::Literal(ScalarValue::Utf8(Some(ref s)), _) = expr
122 && let Some(normalized) = normalize_datetime_str(s)
123 {
124 return lit(normalized);
125 }
126 expr
127}
128
129pub fn normalize_datetime_str(s: &str) -> Option<String> {
132 if s.len() < 16 || s.as_bytes().get(10) != Some(&b'T') {
134 return None;
135 }
136 let b = s.as_bytes();
137 if !(b[11].is_ascii_digit()
138 && b[12].is_ascii_digit()
139 && b[13] == b':'
140 && b[14].is_ascii_digit()
141 && b[15].is_ascii_digit())
142 {
143 return None;
144 }
145 if b.len() > 16 && b[16] == b':' {
147 return None;
148 }
149 let mut normalized = String::with_capacity(s.len() + 3);
151 normalized.push_str(&s[..16]);
152 normalized.push_str(":00");
153 if s.len() > 16 {
154 normalized.push_str(&s[16..]);
155 }
156 Some(normalized)
157}
158
159fn infer_common_scalar_type(scalars: &[ScalarValue]) -> datafusion::arrow::datatypes::DataType {
161 use datafusion::arrow::datatypes::DataType;
162
163 let non_null: Vec<_> = scalars
164 .iter()
165 .filter(|s| !matches!(s, ScalarValue::Null))
166 .collect();
167
168 if non_null.is_empty() {
169 return DataType::Null;
170 }
171
172 if non_null.iter().all(|s| matches!(s, ScalarValue::Int64(_))) {
174 DataType::Int64
175 } else if non_null
176 .iter()
177 .all(|s| matches!(s, ScalarValue::Float64(_) | ScalarValue::Int64(_)))
178 {
179 DataType::Float64
180 } else if non_null.iter().all(|s| matches!(s, ScalarValue::Utf8(_))) {
181 DataType::Utf8
182 } else if non_null
183 .iter()
184 .all(|s| matches!(s, ScalarValue::Boolean(_)))
185 {
186 DataType::Boolean
187 } else {
188 DataType::LargeBinary
190 }
191}
192
193const CYPHER_LIST_FUNCS: &[&str] = &[
195 "_make_cypher_list",
196 "_cypher_list_concat",
197 "_cypher_list_append",
198];
199
200fn is_cypher_list_expr(e: &DfExpr) -> bool {
202 matches!(e, DfExpr::Literal(ScalarValue::LargeBinary(_), _))
203 || matches!(e, DfExpr::ScalarFunction(f) if CYPHER_LIST_FUNCS.contains(&f.func.name()))
204}
205
206fn is_list_expr(e: &DfExpr) -> bool {
208 is_cypher_list_expr(e)
209 || matches!(e, DfExpr::Literal(ScalarValue::List(_), _))
210 || matches!(e, DfExpr::ScalarFunction(f) if f.func.name() == "make_array")
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
224pub enum VariableKind {
225 Node,
227 Edge,
229 EdgeList,
231 Path,
233}
234
235impl VariableKind {
236 pub fn edge_for(is_variable_length: bool) -> Self {
240 if is_variable_length {
241 Self::EdgeList
242 } else {
243 Self::Edge
244 }
245 }
246}
247
248pub fn cypher_expr_to_df(expr: &Expr, context: Option<&TranslationContext>) -> Result<DfExpr> {
283 match expr {
284 Expr::PatternComprehension { .. } => Err(anyhow!(
285 "Pattern comprehensions require fallback executor (graph traversal)"
286 )),
287 Expr::Wildcard => Ok(DfExpr::Literal(
291 datafusion::common::ScalarValue::Int32(Some(1)),
292 None,
293 )),
294
295 Expr::Variable(name) => {
296 if let Some(ctx) = context
301 && ctx.variable_kinds.contains_key(name)
302 {
303 return Ok(DfExpr::Column(Column::from_name(name)));
304 }
305
306 if let Some(ctx) = context
312 && let Some(value) = ctx.outer_values.get(name)
313 {
314 return value_to_scalar(value).map(lit);
315 }
316
317 if let Some(ctx) = context
322 && let Some(value) = ctx.parameters.get(name)
323 {
324 match value {
327 Value::List(values) if name.ends_with("._vid") => {
328 let literals = values
330 .iter()
331 .map(|v| value_to_scalar(v).map(lit))
332 .collect::<Result<Vec<_>>>()?;
333 return Ok(DfExpr::InList(InList {
334 expr: Box::new(DfExpr::Column(Column::from_name(name))),
335 list: literals,
336 negated: false,
337 }));
338 }
339 other_value => return value_to_scalar(other_value).map(lit),
340 }
341 }
342
343 Ok(DfExpr::Column(Column::from_name(name)))
346 }
347
348 Expr::Property(base, prop) => translate_property_access(base, prop, context),
349
350 Expr::ArrayIndex { array, index } => {
351 if let Ok(var_name) = extract_variable_name(array)
354 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
355 {
356 let col_name = format!("{}.{}", var_name, prop_name);
357 return Ok(DfExpr::Column(Column::from_name(col_name)));
358 }
359
360 let array_expr = cypher_expr_to_df(array, context)?;
361 let index_expr = cypher_expr_to_df(index, context)?;
362
363 Ok(dummy_udf_expr("index", vec![array_expr, index_expr]))
365 }
366
367 Expr::ArraySlice { array, start, end } => {
368 let array_expr = cypher_expr_to_df(array, context)?;
372
373 let start_expr = match start {
374 Some(s) => cypher_expr_to_df(s, context)?,
375 None => lit(0i64),
376 };
377
378 let end_expr = match end {
379 Some(e) => cypher_expr_to_df(e, context)?,
380 None => lit(i64::MAX),
381 };
382
383 Ok(dummy_udf_expr(
386 "_cypher_list_slice",
387 vec![array_expr, start_expr, end_expr],
388 ))
389 }
390
391 Expr::Parameter(name) => {
392 if let Some(ctx) = context
394 && let Some(value) = ctx.parameters.get(name)
395 {
396 return value_to_scalar(value).map(lit);
397 }
398 Err(anyhow!("Unresolved parameter: ${}", name))
399 }
400
401 Expr::Literal(value) => {
402 let scalar = cypher_literal_to_scalar(value)?;
403 Ok(lit(scalar))
404 }
405
406 Expr::List(items) => translate_list_literal(items, context),
407
408 Expr::Map(entries) => {
409 if entries.is_empty() {
410 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_common::Value::Map(
412 Default::default(),
413 ));
414 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
415 }
416 let mut args = Vec::with_capacity(entries.len() * 2);
419 for (key, val_expr) in entries {
420 args.push(lit(key.clone()));
421 args.push(cypher_expr_to_df(val_expr, context)?);
422 }
423 Ok(datafusion::functions::expr_fn::named_struct(args))
424 }
425
426 Expr::IsNull(inner) => translate_null_check(inner, context, true),
427
428 Expr::IsNotNull(inner) => translate_null_check(inner, context, false),
429
430 Expr::IsUnique(_) => {
431 Err(anyhow!(
433 "IS UNIQUE can only be used in constraint definitions"
434 ))
435 }
436
437 Expr::FunctionCall {
438 name,
439 args,
440 distinct,
441 window_spec,
442 } => {
443 if window_spec.is_some() {
446 let col_name = expr.to_string_repr();
448 Ok(col(&col_name))
449 } else {
450 translate_function_call(name, args, *distinct, context)
451 }
452 }
453
454 Expr::In { expr, list } => translate_in_expression(expr, list, context),
455
456 Expr::BinaryOp { left, op, right } => {
457 let left_expr = cypher_expr_to_df(left, context)?;
458 let right_expr = cypher_expr_to_df(right, context)?;
459 translate_binary_op(left_expr, op, right_expr)
460 }
461
462 Expr::UnaryOp { op, expr: inner } => {
463 let inner_expr = cypher_expr_to_df(inner, context)?;
464 match op {
465 UnaryOp::Not => Ok(inner_expr.not()),
466 UnaryOp::Neg => Ok(DfExpr::Negative(Box::new(inner_expr))),
467 }
468 }
469
470 Expr::Case {
471 expr,
472 when_then,
473 else_expr,
474 } => translate_case_expression(expr, when_then, else_expr, context),
475
476 Expr::Reduce { .. } => Err(anyhow!(
477 "Reduce expressions not yet supported in DataFusion translation"
478 )),
479
480 Expr::Exists { .. } => Err(anyhow!(
481 "EXISTS subqueries are handled by the physical expression compiler, \
482 not the DataFusion logical expression translator"
483 )),
484
485 Expr::CountSubquery(_) => Err(anyhow!(
486 "Count subqueries not yet supported in DataFusion translation"
487 )),
488
489 Expr::CollectSubquery(_) => Err(anyhow!(
490 "COLLECT subqueries not yet supported in DataFusion translation"
491 )),
492
493 Expr::Quantifier { .. } => {
494 Err(anyhow!(
499 "Quantifier expressions (ALL/ANY/SINGLE/NONE) require physical compilation \
500 via CypherPhysicalExprCompiler"
501 ))
502 }
503
504 Expr::ListComprehension { .. } => {
505 Err(anyhow!(
517 "List comprehensions not yet supported in DataFusion translation - requires lambda functions"
518 ))
519 }
520
521 Expr::ValidAt { .. } => {
522 Err(anyhow!(
525 "VALID_AT expression should have been transformed to function call in planner"
526 ))
527 }
528
529 Expr::MapProjection { base, items } => translate_map_projection(base, items, context),
530
531 Expr::LabelCheck { expr, labels } => {
532 if let Expr::Variable(var) = expr.as_ref() {
533 let is_edge = context
535 .and_then(|ctx| ctx.variable_kinds.get(var))
536 .is_some_and(|k| matches!(k, VariableKind::Edge));
537
538 if is_edge {
539 if labels.len() > 1 {
543 Ok(lit(false))
544 } else {
545 let type_col =
546 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_TYPE)));
547 Ok(DfExpr::Case(datafusion::logical_expr::Case {
549 expr: None,
550 when_then_expr: vec![(
551 Box::new(type_col.clone().is_null()),
552 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
553 )],
554 else_expr: Some(Box::new(type_col.eq(lit(labels[0].clone())))),
555 }))
556 }
557 } else {
558 let labels_col =
560 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
561 let checks = labels
562 .iter()
563 .map(|label| {
564 datafusion::functions_nested::expr_fn::array_has(
565 labels_col.clone(),
566 lit(label.clone()),
567 )
568 })
569 .reduce(|acc, check| acc.and(check));
570 Ok(DfExpr::Case(datafusion::logical_expr::Case {
572 expr: None,
573 when_then_expr: vec![(
574 Box::new(labels_col.is_null()),
575 Box::new(DfExpr::Literal(ScalarValue::Boolean(None), None)),
576 )],
577 else_expr: Some(Box::new(checks.unwrap())),
578 }))
579 }
580 } else {
581 Err(anyhow!(
582 "LabelCheck on non-variable expression not yet supported in DataFusion"
583 ))
584 }
585 }
586 }
587}
588
589#[derive(Debug, Clone)]
593pub struct TranslationContext {
594 pub parameters: std::collections::HashMap<String, Value>,
596
597 pub outer_values: std::collections::HashMap<String, Value>,
601
602 pub variable_labels: std::collections::HashMap<String, String>,
604
605 pub variable_kinds: std::collections::HashMap<String, VariableKind>,
607
608 pub node_variable_hints: Vec<String>,
611
612 pub mutation_edge_hints: Vec<String>,
615
616 pub statement_time: chrono::DateTime<chrono::Utc>,
621}
622
623impl Default for TranslationContext {
624 fn default() -> Self {
625 Self {
626 parameters: std::collections::HashMap::new(),
627 outer_values: std::collections::HashMap::new(),
628 variable_labels: std::collections::HashMap::new(),
629 variable_kinds: std::collections::HashMap::new(),
630 node_variable_hints: Vec::new(),
631 mutation_edge_hints: Vec::new(),
632 statement_time: chrono::Utc::now(),
633 }
634 }
635}
636
637impl TranslationContext {
638 pub fn new() -> Self {
640 Self::default()
641 }
642
643 pub fn with_parameter(mut self, name: impl Into<String>, value: Value) -> Self {
645 self.parameters.insert(name.into(), value);
646 self
647 }
648
649 pub fn with_variable_label(mut self, var: impl Into<String>, label: impl Into<String>) -> Self {
651 self.variable_labels.insert(var.into(), label.into());
652 self
653 }
654}
655
656fn extract_variable_name(expr: &Expr) -> Result<String> {
658 match expr {
659 Expr::Variable(name) => Ok(name.clone()),
660 Expr::Property(base, _) => extract_variable_name(base),
661 _ => Err(anyhow!(
662 "Cannot extract variable name from expression: {:?}",
663 expr
664 )),
665 }
666}
667
668fn translate_null_check(
670 inner: &Expr,
671 context: Option<&TranslationContext>,
672 is_null: bool,
673) -> Result<DfExpr> {
674 if let Expr::Variable(var) = inner
675 && let Some(ctx) = context
676 && let Some(kind) = ctx.variable_kinds.get(var)
677 {
678 let col_name = match kind {
679 VariableKind::Node => format!("{}.{}", var, COL_VID),
680 VariableKind::Edge => format!("{}.{}", var, COL_EID),
681 VariableKind::Path | VariableKind::EdgeList => var.clone(),
682 };
683 let col_expr = DfExpr::Column(Column::from_name(col_name));
684 return Ok(if is_null {
685 col_expr.is_null()
686 } else {
687 col_expr.is_not_null()
688 });
689 }
690
691 let inner_expr = cypher_expr_to_df(inner, context)?;
692 Ok(if is_null {
693 inner_expr.is_null()
694 } else {
695 inner_expr.is_not_null()
696 })
697}
698
699fn try_temporal_accessor(base_expr: DfExpr, prop: &str) -> Option<DfExpr> {
704 if crate::datetime::is_duration_accessor(prop) {
705 Some(dummy_udf_expr(
706 "_duration_property",
707 vec![base_expr, lit(prop.to_string())],
708 ))
709 } else if crate::datetime::is_temporal_accessor(prop) {
710 Some(dummy_udf_expr(
711 "_temporal_property",
712 vec![base_expr, lit(prop.to_string())],
713 ))
714 } else {
715 None
716 }
717}
718
719fn translate_property_access(
721 base: &Expr,
722 prop: &str,
723 context: Option<&TranslationContext>,
724) -> Result<DfExpr> {
725 if let Ok(var_name) = extract_variable_name(base) {
726 let is_graph_entity = context
727 .and_then(|ctx| ctx.variable_kinds.get(&var_name))
728 .is_some_and(|k| matches!(k, VariableKind::Node | VariableKind::Edge));
729
730 if !is_graph_entity
731 && let Some(expr) =
732 try_temporal_accessor(DfExpr::Column(Column::from_name(&var_name)), prop)
733 {
734 return Ok(expr);
735 }
736
737 let col_name = format!("{}.{}", var_name, prop);
738
739 if let Some(ctx) = context
742 && let Some(value) = ctx.parameters.get(&col_name)
743 {
744 match value {
747 Value::List(values) if col_name.ends_with("._vid") => {
748 let literals = values
749 .iter()
750 .map(|v| value_to_scalar(v).map(lit))
751 .collect::<Result<Vec<_>>>()?;
752 return Ok(DfExpr::InList(InList {
753 expr: Box::new(DfExpr::Column(Column::from_name(&col_name))),
754 list: literals,
755 negated: false,
756 }));
757 }
758 other_value => return value_to_scalar(other_value).map(lit),
759 }
760 }
761
762 if !is_graph_entity && matches!(base, Expr::Property(_, _)) {
765 let base_expr = cypher_expr_to_df(base, context)?;
766 return Ok(dummy_udf_expr(
767 "index",
768 vec![base_expr, lit(prop.to_string())],
769 ));
770 }
771
772 if is_graph_entity {
773 Ok(DfExpr::Column(Column::from_name(col_name)))
774 } else {
775 let base_expr = DfExpr::Column(Column::from_name(var_name));
776 Ok(dummy_udf_expr(
777 "index",
778 vec![base_expr, lit(prop.to_string())],
779 ))
780 }
781 } else {
782 if let Some(expr) = try_temporal_accessor(cypher_expr_to_df(base, context)?, prop) {
784 return Ok(expr);
785 }
786
787 if let Expr::Parameter(param_name) = base {
789 if let Some(ctx) = context
790 && let Some(value) = ctx.parameters.get(param_name)
791 {
792 if let Value::Map(map) = value {
793 let extracted = map.get(prop).cloned().unwrap_or(Value::Null);
794 return value_to_scalar(&extracted).map(lit);
795 }
796 return Ok(lit(ScalarValue::Null));
797 }
798 return Err(anyhow!("Unresolved parameter: ${}", param_name));
799 }
800
801 let base_expr = cypher_expr_to_df(base, context)?;
802 Ok(dummy_udf_expr(
803 "index",
804 vec![base_expr, lit(prop.to_string())],
805 ))
806 }
807}
808
809fn translate_list_literal(items: &[Expr], context: Option<&TranslationContext>) -> Result<DfExpr> {
811 let mut has_string = false;
813 let mut has_bool = false;
814 let mut has_list = false;
815 let mut has_map = false;
816 let mut has_numeric = false;
817 let mut has_graph_entity = false;
818 let mut has_temporal = false;
819
820 for item in items {
821 match item {
822 Expr::Literal(CypherLiteral::Float(_)) | Expr::Literal(CypherLiteral::Integer(_)) => {
823 has_numeric = true
824 }
825 Expr::Literal(CypherLiteral::String(_)) => has_string = true,
826 Expr::Literal(CypherLiteral::Bool(_)) => has_bool = true,
827 Expr::List(_) => has_list = true,
828 Expr::Map(_) => has_map = true,
829 Expr::Variable(name)
832 if context
833 .and_then(|ctx| ctx.variable_kinds.get(name))
834 .is_some() =>
835 {
836 has_graph_entity = true;
837 }
838 Expr::FunctionCall { name, .. } => {
841 let upper = name.to_uppercase();
842 if matches!(
843 upper.as_str(),
844 "DATE"
845 | "TIME"
846 | "LOCALTIME"
847 | "LOCALDATETIME"
848 | "DATETIME"
849 | "DURATION"
850 | "DATE.TRUNCATE"
851 | "TIME.TRUNCATE"
852 | "DATETIME.TRUNCATE"
853 | "LOCALDATETIME.TRUNCATE"
854 | "LOCALTIME.TRUNCATE"
855 ) {
856 has_temporal = true;
857 }
858 }
859 _ => {}
861 }
862 }
863
864 let types_count = has_numeric as u8 + has_string as u8 + has_bool as u8 + has_map as u8;
866
867 if has_list || has_map || types_count > 1 || has_graph_entity || has_temporal {
870 if let Some(json_array) = try_items_to_json(items) {
872 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
873 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
874 return Ok(lit(ScalarValue::LargeBinary(Some(cv_bytes))));
875 }
876 let df_args: Vec<DfExpr> = items
878 .iter()
879 .map(|item| cypher_expr_to_df(item, context))
880 .collect::<Result<_>>()?;
881 return Ok(dummy_udf_expr("_make_cypher_list", df_args));
882 }
883
884 let mut df_args = Vec::with_capacity(items.len());
887 let mut has_float = false;
888 let mut has_int = false;
889 let mut has_other = false;
890
891 for item in items {
892 match item {
893 Expr::Literal(CypherLiteral::Float(_)) => has_float = true,
894 Expr::Literal(CypherLiteral::Integer(_)) => has_int = true,
895 _ => has_other = true,
896 }
897 df_args.push(cypher_expr_to_df(item, context)?);
898 }
899
900 if df_args.is_empty() {
901 let empty_arr =
903 ScalarValue::new_list_nullable(&[], &datafusion::arrow::datatypes::DataType::Null);
904 Ok(lit(ScalarValue::List(empty_arr)))
905 } else if has_float && has_int && !has_other {
906 let promoted_args = df_args
908 .into_iter()
909 .map(|e| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64))
910 .collect();
911 Ok(datafusion::functions_nested::expr_fn::make_array(
912 promoted_args,
913 ))
914 } else {
915 let non_null_type = df_args.iter().find_map(|e| {
919 if let DfExpr::Literal(sv, _) = e {
920 let dt = sv.data_type();
921 if dt != datafusion::arrow::datatypes::DataType::Null {
922 return Some(dt);
923 }
924 }
925 None
926 });
927 if let Some(ref target_type) = non_null_type {
928 let coerced = df_args
929 .into_iter()
930 .map(|e| {
931 if matches!(&e, DfExpr::Literal(sv, _) if sv.data_type() == datafusion::arrow::datatypes::DataType::Null)
932 {
933 cast_expr(e, target_type.clone())
934 } else {
935 e
936 }
937 })
938 .collect();
939 Ok(datafusion::functions_nested::expr_fn::make_array(coerced))
940 } else {
941 Ok(datafusion::functions_nested::expr_fn::make_array(df_args))
942 }
943 }
944}
945
946fn translate_in_expression(
948 expr: &Expr,
949 list: &Expr,
950 context: Option<&TranslationContext>,
951) -> Result<DfExpr> {
952 let left_expr = if let Expr::Variable(var) = expr
957 && let Some(ctx) = context
958 && let Some(kind) = ctx.variable_kinds.get(var)
959 {
960 match kind {
961 VariableKind::Node | VariableKind::Edge => {
962 let id_col = match kind {
963 VariableKind::Node => COL_VID,
964 VariableKind::Edge => COL_EID,
965 _ => unreachable!(),
966 };
967 cast_expr(
968 DfExpr::Column(Column::from_name(format!("{}.{}", var, id_col))),
969 datafusion::arrow::datatypes::DataType::Int64,
970 )
971 }
972 _ => cypher_expr_to_df(expr, context)?,
973 }
974 } else {
975 cypher_expr_to_df(expr, context)?
976 };
977
978 if let Expr::List(items) = list {
983 if let Some(json_array) = try_items_to_json(items) {
984 let uni_val: uni_common::Value = serde_json::Value::Array(json_array).into();
986 let cv_bytes = uni_common::cypher_value_codec::encode(&uni_val);
987 let list_literal = lit(ScalarValue::LargeBinary(Some(cv_bytes)));
988 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_literal]))
989 } else {
990 let expanded: Vec<DfExpr> = items
992 .iter()
993 .map(|item| cypher_expr_to_df(item, context))
994 .collect::<Result<Vec<_>>>()?;
995 let list_expr = dummy_udf_expr("_make_cypher_list", expanded);
996 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, list_expr]))
997 }
998 } else {
999 let right_expr = cypher_expr_to_df(list, context)?;
1000
1001 if matches!(right_expr, DfExpr::Literal(ScalarValue::Null, _)) {
1006 return Ok(lit(ScalarValue::Boolean(None)));
1007 }
1008
1009 Ok(dummy_udf_expr("_cypher_in", vec![left_expr, right_expr]))
1010 }
1011}
1012
1013fn translate_case_expression(
1015 operand: &Option<Box<Expr>>,
1016 when_then: &[(Expr, Expr)],
1017 else_expr: &Option<Box<Expr>>,
1018 context: Option<&TranslationContext>,
1019) -> Result<DfExpr> {
1020 let mut case_builder = if let Some(match_expr) = operand {
1021 let match_df = cypher_expr_to_df(match_expr, context)?;
1022 datafusion::logical_expr::case(match_df)
1023 } else {
1024 datafusion::logical_expr::when(
1025 cypher_expr_to_df(&when_then[0].0, context)?,
1026 cypher_expr_to_df(&when_then[0].1, context)?,
1027 )
1028 };
1029
1030 let start_idx = if operand.is_some() { 0 } else { 1 };
1031 for (when_expr, then_expr) in when_then.iter().skip(start_idx) {
1032 let when_df = cypher_expr_to_df(when_expr, context)?;
1033 let then_df = cypher_expr_to_df(then_expr, context)?;
1034 case_builder = case_builder.when(when_df, then_df);
1035 }
1036
1037 if let Some(else_e) = else_expr {
1038 let else_df = cypher_expr_to_df(else_e, context)?;
1039 Ok(case_builder.otherwise(else_df)?)
1040 } else {
1041 Ok(case_builder.end()?)
1042 }
1043}
1044
1045fn translate_map_projection(
1047 base: &Expr,
1048 items: &[MapProjectionItem],
1049 context: Option<&TranslationContext>,
1050) -> Result<DfExpr> {
1051 let mut args = Vec::new();
1052 for item in items {
1053 match item {
1054 MapProjectionItem::Property(prop) => {
1055 args.push(lit(prop.clone()));
1056 let prop_expr = cypher_expr_to_df(
1057 &Expr::Property(Box::new(base.clone()), prop.clone()),
1058 context,
1059 )?;
1060 args.push(prop_expr);
1061 }
1062 MapProjectionItem::LiteralEntry(key, expr) => {
1063 args.push(lit(key.clone()));
1064 args.push(cypher_expr_to_df(expr, context)?);
1065 }
1066 MapProjectionItem::Variable(var) => {
1067 args.push(lit(var.clone()));
1068 args.push(DfExpr::Column(Column::from_name(var)));
1069 }
1070 MapProjectionItem::AllProperties => {
1071 args.push(lit("__all__"));
1072 args.push(cypher_expr_to_df(base, context)?);
1073 }
1074 }
1075 }
1076 Ok(dummy_udf_expr("_map_project", args))
1077}
1078
1079fn try_expr_to_json(expr: &Expr) -> Option<serde_json::Value> {
1082 match expr {
1083 Expr::Literal(CypherLiteral::Null) => Some(serde_json::Value::Null),
1084 Expr::Literal(CypherLiteral::Bool(b)) => Some(serde_json::Value::Bool(*b)),
1085 Expr::Literal(CypherLiteral::Integer(i)) => {
1086 Some(serde_json::Value::Number(serde_json::Number::from(*i)))
1087 }
1088 Expr::Literal(CypherLiteral::Float(f)) => serde_json::Number::from_f64(*f)
1089 .map(serde_json::Value::Number)
1090 .or(Some(serde_json::Value::Null)),
1091 Expr::Literal(CypherLiteral::String(s)) => Some(serde_json::Value::String(s.clone())),
1092 Expr::List(items) => try_items_to_json(items).map(serde_json::Value::Array),
1093 Expr::Map(entries) => {
1094 let mut map = serde_json::Map::new();
1095 for (k, v) in entries {
1096 map.insert(k.clone(), try_expr_to_json(v)?);
1097 }
1098 Some(serde_json::Value::Object(map))
1099 }
1100 _ => None,
1101 }
1102}
1103
1104fn try_items_to_json(items: &[Expr]) -> Option<Vec<serde_json::Value>> {
1106 items.iter().map(try_expr_to_json).collect()
1107}
1108
1109fn cypher_literal_to_scalar(lit: &CypherLiteral) -> Result<ScalarValue> {
1111 match lit {
1112 CypherLiteral::Null => Ok(ScalarValue::Null),
1113 CypherLiteral::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1114 CypherLiteral::Integer(i) => Ok(ScalarValue::Int64(Some(*i))),
1115 CypherLiteral::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1116 CypherLiteral::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1117 CypherLiteral::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1118 }
1119}
1120
1121fn value_to_scalar(value: &Value) -> Result<ScalarValue> {
1123 match value {
1124 Value::Null => Ok(ScalarValue::Null),
1125 Value::Bool(b) => Ok(ScalarValue::Boolean(Some(*b))),
1126 Value::Int(i) => Ok(ScalarValue::Int64(Some(*i))),
1127 Value::Float(f) => Ok(ScalarValue::Float64(Some(*f))),
1128 Value::String(s) => Ok(ScalarValue::Utf8(Some(s.clone()))),
1129 Value::List(items) => {
1130 let scalars: Result<Vec<ScalarValue>> = items.iter().map(value_to_scalar).collect();
1132 let scalars = scalars?;
1133
1134 let data_type = infer_common_scalar_type(&scalars);
1136
1137 let typed_scalars: Vec<ScalarValue> = scalars
1139 .into_iter()
1140 .map(|s| {
1141 if matches!(s, ScalarValue::Null) {
1142 return ScalarValue::try_from(&data_type).unwrap_or(ScalarValue::Null);
1143 }
1144
1145 match (s, &data_type) {
1146 (
1147 ScalarValue::Int64(Some(v)),
1148 datafusion::arrow::datatypes::DataType::Float64,
1149 ) => ScalarValue::Float64(Some(v as f64)),
1150 (s, datafusion::arrow::datatypes::DataType::LargeBinary) => {
1151 let s_str = s.to_string();
1153 ScalarValue::LargeBinary(Some(s_str.into_bytes()))
1154 }
1155 (s, datafusion::arrow::datatypes::DataType::Utf8) => {
1156 if matches!(s, ScalarValue::Utf8(_)) {
1158 s
1159 } else {
1160 ScalarValue::Utf8(Some(s.to_string()))
1161 }
1162 }
1163 (s, _) => s,
1164 }
1165 })
1166 .collect();
1167
1168 if typed_scalars.is_empty() {
1170 Ok(ScalarValue::List(ScalarValue::new_list_nullable(
1171 &[],
1172 &data_type,
1173 )))
1174 } else {
1175 Ok(ScalarValue::List(ScalarValue::new_list(
1176 &typed_scalars,
1177 &data_type,
1178 true,
1179 )))
1180 }
1181 }
1182 Value::Map(map) => {
1183 let mut entries: Vec<(&String, &Value)> = map.iter().collect();
1186 entries.sort_by_key(|(k, _)| *k);
1187
1188 if entries.is_empty() {
1189 return Ok(ScalarValue::Struct(Arc::new(
1190 datafusion::arrow::array::StructArray::new_empty_fields(1, None),
1191 )));
1192 }
1193
1194 let mut fields_arrays = Vec::with_capacity(entries.len());
1195
1196 for (k, v) in entries {
1197 let scalar = value_to_scalar(v)?;
1198 let dt = scalar.data_type();
1199 let field = Arc::new(datafusion::arrow::datatypes::Field::new(k, dt, true));
1200 let array = scalar.to_array()?;
1201 fields_arrays.push((field, array));
1202 }
1203
1204 Ok(ScalarValue::Struct(Arc::new(
1205 datafusion::arrow::array::StructArray::from(fields_arrays),
1206 )))
1207 }
1208 Value::Temporal(tv) => {
1209 use uni_common::TemporalValue;
1210 match tv {
1211 TemporalValue::Date { days_since_epoch } => {
1212 Ok(ScalarValue::Date32(Some(*days_since_epoch)))
1213 }
1214 TemporalValue::LocalTime {
1215 nanos_since_midnight,
1216 } => Ok(ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight))),
1217 TemporalValue::Time {
1218 nanos_since_midnight,
1219 offset_seconds,
1220 } => {
1221 use arrow::array::{ArrayRef, Int32Array, StructArray, Time64NanosecondArray};
1223 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1224
1225 let nanos_arr =
1226 Arc::new(Time64NanosecondArray::from(vec![*nanos_since_midnight]))
1227 as ArrayRef;
1228 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1229
1230 let fields = Fields::from(vec![
1231 Field::new(
1232 "nanos_since_midnight",
1233 ArrowDataType::Time64(TimeUnit::Nanosecond),
1234 true,
1235 ),
1236 Field::new("offset_seconds", ArrowDataType::Int32, true),
1237 ]);
1238
1239 let struct_arr = StructArray::new(fields, vec![nanos_arr, offset_arr], None);
1240 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1241 }
1242 TemporalValue::LocalDateTime { nanos_since_epoch } => Ok(
1243 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None),
1244 ),
1245 TemporalValue::DateTime {
1246 nanos_since_epoch,
1247 offset_seconds,
1248 timezone_name,
1249 } => {
1250 use arrow::array::{
1252 ArrayRef, Int32Array, StringArray, StructArray, TimestampNanosecondArray,
1253 };
1254 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, TimeUnit};
1255
1256 let nanos_arr =
1257 Arc::new(TimestampNanosecondArray::from(vec![*nanos_since_epoch]))
1258 as ArrayRef;
1259 let offset_arr = Arc::new(Int32Array::from(vec![*offset_seconds])) as ArrayRef;
1260 let tz_arr =
1261 Arc::new(StringArray::from(vec![timezone_name.clone()])) as ArrayRef;
1262
1263 let fields = Fields::from(vec![
1264 Field::new(
1265 "nanos_since_epoch",
1266 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
1267 true,
1268 ),
1269 Field::new("offset_seconds", ArrowDataType::Int32, true),
1270 Field::new("timezone_name", ArrowDataType::Utf8, true),
1271 ]);
1272
1273 let struct_arr =
1274 StructArray::new(fields, vec![nanos_arr, offset_arr, tz_arr], None);
1275 Ok(ScalarValue::Struct(Arc::new(struct_arr)))
1276 }
1277 TemporalValue::Duration {
1278 months,
1279 days,
1280 nanos,
1281 } => Ok(ScalarValue::IntervalMonthDayNano(Some(
1282 arrow::datatypes::IntervalMonthDayNano {
1283 months: *months as i32,
1284 days: *days as i32,
1285 nanoseconds: *nanos,
1286 },
1287 ))),
1288 TemporalValue::Btic { lo, hi, meta } => {
1289 let btic = uni_btic::Btic::new(*lo, *hi, *meta)
1290 .map_err(|e| anyhow::anyhow!("invalid BTIC value: {}", e))?;
1291 let packed = uni_btic::encode::encode(&btic);
1292 Ok(ScalarValue::FixedSizeBinary(24, Some(packed.to_vec())))
1293 }
1294 }
1295 }
1296 Value::Vector(v) => {
1297 let cv_bytes = uni_common::cypher_value_codec::encode(&Value::Vector(v.clone()));
1299 Ok(ScalarValue::LargeBinary(Some(cv_bytes)))
1300 }
1301 Value::Bytes(b) => Ok(ScalarValue::LargeBinary(Some(b.clone()))),
1302 other => {
1304 let json_val: serde_json::Value = other.clone().into();
1305 let json_str = serde_json::to_string(&json_val)
1306 .map_err(|e| anyhow!("Failed to serialize value: {}", e))?;
1307 Ok(ScalarValue::LargeBinary(Some(json_str.into_bytes())))
1308 }
1309 }
1310}
1311
1312fn translate_binary_op(left: DfExpr, op: &BinaryOp, right: DfExpr) -> Result<DfExpr> {
1314 match op {
1315 BinaryOp::Eq => Ok(left.eq(right)),
1319 BinaryOp::NotEq => Ok(left.not_eq(right)),
1320 BinaryOp::Lt => Ok(left.lt(right)),
1321 BinaryOp::LtEq => Ok(left.lt_eq(right)),
1322 BinaryOp::Gt => Ok(left.gt(right)),
1323 BinaryOp::GtEq => Ok(left.gt_eq(right)),
1324
1325 BinaryOp::And => Ok(left.and(right)),
1327 BinaryOp::Or => Ok(left.or(right)),
1328 BinaryOp::Xor => {
1329 Ok(dummy_udf_expr("_cypher_xor", vec![left, right]))
1331 }
1332
1333 BinaryOp::Add => {
1335 if is_list_expr(&left) || is_list_expr(&right) {
1336 Ok(dummy_udf_expr("_cypher_list_concat", vec![left, right]))
1337 } else {
1338 Ok(left + right)
1339 }
1340 }
1341 BinaryOp::Sub => Ok(left - right),
1342 BinaryOp::Mul => Ok(left * right),
1343 BinaryOp::Div => Ok(left / right),
1344 BinaryOp::Mod => Ok(left % right),
1345 BinaryOp::Pow => {
1346 let left_f = datafusion::logical_expr::cast(
1349 left,
1350 datafusion::arrow::datatypes::DataType::Float64,
1351 );
1352 let right_f = datafusion::logical_expr::cast(
1353 right,
1354 datafusion::arrow::datatypes::DataType::Float64,
1355 );
1356 Ok(datafusion::functions::math::expr_fn::power(left_f, right_f))
1357 }
1358
1359 BinaryOp::Contains => Ok(dummy_udf_expr("_cypher_contains", vec![left, right])),
1361 BinaryOp::StartsWith => Ok(dummy_udf_expr("_cypher_starts_with", vec![left, right])),
1362 BinaryOp::EndsWith => Ok(dummy_udf_expr("_cypher_ends_with", vec![left, right])),
1363
1364 BinaryOp::Regex => {
1365 Ok(datafusion::functions::expr_fn::regexp_match(left, right, None).is_not_null())
1366 }
1367
1368 BinaryOp::ApproxEq => Err(anyhow!(
1369 "Vector similarity operator (~=) cannot be pushed down to DataFusion"
1370 )),
1371 }
1372}
1373
1374macro_rules! check_args {
1379 (1, $df_args:expr, $name:expr) => {
1380 if let Err(e) = require_arg($df_args, $name) {
1381 return Some(Err(e));
1382 }
1383 };
1384 ($n:expr, $df_args:expr, $name:expr) => {
1385 if let Err(e) = require_args($df_args, $n, $name) {
1386 return Some(Err(e));
1387 }
1388 };
1389}
1390
1391fn require_args(df_args: &[DfExpr], count: usize, func_name: &str) -> Result<()> {
1394 if df_args.len() < count {
1395 let noun = if count == 1 { "argument" } else { "arguments" };
1396 return Err(anyhow!("{} requires {} {}", func_name, count, noun));
1397 }
1398 Ok(())
1399}
1400
1401fn require_arg(df_args: &[DfExpr], func_name: &str) -> Result<()> {
1403 require_args(df_args, 1, func_name)
1404}
1405
1406fn first_arg(df_args: &[DfExpr]) -> DfExpr {
1408 df_args[0].clone()
1409}
1410
1411pub fn cast_expr(expr: DfExpr, data_type: datafusion::arrow::datatypes::DataType) -> DfExpr {
1413 DfExpr::Cast(datafusion::logical_expr::Cast {
1414 expr: Box::new(expr),
1415 data_type,
1416 })
1417}
1418
1419pub fn list_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1425 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1426 Arc::new(crate::df_udfs::create_cypher_list_to_cv_udf()),
1427 vec![expr],
1428 ))
1429}
1430
1431pub fn scalar_to_large_binary_expr(expr: DfExpr) -> DfExpr {
1435 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
1436 Arc::new(crate::df_udfs::create_cypher_scalar_to_cv_udf()),
1437 vec![expr],
1438 ))
1439}
1440
1441fn binary_expr(left: DfExpr, op: datafusion::logical_expr::Operator, right: DfExpr) -> DfExpr {
1443 DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
1444 Box::new(left),
1445 op,
1446 Box::new(right),
1447 ))
1448}
1449
1450pub fn comparison_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1455 use datafusion::logical_expr::Operator;
1456 match op {
1457 Operator::Eq => Some("_cypher_equal"),
1458 Operator::NotEq => Some("_cypher_not_equal"),
1459 Operator::Lt => Some("_cypher_lt"),
1460 Operator::LtEq => Some("_cypher_lt_eq"),
1461 Operator::Gt => Some("_cypher_gt"),
1462 Operator::GtEq => Some("_cypher_gt_eq"),
1463 _ => None,
1464 }
1465}
1466
1467fn arithmetic_udf_name(op: datafusion::logical_expr::Operator) -> Option<&'static str> {
1469 use datafusion::logical_expr::Operator;
1470 match op {
1471 Operator::Plus => Some("_cypher_add"),
1472 Operator::Minus => Some("_cypher_sub"),
1473 Operator::Multiply => Some("_cypher_mul"),
1474 Operator::Divide => Some("_cypher_div"),
1475 Operator::Modulo => Some("_cypher_mod"),
1476 _ => None,
1477 }
1478}
1479
1480fn apply_unary_math_f64<F>(df_args: &[DfExpr], func_name: &str, math_fn: F) -> Result<DfExpr>
1485where
1486 F: FnOnce(DfExpr) -> DfExpr,
1487{
1488 require_arg(df_args, func_name)?;
1489 Ok(math_fn(cast_expr(
1490 first_arg(df_args),
1491 datafusion::arrow::datatypes::DataType::Float64,
1492 )))
1493}
1494
1495fn maybe_distinct(expr: DfExpr, distinct: bool, name: &str) -> Result<DfExpr> {
1497 if distinct {
1498 expr.distinct()
1499 .build()
1500 .map_err(|e| anyhow!("Failed to build {} DISTINCT: {}", name, e))
1501 } else {
1502 Ok(expr)
1503 }
1504}
1505
1506fn translate_aggregate_function(
1508 name_upper: &str,
1509 df_args: &[DfExpr],
1510 distinct: bool,
1511) -> Option<Result<DfExpr>> {
1512 match name_upper {
1513 "COUNT" => {
1514 let expr = if df_args.is_empty() {
1515 datafusion::functions_aggregate::count::count(lit(1i64))
1516 } else {
1517 datafusion::functions_aggregate::count::count(first_arg(df_args))
1518 };
1519 Some(maybe_distinct(expr, distinct, "COUNT"))
1520 }
1521 "SUM" => {
1522 check_args!(1, df_args, "SUM");
1523 let udaf = Arc::new(crate::df_udfs::create_cypher_sum_udaf());
1524 Some(maybe_distinct(
1525 udaf.call(vec![first_arg(df_args)]),
1526 distinct,
1527 "SUM",
1528 ))
1529 }
1530 "AVG" => {
1531 check_args!(1, df_args, "AVG");
1532 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1533 let expr = datafusion::functions_aggregate::average::avg(coerced);
1534 Some(maybe_distinct(expr, distinct, "AVG"))
1535 }
1536 "MIN" => {
1537 check_args!(1, df_args, "MIN");
1538 let udaf = Arc::new(crate::df_udfs::create_cypher_min_udaf());
1539 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1540 }
1541 "MAX" => {
1542 check_args!(1, df_args, "MAX");
1543 let udaf = Arc::new(crate::df_udfs::create_cypher_max_udaf());
1544 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1545 }
1546 "PERCENTILEDISC" => {
1547 if df_args.len() != 2 {
1548 return Some(Err(anyhow!(
1549 "percentileDisc() requires exactly 2 arguments"
1550 )));
1551 }
1552 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1553 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_disc_udaf());
1554 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1555 }
1556 "PERCENTILECONT" => {
1557 if df_args.len() != 2 {
1558 return Some(Err(anyhow!(
1559 "percentileCont() requires exactly 2 arguments"
1560 )));
1561 }
1562 let coerced = crate::df_udfs::cypher_to_float64_expr(df_args[0].clone());
1563 let udaf = Arc::new(crate::df_udfs::create_cypher_percentile_cont_udaf());
1564 Some(Ok(udaf.call(vec![coerced, df_args[1].clone()])))
1565 }
1566 "COLLECT" => {
1567 check_args!(1, df_args, "COLLECT");
1568 Some(Ok(crate::df_udfs::create_cypher_collect_expr(
1569 first_arg(df_args),
1570 distinct,
1571 )))
1572 }
1573 "BTIC_MIN" => {
1575 check_args!(1, df_args, "btic_min");
1576 let udaf = Arc::new(crate::df_udfs::create_btic_min_udaf());
1577 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1578 }
1579 "BTIC_MAX" => {
1580 check_args!(1, df_args, "btic_max");
1581 let udaf = Arc::new(crate::df_udfs::create_btic_max_udaf());
1582 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1583 }
1584 "BTIC_SPAN_AGG" => {
1585 check_args!(1, df_args, "btic_span_agg");
1586 let udaf = Arc::new(crate::df_udfs::create_btic_span_agg_udaf());
1587 Some(Ok(udaf.call(vec![first_arg(df_args)])))
1588 }
1589 "BTIC_COUNT_AT" => {
1590 if df_args.len() != 2 {
1591 return Some(Err(anyhow!("btic_count_at requires 2 arguments")));
1592 }
1593 let udaf = Arc::new(crate::df_udfs::create_btic_count_at_udaf());
1594 Some(Ok(udaf.call(df_args.to_vec())))
1595 }
1596 _ => None,
1597 }
1598}
1599
1600fn translate_string_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1603 match name_upper {
1604 "TOSTRING" => {
1605 check_args!(1, df_args, "toString");
1606 Some(Ok(dummy_udf_expr("tostring", df_args.to_vec())))
1607 }
1608 "TOINTEGER" | "TOINT" => {
1609 check_args!(1, df_args, "toInteger");
1610 Some(Ok(dummy_udf_expr("toInteger", df_args.to_vec())))
1611 }
1612 "TOFLOAT" => {
1613 check_args!(1, df_args, "toFloat");
1614 Some(Ok(dummy_udf_expr("toFloat", df_args.to_vec())))
1615 }
1616 "TOBOOLEAN" | "TOBOOL" => {
1617 check_args!(1, df_args, "toBoolean");
1618 Some(Ok(dummy_udf_expr("toBoolean", df_args.to_vec())))
1619 }
1620 "UPPER" | "TOUPPER" => {
1621 check_args!(1, df_args, "upper");
1622 Some(Ok(datafusion::functions::string::expr_fn::upper(
1623 first_arg(df_args),
1624 )))
1625 }
1626 "LOWER" | "TOLOWER" => {
1627 check_args!(1, df_args, "lower");
1628 Some(Ok(datafusion::functions::string::expr_fn::lower(
1629 first_arg(df_args),
1630 )))
1631 }
1632 "SUBSTRING" => {
1633 check_args!(2, df_args, "substring");
1634 Some(Ok(dummy_udf_expr("_cypher_substring", df_args.to_vec())))
1635 }
1636 "TRIM" => {
1637 check_args!(1, df_args, "TRIM");
1638 Some(Ok(datafusion::functions::string::expr_fn::btrim(vec![
1639 first_arg(df_args),
1640 ])))
1641 }
1642 "LTRIM" => {
1643 check_args!(1, df_args, "LTRIM");
1644 Some(Ok(datafusion::functions::string::expr_fn::ltrim(vec![
1645 first_arg(df_args),
1646 ])))
1647 }
1648 "RTRIM" => {
1649 check_args!(1, df_args, "RTRIM");
1650 Some(Ok(datafusion::functions::string::expr_fn::rtrim(vec![
1651 first_arg(df_args),
1652 ])))
1653 }
1654 "LEFT" => {
1655 check_args!(2, df_args, "left");
1656 Some(Ok(datafusion::functions::unicode::expr_fn::left(
1657 df_args[0].clone(),
1658 df_args[1].clone(),
1659 )))
1660 }
1661 "RIGHT" => {
1662 check_args!(2, df_args, "right");
1663 Some(Ok(datafusion::functions::unicode::expr_fn::right(
1664 df_args[0].clone(),
1665 df_args[1].clone(),
1666 )))
1667 }
1668 "REPLACE" => {
1669 check_args!(3, df_args, "replace");
1670 Some(Ok(datafusion::functions::string::expr_fn::replace(
1671 df_args[0].clone(),
1672 df_args[1].clone(),
1673 df_args[2].clone(),
1674 )))
1675 }
1676 "REVERSE" => {
1677 check_args!(1, df_args, "reverse");
1678 Some(Ok(dummy_udf_expr("_cypher_reverse", df_args.to_vec())))
1679 }
1680 "SPLIT" => {
1681 check_args!(2, df_args, "split");
1682 Some(Ok(dummy_udf_expr("_cypher_split", df_args.to_vec())))
1683 }
1684 "SIZE" | "LENGTH" => {
1685 check_args!(1, df_args, name_upper);
1686 Some(Ok(dummy_udf_expr("_cypher_size", df_args.to_vec())))
1687 }
1688 _ => None,
1689 }
1690}
1691
1692fn translate_math_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1695 use datafusion::functions::math::expr_fn;
1696
1697 let unary_f64 =
1699 |name: &str, f: fn(DfExpr) -> DfExpr| Some(apply_unary_math_f64(df_args, name, f));
1700
1701 match name_upper {
1702 "ABS" => {
1703 check_args!(1, df_args, "abs");
1704 Some(Ok(crate::df_udfs::cypher_abs_expr(first_arg(df_args))))
1708 }
1709 "CEIL" | "CEILING" => {
1710 check_args!(1, df_args, "ceil");
1711 Some(Ok(expr_fn::ceil(first_arg(df_args))))
1712 }
1713 "FLOOR" => {
1714 check_args!(1, df_args, "floor");
1715 Some(Ok(expr_fn::floor(first_arg(df_args))))
1716 }
1717 "ROUND" => {
1718 check_args!(1, df_args, "round");
1719 let args = if df_args.len() == 1 {
1720 vec![first_arg(df_args)]
1721 } else {
1722 vec![df_args[0].clone(), df_args[1].clone()]
1723 };
1724 Some(Ok(expr_fn::round(args)))
1725 }
1726 "SIGN" => {
1727 check_args!(1, df_args, "sign");
1728 let coerced = crate::df_udfs::cypher_to_float64_expr(first_arg(df_args));
1729 Some(Ok(expr_fn::signum(coerced)))
1730 }
1731 "SQRT" => unary_f64("sqrt", expr_fn::sqrt),
1732 "LOG" | "LN" => unary_f64("log", expr_fn::ln),
1733 "LOG10" => unary_f64("log10", expr_fn::log10),
1734 "EXP" => unary_f64("exp", expr_fn::exp),
1735 "SIN" => unary_f64("sin", expr_fn::sin),
1736 "COS" => unary_f64("cos", expr_fn::cos),
1737 "TAN" => unary_f64("tan", expr_fn::tan),
1738 "ASIN" => unary_f64("asin", expr_fn::asin),
1739 "ACOS" => unary_f64("acos", expr_fn::acos),
1740 "ATAN" => unary_f64("atan", expr_fn::atan),
1741 "ATAN2" => {
1742 check_args!(2, df_args, "atan2");
1743 let cast_f64 =
1744 |e: DfExpr| cast_expr(e, datafusion::arrow::datatypes::DataType::Float64);
1745 Some(Ok(expr_fn::atan2(
1746 cast_f64(df_args[0].clone()),
1747 cast_f64(df_args[1].clone()),
1748 )))
1749 }
1750 "RAND" | "RANDOM" => Some(Ok(expr_fn::random())),
1751 "E" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::E))),
1752 "PI" if df_args.is_empty() => Some(Ok(lit(std::f64::consts::PI))),
1753 _ => None,
1754 }
1755}
1756
1757fn translate_temporal_function(
1760 name_upper: &str,
1761 name: &str,
1762 df_args: &[DfExpr],
1763 context: Option<&TranslationContext>,
1764) -> Option<Result<DfExpr>> {
1765 match name_upper {
1766 "DATE"
1767 | "TIME"
1768 | "LOCALTIME"
1769 | "LOCALDATETIME"
1770 | "DATETIME"
1771 | "DURATION"
1772 | "YEAR"
1773 | "MONTH"
1774 | "DAY"
1775 | "HOUR"
1776 | "MINUTE"
1777 | "SECOND"
1778 | "DURATION.BETWEEN"
1779 | "DURATION.INMONTHS"
1780 | "DURATION.INDAYS"
1781 | "DURATION.INSECONDS"
1782 | "DATETIME.FROMEPOCH"
1783 | "DATETIME.FROMEPOCHMILLIS"
1784 | "DATE.TRUNCATE"
1785 | "TIME.TRUNCATE"
1786 | "DATETIME.TRUNCATE"
1787 | "LOCALDATETIME.TRUNCATE"
1788 | "LOCALTIME.TRUNCATE"
1789 | "DATETIME.TRANSACTION"
1790 | "DATETIME.STATEMENT"
1791 | "DATETIME.REALTIME"
1792 | "DATE.TRANSACTION"
1793 | "DATE.STATEMENT"
1794 | "DATE.REALTIME"
1795 | "TIME.TRANSACTION"
1796 | "TIME.STATEMENT"
1797 | "TIME.REALTIME"
1798 | "LOCALTIME.TRANSACTION"
1799 | "LOCALTIME.STATEMENT"
1800 | "LOCALTIME.REALTIME"
1801 | "LOCALDATETIME.TRANSACTION"
1802 | "LOCALDATETIME.STATEMENT"
1803 | "LOCALDATETIME.REALTIME" => {
1804 let stmt_time = context.map(|c| c.statement_time);
1808 if can_constant_fold(name_upper, df_args)
1809 && let Ok(folded) = try_constant_fold_temporal(name_upper, df_args, stmt_time)
1810 {
1811 return Some(Ok(folded));
1812 }
1813 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1814 }
1815 _ => None,
1816 }
1817}
1818
1819fn can_constant_fold(name: &str, args: &[DfExpr]) -> bool {
1821 if name.contains("REALTIME") {
1823 return false;
1824 }
1825 if args.is_empty() {
1833 return matches!(
1834 name,
1835 "DATE"
1836 | "TIME"
1837 | "LOCALTIME"
1838 | "LOCALDATETIME"
1839 | "DATETIME"
1840 | "DATE.STATEMENT"
1841 | "TIME.STATEMENT"
1842 | "LOCALTIME.STATEMENT"
1843 | "LOCALDATETIME.STATEMENT"
1844 | "DATETIME.STATEMENT"
1845 | "DATE.TRANSACTION"
1846 | "TIME.TRANSACTION"
1847 | "LOCALTIME.TRANSACTION"
1848 | "LOCALDATETIME.TRANSACTION"
1849 | "DATETIME.TRANSACTION"
1850 );
1851 }
1852 args.iter().all(is_constant_expr)
1854}
1855
1856fn is_constant_expr(expr: &DfExpr) -> bool {
1858 match expr {
1859 DfExpr::Literal(_, _) => true,
1860 DfExpr::ScalarFunction(func) => {
1861 func.args.iter().all(is_constant_expr)
1863 }
1864 _ => false,
1865 }
1866}
1867
1868fn try_constant_fold_temporal(
1874 name: &str,
1875 args: &[DfExpr],
1876 stmt_time: Option<chrono::DateTime<chrono::Utc>>,
1877) -> Result<DfExpr> {
1878 let val_args: Vec<Value> = args
1880 .iter()
1881 .map(extract_constant_value)
1882 .collect::<Result<_>>()?;
1883
1884 let result = if val_args.is_empty() {
1886 if let Some(frozen) = stmt_time {
1887 crate::datetime::eval_datetime_function_with_clock(name, &val_args, frozen)?
1888 } else {
1889 crate::datetime::eval_datetime_function(name, &val_args)?
1890 }
1891 } else {
1892 crate::datetime::eval_datetime_function(name, &val_args)?
1893 };
1894
1895 let scalar = value_to_scalar(&result)?;
1897 Ok(DfExpr::Literal(scalar, None))
1898}
1899
1900fn extract_constant_value(expr: &DfExpr) -> Result<Value> {
1902 use crate::df_udfs::scalar_to_value;
1903 match expr {
1904 DfExpr::Literal(sv, _) => scalar_to_value(sv).map_err(|e| anyhow::anyhow!("{}", e)),
1905 DfExpr::ScalarFunction(func) => {
1906 let mut map = std::collections::HashMap::new();
1909 let pairs: Vec<&DfExpr> = func.args.iter().collect();
1910 for chunk in pairs.chunks(2) {
1911 if let [key_expr, val_expr] = chunk {
1912 let key = match key_expr {
1914 DfExpr::Literal(ScalarValue::Utf8(Some(s)), _) => s.clone(),
1915 DfExpr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => s.clone(),
1916 _ => return Err(anyhow::anyhow!("Expected string key in struct")),
1917 };
1918 let val = extract_constant_value(val_expr)?;
1919 map.insert(key, val);
1920 } else {
1921 return Err(anyhow::anyhow!("Odd number of args in named_struct"));
1922 }
1923 }
1924 Ok(Value::Map(map))
1925 }
1926 _ => Err(anyhow::anyhow!(
1927 "Cannot extract constant value from expression"
1928 )),
1929 }
1930}
1931
1932fn translate_btic_function(
1935 name_upper: &str,
1936 name: &str,
1937 df_args: &[DfExpr],
1938) -> Option<Result<DfExpr>> {
1939 if crate::expr_eval::is_btic_function(name_upper) {
1940 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
1941 } else {
1942 None
1943 }
1944}
1945
1946fn translate_list_function(name_upper: &str, df_args: &[DfExpr]) -> Option<Result<DfExpr>> {
1949 match name_upper {
1950 "HEAD" => {
1951 check_args!(1, df_args, "head");
1952 Some(Ok(dummy_udf_expr("head", df_args.to_vec())))
1953 }
1954 "LAST" => {
1955 check_args!(1, df_args, "last");
1956 Some(Ok(dummy_udf_expr("last", df_args.to_vec())))
1957 }
1958 "TAIL" => {
1959 check_args!(1, df_args, "tail");
1960 Some(Ok(dummy_udf_expr("_cypher_tail", df_args.to_vec())))
1961 }
1962 "RANGE" => {
1963 check_args!(2, df_args, "range");
1964 Some(Ok(dummy_udf_expr("range", df_args.to_vec())))
1965 }
1966 _ => None,
1967 }
1968}
1969
1970fn translate_graph_function(
1973 name_upper: &str,
1974 name: &str,
1975 df_args: &[DfExpr],
1976 args: &[Expr],
1977 context: Option<&TranslationContext>,
1978) -> Option<Result<DfExpr>> {
1979 match name_upper {
1980 "ID" => {
1981 if let Some(Expr::Variable(var)) = args.first() {
1984 let is_edge = context.is_some_and(|ctx| {
1985 ctx.variable_kinds.get(var) == Some(&VariableKind::Edge)
1986 || ctx.mutation_edge_hints.iter().any(|h| h == var)
1987 });
1988 let id_suffix = if is_edge { COL_EID } else { COL_VID };
1989 Some(Ok(DfExpr::Column(Column::from_name(format!(
1990 "{}.{}",
1991 var, id_suffix
1992 )))))
1993 } else {
1994 Some(Ok(dummy_udf_expr("id", df_args.to_vec())))
1995 }
1996 }
1997 "CREATED_AT" | "UPDATED_AT" => {
1998 if let Some(Expr::Variable(var)) = args.first() {
2002 let suffix = if name_upper == "CREATED_AT" {
2003 "_created_at"
2004 } else {
2005 "_updated_at"
2006 };
2007 Some(Ok(DfExpr::Column(Column::from_name(format!(
2008 "{}.{}",
2009 var, suffix
2010 )))))
2011 } else {
2012 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2013 }
2014 }
2015 "LABELS" | "KEYS" => {
2016 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2021 }
2022 "TYPE" => {
2023 if let Some(Expr::Variable(var)) = args.first()
2027 && let Some(ctx) = context
2028 && let Some(label) = ctx.variable_labels.get(var)
2029 {
2030 let eid_col = DfExpr::Column(Column::from_name(format!("{}._eid", var)));
2033 return Some(Ok(DfExpr::Case(datafusion::logical_expr::Case {
2034 expr: None,
2035 when_then_expr: vec![(
2036 Box::new(eid_col.is_not_null()),
2037 Box::new(lit(label.clone())),
2038 )],
2039 else_expr: Some(Box::new(lit(ScalarValue::Utf8(None)))),
2040 })));
2041 }
2042 if let Some(Expr::Variable(var)) = args.first()
2046 && context
2047 .is_some_and(|ctx| ctx.variable_kinds.get(var) == Some(&VariableKind::Edge))
2048 {
2049 return Some(Ok(DfExpr::Column(Column::from_name(format!(
2050 "{}.{}",
2051 var, COL_TYPE
2052 )))));
2053 }
2054 Some(Ok(dummy_udf_expr("type", df_args.to_vec())))
2055 }
2056 "PROPERTIES" => {
2057 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2060 }
2061 "UNI.TEMPORAL.VALIDAT" => {
2062 if let (
2065 Some(Expr::Variable(var)),
2066 Some(Expr::Literal(CypherLiteral::String(start_prop))),
2067 Some(Expr::Literal(CypherLiteral::String(end_prop))),
2068 Some(ts_expr),
2069 ) = (args.first(), args.get(1), args.get(2), args.get(3))
2070 {
2071 let start_col =
2072 DfExpr::Column(Column::from_name(format!("{}.{}", var, start_prop)));
2073 let end_col = DfExpr::Column(Column::from_name(format!("{}.{}", var, end_prop)));
2074 let ts = match cypher_expr_to_df(ts_expr, context) {
2075 Ok(ts) => ts,
2076 Err(e) => return Some(Err(e)),
2077 };
2078
2079 let start_check = start_col.lt_eq(ts.clone());
2081 let end_null = DfExpr::IsNull(Box::new(end_col.clone()));
2083 let end_after = end_col.gt(ts);
2084 let end_check = end_null.or(end_after);
2085
2086 Some(Ok(start_check.and(end_check)))
2087 } else {
2088 Some(Ok(dummy_udf_expr(name, df_args.to_vec())))
2090 }
2091 }
2092 "STARTNODE" | "ENDNODE" => {
2093 let mut udf_args = df_args.to_vec();
2096 let mut seen = std::collections::HashSet::new();
2097 if let Some(ctx) = context {
2098 for (var, kind) in &ctx.variable_kinds {
2100 if matches!(kind, VariableKind::Node) && seen.insert(var.clone()) {
2101 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2102 }
2103 }
2104 for var in &ctx.node_variable_hints {
2107 if seen.insert(var.clone()) {
2108 udf_args.push(DfExpr::Column(Column::from_name(var.clone())));
2109 }
2110 }
2111 }
2112 Some(Ok(dummy_udf_expr(&name_upper.to_lowercase(), udf_args)))
2113 }
2114 "NODES" | "RELATIONSHIPS" => Some(Ok(dummy_udf_expr(name, df_args.to_vec()))),
2115 "HASLABEL" => {
2116 if let Err(e) = require_args(df_args, 2, "hasLabel") {
2117 return Some(Err(e));
2118 }
2119 if let Some(Expr::Variable(var)) = args.first() {
2121 if let Some(Expr::Literal(CypherLiteral::String(label))) = args.get(1) {
2122 let labels_col =
2124 DfExpr::Column(Column::from_name(format!("{}.{}", var, COL_LABELS)));
2125 Some(Ok(datafusion::functions_nested::expr_fn::array_has(
2126 labels_col,
2127 lit(label.clone()),
2128 )))
2129 } else {
2130 Some(Err(anyhow::anyhow!(
2132 "hasLabel requires string literal as second argument for DataFusion translation"
2133 )))
2134 }
2135 } else {
2136 Some(Err(anyhow::anyhow!(
2138 "hasLabel requires variable as first argument for DataFusion translation"
2139 )))
2140 }
2141 }
2142 _ => None,
2143 }
2144}
2145
2146fn translate_function_call(
2148 name: &str,
2149 args: &[Expr],
2150 distinct: bool,
2151 context: Option<&TranslationContext>,
2152) -> Result<DfExpr> {
2153 let df_args: Vec<DfExpr> = args
2154 .iter()
2155 .map(|arg| cypher_expr_to_df(arg, context))
2156 .collect::<Result<Vec<_>>>()?;
2157
2158 let name_upper = name.to_uppercase();
2159
2160 if let Some(result) = translate_aggregate_function(&name_upper, &df_args, distinct) {
2164 return result;
2165 }
2166
2167 if let Some(result) = translate_string_function(&name_upper, &df_args) {
2168 return result;
2169 }
2170
2171 if let Some(result) = translate_math_function(&name_upper, &df_args) {
2172 return result;
2173 }
2174
2175 if let Some(result) = translate_temporal_function(&name_upper, name, &df_args, context) {
2176 return result;
2177 }
2178
2179 if let Some(result) = translate_btic_function(&name_upper, name, &df_args) {
2180 return result;
2181 }
2182
2183 if let Some(result) = translate_list_function(&name_upper, &df_args) {
2184 return result;
2185 }
2186
2187 if let Some(result) = translate_graph_function(&name_upper, name, &df_args, args, context) {
2188 return result;
2189 }
2190
2191 match name_upper.as_str() {
2193 "COALESCE" => {
2194 require_arg(&df_args, "coalesce")?;
2195 if df_args.len() == 1 {
2200 return Ok(df_args.into_iter().next().unwrap());
2201 }
2202 let n = df_args.len();
2203 let (init, last) = df_args.split_at(n - 1);
2204 let mut builder = datafusion::logical_expr::conditional_expressions::CaseBuilder::new(
2205 None,
2206 vec![],
2207 vec![],
2208 None,
2209 );
2210 for arg in init {
2211 builder.when(arg.clone().is_not_null(), arg.clone());
2212 }
2213 return Ok(builder.otherwise(last[0].clone())?);
2214 }
2215 "NULLIF" => {
2216 require_args(&df_args, 2, "nullif")?;
2217 return Ok(datafusion::functions::expr_fn::nullif(
2218 df_args[0].clone(),
2219 df_args[1].clone(),
2220 ));
2221 }
2222 _ => {}
2223 }
2224
2225 match name_upper.as_str() {
2227 "SIMILAR_TO" | "VECTOR_SIMILARITY" => {
2228 return Ok(dummy_udf_expr(&name_upper.to_lowercase(), df_args));
2229 }
2230 _ => {}
2231 }
2232
2233 Ok(dummy_udf_expr(name, df_args))
2235}
2236
2237#[derive(Debug)]
2242struct DummyUdf {
2243 name: String,
2244 signature: datafusion::logical_expr::Signature,
2245 ret_type: datafusion::arrow::datatypes::DataType,
2246}
2247
2248impl DummyUdf {
2249 fn new(name: String) -> Self {
2250 let ret_type = dummy_udf_return_type(&name);
2251 Self {
2252 name,
2253 signature: datafusion::logical_expr::Signature::variadic_any(
2254 datafusion::logical_expr::Volatility::Immutable,
2255 ),
2256 ret_type,
2257 }
2258 }
2259}
2260
2261fn dummy_udf_return_type(name: &str) -> datafusion::arrow::datatypes::DataType {
2274 use datafusion::arrow::datatypes::DataType;
2275 match name {
2276 "_cypher_add"
2280 | "_cypher_sub"
2281 | "_cypher_mul"
2282 | "_cypher_div"
2283 | "_cypher_mod"
2284 | "_cypher_list_concat"
2285 | "_cypher_list_append"
2286 | "_make_cypher_list"
2287 | "_map_project"
2288 | "_cypher_list_to_cv"
2289 | "_cypher_tail" => DataType::LargeBinary,
2290 _ => DataType::Null,
2294 }
2295}
2296
2297impl PartialEq for DummyUdf {
2298 fn eq(&self, other: &Self) -> bool {
2299 self.name == other.name
2300 }
2301}
2302
2303impl Eq for DummyUdf {}
2304
2305impl Hash for DummyUdf {
2306 fn hash<H: Hasher>(&self, state: &mut H) {
2307 self.name.hash(state);
2308 }
2309}
2310
2311pub fn dummy_udf_expr(name: &str, args: Vec<DfExpr>) -> DfExpr {
2313 DfExpr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction {
2314 func: Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
2315 DummyUdf::new(name.to_lowercase()),
2316 )),
2317 args,
2318 })
2319}
2320
2321impl datafusion::logical_expr::ScalarUDFImpl for DummyUdf {
2322 fn as_any(&self) -> &dyn std::any::Any {
2323 self
2324 }
2325
2326 fn name(&self) -> &str {
2327 &self.name
2328 }
2329
2330 fn signature(&self) -> &datafusion::logical_expr::Signature {
2331 &self.signature
2332 }
2333
2334 fn return_type(
2335 &self,
2336 _arg_types: &[datafusion::arrow::datatypes::DataType],
2337 ) -> datafusion::error::Result<datafusion::arrow::datatypes::DataType> {
2338 Ok(self.ret_type.clone())
2341 }
2342
2343 fn invoke_with_args(
2344 &self,
2345 _args: ScalarFunctionArgs,
2346 ) -> datafusion::error::Result<ColumnarValue> {
2347 Err(datafusion::error::DataFusionError::Plan(format!(
2348 "UDF '{}' is not registered. Register it via SessionContext.",
2349 self.name
2350 )))
2351 }
2352}
2353
2354pub fn collect_properties(expr: &Expr) -> Vec<(String, String)> {
2358 let mut properties = Vec::new();
2359 collect_properties_recursive(expr, &mut properties);
2360 properties.sort();
2361 properties.dedup();
2362 properties
2363}
2364
2365fn collect_properties_recursive(expr: &Expr, properties: &mut Vec<(String, String)>) {
2366 match expr {
2367 Expr::PatternComprehension { .. } => {}
2368 Expr::Property(base, prop) => {
2369 if let Ok(var_name) = extract_variable_name(base) {
2370 properties.push((var_name, prop.clone()));
2371 }
2372 collect_properties_recursive(base, properties);
2373 }
2374 Expr::ArrayIndex { array, index } => {
2375 if let Ok(var_name) = extract_variable_name(array)
2376 && let Expr::Literal(CypherLiteral::String(prop_name)) = index.as_ref()
2377 {
2378 properties.push((var_name, prop_name.clone()));
2379 }
2380 collect_properties_recursive(array, properties);
2381 collect_properties_recursive(index, properties);
2382 }
2383 Expr::ArraySlice { array, start, end } => {
2384 collect_properties_recursive(array, properties);
2385 if let Some(s) = start {
2386 collect_properties_recursive(s, properties);
2387 }
2388 if let Some(e) = end {
2389 collect_properties_recursive(e, properties);
2390 }
2391 }
2392 Expr::List(items) => {
2393 for item in items {
2394 collect_properties_recursive(item, properties);
2395 }
2396 }
2397 Expr::Map(entries) => {
2398 for (_, value) in entries {
2399 collect_properties_recursive(value, properties);
2400 }
2401 }
2402 Expr::IsNull(inner) | Expr::IsNotNull(inner) | Expr::IsUnique(inner) => {
2403 collect_properties_recursive(inner, properties);
2404 }
2405 Expr::FunctionCall { args, .. } => {
2406 for arg in args {
2407 collect_properties_recursive(arg, properties);
2408 }
2409 }
2410 Expr::BinaryOp { left, right, .. } => {
2411 collect_properties_recursive(left, properties);
2412 collect_properties_recursive(right, properties);
2413 }
2414 Expr::UnaryOp { expr, .. } => {
2415 collect_properties_recursive(expr, properties);
2416 }
2417 Expr::Case {
2418 expr,
2419 when_then,
2420 else_expr,
2421 } => {
2422 if let Some(e) = expr {
2423 collect_properties_recursive(e, properties);
2424 }
2425 for (when_e, then_e) in when_then {
2426 collect_properties_recursive(when_e, properties);
2427 collect_properties_recursive(then_e, properties);
2428 }
2429 if let Some(e) = else_expr {
2430 collect_properties_recursive(e, properties);
2431 }
2432 }
2433 Expr::Reduce {
2434 init, list, expr, ..
2435 } => {
2436 collect_properties_recursive(init, properties);
2437 collect_properties_recursive(list, properties);
2438 collect_properties_recursive(expr, properties);
2439 }
2440 Expr::Quantifier {
2441 list, predicate, ..
2442 } => {
2443 collect_properties_recursive(list, properties);
2444 collect_properties_recursive(predicate, properties);
2445 }
2446 Expr::ListComprehension {
2447 list,
2448 where_clause,
2449 map_expr,
2450 ..
2451 } => {
2452 collect_properties_recursive(list, properties);
2453 if let Some(filter) = where_clause {
2454 collect_properties_recursive(filter, properties);
2455 }
2456 collect_properties_recursive(map_expr, properties);
2457 }
2458 Expr::In { expr, list } => {
2459 collect_properties_recursive(expr, properties);
2460 collect_properties_recursive(list, properties);
2461 }
2462 Expr::ValidAt {
2463 entity, timestamp, ..
2464 } => {
2465 collect_properties_recursive(entity, properties);
2466 collect_properties_recursive(timestamp, properties);
2467 }
2468 Expr::MapProjection { base, items } => {
2469 collect_properties_recursive(base, properties);
2470 for item in items {
2471 match item {
2472 uni_cypher::ast::MapProjectionItem::Property(prop) => {
2473 if let Ok(var_name) = extract_variable_name(base) {
2474 properties.push((var_name, prop.clone()));
2475 }
2476 }
2477 uni_cypher::ast::MapProjectionItem::AllProperties => {
2478 if let Ok(var_name) = extract_variable_name(base) {
2479 properties.push((var_name, "*".to_string()));
2480 }
2481 }
2482 uni_cypher::ast::MapProjectionItem::LiteralEntry(_, expr) => {
2483 collect_properties_recursive(expr, properties);
2484 }
2485 uni_cypher::ast::MapProjectionItem::Variable(_) => {}
2486 }
2487 }
2488 }
2489 Expr::LabelCheck { expr, .. } => {
2490 collect_properties_recursive(expr, properties);
2491 }
2492 Expr::Wildcard | Expr::Variable(_) | Expr::Parameter(_) | Expr::Literal(_) => {}
2494 Expr::Exists { .. } | Expr::CountSubquery(_) | Expr::CollectSubquery(_) => {}
2495 }
2496}
2497
2498pub fn wider_numeric_type(
2505 a: &datafusion::arrow::datatypes::DataType,
2506 b: &datafusion::arrow::datatypes::DataType,
2507) -> datafusion::arrow::datatypes::DataType {
2508 use datafusion::arrow::datatypes::DataType;
2509
2510 fn numeric_rank(dt: &DataType) -> u8 {
2511 match dt {
2512 DataType::Int8 | DataType::UInt8 => 1,
2513 DataType::Int16 | DataType::UInt16 => 2,
2514 DataType::Int32 | DataType::UInt32 => 3,
2515 DataType::Int64 | DataType::UInt64 => 4,
2516 DataType::Float16 => 5,
2517 DataType::Float32 => 6,
2518 DataType::Float64 => 7,
2519 _ => 0,
2520 }
2521 }
2522
2523 if numeric_rank(a) >= numeric_rank(b) {
2524 a.clone()
2525 } else {
2526 b.clone()
2527 }
2528}
2529
2530fn resolve_column_type_fallback(
2536 expr: &DfExpr,
2537 schema: &datafusion::common::DFSchema,
2538) -> Option<datafusion::arrow::datatypes::DataType> {
2539 if let DfExpr::Column(col) = expr {
2540 let col_name = &col.name;
2541 for (_, field) in schema.iter() {
2543 if field.name() == col_name {
2544 return Some(field.data_type().clone());
2545 }
2546 }
2547 }
2548 None
2549}
2550
2551fn contains_division(expr: &DfExpr) -> bool {
2554 match expr {
2555 DfExpr::BinaryExpr(b) => {
2556 b.op == datafusion::logical_expr::Operator::Divide
2557 || contains_division(&b.left)
2558 || contains_division(&b.right)
2559 }
2560 DfExpr::Cast(c) => contains_division(&c.expr),
2561 DfExpr::TryCast(c) => contains_division(&c.expr),
2562 _ => false,
2563 }
2564}
2565
2566pub fn apply_type_coercion(expr: &DfExpr, schema: &datafusion::common::DFSchema) -> Result<DfExpr> {
2572 use datafusion::arrow::datatypes::DataType;
2573 use datafusion::logical_expr::ExprSchemable;
2574
2575 match expr {
2576 DfExpr::BinaryExpr(binary) => coerce_binary_expr(binary, schema),
2577 DfExpr::ScalarFunction(func) => coerce_scalar_function(func, schema),
2578 DfExpr::Case(case) => coerce_case_expr(case, schema),
2579 DfExpr::InList(in_list) => {
2580 let coerced_expr = apply_type_coercion(&in_list.expr, schema)?;
2581 let coerced_list = in_list
2582 .list
2583 .iter()
2584 .map(|e| apply_type_coercion(e, schema))
2585 .collect::<Result<Vec<_>>>()?;
2586 let expr_type = coerced_expr
2587 .get_type(schema)
2588 .map_err(|e| anyhow!("Failed to get IN expr type: {}", e))?;
2589 crate::cypher_type_coerce::build_cypher_in_list(
2590 coerced_expr,
2591 &expr_type,
2592 coerced_list,
2593 in_list.negated,
2594 schema,
2595 )
2596 }
2597 DfExpr::Not(inner) => {
2598 let coerced_inner = apply_type_coercion(inner, schema)?;
2599 let inner_type = coerced_inner.get_type(schema).ok();
2600 let final_inner = if inner_type
2601 .as_ref()
2602 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8))
2603 {
2604 datafusion::logical_expr::cast(coerced_inner, DataType::Boolean)
2605 } else if inner_type
2606 .as_ref()
2607 .is_some_and(|t| matches!(t, DataType::LargeBinary))
2608 {
2609 dummy_udf_expr("_cv_to_bool", vec![coerced_inner])
2610 } else {
2611 coerced_inner
2612 };
2613 Ok(DfExpr::Not(Box::new(final_inner)))
2614 }
2615 DfExpr::IsNull(inner) => {
2616 let coerced_inner = apply_type_coercion(inner, schema)?;
2617 Ok(coerced_inner.is_null())
2618 }
2619 DfExpr::IsNotNull(inner) => {
2620 let coerced_inner = apply_type_coercion(inner, schema)?;
2621 Ok(coerced_inner.is_not_null())
2622 }
2623 DfExpr::Negative(inner) => {
2624 let coerced_inner = apply_type_coercion(inner, schema)?;
2625 let inner_type = coerced_inner.get_type(schema).ok();
2626 if matches!(inner_type.as_ref(), Some(DataType::LargeBinary)) {
2627 Ok(dummy_udf_expr(
2628 "_cypher_mul",
2629 vec![coerced_inner, lit(ScalarValue::Int64(Some(-1)))],
2630 ))
2631 } else {
2632 Ok(DfExpr::Negative(Box::new(coerced_inner)))
2633 }
2634 }
2635 DfExpr::Cast(cast) => {
2636 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2637 Ok(DfExpr::Cast(datafusion::logical_expr::Cast::new(
2638 Box::new(coerced_inner),
2639 cast.data_type.clone(),
2640 )))
2641 }
2642 DfExpr::TryCast(cast) => {
2643 let coerced_inner = apply_type_coercion(&cast.expr, schema)?;
2644 Ok(DfExpr::TryCast(datafusion::logical_expr::TryCast::new(
2645 Box::new(coerced_inner),
2646 cast.data_type.clone(),
2647 )))
2648 }
2649 DfExpr::Alias(alias) => {
2650 let coerced_inner = apply_type_coercion(&alias.expr, schema)?;
2651 Ok(coerced_inner.alias(alias.name.clone()))
2652 }
2653 DfExpr::AggregateFunction(agg) => coerce_aggregate_function(agg, schema),
2654 _ => Ok(expr.clone()),
2655 }
2656}
2657
2658fn coerce_logical_operands(
2660 left: DfExpr,
2661 right: DfExpr,
2662 op: datafusion::logical_expr::Operator,
2663 schema: &datafusion::common::DFSchema,
2664) -> Option<DfExpr> {
2665 use datafusion::arrow::datatypes::DataType;
2666 use datafusion::logical_expr::ExprSchemable;
2667
2668 if !matches!(
2669 op,
2670 datafusion::logical_expr::Operator::And | datafusion::logical_expr::Operator::Or
2671 ) {
2672 return None;
2673 }
2674 let left_type = left.get_type(schema).ok();
2675 let right_type = right.get_type(schema).ok();
2676 let left_needs_cast = left_type
2677 .as_ref()
2678 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2679 let right_needs_cast = right_type
2680 .as_ref()
2681 .is_some_and(|t| t.is_null() || matches!(t, DataType::Utf8 | DataType::LargeUtf8));
2682 let left_is_lb = left_type
2683 .as_ref()
2684 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2685 let right_is_lb = right_type
2686 .as_ref()
2687 .is_some_and(|t| matches!(t, DataType::LargeBinary));
2688 if !(left_needs_cast || right_needs_cast || left_is_lb || right_is_lb) {
2689 return None;
2690 }
2691 let coerced_left = if left_is_lb {
2692 dummy_udf_expr("_cv_to_bool", vec![left])
2693 } else if left_needs_cast {
2694 datafusion::logical_expr::cast(left, DataType::Boolean)
2695 } else {
2696 left
2697 };
2698 let coerced_right = if right_is_lb {
2699 dummy_udf_expr("_cv_to_bool", vec![right])
2700 } else if right_needs_cast {
2701 datafusion::logical_expr::cast(right, DataType::Boolean)
2702 } else {
2703 right
2704 };
2705 Some(binary_expr(coerced_left, op, coerced_right))
2706}
2707
2708#[expect(
2711 clippy::too_many_arguments,
2712 reason = "Binary coercion needs all context"
2713)]
2714fn coerce_large_binary_ops(
2715 left: &DfExpr,
2716 right: &DfExpr,
2717 left_type: &datafusion::arrow::datatypes::DataType,
2718 right_type: &datafusion::arrow::datatypes::DataType,
2719 left_is_null: bool,
2720 op: datafusion::logical_expr::Operator,
2721 is_comparison: bool,
2722 is_arithmetic: bool,
2723) -> Option<Result<DfExpr>> {
2724 use datafusion::arrow::datatypes::DataType;
2725 use datafusion::logical_expr::Operator;
2726
2727 let left_is_lb = matches!(left_type, DataType::LargeBinary) || left_is_null;
2728 let right_is_lb = matches!(right_type, DataType::LargeBinary) || (right_type.is_null());
2729
2730 if op == Operator::Plus {
2731 if left_is_lb && right_is_lb {
2732 return Some(Ok(dummy_udf_expr(
2733 "_cypher_add",
2734 vec![left.clone(), right.clone()],
2735 )));
2736 }
2737 let left_is_native_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
2738 let right_is_native_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
2739 if left_is_native_list && right_is_native_list {
2740 return Some(Ok(dummy_udf_expr(
2741 "_cypher_list_concat",
2742 vec![left.clone(), right.clone()],
2743 )));
2744 }
2745 if left_is_native_list || right_is_native_list {
2746 return Some(Ok(dummy_udf_expr(
2747 "_cypher_list_append",
2748 vec![left.clone(), right.clone()],
2749 )));
2750 }
2751 }
2752
2753 if (left_is_lb || right_is_lb) && is_comparison {
2754 if let Some(udf_name) = comparison_udf_name(op) {
2755 return Some(Ok(dummy_udf_expr(
2756 udf_name,
2757 vec![left.clone(), right.clone()],
2758 )));
2759 }
2760 return Some(Ok(binary_expr(left.clone(), op, right.clone())));
2761 }
2762
2763 if (left_is_lb || right_is_lb) && is_arithmetic {
2764 let udf_name =
2765 arithmetic_udf_name(op).expect("is_arithmetic guarantees a valid arithmetic operator");
2766 return Some(Ok(dummy_udf_expr(
2767 udf_name,
2768 vec![left.clone(), right.clone()],
2769 )));
2770 }
2771
2772 None
2773}
2774
2775fn coerce_temporal_comparisons(
2777 left: DfExpr,
2778 right: DfExpr,
2779 left_type: &datafusion::arrow::datatypes::DataType,
2780 right_type: &datafusion::arrow::datatypes::DataType,
2781 op: datafusion::logical_expr::Operator,
2782 is_comparison: bool,
2783) -> Option<DfExpr> {
2784 use datafusion::arrow::datatypes::{DataType, TimeUnit};
2785 use datafusion::logical_expr::Operator;
2786
2787 if !is_comparison {
2788 return None;
2789 }
2790
2791 if uni_common::core::schema::is_datetime_struct(left_type)
2793 && uni_common::core::schema::is_datetime_struct(right_type)
2794 {
2795 return Some(binary_expr(
2796 extract_datetime_nanos(left),
2797 op,
2798 extract_datetime_nanos(right),
2799 ));
2800 }
2801
2802 if uni_common::core::schema::is_time_struct(left_type)
2804 && uni_common::core::schema::is_time_struct(right_type)
2805 {
2806 return Some(binary_expr(
2807 extract_time_nanos(left),
2808 op,
2809 extract_time_nanos(right),
2810 ));
2811 }
2812
2813 let left_is_ts = matches!(left_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2815 let right_is_ts = matches!(right_type, DataType::Timestamp(TimeUnit::Nanosecond, _));
2816
2817 if (left_is_ts && uni_common::core::schema::is_datetime_struct(right_type))
2818 || (uni_common::core::schema::is_datetime_struct(left_type) && right_is_ts)
2819 {
2820 let left_nanos = if uni_common::core::schema::is_datetime_struct(left_type) {
2821 extract_datetime_nanos(left)
2822 } else {
2823 left
2824 };
2825 let right_nanos = if uni_common::core::schema::is_datetime_struct(right_type) {
2826 extract_datetime_nanos(right)
2827 } else {
2828 right
2829 };
2830 let ts_type = DataType::Timestamp(TimeUnit::Nanosecond, None);
2831 return Some(binary_expr(
2832 cast_expr(left_nanos, ts_type.clone()),
2833 op,
2834 cast_expr(right_nanos, ts_type),
2835 ));
2836 }
2837
2838 let left_is_duration = matches!(left_type, DataType::Interval(_));
2842 let right_is_duration = matches!(right_type, DataType::Interval(_));
2843 let left_is_temporal_like = uni_common::core::schema::is_datetime_struct(left_type)
2844 || uni_common::core::schema::is_time_struct(left_type)
2845 || matches!(
2846 left_type,
2847 DataType::Timestamp(_, _)
2848 | DataType::Date32
2849 | DataType::Date64
2850 | DataType::Time32(_)
2851 | DataType::Time64(_)
2852 );
2853 let right_is_temporal_like = uni_common::core::schema::is_datetime_struct(right_type)
2854 || uni_common::core::schema::is_time_struct(right_type)
2855 || matches!(
2856 right_type,
2857 DataType::Timestamp(_, _)
2858 | DataType::Date32
2859 | DataType::Date64
2860 | DataType::Time32(_)
2861 | DataType::Time64(_)
2862 );
2863
2864 if (left_is_duration && right_is_temporal_like) || (right_is_duration && left_is_temporal_like)
2865 {
2866 return Some(match op {
2867 Operator::Eq => lit(false),
2868 Operator::NotEq => lit(true),
2869 _ => lit(ScalarValue::Boolean(None)),
2870 });
2871 }
2872
2873 None
2874}
2875
2876fn coerce_mismatched_types(
2879 left: DfExpr,
2880 right: DfExpr,
2881 left_type: &datafusion::arrow::datatypes::DataType,
2882 right_type: &datafusion::arrow::datatypes::DataType,
2883 op: datafusion::logical_expr::Operator,
2884 is_comparison: bool,
2885) -> Option<Result<DfExpr>> {
2886 use datafusion::arrow::datatypes::DataType;
2887 use datafusion::logical_expr::Operator;
2888
2889 if left_type == right_type {
2890 return None;
2891 }
2892
2893 if left_type.is_numeric() && right_type.is_numeric() {
2895 if left_type == &DataType::Int64
2896 && right_type == &DataType::UInt64
2897 && matches!(&left, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2898 {
2899 let coerced_left = datafusion::logical_expr::cast(left, DataType::UInt64);
2900 return Some(Ok(binary_expr(coerced_left, op, right)));
2901 }
2902 if left_type == &DataType::UInt64
2903 && right_type == &DataType::Int64
2904 && matches!(&right, DfExpr::Literal(ScalarValue::Int64(Some(v)), _) if *v >= 0)
2905 {
2906 let coerced_right = datafusion::logical_expr::cast(right, DataType::UInt64);
2907 return Some(Ok(binary_expr(left, op, coerced_right)));
2908 }
2909 let target = wider_numeric_type(left_type, right_type);
2910 let coerced_left = if *left_type != target {
2911 datafusion::logical_expr::cast(left, target.clone())
2912 } else {
2913 left
2914 };
2915 let coerced_right = if *right_type != target {
2916 datafusion::logical_expr::cast(right, target)
2917 } else {
2918 right
2919 };
2920 return Some(Ok(binary_expr(coerced_left, op, coerced_right)));
2921 }
2922
2923 if is_comparison {
2925 match (left_type, right_type) {
2926 (ts @ DataType::Timestamp(..), DataType::Utf8 | DataType::LargeUtf8) => {
2927 let right = normalize_datetime_literal(right);
2928 return Some(Ok(binary_expr(
2929 left,
2930 op,
2931 datafusion::logical_expr::cast(right, ts.clone()),
2932 )));
2933 }
2934 (DataType::Utf8 | DataType::LargeUtf8, ts @ DataType::Timestamp(..)) => {
2935 let left = normalize_datetime_literal(left);
2936 return Some(Ok(binary_expr(
2937 datafusion::logical_expr::cast(left, ts.clone()),
2938 op,
2939 right,
2940 )));
2941 }
2942 _ => {}
2943 }
2944 }
2945
2946 if is_comparison
2948 && let (DataType::List(l_field), DataType::List(r_field)) = (left_type, right_type)
2949 {
2950 let l_inner = l_field.data_type();
2951 let r_inner = r_field.data_type();
2952 if l_inner.is_numeric() && r_inner.is_numeric() && l_inner != r_inner {
2953 let target_inner = wider_numeric_type(l_inner, r_inner);
2954 let target_type = DataType::List(Arc::new(datafusion::arrow::datatypes::Field::new(
2955 "item",
2956 target_inner,
2957 true,
2958 )));
2959 return Some(Ok(binary_expr(
2960 datafusion::logical_expr::cast(left, target_type.clone()),
2961 op,
2962 datafusion::logical_expr::cast(right, target_type),
2963 )));
2964 }
2965 }
2966
2967 if is_primitive_type(left_type) && is_primitive_type(right_type) {
2969 if op == Operator::Plus {
2970 return Some(crate::cypher_type_coerce::build_cypher_plus(
2971 left, left_type, right, right_type,
2972 ));
2973 }
2974 if is_comparison {
2975 return Some(Ok(crate::cypher_type_coerce::build_cypher_comparison(
2976 left, left_type, right, right_type, op,
2977 )));
2978 }
2979 }
2980
2981 None
2982}
2983
2984fn coerce_list_comparisons(
2986 left: DfExpr,
2987 right: DfExpr,
2988 left_type: &datafusion::arrow::datatypes::DataType,
2989 right_type: &datafusion::arrow::datatypes::DataType,
2990 op: datafusion::logical_expr::Operator,
2991 is_comparison: bool,
2992) -> Option<DfExpr> {
2993 use datafusion::arrow::datatypes::DataType;
2994 use datafusion::logical_expr::Operator;
2995
2996 if !is_comparison {
2997 return None;
2998 }
2999
3000 let left_is_list = matches!(left_type, DataType::List(_) | DataType::LargeList(_));
3001 let right_is_list = matches!(right_type, DataType::List(_) | DataType::LargeList(_));
3002
3003 if left_is_list
3005 && right_is_list
3006 && matches!(
3007 op,
3008 Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
3009 )
3010 {
3011 let op_str = match op {
3012 Operator::Lt => "lt",
3013 Operator::LtEq => "lteq",
3014 Operator::Gt => "gt",
3015 Operator::GtEq => "gteq",
3016 _ => unreachable!(),
3017 };
3018 return Some(dummy_udf_expr(
3019 "_cypher_list_compare",
3020 vec![left, right, lit(op_str)],
3021 ));
3022 }
3023
3024 if left_is_list && right_is_list && matches!(op, Operator::Eq | Operator::NotEq) {
3026 let udf_name =
3027 comparison_udf_name(op).expect("Eq|NotEq is always a valid comparison operator");
3028 return Some(dummy_udf_expr(udf_name, vec![left, right]));
3029 }
3030
3031 if (left_is_list != right_is_list)
3033 && !matches!(left_type, DataType::Null)
3034 && !matches!(right_type, DataType::Null)
3035 {
3036 return Some(match op {
3037 Operator::Eq => lit(false),
3038 Operator::NotEq => lit(true),
3039 _ => lit(ScalarValue::Boolean(None)),
3040 });
3041 }
3042
3043 None
3044}
3045
3046fn coerce_binary_expr(
3048 binary: &datafusion::logical_expr::expr::BinaryExpr,
3049 schema: &datafusion::common::DFSchema,
3050) -> Result<DfExpr> {
3051 use datafusion::arrow::datatypes::DataType;
3052 use datafusion::logical_expr::ExprSchemable;
3053 use datafusion::logical_expr::Operator;
3054
3055 let left = apply_type_coercion(&binary.left, schema)?;
3056 let right = apply_type_coercion(&binary.right, schema)?;
3057
3058 let is_comparison = matches!(
3059 binary.op,
3060 Operator::Eq
3061 | Operator::NotEq
3062 | Operator::Lt
3063 | Operator::LtEq
3064 | Operator::Gt
3065 | Operator::GtEq
3066 );
3067 let is_arithmetic = matches!(
3068 binary.op,
3069 Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide | Operator::Modulo
3070 );
3071
3072 if let Some(result) = coerce_logical_operands(left.clone(), right.clone(), binary.op, schema) {
3074 return Ok(result);
3075 }
3076
3077 if is_comparison || is_arithmetic {
3078 let left_type = match left.get_type(schema) {
3079 Ok(t) => t,
3080 Err(e) => {
3081 if let Some(t) = resolve_column_type_fallback(&left, schema) {
3082 t
3083 } else {
3084 log::warn!("Failed to get left type in binary expr: {}", e);
3085 return Ok(binary_expr(left, binary.op, right));
3086 }
3087 }
3088 };
3089 let right_type = match right.get_type(schema) {
3090 Ok(t) => t,
3091 Err(e) => {
3092 if let Some(t) = resolve_column_type_fallback(&right, schema) {
3093 t
3094 } else {
3095 log::warn!("Failed to get right type in binary expr: {}", e);
3096 return Ok(binary_expr(left, binary.op, right));
3097 }
3098 }
3099 };
3100
3101 let left_is_null = left_type.is_null();
3103 let right_is_null = right_type.is_null();
3104 if left_is_null && right_is_null {
3105 return Ok(lit(ScalarValue::Boolean(None)));
3106 }
3107 if left_is_null || right_is_null {
3108 let target = if left_is_null {
3109 &right_type
3110 } else {
3111 &left_type
3112 };
3113 if !matches!(target, DataType::LargeBinary) {
3114 let coerced_left = if left_is_null {
3115 datafusion::logical_expr::cast(left, target.clone())
3116 } else {
3117 left
3118 };
3119 let coerced_right = if right_is_null {
3120 datafusion::logical_expr::cast(right, target.clone())
3121 } else {
3122 right
3123 };
3124 return Ok(binary_expr(coerced_left, binary.op, coerced_right));
3125 }
3126 }
3127
3128 if let Some(result) = coerce_large_binary_ops(
3130 &left,
3131 &right,
3132 &left_type,
3133 &right_type,
3134 left_is_null,
3135 binary.op,
3136 is_comparison,
3137 is_arithmetic,
3138 ) {
3139 return result;
3140 }
3141
3142 if let Some(result) = coerce_temporal_comparisons(
3144 left.clone(),
3145 right.clone(),
3146 &left_type,
3147 &right_type,
3148 binary.op,
3149 is_comparison,
3150 ) {
3151 return Ok(result);
3152 }
3153
3154 let either_struct =
3156 matches!(left_type, DataType::Struct(_)) || matches!(right_type, DataType::Struct(_));
3157 let either_lb_or_struct = (matches!(left_type, DataType::LargeBinary)
3158 || matches!(left_type, DataType::Struct(_)))
3159 && (matches!(right_type, DataType::LargeBinary)
3160 || matches!(right_type, DataType::Struct(_)));
3161 if is_comparison && either_struct && either_lb_or_struct {
3162 if let Some(udf_name) = comparison_udf_name(binary.op) {
3163 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3164 }
3165 return Ok(lit(ScalarValue::Boolean(None)));
3166 }
3167
3168 if is_comparison && (contains_division(&left) || contains_division(&right)) {
3170 let udf_name = comparison_udf_name(binary.op)
3171 .expect("is_comparison guarantees a valid comparison operator");
3172 return Ok(dummy_udf_expr(udf_name, vec![left, right]));
3173 }
3174
3175 if binary.op == Operator::Plus
3177 && (crate::cypher_type_coerce::is_string_type(&left_type)
3178 || crate::cypher_type_coerce::is_string_type(&right_type))
3179 && is_primitive_type(&left_type)
3180 && is_primitive_type(&right_type)
3181 {
3182 return crate::cypher_type_coerce::build_cypher_plus(
3183 left,
3184 &left_type,
3185 right,
3186 &right_type,
3187 );
3188 }
3189
3190 if let Some(result) = coerce_mismatched_types(
3192 left.clone(),
3193 right.clone(),
3194 &left_type,
3195 &right_type,
3196 binary.op,
3197 is_comparison,
3198 ) {
3199 return result;
3200 }
3201
3202 if let Some(result) = coerce_list_comparisons(
3204 left.clone(),
3205 right.clone(),
3206 &left_type,
3207 &right_type,
3208 binary.op,
3209 is_comparison,
3210 ) {
3211 return Ok(result);
3212 }
3213 }
3214
3215 Ok(binary_expr(left, binary.op, right))
3216}
3217
3218fn coerce_scalar_function(
3220 func: &datafusion::logical_expr::expr::ScalarFunction,
3221 schema: &datafusion::common::DFSchema,
3222) -> Result<DfExpr> {
3223 use datafusion::arrow::datatypes::DataType;
3224 use datafusion::logical_expr::ExprSchemable;
3225
3226 let coerced_args: Vec<DfExpr> = func
3227 .args
3228 .iter()
3229 .map(|a| apply_type_coercion(a, schema))
3230 .collect::<Result<Vec<_>>>()?;
3231
3232 if func.func.name().eq_ignore_ascii_case("coalesce") && coerced_args.len() > 1 {
3233 let types: Vec<_> = coerced_args
3234 .iter()
3235 .filter_map(|a| a.get_type(schema).ok())
3236 .collect();
3237 let has_mixed_types = types.windows(2).any(|w| w[0] != w[1]);
3238 if has_mixed_types {
3239 let all_string_like = types
3243 .iter()
3244 .all(|t| matches!(t, DataType::Utf8 | DataType::LargeUtf8 | DataType::Null));
3245 let unified_args: Vec<DfExpr> = if all_string_like {
3246 coerced_args
3247 .into_iter()
3248 .map(|a| datafusion::logical_expr::cast(a, DataType::Utf8))
3249 .collect()
3250 } else {
3251 coerced_args
3253 .into_iter()
3254 .zip(types.iter())
3255 .map(|(arg, t)| match t {
3256 DataType::LargeBinary | DataType::Null => arg,
3257 DataType::List(_) | DataType::LargeList(_) => {
3258 list_to_large_binary_expr(arg)
3259 }
3260 _ => scalar_to_large_binary_expr(arg),
3261 })
3262 .collect()
3263 };
3264 return Ok(DfExpr::ScalarFunction(
3265 datafusion::logical_expr::expr::ScalarFunction {
3266 func: func.func.clone(),
3267 args: unified_args,
3268 },
3269 ));
3270 }
3271 }
3272
3273 Ok(DfExpr::ScalarFunction(
3274 datafusion::logical_expr::expr::ScalarFunction {
3275 func: func.func.clone(),
3276 args: coerced_args,
3277 },
3278 ))
3279}
3280
3281fn coerce_case_expr(
3284 case: &datafusion::logical_expr::expr::Case,
3285 schema: &datafusion::common::DFSchema,
3286) -> Result<DfExpr> {
3287 use datafusion::arrow::datatypes::DataType;
3288 use datafusion::logical_expr::ExprSchemable;
3289
3290 let coerced_operand = case
3291 .expr
3292 .as_ref()
3293 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3294 .transpose()?;
3295 let coerced_when_then = case
3296 .when_then_expr
3297 .iter()
3298 .map(|(w, t)| {
3299 let cw = apply_type_coercion(w, schema)?;
3300 let cw = match cw.get_type(schema).ok() {
3301 Some(DataType::LargeBinary) => dummy_udf_expr("_cv_to_bool", vec![cw]),
3302 _ => cw,
3303 };
3304 let ct = apply_type_coercion(t, schema)?;
3305 Ok((Box::new(cw), Box::new(ct)))
3306 })
3307 .collect::<Result<Vec<_>>>()?;
3308 let coerced_else = case
3309 .else_expr
3310 .as_ref()
3311 .map(|e| apply_type_coercion(e, schema).map(Box::new))
3312 .transpose()?;
3313
3314 let mut result_case = if let Some(operand) = coerced_operand {
3315 crate::cypher_type_coerce::rewrite_simple_case_to_generic(
3316 *operand,
3317 coerced_when_then,
3318 coerced_else,
3319 schema,
3320 )?
3321 } else {
3322 datafusion::logical_expr::expr::Case {
3323 expr: None,
3324 when_then_expr: coerced_when_then,
3325 else_expr: coerced_else,
3326 }
3327 };
3328
3329 crate::cypher_type_coerce::coerce_case_results(&mut result_case, schema)?;
3330
3331 Ok(DfExpr::Case(result_case))
3332}
3333
3334fn coerce_aggregate_function(
3336 agg: &datafusion::logical_expr::expr::AggregateFunction,
3337 schema: &datafusion::common::DFSchema,
3338) -> Result<DfExpr> {
3339 let coerced_args: Vec<DfExpr> = agg
3340 .params
3341 .args
3342 .iter()
3343 .map(|a| apply_type_coercion(a, schema))
3344 .collect::<Result<Vec<_>>>()?;
3345 let coerced_order_by: Vec<datafusion::logical_expr::SortExpr> = agg
3346 .params
3347 .order_by
3348 .iter()
3349 .map(|s| {
3350 let coerced_expr = apply_type_coercion(&s.expr, schema)?;
3351 Ok(datafusion::logical_expr::SortExpr {
3352 expr: coerced_expr,
3353 asc: s.asc,
3354 nulls_first: s.nulls_first,
3355 })
3356 })
3357 .collect::<Result<Vec<_>>>()?;
3358 let coerced_filter = agg
3359 .params
3360 .filter
3361 .as_ref()
3362 .map(|f| apply_type_coercion(f, schema).map(Box::new))
3363 .transpose()?;
3364 Ok(DfExpr::AggregateFunction(
3365 datafusion::logical_expr::expr::AggregateFunction {
3366 func: agg.func.clone(),
3367 params: datafusion::logical_expr::expr::AggregateFunctionParams {
3368 args: coerced_args,
3369 distinct: agg.params.distinct,
3370 filter: coerced_filter,
3371 order_by: coerced_order_by,
3372 null_treatment: agg.params.null_treatment,
3373 },
3374 },
3375 ))
3376}
3377
3378#[cfg(test)]
3379mod tests {
3380 use super::*;
3381 use arrow_array::{
3382 Array, Int32Array, StringArray, Time64NanosecondArray, TimestampNanosecondArray,
3383 };
3384 use uni_common::TemporalValue;
3385 #[test]
3386 fn test_literal_translation() {
3387 let expr = Expr::Literal(CypherLiteral::Integer(42));
3388 let result = cypher_expr_to_df(&expr, None).unwrap();
3389 let s = format!("{:?}", result);
3390 assert!(s.contains("Literal"));
3392 assert!(s.contains("Int64(42)"));
3393 }
3394
3395 #[test]
3396 fn test_property_access_no_context_uses_index() {
3397 let expr = Expr::Property(Box::new(Expr::Variable("n".to_string())), "age".to_string());
3399 let result = cypher_expr_to_df(&expr, None).unwrap();
3400 let s = format!("{}", result);
3401 assert!(
3402 s.contains("index"),
3403 "expected index UDF for non-graph variable, got: {s}"
3404 );
3405 }
3406
3407 #[test]
3408 fn test_comparison_operator() {
3409 let expr = Expr::BinaryOp {
3410 left: Box::new(Expr::Property(
3411 Box::new(Expr::Variable("n".to_string())),
3412 "age".to_string(),
3413 )),
3414 op: BinaryOp::Gt,
3415 right: Box::new(Expr::Literal(CypherLiteral::Integer(30))),
3416 };
3417 let result = cypher_expr_to_df(&expr, None).unwrap();
3418 let s = format!("{:?}", result);
3420 assert!(s.contains("age"));
3421 assert!(s.contains("30"));
3422 }
3423
3424 #[test]
3425 fn test_boolean_operators() {
3426 let expr = Expr::BinaryOp {
3427 left: Box::new(Expr::BinaryOp {
3428 left: Box::new(Expr::Property(
3429 Box::new(Expr::Variable("n".to_string())),
3430 "age".to_string(),
3431 )),
3432 op: BinaryOp::Gt,
3433 right: Box::new(Expr::Literal(CypherLiteral::Integer(18))),
3434 }),
3435 op: BinaryOp::And,
3436 right: Box::new(Expr::BinaryOp {
3437 left: Box::new(Expr::Property(
3438 Box::new(Expr::Variable("n".to_string())),
3439 "active".to_string(),
3440 )),
3441 op: BinaryOp::Eq,
3442 right: Box::new(Expr::Literal(CypherLiteral::Bool(true))),
3443 }),
3444 };
3445 let result = cypher_expr_to_df(&expr, None).unwrap();
3446 let s = format!("{:?}", result);
3447 assert!(s.contains("And"));
3448 }
3449
3450 #[test]
3451 fn test_is_null() {
3452 let expr = Expr::IsNull(Box::new(Expr::Property(
3453 Box::new(Expr::Variable("n".to_string())),
3454 "email".to_string(),
3455 )));
3456 let result = cypher_expr_to_df(&expr, None).unwrap();
3457 let s = format!("{:?}", result);
3458 assert!(s.contains("IsNull"));
3459 }
3460
3461 #[test]
3462 fn test_collect_properties() {
3463 let expr = Expr::BinaryOp {
3464 left: Box::new(Expr::Property(
3465 Box::new(Expr::Variable("n".to_string())),
3466 "name".to_string(),
3467 )),
3468 op: BinaryOp::Eq,
3469 right: Box::new(Expr::Property(
3470 Box::new(Expr::Variable("m".to_string())),
3471 "name".to_string(),
3472 )),
3473 };
3474
3475 let props = collect_properties(&expr);
3476 assert_eq!(props.len(), 2);
3477 assert!(props.contains(&("m".to_string(), "name".to_string())));
3478 assert!(props.contains(&("n".to_string(), "name".to_string())));
3479 }
3480
3481 #[test]
3482 fn test_function_call() {
3483 let expr = Expr::FunctionCall {
3484 name: "count".to_string(),
3485 args: vec![Expr::Wildcard],
3486 distinct: false,
3487 window_spec: None,
3488 };
3489 let result = cypher_expr_to_df(&expr, None).unwrap();
3490 let s = format!("{:?}", result);
3491 assert!(s.to_lowercase().contains("count"));
3492 }
3493
3494 use datafusion::arrow::datatypes::{DataType, Field, Schema};
3499 use datafusion::logical_expr::Operator;
3500
3501 fn make_schema(cols: &[(&str, DataType)]) -> datafusion::common::DFSchema {
3503 let fields: Vec<_> = cols
3504 .iter()
3505 .map(|(name, dt)| Arc::new(Field::new(*name, dt.clone(), true)))
3506 .collect();
3507 let schema = Schema::new(fields);
3508 datafusion::common::DFSchema::try_from(schema).unwrap()
3509 }
3510
3511 fn contains_udf(expr: &DfExpr, name: &str) -> bool {
3513 let s = format!("{}", expr);
3514 s.contains(name)
3515 }
3516
3517 fn is_binary_op(expr: &DfExpr, expected_op: Operator) -> bool {
3519 matches!(expr, DfExpr::BinaryExpr(b) if b.op == expected_op)
3520 }
3521
3522 #[test]
3523 fn test_coercion_lb_eq_int64() {
3524 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3525 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3526 Box::new(col("lb")),
3527 Operator::Eq,
3528 Box::new(col("i")),
3529 ));
3530 let result = apply_type_coercion(&expr, &schema).unwrap();
3531 assert!(
3533 contains_udf(&result, "_cypher_equal"),
3534 "expected _cypher_equal, got: {result}"
3535 );
3536 }
3537
3538 #[test]
3539 fn test_coercion_lb_noteq_int64() {
3540 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3541 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3542 Box::new(col("lb")),
3543 Operator::NotEq,
3544 Box::new(col("i")),
3545 ));
3546 let result = apply_type_coercion(&expr, &schema).unwrap();
3547 assert!(contains_udf(&result, "_cypher_not_equal"));
3549 }
3550
3551 #[test]
3552 fn test_coercion_lb_lt_int64() {
3553 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3554 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3555 Box::new(col("lb")),
3556 Operator::Lt,
3557 Box::new(col("i")),
3558 ));
3559 let result = apply_type_coercion(&expr, &schema).unwrap();
3560 assert!(contains_udf(&result, "_cypher_lt"));
3562 }
3563
3564 #[test]
3565 fn test_coercion_lb_eq_float64() {
3566 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3567 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3568 Box::new(col("lb")),
3569 Operator::Eq,
3570 Box::new(col("f")),
3571 ));
3572 let result = apply_type_coercion(&expr, &schema).unwrap();
3573 assert!(contains_udf(&result, "_cypher_equal"));
3575 }
3576
3577 #[test]
3578 fn test_coercion_lb_eq_utf8() {
3579 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3580 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3581 Box::new(col("lb")),
3582 Operator::Eq,
3583 Box::new(col("s")),
3584 ));
3585 let result = apply_type_coercion(&expr, &schema).unwrap();
3586 assert!(contains_udf(&result, "_cypher_equal"));
3588 }
3589
3590 #[test]
3591 fn test_coercion_lb_eq_bool() {
3592 let schema = make_schema(&[("lb", DataType::LargeBinary), ("b", DataType::Boolean)]);
3593 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3594 Box::new(col("lb")),
3595 Operator::Eq,
3596 Box::new(col("b")),
3597 ));
3598 let result = apply_type_coercion(&expr, &schema).unwrap();
3599 assert!(contains_udf(&result, "_cypher_equal"));
3601 }
3602
3603 #[test]
3604 fn test_coercion_int64_eq_lb() {
3605 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3607 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3608 Box::new(col("i")),
3609 Operator::Eq,
3610 Box::new(col("lb")),
3611 ));
3612 let result = apply_type_coercion(&expr, &schema).unwrap();
3613 assert!(contains_udf(&result, "_cypher_equal"));
3615 }
3616
3617 #[test]
3618 fn test_coercion_float64_gt_lb() {
3619 let schema = make_schema(&[("f", DataType::Float64), ("lb", DataType::LargeBinary)]);
3620 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3621 Box::new(col("f")),
3622 Operator::Gt,
3623 Box::new(col("lb")),
3624 ));
3625 let result = apply_type_coercion(&expr, &schema).unwrap();
3626 assert!(contains_udf(&result, "_cypher_gt"));
3628 }
3629
3630 #[test]
3631 fn test_coercion_both_lb_eq() {
3632 let schema = make_schema(&[
3633 ("lb1", DataType::LargeBinary),
3634 ("lb2", DataType::LargeBinary),
3635 ]);
3636 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3637 Box::new(col("lb1")),
3638 Operator::Eq,
3639 Box::new(col("lb2")),
3640 ));
3641 let result = apply_type_coercion(&expr, &schema).unwrap();
3642 assert!(contains_udf(&result, "_cypher_equal"));
3643 }
3644
3645 #[test]
3646 fn test_coercion_both_lb_lt() {
3647 let schema = make_schema(&[
3648 ("lb1", DataType::LargeBinary),
3649 ("lb2", DataType::LargeBinary),
3650 ]);
3651 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3652 Box::new(col("lb1")),
3653 Operator::Lt,
3654 Box::new(col("lb2")),
3655 ));
3656 let result = apply_type_coercion(&expr, &schema).unwrap();
3657 assert!(contains_udf(&result, "_cypher_lt"));
3658 }
3659
3660 #[test]
3661 fn test_coercion_both_lb_noteq() {
3662 let schema = make_schema(&[
3663 ("lb1", DataType::LargeBinary),
3664 ("lb2", DataType::LargeBinary),
3665 ]);
3666 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3667 Box::new(col("lb1")),
3668 Operator::NotEq,
3669 Box::new(col("lb2")),
3670 ));
3671 let result = apply_type_coercion(&expr, &schema).unwrap();
3672 assert!(contains_udf(&result, "_cypher_not_equal"));
3673 }
3674
3675 #[test]
3676 fn test_coercion_lb_plus_int64() {
3677 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3678 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3679 Box::new(col("lb")),
3680 Operator::Plus,
3681 Box::new(col("i")),
3682 ));
3683 let result = apply_type_coercion(&expr, &schema).unwrap();
3684 assert!(contains_udf(&result, "_cypher_add"));
3685 }
3686
3687 #[test]
3688 fn test_coercion_lb_minus_int64() {
3689 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3690 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3691 Box::new(col("lb")),
3692 Operator::Minus,
3693 Box::new(col("i")),
3694 ));
3695 let result = apply_type_coercion(&expr, &schema).unwrap();
3696 assert!(contains_udf(&result, "_cypher_sub"));
3697 }
3698
3699 #[test]
3700 fn test_coercion_lb_multiply_float64() {
3701 let schema = make_schema(&[("lb", DataType::LargeBinary), ("f", DataType::Float64)]);
3702 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3703 Box::new(col("lb")),
3704 Operator::Multiply,
3705 Box::new(col("f")),
3706 ));
3707 let result = apply_type_coercion(&expr, &schema).unwrap();
3708 assert!(contains_udf(&result, "_cypher_mul"));
3709 }
3710
3711 #[test]
3712 fn test_coercion_int64_plus_lb() {
3713 let schema = make_schema(&[("i", DataType::Int64), ("lb", DataType::LargeBinary)]);
3714 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3715 Box::new(col("i")),
3716 Operator::Plus,
3717 Box::new(col("lb")),
3718 ));
3719 let result = apply_type_coercion(&expr, &schema).unwrap();
3720 assert!(contains_udf(&result, "_cypher_add"));
3721 }
3722
3723 #[test]
3724 fn test_coercion_lb_plus_utf8() {
3725 let schema = make_schema(&[("lb", DataType::LargeBinary), ("s", DataType::Utf8)]);
3727 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3728 Box::new(col("lb")),
3729 Operator::Plus,
3730 Box::new(col("s")),
3731 ));
3732 let result = apply_type_coercion(&expr, &schema).unwrap();
3733 assert!(contains_udf(&result, "_cypher_add"));
3735 }
3736
3737 #[test]
3738 fn test_coercion_and_null_bool() {
3739 let schema = make_schema(&[("b", DataType::Boolean)]);
3740 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3742 Box::new(lit(ScalarValue::Null)),
3743 Operator::And,
3744 Box::new(col("b")),
3745 ));
3746 let result = apply_type_coercion(&expr, &schema).unwrap();
3747 let s = format!("{}", result);
3748 assert!(
3750 s.contains("CAST") || s.contains("Boolean"),
3751 "expected cast to Boolean, got: {s}"
3752 );
3753 assert!(is_binary_op(&result, Operator::And));
3754 }
3755
3756 #[test]
3757 fn test_coercion_bool_and_null() {
3758 let schema = make_schema(&[("b", DataType::Boolean)]);
3759 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3760 Box::new(col("b")),
3761 Operator::And,
3762 Box::new(lit(ScalarValue::Null)),
3763 ));
3764 let result = apply_type_coercion(&expr, &schema).unwrap();
3765 assert!(is_binary_op(&result, Operator::And));
3766 }
3767
3768 #[test]
3769 fn test_coercion_or_null_bool() {
3770 let schema = make_schema(&[("b", DataType::Boolean)]);
3771 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3772 Box::new(lit(ScalarValue::Null)),
3773 Operator::Or,
3774 Box::new(col("b")),
3775 ));
3776 let result = apply_type_coercion(&expr, &schema).unwrap();
3777 assert!(is_binary_op(&result, Operator::Or));
3778 }
3779
3780 #[test]
3781 fn test_coercion_null_and_null() {
3782 let schema = make_schema(&[]);
3783 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3784 Box::new(lit(ScalarValue::Null)),
3785 Operator::And,
3786 Box::new(lit(ScalarValue::Null)),
3787 ));
3788 let result = apply_type_coercion(&expr, &schema).unwrap();
3789 assert!(is_binary_op(&result, Operator::And));
3790 }
3791
3792 #[test]
3793 fn test_coercion_bool_and_bool_noop() {
3794 let schema = make_schema(&[("a", DataType::Boolean), ("b", DataType::Boolean)]);
3795 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3796 Box::new(col("a")),
3797 Operator::And,
3798 Box::new(col("b")),
3799 ));
3800 let result = apply_type_coercion(&expr, &schema).unwrap();
3801 assert!(is_binary_op(&result, Operator::And));
3803 let s = format!("{}", result);
3804 assert!(!s.contains("CAST"), "should not contain CAST: {s}");
3805 }
3806
3807 #[test]
3808 fn test_coercion_case_when_lb() {
3809 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3811 let when_cond = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3812 Box::new(col("lb")),
3813 Operator::Eq,
3814 Box::new(lit(42_i64)),
3815 ));
3816 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3817 expr: None,
3818 when_then_expr: vec![(Box::new(when_cond), Box::new(lit("a")))],
3819 else_expr: Some(Box::new(lit("b"))),
3820 });
3821 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3822 let s = format!("{}", result);
3823 assert!(
3825 s.contains("_cypher_equal"),
3826 "CASE WHEN should have _cypher_equal, got: {s}"
3827 );
3828 }
3829
3830 #[test]
3831 fn test_coercion_case_then_lb() {
3832 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3834 let then_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3835 Box::new(col("lb")),
3836 Operator::Plus,
3837 Box::new(lit(1_i64)),
3838 ));
3839 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3840 expr: None,
3841 when_then_expr: vec![(Box::new(lit(true)), Box::new(then_expr))],
3842 else_expr: Some(Box::new(lit(0_i64))),
3843 });
3844 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3845 let s = format!("{}", result);
3846 assert!(
3847 s.contains("_cypher_add"),
3848 "CASE THEN should have _cypher_add, got: {s}"
3849 );
3850 }
3851
3852 #[test]
3853 fn test_coercion_case_else_lb() {
3854 let schema = make_schema(&[("lb", DataType::LargeBinary)]);
3856 let else_expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3857 Box::new(col("lb")),
3858 Operator::Plus,
3859 Box::new(lit(2_i64)),
3860 ));
3861 let case_expr = DfExpr::Case(datafusion::logical_expr::expr::Case {
3862 expr: None,
3863 when_then_expr: vec![(Box::new(lit(true)), Box::new(lit(1_i64)))],
3864 else_expr: Some(Box::new(else_expr)),
3865 });
3866 let result = apply_type_coercion(&case_expr, &schema).unwrap();
3867 let s = format!("{}", result);
3868 assert!(
3869 s.contains("_cypher_add"),
3870 "CASE ELSE should have _cypher_add, got: {s}"
3871 );
3872 }
3873
3874 #[test]
3875 fn test_coercion_int64_eq_int64_noop() {
3876 let schema = make_schema(&[("a", DataType::Int64), ("b", DataType::Int64)]);
3877 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3878 Box::new(col("a")),
3879 Operator::Eq,
3880 Box::new(col("b")),
3881 ));
3882 let result = apply_type_coercion(&expr, &schema).unwrap();
3883 assert!(is_binary_op(&result, Operator::Eq));
3884 let s = format!("{}", result);
3885 assert!(
3886 !s.contains("_cypher_value"),
3887 "should not contain cypher_value decode: {s}"
3888 );
3889 }
3890
3891 #[test]
3892 fn test_coercion_both_lb_plus() {
3893 let schema = make_schema(&[
3895 ("lb1", DataType::LargeBinary),
3896 ("lb2", DataType::LargeBinary),
3897 ]);
3898 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3899 Box::new(col("lb1")),
3900 Operator::Plus,
3901 Box::new(col("lb2")),
3902 ));
3903 let result = apply_type_coercion(&expr, &schema).unwrap();
3904 assert!(
3905 contains_udf(&result, "_cypher_add"),
3906 "expected _cypher_add, got: {result}"
3907 );
3908 }
3909
3910 #[test]
3911 fn test_coercion_native_list_plus_scalar() {
3912 let schema = make_schema(&[
3914 (
3915 "lst",
3916 DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
3917 ),
3918 ("i", DataType::Int32),
3919 ]);
3920 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3921 Box::new(col("lst")),
3922 Operator::Plus,
3923 Box::new(col("i")),
3924 ));
3925 let result = apply_type_coercion(&expr, &schema).unwrap();
3926 assert!(
3927 contains_udf(&result, "_cypher_list_append"),
3928 "expected _cypher_list_append, got: {result}"
3929 );
3930 }
3931
3932 #[test]
3933 fn test_coercion_lb_plus_int64_unchanged() {
3934 let schema = make_schema(&[("lb", DataType::LargeBinary), ("i", DataType::Int64)]);
3936 let expr = DfExpr::BinaryExpr(datafusion::logical_expr::expr::BinaryExpr::new(
3937 Box::new(col("lb")),
3938 Operator::Plus,
3939 Box::new(col("i")),
3940 ));
3941 let result = apply_type_coercion(&expr, &schema).unwrap();
3942 assert!(
3943 contains_udf(&result, "_cypher_add"),
3944 "expected _cypher_add, got: {result}"
3945 );
3946 }
3947
3948 #[test]
3953 fn test_mixed_list_with_variables_compiles() {
3954 let expr = Expr::List(vec![
3956 Expr::Variable("n".to_string()),
3957 Expr::Literal(CypherLiteral::Integer(1)),
3958 Expr::Literal(CypherLiteral::String("hello".to_string())),
3959 ]);
3960 let result = cypher_expr_to_df(&expr, None).unwrap();
3961 let s = format!("{}", result);
3962 assert!(
3963 s.contains("_make_cypher_list"),
3964 "expected _make_cypher_list UDF call, got: {s}"
3965 );
3966 }
3967
3968 #[test]
3969 fn test_literal_only_mixed_list_uses_cv_fastpath() {
3970 let expr = Expr::List(vec![
3972 Expr::Literal(CypherLiteral::Integer(1)),
3973 Expr::Literal(CypherLiteral::String("hi".to_string())),
3974 Expr::Literal(CypherLiteral::Bool(true)),
3975 ]);
3976 let result = cypher_expr_to_df(&expr, None).unwrap();
3977 assert!(
3978 matches!(result, DfExpr::Literal(..)),
3979 "expected Literal (CypherValue fast path), got: {result}"
3980 );
3981 }
3982
3983 #[test]
3988 fn test_in_mixed_literal_list_uses_cypher_in() {
3989 let expr = Expr::In {
3991 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
3992 list: Box::new(Expr::List(vec![
3993 Expr::Literal(CypherLiteral::String("1".to_string())),
3994 Expr::Literal(CypherLiteral::Integer(2)),
3995 ])),
3996 };
3997 let result = cypher_expr_to_df(&expr, None).unwrap();
3998 let s = format!("{}", result);
3999 assert!(
4000 s.contains("_cypher_in"),
4001 "expected _cypher_in UDF for mixed-type IN list, got: {s}"
4002 );
4003 }
4004
4005 #[test]
4006 fn test_in_homogeneous_literal_list_uses_cypher_in() {
4007 let expr = Expr::In {
4009 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4010 list: Box::new(Expr::List(vec![
4011 Expr::Literal(CypherLiteral::Integer(2)),
4012 Expr::Literal(CypherLiteral::Integer(3)),
4013 ])),
4014 };
4015 let result = cypher_expr_to_df(&expr, None).unwrap();
4016 let s = format!("{}", result);
4017 assert!(
4018 s.contains("_cypher_in"),
4019 "expected _cypher_in UDF for homogeneous IN list, got: {s}"
4020 );
4021 }
4022
4023 #[test]
4024 fn test_in_list_with_variables_uses_make_cypher_list() {
4025 let expr = Expr::In {
4027 expr: Box::new(Expr::Literal(CypherLiteral::Integer(1))),
4028 list: Box::new(Expr::List(vec![
4029 Expr::Variable("x".to_string()),
4030 Expr::Literal(CypherLiteral::Integer(2)),
4031 ])),
4032 };
4033 let result = cypher_expr_to_df(&expr, None).unwrap();
4034 let s = format!("{}", result);
4035 assert!(
4036 s.contains("_cypher_in"),
4037 "expected _cypher_in UDF, got: {s}"
4038 );
4039 assert!(
4040 s.contains("_make_cypher_list"),
4041 "expected _make_cypher_list for variable-containing list, got: {s}"
4042 );
4043 }
4044
4045 #[test]
4050 fn test_property_on_graph_entity_uses_column() {
4051 let mut ctx = TranslationContext::new();
4053 ctx.variable_kinds
4054 .insert("n".to_string(), VariableKind::Node);
4055
4056 let expr = Expr::Property(
4057 Box::new(Expr::Variable("n".to_string())),
4058 "name".to_string(),
4059 );
4060 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4061 let s = format!("{:?}", result);
4062 assert!(
4063 s.contains("Column") && s.contains("n.name"),
4064 "expected flat column 'n.name' for graph entity, got: {s}"
4065 );
4066 }
4067
4068 #[test]
4069 fn test_property_on_non_graph_var_uses_index() {
4070 let ctx = TranslationContext::new();
4072
4073 let expr = Expr::Property(
4074 Box::new(Expr::Variable("map".to_string())),
4075 "name".to_string(),
4076 );
4077 let result = cypher_expr_to_df(&expr, Some(&ctx)).unwrap();
4078 let s = format!("{}", result);
4079 assert!(
4080 s.contains("index"),
4081 "expected index UDF for non-graph variable, got: {s}"
4082 );
4083 }
4084
4085 #[test]
4086 fn test_value_to_scalar_non_empty_map_becomes_struct() {
4087 let mut map = std::collections::HashMap::new();
4088 map.insert("k".to_string(), Value::Int(1));
4089 let scalar = value_to_scalar(&Value::Map(map)).unwrap();
4090 assert!(
4091 matches!(scalar, ScalarValue::Struct(_)),
4092 "expected Struct scalar for map input"
4093 );
4094 }
4095
4096 #[test]
4097 fn test_value_to_scalar_empty_map_becomes_struct() {
4098 let scalar = value_to_scalar(&Value::Map(Default::default())).unwrap();
4099 assert!(
4100 matches!(scalar, ScalarValue::Struct(_)),
4101 "empty map should produce an empty Struct scalar"
4102 );
4103 }
4104
4105 #[test]
4106 fn test_value_to_scalar_null_is_untyped_null() {
4107 let scalar = value_to_scalar(&Value::Null).unwrap();
4108 assert!(
4109 matches!(scalar, ScalarValue::Null),
4110 "expected untyped Null scalar for Value::Null"
4111 );
4112 }
4113
4114 #[test]
4115 fn test_value_to_scalar_datetime_produces_struct() {
4116 let datetime = Value::Temporal(TemporalValue::DateTime {
4118 nanos_since_epoch: 441763200000000000, offset_seconds: 3600, timezone_name: Some("Europe/Paris".to_string()),
4121 });
4122
4123 let scalar = value_to_scalar(&datetime).unwrap();
4124
4125 if let ScalarValue::Struct(struct_arr) = scalar {
4127 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4128 assert_eq!(struct_arr.num_columns(), 3, "expected 3 fields");
4129
4130 let fields = struct_arr.fields();
4132 assert_eq!(fields[0].name(), "nanos_since_epoch");
4133 assert_eq!(fields[1].name(), "offset_seconds");
4134 assert_eq!(fields[2].name(), "timezone_name");
4135
4136 let nanos_col = struct_arr.column(0);
4138 let offset_col = struct_arr.column(1);
4139 let tz_col = struct_arr.column(2);
4140
4141 if let Some(nanos_arr) = nanos_col
4142 .as_any()
4143 .downcast_ref::<TimestampNanosecondArray>()
4144 {
4145 assert_eq!(nanos_arr.value(0), 441763200000000000);
4146 } else {
4147 panic!("Expected TimestampNanosecondArray for nanos field");
4148 }
4149
4150 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4151 assert_eq!(offset_arr.value(0), 3600);
4152 } else {
4153 panic!("Expected Int32Array for offset field");
4154 }
4155
4156 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4157 assert_eq!(tz_arr.value(0), "Europe/Paris");
4158 } else {
4159 panic!("Expected StringArray for timezone_name field");
4160 }
4161 } else {
4162 panic!(
4163 "Expected ScalarValue::Struct for DateTime, got {:?}",
4164 scalar
4165 );
4166 }
4167 }
4168
4169 #[test]
4170 fn test_value_to_scalar_datetime_with_null_timezone() {
4171 let datetime = Value::Temporal(TemporalValue::DateTime {
4173 nanos_since_epoch: 1704067200000000000, offset_seconds: -18000, timezone_name: None,
4176 });
4177
4178 let scalar = value_to_scalar(&datetime).unwrap();
4179
4180 if let ScalarValue::Struct(struct_arr) = scalar {
4181 assert_eq!(struct_arr.num_columns(), 3);
4182
4183 let tz_col = struct_arr.column(2);
4185 if let Some(tz_arr) = tz_col.as_any().downcast_ref::<StringArray>() {
4186 assert!(tz_arr.is_null(0), "expected null timezone_name");
4187 } else {
4188 panic!("Expected StringArray for timezone_name field");
4189 }
4190 } else {
4191 panic!("Expected ScalarValue::Struct for DateTime");
4192 }
4193 }
4194
4195 #[test]
4196 fn test_value_to_scalar_time_produces_struct() {
4197 let time = Value::Temporal(TemporalValue::Time {
4199 nanos_since_midnight: 37845000000000, offset_seconds: 3600, });
4202
4203 let scalar = value_to_scalar(&time).unwrap();
4204
4205 if let ScalarValue::Struct(struct_arr) = scalar {
4207 assert_eq!(struct_arr.len(), 1, "expected single-row struct array");
4208 assert_eq!(struct_arr.num_columns(), 2, "expected 2 fields");
4209
4210 let fields = struct_arr.fields();
4212 assert_eq!(fields[0].name(), "nanos_since_midnight");
4213 assert_eq!(fields[1].name(), "offset_seconds");
4214
4215 let nanos_col = struct_arr.column(0);
4217 let offset_col = struct_arr.column(1);
4218
4219 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4220 assert_eq!(nanos_arr.value(0), 37845000000000);
4221 } else {
4222 panic!("Expected Time64NanosecondArray for nanos_since_midnight field");
4223 }
4224
4225 if let Some(offset_arr) = offset_col.as_any().downcast_ref::<Int32Array>() {
4226 assert_eq!(offset_arr.value(0), 3600);
4227 } else {
4228 panic!("Expected Int32Array for offset field");
4229 }
4230 } else {
4231 panic!("Expected ScalarValue::Struct for Time, got {:?}", scalar);
4232 }
4233 }
4234
4235 #[test]
4236 fn test_value_to_scalar_time_boundary_values() {
4237 let midnight = Value::Temporal(TemporalValue::Time {
4239 nanos_since_midnight: 0,
4240 offset_seconds: 0,
4241 });
4242
4243 let scalar = value_to_scalar(&midnight).unwrap();
4244
4245 if let ScalarValue::Struct(struct_arr) = scalar {
4246 let nanos_col = struct_arr.column(0);
4247 if let Some(nanos_arr) = nanos_col.as_any().downcast_ref::<Time64NanosecondArray>() {
4248 assert_eq!(nanos_arr.value(0), 0);
4249 } else {
4250 panic!("Expected Time64NanosecondArray");
4251 }
4252 } else {
4253 panic!("Expected ScalarValue::Struct for Time");
4254 }
4255 }
4256}