1use arrow::array::ArrayRef;
27use arrow::datatypes::DataType;
28use arrow_array::{
29 Array, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, LargeBinaryArray,
30 LargeStringArray, StringArray, UInt64Array,
31};
32use chrono::Offset;
33use datafusion::error::Result as DFResult;
34use datafusion::logical_expr::{
35 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
36 Volatility,
37};
38use datafusion::prelude::SessionContext;
39use datafusion::scalar::ScalarValue;
40use std::any::Any;
41use std::hash::{Hash, Hasher};
42use std::sync::Arc;
43use uni_common::Value;
44use uni_cypher::ast::BinaryOp;
45use uni_store::storage::arrow_convert::values_to_array;
46
47use super::expr_eval::cypher_eq;
48
49macro_rules! impl_udf_eq_hash {
53 ($type:ty) => {
54 impl PartialEq for $type {
55 fn eq(&self, other: &Self) -> bool {
56 self.signature == other.signature
57 }
58 }
59
60 impl Eq for $type {}
61
62 impl Hash for $type {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.name().hash(state);
65 }
66 }
67 };
68}
69
70pub fn register_cypher_udfs(ctx: &SessionContext) -> DFResult<()> {
80 ctx.register_udf(create_id_udf());
81 ctx.register_udf(create_type_udf());
82 ctx.register_udf(create_keys_udf());
83 ctx.register_udf(create_properties_udf());
84 ctx.register_udf(create_labels_udf());
85 ctx.register_udf(create_nodes_udf());
86 ctx.register_udf(create_relationships_udf());
87 ctx.register_udf(create_range_udf());
88 ctx.register_udf(create_index_udf());
89 ctx.register_udf(create_startnode_udf());
90 ctx.register_udf(create_endnode_udf());
91
92 ctx.register_udf(create_to_integer_udf());
94 ctx.register_udf(create_to_float_udf());
95 ctx.register_udf(create_to_boolean_udf());
96
97 ctx.register_udf(create_bitwise_or_udf());
99 ctx.register_udf(create_bitwise_and_udf());
100 ctx.register_udf(create_bitwise_xor_udf());
101 ctx.register_udf(create_bitwise_not_udf());
102 ctx.register_udf(create_shift_left_udf());
103 ctx.register_udf(create_shift_right_udf());
104
105 for name in &[
107 "date",
109 "time",
110 "localtime",
111 "localdatetime",
112 "datetime",
113 "duration",
114 "duration.between",
116 "duration.inmonths",
117 "duration.indays",
118 "duration.inseconds",
119 "datetime.fromepoch",
120 "datetime.fromepochmillis",
121 "date.truncate",
123 "time.truncate",
124 "datetime.truncate",
125 "localdatetime.truncate",
126 "localtime.truncate",
127 "datetime.transaction",
129 "datetime.statement",
130 "datetime.realtime",
131 "date.transaction",
132 "date.statement",
133 "date.realtime",
134 "time.transaction",
135 "time.statement",
136 "time.realtime",
137 "localtime.transaction",
138 "localtime.statement",
139 "localtime.realtime",
140 "localdatetime.transaction",
141 "localdatetime.statement",
142 "localdatetime.realtime",
143 ] {
144 ctx.register_udf(create_temporal_udf(name));
145 }
146
147 ctx.register_udf(create_duration_property_udf());
149 ctx.register_udf(create_temporal_property_udf());
150 ctx.register_udf(create_tostring_udf());
151 ctx.register_udf(create_cypher_sort_key_udf());
152 ctx.register_udf(create_has_null_udf());
153 ctx.register_udf(create_cypher_size_udf());
154
155 ctx.register_udf(create_cypher_starts_with_udf());
157 ctx.register_udf(create_cypher_ends_with_udf());
158 ctx.register_udf(create_cypher_contains_udf());
159
160 ctx.register_udf(create_cypher_list_compare_udf());
162
163 ctx.register_udf(create_cypher_xor_udf());
165
166 ctx.register_udf(create_cypher_equal_udf());
168 ctx.register_udf(create_cypher_not_equal_udf());
169 ctx.register_udf(create_cypher_gt_udf());
170 ctx.register_udf(create_cypher_gt_eq_udf());
171 ctx.register_udf(create_cypher_lt_udf());
172 ctx.register_udf(create_cypher_lt_eq_udf());
173
174 ctx.register_udf(create_cv_to_bool_udf());
176
177 ctx.register_udf(create_cypher_add_udf());
179 ctx.register_udf(create_cypher_sub_udf());
180 ctx.register_udf(create_cypher_mul_udf());
181 ctx.register_udf(create_cypher_div_udf());
182 ctx.register_udf(create_cypher_mod_udf());
183
184 ctx.register_udf(create_map_project_udf());
186
187 ctx.register_udf(create_make_cypher_list_udf());
189
190 ctx.register_udf(create_cypher_in_udf());
192
193 ctx.register_udf(create_cypher_list_concat_udf());
195 ctx.register_udf(create_cypher_list_append_udf());
196 ctx.register_udf(create_cypher_list_slice_udf());
197 ctx.register_udf(create_cypher_tail_udf());
198 ctx.register_udf(create_cypher_head_udf());
199 ctx.register_udf(create_cypher_last_udf());
200 ctx.register_udf(create_cypher_reverse_udf());
201 ctx.register_udf(create_cypher_substring_udf());
202 ctx.register_udf(create_cypher_split_udf());
203 ctx.register_udf(create_cypher_list_to_cv_udf());
204 ctx.register_udf(create_cypher_scalar_to_cv_udf());
205
206 for name in &["year", "month", "day", "hour", "minute", "second"] {
208 ctx.register_udf(create_temporal_udf(name));
209 }
210
211 ctx.register_udf(create_cypher_to_float64_udf());
213
214 ctx.register_udf(create_similar_to_udf());
216 ctx.register_udf(create_vector_similarity_udf());
217
218 ctx.register_udaf(create_cypher_min_udaf());
220 ctx.register_udaf(create_cypher_max_udaf());
221 ctx.register_udaf(create_cypher_sum_udaf());
222 ctx.register_udaf(create_cypher_collect_udaf());
223
224 ctx.register_udaf(create_cypher_percentile_disc_udaf());
226 ctx.register_udaf(create_cypher_percentile_cont_udaf());
227
228 Ok(())
229}
230
231pub fn register_custom_udfs(
234 ctx: &SessionContext,
235 registry: &super::executor::custom_functions::CustomFunctionRegistry,
236) -> DFResult<()> {
237 for (name, func) in registry.iter() {
238 let lower = name.to_lowercase();
241 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
242 lower,
243 func.clone(),
244 )));
245 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
247 name.to_string(),
248 func.clone(),
249 )));
250 }
251 Ok(())
252}
253
254struct CustomScalarUdf {
259 name: String,
260 func: super::executor::custom_functions::CustomScalarFn,
261 signature: Signature,
262}
263
264impl CustomScalarUdf {
265 fn new(name: String, func: super::executor::custom_functions::CustomScalarFn) -> Self {
266 Self {
267 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Volatile),
268 name,
269 func,
270 }
271 }
272}
273
274impl std::fmt::Debug for CustomScalarUdf {
275 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276 f.debug_struct("CustomScalarUdf")
277 .field("name", &self.name)
278 .finish()
279 }
280}
281
282impl_udf_eq_hash!(CustomScalarUdf);
283
284impl ScalarUDFImpl for CustomScalarUdf {
285 fn as_any(&self) -> &dyn Any {
286 self
287 }
288
289 fn name(&self) -> &str {
290 &self.name
291 }
292
293 fn signature(&self) -> &Signature {
294 &self.signature
295 }
296
297 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
298 Ok(DataType::LargeBinary)
299 }
300
301 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
302 let func = &self.func;
303 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
304 func(vals).map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
305 })
306 }
307}
308
309pub fn create_id_udf() -> ScalarUDF {
317 ScalarUDF::new_from_impl(IdUdf::new())
318}
319
320#[derive(Debug)]
321struct IdUdf {
322 signature: Signature,
323}
324
325impl IdUdf {
326 fn new() -> Self {
327 Self {
328 signature: Signature::new(
329 TypeSignature::Exact(vec![DataType::UInt64]),
330 Volatility::Immutable,
331 ),
332 }
333 }
334}
335
336impl_udf_eq_hash!(IdUdf);
337
338impl ScalarUDFImpl for IdUdf {
339 fn as_any(&self) -> &dyn Any {
340 self
341 }
342
343 fn name(&self) -> &str {
344 "id"
345 }
346
347 fn signature(&self) -> &Signature {
348 &self.signature
349 }
350
351 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
352 Ok(DataType::UInt64)
353 }
354
355 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
356 if args.args.is_empty() {
358 return Err(datafusion::error::DataFusionError::Execution(
359 "id(): requires 1 argument".to_string(),
360 ));
361 }
362 Ok(args.args[0].clone())
363 }
364}
365
366pub fn create_type_udf() -> ScalarUDF {
374 ScalarUDF::new_from_impl(TypeUdf::new())
375}
376
377#[derive(Debug)]
378struct TypeUdf {
379 signature: Signature,
380}
381
382impl TypeUdf {
383 fn new() -> Self {
384 Self {
385 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
389 }
390 }
391}
392
393impl_udf_eq_hash!(TypeUdf);
394
395impl ScalarUDFImpl for TypeUdf {
396 fn as_any(&self) -> &dyn Any {
397 self
398 }
399
400 fn name(&self) -> &str {
401 "type"
402 }
403
404 fn signature(&self) -> &Signature {
405 &self.signature
406 }
407
408 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
409 Ok(DataType::Utf8)
410 }
411
412 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
413 if args.args.is_empty() {
414 return Err(datafusion::error::DataFusionError::Execution(
415 "type(): requires 1 argument".to_string(),
416 ));
417 }
418 let output_type = DataType::Utf8;
419 invoke_cypher_udf(args, &output_type, |val_args| {
420 if val_args.is_empty() {
421 return Err(datafusion::error::DataFusionError::Execution(
422 "type(): requires 1 argument".to_string(),
423 ));
424 }
425 let val = &val_args[0];
426 match val {
427 Value::Map(map) => {
429 if let Some(Value::String(t)) = map.get("_type") {
430 Ok(Value::String(t.clone()))
431 } else {
432 Err(datafusion::error::DataFusionError::Execution(
434 "TypeError: InvalidArgumentValue - type() requires a relationship argument".to_string(),
435 ))
436 }
437 }
438 Value::Null => Ok(Value::Null),
439 _ => Err(datafusion::error::DataFusionError::Execution(
440 "TypeError: InvalidArgumentValue - type() requires a relationship argument"
441 .to_string(),
442 )),
443 }
444 })
445 }
446}
447
448pub fn create_keys_udf() -> ScalarUDF {
456 ScalarUDF::new_from_impl(KeysUdf::new())
457}
458
459#[derive(Debug)]
460struct KeysUdf {
461 signature: Signature,
462}
463
464impl KeysUdf {
465 fn new() -> Self {
466 Self {
467 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
468 }
469 }
470}
471
472impl_udf_eq_hash!(KeysUdf);
473
474impl ScalarUDFImpl for KeysUdf {
475 fn as_any(&self) -> &dyn Any {
476 self
477 }
478
479 fn name(&self) -> &str {
480 "keys"
481 }
482
483 fn signature(&self) -> &Signature {
484 &self.signature
485 }
486
487 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
488 Ok(DataType::List(Arc::new(
489 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
490 )))
491 }
492
493 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
494 let output_type = self.return_type(&[])?;
495 invoke_cypher_udf(args, &output_type, |val_args| {
496 if val_args.is_empty() {
497 return Err(datafusion::error::DataFusionError::Execution(
498 "keys(): requires 1 argument".to_string(),
499 ));
500 }
501
502 let arg = &val_args[0];
503 let keys = match arg {
504 Value::Map(map) => {
505 let (source, is_entity) = match map.get("_all_props") {
515 Some(Value::Map(all)) => (all, true),
516 _ => (map, false),
517 };
518 let mut key_strings: Vec<String> = source
519 .iter()
520 .filter(|(k, v)| !k.starts_with('_') && (!is_entity || !v.is_null()))
521 .map(|(k, _)| k.clone())
522 .collect();
523 key_strings.sort();
524 key_strings
525 .into_iter()
526 .map(Value::String)
527 .collect::<Vec<_>>()
528 }
529 Value::Null => {
530 return Ok(Value::Null);
531 }
532 _ => {
533 vec![]
536 }
537 };
538
539 Ok(Value::List(keys))
540 })
541 }
542}
543
544pub fn create_properties_udf() -> ScalarUDF {
549 ScalarUDF::new_from_impl(PropertiesUdf::new())
550}
551
552#[derive(Debug)]
553struct PropertiesUdf {
554 signature: Signature,
555}
556
557impl PropertiesUdf {
558 fn new() -> Self {
559 Self {
560 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
561 }
562 }
563}
564
565impl_udf_eq_hash!(PropertiesUdf);
566
567impl ScalarUDFImpl for PropertiesUdf {
568 fn as_any(&self) -> &dyn Any {
569 self
570 }
571
572 fn name(&self) -> &str {
573 "properties"
574 }
575
576 fn signature(&self) -> &Signature {
577 &self.signature
578 }
579
580 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
581 Ok(DataType::LargeBinary)
583 }
584
585 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
586 let output_type = self.return_type(&[])?;
587 invoke_cypher_udf(args, &output_type, |val_args| {
588 if val_args.is_empty() {
589 return Err(datafusion::error::DataFusionError::Execution(
590 "properties(): requires 1 argument".to_string(),
591 ));
592 }
593
594 let arg = &val_args[0];
595 match arg {
596 Value::Map(map) => {
597 let identity_null = map
607 .get("_vid")
608 .map(|v| v.is_null())
609 .or_else(|| map.get("_eid").map(|v| v.is_null()))
610 .unwrap_or(false);
611 if identity_null {
612 return Ok(Value::Null);
613 }
614
615 let source = match map.get("_all_props") {
617 Some(Value::Map(all)) => all,
618 _ => map,
619 };
620 let filtered: std::collections::HashMap<String, Value> = source
622 .iter()
623 .filter(|(k, _)| !k.starts_with('_'))
624 .map(|(k, v)| (k.clone(), v.clone()))
625 .collect();
626 Ok(Value::Map(filtered))
627 }
628 _ => Ok(Value::Null),
629 }
630 })
631 }
632}
633
634pub fn create_index_udf() -> ScalarUDF {
639 ScalarUDF::new_from_impl(IndexUdf::new())
640}
641
642#[derive(Debug)]
643struct IndexUdf {
644 signature: Signature,
645}
646
647impl IndexUdf {
648 fn new() -> Self {
649 Self {
650 signature: Signature::any(2, Volatility::Immutable),
651 }
652 }
653}
654
655impl_udf_eq_hash!(IndexUdf);
656
657impl ScalarUDFImpl for IndexUdf {
658 fn as_any(&self) -> &dyn Any {
659 self
660 }
661
662 fn name(&self) -> &str {
663 "index"
664 }
665
666 fn signature(&self) -> &Signature {
667 &self.signature
668 }
669
670 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
671 Ok(DataType::LargeBinary)
673 }
674
675 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
676 let output_type = self.return_type(&[])?;
677 invoke_cypher_udf(args, &output_type, |val_args| {
678 if val_args.len() != 2 {
679 return Err(datafusion::error::DataFusionError::Execution(
680 "index(): requires 2 arguments".to_string(),
681 ));
682 }
683
684 let container = &val_args[0];
685 let index = &val_args[1];
686
687 let index_as_int = index.as_i64();
691
692 let result = match container {
693 Value::List(arr) => {
694 if let Some(i) = index_as_int {
695 let idx = if i < 0 {
696 let pos = arr.len() as i64 + i;
697 if pos < 0 { -1 } else { pos }
698 } else {
699 i
700 };
701 if idx >= 0 && (idx as usize) < arr.len() {
702 arr[idx as usize].clone()
703 } else {
704 Value::Null
705 }
706 } else if index.is_null() {
707 Value::Null
708 } else {
709 return Err(datafusion::error::DataFusionError::Execution(format!(
710 "TypeError: InvalidArgumentType - list index must be an integer, got: {:?}",
711 index
712 )));
713 }
714 }
715 Value::Map(map) => {
716 if let Some(key) = index.as_str() {
717 if let Some(val) = map.get(key) {
719 val.clone()
720 } else if let Some(Value::Map(all_props)) = map.get("_all_props") {
721 all_props.get(key).cloned().unwrap_or(Value::Null)
723 } else if let Some(Value::Map(props)) = map.get("properties") {
724 props.get(key).cloned().unwrap_or(Value::Null)
726 } else {
727 Value::Null
728 }
729 } else if !index.is_null() {
730 return Err(datafusion::error::DataFusionError::Execution(
731 "index(): map index must be a string".to_string(),
732 ));
733 } else {
734 Value::Null
735 }
736 }
737 Value::Node(node) => {
738 if let Some(key) = index.as_str() {
739 node.properties.get(key).cloned().unwrap_or(Value::Null)
740 } else if !index.is_null() {
741 return Err(datafusion::error::DataFusionError::Execution(
742 "index(): node index must be a string".to_string(),
743 ));
744 } else {
745 Value::Null
746 }
747 }
748 Value::Edge(edge) => {
749 if let Some(key) = index.as_str() {
750 edge.properties.get(key).cloned().unwrap_or(Value::Null)
751 } else if !index.is_null() {
752 return Err(datafusion::error::DataFusionError::Execution(
753 "index(): edge index must be a string".to_string(),
754 ));
755 } else {
756 Value::Null
757 }
758 }
759 Value::Null => Value::Null,
760 _ => {
761 return Err(datafusion::error::DataFusionError::Execution(format!(
762 "TypeError: InvalidArgumentType - cannot index into {:?}",
763 container
764 )));
765 }
766 };
767
768 Ok(result)
769 })
770 }
771}
772
773pub fn create_labels_udf() -> ScalarUDF {
778 ScalarUDF::new_from_impl(LabelsUdf::new())
779}
780
781#[derive(Debug)]
782struct LabelsUdf {
783 signature: Signature,
784}
785
786impl LabelsUdf {
787 fn new() -> Self {
788 Self {
789 signature: Signature::any(1, Volatility::Immutable),
790 }
791 }
792}
793
794impl_udf_eq_hash!(LabelsUdf);
795
796impl ScalarUDFImpl for LabelsUdf {
797 fn as_any(&self) -> &dyn Any {
798 self
799 }
800
801 fn name(&self) -> &str {
802 "labels"
803 }
804
805 fn signature(&self) -> &Signature {
806 &self.signature
807 }
808
809 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
810 Ok(DataType::List(Arc::new(
811 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
812 )))
813 }
814
815 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
816 let output_type = self.return_type(&[])?;
817 invoke_cypher_udf(args, &output_type, |val_args| {
818 if val_args.is_empty() {
819 return Err(datafusion::error::DataFusionError::Execution(
820 "labels(): requires 1 argument".to_string(),
821 ));
822 }
823
824 let node = &val_args[0];
825 match node {
826 Value::Map(map) => {
827 if let Some(Value::List(arr)) = map.get("_labels") {
828 Ok(Value::List(arr.clone()))
829 } else {
830 Err(datafusion::error::DataFusionError::Execution(
832 "TypeError: InvalidArgumentValue - labels() requires a node argument"
833 .to_string(),
834 ))
835 }
836 }
837 Value::Null => Ok(Value::Null),
838 _ => Err(datafusion::error::DataFusionError::Execution(
839 "TypeError: InvalidArgumentValue - labels() requires a node argument"
840 .to_string(),
841 )),
842 }
843 })
844 }
845}
846
847pub fn create_nodes_udf() -> ScalarUDF {
852 ScalarUDF::new_from_impl(NodesUdf::new())
853}
854
855#[derive(Debug)]
856struct NodesUdf {
857 signature: Signature,
858}
859
860impl NodesUdf {
861 fn new() -> Self {
862 Self {
863 signature: Signature::any(1, Volatility::Immutable),
864 }
865 }
866}
867
868impl_udf_eq_hash!(NodesUdf);
869
870impl ScalarUDFImpl for NodesUdf {
871 fn as_any(&self) -> &dyn Any {
872 self
873 }
874
875 fn name(&self) -> &str {
876 "nodes"
877 }
878
879 fn signature(&self) -> &Signature {
880 &self.signature
881 }
882
883 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
884 Ok(DataType::LargeBinary)
885 }
886
887 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
888 let output_type = self.return_type(&[])?;
889 invoke_cypher_udf(args, &output_type, |val_args| {
890 if val_args.is_empty() {
891 return Err(datafusion::error::DataFusionError::Execution(
892 "nodes(): requires 1 argument".to_string(),
893 ));
894 }
895
896 let path = &val_args[0];
897 let nodes = match path {
898 Value::Map(map) => map.get("nodes").cloned().unwrap_or(Value::Null),
899 _ => Value::Null,
900 };
901
902 Ok(nodes)
903 })
904 }
905}
906
907pub fn create_relationships_udf() -> ScalarUDF {
912 ScalarUDF::new_from_impl(RelationshipsUdf::new())
913}
914
915#[derive(Debug)]
916struct RelationshipsUdf {
917 signature: Signature,
918}
919
920impl RelationshipsUdf {
921 fn new() -> Self {
922 Self {
923 signature: Signature::any(1, Volatility::Immutable),
924 }
925 }
926}
927
928impl_udf_eq_hash!(RelationshipsUdf);
929
930impl ScalarUDFImpl for RelationshipsUdf {
931 fn as_any(&self) -> &dyn Any {
932 self
933 }
934
935 fn name(&self) -> &str {
936 "relationships"
937 }
938
939 fn signature(&self) -> &Signature {
940 &self.signature
941 }
942
943 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
944 Ok(DataType::LargeBinary)
945 }
946
947 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
948 let output_type = self.return_type(&[])?;
949 invoke_cypher_udf(args, &output_type, |val_args| {
950 if val_args.is_empty() {
951 return Err(datafusion::error::DataFusionError::Execution(
952 "relationships(): requires 1 argument".to_string(),
953 ));
954 }
955
956 let path = &val_args[0];
957 let rels = match path {
958 Value::Map(map) => map.get("relationships").cloned().unwrap_or(Value::Null),
959 _ => Value::Null,
960 };
961
962 Ok(rels)
963 })
964 }
965}
966
967pub fn create_startnode_udf() -> ScalarUDF {
976 ScalarUDF::new_from_impl(StartNodeUdf::new())
977}
978
979#[derive(Debug)]
980struct StartNodeUdf {
981 signature: Signature,
982}
983
984impl StartNodeUdf {
985 fn new() -> Self {
986 Self {
987 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
988 }
989 }
990}
991
992impl_udf_eq_hash!(StartNodeUdf);
993
994impl ScalarUDFImpl for StartNodeUdf {
995 fn as_any(&self) -> &dyn Any {
996 self
997 }
998
999 fn name(&self) -> &str {
1000 "startnode"
1001 }
1002
1003 fn signature(&self) -> &Signature {
1004 &self.signature
1005 }
1006
1007 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1008 Ok(DataType::LargeBinary)
1009 }
1010
1011 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1012 let output_type = DataType::LargeBinary;
1013 invoke_cypher_udf(args, &output_type, |val_args| {
1014 startnode_endnode_impl(val_args, true)
1015 })
1016 }
1017}
1018
1019pub fn create_endnode_udf() -> ScalarUDF {
1025 ScalarUDF::new_from_impl(EndNodeUdf::new())
1026}
1027
1028#[derive(Debug)]
1029struct EndNodeUdf {
1030 signature: Signature,
1031}
1032
1033impl EndNodeUdf {
1034 fn new() -> Self {
1035 Self {
1036 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1037 }
1038 }
1039}
1040
1041impl_udf_eq_hash!(EndNodeUdf);
1042
1043impl ScalarUDFImpl for EndNodeUdf {
1044 fn as_any(&self) -> &dyn Any {
1045 self
1046 }
1047
1048 fn name(&self) -> &str {
1049 "endnode"
1050 }
1051
1052 fn signature(&self) -> &Signature {
1053 &self.signature
1054 }
1055
1056 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1057 Ok(DataType::LargeBinary)
1058 }
1059
1060 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1061 let output_type = DataType::LargeBinary;
1062 invoke_cypher_udf(args, &output_type, |val_args| {
1063 startnode_endnode_impl(val_args, false)
1064 })
1065 }
1066}
1067
1068fn startnode_endnode_impl(val_args: &[Value], is_start: bool) -> DFResult<Value> {
1073 if val_args.is_empty() {
1074 let fn_name = if is_start { "startNode" } else { "endNode" };
1075 return Err(datafusion::error::DataFusionError::Execution(format!(
1076 "{fn_name}(): requires at least 1 argument"
1077 )));
1078 }
1079
1080 let edge_val = &val_args[0];
1081 let target_vid = extract_endpoint_vid(edge_val, is_start);
1082
1083 let target_vid = match target_vid {
1084 Some(vid) => vid,
1085 None => return Ok(Value::Null),
1086 };
1087
1088 for node_val in val_args.iter().skip(1) {
1090 if let Some(vid) = extract_vid(node_val)
1091 && vid == target_vid
1092 {
1093 return Ok(node_val.clone());
1094 }
1095 }
1096
1097 let mut map = std::collections::HashMap::new();
1099 map.insert("_vid".to_string(), Value::Int(target_vid as i64));
1100 Ok(Value::Map(map))
1101}
1102
1103fn extract_endpoint_vid(val: &Value, is_start: bool) -> Option<u64> {
1105 match val {
1106 Value::Edge(edge) => {
1107 let vid = if is_start { edge.src } else { edge.dst };
1108 Some(vid.as_u64())
1109 }
1110 Value::Map(map) => {
1111 let key = if is_start { "_src_vid" } else { "_dst_vid" };
1113 if let Some(v) = map.get(key) {
1114 return v.as_u64();
1115 }
1116 let key2 = if is_start { "_src" } else { "_dst" };
1118 if let Some(v) = map.get(key2) {
1119 return v.as_u64();
1120 }
1121 let node_key = if is_start { "_startNode" } else { "_endNode" };
1123 if let Some(node_val) = map.get(node_key) {
1124 return extract_vid(node_val);
1125 }
1126 None
1127 }
1128 _ => None,
1129 }
1130}
1131
1132fn extract_vid(val: &Value) -> Option<u64> {
1134 match val {
1135 Value::Map(map) => map.get("_vid").and_then(|v| v.as_u64()),
1136 _ => None,
1137 }
1138}
1139
1140fn extract_i64_range_arg(arg: &ColumnarValue, row_idx: usize, name: &str) -> DFResult<i64> {
1147 match arg {
1148 ColumnarValue::Scalar(sv) => match sv {
1149 ScalarValue::Int8(Some(v)) => Ok(*v as i64),
1150 ScalarValue::Int16(Some(v)) => Ok(*v as i64),
1151 ScalarValue::Int32(Some(v)) => Ok(*v as i64),
1152 ScalarValue::Int64(Some(v)) => Ok(*v),
1153 ScalarValue::UInt8(Some(v)) => Ok(*v as i64),
1154 ScalarValue::UInt16(Some(v)) => Ok(*v as i64),
1155 ScalarValue::UInt32(Some(v)) => Ok(*v as i64),
1156 ScalarValue::UInt64(Some(v)) => Ok(*v as i64),
1157 ScalarValue::LargeBinary(Some(bytes)) => {
1158 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1159 datafusion::error::DataFusionError::Execution(format!(
1160 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1161 name
1162 ))
1163 })
1164 }
1165 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1166 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1167 name
1168 ))),
1169 },
1170 ColumnarValue::Array(arr) => {
1171 if row_idx >= arr.len() || arr.is_null(row_idx) {
1172 return Err(datafusion::error::DataFusionError::Execution(format!(
1173 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1174 name
1175 )));
1176 }
1177 if !arr.is_empty() {
1179 use datafusion::arrow::array::{
1180 Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, UInt16Array,
1181 UInt32Array, UInt64Array,
1182 };
1183 match arr.data_type() {
1184 DataType::Int8 => Ok(arr
1185 .as_any()
1186 .downcast_ref::<Int8Array>()
1187 .unwrap()
1188 .value(row_idx) as i64),
1189 DataType::Int16 => Ok(arr
1190 .as_any()
1191 .downcast_ref::<Int16Array>()
1192 .unwrap()
1193 .value(row_idx) as i64),
1194 DataType::Int32 => Ok(arr
1195 .as_any()
1196 .downcast_ref::<Int32Array>()
1197 .unwrap()
1198 .value(row_idx) as i64),
1199 DataType::Int64 => Ok(arr
1200 .as_any()
1201 .downcast_ref::<Int64Array>()
1202 .unwrap()
1203 .value(row_idx)),
1204 DataType::UInt8 => Ok(arr
1205 .as_any()
1206 .downcast_ref::<UInt8Array>()
1207 .unwrap()
1208 .value(row_idx) as i64),
1209 DataType::UInt16 => Ok(arr
1210 .as_any()
1211 .downcast_ref::<UInt16Array>()
1212 .unwrap()
1213 .value(row_idx) as i64),
1214 DataType::UInt32 => Ok(arr
1215 .as_any()
1216 .downcast_ref::<UInt32Array>()
1217 .unwrap()
1218 .value(row_idx) as i64),
1219 DataType::UInt64 => Ok(arr
1220 .as_any()
1221 .downcast_ref::<UInt64Array>()
1222 .unwrap()
1223 .value(row_idx) as i64),
1224 DataType::LargeBinary => {
1225 let bytes = arr
1226 .as_any()
1227 .downcast_ref::<LargeBinaryArray>()
1228 .unwrap()
1229 .value(row_idx);
1230 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1231 datafusion::error::DataFusionError::Execution(format!(
1232 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1233 name
1234 ))
1235 })
1236 }
1237 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1238 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1239 name
1240 ))),
1241 }
1242 } else {
1243 Err(datafusion::error::DataFusionError::Execution(format!(
1244 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1245 name
1246 )))
1247 }
1248 }
1249 }
1250}
1251
1252pub fn create_range_udf() -> ScalarUDF {
1254 ScalarUDF::new_from_impl(RangeUdf::new())
1255}
1256
1257#[derive(Debug)]
1258struct RangeUdf {
1259 signature: Signature,
1260}
1261
1262impl RangeUdf {
1263 fn new() -> Self {
1264 Self {
1265 signature: Signature::one_of(
1266 vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1267 Volatility::Immutable,
1268 ),
1269 }
1270 }
1271}
1272
1273impl_udf_eq_hash!(RangeUdf);
1274
1275impl ScalarUDFImpl for RangeUdf {
1276 fn as_any(&self) -> &dyn Any {
1277 self
1278 }
1279
1280 fn name(&self) -> &str {
1281 "range"
1282 }
1283
1284 fn signature(&self) -> &Signature {
1285 &self.signature
1286 }
1287
1288 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1289 Ok(DataType::List(Arc::new(
1290 arrow::datatypes::Field::new_list_field(DataType::Int64, true),
1291 )))
1292 }
1293
1294 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1295 if args.args.len() < 2 || args.args.len() > 3 {
1296 return Err(datafusion::error::DataFusionError::Execution(
1297 "range(): requires 2 or 3 arguments".to_string(),
1298 ));
1299 }
1300
1301 let len = args
1302 .args
1303 .iter()
1304 .find_map(|arg| match arg {
1305 ColumnarValue::Array(arr) => Some(arr.len()),
1306 _ => None,
1307 })
1308 .unwrap_or(1);
1309
1310 let mut list_builder =
1311 arrow_array::builder::ListBuilder::new(arrow_array::builder::Int64Builder::new());
1312
1313 for row_idx in 0..len {
1314 let start = extract_i64_range_arg(&args.args[0], row_idx, "start")?;
1315 let end = extract_i64_range_arg(&args.args[1], row_idx, "end")?;
1316 let step = if args.args.len() == 3 {
1317 extract_i64_range_arg(&args.args[2], row_idx, "step")?
1318 } else {
1319 1
1320 };
1321
1322 if step == 0 {
1323 return Err(datafusion::error::DataFusionError::Execution(
1324 "range(): step cannot be zero".to_string(),
1325 ));
1326 }
1327
1328 if step > 0 && start <= end {
1329 let mut current = start;
1330 while current <= end {
1331 list_builder.values().append_value(current);
1332 current += step;
1333 }
1334 } else if step < 0 && start >= end {
1335 let mut current = start;
1336 while current >= end {
1337 list_builder.values().append_value(current);
1338 current += step;
1339 }
1340 }
1341 list_builder.append(true);
1343 }
1344
1345 let list_arr = Arc::new(list_builder.finish()) as ArrayRef;
1346 if len == 1
1347 && args
1348 .args
1349 .iter()
1350 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
1351 {
1352 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
1353 &list_arr, 0,
1354 )?))
1355 } else {
1356 Ok(ColumnarValue::Array(list_arr))
1357 }
1358 }
1359}
1360
1361fn invoke_binary_bitwise_op<F>(
1369 args: &ScalarFunctionArgs,
1370 name: &str,
1371 op: F,
1372) -> DFResult<ColumnarValue>
1373where
1374 F: Fn(i64, i64) -> i64,
1375{
1376 use arrow_array::Int64Array;
1377 use datafusion::common::ScalarValue;
1378 use datafusion::error::DataFusionError;
1379
1380 if args.args.len() != 2 {
1381 return Err(DataFusionError::Execution(format!(
1382 "{}(): requires exactly 2 arguments",
1383 name
1384 )));
1385 }
1386
1387 let left = &args.args[0];
1388 let right = &args.args[1];
1389
1390 match (left, right) {
1391 (
1392 ColumnarValue::Scalar(ScalarValue::Int64(Some(l))),
1393 ColumnarValue::Scalar(ScalarValue::Int64(Some(r))),
1394 ) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*l, *r))))),
1395 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
1396 let l_arr = l_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1397 DataFusionError::Execution(format!("{}(): left array must be Int64", name))
1398 })?;
1399 let r_arr = r_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1400 DataFusionError::Execution(format!("{}(): right array must be Int64", name))
1401 })?;
1402
1403 let result: Int64Array = l_arr
1404 .iter()
1405 .zip(r_arr.iter())
1406 .map(|(l, r)| match (l, r) {
1407 (Some(l), Some(r)) => Some(op(l, r)),
1408 _ => None,
1409 })
1410 .collect();
1411
1412 Ok(ColumnarValue::Array(Arc::new(result)))
1413 }
1414 _ => Err(DataFusionError::Execution(format!(
1415 "{}(): mixed scalar/array not supported",
1416 name
1417 ))),
1418 }
1419}
1420
1421fn invoke_unary_bitwise_op<F>(
1425 args: &ScalarFunctionArgs,
1426 name: &str,
1427 op: F,
1428) -> DFResult<ColumnarValue>
1429where
1430 F: Fn(i64) -> i64,
1431{
1432 use arrow_array::Int64Array;
1433 use datafusion::common::ScalarValue;
1434 use datafusion::error::DataFusionError;
1435
1436 if args.args.len() != 1 {
1437 return Err(DataFusionError::Execution(format!(
1438 "{}(): requires exactly 1 argument",
1439 name
1440 )));
1441 }
1442
1443 let operand = &args.args[0];
1444
1445 match operand {
1446 ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
1447 Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*v)))))
1448 }
1449 ColumnarValue::Array(arr) => {
1450 let arr = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1451 DataFusionError::Execution(format!("{}(): array must be Int64", name))
1452 })?;
1453
1454 let result: Int64Array = arr.iter().map(|v| v.map(&op)).collect();
1455
1456 Ok(ColumnarValue::Array(Arc::new(result)))
1457 }
1458 _ => Err(DataFusionError::Execution(format!(
1459 "{}(): invalid argument type",
1460 name
1461 ))),
1462 }
1463}
1464
1465macro_rules! define_binary_bitwise_udf {
1469 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1470 #[derive(Debug)]
1471 struct $struct_name {
1472 signature: Signature,
1473 }
1474
1475 impl $struct_name {
1476 fn new() -> Self {
1477 Self {
1478 signature: Signature::exact(
1479 vec![DataType::Int64, DataType::Int64],
1480 Volatility::Immutable,
1481 ),
1482 }
1483 }
1484 }
1485
1486 impl_udf_eq_hash!($struct_name);
1487
1488 impl ScalarUDFImpl for $struct_name {
1489 fn as_any(&self) -> &dyn Any {
1490 self
1491 }
1492
1493 fn name(&self) -> &str {
1494 $udf_name
1495 }
1496
1497 fn signature(&self) -> &Signature {
1498 &self.signature
1499 }
1500
1501 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1502 Ok(DataType::Int64)
1503 }
1504
1505 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1506 invoke_binary_bitwise_op(&args, $udf_name, $op)
1507 }
1508 }
1509 };
1510}
1511
1512macro_rules! define_unary_bitwise_udf {
1516 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1517 #[derive(Debug)]
1518 struct $struct_name {
1519 signature: Signature,
1520 }
1521
1522 impl $struct_name {
1523 fn new() -> Self {
1524 Self {
1525 signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
1526 }
1527 }
1528 }
1529
1530 impl_udf_eq_hash!($struct_name);
1531
1532 impl ScalarUDFImpl for $struct_name {
1533 fn as_any(&self) -> &dyn Any {
1534 self
1535 }
1536
1537 fn name(&self) -> &str {
1538 $udf_name
1539 }
1540
1541 fn signature(&self) -> &Signature {
1542 &self.signature
1543 }
1544
1545 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1546 Ok(DataType::Int64)
1547 }
1548
1549 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1550 invoke_unary_bitwise_op(&args, $udf_name, $op)
1551 }
1552 }
1553 };
1554}
1555
1556define_binary_bitwise_udf!(BitwiseOrUdf, "uni.bitwise.or", |l, r| l | r);
1558define_binary_bitwise_udf!(BitwiseAndUdf, "uni.bitwise.and", |l, r| l & r);
1559define_binary_bitwise_udf!(BitwiseXorUdf, "uni.bitwise.xor", |l, r| l ^ r);
1560define_binary_bitwise_udf!(ShiftLeftUdf, "uni.bitwise.shiftLeft", |l, r| l << r);
1561define_binary_bitwise_udf!(ShiftRightUdf, "uni.bitwise.shiftRight", |l, r| l >> r);
1562
1563define_unary_bitwise_udf!(BitwiseNotUdf, "uni.bitwise.not", |v| !v);
1565
1566pub fn create_bitwise_or_udf() -> ScalarUDF {
1568 ScalarUDF::new_from_impl(BitwiseOrUdf::new())
1569}
1570
1571pub fn create_bitwise_and_udf() -> ScalarUDF {
1573 ScalarUDF::new_from_impl(BitwiseAndUdf::new())
1574}
1575
1576pub fn create_bitwise_xor_udf() -> ScalarUDF {
1578 ScalarUDF::new_from_impl(BitwiseXorUdf::new())
1579}
1580
1581pub fn create_bitwise_not_udf() -> ScalarUDF {
1583 ScalarUDF::new_from_impl(BitwiseNotUdf::new())
1584}
1585
1586pub fn create_shift_left_udf() -> ScalarUDF {
1588 ScalarUDF::new_from_impl(ShiftLeftUdf::new())
1589}
1590
1591pub fn create_shift_right_udf() -> ScalarUDF {
1593 ScalarUDF::new_from_impl(ShiftRightUdf::new())
1594}
1595
1596fn create_temporal_udf(name: &str) -> ScalarUDF {
1607 ScalarUDF::new_from_impl(TemporalUdf::new(name.to_string()))
1608}
1609
1610#[derive(Debug)]
1611struct TemporalUdf {
1612 name: String,
1613 signature: Signature,
1614}
1615
1616impl TemporalUdf {
1617 fn new(name: String) -> Self {
1618 Self {
1619 name,
1620 signature: Signature::new(
1623 TypeSignature::OneOf(vec![
1624 TypeSignature::Exact(vec![]),
1625 TypeSignature::VariadicAny,
1626 ]),
1627 Volatility::Immutable,
1628 ),
1629 }
1630 }
1631}
1632
1633impl_udf_eq_hash!(TemporalUdf);
1634
1635impl ScalarUDFImpl for TemporalUdf {
1636 fn as_any(&self) -> &dyn Any {
1637 self
1638 }
1639
1640 fn name(&self) -> &str {
1641 &self.name
1642 }
1643
1644 fn signature(&self) -> &Signature {
1645 &self.signature
1646 }
1647
1648 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1649 let name = self.name.to_lowercase();
1650 match name.as_str() {
1651 "year" | "month" | "day" | "hour" | "minute" | "second" => Ok(DataType::Int64),
1653 "datetime"
1659 | "localdatetime"
1660 | "date"
1661 | "time"
1662 | "localtime"
1663 | "duration"
1664 | "date.truncate"
1665 | "time.truncate"
1666 | "datetime.truncate"
1667 | "localdatetime.truncate"
1668 | "localtime.truncate"
1669 | "duration.between"
1670 | "duration.inmonths"
1671 | "duration.indays"
1672 | "duration.inseconds"
1673 | "datetime.fromepoch"
1674 | "datetime.fromepochmillis"
1675 | "datetime.transaction"
1676 | "datetime.statement"
1677 | "datetime.realtime"
1678 | "date.transaction"
1679 | "date.statement"
1680 | "date.realtime"
1681 | "time.transaction"
1682 | "time.statement"
1683 | "time.realtime"
1684 | "localtime.transaction"
1685 | "localtime.statement"
1686 | "localtime.realtime"
1687 | "localdatetime.transaction"
1688 | "localdatetime.statement"
1689 | "localdatetime.realtime" => Ok(DataType::LargeBinary),
1690 _ => Ok(DataType::Utf8),
1691 }
1692 }
1693
1694 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1695 let func_name = self.name.to_uppercase();
1696 let output_type = self.return_type(&[])?;
1697 invoke_cypher_udf(args, &output_type, |val_args| {
1698 crate::query::datetime::eval_datetime_function(&func_name, val_args).map_err(|e| {
1699 datafusion::error::DataFusionError::Execution(format!("{}(): {}", self.name, e))
1700 })
1701 })
1702 }
1703}
1704
1705fn create_duration_property_udf() -> ScalarUDF {
1710 ScalarUDF::new_from_impl(DurationPropertyUdf::new())
1711}
1712
1713#[derive(Debug)]
1714struct DurationPropertyUdf {
1715 signature: Signature,
1716}
1717
1718impl DurationPropertyUdf {
1719 fn new() -> Self {
1720 Self {
1721 signature: Signature::new(
1722 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
1723 Volatility::Immutable,
1724 ),
1725 }
1726 }
1727}
1728
1729impl_udf_eq_hash!(DurationPropertyUdf);
1730
1731impl ScalarUDFImpl for DurationPropertyUdf {
1732 fn as_any(&self) -> &dyn Any {
1733 self
1734 }
1735
1736 fn name(&self) -> &str {
1737 "_duration_property"
1738 }
1739
1740 fn signature(&self) -> &Signature {
1741 &self.signature
1742 }
1743
1744 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1745 Ok(DataType::Int64)
1746 }
1747
1748 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1749 let output_type = self.return_type(&[])?;
1750 invoke_cypher_udf(args, &output_type, |val_args| {
1751 if val_args.len() != 2 {
1752 return Err(datafusion::error::DataFusionError::Execution(
1753 "_duration_property(): requires 2 arguments (duration_string, component)"
1754 .to_string(),
1755 ));
1756 }
1757
1758 let dur_string_owned;
1759 let dur_str = match &val_args[0] {
1760 Value::String(s) => s.as_str(),
1761 Value::Temporal(uni_common::TemporalValue::Duration { .. }) => {
1762 dur_string_owned = val_args[0].to_string();
1763 &dur_string_owned
1764 }
1765 Value::Null => return Ok(Value::Null),
1766 _ => {
1767 return Err(datafusion::error::DataFusionError::Execution(
1768 "_duration_property(): duration must be a string or temporal duration"
1769 .to_string(),
1770 ));
1771 }
1772 };
1773 let component = match &val_args[1] {
1774 Value::String(s) => s,
1775 _ => {
1776 return Err(datafusion::error::DataFusionError::Execution(
1777 "_duration_property(): component must be a string".to_string(),
1778 ));
1779 }
1780 };
1781
1782 crate::query::datetime::eval_duration_accessor(dur_str, component).map_err(|e| {
1783 datafusion::error::DataFusionError::Execution(format!(
1784 "_duration_property(): {}",
1785 e
1786 ))
1787 })
1788 })
1789 }
1790}
1791
1792fn create_tostring_udf() -> ScalarUDF {
1797 ScalarUDF::new_from_impl(ToStringUdf::new())
1798}
1799
1800#[derive(Debug)]
1801struct ToStringUdf {
1802 signature: Signature,
1803}
1804
1805impl ToStringUdf {
1806 fn new() -> Self {
1807 Self {
1808 signature: Signature::variadic_any(Volatility::Immutable),
1809 }
1810 }
1811}
1812
1813impl_udf_eq_hash!(ToStringUdf);
1814
1815impl ScalarUDFImpl for ToStringUdf {
1816 fn as_any(&self) -> &dyn Any {
1817 self
1818 }
1819
1820 fn name(&self) -> &str {
1821 "tostring"
1822 }
1823
1824 fn signature(&self) -> &Signature {
1825 &self.signature
1826 }
1827
1828 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1829 Ok(DataType::Utf8)
1830 }
1831
1832 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1833 let output_type = self.return_type(&[])?;
1834 invoke_cypher_udf(args, &output_type, |val_args| {
1835 if val_args.is_empty() {
1836 return Err(datafusion::error::DataFusionError::Execution(
1837 "toString(): requires 1 argument".to_string(),
1838 ));
1839 }
1840 match &val_args[0] {
1841 Value::Null => Ok(Value::Null),
1842 Value::String(s) => Ok(Value::String(s.clone())),
1843 Value::Int(i) => Ok(Value::String(i.to_string())),
1844 Value::Float(f) => Ok(Value::String(f.to_string())),
1845 Value::Bool(b) => Ok(Value::String(b.to_string())),
1846 Value::Temporal(t) => Ok(Value::String(t.to_string())),
1847 other => {
1848 let type_name = match other {
1849 Value::List(_) => "List",
1850 Value::Map(_) => "Map",
1851 Value::Node { .. } => "Node",
1852 Value::Edge { .. } => "Relationship",
1853 Value::Path { .. } => "Path",
1854 _ => "Unknown",
1855 };
1856 Err(datafusion::error::DataFusionError::Execution(format!(
1857 "TypeError: InvalidArgumentValue - toString() does not accept {} values",
1858 type_name
1859 )))
1860 }
1861 }
1862 })
1863 }
1864}
1865
1866fn create_temporal_property_udf() -> ScalarUDF {
1871 ScalarUDF::new_from_impl(TemporalPropertyUdf::new())
1872}
1873
1874#[derive(Debug)]
1875struct TemporalPropertyUdf {
1876 signature: Signature,
1877}
1878
1879impl TemporalPropertyUdf {
1880 fn new() -> Self {
1881 Self {
1882 signature: Signature::variadic_any(Volatility::Immutable),
1883 }
1884 }
1885}
1886
1887impl_udf_eq_hash!(TemporalPropertyUdf);
1888
1889impl ScalarUDFImpl for TemporalPropertyUdf {
1890 fn as_any(&self) -> &dyn Any {
1891 self
1892 }
1893
1894 fn name(&self) -> &str {
1895 "_temporal_property"
1896 }
1897
1898 fn signature(&self) -> &Signature {
1899 &self.signature
1900 }
1901
1902 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1903 Ok(DataType::LargeBinary)
1904 }
1905
1906 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1907 let output_type = self.return_type(&[])?;
1908 invoke_cypher_udf(args, &output_type, |val_args| {
1909 if val_args.len() != 2 {
1910 return Err(datafusion::error::DataFusionError::Execution(
1911 "_temporal_property(): requires 2 arguments (temporal_value, component)"
1912 .to_string(),
1913 ));
1914 }
1915
1916 let component = match &val_args[1] {
1917 Value::String(s) => s.clone(),
1918 _ => {
1919 return Err(datafusion::error::DataFusionError::Execution(
1920 "_temporal_property(): component must be a string".to_string(),
1921 ));
1922 }
1923 };
1924
1925 crate::query::datetime::eval_temporal_accessor_value(&val_args[0], &component).map_err(
1926 |e| {
1927 datafusion::error::DataFusionError::Execution(format!(
1928 "_temporal_property(): {}",
1929 e
1930 ))
1931 },
1932 )
1933 })
1934 }
1935}
1936
1937macro_rules! downcast_arr {
1940 ($arr:expr, $array_type:ty) => {
1941 $arr.as_any().downcast_ref::<$array_type>().ok_or_else(|| {
1942 datafusion::error::DataFusionError::Execution(format!(
1943 "Failed to downcast to {}",
1944 stringify!($array_type)
1945 ))
1946 })?
1947 };
1948}
1949
1950fn cypher_type_name(val: &Value) -> &'static str {
1952 match val {
1953 Value::Null => "Null",
1954 Value::Bool(_) => "Boolean",
1955 Value::Int(_) => "Integer",
1956 Value::Float(_) => "Float",
1957 Value::String(_) => "String",
1958 Value::Bytes(_) => "Bytes",
1959 Value::List(_) => "List",
1960 Value::Map(_) => "Map",
1961 Value::Node(_) => "Node",
1962 Value::Edge(_) => "Relationship",
1963 Value::Path(_) => "Path",
1964 Value::Vector(_) => "Vector",
1965 Value::Temporal(_) => "Temporal",
1966 _ => "Unknown",
1967 }
1968}
1969
1970fn string_to_value(s: &str) -> Value {
1972 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
1973 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
1974 {
1975 return Value::from(obj);
1976 }
1977 Value::String(s.to_string())
1978}
1979
1980fn get_value_from_array(arr: &ArrayRef, row: usize) -> DFResult<Value> {
1986 if arr.is_null(row) {
1987 return Ok(Value::Null);
1988 }
1989
1990 match arr.data_type() {
1991 DataType::LargeBinary => {
1992 let typed = downcast_arr!(arr, LargeBinaryArray);
1993 let bytes = typed.value(row);
1994 if let Ok(val) = uni_common::cypher_value_codec::decode(bytes) {
1995 return Ok(val);
1996 }
1997 Ok(serde_json::from_slice::<serde_json::Value>(bytes)
1999 .map(Value::from)
2000 .unwrap_or(Value::Null))
2001 }
2002 DataType::Int64 => Ok(Value::Int(downcast_arr!(arr, Int64Array).value(row))),
2003 DataType::Float64 => Ok(Value::Float(downcast_arr!(arr, Float64Array).value(row))),
2004 DataType::Utf8 => Ok(string_to_value(downcast_arr!(arr, StringArray).value(row))),
2005 DataType::LargeUtf8 => Ok(string_to_value(
2006 downcast_arr!(arr, LargeStringArray).value(row),
2007 )),
2008 DataType::Boolean => Ok(Value::Bool(downcast_arr!(arr, BooleanArray).value(row))),
2009 DataType::UInt64 => Ok(Value::Int(downcast_arr!(arr, UInt64Array).value(row) as i64)),
2010 DataType::Int32 => Ok(Value::Int(downcast_arr!(arr, Int32Array).value(row) as i64)),
2011 DataType::Float32 => Ok(Value::Float(
2012 downcast_arr!(arr, Float32Array).value(row) as f64
2013 )),
2014 _ => {
2017 let scalar = ScalarValue::try_from_array(arr, row).map_err(|e| {
2018 datafusion::error::DataFusionError::Execution(format!(
2019 "Cannot extract scalar from array at row {}: {}",
2020 row, e
2021 ))
2022 })?;
2023 scalar_to_value(&scalar)
2024 }
2025 }
2026}
2027
2028fn get_value_args_for_row(args: &[ColumnarValue], row: usize) -> DFResult<Vec<Value>> {
2030 args.iter()
2031 .map(|arg| match arg {
2032 ColumnarValue::Scalar(scalar) => scalar_to_value(scalar),
2033 ColumnarValue::Array(arr) => get_value_from_array(arr, row),
2034 })
2035 .collect()
2036}
2037
2038fn invoke_cypher_udf<F>(
2040 args: ScalarFunctionArgs,
2041 output_type: &DataType,
2042 f: F,
2043) -> DFResult<ColumnarValue>
2044where
2045 F: Fn(&[Value]) -> DFResult<Value>,
2046{
2047 let len = args
2048 .args
2049 .iter()
2050 .find_map(|arg| match arg {
2051 ColumnarValue::Array(arr) => Some(arr.len()),
2052 _ => None,
2053 })
2054 .unwrap_or(1);
2055
2056 if len == 1
2057 && args
2058 .args
2059 .iter()
2060 .all(|a| matches!(a, ColumnarValue::Scalar(_)))
2061 {
2062 let row_args = get_value_args_for_row(&args.args, 0)?;
2063 let res = f(&row_args)?;
2064 if matches!(output_type, DataType::LargeBinary | DataType::List(_)) {
2065 let arr = values_to_array(&[res], output_type)
2067 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2068 return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, 0)?));
2069 }
2070 if res.is_null() {
2072 let typed_null = ScalarValue::try_from(output_type).unwrap_or(ScalarValue::Utf8(None));
2073 return Ok(ColumnarValue::Scalar(typed_null));
2074 }
2075 return value_to_columnar(&res);
2076 }
2077
2078 let mut results = Vec::with_capacity(len);
2079 for i in 0..len {
2080 let row_args = get_value_args_for_row(&args.args, i)?;
2081 results.push(f(&row_args)?);
2082 }
2083
2084 let arr = values_to_array(&results, output_type)
2085 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2086 Ok(ColumnarValue::Array(arr))
2087}
2088
2089fn scalar_arr_to_value(arr: &dyn arrow::array::Array) -> DFResult<Value> {
2092 if arr.is_empty() || arr.is_null(0) {
2093 Ok(Value::Null)
2094 } else {
2095 Ok(uni_store::storage::arrow_convert::arrow_to_value(
2097 arr, 0, None,
2098 ))
2099 }
2100}
2101
2102fn resolve_timezone_offset(tz_name: &str, nanos_utc: i64) -> i32 {
2104 if tz_name == "UTC" || tz_name == "Z" {
2105 return 0;
2106 }
2107 if let Ok(tz) = tz_name.parse::<chrono_tz::Tz>() {
2108 let dt = chrono::DateTime::from_timestamp_nanos(nanos_utc).with_timezone(&tz);
2109 dt.offset().fix().local_minus_utc()
2110 } else {
2111 0
2112 }
2113}
2114
2115fn duration_micros_to_value(micros: i64) -> Value {
2117 let dur = crate::query::datetime::CypherDuration::from_micros(micros);
2118 Value::Temporal(uni_common::TemporalValue::Duration {
2119 months: dur.months,
2120 days: dur.days,
2121 nanos: dur.nanos,
2122 })
2123}
2124
2125fn timestamp_nanos_to_value(nanos: i64, tz: Option<&Arc<str>>) -> DFResult<Value> {
2127 if let Some(tz_str) = tz {
2128 let offset = resolve_timezone_offset(tz_str.as_ref(), nanos);
2129 let tz_name = if tz_str.as_ref() == "UTC" {
2130 None
2131 } else {
2132 Some(tz_str.to_string())
2133 };
2134 Ok(Value::Temporal(uni_common::TemporalValue::DateTime {
2135 nanos_since_epoch: nanos,
2136 offset_seconds: offset,
2137 timezone_name: tz_name,
2138 }))
2139 } else {
2140 Ok(Value::Temporal(uni_common::TemporalValue::LocalDateTime {
2141 nanos_since_epoch: nanos,
2142 }))
2143 }
2144}
2145
2146pub(crate) fn scalar_to_value(scalar: &ScalarValue) -> DFResult<Value> {
2148 match scalar {
2149 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
2150 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2153 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2154 {
2155 return Ok(Value::from(obj));
2156 }
2157 Ok(Value::String(s.clone()))
2158 }
2159 ScalarValue::LargeBinary(Some(b)) => {
2160 if let Ok(val) = uni_common::cypher_value_codec::decode(b) {
2163 return Ok(val);
2164 }
2165 if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(b) {
2166 Ok(Value::from(obj))
2167 } else {
2168 Ok(Value::Null)
2169 }
2170 }
2171 ScalarValue::Int64(Some(i)) => Ok(Value::Int(*i)),
2172 ScalarValue::Int32(Some(i)) => Ok(Value::Int(*i as i64)),
2173 ScalarValue::Float64(Some(f)) => {
2174 Ok(Value::Float(*f))
2176 }
2177 ScalarValue::Boolean(Some(b)) => Ok(Value::Bool(*b)),
2178 ScalarValue::Struct(arr) => scalar_arr_to_value(arr.as_ref()),
2179 ScalarValue::List(arr) => scalar_arr_to_value(arr.as_ref()),
2180 ScalarValue::LargeList(arr) => scalar_arr_to_value(arr.as_ref()),
2181 ScalarValue::FixedSizeList(arr) => scalar_arr_to_value(arr.as_ref()),
2182 ScalarValue::UInt64(Some(u)) => Ok(Value::Int(*u as i64)),
2184 ScalarValue::UInt32(Some(u)) => Ok(Value::Int(*u as i64)),
2185 ScalarValue::UInt16(Some(u)) => Ok(Value::Int(*u as i64)),
2186 ScalarValue::UInt8(Some(u)) => Ok(Value::Int(*u as i64)),
2187 ScalarValue::Int16(Some(i)) => Ok(Value::Int(*i as i64)),
2188 ScalarValue::Int8(Some(i)) => Ok(Value::Int(*i as i64)),
2189
2190 ScalarValue::Date32(Some(days)) => Ok(Value::Temporal(uni_common::TemporalValue::Date {
2192 days_since_epoch: *days,
2193 })),
2194 ScalarValue::Date64(Some(millis)) => {
2195 let days = (*millis / 86_400_000) as i32;
2196 Ok(Value::Temporal(uni_common::TemporalValue::Date {
2197 days_since_epoch: days,
2198 }))
2199 }
2200 ScalarValue::TimestampNanosecond(Some(nanos), tz) => {
2201 timestamp_nanos_to_value(*nanos, tz.as_ref())
2202 }
2203 ScalarValue::TimestampMicrosecond(Some(micros), tz) => {
2204 timestamp_nanos_to_value(*micros * 1_000, tz.as_ref())
2205 }
2206 ScalarValue::TimestampMillisecond(Some(millis), tz) => {
2207 timestamp_nanos_to_value(*millis * 1_000_000, tz.as_ref())
2208 }
2209 ScalarValue::TimestampSecond(Some(secs), tz) => {
2210 timestamp_nanos_to_value(*secs * 1_000_000_000, tz.as_ref())
2211 }
2212 ScalarValue::Time64Nanosecond(Some(nanos)) => {
2213 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2214 nanos_since_midnight: *nanos,
2215 }))
2216 }
2217 ScalarValue::Time64Microsecond(Some(micros)) => {
2218 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2219 nanos_since_midnight: *micros * 1_000,
2220 }))
2221 }
2222 ScalarValue::IntervalMonthDayNano(Some(v)) => {
2223 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2224 months: v.months as i64,
2225 days: v.days as i64,
2226 nanos: v.nanoseconds,
2227 }))
2228 }
2229 ScalarValue::DurationMicrosecond(Some(micros)) => Ok(duration_micros_to_value(*micros)),
2230 ScalarValue::DurationMillisecond(Some(millis)) => {
2231 Ok(duration_micros_to_value(*millis * 1_000))
2232 }
2233 ScalarValue::DurationSecond(Some(secs)) => Ok(duration_micros_to_value(*secs * 1_000_000)),
2234 ScalarValue::DurationNanosecond(Some(nanos)) => {
2235 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2236 months: 0,
2237 days: 0,
2238 nanos: *nanos,
2239 }))
2240 }
2241 ScalarValue::Float32(Some(f)) => Ok(Value::Float(*f as f64)),
2242
2243 ScalarValue::Null
2245 | ScalarValue::Utf8(None)
2246 | ScalarValue::LargeUtf8(None)
2247 | ScalarValue::LargeBinary(None)
2248 | ScalarValue::Int64(None)
2249 | ScalarValue::Int32(None)
2250 | ScalarValue::Int16(None)
2251 | ScalarValue::Int8(None)
2252 | ScalarValue::UInt64(None)
2253 | ScalarValue::UInt32(None)
2254 | ScalarValue::UInt16(None)
2255 | ScalarValue::UInt8(None)
2256 | ScalarValue::Float64(None)
2257 | ScalarValue::Float32(None)
2258 | ScalarValue::Boolean(None)
2259 | ScalarValue::Date32(None)
2260 | ScalarValue::Date64(None)
2261 | ScalarValue::TimestampMicrosecond(None, _)
2262 | ScalarValue::TimestampMillisecond(None, _)
2263 | ScalarValue::TimestampSecond(None, _)
2264 | ScalarValue::TimestampNanosecond(None, _)
2265 | ScalarValue::Time64Microsecond(None)
2266 | ScalarValue::Time64Nanosecond(None)
2267 | ScalarValue::DurationMicrosecond(None)
2268 | ScalarValue::DurationMillisecond(None)
2269 | ScalarValue::DurationSecond(None)
2270 | ScalarValue::DurationNanosecond(None)
2271 | ScalarValue::IntervalMonthDayNano(None) => Ok(Value::Null),
2272 other => Err(datafusion::error::DataFusionError::Execution(format!(
2273 "scalar_to_value(): unsupported scalar type {other:?}"
2274 ))),
2275 }
2276}
2277
2278fn value_to_columnar(val: &Value) -> DFResult<ColumnarValue> {
2280 let scalar = match val {
2281 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
2282 Value::Int(i) => ScalarValue::Int64(Some(*i)),
2283 Value::Float(f) => ScalarValue::Float64(Some(*f)),
2284 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
2285 Value::Null => ScalarValue::Utf8(None),
2286 Value::Temporal(tv) => {
2287 use uni_common::TemporalValue;
2288 match tv {
2289 TemporalValue::Date { days_since_epoch } => {
2290 ScalarValue::Date32(Some(*days_since_epoch))
2291 }
2292 TemporalValue::LocalTime {
2293 nanos_since_midnight,
2294 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2295 TemporalValue::Time {
2296 nanos_since_midnight,
2297 ..
2298 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2299 TemporalValue::LocalDateTime { nanos_since_epoch } => {
2300 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None)
2301 }
2302 TemporalValue::DateTime {
2303 nanos_since_epoch,
2304 timezone_name,
2305 ..
2306 } => {
2307 let tz = timezone_name.as_deref().unwrap_or("UTC");
2308 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), Some(tz.into()))
2309 }
2310 TemporalValue::Duration {
2311 months,
2312 days,
2313 nanos,
2314 } => ScalarValue::IntervalMonthDayNano(Some(
2315 arrow::datatypes::IntervalMonthDayNano {
2316 months: *months as i32,
2317 days: *days as i32,
2318 nanoseconds: *nanos,
2319 },
2320 )),
2321 }
2322 }
2323 other => {
2324 return Err(datafusion::error::DataFusionError::Execution(format!(
2325 "value_to_columnar(): unsupported type {other:?}"
2326 )));
2327 }
2328 };
2329 Ok(ColumnarValue::Scalar(scalar))
2330}
2331
2332pub fn create_has_null_udf() -> ScalarUDF {
2338 ScalarUDF::new_from_impl(HasNullUdf::new())
2339}
2340
2341#[derive(Debug)]
2342struct HasNullUdf {
2343 signature: Signature,
2344}
2345
2346impl HasNullUdf {
2347 fn new() -> Self {
2348 Self {
2349 signature: Signature::any(1, Volatility::Immutable),
2350 }
2351 }
2352}
2353
2354impl_udf_eq_hash!(HasNullUdf);
2355
2356impl ScalarUDFImpl for HasNullUdf {
2357 fn as_any(&self) -> &dyn Any {
2358 self
2359 }
2360
2361 fn name(&self) -> &str {
2362 "_has_null"
2363 }
2364
2365 fn signature(&self) -> &Signature {
2366 &self.signature
2367 }
2368
2369 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2370 Ok(DataType::Boolean)
2371 }
2372
2373 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2374 if args.args.len() != 1 {
2375 return Err(datafusion::error::DataFusionError::Execution(
2376 "_has_null(): requires 1 argument".to_string(),
2377 ));
2378 }
2379
2380 fn check_list_nulls<T: arrow_array::OffsetSizeTrait>(
2382 arr: &arrow_array::GenericListArray<T>,
2383 idx: usize,
2384 ) -> bool {
2385 if arr.is_null(idx) || arr.is_empty() {
2386 false
2387 } else {
2388 arr.value(idx).null_count() > 0
2389 }
2390 }
2391
2392 match &args.args[0] {
2393 ColumnarValue::Scalar(scalar) => {
2394 let has_null = match scalar {
2395 ScalarValue::List(arr) => arr
2396 .as_any()
2397 .downcast_ref::<arrow::array::ListArray>()
2398 .map(|a| !a.is_empty() && a.value(0).null_count() > 0)
2399 .unwrap_or(arr.null_count() > 0),
2400 ScalarValue::LargeList(arr) => arr.len() > 0 && arr.value(0).null_count() > 0,
2401 ScalarValue::FixedSizeList(arr) => {
2402 arr.len() > 0 && arr.value(0).null_count() > 0
2403 }
2404 _ => false,
2405 };
2406 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(has_null))))
2407 }
2408 ColumnarValue::Array(arr) => {
2409 use arrow_array::{LargeListArray, ListArray};
2410
2411 let results: arrow::array::BooleanArray =
2412 if let Some(list_arr) = arr.as_any().downcast_ref::<ListArray>() {
2413 (0..list_arr.len())
2414 .map(|i| {
2415 if list_arr.is_null(i) {
2416 None
2417 } else {
2418 Some(check_list_nulls(list_arr, i))
2419 }
2420 })
2421 .collect()
2422 } else if let Some(large) = arr.as_any().downcast_ref::<LargeListArray>() {
2423 (0..large.len())
2424 .map(|i| {
2425 if large.is_null(i) {
2426 None
2427 } else {
2428 Some(check_list_nulls(large, i))
2429 }
2430 })
2431 .collect()
2432 } else {
2433 return Err(datafusion::error::DataFusionError::Execution(
2434 "_has_null(): requires list array".to_string(),
2435 ));
2436 };
2437 Ok(ColumnarValue::Array(Arc::new(results)))
2438 }
2439 }
2440 }
2441}
2442
2443pub fn create_to_integer_udf() -> ScalarUDF {
2448 ScalarUDF::new_from_impl(ToIntegerUdf::new())
2449}
2450
2451#[derive(Debug)]
2452struct ToIntegerUdf {
2453 signature: Signature,
2454}
2455
2456impl ToIntegerUdf {
2457 fn new() -> Self {
2458 Self {
2459 signature: Signature::any(1, Volatility::Immutable),
2460 }
2461 }
2462}
2463
2464impl_udf_eq_hash!(ToIntegerUdf);
2465
2466impl ScalarUDFImpl for ToIntegerUdf {
2467 fn as_any(&self) -> &dyn Any {
2468 self
2469 }
2470
2471 fn name(&self) -> &str {
2472 "tointeger"
2473 }
2474
2475 fn signature(&self) -> &Signature {
2476 &self.signature
2477 }
2478
2479 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2480 Ok(DataType::Int64)
2481 }
2482
2483 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2484 let output_type = self.return_type(&[])?;
2485 invoke_cypher_udf(args, &output_type, |val_args| {
2486 if val_args.is_empty() {
2487 return Err(datafusion::error::DataFusionError::Execution(
2488 "tointeger(): requires 1 argument".to_string(),
2489 ));
2490 }
2491
2492 let val = &val_args[0];
2493 match val {
2494 Value::Int(i) => Ok(Value::Int(*i)),
2495 Value::Float(f) => Ok(Value::Int(*f as i64)),
2496 Value::String(s) => {
2497 if let Ok(i) = s.parse::<i64>() {
2498 Ok(Value::Int(i))
2499 } else if let Ok(f) = s.parse::<f64>() {
2500 Ok(Value::Int(f as i64))
2501 } else {
2502 Ok(Value::Null)
2503 }
2504 }
2505 Value::Null => Ok(Value::Null),
2506 other => Err(datafusion::error::DataFusionError::Execution(format!(
2507 "InvalidArgumentValue: tointeger(): cannot convert {} to integer",
2508 cypher_type_name(other)
2509 ))),
2510 }
2511 })
2512 }
2513}
2514
2515pub fn create_to_float_udf() -> ScalarUDF {
2520 ScalarUDF::new_from_impl(ToFloatUdf::new())
2521}
2522
2523#[derive(Debug)]
2524struct ToFloatUdf {
2525 signature: Signature,
2526}
2527
2528impl ToFloatUdf {
2529 fn new() -> Self {
2530 Self {
2531 signature: Signature::any(1, Volatility::Immutable),
2532 }
2533 }
2534}
2535
2536impl_udf_eq_hash!(ToFloatUdf);
2537
2538impl ScalarUDFImpl for ToFloatUdf {
2539 fn as_any(&self) -> &dyn Any {
2540 self
2541 }
2542
2543 fn name(&self) -> &str {
2544 "tofloat"
2545 }
2546
2547 fn signature(&self) -> &Signature {
2548 &self.signature
2549 }
2550
2551 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2552 Ok(DataType::Float64)
2553 }
2554
2555 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2556 let output_type = self.return_type(&[])?;
2557 invoke_cypher_udf(args, &output_type, |val_args| {
2558 if val_args.is_empty() {
2559 return Err(datafusion::error::DataFusionError::Execution(
2560 "tofloat(): requires 1 argument".to_string(),
2561 ));
2562 }
2563
2564 let val = &val_args[0];
2565 match val {
2566 Value::Int(i) => Ok(Value::Float(*i as f64)),
2567 Value::Float(f) => Ok(Value::Float(*f)),
2568 Value::String(s) => {
2569 if let Ok(f) = s.parse::<f64>() {
2570 Ok(Value::Float(f))
2571 } else {
2572 Ok(Value::Null)
2573 }
2574 }
2575 Value::Null => Ok(Value::Null),
2576 other => Err(datafusion::error::DataFusionError::Execution(format!(
2577 "InvalidArgumentValue: tofloat(): cannot convert {} to float",
2578 cypher_type_name(other)
2579 ))),
2580 }
2581 })
2582 }
2583}
2584
2585pub fn create_to_boolean_udf() -> ScalarUDF {
2590 ScalarUDF::new_from_impl(ToBooleanUdf::new())
2591}
2592
2593#[derive(Debug)]
2594struct ToBooleanUdf {
2595 signature: Signature,
2596}
2597
2598impl ToBooleanUdf {
2599 fn new() -> Self {
2600 Self {
2601 signature: Signature::any(1, Volatility::Immutable),
2602 }
2603 }
2604}
2605
2606impl_udf_eq_hash!(ToBooleanUdf);
2607
2608impl ScalarUDFImpl for ToBooleanUdf {
2609 fn as_any(&self) -> &dyn Any {
2610 self
2611 }
2612
2613 fn name(&self) -> &str {
2614 "toboolean"
2615 }
2616
2617 fn signature(&self) -> &Signature {
2618 &self.signature
2619 }
2620
2621 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2622 Ok(DataType::Boolean)
2623 }
2624
2625 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2626 let output_type = self.return_type(&[])?;
2627 invoke_cypher_udf(args, &output_type, |val_args| {
2628 if val_args.is_empty() {
2629 return Err(datafusion::error::DataFusionError::Execution(
2630 "toboolean(): requires 1 argument".to_string(),
2631 ));
2632 }
2633
2634 let val = &val_args[0];
2635 match val {
2636 Value::Bool(b) => Ok(Value::Bool(*b)),
2637 Value::String(s) => {
2638 let s_lower = s.to_lowercase();
2639 if s_lower == "true" {
2640 Ok(Value::Bool(true))
2641 } else if s_lower == "false" {
2642 Ok(Value::Bool(false))
2643 } else {
2644 Ok(Value::Null)
2645 }
2646 }
2647 Value::Null => Ok(Value::Null),
2648 Value::Int(i) => Ok(Value::Bool(*i != 0)),
2649 other => Err(datafusion::error::DataFusionError::Execution(format!(
2650 "InvalidArgumentValue: toboolean(): cannot convert {} to boolean",
2651 cypher_type_name(other)
2652 ))),
2653 }
2654 })
2655 }
2656}
2657
2658pub fn create_cypher_sort_key_udf() -> ScalarUDF {
2665 ScalarUDF::new_from_impl(CypherSortKeyUdf::new())
2666}
2667
2668#[derive(Debug)]
2669struct CypherSortKeyUdf {
2670 signature: Signature,
2671}
2672
2673impl CypherSortKeyUdf {
2674 fn new() -> Self {
2675 Self {
2676 signature: Signature::any(1, Volatility::Immutable),
2677 }
2678 }
2679}
2680
2681impl_udf_eq_hash!(CypherSortKeyUdf);
2682
2683impl ScalarUDFImpl for CypherSortKeyUdf {
2684 fn as_any(&self) -> &dyn Any {
2685 self
2686 }
2687
2688 fn name(&self) -> &str {
2689 "_cypher_sort_key"
2690 }
2691
2692 fn signature(&self) -> &Signature {
2693 &self.signature
2694 }
2695
2696 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2697 Ok(DataType::LargeBinary)
2698 }
2699
2700 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2701 if args.args.len() != 1 {
2702 return Err(datafusion::error::DataFusionError::Execution(
2703 "_cypher_sort_key(): requires 1 argument".to_string(),
2704 ));
2705 }
2706
2707 let arg = &args.args[0];
2708 match arg {
2709 ColumnarValue::Scalar(s) => {
2710 let val = if s.is_null() {
2711 Value::Null
2712 } else {
2713 scalar_to_value(s)?
2714 };
2715 let key = encode_cypher_sort_key(&val);
2716 Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(key))))
2717 }
2718 ColumnarValue::Array(arr) => {
2719 let mut keys: Vec<Option<Vec<u8>>> = Vec::with_capacity(arr.len());
2720 for i in 0..arr.len() {
2721 let val = if arr.is_null(i) {
2722 Value::Null
2723 } else {
2724 get_value_from_array(arr, i)?
2725 };
2726 keys.push(Some(encode_cypher_sort_key(&val)));
2727 }
2728 let array = LargeBinaryArray::from(
2729 keys.iter()
2730 .map(|k| k.as_deref())
2731 .collect::<Vec<Option<&[u8]>>>(),
2732 );
2733 Ok(ColumnarValue::Array(Arc::new(array)))
2734 }
2735 }
2736 }
2737}
2738
2739pub fn encode_cypher_sort_key(value: &Value) -> Vec<u8> {
2745 let mut buf = Vec::with_capacity(32);
2746 encode_sort_key_to_buf(value, &mut buf);
2747 buf
2748}
2749
2750fn encode_sort_key_to_buf(value: &Value, buf: &mut Vec<u8>) {
2752 if let Value::Map(map) = value {
2754 if let Some(tv) = sort_key_map_as_temporal(map) {
2755 buf.push(0x07); encode_temporal_payload(&tv, buf);
2757 return;
2758 }
2759 let rank = sort_key_map_rank(map);
2760 if rank != 0 {
2761 buf.push(rank);
2763 match rank {
2764 0x01 => encode_map_as_node_payload(map, buf),
2765 0x02 => encode_map_as_edge_payload(map, buf),
2766 0x04 => encode_map_as_path_payload(map, buf),
2767 _ => {} }
2769 return;
2770 }
2771 }
2772
2773 if let Value::String(s) = value {
2775 if let Some(tv) = sort_key_string_as_temporal(s) {
2776 buf.push(0x07); encode_temporal_payload(&tv, buf);
2778 return;
2779 }
2780 if let Some(temporal_type) = crate::query::datetime::classify_temporal(s) {
2783 buf.push(0x07); if encode_wide_temporal_sort_key(s, temporal_type, buf) {
2785 return;
2786 }
2787 buf.pop();
2789 }
2790 }
2791
2792 let rank = sort_key_type_rank(value);
2793 buf.push(rank);
2794
2795 match value {
2796 Value::Null => {} Value::Float(f) if f.is_nan() => {} Value::Bool(b) => buf.push(if *b { 0x01 } else { 0x00 }),
2799 Value::Int(i) => {
2800 let f = *i as f64;
2801 buf.extend_from_slice(&encode_order_preserving_f64(f));
2802 }
2803 Value::Float(f) => {
2804 buf.extend_from_slice(&encode_order_preserving_f64(*f));
2805 }
2806 Value::String(s) => {
2807 byte_stuff_terminate(s.as_bytes(), buf);
2808 }
2809 Value::Temporal(tv) => {
2810 encode_temporal_payload(tv, buf);
2811 }
2812 Value::List(items) => {
2813 encode_list_payload(items, buf);
2814 }
2815 Value::Map(map) => {
2816 encode_map_payload(map, buf);
2817 }
2818 Value::Node(node) => {
2819 encode_node_payload(node, buf);
2820 }
2821 Value::Edge(edge) => {
2822 encode_edge_payload(edge, buf);
2823 }
2824 Value::Path(path) => {
2825 encode_path_payload(path, buf);
2826 }
2827 Value::Bytes(b) => {
2829 byte_stuff_terminate(b, buf);
2830 }
2831 Value::Vector(v) => {
2832 for f in v {
2833 buf.extend_from_slice(&encode_order_preserving_f64(*f as f64));
2834 }
2835 }
2836 _ => {} }
2838}
2839
2840fn sort_key_type_rank(v: &Value) -> u8 {
2844 match v {
2845 Value::Map(map) => sort_key_map_rank(map),
2846 Value::Node(_) => 0x01,
2847 Value::Edge(_) => 0x02,
2848 Value::List(_) => 0x03,
2849 Value::Path(_) => 0x04,
2850 Value::String(_) => 0x05,
2851 Value::Bool(_) => 0x06,
2852 Value::Temporal(_) => 0x07,
2853 Value::Int(_) => 0x08,
2854 Value::Float(f) if f.is_nan() => 0x09,
2855 Value::Float(_) => 0x08,
2856 Value::Null => 0x0A,
2857 Value::Bytes(_) | Value::Vector(_) => 0x0B,
2858 _ => 0x0B, }
2860}
2861
2862fn sort_key_map_rank(map: &std::collections::HashMap<String, Value>) -> u8 {
2864 if sort_key_map_as_temporal(map).is_some() {
2865 0x07
2866 } else if map.contains_key("nodes")
2867 && (map.contains_key("relationships") || map.contains_key("edges"))
2868 {
2869 0x04 } else if map.contains_key("_eid")
2871 || map.contains_key("_src")
2872 || map.contains_key("_dst")
2873 || map.contains_key("_type")
2874 || map.contains_key("_type_name")
2875 {
2876 0x02 } else if map.contains_key("_vid") || map.contains_key("_labels") || map.contains_key("_label")
2878 {
2879 0x01 } else {
2881 0x00 }
2883}
2884
2885fn sort_key_map_as_temporal(
2889 map: &std::collections::HashMap<String, Value>,
2890) -> Option<uni_common::TemporalValue> {
2891 super::expr_eval::temporal_from_map_wrapper(map)
2892}
2893
2894fn sort_key_string_as_temporal(s: &str) -> Option<uni_common::TemporalValue> {
2898 super::expr_eval::temporal_from_value(&Value::String(s.to_string()))
2899}
2900
2901fn encode_wide_temporal_sort_key(
2908 s: &str,
2909 temporal_type: uni_common::TemporalType,
2910 buf: &mut Vec<u8>,
2911) -> bool {
2912 match temporal_type {
2913 uni_common::TemporalType::LocalDateTime => {
2914 if let Some(ndt) = parse_naive_datetime(s) {
2915 buf.push(0x03); let wide_nanos = naive_datetime_to_wide_nanos(&ndt);
2917 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
2918 return true;
2919 }
2920 false
2921 }
2922 uni_common::TemporalType::DateTime => {
2923 let base = if let Some(bracket_pos) = s.find('[') {
2925 &s[..bracket_pos]
2926 } else {
2927 s
2928 };
2929 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%.f%:z") {
2930 buf.push(0x04); let utc = dt.naive_utc();
2932 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
2933 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
2934 return true;
2935 }
2936 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%:z") {
2937 buf.push(0x04); let utc = dt.naive_utc();
2939 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
2940 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
2941 return true;
2942 }
2943 false
2944 }
2945 uni_common::TemporalType::Date => {
2946 if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
2947 && let Some(epoch) = chrono::NaiveDate::from_ymd_opt(1970, 1, 1)
2948 {
2949 buf.push(0x00); let days = nd.signed_duration_since(epoch).num_days() as i32;
2951 buf.extend_from_slice(&encode_order_preserving_i32(days));
2952 return true;
2953 }
2954 false
2955 }
2956 _ => false,
2957 }
2958}
2959
2960fn parse_naive_datetime(s: &str) -> Option<chrono::NaiveDateTime> {
2962 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
2963 .ok()
2964 .or_else(|| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").ok())
2965}
2966
2967fn naive_datetime_to_wide_nanos(ndt: &chrono::NaiveDateTime) -> i128 {
2970 let secs = ndt.and_utc().timestamp() as i128;
2971 let subsec_nanos = ndt.and_utc().timestamp_subsec_nanos() as i128;
2972 secs * 1_000_000_000 + subsec_nanos
2973}
2974
2975fn encode_map_as_node_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
2977 let mut labels: Vec<String> = Vec::new();
2979 if let Some(Value::List(lbls)) = map.get("_labels") {
2980 for l in lbls {
2981 if let Value::String(s) = l {
2982 labels.push(s.clone());
2983 }
2984 }
2985 } else if let Some(Value::String(lbl)) = map.get("_label") {
2986 labels.push(lbl.clone());
2987 }
2988 labels.sort();
2989
2990 let vid = map.get("_vid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
2992
2993 let labels_joined = labels.join("\x01");
2995 byte_stuff_terminate(labels_joined.as_bytes(), buf);
2996
2997 buf.extend_from_slice(&vid.to_be_bytes());
2999
3000 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3002 for (k, v) in map {
3003 if !k.starts_with('_') {
3004 props.insert(k.clone(), v.clone());
3005 }
3006 }
3007 encode_map_payload(&props, buf);
3008}
3009
3010fn encode_map_as_edge_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3012 let edge_type = map
3013 .get("_type")
3014 .or_else(|| map.get("_type_name"))
3015 .and_then(|v| {
3016 if let Value::String(s) = v {
3017 Some(s.as_str())
3018 } else {
3019 None
3020 }
3021 })
3022 .unwrap_or("");
3023
3024 byte_stuff_terminate(edge_type.as_bytes(), buf);
3025
3026 let src = map.get("_src").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3027 let dst = map.get("_dst").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3028 let eid = map.get("_eid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3029
3030 buf.extend_from_slice(&src.to_be_bytes());
3031 buf.extend_from_slice(&dst.to_be_bytes());
3032 buf.extend_from_slice(&eid.to_be_bytes());
3033
3034 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3036 for (k, v) in map {
3037 if !k.starts_with('_') {
3038 props.insert(k.clone(), v.clone());
3039 }
3040 }
3041 encode_map_payload(&props, buf);
3042}
3043
3044fn encode_map_as_path_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3046 if let Some(Value::List(nodes)) = map.get("nodes") {
3048 encode_list_payload(nodes, buf);
3049 } else {
3050 buf.push(0x00); }
3052 let edges = map.get("relationships").or_else(|| map.get("edges"));
3054 if let Some(Value::List(edges)) = edges {
3055 encode_list_payload(edges, buf);
3056 } else {
3057 buf.push(0x00); }
3059}
3060
3061fn encode_order_preserving_f64(f: f64) -> [u8; 8] {
3068 let bits = f.to_bits();
3069 let encoded = if bits >> 63 == 1 {
3070 !bits
3072 } else {
3073 bits ^ (1u64 << 63)
3075 };
3076 encoded.to_be_bytes()
3077}
3078
3079fn encode_order_preserving_i64(i: i64) -> [u8; 8] {
3081 ((i as u64) ^ (1u64 << 63)).to_be_bytes()
3083}
3084
3085fn encode_order_preserving_i32(i: i32) -> [u8; 4] {
3087 ((i as u32) ^ (1u32 << 31)).to_be_bytes()
3088}
3089
3090fn encode_order_preserving_i128(i: i128) -> [u8; 16] {
3092 ((i as u128) ^ (1u128 << 127)).to_be_bytes()
3093}
3094
3095fn byte_stuff_terminate(data: &[u8], buf: &mut Vec<u8>) {
3100 byte_stuff(data, buf);
3101 buf.push(0x00);
3102 buf.push(0x00);
3103}
3104
3105fn byte_stuff(data: &[u8], buf: &mut Vec<u8>) {
3107 for &b in data {
3108 buf.push(b);
3109 if b == 0x00 {
3110 buf.push(0xFF);
3111 }
3112 }
3113}
3114
3115fn encode_list_payload(items: &[Value], buf: &mut Vec<u8>) {
3120 for item in items {
3121 buf.push(0x01); let elem_key = encode_cypher_sort_key(item);
3123 byte_stuff_terminate(&elem_key, buf);
3124 }
3125 buf.push(0x00); }
3127
3128fn encode_map_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3130 let mut pairs: Vec<(&String, &Value)> = map.iter().collect();
3131 pairs.sort_by_key(|(k, _)| *k);
3132
3133 for (key, value) in pairs {
3134 buf.push(0x01); byte_stuff_terminate(key.as_bytes(), buf);
3136 let val_key = encode_cypher_sort_key(value);
3137 byte_stuff_terminate(&val_key, buf);
3138 }
3139 buf.push(0x00); }
3141
3142fn encode_node_payload(node: &uni_common::Node, buf: &mut Vec<u8>) {
3146 let mut labels = node.labels.clone();
3147 labels.sort();
3148 let labels_joined = labels.join("\x01");
3149 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3150
3151 buf.extend_from_slice(&node.vid.as_u64().to_be_bytes());
3152
3153 encode_map_payload(&node.properties, buf);
3154}
3155
3156fn encode_edge_payload(edge: &uni_common::Edge, buf: &mut Vec<u8>) {
3160 byte_stuff_terminate(edge.edge_type.as_bytes(), buf);
3161
3162 buf.extend_from_slice(&edge.src.as_u64().to_be_bytes());
3163 buf.extend_from_slice(&edge.dst.as_u64().to_be_bytes());
3164 buf.extend_from_slice(&edge.eid.as_u64().to_be_bytes());
3165
3166 encode_map_payload(&edge.properties, buf);
3167}
3168
3169fn encode_path_payload(path: &uni_common::Path, buf: &mut Vec<u8>) {
3173 for node in &path.nodes {
3175 buf.push(0x01); let mut node_key = Vec::new();
3177 node_key.push(0x01); encode_node_payload(node, &mut node_key);
3179 byte_stuff_terminate(&node_key, buf);
3180 }
3181 buf.push(0x00); for edge in &path.edges {
3185 buf.push(0x01); let mut edge_key = Vec::new();
3187 edge_key.push(0x02); encode_edge_payload(edge, &mut edge_key);
3189 byte_stuff_terminate(&edge_key, buf);
3190 }
3191 buf.push(0x00); }
3193
3194fn encode_temporal_payload(tv: &uni_common::TemporalValue, buf: &mut Vec<u8>) {
3196 match tv {
3197 uni_common::TemporalValue::Date { days_since_epoch } => {
3198 buf.push(0x00); buf.extend_from_slice(&encode_order_preserving_i32(*days_since_epoch));
3200 }
3201 uni_common::TemporalValue::LocalTime {
3202 nanos_since_midnight,
3203 } => {
3204 buf.push(0x01); buf.extend_from_slice(&encode_order_preserving_i64(*nanos_since_midnight));
3206 }
3207 uni_common::TemporalValue::Time {
3208 nanos_since_midnight,
3209 offset_seconds,
3210 } => {
3211 buf.push(0x02); let utc_nanos =
3213 *nanos_since_midnight as i128 - (*offset_seconds as i128) * 1_000_000_000;
3214 buf.extend_from_slice(&encode_order_preserving_i128(utc_nanos));
3215 }
3216 uni_common::TemporalValue::LocalDateTime { nanos_since_epoch } => {
3217 buf.push(0x03); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3220 }
3221 uni_common::TemporalValue::DateTime {
3222 nanos_since_epoch, ..
3223 } => {
3224 buf.push(0x04); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3227 }
3228 uni_common::TemporalValue::Duration {
3229 months,
3230 days,
3231 nanos,
3232 } => {
3233 buf.push(0x05); buf.extend_from_slice(&encode_order_preserving_i64(*months));
3235 buf.extend_from_slice(&encode_order_preserving_i64(*days));
3236 buf.extend_from_slice(&encode_order_preserving_i64(*nanos));
3237 }
3238 }
3239}
3240
3241pub fn invoke_cypher_string_op<F>(
3246 args: &ScalarFunctionArgs,
3247 name: &str,
3248 op: F,
3249) -> DFResult<ColumnarValue>
3250where
3251 F: Fn(&str, &str) -> bool,
3252{
3253 use arrow_array::{BooleanArray, LargeBinaryArray, LargeStringArray, StringArray};
3254 use datafusion::common::ScalarValue;
3255 use datafusion::error::DataFusionError;
3256
3257 if args.args.len() != 2 {
3258 return Err(DataFusionError::Execution(format!(
3259 "{}(): requires exactly 2 arguments",
3260 name
3261 )));
3262 }
3263
3264 let left = &args.args[0];
3265 let right = &args.args[1];
3266
3267 let extract_string = |scalar: &ScalarValue| -> Option<String> {
3269 match scalar {
3270 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => Some(s.clone()),
3271 ScalarValue::LargeBinary(Some(bytes)) => {
3272 match uni_common::cypher_value_codec::decode(bytes) {
3274 Ok(uni_common::Value::String(s)) => Some(s),
3275 _ => None,
3276 }
3277 }
3278 ScalarValue::Utf8(None)
3279 | ScalarValue::LargeUtf8(None)
3280 | ScalarValue::LargeBinary(None)
3281 | ScalarValue::Null => None,
3282 _ => None,
3283 }
3284 };
3285
3286 match (left, right) {
3287 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Scalar(r_scalar)) => {
3288 let l_str = extract_string(l_scalar);
3289 let r_str = extract_string(r_scalar);
3290
3291 match (l_str, r_str) {
3292 (Some(l), Some(r)) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(op(
3293 &l, &r,
3294 ))))),
3295 _ => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
3296 }
3297 }
3298 (ColumnarValue::Array(l_arr), ColumnarValue::Scalar(r_scalar)) => {
3299 let r_val = extract_string(r_scalar);
3301
3302 if r_val.is_none() {
3303 let nulls = arrow_array::new_null_array(&DataType::Boolean, l_arr.len());
3305 return Ok(ColumnarValue::Array(nulls));
3306 }
3307 let pattern = r_val.unwrap();
3308
3309 let result_array = if let Some(arr) = l_arr.as_any().downcast_ref::<StringArray>() {
3311 arr.iter()
3312 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3313 .collect::<BooleanArray>()
3314 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeStringArray>() {
3315 arr.iter()
3316 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3317 .collect::<BooleanArray>()
3318 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3319 arr.iter()
3321 .map(|opt_bytes| {
3322 opt_bytes.and_then(|bytes| {
3323 match uni_common::cypher_value_codec::decode(bytes) {
3324 Ok(uni_common::Value::String(s)) => Some(op(&s, &pattern)),
3325 _ => None,
3326 }
3327 })
3328 })
3329 .collect::<BooleanArray>()
3330 } else {
3331 arrow_array::new_null_array(&DataType::Boolean, l_arr.len())
3333 .as_any()
3334 .downcast_ref::<BooleanArray>()
3335 .unwrap()
3336 .clone()
3337 };
3338
3339 Ok(ColumnarValue::Array(Arc::new(result_array)))
3340 }
3341 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Array(r_arr)) => {
3342 let l_val = extract_string(l_scalar);
3344
3345 if l_val.is_none() {
3346 let nulls = arrow_array::new_null_array(&DataType::Boolean, r_arr.len());
3347 return Ok(ColumnarValue::Array(nulls));
3348 }
3349 let target = l_val.unwrap();
3350
3351 let result_array = if let Some(arr) = r_arr.as_any().downcast_ref::<StringArray>() {
3352 arr.iter()
3353 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3354 .collect::<BooleanArray>()
3355 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeStringArray>() {
3356 arr.iter()
3357 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3358 .collect::<BooleanArray>()
3359 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3360 arr.iter()
3362 .map(|opt_bytes| {
3363 opt_bytes.and_then(|bytes| {
3364 match uni_common::cypher_value_codec::decode(bytes) {
3365 Ok(uni_common::Value::String(s)) => Some(op(&target, &s)),
3366 _ => None,
3367 }
3368 })
3369 })
3370 .collect::<BooleanArray>()
3371 } else {
3372 arrow_array::new_null_array(&DataType::Boolean, r_arr.len())
3374 .as_any()
3375 .downcast_ref::<BooleanArray>()
3376 .unwrap()
3377 .clone()
3378 };
3379
3380 Ok(ColumnarValue::Array(Arc::new(result_array)))
3381 }
3382 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
3383 if l_arr.len() != r_arr.len() {
3385 return Err(DataFusionError::Execution(format!(
3386 "{}(): array lengths must match",
3387 name
3388 )));
3389 }
3390
3391 let extract_string_at = |arr: &dyn Array, idx: usize| -> Option<String> {
3393 if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
3394 str_arr.value(idx).to_string().into()
3395 } else if let Some(str_arr) = arr.as_any().downcast_ref::<LargeStringArray>() {
3396 str_arr.value(idx).to_string().into()
3397 } else if let Some(bin_arr) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
3398 if bin_arr.is_null(idx) {
3399 return None;
3400 }
3401 let bytes = bin_arr.value(idx);
3402 match uni_common::cypher_value_codec::decode(bytes) {
3403 Ok(uni_common::Value::String(s)) => Some(s),
3404 _ => None,
3405 }
3406 } else {
3407 None
3408 }
3409 };
3410
3411 let result: BooleanArray = (0..l_arr.len())
3412 .map(|idx| {
3413 match (
3414 extract_string_at(l_arr.as_ref(), idx),
3415 extract_string_at(r_arr.as_ref(), idx),
3416 ) {
3417 (Some(l_str), Some(r_str)) => Some(op(&l_str, &r_str)),
3418 _ => None,
3419 }
3420 })
3421 .collect();
3422
3423 Ok(ColumnarValue::Array(Arc::new(result)))
3424 }
3425 }
3426}
3427
3428macro_rules! define_string_op_udf {
3429 ($struct_name:ident, $udf_name:literal, $op:expr) => {
3430 #[derive(Debug)]
3431 struct $struct_name {
3432 signature: Signature,
3433 }
3434
3435 impl $struct_name {
3436 fn new() -> Self {
3437 Self {
3438 signature: Signature::any(2, Volatility::Immutable),
3440 }
3441 }
3442 }
3443
3444 impl_udf_eq_hash!($struct_name);
3445
3446 impl ScalarUDFImpl for $struct_name {
3447 fn as_any(&self) -> &dyn Any {
3448 self
3449 }
3450 fn name(&self) -> &str {
3451 $udf_name
3452 }
3453 fn signature(&self) -> &Signature {
3454 &self.signature
3455 }
3456 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
3457 Ok(DataType::Boolean)
3458 }
3459
3460 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3461 invoke_cypher_string_op(&args, $udf_name, $op)
3462 }
3463 }
3464 };
3465}
3466
3467define_string_op_udf!(CypherStartsWithUdf, "_cypher_starts_with", |s, p| s
3468 .starts_with(p));
3469define_string_op_udf!(CypherEndsWithUdf, "_cypher_ends_with", |s, p| s
3470 .ends_with(p));
3471define_string_op_udf!(CypherContainsUdf, "_cypher_contains", |s, p| s.contains(p));
3472
3473pub fn create_cypher_starts_with_udf() -> ScalarUDF {
3474 ScalarUDF::new_from_impl(CypherStartsWithUdf::new())
3475}
3476pub fn create_cypher_ends_with_udf() -> ScalarUDF {
3477 ScalarUDF::new_from_impl(CypherEndsWithUdf::new())
3478}
3479pub fn create_cypher_contains_udf() -> ScalarUDF {
3480 ScalarUDF::new_from_impl(CypherContainsUdf::new())
3481}
3482
3483pub fn create_cypher_equal_udf() -> ScalarUDF {
3484 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_equal", BinaryOp::Eq))
3485}
3486pub fn create_cypher_not_equal_udf() -> ScalarUDF {
3487 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_not_equal", BinaryOp::NotEq))
3488}
3489pub fn create_cypher_lt_udf() -> ScalarUDF {
3490 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt", BinaryOp::Lt))
3491}
3492pub fn create_cypher_lt_eq_udf() -> ScalarUDF {
3493 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt_eq", BinaryOp::LtEq))
3494}
3495pub fn create_cypher_gt_udf() -> ScalarUDF {
3496 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt", BinaryOp::Gt))
3497}
3498pub fn create_cypher_gt_eq_udf() -> ScalarUDF {
3499 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt_eq", BinaryOp::GtEq))
3500}
3501
3502#[expect(clippy::match_like_matches_macro)]
3504fn apply_comparison_op(ord: std::cmp::Ordering, op: &BinaryOp) -> bool {
3505 use std::cmp::Ordering;
3506 match (ord, op) {
3507 (Ordering::Less, BinaryOp::Lt | BinaryOp::LtEq | BinaryOp::NotEq) => true,
3508 (Ordering::Equal, BinaryOp::Eq | BinaryOp::LtEq | BinaryOp::GtEq) => true,
3509 (Ordering::Greater, BinaryOp::Gt | BinaryOp::GtEq | BinaryOp::NotEq) => true,
3510 _ => false,
3511 }
3512}
3513
3514fn compare_f64(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<bool> {
3517 if lhs.is_nan() || rhs.is_nan() {
3518 Some(matches!(op, BinaryOp::NotEq))
3519 } else {
3520 Some(apply_comparison_op(lhs.partial_cmp(&rhs)?, op))
3521 }
3522}
3523
3524fn cv_bytes_as_f64(bytes: &[u8]) -> Option<f64> {
3526 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
3527 match peek_tag(bytes)? {
3528 TAG_INT => decode_int(bytes).map(|i| i as f64),
3529 TAG_FLOAT => decode_float(bytes),
3530 _ => None,
3531 }
3532}
3533
3534fn compare_cv_numeric(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<bool> {
3537 use uni_common::cypher_value_codec::{TAG_INT, TAG_NULL, decode_int, peek_tag};
3538 if peek_tag(bytes) == Some(TAG_INT)
3540 && let Some(lhs_int) = decode_int(bytes)
3541 && rhs.fract() == 0.0
3543 && rhs >= i64::MIN as f64
3544 && rhs <= i64::MAX as f64
3545 {
3546 return Some(apply_comparison_op(lhs_int.cmp(&(rhs as i64)), op));
3547 }
3548 if peek_tag(bytes) == Some(TAG_NULL) {
3549 return None;
3550 }
3551 let lhs = cv_bytes_as_f64(bytes)?;
3552 compare_f64(lhs, rhs, op)
3553}
3554
3555fn try_fast_compare(
3559 lhs: &ColumnarValue,
3560 rhs: &ColumnarValue,
3561 op: &BinaryOp,
3562) -> Option<ColumnarValue> {
3563 use arrow_array::builder::BooleanBuilder;
3564 use uni_common::cypher_value_codec::{
3565 TAG_INT, TAG_NULL, TAG_STRING, decode_int, decode_string, peek_tag,
3566 };
3567
3568 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
3569 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
3570 _ => return None,
3571 };
3572
3573 if !matches!(lhs_arr.data_type(), DataType::LargeBinary) {
3575 return None;
3576 }
3577
3578 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
3579
3580 match rhs_arr.data_type() {
3581 DataType::Int64 => {
3583 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
3584 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
3585 for i in 0..lb_arr.len() {
3586 if lb_arr.is_null(i) || int_arr.is_null(i) {
3587 builder.append_null();
3588 } else {
3589 match compare_cv_numeric(lb_arr.value(i), int_arr.value(i) as f64, op) {
3590 Some(result) => builder.append_value(result),
3591 None => builder.append_null(),
3592 }
3593 }
3594 }
3595 Some(ColumnarValue::Array(Arc::new(builder.finish())))
3596 }
3597
3598 DataType::Float64 => {
3600 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
3601 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
3602 for i in 0..lb_arr.len() {
3603 if lb_arr.is_null(i) || float_arr.is_null(i) {
3604 builder.append_null();
3605 } else {
3606 match compare_cv_numeric(lb_arr.value(i), float_arr.value(i), op) {
3607 Some(result) => builder.append_value(result),
3608 None => builder.append_null(),
3609 }
3610 }
3611 }
3612 Some(ColumnarValue::Array(Arc::new(builder.finish())))
3613 }
3614
3615 DataType::Utf8 | DataType::LargeUtf8 => {
3617 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
3618 for i in 0..lb_arr.len() {
3619 if lb_arr.is_null(i) || rhs_arr.is_null(i) {
3620 builder.append_null();
3621 } else {
3622 let bytes = lb_arr.value(i);
3623 let rhs_str = if matches!(rhs_arr.data_type(), DataType::Utf8) {
3624 rhs_arr.as_any().downcast_ref::<StringArray>()?.value(i)
3625 } else {
3626 rhs_arr
3627 .as_any()
3628 .downcast_ref::<LargeStringArray>()?
3629 .value(i)
3630 };
3631 match peek_tag(bytes) {
3632 Some(TAG_STRING) => {
3633 if let Some(lhs_str) = decode_string(bytes) {
3634 builder.append_value(apply_comparison_op(
3635 lhs_str.as_str().cmp(rhs_str),
3636 op,
3637 ));
3638 } else {
3639 builder.append_null();
3640 }
3641 }
3642 _ => builder.append_null(),
3643 }
3644 }
3645 }
3646 Some(ColumnarValue::Array(Arc::new(builder.finish())))
3647 }
3648
3649 DataType::LargeBinary => {
3651 let rhs_lb = rhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
3652 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
3653 for i in 0..lb_arr.len() {
3654 if lb_arr.is_null(i) || rhs_lb.is_null(i) {
3655 builder.append_null();
3656 } else {
3657 let lhs_bytes = lb_arr.value(i);
3658 let rhs_bytes = rhs_lb.value(i);
3659 let lhs_tag = peek_tag(lhs_bytes);
3660 let rhs_tag = peek_tag(rhs_bytes);
3661
3662 if lhs_tag == Some(TAG_NULL) || rhs_tag == Some(TAG_NULL) {
3664 builder.append_null();
3665 continue;
3666 }
3667
3668 if lhs_tag == Some(TAG_INT) && rhs_tag == Some(TAG_INT) {
3670 if let (Some(l), Some(r)) = (decode_int(lhs_bytes), decode_int(rhs_bytes)) {
3671 builder.append_value(apply_comparison_op(l.cmp(&r), op));
3672 } else {
3673 builder.append_null();
3674 }
3675 continue;
3676 }
3677
3678 if lhs_tag == Some(TAG_STRING) && rhs_tag == Some(TAG_STRING) {
3680 if let (Some(l), Some(r)) =
3681 (decode_string(lhs_bytes), decode_string(rhs_bytes))
3682 {
3683 builder.append_value(apply_comparison_op(l.cmp(&r), op));
3684 } else {
3685 builder.append_null();
3686 }
3687 continue;
3688 }
3689
3690 if let (Some(l), Some(r)) =
3692 (cv_bytes_as_f64(lhs_bytes), cv_bytes_as_f64(rhs_bytes))
3693 {
3694 match compare_f64(l, r, op) {
3695 Some(result) => builder.append_value(result),
3696 None => builder.append_null(),
3697 }
3698 } else {
3699 return None;
3703 }
3704 }
3705 }
3706 Some(ColumnarValue::Array(Arc::new(builder.finish())))
3707 }
3708
3709 _ => None, }
3711}
3712
3713#[derive(Debug)]
3714struct CypherCompareUdf {
3715 name: String,
3716 op: BinaryOp,
3717 signature: Signature,
3718}
3719
3720impl CypherCompareUdf {
3721 fn new(name: &str, op: BinaryOp) -> Self {
3722 Self {
3723 name: name.to_string(),
3724 op,
3725 signature: Signature::any(2, Volatility::Immutable),
3726 }
3727 }
3728}
3729
3730impl PartialEq for CypherCompareUdf {
3731 fn eq(&self, other: &Self) -> bool {
3732 self.name == other.name
3733 }
3734}
3735
3736impl Eq for CypherCompareUdf {}
3737
3738impl std::hash::Hash for CypherCompareUdf {
3739 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
3740 self.name.hash(state);
3741 }
3742}
3743
3744impl ScalarUDFImpl for CypherCompareUdf {
3745 fn as_any(&self) -> &dyn Any {
3746 self
3747 }
3748 fn name(&self) -> &str {
3749 &self.name
3750 }
3751 fn signature(&self) -> &Signature {
3752 &self.signature
3753 }
3754 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
3755 Ok(DataType::Boolean)
3756 }
3757
3758 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3759 if args.args.len() != 2 {
3760 return Err(datafusion::error::DataFusionError::Execution(format!(
3761 "{}(): requires 2 arguments",
3762 self.name
3763 )));
3764 }
3765
3766 if let Some(result) = try_fast_compare(&args.args[0], &args.args[1], &self.op) {
3768 return Ok(result);
3769 }
3770
3771 let output_type = DataType::Boolean;
3773 invoke_cypher_udf(args, &output_type, |val_args| {
3774 crate::query::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
3775 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
3776 })
3777 }
3778}
3779
3780pub fn create_cypher_add_udf() -> ScalarUDF {
3786 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_add", BinaryOp::Add))
3787}
3788pub fn create_cypher_sub_udf() -> ScalarUDF {
3789 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_sub", BinaryOp::Sub))
3790}
3791pub fn create_cypher_mul_udf() -> ScalarUDF {
3792 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mul", BinaryOp::Mul))
3793}
3794pub fn create_cypher_div_udf() -> ScalarUDF {
3795 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_div", BinaryOp::Div))
3796}
3797pub fn create_cypher_mod_udf() -> ScalarUDF {
3798 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mod", BinaryOp::Mod))
3799}
3800
3801pub fn create_cypher_abs_udf() -> ScalarUDF {
3803 ScalarUDF::new_from_impl(CypherAbsUdf::new())
3804}
3805
3806pub(crate) fn cypher_abs_expr(
3808 arg: datafusion::logical_expr::Expr,
3809) -> datafusion::logical_expr::Expr {
3810 datafusion::logical_expr::Expr::ScalarFunction(
3811 datafusion::logical_expr::expr::ScalarFunction::new_udf(
3812 Arc::new(create_cypher_abs_udf()),
3813 vec![arg],
3814 ),
3815 )
3816}
3817
3818#[derive(Debug)]
3819struct CypherAbsUdf {
3820 signature: Signature,
3821}
3822
3823impl CypherAbsUdf {
3824 fn new() -> Self {
3825 Self {
3826 signature: Signature::any(1, Volatility::Immutable),
3827 }
3828 }
3829}
3830
3831impl_udf_eq_hash!(CypherAbsUdf);
3832
3833impl ScalarUDFImpl for CypherAbsUdf {
3834 fn as_any(&self) -> &dyn Any {
3835 self
3836 }
3837 fn name(&self) -> &str {
3838 "_cypher_abs"
3839 }
3840 fn signature(&self) -> &Signature {
3841 &self.signature
3842 }
3843 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3844 Ok(DataType::LargeBinary)
3845 }
3846 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3847 if args.args.len() != 1 {
3848 return Err(datafusion::error::DataFusionError::Execution(
3849 "_cypher_abs requires exactly 1 argument".into(),
3850 ));
3851 }
3852 invoke_cypher_udf(args, &DataType::LargeBinary, |val_args| {
3853 match &val_args[0] {
3854 Value::Int(i) => i.checked_abs().map(Value::Int).ok_or_else(|| {
3855 datafusion::error::DataFusionError::Execution(
3856 "integer overflow in abs()".into(),
3857 )
3858 }),
3859 Value::Float(f) => Ok(Value::Float(f.abs())),
3860 Value::Null => Ok(Value::Null),
3861 other => Err(datafusion::error::DataFusionError::Execution(format!(
3862 "abs() requires a numeric argument, got {other:?}"
3863 ))),
3864 }
3865 })
3866 }
3867}
3868
3869fn apply_int_arithmetic(lhs: i64, rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
3872 use uni_common::cypher_value_codec::encode_int;
3873 match op {
3874 BinaryOp::Add => lhs.checked_add(rhs).map(encode_int),
3875 BinaryOp::Sub => lhs.checked_sub(rhs).map(encode_int),
3876 BinaryOp::Mul => lhs.checked_mul(rhs).map(encode_int),
3877 BinaryOp::Div => {
3878 if rhs == 0 {
3880 None
3881 } else {
3882 lhs.checked_div(rhs).map(encode_int)
3883 }
3884 }
3885 BinaryOp::Mod => {
3886 if rhs == 0 {
3887 None
3888 } else {
3889 lhs.checked_rem(rhs).map(encode_int)
3890 }
3891 }
3892 _ => None,
3893 }
3894}
3895
3896fn apply_float_arithmetic(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
3898 use uni_common::cypher_value_codec::encode_float;
3899 let result = match op {
3900 BinaryOp::Add => lhs + rhs,
3901 BinaryOp::Sub => lhs - rhs,
3902 BinaryOp::Mul => lhs * rhs,
3903 BinaryOp::Div => lhs / rhs, BinaryOp::Mod => lhs % rhs,
3905 _ => return None,
3906 };
3907 Some(encode_float(result))
3908}
3909
3910fn cv_arithmetic_int(bytes: &[u8], rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
3913 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
3914 match peek_tag(bytes)? {
3915 TAG_INT => apply_int_arithmetic(decode_int(bytes)?, rhs, op),
3916 TAG_FLOAT => apply_float_arithmetic(decode_float(bytes)?, rhs as f64, op),
3917 _ => None,
3918 }
3919}
3920
3921fn cv_arithmetic_float(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
3924 let lhs = cv_bytes_as_f64(bytes)?;
3925 apply_float_arithmetic(lhs, rhs, op)
3926}
3927
3928fn try_fast_arithmetic(
3932 lhs: &ColumnarValue,
3933 rhs: &ColumnarValue,
3934 op: &BinaryOp,
3935) -> Option<ColumnarValue> {
3936 use arrow_array::builder::LargeBinaryBuilder;
3937
3938 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
3939 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
3940 _ => return None,
3941 };
3942
3943 match (lhs_arr.data_type(), rhs_arr.data_type()) {
3944 (DataType::LargeBinary, DataType::Int64) => {
3946 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
3947 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
3948 let mut builder = LargeBinaryBuilder::new();
3949 for i in 0..lb_arr.len() {
3950 if lb_arr.is_null(i) || int_arr.is_null(i) {
3951 builder.append_null();
3952 } else if let Some(bytes) = cv_arithmetic_int(lb_arr.value(i), int_arr.value(i), op)
3953 {
3954 builder.append_value(&bytes);
3955 } else {
3956 builder.append_null();
3957 }
3958 }
3959 Some(ColumnarValue::Array(Arc::new(builder.finish())))
3960 }
3961
3962 (DataType::LargeBinary, DataType::Float64) => {
3964 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
3965 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
3966 let mut builder = LargeBinaryBuilder::new();
3967 for i in 0..lb_arr.len() {
3968 if lb_arr.is_null(i) || float_arr.is_null(i) {
3969 builder.append_null();
3970 } else if let Some(bytes) =
3971 cv_arithmetic_float(lb_arr.value(i), float_arr.value(i), op)
3972 {
3973 builder.append_value(&bytes);
3974 } else {
3975 builder.append_null();
3976 }
3977 }
3978 Some(ColumnarValue::Array(Arc::new(builder.finish())))
3979 }
3980
3981 (DataType::Int64, DataType::Int64) => {
3983 let lhs_int = lhs_arr.as_any().downcast_ref::<Int64Array>()?;
3984 let rhs_int = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
3985 let mut builder = LargeBinaryBuilder::new();
3986 for i in 0..lhs_int.len() {
3987 if lhs_int.is_null(i) || rhs_int.is_null(i) {
3988 builder.append_null();
3989 } else if let Some(bytes) =
3990 apply_int_arithmetic(lhs_int.value(i), rhs_int.value(i), op)
3991 {
3992 builder.append_value(&bytes);
3993 } else {
3994 builder.append_null();
3995 }
3996 }
3997 Some(ColumnarValue::Array(Arc::new(builder.finish())))
3998 }
3999
4000 _ => None, }
4002}
4003
4004#[derive(Debug)]
4005struct CypherArithmeticUdf {
4006 name: String,
4007 op: BinaryOp,
4008 signature: Signature,
4009}
4010
4011impl CypherArithmeticUdf {
4012 fn new(name: &str, op: BinaryOp) -> Self {
4013 Self {
4014 name: name.to_string(),
4015 op,
4016 signature: Signature::any(2, Volatility::Immutable),
4017 }
4018 }
4019}
4020
4021impl PartialEq for CypherArithmeticUdf {
4022 fn eq(&self, other: &Self) -> bool {
4023 self.name == other.name
4024 }
4025}
4026
4027impl Eq for CypherArithmeticUdf {}
4028
4029impl std::hash::Hash for CypherArithmeticUdf {
4030 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4031 self.name.hash(state);
4032 }
4033}
4034
4035impl ScalarUDFImpl for CypherArithmeticUdf {
4036 fn as_any(&self) -> &dyn Any {
4037 self
4038 }
4039 fn name(&self) -> &str {
4040 &self.name
4041 }
4042 fn signature(&self) -> &Signature {
4043 &self.signature
4044 }
4045 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4046 Ok(DataType::LargeBinary) }
4048
4049 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4050 if args.args.len() != 2 {
4051 return Err(datafusion::error::DataFusionError::Execution(format!(
4052 "{}(): requires 2 arguments",
4053 self.name
4054 )));
4055 }
4056
4057 if let Some(result) = try_fast_arithmetic(&args.args[0], &args.args[1], &self.op) {
4059 return Ok(result);
4060 }
4061
4062 let output_type = DataType::LargeBinary;
4064 invoke_cypher_udf(args, &output_type, |val_args| {
4065 crate::query::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4066 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4067 })
4068 }
4069}
4070
4071pub fn create_cypher_xor_udf() -> ScalarUDF {
4076 ScalarUDF::new_from_impl(CypherXorUdf::new())
4077}
4078
4079#[derive(Debug)]
4080struct CypherXorUdf {
4081 signature: Signature,
4082}
4083
4084impl CypherXorUdf {
4085 fn new() -> Self {
4086 Self {
4087 signature: Signature::any(2, Volatility::Immutable),
4088 }
4089 }
4090}
4091
4092impl_udf_eq_hash!(CypherXorUdf);
4093
4094impl ScalarUDFImpl for CypherXorUdf {
4095 fn as_any(&self) -> &dyn Any {
4096 self
4097 }
4098 fn name(&self) -> &str {
4099 "_cypher_xor"
4100 }
4101 fn signature(&self) -> &Signature {
4102 &self.signature
4103 }
4104 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4105 Ok(DataType::Boolean)
4106 }
4107
4108 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4109 let output_type = DataType::Boolean;
4110 invoke_cypher_udf(args, &output_type, |val_args| {
4111 if val_args.len() != 2 {
4112 return Err(datafusion::error::DataFusionError::Execution(
4113 "_cypher_xor(): requires 2 arguments".to_string(),
4114 ));
4115 }
4116 let coerce_bool = |v: &Value| -> Value {
4118 match v {
4119 Value::String(s) if s == "true" => Value::Bool(true),
4120 Value::String(s) if s == "false" => Value::Bool(false),
4121 other => other.clone(),
4122 }
4123 };
4124 let left = coerce_bool(&val_args[0]);
4125 let right = coerce_bool(&val_args[1]);
4126 crate::query::expr_eval::eval_binary_op(&left, &BinaryOp::Xor, &right)
4127 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4128 })
4129 }
4130}
4131
4132pub fn create_cv_to_bool_udf() -> ScalarUDF {
4139 ScalarUDF::new_from_impl(CvToBoolUdf::new())
4140}
4141
4142#[derive(Debug)]
4143struct CvToBoolUdf {
4144 signature: Signature,
4145}
4146
4147impl CvToBoolUdf {
4148 fn new() -> Self {
4149 Self {
4150 signature: Signature::exact(vec![DataType::LargeBinary], Volatility::Immutable),
4151 }
4152 }
4153}
4154
4155impl_udf_eq_hash!(CvToBoolUdf);
4156
4157impl ScalarUDFImpl for CvToBoolUdf {
4158 fn as_any(&self) -> &dyn Any {
4159 self
4160 }
4161 fn name(&self) -> &str {
4162 "_cv_to_bool"
4163 }
4164 fn signature(&self) -> &Signature {
4165 &self.signature
4166 }
4167 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4168 Ok(DataType::Boolean)
4169 }
4170
4171 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4172 if args.args.len() != 1 {
4173 return Err(datafusion::error::DataFusionError::Execution(
4174 "_cv_to_bool() requires exactly 1 argument".to_string(),
4175 ));
4176 }
4177
4178 match &args.args[0] {
4179 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
4180 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4182 let b = match peek_tag(bytes) {
4183 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4184 Some(TAG_NULL) => false,
4185 _ => false, };
4187 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))))
4188 }
4189 ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
4190 ColumnarValue::Array(arr) => {
4191 let lb_arr = arr
4192 .as_any()
4193 .downcast_ref::<arrow_array::LargeBinaryArray>()
4194 .ok_or_else(|| {
4195 datafusion::error::DataFusionError::Execution(format!(
4196 "_cv_to_bool(): expected LargeBinary array, got {:?}",
4197 arr.data_type()
4198 ))
4199 })?;
4200
4201 let mut builder = arrow_array::builder::BooleanBuilder::with_capacity(lb_arr.len());
4202
4203 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4205
4206 for i in 0..lb_arr.len() {
4207 if lb_arr.is_null(i) {
4208 builder.append_null();
4209 } else {
4210 let bytes = lb_arr.value(i);
4211 let b = match peek_tag(bytes) {
4212 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4213 Some(TAG_NULL) => false,
4214 _ => false, };
4216 builder.append_value(b);
4217 }
4218 }
4219 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
4220 }
4221 }
4222 }
4223}
4224
4225pub fn create_cypher_size_udf() -> ScalarUDF {
4231 ScalarUDF::new_from_impl(CypherSizeUdf::new())
4232}
4233
4234#[derive(Debug)]
4235struct CypherSizeUdf {
4236 signature: Signature,
4237}
4238
4239impl CypherSizeUdf {
4240 fn new() -> Self {
4241 Self {
4242 signature: Signature::any(1, Volatility::Immutable),
4243 }
4244 }
4245}
4246
4247impl_udf_eq_hash!(CypherSizeUdf);
4248
4249impl ScalarUDFImpl for CypherSizeUdf {
4250 fn as_any(&self) -> &dyn Any {
4251 self
4252 }
4253
4254 fn name(&self) -> &str {
4255 "_cypher_size"
4256 }
4257
4258 fn signature(&self) -> &Signature {
4259 &self.signature
4260 }
4261
4262 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4263 Ok(DataType::Int64)
4264 }
4265
4266 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4267 if args.args.len() != 1 {
4268 return Err(datafusion::error::DataFusionError::Execution(
4269 "_cypher_size() requires exactly 1 argument".to_string(),
4270 ));
4271 }
4272
4273 match &args.args[0] {
4274 ColumnarValue::Scalar(scalar) => {
4275 let result = cypher_size_scalar(scalar)?;
4276 Ok(ColumnarValue::Scalar(result))
4277 }
4278 ColumnarValue::Array(arr) => {
4279 let mut results: Vec<Option<i64>> = Vec::with_capacity(arr.len());
4280 for i in 0..arr.len() {
4281 if arr.is_null(i) {
4282 results.push(None);
4283 } else {
4284 let scalar = ScalarValue::try_from_array(arr, i)?;
4285 match cypher_size_scalar(&scalar)? {
4286 ScalarValue::Int64(v) => results.push(v),
4287 _ => results.push(None),
4288 }
4289 }
4290 }
4291 let arr: ArrayRef = Arc::new(arrow_array::Int64Array::from(results));
4292 Ok(ColumnarValue::Array(arr))
4293 }
4294 }
4295 }
4296}
4297
4298fn cypher_size_scalar(scalar: &ScalarValue) -> DFResult<ScalarValue> {
4299 match scalar {
4300 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
4302 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
4303 }
4304 ScalarValue::List(arr) => {
4307 if arr.is_empty() || arr.is_null(0) {
4308 Ok(ScalarValue::Int64(None))
4309 } else {
4310 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4311 }
4312 }
4313 ScalarValue::LargeList(arr) => {
4314 if arr.is_empty() || arr.is_null(0) {
4315 Ok(ScalarValue::Int64(None))
4316 } else {
4317 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4318 }
4319 }
4320 ScalarValue::LargeBinary(Some(b)) => {
4322 if let Ok(uni_val) = uni_common::cypher_value_codec::decode(b) {
4323 match &uni_val {
4324 uni_common::Value::Node(_) => {
4325 Err(datafusion::error::DataFusionError::Execution(
4326 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
4327 ))
4328 }
4329 uni_common::Value::Edge(_) => {
4330 Err(datafusion::error::DataFusionError::Execution(
4331 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
4332 ))
4333 }
4334 _ => {
4335 let json_val: serde_json::Value = uni_val.into();
4336 match json_val {
4337 serde_json::Value::Array(arr) => Ok(ScalarValue::Int64(Some(arr.len() as i64))),
4338 serde_json::Value::String(s) => {
4339 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
4340 }
4341 serde_json::Value::Object(m) => Ok(ScalarValue::Int64(Some(m.len() as i64))),
4342 _ => Ok(ScalarValue::Int64(None)),
4343 }
4344 }
4345 }
4346 } else {
4347 Ok(ScalarValue::Int64(None))
4348 }
4349 }
4350 ScalarValue::Map(arr) => {
4352 if arr.is_empty() || arr.is_null(0) {
4353 Ok(ScalarValue::Int64(None))
4354 } else {
4355 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4357 }
4358 }
4359 ScalarValue::Struct(arr) => {
4361 if arr.is_null(0) {
4362 Ok(ScalarValue::Int64(None))
4363 } else {
4364 let schema = arr.fields();
4365 let field_names: Vec<&str> = schema.iter().map(|f| f.name().as_str()).collect();
4366 if field_names.contains(&"_vid") && !field_names.contains(&"relationships") {
4368 return Err(datafusion::error::DataFusionError::Execution(
4369 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
4370 ));
4371 }
4372 if field_names.contains(&"_eid")
4374 || (field_names.contains(&"_src") && field_names.contains(&"_dst"))
4375 {
4376 return Err(datafusion::error::DataFusionError::Execution(
4377 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
4378 ));
4379 }
4380 if let Some((rels_idx, _)) = schema
4382 .iter()
4383 .enumerate()
4384 .find(|(_, f)| f.name() == "relationships")
4385 {
4386 let rels_col = arr.column(rels_idx);
4388 if let Some(list_arr) =
4389 rels_col.as_any().downcast_ref::<arrow_array::ListArray>()
4390 {
4391 if list_arr.is_null(0) {
4392 Ok(ScalarValue::Int64(Some(0)))
4393 } else {
4394 Ok(ScalarValue::Int64(Some(list_arr.value(0).len() as i64)))
4395 }
4396 } else {
4397 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
4398 }
4399 } else {
4400 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
4401 }
4402 }
4403 }
4404 ScalarValue::Null
4406 | ScalarValue::Utf8(None)
4407 | ScalarValue::LargeUtf8(None)
4408 | ScalarValue::LargeBinary(None) => Ok(ScalarValue::Int64(None)),
4409 other => Err(datafusion::error::DataFusionError::Execution(format!(
4410 "_cypher_size(): unsupported type {other:?}"
4411 ))),
4412 }
4413}
4414
4415pub fn create_cypher_list_compare_udf() -> ScalarUDF {
4421 ScalarUDF::new_from_impl(CypherListCompareUdf::new())
4422}
4423
4424#[derive(Debug)]
4425struct CypherListCompareUdf {
4426 signature: Signature,
4427}
4428
4429impl CypherListCompareUdf {
4430 fn new() -> Self {
4431 Self {
4432 signature: Signature::any(3, Volatility::Immutable),
4433 }
4434 }
4435}
4436
4437impl_udf_eq_hash!(CypherListCompareUdf);
4438
4439impl ScalarUDFImpl for CypherListCompareUdf {
4440 fn as_any(&self) -> &dyn Any {
4441 self
4442 }
4443
4444 fn name(&self) -> &str {
4445 "_cypher_list_compare"
4446 }
4447
4448 fn signature(&self) -> &Signature {
4449 &self.signature
4450 }
4451
4452 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4453 Ok(DataType::Boolean)
4454 }
4455
4456 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4457 let output_type = DataType::Boolean;
4458 invoke_cypher_udf(args, &output_type, |val_args| {
4459 if val_args.len() != 3 {
4460 return Err(datafusion::error::DataFusionError::Execution(
4461 "_cypher_list_compare(): requires 3 arguments (left, right, op)".to_string(),
4462 ));
4463 }
4464
4465 let left = &val_args[0];
4466 let right = &val_args[1];
4467 let op_str = match &val_args[2] {
4468 Value::String(s) => s.as_str(),
4469 _ => {
4470 return Err(datafusion::error::DataFusionError::Execution(
4471 "_cypher_list_compare(): op must be a string".to_string(),
4472 ));
4473 }
4474 };
4475
4476 let (left_items, right_items) = match (left, right) {
4477 (Value::List(l), Value::List(r)) => (l, r),
4478 (Value::Null, _) | (_, Value::Null) => return Ok(Value::Null),
4479 _ => {
4480 return Err(datafusion::error::DataFusionError::Execution(
4481 "_cypher_list_compare(): both arguments must be lists".to_string(),
4482 ));
4483 }
4484 };
4485
4486 let cmp = cypher_list_cmp(left_items, right_items);
4488
4489 let result = match (op_str, cmp) {
4490 (_, None) => Value::Null,
4491 ("lt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Less),
4492 ("lteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Greater),
4493 ("gt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Greater),
4494 ("gteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Less),
4495 _ => {
4496 return Err(datafusion::error::DataFusionError::Execution(format!(
4497 "_cypher_list_compare(): unknown op '{}'",
4498 op_str
4499 )));
4500 }
4501 };
4502
4503 Ok(result)
4504 })
4505 }
4506}
4507
4508pub fn create_map_project_udf() -> ScalarUDF {
4513 ScalarUDF::new_from_impl(MapProjectUdf::new())
4514}
4515
4516#[derive(Debug)]
4517struct MapProjectUdf {
4518 signature: Signature,
4519}
4520
4521impl MapProjectUdf {
4522 fn new() -> Self {
4523 Self {
4524 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
4525 }
4526 }
4527}
4528
4529impl_udf_eq_hash!(MapProjectUdf);
4530
4531impl ScalarUDFImpl for MapProjectUdf {
4532 fn as_any(&self) -> &dyn Any {
4533 self
4534 }
4535
4536 fn name(&self) -> &str {
4537 "_map_project"
4538 }
4539
4540 fn signature(&self) -> &Signature {
4541 &self.signature
4542 }
4543
4544 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4545 Ok(DataType::LargeBinary)
4546 }
4547
4548 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4549 let output_type = self.return_type(&[])?;
4550 invoke_cypher_udf(args, &output_type, |val_args| {
4551 let mut result_map = std::collections::HashMap::new();
4552 let mut i = 0;
4553 while i + 1 < val_args.len() {
4554 let key = &val_args[i];
4555 let value = &val_args[i + 1];
4556 if let Some(k) = key.as_str() {
4557 if k == "__all__" {
4558 match value {
4560 Value::Map(map) => {
4561 for (mk, mv) in map {
4562 if !mk.starts_with('_') {
4563 result_map.insert(mk.clone(), mv.clone());
4564 }
4565 }
4566 }
4567 Value::Node(node) => {
4568 for (pk, pv) in &node.properties {
4569 result_map.insert(pk.clone(), pv.clone());
4570 }
4571 }
4572 Value::Edge(edge) => {
4573 for (pk, pv) in &edge.properties {
4574 result_map.insert(pk.clone(), pv.clone());
4575 }
4576 }
4577 _ => {}
4578 }
4579 } else {
4580 result_map.insert(k.to_string(), value.clone());
4581 }
4582 }
4583 i += 2;
4584 }
4585 Ok(Value::Map(result_map))
4586 })
4587 }
4588}
4589
4590pub fn create_make_cypher_list_udf() -> ScalarUDF {
4595 ScalarUDF::new_from_impl(MakeCypherListUdf::new())
4596}
4597
4598#[derive(Debug)]
4599struct MakeCypherListUdf {
4600 signature: Signature,
4601}
4602
4603impl MakeCypherListUdf {
4604 fn new() -> Self {
4605 Self {
4606 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
4607 }
4608 }
4609}
4610
4611impl_udf_eq_hash!(MakeCypherListUdf);
4612
4613impl ScalarUDFImpl for MakeCypherListUdf {
4614 fn as_any(&self) -> &dyn Any {
4615 self
4616 }
4617
4618 fn name(&self) -> &str {
4619 "_make_cypher_list"
4620 }
4621
4622 fn signature(&self) -> &Signature {
4623 &self.signature
4624 }
4625
4626 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4627 Ok(DataType::LargeBinary)
4628 }
4629
4630 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4631 let output_type = self.return_type(&[])?;
4632 invoke_cypher_udf(args, &output_type, |val_args| {
4633 Ok(Value::List(val_args.to_vec()))
4634 })
4635 }
4636}
4637
4638pub fn create_cypher_in_udf() -> ScalarUDF {
4655 ScalarUDF::new_from_impl(CypherInUdf::new())
4656}
4657
4658#[derive(Debug)]
4659struct CypherInUdf {
4660 signature: Signature,
4661}
4662
4663impl CypherInUdf {
4664 fn new() -> Self {
4665 Self {
4666 signature: Signature::any(2, Volatility::Immutable),
4667 }
4668 }
4669}
4670
4671impl_udf_eq_hash!(CypherInUdf);
4672
4673impl ScalarUDFImpl for CypherInUdf {
4674 fn as_any(&self) -> &dyn Any {
4675 self
4676 }
4677
4678 fn name(&self) -> &str {
4679 "_cypher_in"
4680 }
4681
4682 fn signature(&self) -> &Signature {
4683 &self.signature
4684 }
4685
4686 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4687 Ok(DataType::Boolean)
4688 }
4689
4690 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4691 invoke_cypher_udf(args, &DataType::Boolean, |vals| {
4692 if vals.len() != 2 {
4693 return Err(datafusion::error::DataFusionError::Execution(
4694 "_cypher_in(): requires 2 arguments".to_string(),
4695 ));
4696 }
4697 let element = &vals[0];
4698 let list_val = &vals[1];
4699
4700 if list_val.is_null() {
4702 return Ok(Value::Null);
4703 }
4704
4705 let items = match list_val {
4707 Value::List(items) => items.as_slice(),
4708 _ => {
4709 return Err(datafusion::error::DataFusionError::Execution(format!(
4710 "_cypher_in(): second argument must be a list, got {:?}",
4711 list_val
4712 )));
4713 }
4714 };
4715
4716 if element.is_null() {
4718 return if items.is_empty() {
4719 Ok(Value::Bool(false))
4720 } else {
4721 Ok(Value::Null) };
4723 }
4724
4725 let mut has_null = false;
4727 for item in items {
4728 match cypher_eq(element, item) {
4729 Some(true) => return Ok(Value::Bool(true)),
4730 None => has_null = true,
4731 Some(false) => {}
4732 }
4733 }
4734
4735 if has_null {
4736 Ok(Value::Null) } else {
4738 Ok(Value::Bool(false))
4739 }
4740 })
4741 }
4742}
4743
4744pub fn create_cypher_list_concat_udf() -> ScalarUDF {
4750 ScalarUDF::new_from_impl(CypherListConcatUdf::new())
4751}
4752
4753#[derive(Debug)]
4754struct CypherListConcatUdf {
4755 signature: Signature,
4756}
4757
4758impl CypherListConcatUdf {
4759 fn new() -> Self {
4760 Self {
4761 signature: Signature::any(2, Volatility::Immutable),
4762 }
4763 }
4764}
4765
4766impl_udf_eq_hash!(CypherListConcatUdf);
4767
4768impl ScalarUDFImpl for CypherListConcatUdf {
4769 fn as_any(&self) -> &dyn Any {
4770 self
4771 }
4772
4773 fn name(&self) -> &str {
4774 "_cypher_list_concat"
4775 }
4776
4777 fn signature(&self) -> &Signature {
4778 &self.signature
4779 }
4780
4781 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4782 Ok(DataType::LargeBinary)
4783 }
4784
4785 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4786 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
4787 if vals.len() != 2 {
4788 return Err(datafusion::error::DataFusionError::Execution(
4789 "_cypher_list_concat(): requires 2 arguments".to_string(),
4790 ));
4791 }
4792 if vals[0].is_null() || vals[1].is_null() {
4794 return Ok(Value::Null);
4795 }
4796 match (&vals[0], &vals[1]) {
4797 (Value::List(left), Value::List(right)) => {
4798 let mut result = left.clone();
4799 result.extend(right.iter().cloned());
4800 Ok(Value::List(result))
4801 }
4802 (Value::List(list), elem) => {
4805 let mut result = list.clone();
4806 result.push(elem.clone());
4807 Ok(Value::List(result))
4808 }
4809 (elem, Value::List(list)) => {
4810 let mut result = vec![elem.clone()];
4811 result.extend(list.iter().cloned());
4812 Ok(Value::List(result))
4813 }
4814 _ => {
4815 crate::query::expr_eval::eval_binary_op(
4818 &vals[0],
4819 &uni_cypher::ast::BinaryOp::Add,
4820 &vals[1],
4821 )
4822 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4823 }
4824 }
4825 })
4826 }
4827}
4828
4829pub fn create_cypher_list_append_udf() -> ScalarUDF {
4835 ScalarUDF::new_from_impl(CypherListAppendUdf::new())
4836}
4837
4838#[derive(Debug)]
4839struct CypherListAppendUdf {
4840 signature: Signature,
4841}
4842
4843impl CypherListAppendUdf {
4844 fn new() -> Self {
4845 Self {
4846 signature: Signature::any(2, Volatility::Immutable),
4847 }
4848 }
4849}
4850
4851impl_udf_eq_hash!(CypherListAppendUdf);
4852
4853impl ScalarUDFImpl for CypherListAppendUdf {
4854 fn as_any(&self) -> &dyn Any {
4855 self
4856 }
4857
4858 fn name(&self) -> &str {
4859 "_cypher_list_append"
4860 }
4861
4862 fn signature(&self) -> &Signature {
4863 &self.signature
4864 }
4865
4866 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4867 Ok(DataType::LargeBinary)
4868 }
4869
4870 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4871 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
4872 if vals.len() != 2 {
4873 return Err(datafusion::error::DataFusionError::Execution(
4874 "_cypher_list_append(): requires 2 arguments".to_string(),
4875 ));
4876 }
4877 let left = &vals[0];
4878 let right = &vals[1];
4879
4880 if left.is_null() || right.is_null() {
4882 return Ok(Value::Null);
4883 }
4884
4885 match (left, right) {
4886 (Value::List(list), elem) => {
4888 let mut result = list.clone();
4889 result.push(elem.clone());
4890 Ok(Value::List(result))
4891 }
4892 (elem, Value::List(list)) => {
4894 let mut result = vec![elem.clone()];
4895 result.extend(list.iter().cloned());
4896 Ok(Value::List(result))
4897 }
4898 _ => Err(datafusion::error::DataFusionError::Execution(format!(
4899 "_cypher_list_append(): at least one argument must be a list, got {:?} and {:?}",
4900 left, right
4901 ))),
4902 }
4903 })
4904 }
4905}
4906
4907pub fn create_cypher_list_slice_udf() -> ScalarUDF {
4913 ScalarUDF::new_from_impl(CypherListSliceUdf::new())
4914}
4915
4916#[derive(Debug)]
4917struct CypherListSliceUdf {
4918 signature: Signature,
4919}
4920
4921impl CypherListSliceUdf {
4922 fn new() -> Self {
4923 Self {
4924 signature: Signature::any(3, Volatility::Immutable),
4925 }
4926 }
4927}
4928
4929impl_udf_eq_hash!(CypherListSliceUdf);
4930
4931impl ScalarUDFImpl for CypherListSliceUdf {
4932 fn as_any(&self) -> &dyn Any {
4933 self
4934 }
4935
4936 fn name(&self) -> &str {
4937 "_cypher_list_slice"
4938 }
4939
4940 fn signature(&self) -> &Signature {
4941 &self.signature
4942 }
4943
4944 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4945 Ok(DataType::LargeBinary)
4946 }
4947
4948 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4949 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
4950 if vals.len() != 3 {
4951 return Err(datafusion::error::DataFusionError::Execution(
4952 "_cypher_list_slice(): requires 3 arguments (list, start, end)".to_string(),
4953 ));
4954 }
4955 if vals[0].is_null() {
4957 return Ok(Value::Null);
4958 }
4959 let list = match &vals[0] {
4960 Value::List(l) => l,
4961 _ => {
4962 return Err(datafusion::error::DataFusionError::Execution(format!(
4963 "_cypher_list_slice(): first argument must be a list, got {:?}",
4964 vals[0]
4965 )));
4966 }
4967 };
4968 if vals[1].is_null() || vals[2].is_null() {
4970 return Ok(Value::Null);
4971 }
4972
4973 let len = list.len() as i64;
4974 let raw_start = match &vals[1] {
4975 Value::Int(i) => *i,
4976 _ => 0,
4977 };
4978 let raw_end = match &vals[2] {
4979 Value::Int(i) => *i,
4980 _ => len,
4981 };
4982
4983 let start = if raw_start < 0 {
4985 (len + raw_start).max(0) as usize
4986 } else {
4987 (raw_start).min(len) as usize
4988 };
4989 let end = if raw_end == i64::MAX {
4990 len as usize
4991 } else if raw_end < 0 {
4992 (len + raw_end).max(0) as usize
4993 } else {
4994 (raw_end).min(len) as usize
4995 };
4996
4997 if start >= end {
4998 return Ok(Value::List(vec![]));
4999 }
5000 Ok(Value::List(list[start..end.min(list.len())].to_vec()))
5001 })
5002 }
5003}
5004
5005pub fn create_cypher_reverse_udf() -> ScalarUDF {
5016 ScalarUDF::new_from_impl(CypherReverseUdf::new())
5017}
5018
5019#[derive(Debug)]
5020struct CypherReverseUdf {
5021 signature: Signature,
5022}
5023
5024impl CypherReverseUdf {
5025 fn new() -> Self {
5026 Self {
5027 signature: Signature::any(1, Volatility::Immutable),
5028 }
5029 }
5030}
5031
5032impl_udf_eq_hash!(CypherReverseUdf);
5033
5034impl ScalarUDFImpl for CypherReverseUdf {
5035 fn as_any(&self) -> &dyn Any {
5036 self
5037 }
5038
5039 fn name(&self) -> &str {
5040 "_cypher_reverse"
5041 }
5042
5043 fn signature(&self) -> &Signature {
5044 &self.signature
5045 }
5046
5047 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5048 Ok(DataType::LargeBinary)
5049 }
5050
5051 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5052 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5053 if vals.len() != 1 {
5054 return Err(datafusion::error::DataFusionError::Execution(
5055 "_cypher_reverse(): requires exactly 1 argument".to_string(),
5056 ));
5057 }
5058 match &vals[0] {
5059 Value::Null => Ok(Value::Null),
5060 Value::String(s) => Ok(Value::String(s.chars().rev().collect())),
5061 Value::List(l) => {
5062 let mut reversed = l.clone();
5063 reversed.reverse();
5064 Ok(Value::List(reversed))
5065 }
5066 other => Err(datafusion::error::DataFusionError::Execution(format!(
5067 "_cypher_reverse(): expected string or list, got {:?}",
5068 other
5069 ))),
5070 }
5071 })
5072 }
5073}
5074
5075pub fn create_cypher_substring_udf() -> ScalarUDF {
5086 ScalarUDF::new_from_impl(CypherSubstringUdf::new())
5087}
5088
5089#[derive(Debug)]
5090struct CypherSubstringUdf {
5091 signature: Signature,
5092}
5093
5094impl CypherSubstringUdf {
5095 fn new() -> Self {
5096 Self {
5097 signature: Signature::variadic_any(Volatility::Immutable),
5098 }
5099 }
5100}
5101
5102impl_udf_eq_hash!(CypherSubstringUdf);
5103
5104impl ScalarUDFImpl for CypherSubstringUdf {
5105 fn as_any(&self) -> &dyn Any {
5106 self
5107 }
5108
5109 fn name(&self) -> &str {
5110 "_cypher_substring"
5111 }
5112
5113 fn signature(&self) -> &Signature {
5114 &self.signature
5115 }
5116
5117 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5118 Ok(DataType::Utf8)
5119 }
5120
5121 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5122 invoke_cypher_udf(args, &DataType::Utf8, |vals| {
5123 if vals.len() < 2 || vals.len() > 3 {
5124 return Err(datafusion::error::DataFusionError::Execution(
5125 "_cypher_substring(): requires 2 or 3 arguments".to_string(),
5126 ));
5127 }
5128 if vals.iter().any(|v| v.is_null()) {
5130 return Ok(Value::Null);
5131 }
5132 let s = match &vals[0] {
5133 Value::String(s) => s.as_str(),
5134 other => {
5135 return Err(datafusion::error::DataFusionError::Execution(format!(
5136 "_cypher_substring(): first argument must be a string, got {:?}",
5137 other
5138 )));
5139 }
5140 };
5141 let start = match &vals[1] {
5142 Value::Int(i) => *i,
5143 other => {
5144 return Err(datafusion::error::DataFusionError::Execution(format!(
5145 "_cypher_substring(): second argument must be an integer, got {:?}",
5146 other
5147 )));
5148 }
5149 };
5150
5151 let chars: Vec<char> = s.chars().collect();
5153 let len = chars.len() as i64;
5154
5155 let start_idx = start.max(0).min(len) as usize;
5157
5158 let end_idx = if vals.len() == 3 {
5159 let length = match &vals[2] {
5160 Value::Int(i) => *i,
5161 other => {
5162 return Err(datafusion::error::DataFusionError::Execution(format!(
5163 "_cypher_substring(): third argument must be an integer, got {:?}",
5164 other
5165 )));
5166 }
5167 };
5168 if length < 0 {
5169 return Err(datafusion::error::DataFusionError::Execution(
5170 "ArgumentError: NegativeIntegerArgument - substring length must be non-negative".to_string(),
5171 ));
5172 }
5173 (start_idx as i64 + length).min(len) as usize
5174 } else {
5175 len as usize
5176 };
5177
5178 Ok(Value::String(chars[start_idx..end_idx].iter().collect()))
5179 })
5180 }
5181}
5182
5183pub fn create_cypher_split_udf() -> ScalarUDF {
5192 ScalarUDF::new_from_impl(CypherSplitUdf::new())
5193}
5194
5195#[derive(Debug)]
5196struct CypherSplitUdf {
5197 signature: Signature,
5198}
5199
5200impl CypherSplitUdf {
5201 fn new() -> Self {
5202 Self {
5203 signature: Signature::any(2, Volatility::Immutable),
5204 }
5205 }
5206}
5207
5208impl_udf_eq_hash!(CypherSplitUdf);
5209
5210impl ScalarUDFImpl for CypherSplitUdf {
5211 fn as_any(&self) -> &dyn Any {
5212 self
5213 }
5214
5215 fn name(&self) -> &str {
5216 "_cypher_split"
5217 }
5218
5219 fn signature(&self) -> &Signature {
5220 &self.signature
5221 }
5222
5223 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5224 Ok(DataType::LargeBinary)
5225 }
5226
5227 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5228 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5229 if vals.len() != 2 {
5230 return Err(datafusion::error::DataFusionError::Execution(
5231 "_cypher_split(): requires exactly 2 arguments".to_string(),
5232 ));
5233 }
5234 if vals.iter().any(|v| v.is_null()) {
5236 return Ok(Value::Null);
5237 }
5238 let s = match &vals[0] {
5239 Value::String(s) => s.clone(),
5240 other => {
5241 return Err(datafusion::error::DataFusionError::Execution(format!(
5242 "_cypher_split(): first argument must be a string, got {:?}",
5243 other
5244 )));
5245 }
5246 };
5247 let delimiter = match &vals[1] {
5248 Value::String(d) => d.clone(),
5249 other => {
5250 return Err(datafusion::error::DataFusionError::Execution(format!(
5251 "_cypher_split(): second argument must be a string, got {:?}",
5252 other
5253 )));
5254 }
5255 };
5256 let parts: Vec<Value> = s
5257 .split(&delimiter)
5258 .map(|p| Value::String(p.to_string()))
5259 .collect();
5260 Ok(Value::List(parts))
5261 })
5262 }
5263}
5264
5265pub fn create_cypher_list_to_cv_udf() -> ScalarUDF {
5276 ScalarUDF::new_from_impl(CypherListToCvUdf::new())
5277}
5278
5279#[derive(Debug)]
5280struct CypherListToCvUdf {
5281 signature: Signature,
5282}
5283
5284impl CypherListToCvUdf {
5285 fn new() -> Self {
5286 Self {
5287 signature: Signature::any(1, Volatility::Immutable),
5288 }
5289 }
5290}
5291
5292impl_udf_eq_hash!(CypherListToCvUdf);
5293
5294impl ScalarUDFImpl for CypherListToCvUdf {
5295 fn as_any(&self) -> &dyn Any {
5296 self
5297 }
5298
5299 fn name(&self) -> &str {
5300 "_cypher_list_to_cv"
5301 }
5302
5303 fn signature(&self) -> &Signature {
5304 &self.signature
5305 }
5306
5307 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5308 Ok(DataType::LargeBinary)
5309 }
5310
5311 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5312 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5313 if vals.len() != 1 {
5314 return Err(datafusion::error::DataFusionError::Execution(
5315 "_cypher_list_to_cv(): requires exactly 1 argument".to_string(),
5316 ));
5317 }
5318 Ok(vals[0].clone())
5319 })
5320 }
5321}
5322
5323pub fn create_cypher_scalar_to_cv_udf() -> ScalarUDF {
5334 ScalarUDF::new_from_impl(CypherScalarToCvUdf::new())
5335}
5336
5337#[derive(Debug)]
5338struct CypherScalarToCvUdf {
5339 signature: Signature,
5340}
5341
5342impl CypherScalarToCvUdf {
5343 fn new() -> Self {
5344 Self {
5345 signature: Signature::any(1, Volatility::Immutable),
5346 }
5347 }
5348}
5349
5350impl_udf_eq_hash!(CypherScalarToCvUdf);
5351
5352impl ScalarUDFImpl for CypherScalarToCvUdf {
5353 fn as_any(&self) -> &dyn Any {
5354 self
5355 }
5356
5357 fn name(&self) -> &str {
5358 "_cypher_scalar_to_cv"
5359 }
5360
5361 fn signature(&self) -> &Signature {
5362 &self.signature
5363 }
5364
5365 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5366 Ok(DataType::LargeBinary)
5367 }
5368
5369 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5370 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5371 if vals.len() != 1 {
5372 return Err(datafusion::error::DataFusionError::Execution(
5373 "_cypher_scalar_to_cv(): requires exactly 1 argument".to_string(),
5374 ));
5375 }
5376 Ok(vals[0].clone())
5377 })
5378 }
5379}
5380
5381pub fn create_cypher_tail_udf() -> ScalarUDF {
5393 ScalarUDF::new_from_impl(CypherTailUdf::new())
5394}
5395
5396#[derive(Debug)]
5397struct CypherTailUdf {
5398 signature: Signature,
5399}
5400
5401impl CypherTailUdf {
5402 fn new() -> Self {
5403 Self {
5404 signature: Signature::any(1, Volatility::Immutable),
5405 }
5406 }
5407}
5408
5409impl_udf_eq_hash!(CypherTailUdf);
5410
5411impl ScalarUDFImpl for CypherTailUdf {
5412 fn as_any(&self) -> &dyn Any {
5413 self
5414 }
5415
5416 fn name(&self) -> &str {
5417 "_cypher_tail"
5418 }
5419
5420 fn signature(&self) -> &Signature {
5421 &self.signature
5422 }
5423
5424 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5425 Ok(DataType::LargeBinary)
5426 }
5427
5428 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5429 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5430 if vals.len() != 1 {
5431 return Err(datafusion::error::DataFusionError::Execution(
5432 "_cypher_tail(): requires exactly 1 argument".to_string(),
5433 ));
5434 }
5435 match &vals[0] {
5436 Value::Null => Ok(Value::Null),
5437 Value::List(l) => {
5438 if l.is_empty() {
5439 Ok(Value::List(vec![]))
5440 } else {
5441 Ok(Value::List(l[1..].to_vec()))
5442 }
5443 }
5444 other => Err(datafusion::error::DataFusionError::Execution(format!(
5445 "_cypher_tail(): expected list, got {:?}",
5446 other
5447 ))),
5448 }
5449 })
5450 }
5451}
5452
5453pub fn create_cypher_head_udf() -> ScalarUDF {
5464 ScalarUDF::new_from_impl(CypherHeadUdf::new())
5465}
5466
5467#[derive(Debug)]
5468struct CypherHeadUdf {
5469 signature: Signature,
5470}
5471
5472impl CypherHeadUdf {
5473 fn new() -> Self {
5474 Self {
5475 signature: Signature::any(1, Volatility::Immutable),
5476 }
5477 }
5478}
5479
5480impl_udf_eq_hash!(CypherHeadUdf);
5481
5482impl ScalarUDFImpl for CypherHeadUdf {
5483 fn as_any(&self) -> &dyn Any {
5484 self
5485 }
5486
5487 fn name(&self) -> &str {
5488 "head"
5489 }
5490
5491 fn signature(&self) -> &Signature {
5492 &self.signature
5493 }
5494
5495 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5496 Ok(DataType::LargeBinary)
5497 }
5498
5499 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5500 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5501 if vals.len() != 1 {
5502 return Err(datafusion::error::DataFusionError::Execution(
5503 "head(): requires exactly 1 argument".to_string(),
5504 ));
5505 }
5506 match &vals[0] {
5507 Value::Null => Ok(Value::Null),
5508 Value::List(l) => Ok(l.first().cloned().unwrap_or(Value::Null)),
5509 other => Err(datafusion::error::DataFusionError::Execution(format!(
5510 "head(): expected list, got {:?}",
5511 other
5512 ))),
5513 }
5514 })
5515 }
5516}
5517
5518pub fn create_cypher_last_udf() -> ScalarUDF {
5529 ScalarUDF::new_from_impl(CypherLastUdf::new())
5530}
5531
5532#[derive(Debug)]
5533struct CypherLastUdf {
5534 signature: Signature,
5535}
5536
5537impl CypherLastUdf {
5538 fn new() -> Self {
5539 Self {
5540 signature: Signature::any(1, Volatility::Immutable),
5541 }
5542 }
5543}
5544
5545impl_udf_eq_hash!(CypherLastUdf);
5546
5547impl ScalarUDFImpl for CypherLastUdf {
5548 fn as_any(&self) -> &dyn Any {
5549 self
5550 }
5551
5552 fn name(&self) -> &str {
5553 "last"
5554 }
5555
5556 fn signature(&self) -> &Signature {
5557 &self.signature
5558 }
5559
5560 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5561 Ok(DataType::LargeBinary)
5562 }
5563
5564 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5565 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5566 if vals.len() != 1 {
5567 return Err(datafusion::error::DataFusionError::Execution(
5568 "last(): requires exactly 1 argument".to_string(),
5569 ));
5570 }
5571 match &vals[0] {
5572 Value::Null => Ok(Value::Null),
5573 Value::List(l) => Ok(l.last().cloned().unwrap_or(Value::Null)),
5574 other => Err(datafusion::error::DataFusionError::Execution(format!(
5575 "last(): expected list, got {:?}",
5576 other
5577 ))),
5578 }
5579 })
5580 }
5581}
5582
5583fn cypher_list_cmp(left: &[Value], right: &[Value]) -> Option<std::cmp::Ordering> {
5586 let min_len = left.len().min(right.len());
5587 for i in 0..min_len {
5588 let cmp = cypher_value_cmp(&left[i], &right[i])?;
5589 if cmp != std::cmp::Ordering::Equal {
5590 return Some(cmp);
5591 }
5592 }
5593 Some(left.len().cmp(&right.len()))
5595}
5596
5597fn cypher_value_cmp(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
5600 match (a, b) {
5601 (Value::Null, Value::Null) => Some(std::cmp::Ordering::Equal),
5602 (Value::Null, _) | (_, Value::Null) => None,
5603 (Value::Int(l), Value::Int(r)) => Some(l.cmp(r)),
5604 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r),
5605 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r),
5606 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)),
5607 (Value::String(l), Value::String(r)) => Some(l.cmp(r)),
5608 (Value::Bool(l), Value::Bool(r)) => Some(l.cmp(r)),
5609 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r),
5610 _ => None, }
5612}
5613
5614struct CypherToFloat64Udf {
5622 signature: Signature,
5623}
5624
5625impl CypherToFloat64Udf {
5626 fn new() -> Self {
5627 Self {
5628 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
5629 }
5630 }
5631}
5632
5633impl_udf_eq_hash!(CypherToFloat64Udf);
5634
5635impl std::fmt::Debug for CypherToFloat64Udf {
5636 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5637 f.debug_struct("CypherToFloat64Udf").finish()
5638 }
5639}
5640
5641impl ScalarUDFImpl for CypherToFloat64Udf {
5642 fn as_any(&self) -> &dyn Any {
5643 self
5644 }
5645 fn name(&self) -> &str {
5646 "_cypher_to_float64"
5647 }
5648 fn signature(&self) -> &Signature {
5649 &self.signature
5650 }
5651 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
5652 Ok(DataType::Float64)
5653 }
5654 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5655 if args.args.len() != 1 {
5656 return Err(datafusion::error::DataFusionError::Execution(
5657 "_cypher_to_float64 requires exactly 1 argument".into(),
5658 ));
5659 }
5660 match &args.args[0] {
5661 ColumnarValue::Scalar(scalar) => {
5662 let f = match scalar {
5663 ScalarValue::LargeBinary(Some(bytes)) => cv_bytes_as_f64(bytes),
5664 ScalarValue::Int64(Some(i)) => Some(*i as f64),
5665 ScalarValue::Int32(Some(i)) => Some(*i as f64),
5666 ScalarValue::Float64(Some(f)) => Some(*f),
5667 ScalarValue::Float32(Some(f)) => Some(*f as f64),
5668 _ => None,
5669 };
5670 Ok(ColumnarValue::Scalar(ScalarValue::Float64(f)))
5671 }
5672 ColumnarValue::Array(arr) => {
5673 let len = arr.len();
5674 let mut builder = arrow::array::Float64Builder::with_capacity(len);
5675 match arr.data_type() {
5676 DataType::LargeBinary => {
5677 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
5678 for i in 0..len {
5679 if lb.is_null(i) {
5680 builder.append_null();
5681 } else {
5682 match cv_bytes_as_f64(lb.value(i)) {
5683 Some(f) => builder.append_value(f),
5684 None => builder.append_null(),
5685 }
5686 }
5687 }
5688 }
5689 DataType::Int64 => {
5690 let int_arr = arr.as_any().downcast_ref::<Int64Array>().unwrap();
5691 for i in 0..len {
5692 if int_arr.is_null(i) {
5693 builder.append_null();
5694 } else {
5695 builder.append_value(int_arr.value(i) as f64);
5696 }
5697 }
5698 }
5699 DataType::Float64 => {
5700 let f_arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
5701 for i in 0..len {
5702 if f_arr.is_null(i) {
5703 builder.append_null();
5704 } else {
5705 builder.append_value(f_arr.value(i));
5706 }
5707 }
5708 }
5709 _ => {
5710 for _ in 0..len {
5711 builder.append_null();
5712 }
5713 }
5714 }
5715 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
5716 }
5717 }
5718 }
5719}
5720
5721fn create_cypher_to_float64_udf() -> ScalarUDF {
5722 ScalarUDF::from(CypherToFloat64Udf::new())
5723}
5724
5725pub(crate) fn cypher_to_float64_expr(
5727 arg: datafusion::logical_expr::Expr,
5728) -> datafusion::logical_expr::Expr {
5729 datafusion::logical_expr::Expr::ScalarFunction(
5730 datafusion::logical_expr::expr::ScalarFunction::new_udf(
5731 Arc::new(create_cypher_to_float64_udf()),
5732 vec![arg],
5733 ),
5734 )
5735}
5736
5737pub(crate) fn cypher_to_float64_udf() -> datafusion::logical_expr::ScalarUDF {
5739 create_cypher_to_float64_udf()
5740}
5741
5742fn cypher_type_rank(val: &Value) -> u8 {
5750 match val {
5751 Value::Null => 0,
5752 Value::List(_) => 1,
5753 Value::String(_) => 2,
5754 Value::Bool(_) => 3,
5755 Value::Int(_) | Value::Float(_) => 4,
5756 _ => 5, }
5758}
5759
5760fn cypher_cross_type_cmp(a: &Value, b: &Value) -> std::cmp::Ordering {
5763 use std::cmp::Ordering;
5764 let ra = cypher_type_rank(a);
5765 let rb = cypher_type_rank(b);
5766 if ra != rb {
5767 return ra.cmp(&rb);
5768 }
5769 match (a, b) {
5771 (Value::Int(l), Value::Int(r)) => l.cmp(r),
5772 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r).unwrap_or(Ordering::Equal),
5773 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r).unwrap_or(Ordering::Equal),
5774 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)).unwrap_or(Ordering::Equal),
5775 (Value::String(l), Value::String(r)) => l.cmp(r),
5776 (Value::Bool(l), Value::Bool(r)) => l.cmp(r),
5777 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r).unwrap_or(Ordering::Equal),
5778 _ => Ordering::Equal,
5779 }
5780}
5781
5782fn scalar_binary_to_value(bytes: &[u8]) -> Value {
5784 uni_common::cypher_value_codec::decode(bytes).unwrap_or(Value::Null)
5785}
5786
5787use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, AggregateUDFImpl};
5788
5789#[derive(Debug, Clone)]
5791struct CypherMinMaxUdaf {
5792 name: String,
5793 signature: Signature,
5794 is_max: bool,
5795}
5796
5797impl CypherMinMaxUdaf {
5798 fn new(is_max: bool) -> Self {
5799 let name = if is_max { "_cypher_max" } else { "_cypher_min" };
5800 Self {
5801 name: name.to_string(),
5802 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
5803 is_max,
5804 }
5805 }
5806}
5807
5808impl PartialEq for CypherMinMaxUdaf {
5809 fn eq(&self, other: &Self) -> bool {
5810 self.name == other.name
5811 }
5812}
5813
5814impl Eq for CypherMinMaxUdaf {}
5815
5816impl Hash for CypherMinMaxUdaf {
5817 fn hash<H: Hasher>(&self, state: &mut H) {
5818 self.name.hash(state);
5819 }
5820}
5821
5822impl AggregateUDFImpl for CypherMinMaxUdaf {
5823 fn as_any(&self) -> &dyn Any {
5824 self
5825 }
5826 fn name(&self) -> &str {
5827 &self.name
5828 }
5829 fn signature(&self) -> &Signature {
5830 &self.signature
5831 }
5832 fn return_type(&self, args: &[DataType]) -> DFResult<DataType> {
5833 Ok(args.first().cloned().unwrap_or(DataType::LargeBinary))
5835 }
5836 fn accumulator(
5837 &self,
5838 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
5839 ) -> DFResult<Box<dyn DfAccumulator>> {
5840 Ok(Box::new(CypherMinMaxAccumulator {
5841 current: None,
5842 is_max: self.is_max,
5843 return_type: acc_args.return_field.data_type().clone(),
5844 }))
5845 }
5846 fn state_fields(
5847 &self,
5848 args: datafusion::logical_expr::function::StateFieldsArgs,
5849 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
5850 Ok(vec![Arc::new(arrow::datatypes::Field::new(
5851 args.name,
5852 DataType::LargeBinary,
5853 true,
5854 ))])
5855 }
5856}
5857
5858#[derive(Debug)]
5859struct CypherMinMaxAccumulator {
5860 current: Option<Value>,
5861 is_max: bool,
5862 return_type: DataType,
5863}
5864
5865impl DfAccumulator for CypherMinMaxAccumulator {
5866 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
5867 let arr = &values[0];
5868 match arr.data_type() {
5869 DataType::LargeBinary => {
5870 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
5871 for i in 0..lb.len() {
5872 if lb.is_null(i) {
5873 continue;
5874 }
5875 let val = scalar_binary_to_value(lb.value(i));
5876 if val.is_null() {
5877 continue;
5878 }
5879 self.current = Some(match self.current.take() {
5880 None => val,
5881 Some(cur) => {
5882 let ord = cypher_cross_type_cmp(&val, &cur);
5883 if (self.is_max && ord == std::cmp::Ordering::Greater)
5884 || (!self.is_max && ord == std::cmp::Ordering::Less)
5885 {
5886 val
5887 } else {
5888 cur
5889 }
5890 }
5891 });
5892 }
5893 }
5894 _ => {
5895 for i in 0..arr.len() {
5897 if arr.is_null(i) {
5898 continue;
5899 }
5900 let sv = ScalarValue::try_from_array(arr, i).map_err(|e| {
5901 datafusion::error::DataFusionError::Execution(e.to_string())
5902 })?;
5903 let val = scalar_to_value(&sv)?;
5904 if val.is_null() {
5905 continue;
5906 }
5907 self.current = Some(match self.current.take() {
5908 None => val,
5909 Some(cur) => {
5910 let ord = cypher_cross_type_cmp(&val, &cur);
5911 if (self.is_max && ord == std::cmp::Ordering::Greater)
5912 || (!self.is_max && ord == std::cmp::Ordering::Less)
5913 {
5914 val
5915 } else {
5916 cur
5917 }
5918 }
5919 });
5920 }
5921 }
5922 }
5923 Ok(())
5924 }
5925 fn evaluate(&mut self) -> DFResult<ScalarValue> {
5926 match &self.current {
5927 None => {
5928 ScalarValue::try_from(&self.return_type)
5930 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
5931 }
5932 Some(val) => {
5933 if matches!(self.return_type, DataType::LargeBinary) {
5935 let bytes = uni_common::cypher_value_codec::encode(val);
5936 return Ok(ScalarValue::LargeBinary(Some(bytes)));
5937 }
5938 match val {
5940 Value::Int(i) => match &self.return_type {
5941 DataType::Int64 => Ok(ScalarValue::Int64(Some(*i))),
5942 DataType::UInt64 => Ok(ScalarValue::UInt64(Some(*i as u64))),
5943 _ => {
5944 let bytes = uni_common::cypher_value_codec::encode(val);
5945 Ok(ScalarValue::LargeBinary(Some(bytes)))
5946 }
5947 },
5948 Value::Float(f) => match &self.return_type {
5949 DataType::Float64 => Ok(ScalarValue::Float64(Some(*f))),
5950 _ => {
5951 let bytes = uni_common::cypher_value_codec::encode(val);
5952 Ok(ScalarValue::LargeBinary(Some(bytes)))
5953 }
5954 },
5955 Value::String(s) => match &self.return_type {
5956 DataType::Utf8 => Ok(ScalarValue::Utf8(Some(s.clone()))),
5957 DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
5958 _ => {
5959 let bytes = uni_common::cypher_value_codec::encode(val);
5960 Ok(ScalarValue::LargeBinary(Some(bytes)))
5961 }
5962 },
5963 Value::Bool(b) => match &self.return_type {
5964 DataType::Boolean => Ok(ScalarValue::Boolean(Some(*b))),
5965 _ => {
5966 let bytes = uni_common::cypher_value_codec::encode(val);
5967 Ok(ScalarValue::LargeBinary(Some(bytes)))
5968 }
5969 },
5970 _ => {
5971 let bytes = uni_common::cypher_value_codec::encode(val);
5973 Ok(ScalarValue::LargeBinary(Some(bytes)))
5974 }
5975 }
5976 }
5977 }
5978 }
5979 fn size(&self) -> usize {
5980 std::mem::size_of_val(self) + self.current.as_ref().map_or(0, |_| 64)
5981 }
5982 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
5983 Ok(vec![self.evaluate()?])
5984 }
5985 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
5986 self.update_batch(states)
5987 }
5988}
5989
5990pub(crate) fn create_cypher_min_udaf() -> AggregateUDF {
5991 AggregateUDF::from(CypherMinMaxUdaf::new(false))
5992}
5993
5994pub(crate) fn create_cypher_max_udaf() -> AggregateUDF {
5995 AggregateUDF::from(CypherMinMaxUdaf::new(true))
5996}
5997
5998#[derive(Debug, Clone)]
6004struct CypherSumUdaf {
6005 signature: Signature,
6006}
6007
6008impl CypherSumUdaf {
6009 fn new() -> Self {
6010 Self {
6011 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6012 }
6013 }
6014}
6015
6016impl PartialEq for CypherSumUdaf {
6017 fn eq(&self, other: &Self) -> bool {
6018 self.signature == other.signature
6019 }
6020}
6021
6022impl Eq for CypherSumUdaf {}
6023
6024impl Hash for CypherSumUdaf {
6025 fn hash<H: Hasher>(&self, state: &mut H) {
6026 self.name().hash(state);
6027 }
6028}
6029
6030impl AggregateUDFImpl for CypherSumUdaf {
6031 fn as_any(&self) -> &dyn Any {
6032 self
6033 }
6034 fn name(&self) -> &str {
6035 "_cypher_sum"
6036 }
6037 fn signature(&self) -> &Signature {
6038 &self.signature
6039 }
6040 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6041 Ok(DataType::LargeBinary)
6044 }
6045 fn accumulator(
6046 &self,
6047 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6048 ) -> DFResult<Box<dyn DfAccumulator>> {
6049 Ok(Box::new(CypherSumAccumulator {
6050 sum: 0.0,
6051 all_ints: true,
6052 int_sum: 0i64,
6053 has_value: false,
6054 }))
6055 }
6056 fn state_fields(
6057 &self,
6058 args: datafusion::logical_expr::function::StateFieldsArgs,
6059 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6060 Ok(vec![
6061 Arc::new(arrow::datatypes::Field::new(
6062 format!("{}_sum", args.name),
6063 DataType::Float64,
6064 true,
6065 )),
6066 Arc::new(arrow::datatypes::Field::new(
6067 format!("{}_int_sum", args.name),
6068 DataType::Int64,
6069 true,
6070 )),
6071 Arc::new(arrow::datatypes::Field::new(
6072 format!("{}_all_ints", args.name),
6073 DataType::Boolean,
6074 true,
6075 )),
6076 Arc::new(arrow::datatypes::Field::new(
6077 format!("{}_has_value", args.name),
6078 DataType::Boolean,
6079 true,
6080 )),
6081 ])
6082 }
6083}
6084
6085#[derive(Debug)]
6086struct CypherSumAccumulator {
6087 sum: f64,
6088 all_ints: bool,
6089 int_sum: i64,
6090 has_value: bool,
6091}
6092
6093impl DfAccumulator for CypherSumAccumulator {
6094 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6095 let arr = &values[0];
6096 for i in 0..arr.len() {
6097 if arr.is_null(i) {
6098 continue;
6099 }
6100 match arr.data_type() {
6101 DataType::LargeBinary => {
6102 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6103 let bytes = lb.value(i);
6104 use uni_common::cypher_value_codec::{
6105 TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag,
6106 };
6107 match peek_tag(bytes) {
6108 Some(TAG_INT) => {
6109 if let Some(v) = decode_int(bytes) {
6110 self.sum += v as f64;
6111 self.int_sum = self.int_sum.wrapping_add(v);
6112 self.has_value = true;
6113 }
6114 }
6115 Some(TAG_FLOAT) => {
6116 if let Some(v) = decode_float(bytes) {
6117 self.sum += v;
6118 self.all_ints = false;
6119 self.has_value = true;
6120 }
6121 }
6122 _ => {} }
6124 }
6125 DataType::Int64 => {
6126 let a = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6127 let v = a.value(i);
6128 self.sum += v as f64;
6129 self.int_sum = self.int_sum.wrapping_add(v);
6130 self.has_value = true;
6131 }
6132 DataType::Float64 => {
6133 let a = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6134 self.sum += a.value(i);
6135 self.all_ints = false;
6136 self.has_value = true;
6137 }
6138 _ => {}
6139 }
6140 }
6141 Ok(())
6142 }
6143 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6144 if !self.has_value {
6145 return Ok(ScalarValue::LargeBinary(None));
6146 }
6147 let val = if self.all_ints {
6148 Value::Int(self.int_sum)
6149 } else {
6150 Value::Float(self.sum)
6151 };
6152 let bytes = uni_common::cypher_value_codec::encode(&val);
6153 Ok(ScalarValue::LargeBinary(Some(bytes)))
6154 }
6155 fn size(&self) -> usize {
6156 std::mem::size_of_val(self)
6157 }
6158 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6159 Ok(vec![
6160 ScalarValue::Float64(Some(self.sum)),
6161 ScalarValue::Int64(Some(self.int_sum)),
6162 ScalarValue::Boolean(Some(self.all_ints)),
6163 ScalarValue::Boolean(Some(self.has_value)),
6164 ])
6165 }
6166 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6167 let sum_arr = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
6168 let int_sum_arr = states[1].as_any().downcast_ref::<Int64Array>().unwrap();
6169 let all_ints_arr = states[2].as_any().downcast_ref::<BooleanArray>().unwrap();
6170 let has_value_arr = states[3].as_any().downcast_ref::<BooleanArray>().unwrap();
6171 for i in 0..sum_arr.len() {
6172 if !has_value_arr.is_null(i) && has_value_arr.value(i) {
6173 self.sum += sum_arr.value(i);
6174 self.int_sum = self.int_sum.wrapping_add(int_sum_arr.value(i));
6175 if !all_ints_arr.value(i) {
6176 self.all_ints = false;
6177 }
6178 self.has_value = true;
6179 }
6180 }
6181 Ok(())
6182 }
6183}
6184
6185pub(crate) fn create_cypher_sum_udaf() -> AggregateUDF {
6186 AggregateUDF::from(CypherSumUdaf::new())
6187}
6188
6189#[derive(Debug, Clone)]
6196struct CypherCollectUdaf {
6197 signature: Signature,
6198}
6199
6200impl CypherCollectUdaf {
6201 fn new() -> Self {
6202 Self {
6203 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6204 }
6205 }
6206}
6207
6208impl PartialEq for CypherCollectUdaf {
6209 fn eq(&self, other: &Self) -> bool {
6210 self.signature == other.signature
6211 }
6212}
6213
6214impl Eq for CypherCollectUdaf {}
6215
6216impl Hash for CypherCollectUdaf {
6217 fn hash<H: Hasher>(&self, state: &mut H) {
6218 self.name().hash(state);
6219 }
6220}
6221
6222impl AggregateUDFImpl for CypherCollectUdaf {
6223 fn as_any(&self) -> &dyn Any {
6224 self
6225 }
6226 fn name(&self) -> &str {
6227 "_cypher_collect"
6228 }
6229 fn signature(&self) -> &Signature {
6230 &self.signature
6231 }
6232 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6233 Ok(DataType::LargeBinary)
6234 }
6235 fn accumulator(
6236 &self,
6237 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6238 ) -> DFResult<Box<dyn DfAccumulator>> {
6239 Ok(Box::new(CypherCollectAccumulator {
6240 values: Vec::new(),
6241 distinct: acc_args.is_distinct,
6242 }))
6243 }
6244 fn state_fields(
6245 &self,
6246 args: datafusion::logical_expr::function::StateFieldsArgs,
6247 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6248 Ok(vec![Arc::new(arrow::datatypes::Field::new(
6249 args.name,
6250 DataType::LargeBinary,
6251 true,
6252 ))])
6253 }
6254}
6255
6256#[derive(Debug)]
6257struct CypherCollectAccumulator {
6258 values: Vec<Value>,
6259 distinct: bool,
6260}
6261
6262impl DfAccumulator for CypherCollectAccumulator {
6263 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6264 let arr = &values[0];
6265 for i in 0..arr.len() {
6266 if arr.is_null(i) {
6267 continue;
6268 }
6269 if let Some(struct_arr) = arr.as_any().downcast_ref::<arrow::array::StructArray>()
6273 && struct_arr.num_columns() > 0
6274 && struct_arr.column(0).is_null(i)
6275 {
6276 continue;
6277 }
6278 let sv = ScalarValue::try_from_array(arr, i)
6279 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
6280 let val = scalar_to_value(&sv)?;
6281 if val.is_null() {
6282 continue;
6283 }
6284 if self.distinct {
6285 let repr = val.to_string();
6287 if self.values.iter().any(|v| v.to_string() == repr) {
6288 continue;
6289 }
6290 }
6291 self.values.push(val);
6292 }
6293 Ok(())
6294 }
6295 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6296 let val = Value::List(self.values.clone());
6298 let bytes = uni_common::cypher_value_codec::encode(&val);
6299 Ok(ScalarValue::LargeBinary(Some(bytes)))
6300 }
6301 fn size(&self) -> usize {
6302 std::mem::size_of_val(self) + self.values.len() * 64
6303 }
6304 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6305 Ok(vec![self.evaluate()?])
6306 }
6307 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6308 let arr = &states[0];
6310 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
6311 for i in 0..lb.len() {
6312 if lb.is_null(i) {
6313 continue;
6314 }
6315 let val = scalar_binary_to_value(lb.value(i));
6316 if let Value::List(items) = val {
6317 for item in items {
6318 if !item.is_null() {
6319 if self.distinct {
6320 let repr = item.to_string();
6321 if self.values.iter().any(|v| v.to_string() == repr) {
6322 continue;
6323 }
6324 }
6325 self.values.push(item);
6326 }
6327 }
6328 }
6329 }
6330 }
6331 Ok(())
6332 }
6333}
6334
6335pub(crate) fn create_cypher_collect_udaf() -> AggregateUDF {
6336 AggregateUDF::from(CypherCollectUdaf::new())
6337}
6338
6339pub(crate) fn create_cypher_collect_expr(
6341 arg: datafusion::logical_expr::Expr,
6342 distinct: bool,
6343) -> datafusion::logical_expr::Expr {
6344 let udaf = Arc::new(create_cypher_collect_udaf());
6347 if distinct {
6348 datafusion::logical_expr::Expr::AggregateFunction(
6350 datafusion::logical_expr::expr::AggregateFunction::new_udf(
6351 udaf,
6352 vec![arg],
6353 true, None,
6355 vec![],
6356 None,
6357 ),
6358 )
6359 } else {
6360 udaf.call(vec![arg])
6361 }
6362}
6363
6364#[derive(Debug, Clone)]
6370struct CypherPercentileDiscUdaf {
6371 signature: Signature,
6372}
6373
6374impl CypherPercentileDiscUdaf {
6375 fn new() -> Self {
6376 Self {
6377 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
6378 }
6379 }
6380}
6381
6382impl PartialEq for CypherPercentileDiscUdaf {
6383 fn eq(&self, other: &Self) -> bool {
6384 self.signature == other.signature
6385 }
6386}
6387
6388impl Eq for CypherPercentileDiscUdaf {}
6389
6390impl Hash for CypherPercentileDiscUdaf {
6391 fn hash<H: Hasher>(&self, state: &mut H) {
6392 self.name().hash(state);
6393 }
6394}
6395
6396impl AggregateUDFImpl for CypherPercentileDiscUdaf {
6397 fn as_any(&self) -> &dyn Any {
6398 self
6399 }
6400 fn name(&self) -> &str {
6401 "percentiledisc"
6402 }
6403 fn signature(&self) -> &Signature {
6404 &self.signature
6405 }
6406 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6407 Ok(DataType::Float64)
6408 }
6409 fn accumulator(
6410 &self,
6411 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6412 ) -> DFResult<Box<dyn DfAccumulator>> {
6413 Ok(Box::new(CypherPercentileDiscAccumulator {
6414 values: Vec::new(),
6415 percentile: None,
6416 }))
6417 }
6418 fn state_fields(
6419 &self,
6420 args: datafusion::logical_expr::function::StateFieldsArgs,
6421 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6422 Ok(vec![
6423 Arc::new(arrow::datatypes::Field::new(
6424 format!("{}_values", args.name),
6425 DataType::List(Arc::new(arrow::datatypes::Field::new(
6426 "item",
6427 DataType::Float64,
6428 true,
6429 ))),
6430 true,
6431 )),
6432 Arc::new(arrow::datatypes::Field::new(
6433 format!("{}_percentile", args.name),
6434 DataType::Float64,
6435 true,
6436 )),
6437 ])
6438 }
6439}
6440
6441#[derive(Debug)]
6442struct CypherPercentileDiscAccumulator {
6443 values: Vec<f64>,
6444 percentile: Option<f64>,
6445}
6446
6447impl CypherPercentileDiscAccumulator {
6448 fn extract_f64(arr: &ArrayRef, i: usize) -> Option<f64> {
6449 if arr.is_null(i) {
6450 return None;
6451 }
6452 match arr.data_type() {
6453 DataType::LargeBinary => {
6454 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
6455 cv_bytes_as_f64(lb.value(i))
6456 }
6457 DataType::Int64 => {
6458 let a = arr.as_any().downcast_ref::<Int64Array>()?;
6459 Some(a.value(i) as f64)
6460 }
6461 DataType::Float64 => {
6462 let a = arr.as_any().downcast_ref::<Float64Array>()?;
6463 Some(a.value(i))
6464 }
6465 DataType::Int32 => {
6466 let a = arr.as_any().downcast_ref::<Int32Array>()?;
6467 Some(a.value(i) as f64)
6468 }
6469 DataType::Float32 => {
6470 let a = arr.as_any().downcast_ref::<Float32Array>()?;
6471 Some(a.value(i) as f64)
6472 }
6473 _ => None,
6474 }
6475 }
6476
6477 fn extract_percentile(arr: &ArrayRef, i: usize) -> Option<f64> {
6478 if arr.is_null(i) {
6479 return None;
6480 }
6481 match arr.data_type() {
6482 DataType::Float64 => {
6483 let a = arr.as_any().downcast_ref::<Float64Array>()?;
6484 Some(a.value(i))
6485 }
6486 DataType::Int64 => {
6487 let a = arr.as_any().downcast_ref::<Int64Array>()?;
6488 Some(a.value(i) as f64)
6489 }
6490 DataType::LargeBinary => {
6491 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
6492 cv_bytes_as_f64(lb.value(i))
6493 }
6494 _ => None,
6495 }
6496 }
6497}
6498
6499impl DfAccumulator for CypherPercentileDiscAccumulator {
6500 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6501 let expr_arr = &values[0];
6502 let pct_arr = &values[1];
6503 for i in 0..expr_arr.len() {
6504 if self.percentile.is_none()
6506 && let Some(p) = Self::extract_percentile(pct_arr, i)
6507 {
6508 if !(0.0..=1.0).contains(&p) {
6509 return Err(datafusion::error::DataFusionError::Execution(
6510 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
6511 ));
6512 }
6513 self.percentile = Some(p);
6514 }
6515 if let Some(f) = Self::extract_f64(expr_arr, i) {
6516 self.values.push(f);
6517 }
6518 }
6519 Ok(())
6520 }
6521 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6522 let pct = match self.percentile {
6523 Some(p) if !(0.0..=1.0).contains(&p) => {
6524 return Err(datafusion::error::DataFusionError::Execution(
6525 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
6526 ));
6527 }
6528 Some(p) => p,
6529 None => 0.0,
6530 };
6531 if self.values.is_empty() {
6532 return Ok(ScalarValue::Float64(None));
6533 }
6534 self.values
6535 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
6536 let n = self.values.len();
6537 let idx = (pct * (n as f64 - 1.0)).round() as usize;
6538 let idx = idx.min(n - 1);
6539 let result = self.values[idx];
6540 Ok(ScalarValue::Float64(Some(result)))
6541 }
6542 fn size(&self) -> usize {
6543 std::mem::size_of_val(self) + self.values.capacity() * 8
6544 }
6545 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6546 let list_values: Vec<ScalarValue> = self
6548 .values
6549 .iter()
6550 .map(|f| ScalarValue::Float64(Some(*f)))
6551 .collect();
6552 let list_scalar = ScalarValue::List(ScalarValue::new_list(
6553 &list_values,
6554 &DataType::Float64,
6555 true,
6556 ));
6557 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
6558 }
6559 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6560 let list_arr = &states[0];
6562 let pct_arr = &states[1];
6563 if self.percentile.is_none()
6565 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
6566 {
6567 for i in 0..f64_arr.len() {
6568 if !f64_arr.is_null(i) {
6569 self.percentile = Some(f64_arr.value(i));
6570 break;
6571 }
6572 }
6573 }
6574 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
6576 for i in 0..list_array.len() {
6577 if list_array.is_null(i) {
6578 continue;
6579 }
6580 let inner = list_array.value(i);
6581 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
6582 for j in 0..f64_arr.len() {
6583 if !f64_arr.is_null(j) {
6584 self.values.push(f64_arr.value(j));
6585 }
6586 }
6587 }
6588 }
6589 }
6590 Ok(())
6591 }
6592}
6593
6594#[derive(Debug, Clone)]
6596struct CypherPercentileContUdaf {
6597 signature: Signature,
6598}
6599
6600impl CypherPercentileContUdaf {
6601 fn new() -> Self {
6602 Self {
6603 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
6604 }
6605 }
6606}
6607
6608impl PartialEq for CypherPercentileContUdaf {
6609 fn eq(&self, other: &Self) -> bool {
6610 self.signature == other.signature
6611 }
6612}
6613
6614impl Eq for CypherPercentileContUdaf {}
6615
6616impl Hash for CypherPercentileContUdaf {
6617 fn hash<H: Hasher>(&self, state: &mut H) {
6618 self.name().hash(state);
6619 }
6620}
6621
6622impl AggregateUDFImpl for CypherPercentileContUdaf {
6623 fn as_any(&self) -> &dyn Any {
6624 self
6625 }
6626 fn name(&self) -> &str {
6627 "percentilecont"
6628 }
6629 fn signature(&self) -> &Signature {
6630 &self.signature
6631 }
6632 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6633 Ok(DataType::Float64)
6634 }
6635 fn accumulator(
6636 &self,
6637 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6638 ) -> DFResult<Box<dyn DfAccumulator>> {
6639 Ok(Box::new(CypherPercentileContAccumulator {
6640 values: Vec::new(),
6641 percentile: None,
6642 }))
6643 }
6644 fn state_fields(
6645 &self,
6646 args: datafusion::logical_expr::function::StateFieldsArgs,
6647 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6648 Ok(vec![
6649 Arc::new(arrow::datatypes::Field::new(
6650 format!("{}_values", args.name),
6651 DataType::List(Arc::new(arrow::datatypes::Field::new(
6652 "item",
6653 DataType::Float64,
6654 true,
6655 ))),
6656 true,
6657 )),
6658 Arc::new(arrow::datatypes::Field::new(
6659 format!("{}_percentile", args.name),
6660 DataType::Float64,
6661 true,
6662 )),
6663 ])
6664 }
6665}
6666
6667#[derive(Debug)]
6668struct CypherPercentileContAccumulator {
6669 values: Vec<f64>,
6670 percentile: Option<f64>,
6671}
6672
6673impl DfAccumulator for CypherPercentileContAccumulator {
6674 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6675 let expr_arr = &values[0];
6676 let pct_arr = &values[1];
6677 for i in 0..expr_arr.len() {
6678 if self.percentile.is_none()
6679 && let Some(p) = CypherPercentileDiscAccumulator::extract_percentile(pct_arr, i)
6680 {
6681 if !(0.0..=1.0).contains(&p) {
6682 return Err(datafusion::error::DataFusionError::Execution(
6683 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
6684 ));
6685 }
6686 self.percentile = Some(p);
6687 }
6688 if let Some(f) = CypherPercentileDiscAccumulator::extract_f64(expr_arr, i) {
6689 self.values.push(f);
6690 }
6691 }
6692 Ok(())
6693 }
6694 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6695 let pct = match self.percentile {
6696 Some(p) if !(0.0..=1.0).contains(&p) => {
6697 return Err(datafusion::error::DataFusionError::Execution(
6698 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
6699 ));
6700 }
6701 Some(p) => p,
6702 None => 0.0,
6703 };
6704 if self.values.is_empty() {
6705 return Ok(ScalarValue::Float64(None));
6706 }
6707 self.values
6708 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
6709 let n = self.values.len();
6710 if n == 1 {
6711 return Ok(ScalarValue::Float64(Some(self.values[0])));
6712 }
6713 let pos = pct * (n as f64 - 1.0);
6714 let lower = pos.floor() as usize;
6715 let upper = pos.ceil() as usize;
6716 let lower = lower.min(n - 1);
6717 let upper = upper.min(n - 1);
6718 if lower == upper {
6719 Ok(ScalarValue::Float64(Some(self.values[lower])))
6720 } else {
6721 let frac = pos - lower as f64;
6722 let result = self.values[lower] + frac * (self.values[upper] - self.values[lower]);
6723 Ok(ScalarValue::Float64(Some(result)))
6724 }
6725 }
6726 fn size(&self) -> usize {
6727 std::mem::size_of_val(self) + self.values.capacity() * 8
6728 }
6729 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6730 let list_values: Vec<ScalarValue> = self
6731 .values
6732 .iter()
6733 .map(|f| ScalarValue::Float64(Some(*f)))
6734 .collect();
6735 let list_scalar = ScalarValue::List(ScalarValue::new_list(
6736 &list_values,
6737 &DataType::Float64,
6738 true,
6739 ));
6740 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
6741 }
6742 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6743 let list_arr = &states[0];
6744 let pct_arr = &states[1];
6745 if self.percentile.is_none()
6746 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
6747 {
6748 for i in 0..f64_arr.len() {
6749 if !f64_arr.is_null(i) {
6750 self.percentile = Some(f64_arr.value(i));
6751 break;
6752 }
6753 }
6754 }
6755 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
6756 for i in 0..list_array.len() {
6757 if list_array.is_null(i) {
6758 continue;
6759 }
6760 let inner = list_array.value(i);
6761 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
6762 for j in 0..f64_arr.len() {
6763 if !f64_arr.is_null(j) {
6764 self.values.push(f64_arr.value(j));
6765 }
6766 }
6767 }
6768 }
6769 }
6770 Ok(())
6771 }
6772}
6773
6774pub(crate) fn create_cypher_percentile_disc_udaf() -> AggregateUDF {
6775 AggregateUDF::from(CypherPercentileDiscUdaf::new())
6776}
6777
6778pub(crate) fn create_cypher_percentile_cont_udaf() -> AggregateUDF {
6779 AggregateUDF::from(CypherPercentileContUdaf::new())
6780}
6781
6782fn invoke_similarity_udf(
6792 func_name: &str,
6793 min_args: usize,
6794 args: ScalarFunctionArgs,
6795) -> DFResult<ColumnarValue> {
6796 let output_type = DataType::Float64;
6797 invoke_cypher_udf(args, &output_type, |val_args| {
6798 if val_args.len() < min_args {
6799 return Err(datafusion::error::DataFusionError::Execution(format!(
6800 "{} requires at least {} arguments",
6801 func_name, min_args
6802 )));
6803 }
6804 crate::query::similar_to::eval_similar_to_pure(&val_args[0], &val_args[1])
6805 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
6806 })
6807}
6808
6809pub fn create_similar_to_udf() -> ScalarUDF {
6811 ScalarUDF::new_from_impl(SimilarToUdf::new())
6812}
6813
6814#[derive(Debug)]
6815struct SimilarToUdf {
6816 signature: Signature,
6817}
6818
6819impl SimilarToUdf {
6820 fn new() -> Self {
6821 Self {
6822 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
6823 }
6824 }
6825}
6826
6827impl_udf_eq_hash!(SimilarToUdf);
6828
6829impl ScalarUDFImpl for SimilarToUdf {
6830 fn as_any(&self) -> &dyn Any {
6831 self
6832 }
6833
6834 fn name(&self) -> &str {
6835 "similar_to"
6836 }
6837
6838 fn signature(&self) -> &Signature {
6839 &self.signature
6840 }
6841
6842 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6843 Ok(DataType::Float64)
6844 }
6845
6846 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6847 invoke_similarity_udf("similar_to", 2, args)
6848 }
6849}
6850
6851pub fn create_vector_similarity_udf() -> ScalarUDF {
6853 ScalarUDF::new_from_impl(VectorSimilarityUdf::new())
6854}
6855
6856#[derive(Debug)]
6857struct VectorSimilarityUdf {
6858 signature: Signature,
6859}
6860
6861impl VectorSimilarityUdf {
6862 fn new() -> Self {
6863 Self {
6864 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
6865 }
6866 }
6867}
6868
6869impl_udf_eq_hash!(VectorSimilarityUdf);
6870
6871impl ScalarUDFImpl for VectorSimilarityUdf {
6872 fn as_any(&self) -> &dyn Any {
6873 self
6874 }
6875
6876 fn name(&self) -> &str {
6877 "vector_similarity"
6878 }
6879
6880 fn signature(&self) -> &Signature {
6881 &self.signature
6882 }
6883
6884 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6885 Ok(DataType::Float64)
6886 }
6887
6888 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6889 invoke_similarity_udf("vector_similarity", 2, args)
6890 }
6891}
6892
6893#[cfg(test)]
6894mod tests {
6895 use super::*;
6896 use datafusion::execution::FunctionRegistry;
6897
6898 #[test]
6899 fn test_register_udfs() {
6900 let ctx = SessionContext::new();
6901 register_cypher_udfs(&ctx).unwrap();
6902
6903 assert!(ctx.udf("id").is_ok());
6906 assert!(ctx.udf("type").is_ok());
6907 assert!(ctx.udf("keys").is_ok());
6908 assert!(ctx.udf("range").is_ok());
6909 assert!(
6910 ctx.udf("_make_cypher_list").is_ok(),
6911 "_make_cypher_list UDF should be registered"
6912 );
6913 assert!(
6914 ctx.udf("_cv_to_bool").is_ok(),
6915 "_cv_to_bool UDF should be registered"
6916 );
6917 }
6918
6919 #[test]
6920 fn test_id_udf_signature() {
6921 let udf = create_id_udf();
6922 assert_eq!(udf.name(), "id");
6923 }
6924
6925 #[test]
6926 fn test_has_null_udf() {
6927 use datafusion::arrow::datatypes::{DataType, Field};
6928 use datafusion::config::ConfigOptions;
6929 use datafusion::scalar::ScalarValue;
6930 use std::sync::Arc;
6931
6932 let udf = create_has_null_udf();
6933
6934 let values = vec![
6936 ScalarValue::Int64(Some(1)),
6937 ScalarValue::Int64(Some(2)),
6938 ScalarValue::Int64(None),
6939 ];
6940
6941 let list_scalar = ScalarValue::List(ScalarValue::new_list(&values, &DataType::Int64, true));
6943
6944 let list_field = Arc::new(Field::new(
6945 "item",
6946 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
6947 true,
6948 ));
6949
6950 let args = ScalarFunctionArgs {
6951 args: vec![ColumnarValue::Scalar(list_scalar)],
6952 arg_fields: vec![list_field],
6953 number_rows: 1,
6954 return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
6955 config_options: Arc::new(ConfigOptions::default()),
6956 };
6957
6958 let result = udf.invoke_with_args(args).unwrap();
6959
6960 if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = result {
6961 assert!(b, "has_null should return true for list with null");
6962 } else {
6963 panic!("Unexpected result: {:?}", result);
6964 }
6965 }
6966
6967 fn json_to_cv_bytes(val: &serde_json::Value) -> Vec<u8> {
6973 let uni_val: uni_common::Value = val.clone().into();
6974 uni_common::cypher_value_codec::encode(&uni_val)
6975 }
6976
6977 fn make_multi_scalar_args(scalars: Vec<ScalarValue>) -> ScalarFunctionArgs {
6986 make_multi_scalar_args_with_return(scalars, DataType::LargeBinary)
6987 }
6988
6989 fn make_multi_scalar_args_with_return(
6990 scalars: Vec<ScalarValue>,
6991 return_type: DataType,
6992 ) -> ScalarFunctionArgs {
6993 use datafusion::arrow::datatypes::Field;
6994 use datafusion::config::ConfigOptions;
6995
6996 let arg_fields: Vec<_> = scalars
6997 .iter()
6998 .enumerate()
6999 .map(|(i, s)| Arc::new(Field::new(format!("arg{i}"), s.data_type(), true)))
7000 .collect();
7001 let args: Vec<_> = scalars.into_iter().map(ColumnarValue::Scalar).collect();
7002 ScalarFunctionArgs {
7003 args,
7004 arg_fields,
7005 number_rows: 1,
7006 return_field: Arc::new(Field::new("result", return_type, true)),
7007 config_options: Arc::new(ConfigOptions::default()),
7008 }
7009 }
7010
7011 fn decode_cv_scalar(cv: &ColumnarValue) -> serde_json::Value {
7013 match cv {
7014 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7015 let val = uni_common::cypher_value_codec::decode(bytes)
7016 .expect("failed to decode CypherValue output");
7017 val.into()
7018 }
7019 other => panic!("expected LargeBinary scalar, got {other:?}"),
7020 }
7021 }
7022
7023 #[test]
7024 fn test_make_cypher_list_scalars() {
7025 let udf = create_make_cypher_list_udf();
7026 let args = make_multi_scalar_args(vec![
7027 ScalarValue::Int64(Some(1)),
7028 ScalarValue::Float64(Some(3.21)),
7029 ScalarValue::Utf8(Some("hello".to_string())),
7030 ScalarValue::Boolean(Some(true)),
7031 ScalarValue::Null,
7032 ]);
7033 let result = udf.invoke_with_args(args).unwrap();
7034 let json = decode_cv_scalar(&result);
7035 let arr = json.as_array().expect("should be array");
7036 assert_eq!(arr.len(), 5);
7037 assert_eq!(arr[0], serde_json::json!(1));
7038 assert_eq!(arr[1], serde_json::json!(3.21));
7039 assert_eq!(arr[2], serde_json::json!("hello"));
7040 assert_eq!(arr[3], serde_json::json!(true));
7041 assert!(arr[4].is_null());
7042 }
7043
7044 #[test]
7045 fn test_make_cypher_list_empty() {
7046 let udf = create_make_cypher_list_udf();
7047 let args = make_multi_scalar_args(vec![]);
7048 let result = udf.invoke_with_args(args).unwrap();
7049 let json = decode_cv_scalar(&result);
7050 let arr = json.as_array().expect("should be array");
7051 assert!(arr.is_empty());
7052 }
7053
7054 #[test]
7055 fn test_make_cypher_list_single() {
7056 let udf = create_make_cypher_list_udf();
7057 let args = make_multi_scalar_args(vec![ScalarValue::Int64(Some(42))]);
7058 let result = udf.invoke_with_args(args).unwrap();
7059 let json = decode_cv_scalar(&result);
7060 let arr = json.as_array().expect("should be array");
7061 assert_eq!(arr.len(), 1);
7062 assert_eq!(arr[0], serde_json::json!(42));
7063 }
7064
7065 #[test]
7066 fn test_make_cypher_list_nested_cypher_value() {
7067 let udf = create_make_cypher_list_udf();
7068 let nested_bytes = json_to_cv_bytes(&serde_json::json!([1, 2]));
7070 let args = make_multi_scalar_args(vec![
7071 ScalarValue::LargeBinary(Some(nested_bytes)),
7072 ScalarValue::Int64(Some(3)),
7073 ]);
7074 let result = udf.invoke_with_args(args).unwrap();
7075 let json = decode_cv_scalar(&result);
7076 let arr = json.as_array().expect("should be array");
7077 assert_eq!(arr.len(), 2);
7078 assert_eq!(arr[0], serde_json::json!([1, 2]));
7079 assert_eq!(arr[1], serde_json::json!(3));
7080 }
7081
7082 fn make_cypher_in_args(
7088 element: &serde_json::Value,
7089 list: &serde_json::Value,
7090 ) -> ScalarFunctionArgs {
7091 make_multi_scalar_args_with_return(
7092 vec![
7093 ScalarValue::LargeBinary(Some(json_to_cv_bytes(element))),
7094 ScalarValue::LargeBinary(Some(json_to_cv_bytes(list))),
7095 ],
7096 DataType::Boolean,
7097 )
7098 }
7099
7100 #[test]
7101 fn test_cypher_in_found() {
7102 let udf = create_cypher_in_udf();
7103 let args = make_cypher_in_args(&serde_json::json!(3), &serde_json::json!([1, 2, 3]));
7104 let result = udf.invoke_with_args(args).unwrap();
7105 match result {
7106 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7107 other => panic!("expected Boolean(true), got {other:?}"),
7108 }
7109 }
7110
7111 #[test]
7112 fn test_cypher_in_not_found() {
7113 let udf = create_cypher_in_udf();
7114 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, 2, 3]));
7115 let result = udf.invoke_with_args(args).unwrap();
7116 match result {
7117 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7118 other => panic!("expected Boolean(false), got {other:?}"),
7119 }
7120 }
7121
7122 #[test]
7123 fn test_cypher_in_null_list() {
7124 let udf = create_cypher_in_udf();
7125 let args = make_multi_scalar_args_with_return(
7126 vec![
7127 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(1)))),
7128 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7129 ],
7130 DataType::Boolean,
7131 );
7132 let result = udf.invoke_with_args(args).unwrap();
7133 match result {
7134 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null list, got {other:?}"),
7136 }
7137 }
7138
7139 #[test]
7140 fn test_cypher_in_null_element_nonempty() {
7141 let udf = create_cypher_in_udf();
7142 let args = make_multi_scalar_args_with_return(
7143 vec![
7144 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7145 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7146 ],
7147 DataType::Boolean,
7148 );
7149 let result = udf.invoke_with_args(args).unwrap();
7150 match result {
7151 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null IN non-empty list, got {other:?}"),
7153 }
7154 }
7155
7156 #[test]
7157 fn test_cypher_in_null_element_empty() {
7158 let udf = create_cypher_in_udf();
7159 let args = make_multi_scalar_args_with_return(
7160 vec![
7161 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7162 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7163 ],
7164 DataType::Boolean,
7165 );
7166 let result = udf.invoke_with_args(args).unwrap();
7167 match result {
7168 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7169 other => panic!("expected Boolean(false) for null IN [], got {other:?}"),
7170 }
7171 }
7172
7173 #[test]
7174 fn test_cypher_in_not_found_with_null() {
7175 let udf = create_cypher_in_udf();
7176 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, null, 3]));
7177 let result = udf.invoke_with_args(args).unwrap();
7178 match result {
7179 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for 4 IN [1,null,3], got {other:?}"),
7181 }
7182 }
7183
7184 #[test]
7185 fn test_cypher_in_cross_type_int_float() {
7186 let udf = create_cypher_in_udf();
7187 let args = make_cypher_in_args(&serde_json::json!(1), &serde_json::json!([1.0, 2.0]));
7188 let result = udf.invoke_with_args(args).unwrap();
7189 match result {
7190 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7191 other => panic!("expected Boolean(true) for 1 IN [1.0, 2.0], got {other:?}"),
7192 }
7193 }
7194
7195 #[test]
7200 fn test_list_concat_basic() {
7201 let udf = create_cypher_list_concat_udf();
7202 let args = make_multi_scalar_args(vec![
7203 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7204 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([3, 4])))),
7205 ]);
7206 let result = udf.invoke_with_args(args).unwrap();
7207 let json = decode_cv_scalar(&result);
7208 assert_eq!(json, serde_json::json!([1, 2, 3, 4]));
7209 }
7210
7211 #[test]
7212 fn test_list_concat_empty() {
7213 let udf = create_cypher_list_concat_udf();
7214 let args = make_multi_scalar_args(vec![
7215 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7216 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7217 ]);
7218 let result = udf.invoke_with_args(args).unwrap();
7219 let json = decode_cv_scalar(&result);
7220 assert_eq!(json, serde_json::json!([1]));
7221 }
7222
7223 #[test]
7224 fn test_list_concat_null_left() {
7225 let udf = create_cypher_list_concat_udf();
7226 let args = make_multi_scalar_args(vec![
7227 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7228 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7229 ]);
7230 let result = udf.invoke_with_args(args).unwrap();
7231 match result {
7232 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7233 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7234 let json: serde_json::Value = uni_val.into();
7235 assert!(json.is_null(), "expected null, got {json}");
7236 }
7237 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {} other => panic!("expected null result, got {other:?}"),
7239 }
7240 }
7241
7242 #[test]
7243 fn test_list_concat_null_right() {
7244 let udf = create_cypher_list_concat_udf();
7245 let args = make_multi_scalar_args(vec![
7246 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7247 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7248 ]);
7249 let result = udf.invoke_with_args(args).unwrap();
7250 match result {
7251 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7252 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7253 let json: serde_json::Value = uni_val.into();
7254 assert!(json.is_null(), "expected null, got {json}");
7255 }
7256 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7257 other => panic!("expected null result, got {other:?}"),
7258 }
7259 }
7260
7261 #[test]
7266 fn test_list_append_scalar() {
7267 let udf = create_cypher_list_append_udf();
7268 let args = make_multi_scalar_args(vec![
7269 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7270 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7271 ]);
7272 let result = udf.invoke_with_args(args).unwrap();
7273 let json = decode_cv_scalar(&result);
7274 assert_eq!(json, serde_json::json!([1, 2, 3]));
7275 }
7276
7277 #[test]
7278 fn test_list_prepend_scalar() {
7279 let udf = create_cypher_list_append_udf();
7280 let args = make_multi_scalar_args(vec![
7281 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7282 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7283 ]);
7284 let result = udf.invoke_with_args(args).unwrap();
7285 let json = decode_cv_scalar(&result);
7286 assert_eq!(json, serde_json::json!([3, 1, 2]));
7287 }
7288
7289 #[test]
7290 fn test_list_append_null_list() {
7291 let udf = create_cypher_list_append_udf();
7292 let args = make_multi_scalar_args(vec![
7293 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7294 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7295 ]);
7296 let result = udf.invoke_with_args(args).unwrap();
7297 match result {
7298 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7299 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7300 let json: serde_json::Value = uni_val.into();
7301 assert!(json.is_null(), "expected null, got {json}");
7302 }
7303 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7304 other => panic!("expected null result, got {other:?}"),
7305 }
7306 }
7307
7308 #[test]
7309 fn test_list_append_null_scalar() {
7310 let udf = create_cypher_list_append_udf();
7311 let args = make_multi_scalar_args(vec![
7312 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7313 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7314 ]);
7315 let result = udf.invoke_with_args(args).unwrap();
7316 match result {
7317 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7318 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7319 let json: serde_json::Value = uni_val.into();
7320 assert!(json.is_null(), "expected null, got {json}");
7321 }
7322 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7323 other => panic!("expected null result, got {other:?}"),
7324 }
7325 }
7326
7327 #[test]
7332 fn test_sort_key_cross_type_ordering() {
7333 use uni_common::core::id::{Eid, Vid};
7336 use uni_common::{Edge, Node, Path, TemporalValue, Value};
7337
7338 let map_val = Value::Map([("a".to_string(), Value::String("map".to_string()))].into());
7339 let node_val = Value::Node(Node {
7340 vid: Vid::new(1),
7341 labels: vec!["L".to_string()],
7342 properties: Default::default(),
7343 });
7344 let edge_val = Value::Edge(Edge {
7345 eid: Eid::new(1),
7346 edge_type: "T".to_string(),
7347 src: Vid::new(1),
7348 dst: Vid::new(2),
7349 properties: Default::default(),
7350 });
7351 let list_val = Value::List(vec![Value::Int(1)]);
7352 let path_val = Value::Path(Path {
7353 nodes: vec![Node {
7354 vid: Vid::new(1),
7355 labels: vec!["L".to_string()],
7356 properties: Default::default(),
7357 }],
7358 edges: vec![],
7359 });
7360 let string_val = Value::String("hello".to_string());
7361 let bool_val = Value::Bool(false);
7362 let temporal_val = Value::Temporal(TemporalValue::Date {
7363 days_since_epoch: 1000,
7364 });
7365 let number_val = Value::Int(42);
7366 let nan_val = Value::Float(f64::NAN);
7367 let null_val = Value::Null;
7368
7369 let values = vec![
7370 &map_val,
7371 &node_val,
7372 &edge_val,
7373 &list_val,
7374 &path_val,
7375 &string_val,
7376 &bool_val,
7377 &temporal_val,
7378 &number_val,
7379 &nan_val,
7380 &null_val,
7381 ];
7382
7383 let keys: Vec<Vec<u8>> = values.iter().map(|v| encode_cypher_sort_key(v)).collect();
7384
7385 for i in 0..keys.len() - 1 {
7387 assert!(
7388 keys[i] < keys[i + 1],
7389 "Expected sort_key({:?}) < sort_key({:?}), but {:?} >= {:?}",
7390 values[i],
7391 values[i + 1],
7392 keys[i],
7393 keys[i + 1]
7394 );
7395 }
7396 }
7397
7398 #[test]
7399 fn test_sort_key_numbers() {
7400 let neg_inf = encode_cypher_sort_key(&Value::Float(f64::NEG_INFINITY));
7401 let neg_100 = encode_cypher_sort_key(&Value::Float(-100.0));
7402 let neg_1 = encode_cypher_sort_key(&Value::Int(-1));
7403 let zero_int = encode_cypher_sort_key(&Value::Int(0));
7404 let zero_float = encode_cypher_sort_key(&Value::Float(0.0));
7405 let one_int = encode_cypher_sort_key(&Value::Int(1));
7406 let one_float = encode_cypher_sort_key(&Value::Float(1.0));
7407 let hundred = encode_cypher_sort_key(&Value::Int(100));
7408 let pos_inf = encode_cypher_sort_key(&Value::Float(f64::INFINITY));
7409 let nan = encode_cypher_sort_key(&Value::Float(f64::NAN));
7410
7411 assert!(neg_inf < neg_100, "-inf < -100");
7412 assert!(neg_100 < neg_1, "-100 < -1");
7413 assert!(neg_1 < zero_int, "-1 < 0");
7414 assert_eq!(zero_int, zero_float, "0 int == 0.0 float");
7415 assert!(zero_int < one_int, "0 < 1");
7416 assert_eq!(one_int, one_float, "1 int == 1.0 float");
7417 assert!(one_int < hundred, "1 < 100");
7418 assert!(hundred < pos_inf, "100 < +inf");
7419 assert!(pos_inf < nan, "+inf < NaN");
7421 }
7422
7423 #[test]
7424 fn test_sort_key_booleans() {
7425 let f = encode_cypher_sort_key(&Value::Bool(false));
7426 let t = encode_cypher_sort_key(&Value::Bool(true));
7427 assert!(f < t, "false < true");
7428 }
7429
7430 #[test]
7431 fn test_sort_key_strings() {
7432 let empty = encode_cypher_sort_key(&Value::String(String::new()));
7433 let a = encode_cypher_sort_key(&Value::String("a".to_string()));
7434 let ab = encode_cypher_sort_key(&Value::String("ab".to_string()));
7435 let b = encode_cypher_sort_key(&Value::String("b".to_string()));
7436
7437 assert!(empty < a, "'' < 'a'");
7438 assert!(a < ab, "'a' < 'ab'");
7439 assert!(ab < b, "'ab' < 'b'");
7440 }
7441
7442 #[test]
7443 fn test_sort_key_lists() {
7444 let empty = encode_cypher_sort_key(&Value::List(vec![]));
7445 let one = encode_cypher_sort_key(&Value::List(vec![Value::Int(1)]));
7446 let one_two = encode_cypher_sort_key(&Value::List(vec![Value::Int(1), Value::Int(2)]));
7447 let two = encode_cypher_sort_key(&Value::List(vec![Value::Int(2)]));
7448
7449 assert!(empty < one, "[] < [1]");
7450 assert!(one < one_two, "[1] < [1,2]");
7451 assert!(one_two < two, "[1,2] < [2]");
7452 }
7453
7454 #[test]
7455 fn test_sort_key_temporal() {
7456 use uni_common::TemporalValue;
7457
7458 let date1 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
7459 days_since_epoch: 100,
7460 }));
7461 let date2 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
7462 days_since_epoch: 200,
7463 }));
7464 assert!(date1 < date2, "earlier date < later date");
7465
7466 let date = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
7468 days_since_epoch: i32::MAX,
7469 }));
7470 let local_time = encode_cypher_sort_key(&Value::Temporal(TemporalValue::LocalTime {
7471 nanos_since_midnight: 0,
7472 }));
7473 assert!(date < local_time, "Date < LocalTime (by variant rank)");
7474 }
7475
7476 #[test]
7477 fn test_sort_key_nested_lists() {
7478 let inner_a = Value::List(vec![Value::Int(1)]);
7479 let inner_b = Value::List(vec![Value::Int(2)]);
7480
7481 let list_a = encode_cypher_sort_key(&Value::List(vec![inner_a.clone()]));
7482 let list_b = encode_cypher_sort_key(&Value::List(vec![inner_b.clone()]));
7483
7484 assert!(list_a < list_b, "[[1]] < [[2]]");
7485 }
7486
7487 #[test]
7488 fn test_sort_key_null_handling() {
7489 let null_key = encode_cypher_sort_key(&Value::Null);
7490 assert_eq!(null_key, vec![0x0A], "Null produces [0x0A]");
7491
7492 let number_key = encode_cypher_sort_key(&Value::Int(42));
7494 assert!(number_key < null_key, "number < null");
7495 }
7496
7497 #[test]
7498 fn test_byte_stuff_roundtrip() {
7499 let s1 = Value::String("a\x00b".to_string());
7501 let s2 = Value::String("a\x00c".to_string());
7502 let s3 = Value::String("a\x01".to_string());
7503
7504 let k1 = encode_cypher_sort_key(&s1);
7505 let k2 = encode_cypher_sort_key(&s2);
7506 let k3 = encode_cypher_sort_key(&s3);
7507
7508 assert!(k1 < k2, "a\\x00b < a\\x00c");
7509 assert!(k1 < k3, "a\\x00b < a\\x01");
7512 }
7513
7514 #[test]
7515 fn test_sort_key_order_preserving_f64() {
7516 let vals = [f64::NEG_INFINITY, -1.0, -0.0, 0.0, 1.0, f64::INFINITY];
7518 let encoded: Vec<[u8; 8]> = vals
7519 .iter()
7520 .map(|f| encode_order_preserving_f64(*f))
7521 .collect();
7522
7523 for i in 0..encoded.len() - 1 {
7524 assert!(
7525 encoded[i] <= encoded[i + 1],
7526 "encode({}) should <= encode({}), got {:?} vs {:?}",
7527 vals[i],
7528 vals[i + 1],
7529 encoded[i],
7530 encoded[i + 1]
7531 );
7532 }
7533 }
7534
7535 #[test]
7539 fn test_sort_key_string_as_temporal_time_with_offset() {
7540 let tv = sort_key_string_as_temporal("12:35:15+05:00")
7541 .expect("should parse Time with positive offset");
7542 match tv {
7543 uni_common::TemporalValue::Time {
7544 nanos_since_midnight,
7545 offset_seconds,
7546 } => {
7547 assert_eq!(offset_seconds, 5 * 3600, "offset should be +05:00 = 18000s");
7548 let expected_nanos = (12 * 3600 + 35 * 60 + 15) * 1_000_000_000i64;
7550 assert_eq!(nanos_since_midnight, expected_nanos);
7551 }
7552 other => panic!("expected TemporalValue::Time, got {other:?}"),
7553 }
7554 }
7555
7556 #[test]
7557 fn test_sort_key_string_as_temporal_time_negative_offset() {
7558 let tv = sort_key_string_as_temporal("10:35:00-08:00")
7559 .expect("should parse Time with negative offset");
7560 match tv {
7561 uni_common::TemporalValue::Time {
7562 nanos_since_midnight,
7563 offset_seconds,
7564 } => {
7565 assert_eq!(
7566 offset_seconds,
7567 -8 * 3600,
7568 "offset should be -08:00 = -28800s"
7569 );
7570 let expected_nanos = (10 * 3600 + 35 * 60) * 1_000_000_000i64;
7571 assert_eq!(nanos_since_midnight, expected_nanos);
7572 }
7573 other => panic!("expected TemporalValue::Time, got {other:?}"),
7574 }
7575 }
7576
7577 #[test]
7578 fn test_sort_key_string_as_temporal_date() {
7579 use super::super::expr_eval::temporal_from_value;
7580 let tv = temporal_from_value(&Value::String("2024-01-15".into()))
7581 .expect("should parse Date string");
7582 match tv {
7583 uni_common::TemporalValue::Date { days_since_epoch } => {
7584 assert!(days_since_epoch > 0, "2024-01-15 should be after epoch");
7586 }
7587 other => panic!("expected TemporalValue::Date, got {other:?}"),
7588 }
7589 }
7590}