1use arrow::array::ArrayRef;
27use arrow::datatypes::DataType;
28use arrow_array::{
29 Array, BooleanArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array,
30 LargeBinaryArray, LargeStringArray, StringArray, UInt64Array,
31};
32use chrono::Offset;
33use datafusion::error::Result as DFResult;
34use datafusion::logical_expr::{
35 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
36 Volatility,
37};
38use datafusion::prelude::SessionContext;
39use datafusion::scalar::ScalarValue;
40use std::any::Any;
41use std::hash::{Hash, Hasher};
42use std::sync::Arc;
43use uni_common::Value;
44use uni_cypher::ast::BinaryOp;
45use uni_store::storage::arrow_convert::values_to_array;
46
47use super::expr_eval::cypher_eq;
48
49macro_rules! impl_udf_eq_hash {
53 ($type:ty) => {
54 impl PartialEq for $type {
55 fn eq(&self, other: &Self) -> bool {
56 self.signature == other.signature
57 }
58 }
59
60 impl Eq for $type {}
61
62 impl Hash for $type {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.name().hash(state);
65 }
66 }
67 };
68}
69
70pub fn register_cypher_udfs(ctx: &SessionContext) -> DFResult<()> {
80 ctx.register_udf(create_id_udf());
81 ctx.register_udf(create_created_at_udf());
82 ctx.register_udf(create_updated_at_udf());
83 ctx.register_udf(create_type_udf());
84 ctx.register_udf(create_keys_udf());
85 ctx.register_udf(create_properties_udf());
86 ctx.register_udf(create_labels_udf());
87 ctx.register_udf(create_nodes_udf());
88 ctx.register_udf(create_relationships_udf());
89 ctx.register_udf(create_range_udf());
90 ctx.register_udf(create_index_udf());
91 ctx.register_udf(create_startnode_udf());
92 ctx.register_udf(create_endnode_udf());
93
94 ctx.register_udf(create_to_integer_udf());
96 ctx.register_udf(create_to_float_udf());
97 ctx.register_udf(create_to_boolean_udf());
98
99 ctx.register_udf(create_bitwise_or_udf());
101 ctx.register_udf(create_bitwise_and_udf());
102 ctx.register_udf(create_bitwise_xor_udf());
103 ctx.register_udf(create_bitwise_not_udf());
104 ctx.register_udf(create_shift_left_udf());
105 ctx.register_udf(create_shift_right_udf());
106
107 for name in &[
109 "date",
111 "time",
112 "localtime",
113 "localdatetime",
114 "datetime",
115 "duration",
116 "btic",
117 "duration.between",
119 "duration.inmonths",
120 "duration.indays",
121 "duration.inseconds",
122 "datetime.fromepoch",
123 "datetime.fromepochmillis",
124 "date.truncate",
126 "time.truncate",
127 "datetime.truncate",
128 "localdatetime.truncate",
129 "localtime.truncate",
130 "datetime.transaction",
132 "datetime.statement",
133 "datetime.realtime",
134 "date.transaction",
135 "date.statement",
136 "date.realtime",
137 "time.transaction",
138 "time.statement",
139 "time.realtime",
140 "localtime.transaction",
141 "localtime.statement",
142 "localtime.realtime",
143 "localdatetime.transaction",
144 "localdatetime.statement",
145 "localdatetime.realtime",
146 ] {
147 ctx.register_udf(create_temporal_udf(name));
148 }
149
150 ctx.register_udf(create_duration_property_udf());
152 ctx.register_udf(create_temporal_property_udf());
153 ctx.register_udf(create_tostring_udf());
154 ctx.register_udf(create_cypher_sort_key_udf());
155 ctx.register_udf(create_has_null_udf());
156 ctx.register_udf(create_cypher_size_udf());
157
158 ctx.register_udf(create_cypher_starts_with_udf());
160 ctx.register_udf(create_cypher_ends_with_udf());
161 ctx.register_udf(create_cypher_contains_udf());
162
163 ctx.register_udf(create_cypher_list_compare_udf());
165
166 ctx.register_udf(create_cypher_xor_udf());
168
169 ctx.register_udf(create_cypher_equal_udf());
171 ctx.register_udf(create_cypher_not_equal_udf());
172 ctx.register_udf(create_cypher_gt_udf());
173 ctx.register_udf(create_cypher_gt_eq_udf());
174 ctx.register_udf(create_cypher_lt_udf());
175 ctx.register_udf(create_cypher_lt_eq_udf());
176
177 ctx.register_udf(create_cv_to_bool_udf());
179
180 ctx.register_udf(create_cypher_add_udf());
182 ctx.register_udf(create_cypher_sub_udf());
183 ctx.register_udf(create_cypher_mul_udf());
184 ctx.register_udf(create_cypher_div_udf());
185 ctx.register_udf(create_cypher_mod_udf());
186
187 ctx.register_udf(create_map_project_udf());
189
190 ctx.register_udf(create_make_cypher_list_udf());
192
193 ctx.register_udf(create_cypher_in_udf());
195
196 ctx.register_udf(create_cypher_list_concat_udf());
198 ctx.register_udf(create_cypher_list_append_udf());
199 ctx.register_udf(create_cypher_list_slice_udf());
200 ctx.register_udf(create_cypher_tail_udf());
201 ctx.register_udf(create_cypher_head_udf());
202 ctx.register_udf(create_cypher_last_udf());
203 ctx.register_udf(create_cypher_reverse_udf());
204 ctx.register_udf(create_cypher_substring_udf());
205 ctx.register_udf(create_cypher_split_udf());
206 ctx.register_udf(create_cypher_list_to_cv_udf());
207 ctx.register_udf(create_cypher_scalar_to_cv_udf());
208
209 for name in &["year", "month", "day", "hour", "minute", "second"] {
211 ctx.register_udf(create_temporal_udf(name));
212 }
213
214 ctx.register_udf(create_cypher_to_float64_udf());
216
217 ctx.register_udf(create_similar_to_udf());
219 ctx.register_udf(create_vector_similarity_udf());
220
221 ctx.register_udaf(create_cypher_min_udaf());
223 ctx.register_udaf(create_cypher_max_udaf());
224 ctx.register_udaf(create_cypher_sum_udaf());
225 ctx.register_udaf(create_cypher_collect_udaf());
226
227 ctx.register_udaf(create_cypher_percentile_disc_udaf());
229 ctx.register_udaf(create_cypher_percentile_cont_udaf());
230
231 register_btic_scalar_udfs(ctx)?;
233
234 ctx.register_udaf(create_btic_min_udaf());
236 ctx.register_udaf(create_btic_max_udaf());
237 ctx.register_udaf(create_btic_span_agg_udaf());
238 ctx.register_udaf(create_btic_count_at_udaf());
239
240 Ok(())
241}
242
243pub fn register_custom_udfs(
251 ctx: &SessionContext,
252 registry: &crate::custom_functions::CustomFunctionRegistry,
253) -> DFResult<()> {
254 for (name, func) in registry.iter() {
255 let lower = name.to_lowercase();
258 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
259 lower,
260 func.clone(),
261 )));
262 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
264 name.to_string(),
265 func.clone(),
266 )));
267 }
268 Ok(())
269}
270
271struct CustomScalarUdf {
276 name: String,
277 func: crate::custom_functions::CustomScalarFn,
278 signature: Signature,
279}
280
281impl CustomScalarUdf {
282 fn new(name: String, func: crate::custom_functions::CustomScalarFn) -> Self {
283 Self {
284 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Volatile),
285 name,
286 func,
287 }
288 }
289}
290
291impl std::fmt::Debug for CustomScalarUdf {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("CustomScalarUdf")
294 .field("name", &self.name)
295 .finish()
296 }
297}
298
299impl_udf_eq_hash!(CustomScalarUdf);
300
301impl ScalarUDFImpl for CustomScalarUdf {
302 fn as_any(&self) -> &dyn Any {
303 self
304 }
305
306 fn name(&self) -> &str {
307 &self.name
308 }
309
310 fn signature(&self) -> &Signature {
311 &self.signature
312 }
313
314 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
315 Ok(DataType::LargeBinary)
316 }
317
318 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
319 let func = &self.func;
320 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
321 func(vals).map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
322 })
323 }
324}
325
326pub fn create_id_udf() -> ScalarUDF {
334 ScalarUDF::new_from_impl(IdUdf::new())
335}
336
337#[derive(Debug)]
338struct IdUdf {
339 signature: Signature,
340}
341
342impl IdUdf {
343 fn new() -> Self {
344 Self {
345 signature: Signature::new(
346 TypeSignature::Exact(vec![DataType::UInt64]),
347 Volatility::Immutable,
348 ),
349 }
350 }
351}
352
353impl_udf_eq_hash!(IdUdf);
354
355impl ScalarUDFImpl for IdUdf {
356 fn as_any(&self) -> &dyn Any {
357 self
358 }
359
360 fn name(&self) -> &str {
361 "id"
362 }
363
364 fn signature(&self) -> &Signature {
365 &self.signature
366 }
367
368 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
369 Ok(DataType::UInt64)
370 }
371
372 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
373 if args.args.is_empty() {
375 return Err(datafusion::error::DataFusionError::Execution(
376 "id(): requires 1 argument".to_string(),
377 ));
378 }
379 Ok(args.args[0].clone())
380 }
381}
382
383pub fn create_created_at_udf() -> ScalarUDF {
396 ScalarUDF::new_from_impl(SystemTimestampUdf::new("created_at"))
397}
398
399pub fn create_updated_at_udf() -> ScalarUDF {
406 ScalarUDF::new_from_impl(SystemTimestampUdf::new("updated_at"))
407}
408
409#[derive(Debug)]
410struct SystemTimestampUdf {
411 name: &'static str,
412 signature: Signature,
413}
414
415impl SystemTimestampUdf {
416 fn new(name: &'static str) -> Self {
417 Self {
418 name,
419 signature: Signature::new(
420 TypeSignature::Exact(vec![DataType::Timestamp(
421 arrow_schema::TimeUnit::Nanosecond,
422 Some("UTC".into()),
423 )]),
424 Volatility::Immutable,
425 ),
426 }
427 }
428}
429
430impl_udf_eq_hash!(SystemTimestampUdf);
431
432impl ScalarUDFImpl for SystemTimestampUdf {
433 fn as_any(&self) -> &dyn Any {
434 self
435 }
436
437 fn name(&self) -> &str {
438 self.name
439 }
440
441 fn signature(&self) -> &Signature {
442 &self.signature
443 }
444
445 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
446 Ok(DataType::Timestamp(
447 arrow_schema::TimeUnit::Nanosecond,
448 Some("UTC".into()),
449 ))
450 }
451
452 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
453 if args.args.is_empty() {
458 return Err(datafusion::error::DataFusionError::Execution(format!(
459 "{}(): requires 1 argument",
460 self.name
461 )));
462 }
463 Ok(args.args[0].clone())
464 }
465}
466
467pub fn create_type_udf() -> ScalarUDF {
475 ScalarUDF::new_from_impl(TypeUdf::new())
476}
477
478#[derive(Debug)]
479struct TypeUdf {
480 signature: Signature,
481}
482
483impl TypeUdf {
484 fn new() -> Self {
485 Self {
486 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
490 }
491 }
492}
493
494impl_udf_eq_hash!(TypeUdf);
495
496impl ScalarUDFImpl for TypeUdf {
497 fn as_any(&self) -> &dyn Any {
498 self
499 }
500
501 fn name(&self) -> &str {
502 "type"
503 }
504
505 fn signature(&self) -> &Signature {
506 &self.signature
507 }
508
509 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
510 Ok(DataType::Utf8)
511 }
512
513 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
514 if args.args.is_empty() {
515 return Err(datafusion::error::DataFusionError::Execution(
516 "type(): requires 1 argument".to_string(),
517 ));
518 }
519 let output_type = DataType::Utf8;
520 invoke_cypher_udf(args, &output_type, |val_args| {
521 if val_args.is_empty() {
522 return Err(datafusion::error::DataFusionError::Execution(
523 "type(): requires 1 argument".to_string(),
524 ));
525 }
526 let val = &val_args[0];
527 match val {
528 Value::Map(map) => {
530 if let Some(Value::String(t)) = map.get("_type") {
531 Ok(Value::String(t.clone()))
532 } else {
533 Err(datafusion::error::DataFusionError::Execution(
535 "TypeError: InvalidArgumentValue - type() requires a relationship argument".to_string(),
536 ))
537 }
538 }
539 Value::Null => Ok(Value::Null),
540 _ => Err(datafusion::error::DataFusionError::Execution(
541 "TypeError: InvalidArgumentValue - type() requires a relationship argument"
542 .to_string(),
543 )),
544 }
545 })
546 }
547}
548
549pub fn create_keys_udf() -> ScalarUDF {
557 ScalarUDF::new_from_impl(KeysUdf::new())
558}
559
560#[derive(Debug)]
561struct KeysUdf {
562 signature: Signature,
563}
564
565impl KeysUdf {
566 fn new() -> Self {
567 Self {
568 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
569 }
570 }
571}
572
573impl_udf_eq_hash!(KeysUdf);
574
575impl ScalarUDFImpl for KeysUdf {
576 fn as_any(&self) -> &dyn Any {
577 self
578 }
579
580 fn name(&self) -> &str {
581 "keys"
582 }
583
584 fn signature(&self) -> &Signature {
585 &self.signature
586 }
587
588 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
589 Ok(DataType::List(Arc::new(
590 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
591 )))
592 }
593
594 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
595 let output_type = self.return_type(&[])?;
596 invoke_cypher_udf(args, &output_type, |val_args| {
597 if val_args.is_empty() {
598 return Err(datafusion::error::DataFusionError::Execution(
599 "keys(): requires 1 argument".to_string(),
600 ));
601 }
602
603 let arg = &val_args[0];
604 let keys = match arg {
605 Value::Map(map) => {
606 let (source, is_entity) = match map.get("_all_props") {
616 Some(Value::Map(all)) => (all, true),
617 _ => (map, false),
618 };
619 let mut key_strings: Vec<String> = source
620 .iter()
621 .filter(|(k, v)| !k.starts_with('_') && (!is_entity || !v.is_null()))
622 .map(|(k, _)| k.clone())
623 .collect();
624 key_strings.sort();
625 key_strings
626 .into_iter()
627 .map(Value::String)
628 .collect::<Vec<_>>()
629 }
630 Value::Null => {
631 return Ok(Value::Null);
632 }
633 _ => {
634 vec![]
637 }
638 };
639
640 Ok(Value::List(keys))
641 })
642 }
643}
644
645pub fn create_properties_udf() -> ScalarUDF {
650 ScalarUDF::new_from_impl(PropertiesUdf::new())
651}
652
653#[derive(Debug)]
654struct PropertiesUdf {
655 signature: Signature,
656}
657
658impl PropertiesUdf {
659 fn new() -> Self {
660 Self {
661 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
662 }
663 }
664}
665
666impl_udf_eq_hash!(PropertiesUdf);
667
668impl ScalarUDFImpl for PropertiesUdf {
669 fn as_any(&self) -> &dyn Any {
670 self
671 }
672
673 fn name(&self) -> &str {
674 "properties"
675 }
676
677 fn signature(&self) -> &Signature {
678 &self.signature
679 }
680
681 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
682 Ok(DataType::LargeBinary)
684 }
685
686 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
687 let output_type = self.return_type(&[])?;
688 invoke_cypher_udf(args, &output_type, |val_args| {
689 if val_args.is_empty() {
690 return Err(datafusion::error::DataFusionError::Execution(
691 "properties(): requires 1 argument".to_string(),
692 ));
693 }
694
695 let arg = &val_args[0];
696 match arg {
697 Value::Map(map) => {
698 let identity_null = map
708 .get("_vid")
709 .map(|v| v.is_null())
710 .or_else(|| map.get("_eid").map(|v| v.is_null()))
711 .unwrap_or(false);
712 if identity_null {
713 return Ok(Value::Null);
714 }
715
716 let source = match map.get("_all_props") {
718 Some(Value::Map(all)) => all,
719 _ => map,
720 };
721 let filtered: std::collections::HashMap<String, Value> = source
723 .iter()
724 .filter(|(k, _)| !k.starts_with('_'))
725 .map(|(k, v)| (k.clone(), v.clone()))
726 .collect();
727 Ok(Value::Map(filtered))
728 }
729 _ => Ok(Value::Null),
730 }
731 })
732 }
733}
734
735pub fn create_index_udf() -> ScalarUDF {
740 ScalarUDF::new_from_impl(IndexUdf::new())
741}
742
743#[derive(Debug)]
744struct IndexUdf {
745 signature: Signature,
746}
747
748impl IndexUdf {
749 fn new() -> Self {
750 Self {
751 signature: Signature::any(2, Volatility::Immutable),
752 }
753 }
754}
755
756impl_udf_eq_hash!(IndexUdf);
757
758impl ScalarUDFImpl for IndexUdf {
759 fn as_any(&self) -> &dyn Any {
760 self
761 }
762
763 fn name(&self) -> &str {
764 "index"
765 }
766
767 fn signature(&self) -> &Signature {
768 &self.signature
769 }
770
771 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
772 Ok(DataType::LargeBinary)
774 }
775
776 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
777 let output_type = self.return_type(&[])?;
778 invoke_cypher_udf(args, &output_type, |val_args| {
779 if val_args.len() != 2 {
780 return Err(datafusion::error::DataFusionError::Execution(
781 "index(): requires 2 arguments".to_string(),
782 ));
783 }
784
785 let container = &val_args[0];
786 let index = &val_args[1];
787
788 let index_as_int = index.as_i64();
792
793 let result = match container {
794 Value::List(arr) => {
795 if let Some(i) = index_as_int {
796 let idx = if i < 0 {
797 let pos = arr.len() as i64 + i;
798 if pos < 0 { -1 } else { pos }
799 } else {
800 i
801 };
802 if idx >= 0 && (idx as usize) < arr.len() {
803 arr[idx as usize].clone()
804 } else {
805 Value::Null
806 }
807 } else if index.is_null() {
808 Value::Null
809 } else {
810 return Err(datafusion::error::DataFusionError::Execution(format!(
811 "TypeError: InvalidArgumentType - list index must be an integer, got: {:?}",
812 index
813 )));
814 }
815 }
816 Value::Map(map) => {
817 if let Some(key) = index.as_str() {
818 if let Some(val) = map.get(key) {
820 val.clone()
821 } else if let Some(Value::Map(all_props)) = map.get("_all_props") {
822 all_props.get(key).cloned().unwrap_or(Value::Null)
824 } else if let Some(Value::Map(props)) = map.get("properties") {
825 props.get(key).cloned().unwrap_or(Value::Null)
827 } else {
828 Value::Null
829 }
830 } else if !index.is_null() {
831 return Err(datafusion::error::DataFusionError::Execution(
832 "index(): map index must be a string".to_string(),
833 ));
834 } else {
835 Value::Null
836 }
837 }
838 Value::Node(node) => {
839 if let Some(key) = index.as_str() {
840 node.properties.get(key).cloned().unwrap_or(Value::Null)
841 } else if !index.is_null() {
842 return Err(datafusion::error::DataFusionError::Execution(
843 "index(): node index must be a string".to_string(),
844 ));
845 } else {
846 Value::Null
847 }
848 }
849 Value::Edge(edge) => {
850 if let Some(key) = index.as_str() {
851 edge.properties.get(key).cloned().unwrap_or(Value::Null)
852 } else if !index.is_null() {
853 return Err(datafusion::error::DataFusionError::Execution(
854 "index(): edge index must be a string".to_string(),
855 ));
856 } else {
857 Value::Null
858 }
859 }
860 Value::Null => Value::Null,
861 _ => {
862 return Err(datafusion::error::DataFusionError::Execution(format!(
863 "TypeError: InvalidArgumentType - cannot index into {:?}",
864 container
865 )));
866 }
867 };
868
869 Ok(result)
870 })
871 }
872}
873
874pub fn create_labels_udf() -> ScalarUDF {
879 ScalarUDF::new_from_impl(LabelsUdf::new())
880}
881
882#[derive(Debug)]
883struct LabelsUdf {
884 signature: Signature,
885}
886
887impl LabelsUdf {
888 fn new() -> Self {
889 Self {
890 signature: Signature::any(1, Volatility::Immutable),
891 }
892 }
893}
894
895impl_udf_eq_hash!(LabelsUdf);
896
897impl ScalarUDFImpl for LabelsUdf {
898 fn as_any(&self) -> &dyn Any {
899 self
900 }
901
902 fn name(&self) -> &str {
903 "labels"
904 }
905
906 fn signature(&self) -> &Signature {
907 &self.signature
908 }
909
910 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
911 Ok(DataType::List(Arc::new(
912 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
913 )))
914 }
915
916 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
917 let output_type = self.return_type(&[])?;
918 invoke_cypher_udf(args, &output_type, |val_args| {
919 if val_args.is_empty() {
920 return Err(datafusion::error::DataFusionError::Execution(
921 "labels(): requires 1 argument".to_string(),
922 ));
923 }
924
925 let node = &val_args[0];
926 match node {
927 Value::Map(map) => {
928 if let Some(Value::List(arr)) = map.get("_labels") {
929 Ok(Value::List(arr.clone()))
930 } else {
931 Err(datafusion::error::DataFusionError::Execution(
933 "TypeError: InvalidArgumentValue - labels() requires a node argument"
934 .to_string(),
935 ))
936 }
937 }
938 Value::Null => Ok(Value::Null),
939 _ => Err(datafusion::error::DataFusionError::Execution(
940 "TypeError: InvalidArgumentValue - labels() requires a node argument"
941 .to_string(),
942 )),
943 }
944 })
945 }
946}
947
948pub fn create_nodes_udf() -> ScalarUDF {
953 ScalarUDF::new_from_impl(NodesUdf::new())
954}
955
956#[derive(Debug)]
957struct NodesUdf {
958 signature: Signature,
959}
960
961impl NodesUdf {
962 fn new() -> Self {
963 Self {
964 signature: Signature::any(1, Volatility::Immutable),
965 }
966 }
967}
968
969impl_udf_eq_hash!(NodesUdf);
970
971impl ScalarUDFImpl for NodesUdf {
972 fn as_any(&self) -> &dyn Any {
973 self
974 }
975
976 fn name(&self) -> &str {
977 "nodes"
978 }
979
980 fn signature(&self) -> &Signature {
981 &self.signature
982 }
983
984 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
985 Ok(DataType::LargeBinary)
986 }
987
988 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
989 let output_type = self.return_type(&[])?;
990 invoke_cypher_udf(args, &output_type, |val_args| {
991 if val_args.is_empty() {
992 return Err(datafusion::error::DataFusionError::Execution(
993 "nodes(): requires 1 argument".to_string(),
994 ));
995 }
996
997 let path = &val_args[0];
998 let nodes = match path {
999 Value::Map(map) => map.get("nodes").cloned().unwrap_or(Value::Null),
1000 _ => Value::Null,
1001 };
1002
1003 Ok(nodes)
1004 })
1005 }
1006}
1007
1008pub fn create_relationships_udf() -> ScalarUDF {
1013 ScalarUDF::new_from_impl(RelationshipsUdf::new())
1014}
1015
1016#[derive(Debug)]
1017struct RelationshipsUdf {
1018 signature: Signature,
1019}
1020
1021impl RelationshipsUdf {
1022 fn new() -> Self {
1023 Self {
1024 signature: Signature::any(1, Volatility::Immutable),
1025 }
1026 }
1027}
1028
1029impl_udf_eq_hash!(RelationshipsUdf);
1030
1031impl ScalarUDFImpl for RelationshipsUdf {
1032 fn as_any(&self) -> &dyn Any {
1033 self
1034 }
1035
1036 fn name(&self) -> &str {
1037 "relationships"
1038 }
1039
1040 fn signature(&self) -> &Signature {
1041 &self.signature
1042 }
1043
1044 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1045 Ok(DataType::LargeBinary)
1046 }
1047
1048 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1049 let output_type = self.return_type(&[])?;
1050 invoke_cypher_udf(args, &output_type, |val_args| {
1051 if val_args.is_empty() {
1052 return Err(datafusion::error::DataFusionError::Execution(
1053 "relationships(): requires 1 argument".to_string(),
1054 ));
1055 }
1056
1057 let path = &val_args[0];
1058 let rels = match path {
1059 Value::Map(map) => map.get("relationships").cloned().unwrap_or(Value::Null),
1060 _ => Value::Null,
1061 };
1062
1063 Ok(rels)
1064 })
1065 }
1066}
1067
1068pub fn create_startnode_udf() -> ScalarUDF {
1077 ScalarUDF::new_from_impl(StartNodeUdf::new())
1078}
1079
1080#[derive(Debug)]
1081struct StartNodeUdf {
1082 signature: Signature,
1083}
1084
1085impl StartNodeUdf {
1086 fn new() -> Self {
1087 Self {
1088 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1089 }
1090 }
1091}
1092
1093impl_udf_eq_hash!(StartNodeUdf);
1094
1095impl ScalarUDFImpl for StartNodeUdf {
1096 fn as_any(&self) -> &dyn Any {
1097 self
1098 }
1099
1100 fn name(&self) -> &str {
1101 "startnode"
1102 }
1103
1104 fn signature(&self) -> &Signature {
1105 &self.signature
1106 }
1107
1108 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1109 Ok(DataType::LargeBinary)
1110 }
1111
1112 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1113 let output_type = DataType::LargeBinary;
1114 invoke_cypher_udf(args, &output_type, |val_args| {
1115 startnode_endnode_impl(val_args, true)
1116 })
1117 }
1118}
1119
1120pub fn create_endnode_udf() -> ScalarUDF {
1126 ScalarUDF::new_from_impl(EndNodeUdf::new())
1127}
1128
1129#[derive(Debug)]
1130struct EndNodeUdf {
1131 signature: Signature,
1132}
1133
1134impl EndNodeUdf {
1135 fn new() -> Self {
1136 Self {
1137 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1138 }
1139 }
1140}
1141
1142impl_udf_eq_hash!(EndNodeUdf);
1143
1144impl ScalarUDFImpl for EndNodeUdf {
1145 fn as_any(&self) -> &dyn Any {
1146 self
1147 }
1148
1149 fn name(&self) -> &str {
1150 "endnode"
1151 }
1152
1153 fn signature(&self) -> &Signature {
1154 &self.signature
1155 }
1156
1157 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1158 Ok(DataType::LargeBinary)
1159 }
1160
1161 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1162 let output_type = DataType::LargeBinary;
1163 invoke_cypher_udf(args, &output_type, |val_args| {
1164 startnode_endnode_impl(val_args, false)
1165 })
1166 }
1167}
1168
1169fn startnode_endnode_impl(val_args: &[Value], is_start: bool) -> DFResult<Value> {
1174 if val_args.is_empty() {
1175 let fn_name = if is_start { "startNode" } else { "endNode" };
1176 return Err(datafusion::error::DataFusionError::Execution(format!(
1177 "{fn_name}(): requires at least 1 argument"
1178 )));
1179 }
1180
1181 let edge_val = &val_args[0];
1182 let target_vid = extract_endpoint_vid(edge_val, is_start);
1183
1184 let target_vid = match target_vid {
1185 Some(vid) => vid,
1186 None => return Ok(Value::Null),
1187 };
1188
1189 for node_val in val_args.iter().skip(1) {
1191 if let Some(vid) = extract_vid(node_val)
1192 && vid == target_vid
1193 {
1194 return Ok(node_val.clone());
1195 }
1196 }
1197
1198 let mut map = std::collections::HashMap::new();
1200 map.insert("_vid".to_string(), Value::Int(target_vid as i64));
1201 Ok(Value::Map(map))
1202}
1203
1204fn extract_endpoint_vid(val: &Value, is_start: bool) -> Option<u64> {
1206 match val {
1207 Value::Edge(edge) => {
1208 let vid = if is_start { edge.src } else { edge.dst };
1209 Some(vid.as_u64())
1210 }
1211 Value::Map(map) => {
1212 let key = if is_start { "_src_vid" } else { "_dst_vid" };
1214 if let Some(v) = map.get(key) {
1215 return v.as_u64();
1216 }
1217 let key2 = if is_start { "_src" } else { "_dst" };
1219 if let Some(v) = map.get(key2) {
1220 return v.as_u64();
1221 }
1222 let node_key = if is_start { "_startNode" } else { "_endNode" };
1224 if let Some(node_val) = map.get(node_key) {
1225 return extract_vid(node_val);
1226 }
1227 None
1228 }
1229 _ => None,
1230 }
1231}
1232
1233fn extract_vid(val: &Value) -> Option<u64> {
1235 match val {
1236 Value::Map(map) => map.get("_vid").and_then(|v| v.as_u64()),
1237 _ => None,
1238 }
1239}
1240
1241fn extract_i64_range_arg(arg: &ColumnarValue, row_idx: usize, name: &str) -> DFResult<i64> {
1248 match arg {
1249 ColumnarValue::Scalar(sv) => match sv {
1250 ScalarValue::Int8(Some(v)) => Ok(*v as i64),
1251 ScalarValue::Int16(Some(v)) => Ok(*v as i64),
1252 ScalarValue::Int32(Some(v)) => Ok(*v as i64),
1253 ScalarValue::Int64(Some(v)) => Ok(*v),
1254 ScalarValue::UInt8(Some(v)) => Ok(*v as i64),
1255 ScalarValue::UInt16(Some(v)) => Ok(*v as i64),
1256 ScalarValue::UInt32(Some(v)) => Ok(*v as i64),
1257 ScalarValue::UInt64(Some(v)) => Ok(*v as i64),
1258 ScalarValue::LargeBinary(Some(bytes)) => {
1259 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1260 datafusion::error::DataFusionError::Execution(format!(
1261 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1262 name
1263 ))
1264 })
1265 }
1266 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1267 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1268 name
1269 ))),
1270 },
1271 ColumnarValue::Array(arr) => {
1272 if row_idx >= arr.len() || arr.is_null(row_idx) {
1273 return Err(datafusion::error::DataFusionError::Execution(format!(
1274 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1275 name
1276 )));
1277 }
1278 if !arr.is_empty() {
1280 use datafusion::arrow::array::{
1281 Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, UInt16Array,
1282 UInt32Array, UInt64Array,
1283 };
1284 match arr.data_type() {
1285 DataType::Int8 => Ok(arr
1286 .as_any()
1287 .downcast_ref::<Int8Array>()
1288 .unwrap()
1289 .value(row_idx) as i64),
1290 DataType::Int16 => Ok(arr
1291 .as_any()
1292 .downcast_ref::<Int16Array>()
1293 .unwrap()
1294 .value(row_idx) as i64),
1295 DataType::Int32 => Ok(arr
1296 .as_any()
1297 .downcast_ref::<Int32Array>()
1298 .unwrap()
1299 .value(row_idx) as i64),
1300 DataType::Int64 => Ok(arr
1301 .as_any()
1302 .downcast_ref::<Int64Array>()
1303 .unwrap()
1304 .value(row_idx)),
1305 DataType::UInt8 => Ok(arr
1306 .as_any()
1307 .downcast_ref::<UInt8Array>()
1308 .unwrap()
1309 .value(row_idx) as i64),
1310 DataType::UInt16 => Ok(arr
1311 .as_any()
1312 .downcast_ref::<UInt16Array>()
1313 .unwrap()
1314 .value(row_idx) as i64),
1315 DataType::UInt32 => Ok(arr
1316 .as_any()
1317 .downcast_ref::<UInt32Array>()
1318 .unwrap()
1319 .value(row_idx) as i64),
1320 DataType::UInt64 => Ok(arr
1321 .as_any()
1322 .downcast_ref::<UInt64Array>()
1323 .unwrap()
1324 .value(row_idx) as i64),
1325 DataType::LargeBinary => {
1326 let bytes = arr
1327 .as_any()
1328 .downcast_ref::<LargeBinaryArray>()
1329 .unwrap()
1330 .value(row_idx);
1331 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1332 datafusion::error::DataFusionError::Execution(format!(
1333 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1334 name
1335 ))
1336 })
1337 }
1338 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1339 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1340 name
1341 ))),
1342 }
1343 } else {
1344 Err(datafusion::error::DataFusionError::Execution(format!(
1345 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1346 name
1347 )))
1348 }
1349 }
1350 }
1351}
1352
1353pub fn create_range_udf() -> ScalarUDF {
1355 ScalarUDF::new_from_impl(RangeUdf::new())
1356}
1357
1358#[derive(Debug)]
1359struct RangeUdf {
1360 signature: Signature,
1361}
1362
1363impl RangeUdf {
1364 fn new() -> Self {
1365 Self {
1366 signature: Signature::one_of(
1367 vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1368 Volatility::Immutable,
1369 ),
1370 }
1371 }
1372}
1373
1374impl_udf_eq_hash!(RangeUdf);
1375
1376impl ScalarUDFImpl for RangeUdf {
1377 fn as_any(&self) -> &dyn Any {
1378 self
1379 }
1380
1381 fn name(&self) -> &str {
1382 "range"
1383 }
1384
1385 fn signature(&self) -> &Signature {
1386 &self.signature
1387 }
1388
1389 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1390 Ok(DataType::List(Arc::new(
1391 arrow::datatypes::Field::new_list_field(DataType::Int64, true),
1392 )))
1393 }
1394
1395 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1396 if args.args.len() < 2 || args.args.len() > 3 {
1397 return Err(datafusion::error::DataFusionError::Execution(
1398 "range(): requires 2 or 3 arguments".to_string(),
1399 ));
1400 }
1401
1402 let len = args
1403 .args
1404 .iter()
1405 .find_map(|arg| match arg {
1406 ColumnarValue::Array(arr) => Some(arr.len()),
1407 _ => None,
1408 })
1409 .unwrap_or(1);
1410
1411 let mut list_builder =
1412 arrow_array::builder::ListBuilder::new(arrow_array::builder::Int64Builder::new());
1413
1414 for row_idx in 0..len {
1415 let start = extract_i64_range_arg(&args.args[0], row_idx, "start")?;
1416 let end = extract_i64_range_arg(&args.args[1], row_idx, "end")?;
1417 let step = if args.args.len() == 3 {
1418 extract_i64_range_arg(&args.args[2], row_idx, "step")?
1419 } else {
1420 1
1421 };
1422
1423 if step == 0 {
1424 return Err(datafusion::error::DataFusionError::Execution(
1425 "range(): step cannot be zero".to_string(),
1426 ));
1427 }
1428
1429 if step > 0 && start <= end {
1430 let mut current = start;
1431 while current <= end {
1432 list_builder.values().append_value(current);
1433 current += step;
1434 }
1435 } else if step < 0 && start >= end {
1436 let mut current = start;
1437 while current >= end {
1438 list_builder.values().append_value(current);
1439 current += step;
1440 }
1441 }
1442 list_builder.append(true);
1444 }
1445
1446 let list_arr = Arc::new(list_builder.finish()) as ArrayRef;
1447 if len == 1
1448 && args
1449 .args
1450 .iter()
1451 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
1452 {
1453 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
1454 &list_arr, 0,
1455 )?))
1456 } else {
1457 Ok(ColumnarValue::Array(list_arr))
1458 }
1459 }
1460}
1461
1462fn invoke_binary_bitwise_op<F>(
1470 args: &ScalarFunctionArgs,
1471 name: &str,
1472 op: F,
1473) -> DFResult<ColumnarValue>
1474where
1475 F: Fn(i64, i64) -> i64,
1476{
1477 use arrow_array::Int64Array;
1478 use datafusion::common::ScalarValue;
1479 use datafusion::error::DataFusionError;
1480
1481 if args.args.len() != 2 {
1482 return Err(DataFusionError::Execution(format!(
1483 "{}(): requires exactly 2 arguments",
1484 name
1485 )));
1486 }
1487
1488 let left = &args.args[0];
1489 let right = &args.args[1];
1490
1491 match (left, right) {
1492 (
1493 ColumnarValue::Scalar(ScalarValue::Int64(Some(l))),
1494 ColumnarValue::Scalar(ScalarValue::Int64(Some(r))),
1495 ) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*l, *r))))),
1496 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
1497 let l_arr = l_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1498 DataFusionError::Execution(format!("{}(): left array must be Int64", name))
1499 })?;
1500 let r_arr = r_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1501 DataFusionError::Execution(format!("{}(): right array must be Int64", name))
1502 })?;
1503
1504 let result: Int64Array = l_arr
1505 .iter()
1506 .zip(r_arr.iter())
1507 .map(|(l, r)| match (l, r) {
1508 (Some(l), Some(r)) => Some(op(l, r)),
1509 _ => None,
1510 })
1511 .collect();
1512
1513 Ok(ColumnarValue::Array(Arc::new(result)))
1514 }
1515 _ => Err(DataFusionError::Execution(format!(
1516 "{}(): mixed scalar/array not supported",
1517 name
1518 ))),
1519 }
1520}
1521
1522fn invoke_unary_bitwise_op<F>(
1526 args: &ScalarFunctionArgs,
1527 name: &str,
1528 op: F,
1529) -> DFResult<ColumnarValue>
1530where
1531 F: Fn(i64) -> i64,
1532{
1533 use arrow_array::Int64Array;
1534 use datafusion::common::ScalarValue;
1535 use datafusion::error::DataFusionError;
1536
1537 if args.args.len() != 1 {
1538 return Err(DataFusionError::Execution(format!(
1539 "{}(): requires exactly 1 argument",
1540 name
1541 )));
1542 }
1543
1544 let operand = &args.args[0];
1545
1546 match operand {
1547 ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
1548 Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*v)))))
1549 }
1550 ColumnarValue::Array(arr) => {
1551 let arr = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1552 DataFusionError::Execution(format!("{}(): array must be Int64", name))
1553 })?;
1554
1555 let result: Int64Array = arr.iter().map(|v| v.map(&op)).collect();
1556
1557 Ok(ColumnarValue::Array(Arc::new(result)))
1558 }
1559 _ => Err(DataFusionError::Execution(format!(
1560 "{}(): invalid argument type",
1561 name
1562 ))),
1563 }
1564}
1565
1566macro_rules! define_binary_bitwise_udf {
1570 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1571 #[derive(Debug)]
1572 struct $struct_name {
1573 signature: Signature,
1574 }
1575
1576 impl $struct_name {
1577 fn new() -> Self {
1578 Self {
1579 signature: Signature::exact(
1580 vec![DataType::Int64, DataType::Int64],
1581 Volatility::Immutable,
1582 ),
1583 }
1584 }
1585 }
1586
1587 impl_udf_eq_hash!($struct_name);
1588
1589 impl ScalarUDFImpl for $struct_name {
1590 fn as_any(&self) -> &dyn Any {
1591 self
1592 }
1593
1594 fn name(&self) -> &str {
1595 $udf_name
1596 }
1597
1598 fn signature(&self) -> &Signature {
1599 &self.signature
1600 }
1601
1602 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1603 Ok(DataType::Int64)
1604 }
1605
1606 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1607 invoke_binary_bitwise_op(&args, $udf_name, $op)
1608 }
1609 }
1610 };
1611}
1612
1613macro_rules! define_unary_bitwise_udf {
1617 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1618 #[derive(Debug)]
1619 struct $struct_name {
1620 signature: Signature,
1621 }
1622
1623 impl $struct_name {
1624 fn new() -> Self {
1625 Self {
1626 signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
1627 }
1628 }
1629 }
1630
1631 impl_udf_eq_hash!($struct_name);
1632
1633 impl ScalarUDFImpl for $struct_name {
1634 fn as_any(&self) -> &dyn Any {
1635 self
1636 }
1637
1638 fn name(&self) -> &str {
1639 $udf_name
1640 }
1641
1642 fn signature(&self) -> &Signature {
1643 &self.signature
1644 }
1645
1646 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1647 Ok(DataType::Int64)
1648 }
1649
1650 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1651 invoke_unary_bitwise_op(&args, $udf_name, $op)
1652 }
1653 }
1654 };
1655}
1656
1657define_binary_bitwise_udf!(BitwiseOrUdf, "uni.bitwise.or", |l, r| l | r);
1659define_binary_bitwise_udf!(BitwiseAndUdf, "uni.bitwise.and", |l, r| l & r);
1660define_binary_bitwise_udf!(BitwiseXorUdf, "uni.bitwise.xor", |l, r| l ^ r);
1661define_binary_bitwise_udf!(ShiftLeftUdf, "uni.bitwise.shiftLeft", |l, r| l << r);
1662define_binary_bitwise_udf!(ShiftRightUdf, "uni.bitwise.shiftRight", |l, r| l >> r);
1663
1664define_unary_bitwise_udf!(BitwiseNotUdf, "uni.bitwise.not", |v| !v);
1666
1667pub fn create_bitwise_or_udf() -> ScalarUDF {
1669 ScalarUDF::new_from_impl(BitwiseOrUdf::new())
1670}
1671
1672pub fn create_bitwise_and_udf() -> ScalarUDF {
1674 ScalarUDF::new_from_impl(BitwiseAndUdf::new())
1675}
1676
1677pub fn create_bitwise_xor_udf() -> ScalarUDF {
1679 ScalarUDF::new_from_impl(BitwiseXorUdf::new())
1680}
1681
1682pub fn create_bitwise_not_udf() -> ScalarUDF {
1684 ScalarUDF::new_from_impl(BitwiseNotUdf::new())
1685}
1686
1687pub fn create_shift_left_udf() -> ScalarUDF {
1689 ScalarUDF::new_from_impl(ShiftLeftUdf::new())
1690}
1691
1692pub fn create_shift_right_udf() -> ScalarUDF {
1694 ScalarUDF::new_from_impl(ShiftRightUdf::new())
1695}
1696
1697fn create_temporal_udf(name: &str) -> ScalarUDF {
1708 ScalarUDF::new_from_impl(TemporalUdf::new(name.to_string()))
1709}
1710
1711#[derive(Debug)]
1712struct TemporalUdf {
1713 name: String,
1714 signature: Signature,
1715}
1716
1717impl TemporalUdf {
1718 fn new(name: String) -> Self {
1719 Self {
1720 name,
1721 signature: Signature::new(
1724 TypeSignature::OneOf(vec![
1725 TypeSignature::Exact(vec![]),
1726 TypeSignature::VariadicAny,
1727 ]),
1728 Volatility::Immutable,
1729 ),
1730 }
1731 }
1732}
1733
1734impl_udf_eq_hash!(TemporalUdf);
1735
1736impl ScalarUDFImpl for TemporalUdf {
1737 fn as_any(&self) -> &dyn Any {
1738 self
1739 }
1740
1741 fn name(&self) -> &str {
1742 &self.name
1743 }
1744
1745 fn signature(&self) -> &Signature {
1746 &self.signature
1747 }
1748
1749 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1750 let name = self.name.to_lowercase();
1751 match name.as_str() {
1752 "year" | "month" | "day" | "hour" | "minute" | "second" => Ok(DataType::Int64),
1754 "datetime"
1760 | "localdatetime"
1761 | "date"
1762 | "time"
1763 | "localtime"
1764 | "duration"
1765 | "date.truncate"
1766 | "time.truncate"
1767 | "datetime.truncate"
1768 | "localdatetime.truncate"
1769 | "localtime.truncate"
1770 | "duration.between"
1771 | "duration.inmonths"
1772 | "duration.indays"
1773 | "duration.inseconds"
1774 | "datetime.fromepoch"
1775 | "datetime.fromepochmillis"
1776 | "datetime.transaction"
1777 | "datetime.statement"
1778 | "datetime.realtime"
1779 | "date.transaction"
1780 | "date.statement"
1781 | "date.realtime"
1782 | "time.transaction"
1783 | "time.statement"
1784 | "time.realtime"
1785 | "localtime.transaction"
1786 | "localtime.statement"
1787 | "localtime.realtime"
1788 | "localdatetime.transaction"
1789 | "localdatetime.statement"
1790 | "localdatetime.realtime"
1791 | "btic" => Ok(DataType::LargeBinary),
1792 _ => Ok(DataType::Utf8),
1793 }
1794 }
1795
1796 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1797 let func_name = self.name.to_uppercase();
1798 let output_type = self.return_type(&[])?;
1799 invoke_cypher_udf(args, &output_type, |val_args| {
1800 crate::datetime::eval_datetime_function(&func_name, val_args).map_err(|e| {
1801 datafusion::error::DataFusionError::Execution(format!("{}(): {}", self.name, e))
1802 })
1803 })
1804 }
1805}
1806
1807fn create_duration_property_udf() -> ScalarUDF {
1812 ScalarUDF::new_from_impl(DurationPropertyUdf::new())
1813}
1814
1815#[derive(Debug)]
1816struct DurationPropertyUdf {
1817 signature: Signature,
1818}
1819
1820impl DurationPropertyUdf {
1821 fn new() -> Self {
1822 Self {
1823 signature: Signature::new(
1824 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
1825 Volatility::Immutable,
1826 ),
1827 }
1828 }
1829}
1830
1831impl_udf_eq_hash!(DurationPropertyUdf);
1832
1833impl ScalarUDFImpl for DurationPropertyUdf {
1834 fn as_any(&self) -> &dyn Any {
1835 self
1836 }
1837
1838 fn name(&self) -> &str {
1839 "_duration_property"
1840 }
1841
1842 fn signature(&self) -> &Signature {
1843 &self.signature
1844 }
1845
1846 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1847 Ok(DataType::Int64)
1848 }
1849
1850 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1851 let output_type = self.return_type(&[])?;
1852 invoke_cypher_udf(args, &output_type, |val_args| {
1853 if val_args.len() != 2 {
1854 return Err(datafusion::error::DataFusionError::Execution(
1855 "_duration_property(): requires 2 arguments (duration_string, component)"
1856 .to_string(),
1857 ));
1858 }
1859
1860 let dur_string_owned;
1861 let dur_str = match &val_args[0] {
1862 Value::String(s) => s.as_str(),
1863 Value::Temporal(uni_common::TemporalValue::Duration { .. }) => {
1864 dur_string_owned = val_args[0].to_string();
1865 &dur_string_owned
1866 }
1867 Value::Null => return Ok(Value::Null),
1868 _ => {
1869 return Err(datafusion::error::DataFusionError::Execution(
1870 "_duration_property(): duration must be a string or temporal duration"
1871 .to_string(),
1872 ));
1873 }
1874 };
1875 let component = match &val_args[1] {
1876 Value::String(s) => s,
1877 _ => {
1878 return Err(datafusion::error::DataFusionError::Execution(
1879 "_duration_property(): component must be a string".to_string(),
1880 ));
1881 }
1882 };
1883
1884 crate::datetime::eval_duration_accessor(dur_str, component).map_err(|e| {
1885 datafusion::error::DataFusionError::Execution(format!(
1886 "_duration_property(): {}",
1887 e
1888 ))
1889 })
1890 })
1891 }
1892}
1893
1894fn create_tostring_udf() -> ScalarUDF {
1899 ScalarUDF::new_from_impl(ToStringUdf::new())
1900}
1901
1902#[derive(Debug)]
1903struct ToStringUdf {
1904 signature: Signature,
1905}
1906
1907impl ToStringUdf {
1908 fn new() -> Self {
1909 Self {
1910 signature: Signature::variadic_any(Volatility::Immutable),
1911 }
1912 }
1913}
1914
1915impl_udf_eq_hash!(ToStringUdf);
1916
1917impl ScalarUDFImpl for ToStringUdf {
1918 fn as_any(&self) -> &dyn Any {
1919 self
1920 }
1921
1922 fn name(&self) -> &str {
1923 "tostring"
1924 }
1925
1926 fn signature(&self) -> &Signature {
1927 &self.signature
1928 }
1929
1930 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1931 Ok(DataType::Utf8)
1932 }
1933
1934 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1935 let output_type = self.return_type(&[])?;
1936 invoke_cypher_udf(args, &output_type, |val_args| {
1937 if val_args.is_empty() {
1938 return Err(datafusion::error::DataFusionError::Execution(
1939 "toString(): requires 1 argument".to_string(),
1940 ));
1941 }
1942 match &val_args[0] {
1943 Value::Null => Ok(Value::Null),
1944 Value::String(s) => Ok(Value::String(s.clone())),
1945 Value::Int(i) => Ok(Value::String(i.to_string())),
1946 Value::Float(f) => Ok(Value::String(f.to_string())),
1947 Value::Bool(b) => Ok(Value::String(b.to_string())),
1948 Value::Temporal(t) => Ok(Value::String(t.to_string())),
1949 other => {
1950 let type_name = match other {
1951 Value::List(_) => "List",
1952 Value::Map(_) => "Map",
1953 Value::Node { .. } => "Node",
1954 Value::Edge { .. } => "Relationship",
1955 Value::Path { .. } => "Path",
1956 _ => "Unknown",
1957 };
1958 Err(datafusion::error::DataFusionError::Execution(format!(
1959 "TypeError: InvalidArgumentValue - toString() does not accept {} values",
1960 type_name
1961 )))
1962 }
1963 }
1964 })
1965 }
1966}
1967
1968fn create_temporal_property_udf() -> ScalarUDF {
1973 ScalarUDF::new_from_impl(TemporalPropertyUdf::new())
1974}
1975
1976#[derive(Debug)]
1977struct TemporalPropertyUdf {
1978 signature: Signature,
1979}
1980
1981impl TemporalPropertyUdf {
1982 fn new() -> Self {
1983 Self {
1984 signature: Signature::variadic_any(Volatility::Immutable),
1985 }
1986 }
1987}
1988
1989impl_udf_eq_hash!(TemporalPropertyUdf);
1990
1991impl ScalarUDFImpl for TemporalPropertyUdf {
1992 fn as_any(&self) -> &dyn Any {
1993 self
1994 }
1995
1996 fn name(&self) -> &str {
1997 "_temporal_property"
1998 }
1999
2000 fn signature(&self) -> &Signature {
2001 &self.signature
2002 }
2003
2004 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2005 Ok(DataType::LargeBinary)
2006 }
2007
2008 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2009 let output_type = self.return_type(&[])?;
2010 invoke_cypher_udf(args, &output_type, |val_args| {
2011 if val_args.len() != 2 {
2012 return Err(datafusion::error::DataFusionError::Execution(
2013 "_temporal_property(): requires 2 arguments (temporal_value, component)"
2014 .to_string(),
2015 ));
2016 }
2017
2018 let component = match &val_args[1] {
2019 Value::String(s) => s.clone(),
2020 _ => {
2021 return Err(datafusion::error::DataFusionError::Execution(
2022 "_temporal_property(): component must be a string".to_string(),
2023 ));
2024 }
2025 };
2026
2027 crate::datetime::eval_temporal_accessor_value(&val_args[0], &component).map_err(|e| {
2028 datafusion::error::DataFusionError::Execution(format!(
2029 "_temporal_property(): {}",
2030 e
2031 ))
2032 })
2033 })
2034 }
2035}
2036
2037macro_rules! downcast_arr {
2040 ($arr:expr, $array_type:ty) => {
2041 $arr.as_any().downcast_ref::<$array_type>().ok_or_else(|| {
2042 datafusion::error::DataFusionError::Execution(format!(
2043 "Failed to downcast to {}",
2044 stringify!($array_type)
2045 ))
2046 })?
2047 };
2048}
2049
2050fn cypher_type_name(val: &Value) -> &'static str {
2052 match val {
2053 Value::Null => "Null",
2054 Value::Bool(_) => "Boolean",
2055 Value::Int(_) => "Integer",
2056 Value::Float(_) => "Float",
2057 Value::String(_) => "String",
2058 Value::Bytes(_) => "Bytes",
2059 Value::List(_) => "List",
2060 Value::Map(_) => "Map",
2061 Value::Node(_) => "Node",
2062 Value::Edge(_) => "Relationship",
2063 Value::Path(_) => "Path",
2064 Value::Vector(_) => "Vector",
2065 Value::Temporal(_) => "Temporal",
2066 _ => "Unknown",
2067 }
2068}
2069
2070fn string_to_value(s: &str) -> Value {
2072 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2073 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2074 {
2075 return Value::from(obj);
2076 }
2077 Value::String(s.to_string())
2078}
2079
2080fn get_value_from_array(arr: &ArrayRef, row: usize) -> DFResult<Value> {
2086 if arr.is_null(row) {
2087 return Ok(Value::Null);
2088 }
2089
2090 match arr.data_type() {
2091 DataType::LargeBinary => {
2092 let typed = downcast_arr!(arr, LargeBinaryArray);
2093 let bytes = typed.value(row);
2094 if let Ok(val) = uni_common::cypher_value_codec::decode(bytes) {
2095 return Ok(val);
2096 }
2097 Ok(serde_json::from_slice::<serde_json::Value>(bytes)
2099 .map(Value::from)
2100 .unwrap_or(Value::Null))
2101 }
2102 DataType::Int64 => Ok(Value::Int(downcast_arr!(arr, Int64Array).value(row))),
2103 DataType::Float64 => Ok(Value::Float(downcast_arr!(arr, Float64Array).value(row))),
2104 DataType::Utf8 => Ok(string_to_value(downcast_arr!(arr, StringArray).value(row))),
2105 DataType::LargeUtf8 => Ok(string_to_value(
2106 downcast_arr!(arr, LargeStringArray).value(row),
2107 )),
2108 DataType::Boolean => Ok(Value::Bool(downcast_arr!(arr, BooleanArray).value(row))),
2109 DataType::UInt64 => Ok(Value::Int(downcast_arr!(arr, UInt64Array).value(row) as i64)),
2110 DataType::Int32 => Ok(Value::Int(downcast_arr!(arr, Int32Array).value(row) as i64)),
2111 DataType::Float32 => Ok(Value::Float(
2112 downcast_arr!(arr, Float32Array).value(row) as f64
2113 )),
2114 _ => {
2117 let scalar = ScalarValue::try_from_array(arr, row).map_err(|e| {
2118 datafusion::error::DataFusionError::Execution(format!(
2119 "Cannot extract scalar from array at row {}: {}",
2120 row, e
2121 ))
2122 })?;
2123 scalar_to_value(&scalar)
2124 }
2125 }
2126}
2127
2128fn get_value_args_for_row(args: &[ColumnarValue], row: usize) -> DFResult<Vec<Value>> {
2130 args.iter()
2131 .map(|arg| match arg {
2132 ColumnarValue::Scalar(scalar) => scalar_to_value(scalar),
2133 ColumnarValue::Array(arr) => get_value_from_array(arr, row),
2134 })
2135 .collect()
2136}
2137
2138fn invoke_cypher_udf<F>(
2140 args: ScalarFunctionArgs,
2141 output_type: &DataType,
2142 f: F,
2143) -> DFResult<ColumnarValue>
2144where
2145 F: Fn(&[Value]) -> DFResult<Value>,
2146{
2147 let len = args
2148 .args
2149 .iter()
2150 .find_map(|arg| match arg {
2151 ColumnarValue::Array(arr) => Some(arr.len()),
2152 _ => None,
2153 })
2154 .unwrap_or(1);
2155
2156 if len == 1
2157 && args
2158 .args
2159 .iter()
2160 .all(|a| matches!(a, ColumnarValue::Scalar(_)))
2161 {
2162 let row_args = get_value_args_for_row(&args.args, 0)?;
2163 let res = f(&row_args)?;
2164 if matches!(output_type, DataType::LargeBinary | DataType::List(_)) {
2165 let arr = values_to_array(&[res], output_type)
2167 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2168 return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, 0)?));
2169 }
2170 if res.is_null() {
2172 let typed_null = ScalarValue::try_from(output_type).unwrap_or(ScalarValue::Utf8(None));
2173 return Ok(ColumnarValue::Scalar(typed_null));
2174 }
2175 return value_to_columnar(&res);
2176 }
2177
2178 let mut results = Vec::with_capacity(len);
2179 for i in 0..len {
2180 let row_args = get_value_args_for_row(&args.args, i)?;
2181 results.push(f(&row_args)?);
2182 }
2183
2184 let arr = values_to_array(&results, output_type)
2185 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2186 Ok(ColumnarValue::Array(arr))
2187}
2188
2189fn scalar_arr_to_value(arr: &dyn arrow::array::Array) -> DFResult<Value> {
2192 if arr.is_empty() || arr.is_null(0) {
2193 Ok(Value::Null)
2194 } else {
2195 Ok(uni_store::storage::arrow_convert::arrow_to_value(
2197 arr, 0, None,
2198 ))
2199 }
2200}
2201
2202fn resolve_timezone_offset(tz_name: &str, nanos_utc: i64) -> i32 {
2204 if tz_name == "UTC" || tz_name == "Z" {
2205 return 0;
2206 }
2207 if let Ok(tz) = tz_name.parse::<chrono_tz::Tz>() {
2208 let dt = chrono::DateTime::from_timestamp_nanos(nanos_utc).with_timezone(&tz);
2209 dt.offset().fix().local_minus_utc()
2210 } else {
2211 0
2212 }
2213}
2214
2215fn duration_micros_to_value(micros: i64) -> Value {
2217 let dur = crate::datetime::CypherDuration::from_micros(micros);
2218 Value::Temporal(uni_common::TemporalValue::Duration {
2219 months: dur.months,
2220 days: dur.days,
2221 nanos: dur.nanos,
2222 })
2223}
2224
2225fn timestamp_nanos_to_value(nanos: i64, tz: Option<&Arc<str>>) -> DFResult<Value> {
2227 if let Some(tz_str) = tz {
2228 let offset = resolve_timezone_offset(tz_str.as_ref(), nanos);
2229 let tz_name = if tz_str.as_ref() == "UTC" {
2230 None
2231 } else {
2232 Some(tz_str.to_string())
2233 };
2234 Ok(Value::Temporal(uni_common::TemporalValue::DateTime {
2235 nanos_since_epoch: nanos,
2236 offset_seconds: offset,
2237 timezone_name: tz_name,
2238 }))
2239 } else {
2240 Ok(Value::Temporal(uni_common::TemporalValue::LocalDateTime {
2241 nanos_since_epoch: nanos,
2242 }))
2243 }
2244}
2245
2246pub fn scalar_to_value(scalar: &ScalarValue) -> DFResult<Value> {
2248 match scalar {
2249 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
2250 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2253 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2254 {
2255 return Ok(Value::from(obj));
2256 }
2257 Ok(Value::String(s.clone()))
2258 }
2259 ScalarValue::LargeBinary(Some(b)) => {
2260 if let Ok(val) = uni_common::cypher_value_codec::decode(b) {
2263 return Ok(val);
2264 }
2265 if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(b) {
2266 Ok(Value::from(obj))
2267 } else {
2268 Ok(Value::Null)
2269 }
2270 }
2271 ScalarValue::Int64(Some(i)) => Ok(Value::Int(*i)),
2272 ScalarValue::Int32(Some(i)) => Ok(Value::Int(*i as i64)),
2273 ScalarValue::Float64(Some(f)) => {
2274 Ok(Value::Float(*f))
2276 }
2277 ScalarValue::Boolean(Some(b)) => Ok(Value::Bool(*b)),
2278 ScalarValue::Struct(arr) => scalar_arr_to_value(arr.as_ref()),
2279 ScalarValue::List(arr) => scalar_arr_to_value(arr.as_ref()),
2280 ScalarValue::LargeList(arr) => scalar_arr_to_value(arr.as_ref()),
2281 ScalarValue::FixedSizeList(arr) => scalar_arr_to_value(arr.as_ref()),
2282 ScalarValue::UInt64(Some(u)) => Ok(Value::Int(*u as i64)),
2284 ScalarValue::UInt32(Some(u)) => Ok(Value::Int(*u as i64)),
2285 ScalarValue::UInt16(Some(u)) => Ok(Value::Int(*u as i64)),
2286 ScalarValue::UInt8(Some(u)) => Ok(Value::Int(*u as i64)),
2287 ScalarValue::Int16(Some(i)) => Ok(Value::Int(*i as i64)),
2288 ScalarValue::Int8(Some(i)) => Ok(Value::Int(*i as i64)),
2289
2290 ScalarValue::Date32(Some(days)) => Ok(Value::Temporal(uni_common::TemporalValue::Date {
2292 days_since_epoch: *days,
2293 })),
2294 ScalarValue::Date64(Some(millis)) => {
2295 let days = (*millis / 86_400_000) as i32;
2296 Ok(Value::Temporal(uni_common::TemporalValue::Date {
2297 days_since_epoch: days,
2298 }))
2299 }
2300 ScalarValue::TimestampNanosecond(Some(nanos), tz) => {
2301 timestamp_nanos_to_value(*nanos, tz.as_ref())
2302 }
2303 ScalarValue::TimestampMicrosecond(Some(micros), tz) => {
2304 timestamp_nanos_to_value(*micros * 1_000, tz.as_ref())
2305 }
2306 ScalarValue::TimestampMillisecond(Some(millis), tz) => {
2307 timestamp_nanos_to_value(*millis * 1_000_000, tz.as_ref())
2308 }
2309 ScalarValue::TimestampSecond(Some(secs), tz) => {
2310 timestamp_nanos_to_value(*secs * 1_000_000_000, tz.as_ref())
2311 }
2312 ScalarValue::Time64Nanosecond(Some(nanos)) => {
2313 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2314 nanos_since_midnight: *nanos,
2315 }))
2316 }
2317 ScalarValue::Time64Microsecond(Some(micros)) => {
2318 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2319 nanos_since_midnight: *micros * 1_000,
2320 }))
2321 }
2322 ScalarValue::IntervalMonthDayNano(Some(v)) => {
2323 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2324 months: v.months as i64,
2325 days: v.days as i64,
2326 nanos: v.nanoseconds,
2327 }))
2328 }
2329 ScalarValue::DurationMicrosecond(Some(micros)) => Ok(duration_micros_to_value(*micros)),
2330 ScalarValue::DurationMillisecond(Some(millis)) => {
2331 Ok(duration_micros_to_value(*millis * 1_000))
2332 }
2333 ScalarValue::DurationSecond(Some(secs)) => Ok(duration_micros_to_value(*secs * 1_000_000)),
2334 ScalarValue::DurationNanosecond(Some(nanos)) => {
2335 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2336 months: 0,
2337 days: 0,
2338 nanos: *nanos,
2339 }))
2340 }
2341 ScalarValue::Float32(Some(f)) => Ok(Value::Float(*f as f64)),
2342
2343 ScalarValue::FixedSizeBinary(24, Some(bytes)) => {
2345 match uni_btic::encode::decode_slice(bytes) {
2346 Ok(btic) => Ok(Value::Temporal(uni_common::TemporalValue::Btic {
2347 lo: btic.lo(),
2348 hi: btic.hi(),
2349 meta: btic.meta(),
2350 })),
2351 Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
2352 "BTIC decode error: {e}"
2353 ))),
2354 }
2355 }
2356
2357 ScalarValue::Null
2359 | ScalarValue::Utf8(None)
2360 | ScalarValue::LargeUtf8(None)
2361 | ScalarValue::LargeBinary(None)
2362 | ScalarValue::Int64(None)
2363 | ScalarValue::Int32(None)
2364 | ScalarValue::Int16(None)
2365 | ScalarValue::Int8(None)
2366 | ScalarValue::UInt64(None)
2367 | ScalarValue::UInt32(None)
2368 | ScalarValue::UInt16(None)
2369 | ScalarValue::UInt8(None)
2370 | ScalarValue::Float64(None)
2371 | ScalarValue::Float32(None)
2372 | ScalarValue::Boolean(None)
2373 | ScalarValue::Date32(None)
2374 | ScalarValue::Date64(None)
2375 | ScalarValue::TimestampMicrosecond(None, _)
2376 | ScalarValue::TimestampMillisecond(None, _)
2377 | ScalarValue::TimestampSecond(None, _)
2378 | ScalarValue::TimestampNanosecond(None, _)
2379 | ScalarValue::Time64Microsecond(None)
2380 | ScalarValue::Time64Nanosecond(None)
2381 | ScalarValue::DurationMicrosecond(None)
2382 | ScalarValue::DurationMillisecond(None)
2383 | ScalarValue::DurationSecond(None)
2384 | ScalarValue::DurationNanosecond(None)
2385 | ScalarValue::IntervalMonthDayNano(None)
2386 | ScalarValue::FixedSizeBinary(_, None) => Ok(Value::Null),
2387 other => Err(datafusion::error::DataFusionError::Execution(format!(
2388 "scalar_to_value(): unsupported scalar type {other:?}"
2389 ))),
2390 }
2391}
2392
2393fn value_to_columnar(val: &Value) -> DFResult<ColumnarValue> {
2395 let scalar = match val {
2396 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
2397 Value::Int(i) => ScalarValue::Int64(Some(*i)),
2398 Value::Float(f) => ScalarValue::Float64(Some(*f)),
2399 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
2400 Value::Null => ScalarValue::Utf8(None),
2401 Value::Temporal(tv) => {
2402 use uni_common::TemporalValue;
2403 match tv {
2404 TemporalValue::Date { days_since_epoch } => {
2405 ScalarValue::Date32(Some(*days_since_epoch))
2406 }
2407 TemporalValue::LocalTime {
2408 nanos_since_midnight,
2409 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2410 TemporalValue::Time {
2411 nanos_since_midnight,
2412 ..
2413 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2414 TemporalValue::LocalDateTime { nanos_since_epoch } => {
2415 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None)
2416 }
2417 TemporalValue::DateTime {
2418 nanos_since_epoch,
2419 timezone_name,
2420 ..
2421 } => {
2422 let tz = timezone_name.as_deref().unwrap_or("UTC");
2423 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), Some(tz.into()))
2424 }
2425 TemporalValue::Duration {
2426 months,
2427 days,
2428 nanos,
2429 } => ScalarValue::IntervalMonthDayNano(Some(
2430 arrow::datatypes::IntervalMonthDayNano {
2431 months: *months as i32,
2432 days: *days as i32,
2433 nanoseconds: *nanos,
2434 },
2435 )),
2436 TemporalValue::Btic { lo, hi, meta } => {
2437 let btic = uni_btic::Btic::new(*lo, *hi, *meta).map_err(|e| {
2438 datafusion::error::DataFusionError::Execution(format!("invalid BTIC: {e}"))
2439 })?;
2440 let packed = uni_btic::encode::encode(&btic);
2441 ScalarValue::FixedSizeBinary(24, Some(packed.to_vec()))
2442 }
2443 }
2444 }
2445 other => {
2446 return Err(datafusion::error::DataFusionError::Execution(format!(
2447 "value_to_columnar(): unsupported type {other:?}"
2448 )));
2449 }
2450 };
2451 Ok(ColumnarValue::Scalar(scalar))
2452}
2453
2454pub fn create_has_null_udf() -> ScalarUDF {
2460 ScalarUDF::new_from_impl(HasNullUdf::new())
2461}
2462
2463#[derive(Debug)]
2464struct HasNullUdf {
2465 signature: Signature,
2466}
2467
2468impl HasNullUdf {
2469 fn new() -> Self {
2470 Self {
2471 signature: Signature::any(1, Volatility::Immutable),
2472 }
2473 }
2474}
2475
2476impl_udf_eq_hash!(HasNullUdf);
2477
2478impl ScalarUDFImpl for HasNullUdf {
2479 fn as_any(&self) -> &dyn Any {
2480 self
2481 }
2482
2483 fn name(&self) -> &str {
2484 "_has_null"
2485 }
2486
2487 fn signature(&self) -> &Signature {
2488 &self.signature
2489 }
2490
2491 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2492 Ok(DataType::Boolean)
2493 }
2494
2495 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2496 if args.args.len() != 1 {
2497 return Err(datafusion::error::DataFusionError::Execution(
2498 "_has_null(): requires 1 argument".to_string(),
2499 ));
2500 }
2501
2502 fn check_list_nulls<T: arrow_array::OffsetSizeTrait>(
2504 arr: &arrow_array::GenericListArray<T>,
2505 idx: usize,
2506 ) -> bool {
2507 if arr.is_null(idx) || arr.is_empty() {
2508 false
2509 } else {
2510 arr.value(idx).null_count() > 0
2511 }
2512 }
2513
2514 match &args.args[0] {
2515 ColumnarValue::Scalar(scalar) => {
2516 let has_null = match scalar {
2517 ScalarValue::List(arr) => arr
2518 .as_any()
2519 .downcast_ref::<arrow::array::ListArray>()
2520 .map(|a| !a.is_empty() && a.value(0).null_count() > 0)
2521 .unwrap_or(arr.null_count() > 0),
2522 ScalarValue::LargeList(arr) => arr.len() > 0 && arr.value(0).null_count() > 0,
2523 ScalarValue::FixedSizeList(arr) => {
2524 arr.len() > 0 && arr.value(0).null_count() > 0
2525 }
2526 _ => false,
2527 };
2528 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(has_null))))
2529 }
2530 ColumnarValue::Array(arr) => {
2531 use arrow_array::{LargeListArray, ListArray};
2532
2533 let results: arrow::array::BooleanArray =
2534 if let Some(list_arr) = arr.as_any().downcast_ref::<ListArray>() {
2535 (0..list_arr.len())
2536 .map(|i| {
2537 if list_arr.is_null(i) {
2538 None
2539 } else {
2540 Some(check_list_nulls(list_arr, i))
2541 }
2542 })
2543 .collect()
2544 } else if let Some(large) = arr.as_any().downcast_ref::<LargeListArray>() {
2545 (0..large.len())
2546 .map(|i| {
2547 if large.is_null(i) {
2548 None
2549 } else {
2550 Some(check_list_nulls(large, i))
2551 }
2552 })
2553 .collect()
2554 } else {
2555 return Err(datafusion::error::DataFusionError::Execution(
2556 "_has_null(): requires list array".to_string(),
2557 ));
2558 };
2559 Ok(ColumnarValue::Array(Arc::new(results)))
2560 }
2561 }
2562 }
2563}
2564
2565pub fn create_to_integer_udf() -> ScalarUDF {
2570 ScalarUDF::new_from_impl(ToIntegerUdf::new())
2571}
2572
2573#[derive(Debug)]
2574struct ToIntegerUdf {
2575 signature: Signature,
2576}
2577
2578impl ToIntegerUdf {
2579 fn new() -> Self {
2580 Self {
2581 signature: Signature::any(1, Volatility::Immutable),
2582 }
2583 }
2584}
2585
2586impl_udf_eq_hash!(ToIntegerUdf);
2587
2588impl ScalarUDFImpl for ToIntegerUdf {
2589 fn as_any(&self) -> &dyn Any {
2590 self
2591 }
2592
2593 fn name(&self) -> &str {
2594 "tointeger"
2595 }
2596
2597 fn signature(&self) -> &Signature {
2598 &self.signature
2599 }
2600
2601 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2602 Ok(DataType::Int64)
2603 }
2604
2605 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2606 let output_type = self.return_type(&[])?;
2607 invoke_cypher_udf(args, &output_type, |val_args| {
2608 if val_args.is_empty() {
2609 return Err(datafusion::error::DataFusionError::Execution(
2610 "tointeger(): requires 1 argument".to_string(),
2611 ));
2612 }
2613
2614 let val = &val_args[0];
2615 match val {
2616 Value::Int(i) => Ok(Value::Int(*i)),
2617 Value::Float(f) => Ok(Value::Int(*f as i64)),
2618 Value::String(s) => {
2619 if let Ok(i) = s.parse::<i64>() {
2620 Ok(Value::Int(i))
2621 } else if let Ok(f) = s.parse::<f64>() {
2622 Ok(Value::Int(f as i64))
2623 } else {
2624 Ok(Value::Null)
2625 }
2626 }
2627 Value::Null => Ok(Value::Null),
2628 other => Err(datafusion::error::DataFusionError::Execution(format!(
2629 "InvalidArgumentValue: tointeger(): cannot convert {} to integer",
2630 cypher_type_name(other)
2631 ))),
2632 }
2633 })
2634 }
2635}
2636
2637pub fn create_to_float_udf() -> ScalarUDF {
2642 ScalarUDF::new_from_impl(ToFloatUdf::new())
2643}
2644
2645#[derive(Debug)]
2646struct ToFloatUdf {
2647 signature: Signature,
2648}
2649
2650impl ToFloatUdf {
2651 fn new() -> Self {
2652 Self {
2653 signature: Signature::any(1, Volatility::Immutable),
2654 }
2655 }
2656}
2657
2658impl_udf_eq_hash!(ToFloatUdf);
2659
2660impl ScalarUDFImpl for ToFloatUdf {
2661 fn as_any(&self) -> &dyn Any {
2662 self
2663 }
2664
2665 fn name(&self) -> &str {
2666 "tofloat"
2667 }
2668
2669 fn signature(&self) -> &Signature {
2670 &self.signature
2671 }
2672
2673 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2674 Ok(DataType::Float64)
2675 }
2676
2677 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2678 let output_type = self.return_type(&[])?;
2679 invoke_cypher_udf(args, &output_type, |val_args| {
2680 if val_args.is_empty() {
2681 return Err(datafusion::error::DataFusionError::Execution(
2682 "tofloat(): requires 1 argument".to_string(),
2683 ));
2684 }
2685
2686 let val = &val_args[0];
2687 match val {
2688 Value::Int(i) => Ok(Value::Float(*i as f64)),
2689 Value::Float(f) => Ok(Value::Float(*f)),
2690 Value::String(s) => {
2691 if let Ok(f) = s.parse::<f64>() {
2692 Ok(Value::Float(f))
2693 } else {
2694 Ok(Value::Null)
2695 }
2696 }
2697 Value::Null => Ok(Value::Null),
2698 other => Err(datafusion::error::DataFusionError::Execution(format!(
2699 "InvalidArgumentValue: tofloat(): cannot convert {} to float",
2700 cypher_type_name(other)
2701 ))),
2702 }
2703 })
2704 }
2705}
2706
2707pub fn create_to_boolean_udf() -> ScalarUDF {
2712 ScalarUDF::new_from_impl(ToBooleanUdf::new())
2713}
2714
2715#[derive(Debug)]
2716struct ToBooleanUdf {
2717 signature: Signature,
2718}
2719
2720impl ToBooleanUdf {
2721 fn new() -> Self {
2722 Self {
2723 signature: Signature::any(1, Volatility::Immutable),
2724 }
2725 }
2726}
2727
2728impl_udf_eq_hash!(ToBooleanUdf);
2729
2730impl ScalarUDFImpl for ToBooleanUdf {
2731 fn as_any(&self) -> &dyn Any {
2732 self
2733 }
2734
2735 fn name(&self) -> &str {
2736 "toboolean"
2737 }
2738
2739 fn signature(&self) -> &Signature {
2740 &self.signature
2741 }
2742
2743 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2744 Ok(DataType::Boolean)
2745 }
2746
2747 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2748 let output_type = self.return_type(&[])?;
2749 invoke_cypher_udf(args, &output_type, |val_args| {
2750 if val_args.is_empty() {
2751 return Err(datafusion::error::DataFusionError::Execution(
2752 "toboolean(): requires 1 argument".to_string(),
2753 ));
2754 }
2755
2756 let val = &val_args[0];
2757 match val {
2758 Value::Bool(b) => Ok(Value::Bool(*b)),
2759 Value::String(s) => {
2760 let s_lower = s.to_lowercase();
2761 if s_lower == "true" {
2762 Ok(Value::Bool(true))
2763 } else if s_lower == "false" {
2764 Ok(Value::Bool(false))
2765 } else {
2766 Ok(Value::Null)
2767 }
2768 }
2769 Value::Null => Ok(Value::Null),
2770 Value::Int(i) => Ok(Value::Bool(*i != 0)),
2771 other => Err(datafusion::error::DataFusionError::Execution(format!(
2772 "InvalidArgumentValue: toboolean(): cannot convert {} to boolean",
2773 cypher_type_name(other)
2774 ))),
2775 }
2776 })
2777 }
2778}
2779
2780pub fn create_cypher_sort_key_udf() -> ScalarUDF {
2787 ScalarUDF::new_from_impl(CypherSortKeyUdf::new())
2788}
2789
2790#[derive(Debug)]
2791struct CypherSortKeyUdf {
2792 signature: Signature,
2793}
2794
2795impl CypherSortKeyUdf {
2796 fn new() -> Self {
2797 Self {
2798 signature: Signature::any(1, Volatility::Immutable),
2799 }
2800 }
2801}
2802
2803impl_udf_eq_hash!(CypherSortKeyUdf);
2804
2805impl ScalarUDFImpl for CypherSortKeyUdf {
2806 fn as_any(&self) -> &dyn Any {
2807 self
2808 }
2809
2810 fn name(&self) -> &str {
2811 "_cypher_sort_key"
2812 }
2813
2814 fn signature(&self) -> &Signature {
2815 &self.signature
2816 }
2817
2818 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2819 Ok(DataType::LargeBinary)
2820 }
2821
2822 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2823 if args.args.len() != 1 {
2824 return Err(datafusion::error::DataFusionError::Execution(
2825 "_cypher_sort_key(): requires 1 argument".to_string(),
2826 ));
2827 }
2828
2829 let arg = &args.args[0];
2830 match arg {
2831 ColumnarValue::Scalar(s) => {
2832 let val = if s.is_null() {
2833 Value::Null
2834 } else {
2835 scalar_to_value(s)?
2836 };
2837 let key = encode_cypher_sort_key(&val);
2838 Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(key))))
2839 }
2840 ColumnarValue::Array(arr) => {
2841 let mut keys: Vec<Option<Vec<u8>>> = Vec::with_capacity(arr.len());
2842 for i in 0..arr.len() {
2843 let val = if arr.is_null(i) {
2844 Value::Null
2845 } else {
2846 get_value_from_array(arr, i)?
2847 };
2848 keys.push(Some(encode_cypher_sort_key(&val)));
2849 }
2850 let array = LargeBinaryArray::from(
2851 keys.iter()
2852 .map(|k| k.as_deref())
2853 .collect::<Vec<Option<&[u8]>>>(),
2854 );
2855 Ok(ColumnarValue::Array(Arc::new(array)))
2856 }
2857 }
2858 }
2859}
2860
2861pub fn encode_cypher_sort_key(value: &Value) -> Vec<u8> {
2867 let mut buf = Vec::with_capacity(32);
2868 encode_sort_key_to_buf(value, &mut buf);
2869 buf
2870}
2871
2872fn encode_sort_key_to_buf(value: &Value, buf: &mut Vec<u8>) {
2874 if let Value::Map(map) = value {
2876 if let Some(tv) = sort_key_map_as_temporal(map) {
2877 buf.push(0x07); encode_temporal_payload(&tv, buf);
2879 return;
2880 }
2881 let rank = sort_key_map_rank(map);
2882 if rank != 0 {
2883 buf.push(rank);
2885 match rank {
2886 0x01 => encode_map_as_node_payload(map, buf),
2887 0x02 => encode_map_as_edge_payload(map, buf),
2888 0x04 => encode_map_as_path_payload(map, buf),
2889 _ => {} }
2891 return;
2892 }
2893 }
2894
2895 if let Value::String(s) = value {
2897 if let Some(tv) = sort_key_string_as_temporal(s) {
2898 buf.push(0x07); encode_temporal_payload(&tv, buf);
2900 return;
2901 }
2902 if let Some(temporal_type) = crate::datetime::classify_temporal(s) {
2905 buf.push(0x07); if encode_wide_temporal_sort_key(s, temporal_type, buf) {
2907 return;
2908 }
2909 buf.pop();
2911 }
2912 }
2913
2914 let rank = sort_key_type_rank(value);
2915 buf.push(rank);
2916
2917 match value {
2918 Value::Null => {} Value::Float(f) if f.is_nan() => {} Value::Bool(b) => buf.push(if *b { 0x01 } else { 0x00 }),
2921 Value::Int(i) => {
2922 let f = *i as f64;
2923 buf.extend_from_slice(&encode_order_preserving_f64(f));
2924 }
2925 Value::Float(f) => {
2926 buf.extend_from_slice(&encode_order_preserving_f64(*f));
2927 }
2928 Value::String(s) => {
2929 byte_stuff_terminate(s.as_bytes(), buf);
2930 }
2931 Value::Temporal(tv) => {
2932 encode_temporal_payload(tv, buf);
2933 }
2934 Value::List(items) => {
2935 encode_list_payload(items, buf);
2936 }
2937 Value::Map(map) => {
2938 encode_map_payload(map, buf);
2939 }
2940 Value::Node(node) => {
2941 encode_node_payload(node, buf);
2942 }
2943 Value::Edge(edge) => {
2944 encode_edge_payload(edge, buf);
2945 }
2946 Value::Path(path) => {
2947 encode_path_payload(path, buf);
2948 }
2949 Value::Bytes(b) => {
2951 byte_stuff_terminate(b, buf);
2952 }
2953 Value::Vector(v) => {
2954 for f in v {
2955 buf.extend_from_slice(&encode_order_preserving_f64(*f as f64));
2956 }
2957 }
2958 _ => {} }
2960}
2961
2962fn sort_key_type_rank(v: &Value) -> u8 {
2966 match v {
2967 Value::Map(map) => sort_key_map_rank(map),
2968 Value::Node(_) => 0x01,
2969 Value::Edge(_) => 0x02,
2970 Value::List(_) => 0x03,
2971 Value::Path(_) => 0x04,
2972 Value::String(_) => 0x05,
2973 Value::Bool(_) => 0x06,
2974 Value::Temporal(_) => 0x07,
2975 Value::Int(_) => 0x08,
2976 Value::Float(f) if f.is_nan() => 0x09,
2977 Value::Float(_) => 0x08,
2978 Value::Null => 0x0A,
2979 Value::Bytes(_) | Value::Vector(_) => 0x0B,
2980 _ => 0x0B, }
2982}
2983
2984fn sort_key_map_rank(map: &std::collections::HashMap<String, Value>) -> u8 {
2986 if sort_key_map_as_temporal(map).is_some() {
2987 0x07
2988 } else if map.contains_key("nodes")
2989 && (map.contains_key("relationships") || map.contains_key("edges"))
2990 {
2991 0x04 } else if map.contains_key("_eid")
2993 || map.contains_key("_src")
2994 || map.contains_key("_dst")
2995 || map.contains_key("_type")
2996 || map.contains_key("_type_name")
2997 {
2998 0x02 } else if map.contains_key("_vid") || map.contains_key("_labels") || map.contains_key("_label")
3000 {
3001 0x01 } else {
3003 0x00 }
3005}
3006
3007fn sort_key_map_as_temporal(
3011 map: &std::collections::HashMap<String, Value>,
3012) -> Option<uni_common::TemporalValue> {
3013 super::expr_eval::temporal_from_map_wrapper(map)
3014}
3015
3016fn sort_key_string_as_temporal(s: &str) -> Option<uni_common::TemporalValue> {
3020 super::expr_eval::temporal_from_value(&Value::String(s.to_string()))
3021}
3022
3023fn encode_wide_temporal_sort_key(
3030 s: &str,
3031 temporal_type: uni_common::TemporalType,
3032 buf: &mut Vec<u8>,
3033) -> bool {
3034 match temporal_type {
3035 uni_common::TemporalType::LocalDateTime => {
3036 if let Some(ndt) = parse_naive_datetime(s) {
3037 buf.push(0x03); let wide_nanos = naive_datetime_to_wide_nanos(&ndt);
3039 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3040 return true;
3041 }
3042 false
3043 }
3044 uni_common::TemporalType::DateTime => {
3045 let base = if let Some(bracket_pos) = s.find('[') {
3047 &s[..bracket_pos]
3048 } else {
3049 s
3050 };
3051 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%.f%:z") {
3052 buf.push(0x04); let utc = dt.naive_utc();
3054 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
3055 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3056 return true;
3057 }
3058 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%:z") {
3059 buf.push(0x04); let utc = dt.naive_utc();
3061 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
3062 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3063 return true;
3064 }
3065 false
3066 }
3067 uni_common::TemporalType::Date => {
3068 if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
3069 && let Some(epoch) = chrono::NaiveDate::from_ymd_opt(1970, 1, 1)
3070 {
3071 buf.push(0x00); let days = nd.signed_duration_since(epoch).num_days() as i32;
3073 buf.extend_from_slice(&encode_order_preserving_i32(days));
3074 return true;
3075 }
3076 false
3077 }
3078 _ => false,
3079 }
3080}
3081
3082fn parse_naive_datetime(s: &str) -> Option<chrono::NaiveDateTime> {
3084 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
3085 .ok()
3086 .or_else(|| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").ok())
3087}
3088
3089fn naive_datetime_to_wide_nanos(ndt: &chrono::NaiveDateTime) -> i128 {
3092 let secs = ndt.and_utc().timestamp() as i128;
3093 let subsec_nanos = ndt.and_utc().timestamp_subsec_nanos() as i128;
3094 secs * 1_000_000_000 + subsec_nanos
3095}
3096
3097fn encode_map_as_node_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3099 let mut labels: Vec<String> = Vec::new();
3101 if let Some(Value::List(lbls)) = map.get("_labels") {
3102 for l in lbls {
3103 if let Value::String(s) = l {
3104 labels.push(s.clone());
3105 }
3106 }
3107 } else if let Some(Value::String(lbl)) = map.get("_label") {
3108 labels.push(lbl.clone());
3109 }
3110 labels.sort();
3111
3112 let vid = map.get("_vid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3114
3115 let labels_joined = labels.join("\x01");
3117 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3118
3119 buf.extend_from_slice(&vid.to_be_bytes());
3121
3122 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3124 for (k, v) in map {
3125 if !k.starts_with('_') {
3126 props.insert(k.clone(), v.clone());
3127 }
3128 }
3129 encode_map_payload(&props, buf);
3130}
3131
3132fn encode_map_as_edge_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3134 let edge_type = map
3135 .get("_type")
3136 .or_else(|| map.get("_type_name"))
3137 .and_then(|v| {
3138 if let Value::String(s) = v {
3139 Some(s.as_str())
3140 } else {
3141 None
3142 }
3143 })
3144 .unwrap_or("");
3145
3146 byte_stuff_terminate(edge_type.as_bytes(), buf);
3147
3148 let src = map.get("_src").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3149 let dst = map.get("_dst").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3150 let eid = map.get("_eid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3151
3152 buf.extend_from_slice(&src.to_be_bytes());
3153 buf.extend_from_slice(&dst.to_be_bytes());
3154 buf.extend_from_slice(&eid.to_be_bytes());
3155
3156 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3158 for (k, v) in map {
3159 if !k.starts_with('_') {
3160 props.insert(k.clone(), v.clone());
3161 }
3162 }
3163 encode_map_payload(&props, buf);
3164}
3165
3166fn encode_map_as_path_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3168 if let Some(Value::List(nodes)) = map.get("nodes") {
3170 encode_list_payload(nodes, buf);
3171 } else {
3172 buf.push(0x00); }
3174 let edges = map.get("relationships").or_else(|| map.get("edges"));
3176 if let Some(Value::List(edges)) = edges {
3177 encode_list_payload(edges, buf);
3178 } else {
3179 buf.push(0x00); }
3181}
3182
3183fn encode_order_preserving_f64(f: f64) -> [u8; 8] {
3190 let bits = f.to_bits();
3191 let encoded = if bits >> 63 == 1 {
3192 !bits
3194 } else {
3195 bits ^ (1u64 << 63)
3197 };
3198 encoded.to_be_bytes()
3199}
3200
3201fn encode_order_preserving_i64(i: i64) -> [u8; 8] {
3203 ((i as u64) ^ (1u64 << 63)).to_be_bytes()
3205}
3206
3207fn encode_order_preserving_i32(i: i32) -> [u8; 4] {
3209 ((i as u32) ^ (1u32 << 31)).to_be_bytes()
3210}
3211
3212fn encode_order_preserving_i128(i: i128) -> [u8; 16] {
3214 ((i as u128) ^ (1u128 << 127)).to_be_bytes()
3215}
3216
3217fn byte_stuff_terminate(data: &[u8], buf: &mut Vec<u8>) {
3222 byte_stuff(data, buf);
3223 buf.push(0x00);
3224 buf.push(0x00);
3225}
3226
3227fn byte_stuff(data: &[u8], buf: &mut Vec<u8>) {
3229 for &b in data {
3230 buf.push(b);
3231 if b == 0x00 {
3232 buf.push(0xFF);
3233 }
3234 }
3235}
3236
3237fn encode_list_payload(items: &[Value], buf: &mut Vec<u8>) {
3242 for item in items {
3243 buf.push(0x01); let elem_key = encode_cypher_sort_key(item);
3245 byte_stuff_terminate(&elem_key, buf);
3246 }
3247 buf.push(0x00); }
3249
3250fn encode_map_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3252 let mut pairs: Vec<(&String, &Value)> = map.iter().collect();
3253 pairs.sort_by_key(|(k, _)| *k);
3254
3255 for (key, value) in pairs {
3256 buf.push(0x01); byte_stuff_terminate(key.as_bytes(), buf);
3258 let val_key = encode_cypher_sort_key(value);
3259 byte_stuff_terminate(&val_key, buf);
3260 }
3261 buf.push(0x00); }
3263
3264fn encode_node_payload(node: &uni_common::Node, buf: &mut Vec<u8>) {
3268 let mut labels = node.labels.clone();
3269 labels.sort();
3270 let labels_joined = labels.join("\x01");
3271 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3272
3273 buf.extend_from_slice(&node.vid.as_u64().to_be_bytes());
3274
3275 encode_map_payload(&node.properties, buf);
3276}
3277
3278fn encode_edge_payload(edge: &uni_common::Edge, buf: &mut Vec<u8>) {
3282 byte_stuff_terminate(edge.edge_type.as_bytes(), buf);
3283
3284 buf.extend_from_slice(&edge.src.as_u64().to_be_bytes());
3285 buf.extend_from_slice(&edge.dst.as_u64().to_be_bytes());
3286 buf.extend_from_slice(&edge.eid.as_u64().to_be_bytes());
3287
3288 encode_map_payload(&edge.properties, buf);
3289}
3290
3291fn encode_path_payload(path: &uni_common::Path, buf: &mut Vec<u8>) {
3295 for node in &path.nodes {
3297 buf.push(0x01); let mut node_key = Vec::new();
3299 node_key.push(0x01); encode_node_payload(node, &mut node_key);
3301 byte_stuff_terminate(&node_key, buf);
3302 }
3303 buf.push(0x00); for edge in &path.edges {
3307 buf.push(0x01); let mut edge_key = Vec::new();
3309 edge_key.push(0x02); encode_edge_payload(edge, &mut edge_key);
3311 byte_stuff_terminate(&edge_key, buf);
3312 }
3313 buf.push(0x00); }
3315
3316fn encode_temporal_payload(tv: &uni_common::TemporalValue, buf: &mut Vec<u8>) {
3318 match tv {
3319 uni_common::TemporalValue::Date { days_since_epoch } => {
3320 buf.push(0x00); buf.extend_from_slice(&encode_order_preserving_i32(*days_since_epoch));
3322 }
3323 uni_common::TemporalValue::LocalTime {
3324 nanos_since_midnight,
3325 } => {
3326 buf.push(0x01); buf.extend_from_slice(&encode_order_preserving_i64(*nanos_since_midnight));
3328 }
3329 uni_common::TemporalValue::Time {
3330 nanos_since_midnight,
3331 offset_seconds,
3332 } => {
3333 buf.push(0x02); let utc_nanos =
3335 *nanos_since_midnight as i128 - (*offset_seconds as i128) * 1_000_000_000;
3336 buf.extend_from_slice(&encode_order_preserving_i128(utc_nanos));
3337 }
3338 uni_common::TemporalValue::LocalDateTime { nanos_since_epoch } => {
3339 buf.push(0x03); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3342 }
3343 uni_common::TemporalValue::DateTime {
3344 nanos_since_epoch, ..
3345 } => {
3346 buf.push(0x04); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3349 }
3350 uni_common::TemporalValue::Duration {
3351 months,
3352 days,
3353 nanos,
3354 } => {
3355 buf.push(0x05); buf.extend_from_slice(&encode_order_preserving_i64(*months));
3357 buf.extend_from_slice(&encode_order_preserving_i64(*days));
3358 buf.extend_from_slice(&encode_order_preserving_i64(*nanos));
3359 }
3360 uni_common::TemporalValue::Btic { lo, hi, meta } => {
3361 buf.push(0x06); if let Ok(btic) = uni_btic::Btic::new(*lo, *hi, *meta) {
3364 buf.extend_from_slice(&uni_btic::encode::encode(&btic));
3365 } else {
3366 buf.extend_from_slice(&encode_order_preserving_i64(*lo));
3367 buf.extend_from_slice(&encode_order_preserving_i64(*hi));
3368 }
3369 }
3370 }
3371}
3372
3373#[derive(Debug)]
3382struct BticScalarUdf {
3383 name: String,
3384 signature: Signature,
3385 return_type: DataType,
3386}
3387
3388impl BticScalarUdf {
3389 fn new(name: &str, num_args: usize, return_type: DataType) -> Self {
3390 Self {
3391 name: name.to_string(),
3392 signature: Signature::new(TypeSignature::Any(num_args), Volatility::Immutable),
3393 return_type,
3394 }
3395 }
3396}
3397
3398impl_udf_eq_hash!(BticScalarUdf);
3399
3400impl ScalarUDFImpl for BticScalarUdf {
3401 fn as_any(&self) -> &dyn Any {
3402 self
3403 }
3404 fn name(&self) -> &str {
3405 &self.name
3406 }
3407 fn signature(&self) -> &Signature {
3408 &self.signature
3409 }
3410 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
3411 Ok(self.return_type.clone())
3412 }
3413 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3414 let fname = self.name.to_uppercase();
3415 let rt = self.return_type.clone();
3416 invoke_cypher_udf(args, &rt, |val_args| {
3417 crate::expr_eval::eval_btic_function(&fname, val_args)
3418 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
3419 })
3420 }
3421}
3422
3423fn register_btic_scalar_udfs(ctx: &SessionContext) -> DFResult<()> {
3425 for name in &["btic_lo", "btic_hi"] {
3427 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3428 name,
3429 1,
3430 DataType::LargeBinary,
3431 )));
3432 }
3433 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3435 "btic_duration",
3436 1,
3437 DataType::Int64,
3438 )));
3439 for name in &["btic_is_instant", "btic_is_unbounded", "btic_is_finite"] {
3441 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3442 name,
3443 1,
3444 DataType::Boolean,
3445 )));
3446 }
3447 for name in &[
3449 "btic_granularity",
3450 "btic_lo_granularity",
3451 "btic_hi_granularity",
3452 "btic_certainty",
3453 "btic_lo_certainty",
3454 "btic_hi_certainty",
3455 ] {
3456 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3457 name,
3458 1,
3459 DataType::Utf8,
3460 )));
3461 }
3462 for name in &[
3464 "btic_contains_point",
3465 "btic_overlaps",
3466 "btic_contains",
3467 "btic_before",
3468 "btic_after",
3469 "btic_meets",
3470 "btic_adjacent",
3471 "btic_disjoint",
3472 "btic_equals",
3473 "btic_starts",
3474 "btic_during",
3475 "btic_finishes",
3476 ] {
3477 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3478 name,
3479 2,
3480 DataType::Boolean,
3481 )));
3482 }
3483 for name in &["btic_intersection", "btic_span", "btic_gap"] {
3485 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3486 name,
3487 2,
3488 DataType::LargeBinary,
3489 )));
3490 }
3491 Ok(())
3492}
3493
3494#[derive(Debug, Clone)]
3500struct BticMinMaxUdaf {
3501 name: String,
3502 signature: Signature,
3503 is_max: bool,
3504}
3505
3506impl BticMinMaxUdaf {
3507 fn new(is_max: bool) -> Self {
3508 Self {
3509 name: (if is_max { "btic_max" } else { "btic_min" }).to_string(),
3510 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3511 is_max,
3512 }
3513 }
3514}
3515
3516impl_udf_eq_hash!(BticMinMaxUdaf);
3517
3518impl AggregateUDFImpl for BticMinMaxUdaf {
3519 fn as_any(&self) -> &dyn Any {
3520 self
3521 }
3522 fn name(&self) -> &str {
3523 &self.name
3524 }
3525 fn signature(&self) -> &Signature {
3526 &self.signature
3527 }
3528 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3529 Ok(DataType::LargeBinary)
3530 }
3531 fn accumulator(
3532 &self,
3533 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3534 ) -> DFResult<Box<dyn DfAccumulator>> {
3535 Ok(Box::new(BticMinMaxAccumulator {
3536 current: None,
3537 is_max: self.is_max,
3538 }))
3539 }
3540 fn state_fields(
3541 &self,
3542 args: datafusion::logical_expr::function::StateFieldsArgs,
3543 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3544 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3545 args.name,
3546 DataType::LargeBinary,
3547 true,
3548 ))])
3549 }
3550}
3551
3552#[derive(Debug)]
3553struct BticMinMaxAccumulator {
3554 current: Option<uni_btic::Btic>,
3555 is_max: bool,
3556}
3557
3558impl DfAccumulator for BticMinMaxAccumulator {
3559 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3560 let arr = &values[0];
3561 for i in 0..arr.len() {
3562 if arr.is_null(i) {
3563 continue;
3564 }
3565 let Some(btic) = decode_btic_from_array(arr, i)? else {
3566 continue;
3567 };
3568 self.current = Some(match self.current.take() {
3569 None => btic,
3570 Some(cur) => {
3571 if (self.is_max && btic > cur) || (!self.is_max && btic < cur) {
3572 btic
3573 } else {
3574 cur
3575 }
3576 }
3577 });
3578 }
3579 Ok(())
3580 }
3581 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3582 Ok(btic_to_scalar_value(self.current.as_ref()))
3583 }
3584 fn size(&self) -> usize {
3585 std::mem::size_of::<Self>()
3586 }
3587 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3588 Ok(vec![self.evaluate()?])
3589 }
3590 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3591 self.update_batch(states)
3592 }
3593}
3594
3595#[derive(Debug, Clone)]
3597struct BticSpanAggUdaf {
3598 signature: Signature,
3599}
3600
3601impl BticSpanAggUdaf {
3602 fn new() -> Self {
3603 Self {
3604 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3605 }
3606 }
3607}
3608
3609impl_udf_eq_hash!(BticSpanAggUdaf);
3610
3611impl AggregateUDFImpl for BticSpanAggUdaf {
3612 fn as_any(&self) -> &dyn Any {
3613 self
3614 }
3615 fn name(&self) -> &str {
3616 "btic_span_agg"
3617 }
3618 fn signature(&self) -> &Signature {
3619 &self.signature
3620 }
3621 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3622 Ok(DataType::LargeBinary)
3623 }
3624 fn accumulator(
3625 &self,
3626 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3627 ) -> DFResult<Box<dyn DfAccumulator>> {
3628 Ok(Box::new(BticSpanAggAccumulator { current: None }))
3629 }
3630 fn state_fields(
3631 &self,
3632 args: datafusion::logical_expr::function::StateFieldsArgs,
3633 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3634 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3635 args.name,
3636 DataType::LargeBinary,
3637 true,
3638 ))])
3639 }
3640}
3641
3642#[derive(Debug)]
3643struct BticSpanAggAccumulator {
3644 current: Option<uni_btic::Btic>,
3645}
3646
3647impl DfAccumulator for BticSpanAggAccumulator {
3648 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3649 let arr = &values[0];
3650 for i in 0..arr.len() {
3651 if arr.is_null(i) {
3652 continue;
3653 }
3654 let Some(btic) = decode_btic_from_array(arr, i)? else {
3655 continue;
3656 };
3657 self.current = Some(match self.current.take() {
3658 None => btic,
3659 Some(cur) => uni_btic::set_ops::span(&cur, &btic),
3660 });
3661 }
3662 Ok(())
3663 }
3664 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3665 Ok(btic_to_scalar_value(self.current.as_ref()))
3666 }
3667 fn size(&self) -> usize {
3668 std::mem::size_of::<Self>()
3669 }
3670 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3671 Ok(vec![self.evaluate()?])
3672 }
3673 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3674 self.update_batch(states)
3675 }
3676}
3677
3678#[derive(Debug, Clone)]
3680struct BticCountAtUdaf {
3681 signature: Signature,
3682}
3683
3684impl BticCountAtUdaf {
3685 fn new() -> Self {
3686 Self {
3687 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
3688 }
3689 }
3690}
3691
3692impl_udf_eq_hash!(BticCountAtUdaf);
3693
3694impl AggregateUDFImpl for BticCountAtUdaf {
3695 fn as_any(&self) -> &dyn Any {
3696 self
3697 }
3698 fn name(&self) -> &str {
3699 "btic_count_at"
3700 }
3701 fn signature(&self) -> &Signature {
3702 &self.signature
3703 }
3704 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3705 Ok(DataType::Int64)
3706 }
3707 fn accumulator(
3708 &self,
3709 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3710 ) -> DFResult<Box<dyn DfAccumulator>> {
3711 Ok(Box::new(BticCountAtAccumulator { count: 0 }))
3712 }
3713 fn state_fields(
3714 &self,
3715 args: datafusion::logical_expr::function::StateFieldsArgs,
3716 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3717 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3718 args.name,
3719 DataType::Int64,
3720 true,
3721 ))])
3722 }
3723}
3724
3725#[derive(Debug)]
3726struct BticCountAtAccumulator {
3727 count: i64,
3728}
3729
3730impl DfAccumulator for BticCountAtAccumulator {
3731 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3732 if values.len() < 2 {
3733 return Ok(());
3734 }
3735 let btic_arr = &values[0];
3736 let point_arr = &values[1];
3737
3738 for i in 0..btic_arr.len() {
3739 if btic_arr.is_null(i) || point_arr.is_null(i) {
3740 continue;
3741 }
3742 let Some(btic) = decode_btic_from_array(btic_arr, i)? else {
3743 continue;
3744 };
3745
3746 let point_ms = if let Some(int_arr) = point_arr.as_any().downcast_ref::<Int64Array>() {
3748 int_arr.value(i)
3749 } else if let Some(lb) = point_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3750 let val = scalar_binary_to_value(lb.value(i));
3751 match &val {
3752 Value::Int(ms) => *ms,
3753 Value::Temporal(uni_common::TemporalValue::DateTime {
3754 nanos_since_epoch,
3755 ..
3756 }) => nanos_since_epoch / 1_000_000,
3757 _ => continue,
3758 }
3759 } else {
3760 continue;
3761 };
3762
3763 if uni_btic::predicates::contains_point(&btic, point_ms) {
3764 self.count += 1;
3765 }
3766 }
3767 Ok(())
3768 }
3769 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3770 Ok(ScalarValue::Int64(Some(self.count)))
3771 }
3772 fn size(&self) -> usize {
3773 std::mem::size_of::<Self>()
3774 }
3775 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3776 Ok(vec![ScalarValue::Int64(Some(self.count))])
3777 }
3778 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3779 let arr = &states[0];
3780 if let Some(int_arr) = arr.as_any().downcast_ref::<Int64Array>() {
3781 for i in 0..int_arr.len() {
3782 if !int_arr.is_null(i) {
3783 self.count += int_arr.value(i);
3784 }
3785 }
3786 }
3787 Ok(())
3788 }
3789}
3790
3791fn btic_to_scalar_value(btic: Option<&uni_btic::Btic>) -> ScalarValue {
3793 match btic {
3794 None => ScalarValue::LargeBinary(None),
3795 Some(b) => {
3796 let val = Value::Temporal(uni_common::TemporalValue::Btic {
3797 lo: b.lo(),
3798 hi: b.hi(),
3799 meta: b.meta(),
3800 });
3801 ScalarValue::LargeBinary(Some(uni_common::cypher_value_codec::encode(&val)))
3802 }
3803 }
3804}
3805
3806fn decode_btic_from_array(arr: &ArrayRef, row: usize) -> DFResult<Option<uni_btic::Btic>> {
3808 if let Some(fsb) = arr.as_any().downcast_ref::<FixedSizeBinaryArray>() {
3810 let bytes = fsb.value(row);
3811 return uni_btic::encode::decode_slice(bytes)
3812 .map(Some)
3813 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3814 }
3815 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
3817 let val = scalar_binary_to_value(lb.value(row));
3818 if let Value::Temporal(uni_common::TemporalValue::Btic { lo, hi, meta }) = val {
3819 return uni_btic::Btic::new(lo, hi, meta)
3820 .map(Some)
3821 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3822 }
3823 return Ok(None);
3824 }
3825 Ok(None)
3826}
3827
3828pub fn create_btic_min_udaf() -> AggregateUDF {
3829 AggregateUDF::from(BticMinMaxUdaf::new(false))
3830}
3831
3832pub fn create_btic_max_udaf() -> AggregateUDF {
3833 AggregateUDF::from(BticMinMaxUdaf::new(true))
3834}
3835
3836pub fn create_btic_span_agg_udaf() -> AggregateUDF {
3837 AggregateUDF::from(BticSpanAggUdaf::new())
3838}
3839
3840pub fn create_btic_count_at_udaf() -> AggregateUDF {
3841 AggregateUDF::from(BticCountAtUdaf::new())
3842}
3843
3844pub fn invoke_cypher_string_op<F>(
3849 args: &ScalarFunctionArgs,
3850 name: &str,
3851 op: F,
3852) -> DFResult<ColumnarValue>
3853where
3854 F: Fn(&str, &str) -> bool,
3855{
3856 use arrow_array::{BooleanArray, LargeBinaryArray, LargeStringArray, StringArray};
3857 use datafusion::common::ScalarValue;
3858 use datafusion::error::DataFusionError;
3859
3860 if args.args.len() != 2 {
3861 return Err(DataFusionError::Execution(format!(
3862 "{}(): requires exactly 2 arguments",
3863 name
3864 )));
3865 }
3866
3867 let left = &args.args[0];
3868 let right = &args.args[1];
3869
3870 let extract_string = |scalar: &ScalarValue| -> Option<String> {
3872 match scalar {
3873 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => Some(s.clone()),
3874 ScalarValue::LargeBinary(Some(bytes)) => {
3875 match uni_common::cypher_value_codec::decode(bytes) {
3877 Ok(uni_common::Value::String(s)) => Some(s),
3878 _ => None,
3879 }
3880 }
3881 ScalarValue::Utf8(None)
3882 | ScalarValue::LargeUtf8(None)
3883 | ScalarValue::LargeBinary(None)
3884 | ScalarValue::Null => None,
3885 _ => None,
3886 }
3887 };
3888
3889 match (left, right) {
3890 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Scalar(r_scalar)) => {
3891 let l_str = extract_string(l_scalar);
3892 let r_str = extract_string(r_scalar);
3893
3894 match (l_str, r_str) {
3895 (Some(l), Some(r)) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(op(
3896 &l, &r,
3897 ))))),
3898 _ => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
3899 }
3900 }
3901 (ColumnarValue::Array(l_arr), ColumnarValue::Scalar(r_scalar)) => {
3902 let r_val = extract_string(r_scalar);
3904
3905 if r_val.is_none() {
3906 let nulls = arrow_array::new_null_array(&DataType::Boolean, l_arr.len());
3908 return Ok(ColumnarValue::Array(nulls));
3909 }
3910 let pattern = r_val.unwrap();
3911
3912 let result_array = if let Some(arr) = l_arr.as_any().downcast_ref::<StringArray>() {
3914 arr.iter()
3915 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3916 .collect::<BooleanArray>()
3917 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeStringArray>() {
3918 arr.iter()
3919 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3920 .collect::<BooleanArray>()
3921 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3922 arr.iter()
3924 .map(|opt_bytes| {
3925 opt_bytes.and_then(|bytes| {
3926 match uni_common::cypher_value_codec::decode(bytes) {
3927 Ok(uni_common::Value::String(s)) => Some(op(&s, &pattern)),
3928 _ => None,
3929 }
3930 })
3931 })
3932 .collect::<BooleanArray>()
3933 } else {
3934 arrow_array::new_null_array(&DataType::Boolean, l_arr.len())
3936 .as_any()
3937 .downcast_ref::<BooleanArray>()
3938 .unwrap()
3939 .clone()
3940 };
3941
3942 Ok(ColumnarValue::Array(Arc::new(result_array)))
3943 }
3944 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Array(r_arr)) => {
3945 let l_val = extract_string(l_scalar);
3947
3948 if l_val.is_none() {
3949 let nulls = arrow_array::new_null_array(&DataType::Boolean, r_arr.len());
3950 return Ok(ColumnarValue::Array(nulls));
3951 }
3952 let target = l_val.unwrap();
3953
3954 let result_array = if let Some(arr) = r_arr.as_any().downcast_ref::<StringArray>() {
3955 arr.iter()
3956 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3957 .collect::<BooleanArray>()
3958 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeStringArray>() {
3959 arr.iter()
3960 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3961 .collect::<BooleanArray>()
3962 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3963 arr.iter()
3965 .map(|opt_bytes| {
3966 opt_bytes.and_then(|bytes| {
3967 match uni_common::cypher_value_codec::decode(bytes) {
3968 Ok(uni_common::Value::String(s)) => Some(op(&target, &s)),
3969 _ => None,
3970 }
3971 })
3972 })
3973 .collect::<BooleanArray>()
3974 } else {
3975 arrow_array::new_null_array(&DataType::Boolean, r_arr.len())
3977 .as_any()
3978 .downcast_ref::<BooleanArray>()
3979 .unwrap()
3980 .clone()
3981 };
3982
3983 Ok(ColumnarValue::Array(Arc::new(result_array)))
3984 }
3985 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
3986 if l_arr.len() != r_arr.len() {
3988 return Err(DataFusionError::Execution(format!(
3989 "{}(): array lengths must match",
3990 name
3991 )));
3992 }
3993
3994 let extract_string_at = |arr: &dyn Array, idx: usize| -> Option<String> {
3996 if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
3997 str_arr.value(idx).to_string().into()
3998 } else if let Some(str_arr) = arr.as_any().downcast_ref::<LargeStringArray>() {
3999 str_arr.value(idx).to_string().into()
4000 } else if let Some(bin_arr) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
4001 if bin_arr.is_null(idx) {
4002 return None;
4003 }
4004 let bytes = bin_arr.value(idx);
4005 match uni_common::cypher_value_codec::decode(bytes) {
4006 Ok(uni_common::Value::String(s)) => Some(s),
4007 _ => None,
4008 }
4009 } else {
4010 None
4011 }
4012 };
4013
4014 let result: BooleanArray = (0..l_arr.len())
4015 .map(|idx| {
4016 match (
4017 extract_string_at(l_arr.as_ref(), idx),
4018 extract_string_at(r_arr.as_ref(), idx),
4019 ) {
4020 (Some(l_str), Some(r_str)) => Some(op(&l_str, &r_str)),
4021 _ => None,
4022 }
4023 })
4024 .collect();
4025
4026 Ok(ColumnarValue::Array(Arc::new(result)))
4027 }
4028 }
4029}
4030
4031macro_rules! define_string_op_udf {
4032 ($struct_name:ident, $udf_name:literal, $op:expr) => {
4033 #[derive(Debug)]
4034 struct $struct_name {
4035 signature: Signature,
4036 }
4037
4038 impl $struct_name {
4039 fn new() -> Self {
4040 Self {
4041 signature: Signature::any(2, Volatility::Immutable),
4043 }
4044 }
4045 }
4046
4047 impl_udf_eq_hash!($struct_name);
4048
4049 impl ScalarUDFImpl for $struct_name {
4050 fn as_any(&self) -> &dyn Any {
4051 self
4052 }
4053 fn name(&self) -> &str {
4054 $udf_name
4055 }
4056 fn signature(&self) -> &Signature {
4057 &self.signature
4058 }
4059 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4060 Ok(DataType::Boolean)
4061 }
4062
4063 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4064 invoke_cypher_string_op(&args, $udf_name, $op)
4065 }
4066 }
4067 };
4068}
4069
4070define_string_op_udf!(CypherStartsWithUdf, "_cypher_starts_with", |s, p| s
4071 .starts_with(p));
4072define_string_op_udf!(CypherEndsWithUdf, "_cypher_ends_with", |s, p| s
4073 .ends_with(p));
4074define_string_op_udf!(CypherContainsUdf, "_cypher_contains", |s, p| s.contains(p));
4075
4076pub fn create_cypher_starts_with_udf() -> ScalarUDF {
4077 ScalarUDF::new_from_impl(CypherStartsWithUdf::new())
4078}
4079pub fn create_cypher_ends_with_udf() -> ScalarUDF {
4080 ScalarUDF::new_from_impl(CypherEndsWithUdf::new())
4081}
4082pub fn create_cypher_contains_udf() -> ScalarUDF {
4083 ScalarUDF::new_from_impl(CypherContainsUdf::new())
4084}
4085
4086pub fn create_cypher_equal_udf() -> ScalarUDF {
4087 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_equal", BinaryOp::Eq))
4088}
4089pub fn create_cypher_not_equal_udf() -> ScalarUDF {
4090 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_not_equal", BinaryOp::NotEq))
4091}
4092pub fn create_cypher_lt_udf() -> ScalarUDF {
4093 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt", BinaryOp::Lt))
4094}
4095pub fn create_cypher_lt_eq_udf() -> ScalarUDF {
4096 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt_eq", BinaryOp::LtEq))
4097}
4098pub fn create_cypher_gt_udf() -> ScalarUDF {
4099 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt", BinaryOp::Gt))
4100}
4101pub fn create_cypher_gt_eq_udf() -> ScalarUDF {
4102 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt_eq", BinaryOp::GtEq))
4103}
4104
4105#[expect(clippy::match_like_matches_macro)]
4107fn apply_comparison_op(ord: std::cmp::Ordering, op: &BinaryOp) -> bool {
4108 use std::cmp::Ordering;
4109 match (ord, op) {
4110 (Ordering::Less, BinaryOp::Lt | BinaryOp::LtEq | BinaryOp::NotEq) => true,
4111 (Ordering::Equal, BinaryOp::Eq | BinaryOp::LtEq | BinaryOp::GtEq) => true,
4112 (Ordering::Greater, BinaryOp::Gt | BinaryOp::GtEq | BinaryOp::NotEq) => true,
4113 _ => false,
4114 }
4115}
4116
4117fn compare_f64(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<bool> {
4120 if lhs.is_nan() || rhs.is_nan() {
4121 Some(matches!(op, BinaryOp::NotEq))
4122 } else {
4123 Some(apply_comparison_op(lhs.partial_cmp(&rhs)?, op))
4124 }
4125}
4126
4127fn cv_bytes_as_f64(bytes: &[u8]) -> Option<f64> {
4129 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4130 match peek_tag(bytes)? {
4131 TAG_INT => decode_int(bytes).map(|i| i as f64),
4132 TAG_FLOAT => decode_float(bytes),
4133 _ => None,
4134 }
4135}
4136
4137fn compare_cv_numeric(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<bool> {
4140 use uni_common::cypher_value_codec::{TAG_INT, TAG_NULL, decode_int, peek_tag};
4141 if peek_tag(bytes) == Some(TAG_INT)
4143 && let Some(lhs_int) = decode_int(bytes)
4144 && rhs.fract() == 0.0
4146 && rhs >= i64::MIN as f64
4147 && rhs <= i64::MAX as f64
4148 {
4149 return Some(apply_comparison_op(lhs_int.cmp(&(rhs as i64)), op));
4150 }
4151 if peek_tag(bytes) == Some(TAG_NULL) {
4152 return None;
4153 }
4154 let lhs = cv_bytes_as_f64(bytes)?;
4155 compare_f64(lhs, rhs, op)
4156}
4157
4158fn try_fast_compare(
4162 lhs: &ColumnarValue,
4163 rhs: &ColumnarValue,
4164 op: &BinaryOp,
4165) -> Option<ColumnarValue> {
4166 use arrow_array::builder::BooleanBuilder;
4167 use uni_common::cypher_value_codec::{
4168 TAG_INT, TAG_NULL, TAG_STRING, decode_int, decode_string, peek_tag,
4169 };
4170
4171 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4172 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4173 _ => return None,
4174 };
4175
4176 if !matches!(lhs_arr.data_type(), DataType::LargeBinary) {
4178 return None;
4179 }
4180
4181 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4182
4183 match rhs_arr.data_type() {
4184 DataType::Int64 => {
4186 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4187 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4188 for i in 0..lb_arr.len() {
4189 if lb_arr.is_null(i) || int_arr.is_null(i) {
4190 builder.append_null();
4191 } else {
4192 match compare_cv_numeric(lb_arr.value(i), int_arr.value(i) as f64, op) {
4193 Some(result) => builder.append_value(result),
4194 None => builder.append_null(),
4195 }
4196 }
4197 }
4198 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4199 }
4200
4201 DataType::Float64 => {
4203 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4204 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4205 for i in 0..lb_arr.len() {
4206 if lb_arr.is_null(i) || float_arr.is_null(i) {
4207 builder.append_null();
4208 } else {
4209 match compare_cv_numeric(lb_arr.value(i), float_arr.value(i), op) {
4210 Some(result) => builder.append_value(result),
4211 None => builder.append_null(),
4212 }
4213 }
4214 }
4215 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4216 }
4217
4218 DataType::Utf8 | DataType::LargeUtf8 => {
4220 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4221 for i in 0..lb_arr.len() {
4222 if lb_arr.is_null(i) || rhs_arr.is_null(i) {
4223 builder.append_null();
4224 } else {
4225 let bytes = lb_arr.value(i);
4226 let rhs_str = if matches!(rhs_arr.data_type(), DataType::Utf8) {
4227 rhs_arr.as_any().downcast_ref::<StringArray>()?.value(i)
4228 } else {
4229 rhs_arr
4230 .as_any()
4231 .downcast_ref::<LargeStringArray>()?
4232 .value(i)
4233 };
4234 match peek_tag(bytes) {
4235 Some(TAG_STRING) => {
4236 if let Some(lhs_str) = decode_string(bytes) {
4237 builder.append_value(apply_comparison_op(
4238 lhs_str.as_str().cmp(rhs_str),
4239 op,
4240 ));
4241 } else {
4242 builder.append_null();
4243 }
4244 }
4245 _ => builder.append_null(),
4246 }
4247 }
4248 }
4249 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4250 }
4251
4252 DataType::LargeBinary => {
4254 let rhs_lb = rhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4255 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4256 for i in 0..lb_arr.len() {
4257 if lb_arr.is_null(i) || rhs_lb.is_null(i) {
4258 builder.append_null();
4259 } else {
4260 let lhs_bytes = lb_arr.value(i);
4261 let rhs_bytes = rhs_lb.value(i);
4262 let lhs_tag = peek_tag(lhs_bytes);
4263 let rhs_tag = peek_tag(rhs_bytes);
4264
4265 if lhs_tag == Some(TAG_NULL) || rhs_tag == Some(TAG_NULL) {
4267 builder.append_null();
4268 continue;
4269 }
4270
4271 if lhs_tag == Some(TAG_INT) && rhs_tag == Some(TAG_INT) {
4273 if let (Some(l), Some(r)) = (decode_int(lhs_bytes), decode_int(rhs_bytes)) {
4274 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4275 } else {
4276 builder.append_null();
4277 }
4278 continue;
4279 }
4280
4281 if lhs_tag == Some(TAG_STRING) && rhs_tag == Some(TAG_STRING) {
4283 if let (Some(l), Some(r)) =
4284 (decode_string(lhs_bytes), decode_string(rhs_bytes))
4285 {
4286 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4287 } else {
4288 builder.append_null();
4289 }
4290 continue;
4291 }
4292
4293 if let (Some(l), Some(r)) =
4295 (cv_bytes_as_f64(lhs_bytes), cv_bytes_as_f64(rhs_bytes))
4296 {
4297 match compare_f64(l, r, op) {
4298 Some(result) => builder.append_value(result),
4299 None => builder.append_null(),
4300 }
4301 } else {
4302 return None;
4306 }
4307 }
4308 }
4309 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4310 }
4311
4312 _ => None, }
4314}
4315
4316#[derive(Debug)]
4317struct CypherCompareUdf {
4318 name: String,
4319 op: BinaryOp,
4320 signature: Signature,
4321}
4322
4323impl CypherCompareUdf {
4324 fn new(name: &str, op: BinaryOp) -> Self {
4325 Self {
4326 name: name.to_string(),
4327 op,
4328 signature: Signature::any(2, Volatility::Immutable),
4329 }
4330 }
4331}
4332
4333impl PartialEq for CypherCompareUdf {
4334 fn eq(&self, other: &Self) -> bool {
4335 self.name == other.name
4336 }
4337}
4338
4339impl Eq for CypherCompareUdf {}
4340
4341impl std::hash::Hash for CypherCompareUdf {
4342 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4343 self.name.hash(state);
4344 }
4345}
4346
4347impl ScalarUDFImpl for CypherCompareUdf {
4348 fn as_any(&self) -> &dyn Any {
4349 self
4350 }
4351 fn name(&self) -> &str {
4352 &self.name
4353 }
4354 fn signature(&self) -> &Signature {
4355 &self.signature
4356 }
4357 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4358 Ok(DataType::Boolean)
4359 }
4360
4361 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4362 if args.args.len() != 2 {
4363 return Err(datafusion::error::DataFusionError::Execution(format!(
4364 "{}(): requires 2 arguments",
4365 self.name
4366 )));
4367 }
4368
4369 if let Some(result) = try_fast_compare(&args.args[0], &args.args[1], &self.op) {
4371 return Ok(result);
4372 }
4373
4374 let output_type = DataType::Boolean;
4376 invoke_cypher_udf(args, &output_type, |val_args| {
4377 crate::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4378 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4379 })
4380 }
4381}
4382
4383pub fn create_cypher_add_udf() -> ScalarUDF {
4389 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_add", BinaryOp::Add))
4390}
4391pub fn create_cypher_sub_udf() -> ScalarUDF {
4392 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_sub", BinaryOp::Sub))
4393}
4394pub fn create_cypher_mul_udf() -> ScalarUDF {
4395 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mul", BinaryOp::Mul))
4396}
4397pub fn create_cypher_div_udf() -> ScalarUDF {
4398 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_div", BinaryOp::Div))
4399}
4400pub fn create_cypher_mod_udf() -> ScalarUDF {
4401 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mod", BinaryOp::Mod))
4402}
4403
4404pub fn create_cypher_abs_udf() -> ScalarUDF {
4406 ScalarUDF::new_from_impl(CypherAbsUdf::new())
4407}
4408
4409pub fn cypher_abs_expr(arg: datafusion::logical_expr::Expr) -> datafusion::logical_expr::Expr {
4411 datafusion::logical_expr::Expr::ScalarFunction(
4412 datafusion::logical_expr::expr::ScalarFunction::new_udf(
4413 Arc::new(create_cypher_abs_udf()),
4414 vec![arg],
4415 ),
4416 )
4417}
4418
4419#[derive(Debug)]
4420struct CypherAbsUdf {
4421 signature: Signature,
4422}
4423
4424impl CypherAbsUdf {
4425 fn new() -> Self {
4426 Self {
4427 signature: Signature::any(1, Volatility::Immutable),
4428 }
4429 }
4430}
4431
4432impl_udf_eq_hash!(CypherAbsUdf);
4433
4434impl ScalarUDFImpl for CypherAbsUdf {
4435 fn as_any(&self) -> &dyn Any {
4436 self
4437 }
4438 fn name(&self) -> &str {
4439 "_cypher_abs"
4440 }
4441 fn signature(&self) -> &Signature {
4442 &self.signature
4443 }
4444 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
4445 Ok(DataType::LargeBinary)
4446 }
4447 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4448 if args.args.len() != 1 {
4449 return Err(datafusion::error::DataFusionError::Execution(
4450 "_cypher_abs requires exactly 1 argument".into(),
4451 ));
4452 }
4453 invoke_cypher_udf(args, &DataType::LargeBinary, |val_args| {
4454 match &val_args[0] {
4455 Value::Int(i) => i.checked_abs().map(Value::Int).ok_or_else(|| {
4456 datafusion::error::DataFusionError::Execution(
4457 "integer overflow in abs()".into(),
4458 )
4459 }),
4460 Value::Float(f) => Ok(Value::Float(f.abs())),
4461 Value::Null => Ok(Value::Null),
4462 other => Err(datafusion::error::DataFusionError::Execution(format!(
4463 "abs() requires a numeric argument, got {other:?}"
4464 ))),
4465 }
4466 })
4467 }
4468}
4469
4470pub(crate) fn checked_int_op(lhs: i64, rhs: i64, op: &BinaryOp) -> Option<i64> {
4481 match op {
4482 BinaryOp::Add => lhs.checked_add(rhs),
4483 BinaryOp::Sub => lhs.checked_sub(rhs),
4484 BinaryOp::Mul => lhs.checked_mul(rhs),
4485 BinaryOp::Div => {
4487 if rhs == 0 {
4488 None
4489 } else {
4490 lhs.checked_div(rhs)
4491 }
4492 }
4493 BinaryOp::Mod => {
4494 if rhs == 0 {
4495 None
4496 } else {
4497 lhs.checked_rem(rhs)
4498 }
4499 }
4500 _ => None,
4501 }
4502}
4503
4504enum CvArithOutcome {
4511 Value(Vec<u8>),
4513 Null,
4515 Error(datafusion::error::DataFusionError),
4517}
4518
4519fn apply_int_arithmetic(lhs: i64, rhs: i64, op: &BinaryOp) -> CvArithOutcome {
4526 use uni_common::cypher_value_codec::encode_int;
4527 if !matches!(
4528 op,
4529 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod
4530 ) {
4531 return CvArithOutcome::Null;
4532 }
4533 match checked_int_op(lhs, rhs, op) {
4534 Some(v) => CvArithOutcome::Value(encode_int(v)),
4535 None => {
4536 let msg = if matches!(op, BinaryOp::Div | BinaryOp::Mod) && rhs == 0 {
4537 "division by zero"
4538 } else {
4539 "integer overflow"
4540 };
4541 CvArithOutcome::Error(datafusion::error::DataFusionError::Execution(msg.into()))
4542 }
4543 }
4544}
4545
4546fn apply_float_arithmetic(lhs: f64, rhs: f64, op: &BinaryOp) -> CvArithOutcome {
4552 use uni_common::cypher_value_codec::encode_float;
4553 let result = match op {
4554 BinaryOp::Add => lhs + rhs,
4555 BinaryOp::Sub => lhs - rhs,
4556 BinaryOp::Mul => lhs * rhs,
4557 BinaryOp::Div => lhs / rhs, BinaryOp::Mod => lhs % rhs,
4559 _ => return CvArithOutcome::Null,
4560 };
4561 CvArithOutcome::Value(encode_float(result))
4562}
4563
4564fn cv_arithmetic_int(bytes: &[u8], rhs: i64, op: &BinaryOp) -> CvArithOutcome {
4570 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4571 match peek_tag(bytes) {
4572 Some(TAG_INT) => match decode_int(bytes) {
4573 Some(lhs) => apply_int_arithmetic(lhs, rhs, op),
4574 None => CvArithOutcome::Null,
4575 },
4576 Some(TAG_FLOAT) => match decode_float(bytes) {
4577 Some(lhs) => apply_float_arithmetic(lhs, rhs as f64, op),
4578 None => CvArithOutcome::Null,
4579 },
4580 _ => CvArithOutcome::Null,
4581 }
4582}
4583
4584fn cv_arithmetic_float(bytes: &[u8], rhs: f64, op: &BinaryOp) -> CvArithOutcome {
4590 match cv_bytes_as_f64(bytes) {
4591 Some(lhs) => apply_float_arithmetic(lhs, rhs, op),
4592 None => CvArithOutcome::Null,
4593 }
4594}
4595
4596pub(crate) fn cypher_arith_return_type(arg_types: &[DataType]) -> DataType {
4611 let any_large_binary = arg_types.iter().any(|t| matches!(t, DataType::LargeBinary));
4612 if any_large_binary {
4613 return DataType::LargeBinary;
4614 }
4615 let any_float = arg_types.iter().any(|t| matches!(t, DataType::Float64));
4616 if any_float {
4617 return DataType::Float64;
4618 }
4619 let all_int_or_null = !arg_types.is_empty()
4623 && arg_types
4624 .iter()
4625 .all(|t| matches!(t, DataType::Int64 | DataType::Null));
4626 let any_int = arg_types.iter().any(|t| matches!(t, DataType::Int64));
4627 if all_int_or_null && any_int {
4628 DataType::Int64
4629 } else {
4630 DataType::LargeBinary
4631 }
4632}
4633
4634fn try_fast_arithmetic(
4644 lhs: &ColumnarValue,
4645 rhs: &ColumnarValue,
4646 op: &BinaryOp,
4647) -> Option<DFResult<ColumnarValue>> {
4648 use arrow_array::builder::LargeBinaryBuilder;
4649
4650 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4651 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4652 _ => return None,
4653 };
4654
4655 match (lhs_arr.data_type(), rhs_arr.data_type()) {
4656 (DataType::Int64, DataType::Int64) => Some(int_kernel_arithmetic(lhs_arr, rhs_arr, op)),
4660
4661 (DataType::Int64, DataType::Float64)
4664 | (DataType::Float64, DataType::Int64)
4665 | (DataType::Float64, DataType::Float64) => {
4666 Some(float_kernel_arithmetic(lhs_arr, rhs_arr, op))
4667 }
4668
4669 (DataType::LargeBinary, DataType::Int64) => {
4671 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4672 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4673 let mut builder = LargeBinaryBuilder::new();
4674 for i in 0..lb_arr.len() {
4675 if lb_arr.is_null(i) || int_arr.is_null(i) {
4676 builder.append_null();
4677 } else {
4678 match cv_arithmetic_int(lb_arr.value(i), int_arr.value(i), op) {
4679 CvArithOutcome::Value(bytes) => builder.append_value(&bytes),
4680 CvArithOutcome::Null => builder.append_null(),
4681 CvArithOutcome::Error(e) => return Some(Err(e)),
4682 }
4683 }
4684 }
4685 Some(Ok(ColumnarValue::Array(Arc::new(builder.finish()))))
4686 }
4687
4688 (DataType::LargeBinary, DataType::Float64) => {
4690 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4691 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4692 let mut builder = LargeBinaryBuilder::new();
4693 for i in 0..lb_arr.len() {
4694 if lb_arr.is_null(i) || float_arr.is_null(i) {
4695 builder.append_null();
4696 } else {
4697 match cv_arithmetic_float(lb_arr.value(i), float_arr.value(i), op) {
4698 CvArithOutcome::Value(bytes) => builder.append_value(&bytes),
4699 CvArithOutcome::Null => builder.append_null(),
4700 CvArithOutcome::Error(e) => return Some(Err(e)),
4701 }
4702 }
4703 }
4704 Some(Ok(ColumnarValue::Array(Arc::new(builder.finish()))))
4705 }
4706
4707 _ => None, }
4709}
4710
4711fn int_kernel_arithmetic(
4719 lhs_arr: &ArrayRef,
4720 rhs_arr: &ArrayRef,
4721 op: &BinaryOp,
4722) -> DFResult<ColumnarValue> {
4723 use arrow::compute::kernels::numeric::{add, div, mul, rem, sub};
4724
4725 let result = match op {
4726 BinaryOp::Add => add(lhs_arr, rhs_arr),
4727 BinaryOp::Sub => sub(lhs_arr, rhs_arr),
4728 BinaryOp::Mul => mul(lhs_arr, rhs_arr),
4729 BinaryOp::Div => div(lhs_arr, rhs_arr),
4730 BinaryOp::Mod => rem(lhs_arr, rhs_arr),
4731 other => {
4732 return Err(datafusion::error::DataFusionError::Execution(format!(
4733 "unsupported integer arithmetic operator: {other:?}"
4734 )));
4735 }
4736 };
4737 result
4738 .map(ColumnarValue::Array)
4739 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4740}
4741
4742fn float_kernel_arithmetic(
4748 lhs_arr: &ArrayRef,
4749 rhs_arr: &ArrayRef,
4750 op: &BinaryOp,
4751) -> DFResult<ColumnarValue> {
4752 use arrow::compute::cast;
4753 use arrow::compute::kernels::numeric::{add, div, mul, rem, sub};
4754
4755 let lhs_f = cast(lhs_arr, &DataType::Float64)
4756 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
4757 let rhs_f = cast(rhs_arr, &DataType::Float64)
4758 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
4759
4760 let result = match op {
4761 BinaryOp::Add => add(&lhs_f, &rhs_f),
4762 BinaryOp::Sub => sub(&lhs_f, &rhs_f),
4763 BinaryOp::Mul => mul(&lhs_f, &rhs_f),
4764 BinaryOp::Div => div(&lhs_f, &rhs_f),
4765 BinaryOp::Mod => rem(&lhs_f, &rhs_f),
4766 other => {
4767 return Err(datafusion::error::DataFusionError::Execution(format!(
4768 "unsupported float arithmetic operator: {other:?}"
4769 )));
4770 }
4771 };
4772 result
4773 .map(ColumnarValue::Array)
4774 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4775}
4776
4777#[derive(Debug)]
4778struct CypherArithmeticUdf {
4779 name: String,
4780 op: BinaryOp,
4781 signature: Signature,
4782}
4783
4784impl CypherArithmeticUdf {
4785 fn new(name: &str, op: BinaryOp) -> Self {
4786 Self {
4787 name: name.to_string(),
4788 op,
4789 signature: Signature::any(2, Volatility::Immutable),
4790 }
4791 }
4792}
4793
4794impl PartialEq for CypherArithmeticUdf {
4795 fn eq(&self, other: &Self) -> bool {
4796 self.name == other.name
4797 }
4798}
4799
4800impl Eq for CypherArithmeticUdf {}
4801
4802impl std::hash::Hash for CypherArithmeticUdf {
4803 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4804 self.name.hash(state);
4805 }
4806}
4807
4808impl ScalarUDFImpl for CypherArithmeticUdf {
4809 fn as_any(&self) -> &dyn Any {
4810 self
4811 }
4812 fn name(&self) -> &str {
4813 &self.name
4814 }
4815 fn signature(&self) -> &Signature {
4816 &self.signature
4817 }
4818 fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
4819 Ok(cypher_arith_return_type(arg_types))
4823 }
4824
4825 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4826 if args.args.len() != 2 {
4827 return Err(datafusion::error::DataFusionError::Execution(format!(
4828 "{}(): requires 2 arguments",
4829 self.name
4830 )));
4831 }
4832
4833 if let Some(result) = try_fast_arithmetic(&args.args[0], &args.args[1], &self.op) {
4836 return result;
4837 }
4838
4839 let output_type = args.return_field.data_type().clone();
4846 invoke_cypher_udf(args, &output_type, |val_args| {
4847 crate::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4848 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4849 })
4850 }
4851}
4852
4853pub fn create_cypher_xor_udf() -> ScalarUDF {
4858 ScalarUDF::new_from_impl(CypherXorUdf::new())
4859}
4860
4861#[derive(Debug)]
4862struct CypherXorUdf {
4863 signature: Signature,
4864}
4865
4866impl CypherXorUdf {
4867 fn new() -> Self {
4868 Self {
4869 signature: Signature::any(2, Volatility::Immutable),
4870 }
4871 }
4872}
4873
4874impl_udf_eq_hash!(CypherXorUdf);
4875
4876impl ScalarUDFImpl for CypherXorUdf {
4877 fn as_any(&self) -> &dyn Any {
4878 self
4879 }
4880 fn name(&self) -> &str {
4881 "_cypher_xor"
4882 }
4883 fn signature(&self) -> &Signature {
4884 &self.signature
4885 }
4886 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4887 Ok(DataType::Boolean)
4888 }
4889
4890 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4891 let output_type = DataType::Boolean;
4892 invoke_cypher_udf(args, &output_type, |val_args| {
4893 if val_args.len() != 2 {
4894 return Err(datafusion::error::DataFusionError::Execution(
4895 "_cypher_xor(): requires 2 arguments".to_string(),
4896 ));
4897 }
4898 let coerce_bool = |v: &Value| -> Value {
4900 match v {
4901 Value::String(s) if s == "true" => Value::Bool(true),
4902 Value::String(s) if s == "false" => Value::Bool(false),
4903 other => other.clone(),
4904 }
4905 };
4906 let left = coerce_bool(&val_args[0]);
4907 let right = coerce_bool(&val_args[1]);
4908 crate::expr_eval::eval_binary_op(&left, &BinaryOp::Xor, &right)
4909 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4910 })
4911 }
4912}
4913
4914pub fn create_cv_to_bool_udf() -> ScalarUDF {
4921 ScalarUDF::new_from_impl(CvToBoolUdf::new())
4922}
4923
4924#[derive(Debug)]
4925struct CvToBoolUdf {
4926 signature: Signature,
4927}
4928
4929impl CvToBoolUdf {
4930 fn new() -> Self {
4931 Self {
4932 signature: Signature::exact(vec![DataType::LargeBinary], Volatility::Immutable),
4933 }
4934 }
4935}
4936
4937impl_udf_eq_hash!(CvToBoolUdf);
4938
4939impl ScalarUDFImpl for CvToBoolUdf {
4940 fn as_any(&self) -> &dyn Any {
4941 self
4942 }
4943 fn name(&self) -> &str {
4944 "_cv_to_bool"
4945 }
4946 fn signature(&self) -> &Signature {
4947 &self.signature
4948 }
4949 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4950 Ok(DataType::Boolean)
4951 }
4952
4953 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4954 if args.args.len() != 1 {
4955 return Err(datafusion::error::DataFusionError::Execution(
4956 "_cv_to_bool() requires exactly 1 argument".to_string(),
4957 ));
4958 }
4959
4960 match &args.args[0] {
4961 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
4962 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4964 let b = match peek_tag(bytes) {
4965 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4966 Some(TAG_NULL) => false,
4967 _ => false, };
4969 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))))
4970 }
4971 ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
4972 ColumnarValue::Array(arr) => {
4973 let lb_arr = arr
4974 .as_any()
4975 .downcast_ref::<arrow_array::LargeBinaryArray>()
4976 .ok_or_else(|| {
4977 datafusion::error::DataFusionError::Execution(format!(
4978 "_cv_to_bool(): expected LargeBinary array, got {:?}",
4979 arr.data_type()
4980 ))
4981 })?;
4982
4983 let mut builder = arrow_array::builder::BooleanBuilder::with_capacity(lb_arr.len());
4984
4985 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4987
4988 for i in 0..lb_arr.len() {
4989 if lb_arr.is_null(i) {
4990 builder.append_null();
4991 } else {
4992 let bytes = lb_arr.value(i);
4993 let b = match peek_tag(bytes) {
4994 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4995 Some(TAG_NULL) => false,
4996 _ => false, };
4998 builder.append_value(b);
4999 }
5000 }
5001 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
5002 }
5003 }
5004 }
5005}
5006
5007pub fn create_cypher_size_udf() -> ScalarUDF {
5013 ScalarUDF::new_from_impl(CypherSizeUdf::new())
5014}
5015
5016#[derive(Debug)]
5017struct CypherSizeUdf {
5018 signature: Signature,
5019}
5020
5021impl CypherSizeUdf {
5022 fn new() -> Self {
5023 Self {
5024 signature: Signature::any(1, Volatility::Immutable),
5025 }
5026 }
5027}
5028
5029impl_udf_eq_hash!(CypherSizeUdf);
5030
5031impl ScalarUDFImpl for CypherSizeUdf {
5032 fn as_any(&self) -> &dyn Any {
5033 self
5034 }
5035
5036 fn name(&self) -> &str {
5037 "_cypher_size"
5038 }
5039
5040 fn signature(&self) -> &Signature {
5041 &self.signature
5042 }
5043
5044 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5045 Ok(DataType::Int64)
5046 }
5047
5048 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5049 if args.args.len() != 1 {
5050 return Err(datafusion::error::DataFusionError::Execution(
5051 "_cypher_size() requires exactly 1 argument".to_string(),
5052 ));
5053 }
5054
5055 match &args.args[0] {
5056 ColumnarValue::Scalar(scalar) => {
5057 let result = cypher_size_scalar(scalar)?;
5058 Ok(ColumnarValue::Scalar(result))
5059 }
5060 ColumnarValue::Array(arr) => {
5061 let mut results: Vec<Option<i64>> = Vec::with_capacity(arr.len());
5062 for i in 0..arr.len() {
5063 if arr.is_null(i) {
5064 results.push(None);
5065 } else {
5066 let scalar = ScalarValue::try_from_array(arr, i)?;
5067 match cypher_size_scalar(&scalar)? {
5068 ScalarValue::Int64(v) => results.push(v),
5069 _ => results.push(None),
5070 }
5071 }
5072 }
5073 let arr: ArrayRef = Arc::new(arrow_array::Int64Array::from(results));
5074 Ok(ColumnarValue::Array(arr))
5075 }
5076 }
5077 }
5078}
5079
5080fn cypher_size_scalar(scalar: &ScalarValue) -> DFResult<ScalarValue> {
5081 match scalar {
5082 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
5084 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
5085 }
5086 ScalarValue::List(arr) => {
5089 if arr.is_empty() || arr.is_null(0) {
5090 Ok(ScalarValue::Int64(None))
5091 } else {
5092 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
5093 }
5094 }
5095 ScalarValue::LargeList(arr) => {
5096 if arr.is_empty() || arr.is_null(0) {
5097 Ok(ScalarValue::Int64(None))
5098 } else {
5099 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
5100 }
5101 }
5102 ScalarValue::LargeBinary(Some(b)) => {
5104 if let Ok(uni_val) = uni_common::cypher_value_codec::decode(b) {
5105 match &uni_val {
5106 uni_common::Value::Node(_) => {
5107 Err(datafusion::error::DataFusionError::Execution(
5108 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
5109 ))
5110 }
5111 uni_common::Value::Edge(_) => {
5112 Err(datafusion::error::DataFusionError::Execution(
5113 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
5114 ))
5115 }
5116 _ => {
5117 let json_val: serde_json::Value = uni_val.into();
5118 match json_val {
5119 serde_json::Value::Array(arr) => Ok(ScalarValue::Int64(Some(arr.len() as i64))),
5120 serde_json::Value::String(s) => {
5121 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
5122 }
5123 serde_json::Value::Object(m) => Ok(ScalarValue::Int64(Some(m.len() as i64))),
5124 _ => Ok(ScalarValue::Int64(None)),
5125 }
5126 }
5127 }
5128 } else {
5129 Ok(ScalarValue::Int64(None))
5130 }
5131 }
5132 ScalarValue::Map(arr) => {
5134 if arr.is_empty() || arr.is_null(0) {
5135 Ok(ScalarValue::Int64(None))
5136 } else {
5137 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
5139 }
5140 }
5141 ScalarValue::Struct(arr) => {
5143 if arr.is_null(0) {
5144 Ok(ScalarValue::Int64(None))
5145 } else {
5146 let schema = arr.fields();
5147 let field_names: Vec<&str> = schema.iter().map(|f| f.name().as_str()).collect();
5148 if field_names.contains(&"_vid") && !field_names.contains(&"relationships") {
5150 return Err(datafusion::error::DataFusionError::Execution(
5151 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
5152 ));
5153 }
5154 if field_names.contains(&"_eid")
5156 || (field_names.contains(&"_src") && field_names.contains(&"_dst"))
5157 {
5158 return Err(datafusion::error::DataFusionError::Execution(
5159 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
5160 ));
5161 }
5162 if let Some((rels_idx, _)) = schema
5164 .iter()
5165 .enumerate()
5166 .find(|(_, f)| f.name() == "relationships")
5167 {
5168 let rels_col = arr.column(rels_idx);
5170 if let Some(list_arr) =
5171 rels_col.as_any().downcast_ref::<arrow_array::ListArray>()
5172 {
5173 if list_arr.is_null(0) {
5174 Ok(ScalarValue::Int64(Some(0)))
5175 } else {
5176 Ok(ScalarValue::Int64(Some(list_arr.value(0).len() as i64)))
5177 }
5178 } else {
5179 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
5180 }
5181 } else {
5182 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
5183 }
5184 }
5185 }
5186 ScalarValue::Null
5188 | ScalarValue::Utf8(None)
5189 | ScalarValue::LargeUtf8(None)
5190 | ScalarValue::LargeBinary(None) => Ok(ScalarValue::Int64(None)),
5191 other => Err(datafusion::error::DataFusionError::Execution(format!(
5192 "_cypher_size(): unsupported type {other:?}"
5193 ))),
5194 }
5195}
5196
5197pub fn create_cypher_list_compare_udf() -> ScalarUDF {
5203 ScalarUDF::new_from_impl(CypherListCompareUdf::new())
5204}
5205
5206#[derive(Debug)]
5207struct CypherListCompareUdf {
5208 signature: Signature,
5209}
5210
5211impl CypherListCompareUdf {
5212 fn new() -> Self {
5213 Self {
5214 signature: Signature::any(3, Volatility::Immutable),
5215 }
5216 }
5217}
5218
5219impl_udf_eq_hash!(CypherListCompareUdf);
5220
5221impl ScalarUDFImpl for CypherListCompareUdf {
5222 fn as_any(&self) -> &dyn Any {
5223 self
5224 }
5225
5226 fn name(&self) -> &str {
5227 "_cypher_list_compare"
5228 }
5229
5230 fn signature(&self) -> &Signature {
5231 &self.signature
5232 }
5233
5234 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5235 Ok(DataType::Boolean)
5236 }
5237
5238 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5239 let output_type = DataType::Boolean;
5240 invoke_cypher_udf(args, &output_type, |val_args| {
5241 if val_args.len() != 3 {
5242 return Err(datafusion::error::DataFusionError::Execution(
5243 "_cypher_list_compare(): requires 3 arguments (left, right, op)".to_string(),
5244 ));
5245 }
5246
5247 let left = &val_args[0];
5248 let right = &val_args[1];
5249 let op_str = match &val_args[2] {
5250 Value::String(s) => s.as_str(),
5251 _ => {
5252 return Err(datafusion::error::DataFusionError::Execution(
5253 "_cypher_list_compare(): op must be a string".to_string(),
5254 ));
5255 }
5256 };
5257
5258 let (left_items, right_items) = match (left, right) {
5259 (Value::List(l), Value::List(r)) => (l, r),
5260 (Value::Null, _) | (_, Value::Null) => return Ok(Value::Null),
5261 _ => {
5262 return Err(datafusion::error::DataFusionError::Execution(
5263 "_cypher_list_compare(): both arguments must be lists".to_string(),
5264 ));
5265 }
5266 };
5267
5268 let cmp = cypher_list_cmp(left_items, right_items);
5270
5271 let result = match (op_str, cmp) {
5272 (_, None) => Value::Null,
5273 ("lt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Less),
5274 ("lteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Greater),
5275 ("gt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Greater),
5276 ("gteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Less),
5277 _ => {
5278 return Err(datafusion::error::DataFusionError::Execution(format!(
5279 "_cypher_list_compare(): unknown op '{}'",
5280 op_str
5281 )));
5282 }
5283 };
5284
5285 Ok(result)
5286 })
5287 }
5288}
5289
5290pub fn create_map_project_udf() -> ScalarUDF {
5295 ScalarUDF::new_from_impl(MapProjectUdf::new())
5296}
5297
5298#[derive(Debug)]
5299struct MapProjectUdf {
5300 signature: Signature,
5301}
5302
5303impl MapProjectUdf {
5304 fn new() -> Self {
5305 Self {
5306 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5307 }
5308 }
5309}
5310
5311impl_udf_eq_hash!(MapProjectUdf);
5312
5313impl ScalarUDFImpl for MapProjectUdf {
5314 fn as_any(&self) -> &dyn Any {
5315 self
5316 }
5317
5318 fn name(&self) -> &str {
5319 "_map_project"
5320 }
5321
5322 fn signature(&self) -> &Signature {
5323 &self.signature
5324 }
5325
5326 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5327 Ok(DataType::LargeBinary)
5328 }
5329
5330 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5331 let output_type = self.return_type(&[])?;
5332 invoke_cypher_udf(args, &output_type, |val_args| {
5333 let mut result_map = std::collections::HashMap::new();
5334 let mut i = 0;
5335 while i + 1 < val_args.len() {
5336 let key = &val_args[i];
5337 let value = &val_args[i + 1];
5338 if let Some(k) = key.as_str() {
5339 if k == "__all__" {
5340 match value {
5342 Value::Map(map) => {
5343 for (mk, mv) in map {
5344 if !mk.starts_with('_') {
5345 result_map.insert(mk.clone(), mv.clone());
5346 }
5347 }
5348 }
5349 Value::Node(node) => {
5350 for (pk, pv) in &node.properties {
5351 result_map.insert(pk.clone(), pv.clone());
5352 }
5353 }
5354 Value::Edge(edge) => {
5355 for (pk, pv) in &edge.properties {
5356 result_map.insert(pk.clone(), pv.clone());
5357 }
5358 }
5359 _ => {}
5360 }
5361 } else {
5362 result_map.insert(k.to_string(), value.clone());
5363 }
5364 }
5365 i += 2;
5366 }
5367 Ok(Value::Map(result_map))
5368 })
5369 }
5370}
5371
5372pub fn create_make_cypher_list_udf() -> ScalarUDF {
5377 ScalarUDF::new_from_impl(MakeCypherListUdf::new())
5378}
5379
5380#[derive(Debug)]
5381struct MakeCypherListUdf {
5382 signature: Signature,
5383}
5384
5385impl MakeCypherListUdf {
5386 fn new() -> Self {
5387 Self {
5388 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5389 }
5390 }
5391}
5392
5393impl_udf_eq_hash!(MakeCypherListUdf);
5394
5395impl ScalarUDFImpl for MakeCypherListUdf {
5396 fn as_any(&self) -> &dyn Any {
5397 self
5398 }
5399
5400 fn name(&self) -> &str {
5401 "_make_cypher_list"
5402 }
5403
5404 fn signature(&self) -> &Signature {
5405 &self.signature
5406 }
5407
5408 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5409 Ok(DataType::LargeBinary)
5410 }
5411
5412 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5413 let output_type = self.return_type(&[])?;
5414 invoke_cypher_udf(args, &output_type, |val_args| {
5415 Ok(Value::List(val_args.to_vec()))
5416 })
5417 }
5418}
5419
5420pub fn create_cypher_in_udf() -> ScalarUDF {
5437 ScalarUDF::new_from_impl(CypherInUdf::new())
5438}
5439
5440#[derive(Debug)]
5441struct CypherInUdf {
5442 signature: Signature,
5443}
5444
5445impl CypherInUdf {
5446 fn new() -> Self {
5447 Self {
5448 signature: Signature::any(2, Volatility::Immutable),
5449 }
5450 }
5451}
5452
5453impl_udf_eq_hash!(CypherInUdf);
5454
5455impl ScalarUDFImpl for CypherInUdf {
5456 fn as_any(&self) -> &dyn Any {
5457 self
5458 }
5459
5460 fn name(&self) -> &str {
5461 "_cypher_in"
5462 }
5463
5464 fn signature(&self) -> &Signature {
5465 &self.signature
5466 }
5467
5468 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5469 Ok(DataType::Boolean)
5470 }
5471
5472 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5473 invoke_cypher_udf(args, &DataType::Boolean, |vals| {
5474 if vals.len() != 2 {
5475 return Err(datafusion::error::DataFusionError::Execution(
5476 "_cypher_in(): requires 2 arguments".to_string(),
5477 ));
5478 }
5479 let element = &vals[0];
5480 let list_val = &vals[1];
5481
5482 if list_val.is_null() {
5484 return Ok(Value::Null);
5485 }
5486
5487 let items = match list_val {
5489 Value::List(items) => items.as_slice(),
5490 _ => {
5491 return Err(datafusion::error::DataFusionError::Execution(format!(
5492 "_cypher_in(): second argument must be a list, got {:?}",
5493 list_val
5494 )));
5495 }
5496 };
5497
5498 if element.is_null() {
5500 return if items.is_empty() {
5501 Ok(Value::Bool(false))
5502 } else {
5503 Ok(Value::Null) };
5505 }
5506
5507 let mut has_null = false;
5509 for item in items {
5510 match cypher_eq(element, item) {
5511 Some(true) => return Ok(Value::Bool(true)),
5512 None => has_null = true,
5513 Some(false) => {}
5514 }
5515 }
5516
5517 if has_null {
5518 Ok(Value::Null) } else {
5520 Ok(Value::Bool(false))
5521 }
5522 })
5523 }
5524}
5525
5526pub fn create_cypher_list_concat_udf() -> ScalarUDF {
5532 ScalarUDF::new_from_impl(CypherListConcatUdf::new())
5533}
5534
5535#[derive(Debug)]
5536struct CypherListConcatUdf {
5537 signature: Signature,
5538}
5539
5540impl CypherListConcatUdf {
5541 fn new() -> Self {
5542 Self {
5543 signature: Signature::any(2, Volatility::Immutable),
5544 }
5545 }
5546}
5547
5548impl_udf_eq_hash!(CypherListConcatUdf);
5549
5550impl ScalarUDFImpl for CypherListConcatUdf {
5551 fn as_any(&self) -> &dyn Any {
5552 self
5553 }
5554
5555 fn name(&self) -> &str {
5556 "_cypher_list_concat"
5557 }
5558
5559 fn signature(&self) -> &Signature {
5560 &self.signature
5561 }
5562
5563 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5564 Ok(DataType::LargeBinary)
5565 }
5566
5567 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5568 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5569 if vals.len() != 2 {
5570 return Err(datafusion::error::DataFusionError::Execution(
5571 "_cypher_list_concat(): requires 2 arguments".to_string(),
5572 ));
5573 }
5574 if vals[0].is_null() || vals[1].is_null() {
5576 return Ok(Value::Null);
5577 }
5578 match (&vals[0], &vals[1]) {
5579 (Value::List(left), Value::List(right)) => {
5580 let mut result = left.clone();
5581 result.extend(right.iter().cloned());
5582 Ok(Value::List(result))
5583 }
5584 (Value::List(list), elem) => {
5587 let mut result = list.clone();
5588 result.push(elem.clone());
5589 Ok(Value::List(result))
5590 }
5591 (elem, Value::List(list)) => {
5592 let mut result = vec![elem.clone()];
5593 result.extend(list.iter().cloned());
5594 Ok(Value::List(result))
5595 }
5596 _ => {
5597 crate::expr_eval::eval_binary_op(
5600 &vals[0],
5601 &uni_cypher::ast::BinaryOp::Add,
5602 &vals[1],
5603 )
5604 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
5605 }
5606 }
5607 })
5608 }
5609}
5610
5611pub fn create_cypher_list_append_udf() -> ScalarUDF {
5617 ScalarUDF::new_from_impl(CypherListAppendUdf::new())
5618}
5619
5620#[derive(Debug)]
5621struct CypherListAppendUdf {
5622 signature: Signature,
5623}
5624
5625impl CypherListAppendUdf {
5626 fn new() -> Self {
5627 Self {
5628 signature: Signature::any(2, Volatility::Immutable),
5629 }
5630 }
5631}
5632
5633impl_udf_eq_hash!(CypherListAppendUdf);
5634
5635impl ScalarUDFImpl for CypherListAppendUdf {
5636 fn as_any(&self) -> &dyn Any {
5637 self
5638 }
5639
5640 fn name(&self) -> &str {
5641 "_cypher_list_append"
5642 }
5643
5644 fn signature(&self) -> &Signature {
5645 &self.signature
5646 }
5647
5648 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5649 Ok(DataType::LargeBinary)
5650 }
5651
5652 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5653 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5654 if vals.len() != 2 {
5655 return Err(datafusion::error::DataFusionError::Execution(
5656 "_cypher_list_append(): requires 2 arguments".to_string(),
5657 ));
5658 }
5659 let left = &vals[0];
5660 let right = &vals[1];
5661
5662 if left.is_null() || right.is_null() {
5664 return Ok(Value::Null);
5665 }
5666
5667 match (left, right) {
5668 (Value::List(list), elem) => {
5670 let mut result = list.clone();
5671 result.push(elem.clone());
5672 Ok(Value::List(result))
5673 }
5674 (elem, Value::List(list)) => {
5676 let mut result = vec![elem.clone()];
5677 result.extend(list.iter().cloned());
5678 Ok(Value::List(result))
5679 }
5680 _ => Err(datafusion::error::DataFusionError::Execution(format!(
5681 "_cypher_list_append(): at least one argument must be a list, got {:?} and {:?}",
5682 left, right
5683 ))),
5684 }
5685 })
5686 }
5687}
5688
5689pub fn create_cypher_list_slice_udf() -> ScalarUDF {
5695 ScalarUDF::new_from_impl(CypherListSliceUdf::new())
5696}
5697
5698#[derive(Debug)]
5699struct CypherListSliceUdf {
5700 signature: Signature,
5701}
5702
5703impl CypherListSliceUdf {
5704 fn new() -> Self {
5705 Self {
5706 signature: Signature::any(3, Volatility::Immutable),
5707 }
5708 }
5709}
5710
5711impl_udf_eq_hash!(CypherListSliceUdf);
5712
5713impl ScalarUDFImpl for CypherListSliceUdf {
5714 fn as_any(&self) -> &dyn Any {
5715 self
5716 }
5717
5718 fn name(&self) -> &str {
5719 "_cypher_list_slice"
5720 }
5721
5722 fn signature(&self) -> &Signature {
5723 &self.signature
5724 }
5725
5726 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5727 Ok(DataType::LargeBinary)
5728 }
5729
5730 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5731 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5732 if vals.len() != 3 {
5733 return Err(datafusion::error::DataFusionError::Execution(
5734 "_cypher_list_slice(): requires 3 arguments (list, start, end)".to_string(),
5735 ));
5736 }
5737 if vals[0].is_null() {
5739 return Ok(Value::Null);
5740 }
5741 let list = match &vals[0] {
5742 Value::List(l) => l,
5743 _ => {
5744 return Err(datafusion::error::DataFusionError::Execution(format!(
5745 "_cypher_list_slice(): first argument must be a list, got {:?}",
5746 vals[0]
5747 )));
5748 }
5749 };
5750 if vals[1].is_null() || vals[2].is_null() {
5752 return Ok(Value::Null);
5753 }
5754
5755 let len = list.len() as i64;
5756 let raw_start = match &vals[1] {
5757 Value::Int(i) => *i,
5758 _ => 0,
5759 };
5760 let raw_end = match &vals[2] {
5761 Value::Int(i) => *i,
5762 _ => len,
5763 };
5764
5765 let start = if raw_start < 0 {
5767 (len + raw_start).max(0) as usize
5768 } else {
5769 (raw_start).min(len) as usize
5770 };
5771 let end = if raw_end == i64::MAX {
5772 len as usize
5773 } else if raw_end < 0 {
5774 (len + raw_end).max(0) as usize
5775 } else {
5776 (raw_end).min(len) as usize
5777 };
5778
5779 if start >= end {
5780 return Ok(Value::List(vec![]));
5781 }
5782 Ok(Value::List(list[start..end.min(list.len())].to_vec()))
5783 })
5784 }
5785}
5786
5787pub fn create_cypher_reverse_udf() -> ScalarUDF {
5798 ScalarUDF::new_from_impl(CypherReverseUdf::new())
5799}
5800
5801#[derive(Debug)]
5802struct CypherReverseUdf {
5803 signature: Signature,
5804}
5805
5806impl CypherReverseUdf {
5807 fn new() -> Self {
5808 Self {
5809 signature: Signature::any(1, Volatility::Immutable),
5810 }
5811 }
5812}
5813
5814impl_udf_eq_hash!(CypherReverseUdf);
5815
5816impl ScalarUDFImpl for CypherReverseUdf {
5817 fn as_any(&self) -> &dyn Any {
5818 self
5819 }
5820
5821 fn name(&self) -> &str {
5822 "_cypher_reverse"
5823 }
5824
5825 fn signature(&self) -> &Signature {
5826 &self.signature
5827 }
5828
5829 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5830 Ok(DataType::LargeBinary)
5831 }
5832
5833 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5834 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5835 if vals.len() != 1 {
5836 return Err(datafusion::error::DataFusionError::Execution(
5837 "_cypher_reverse(): requires exactly 1 argument".to_string(),
5838 ));
5839 }
5840 match &vals[0] {
5841 Value::Null => Ok(Value::Null),
5842 Value::String(s) => Ok(Value::String(s.chars().rev().collect())),
5843 Value::List(l) => {
5844 let mut reversed = l.clone();
5845 reversed.reverse();
5846 Ok(Value::List(reversed))
5847 }
5848 other => Err(datafusion::error::DataFusionError::Execution(format!(
5849 "_cypher_reverse(): expected string or list, got {:?}",
5850 other
5851 ))),
5852 }
5853 })
5854 }
5855}
5856
5857pub fn create_cypher_substring_udf() -> ScalarUDF {
5868 ScalarUDF::new_from_impl(CypherSubstringUdf::new())
5869}
5870
5871#[derive(Debug)]
5872struct CypherSubstringUdf {
5873 signature: Signature,
5874}
5875
5876impl CypherSubstringUdf {
5877 fn new() -> Self {
5878 Self {
5879 signature: Signature::variadic_any(Volatility::Immutable),
5880 }
5881 }
5882}
5883
5884impl_udf_eq_hash!(CypherSubstringUdf);
5885
5886impl ScalarUDFImpl for CypherSubstringUdf {
5887 fn as_any(&self) -> &dyn Any {
5888 self
5889 }
5890
5891 fn name(&self) -> &str {
5892 "_cypher_substring"
5893 }
5894
5895 fn signature(&self) -> &Signature {
5896 &self.signature
5897 }
5898
5899 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5900 Ok(DataType::Utf8)
5901 }
5902
5903 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5904 invoke_cypher_udf(args, &DataType::Utf8, |vals| {
5905 if vals.len() < 2 || vals.len() > 3 {
5906 return Err(datafusion::error::DataFusionError::Execution(
5907 "_cypher_substring(): requires 2 or 3 arguments".to_string(),
5908 ));
5909 }
5910 if vals.iter().any(|v| v.is_null()) {
5912 return Ok(Value::Null);
5913 }
5914 let s = match &vals[0] {
5915 Value::String(s) => s.as_str(),
5916 other => {
5917 return Err(datafusion::error::DataFusionError::Execution(format!(
5918 "_cypher_substring(): first argument must be a string, got {:?}",
5919 other
5920 )));
5921 }
5922 };
5923 let start = match &vals[1] {
5924 Value::Int(i) => *i,
5925 other => {
5926 return Err(datafusion::error::DataFusionError::Execution(format!(
5927 "_cypher_substring(): second argument must be an integer, got {:?}",
5928 other
5929 )));
5930 }
5931 };
5932
5933 let chars: Vec<char> = s.chars().collect();
5935 let len = chars.len() as i64;
5936
5937 let start_idx = start.max(0).min(len) as usize;
5939
5940 let end_idx = if vals.len() == 3 {
5941 let length = match &vals[2] {
5942 Value::Int(i) => *i,
5943 other => {
5944 return Err(datafusion::error::DataFusionError::Execution(format!(
5945 "_cypher_substring(): third argument must be an integer, got {:?}",
5946 other
5947 )));
5948 }
5949 };
5950 if length < 0 {
5951 return Err(datafusion::error::DataFusionError::Execution(
5952 "ArgumentError: NegativeIntegerArgument - substring length must be non-negative".to_string(),
5953 ));
5954 }
5955 (start_idx as i64 + length).min(len) as usize
5956 } else {
5957 len as usize
5958 };
5959
5960 Ok(Value::String(chars[start_idx..end_idx].iter().collect()))
5961 })
5962 }
5963}
5964
5965pub fn create_cypher_split_udf() -> ScalarUDF {
5974 ScalarUDF::new_from_impl(CypherSplitUdf::new())
5975}
5976
5977#[derive(Debug)]
5978struct CypherSplitUdf {
5979 signature: Signature,
5980}
5981
5982impl CypherSplitUdf {
5983 fn new() -> Self {
5984 Self {
5985 signature: Signature::any(2, Volatility::Immutable),
5986 }
5987 }
5988}
5989
5990impl_udf_eq_hash!(CypherSplitUdf);
5991
5992impl ScalarUDFImpl for CypherSplitUdf {
5993 fn as_any(&self) -> &dyn Any {
5994 self
5995 }
5996
5997 fn name(&self) -> &str {
5998 "_cypher_split"
5999 }
6000
6001 fn signature(&self) -> &Signature {
6002 &self.signature
6003 }
6004
6005 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6006 Ok(DataType::LargeBinary)
6007 }
6008
6009 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6010 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6011 if vals.len() != 2 {
6012 return Err(datafusion::error::DataFusionError::Execution(
6013 "_cypher_split(): requires exactly 2 arguments".to_string(),
6014 ));
6015 }
6016 if vals.iter().any(|v| v.is_null()) {
6018 return Ok(Value::Null);
6019 }
6020 let s = match &vals[0] {
6021 Value::String(s) => s.clone(),
6022 other => {
6023 return Err(datafusion::error::DataFusionError::Execution(format!(
6024 "_cypher_split(): first argument must be a string, got {:?}",
6025 other
6026 )));
6027 }
6028 };
6029 let delimiter = match &vals[1] {
6030 Value::String(d) => d.clone(),
6031 other => {
6032 return Err(datafusion::error::DataFusionError::Execution(format!(
6033 "_cypher_split(): second argument must be a string, got {:?}",
6034 other
6035 )));
6036 }
6037 };
6038 let parts: Vec<Value> = s
6039 .split(&delimiter)
6040 .map(|p| Value::String(p.to_string()))
6041 .collect();
6042 Ok(Value::List(parts))
6043 })
6044 }
6045}
6046
6047pub fn create_cypher_list_to_cv_udf() -> ScalarUDF {
6058 ScalarUDF::new_from_impl(CypherListToCvUdf::new())
6059}
6060
6061#[derive(Debug)]
6062struct CypherListToCvUdf {
6063 signature: Signature,
6064}
6065
6066impl CypherListToCvUdf {
6067 fn new() -> Self {
6068 Self {
6069 signature: Signature::any(1, Volatility::Immutable),
6070 }
6071 }
6072}
6073
6074impl_udf_eq_hash!(CypherListToCvUdf);
6075
6076impl ScalarUDFImpl for CypherListToCvUdf {
6077 fn as_any(&self) -> &dyn Any {
6078 self
6079 }
6080
6081 fn name(&self) -> &str {
6082 "_cypher_list_to_cv"
6083 }
6084
6085 fn signature(&self) -> &Signature {
6086 &self.signature
6087 }
6088
6089 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6090 Ok(DataType::LargeBinary)
6091 }
6092
6093 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6094 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6095 if vals.len() != 1 {
6096 return Err(datafusion::error::DataFusionError::Execution(
6097 "_cypher_list_to_cv(): requires exactly 1 argument".to_string(),
6098 ));
6099 }
6100 Ok(vals[0].clone())
6101 })
6102 }
6103}
6104
6105pub fn create_cypher_scalar_to_cv_udf() -> ScalarUDF {
6116 ScalarUDF::new_from_impl(CypherScalarToCvUdf::new())
6117}
6118
6119#[derive(Debug)]
6120struct CypherScalarToCvUdf {
6121 signature: Signature,
6122}
6123
6124impl CypherScalarToCvUdf {
6125 fn new() -> Self {
6126 Self {
6127 signature: Signature::any(1, Volatility::Immutable),
6128 }
6129 }
6130}
6131
6132impl_udf_eq_hash!(CypherScalarToCvUdf);
6133
6134impl ScalarUDFImpl for CypherScalarToCvUdf {
6135 fn as_any(&self) -> &dyn Any {
6136 self
6137 }
6138
6139 fn name(&self) -> &str {
6140 "_cypher_scalar_to_cv"
6141 }
6142
6143 fn signature(&self) -> &Signature {
6144 &self.signature
6145 }
6146
6147 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6148 Ok(DataType::LargeBinary)
6149 }
6150
6151 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6152 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6153 if vals.len() != 1 {
6154 return Err(datafusion::error::DataFusionError::Execution(
6155 "_cypher_scalar_to_cv(): requires exactly 1 argument".to_string(),
6156 ));
6157 }
6158 Ok(vals[0].clone())
6159 })
6160 }
6161}
6162
6163pub fn create_cypher_tail_udf() -> ScalarUDF {
6175 ScalarUDF::new_from_impl(CypherTailUdf::new())
6176}
6177
6178#[derive(Debug)]
6179struct CypherTailUdf {
6180 signature: Signature,
6181}
6182
6183impl CypherTailUdf {
6184 fn new() -> Self {
6185 Self {
6186 signature: Signature::any(1, Volatility::Immutable),
6187 }
6188 }
6189}
6190
6191impl_udf_eq_hash!(CypherTailUdf);
6192
6193impl ScalarUDFImpl for CypherTailUdf {
6194 fn as_any(&self) -> &dyn Any {
6195 self
6196 }
6197
6198 fn name(&self) -> &str {
6199 "_cypher_tail"
6200 }
6201
6202 fn signature(&self) -> &Signature {
6203 &self.signature
6204 }
6205
6206 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6207 Ok(DataType::LargeBinary)
6208 }
6209
6210 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6211 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6212 if vals.len() != 1 {
6213 return Err(datafusion::error::DataFusionError::Execution(
6214 "_cypher_tail(): requires exactly 1 argument".to_string(),
6215 ));
6216 }
6217 match &vals[0] {
6218 Value::Null => Ok(Value::Null),
6219 Value::List(l) => {
6220 if l.is_empty() {
6221 Ok(Value::List(vec![]))
6222 } else {
6223 Ok(Value::List(l[1..].to_vec()))
6224 }
6225 }
6226 other => Err(datafusion::error::DataFusionError::Execution(format!(
6227 "_cypher_tail(): expected list, got {:?}",
6228 other
6229 ))),
6230 }
6231 })
6232 }
6233}
6234
6235pub fn create_cypher_head_udf() -> ScalarUDF {
6246 ScalarUDF::new_from_impl(CypherHeadUdf::new())
6247}
6248
6249#[derive(Debug)]
6250struct CypherHeadUdf {
6251 signature: Signature,
6252}
6253
6254impl CypherHeadUdf {
6255 fn new() -> Self {
6256 Self {
6257 signature: Signature::any(1, Volatility::Immutable),
6258 }
6259 }
6260}
6261
6262impl_udf_eq_hash!(CypherHeadUdf);
6263
6264impl ScalarUDFImpl for CypherHeadUdf {
6265 fn as_any(&self) -> &dyn Any {
6266 self
6267 }
6268
6269 fn name(&self) -> &str {
6270 "head"
6271 }
6272
6273 fn signature(&self) -> &Signature {
6274 &self.signature
6275 }
6276
6277 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6278 Ok(DataType::LargeBinary)
6279 }
6280
6281 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6282 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6283 if vals.len() != 1 {
6284 return Err(datafusion::error::DataFusionError::Execution(
6285 "head(): requires exactly 1 argument".to_string(),
6286 ));
6287 }
6288 match &vals[0] {
6289 Value::Null => Ok(Value::Null),
6290 Value::List(l) => Ok(l.first().cloned().unwrap_or(Value::Null)),
6291 other => Err(datafusion::error::DataFusionError::Execution(format!(
6292 "head(): expected list, got {:?}",
6293 other
6294 ))),
6295 }
6296 })
6297 }
6298}
6299
6300pub fn create_cypher_last_udf() -> ScalarUDF {
6311 ScalarUDF::new_from_impl(CypherLastUdf::new())
6312}
6313
6314#[derive(Debug)]
6315struct CypherLastUdf {
6316 signature: Signature,
6317}
6318
6319impl CypherLastUdf {
6320 fn new() -> Self {
6321 Self {
6322 signature: Signature::any(1, Volatility::Immutable),
6323 }
6324 }
6325}
6326
6327impl_udf_eq_hash!(CypherLastUdf);
6328
6329impl ScalarUDFImpl for CypherLastUdf {
6330 fn as_any(&self) -> &dyn Any {
6331 self
6332 }
6333
6334 fn name(&self) -> &str {
6335 "last"
6336 }
6337
6338 fn signature(&self) -> &Signature {
6339 &self.signature
6340 }
6341
6342 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6343 Ok(DataType::LargeBinary)
6344 }
6345
6346 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6347 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6348 if vals.len() != 1 {
6349 return Err(datafusion::error::DataFusionError::Execution(
6350 "last(): requires exactly 1 argument".to_string(),
6351 ));
6352 }
6353 match &vals[0] {
6354 Value::Null => Ok(Value::Null),
6355 Value::List(l) => Ok(l.last().cloned().unwrap_or(Value::Null)),
6356 other => Err(datafusion::error::DataFusionError::Execution(format!(
6357 "last(): expected list, got {:?}",
6358 other
6359 ))),
6360 }
6361 })
6362 }
6363}
6364
6365fn cypher_list_cmp(left: &[Value], right: &[Value]) -> Option<std::cmp::Ordering> {
6368 let min_len = left.len().min(right.len());
6369 for i in 0..min_len {
6370 let cmp = cypher_value_cmp(&left[i], &right[i])?;
6371 if cmp != std::cmp::Ordering::Equal {
6372 return Some(cmp);
6373 }
6374 }
6375 Some(left.len().cmp(&right.len()))
6377}
6378
6379fn cypher_value_cmp(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
6382 match (a, b) {
6383 (Value::Null, Value::Null) => Some(std::cmp::Ordering::Equal),
6384 (Value::Null, _) | (_, Value::Null) => None,
6385 (Value::Int(l), Value::Int(r)) => Some(l.cmp(r)),
6386 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r),
6387 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r),
6388 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)),
6389 (Value::String(l), Value::String(r)) => Some(l.cmp(r)),
6390 (Value::Bool(l), Value::Bool(r)) => Some(l.cmp(r)),
6391 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r),
6392 _ => None, }
6394}
6395
6396struct CypherToFloat64Udf {
6404 signature: Signature,
6405}
6406
6407impl CypherToFloat64Udf {
6408 fn new() -> Self {
6409 Self {
6410 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6411 }
6412 }
6413}
6414
6415impl_udf_eq_hash!(CypherToFloat64Udf);
6416
6417impl std::fmt::Debug for CypherToFloat64Udf {
6418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6419 f.debug_struct("CypherToFloat64Udf").finish()
6420 }
6421}
6422
6423impl ScalarUDFImpl for CypherToFloat64Udf {
6424 fn as_any(&self) -> &dyn Any {
6425 self
6426 }
6427 fn name(&self) -> &str {
6428 "_cypher_to_float64"
6429 }
6430 fn signature(&self) -> &Signature {
6431 &self.signature
6432 }
6433 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6434 Ok(DataType::Float64)
6435 }
6436 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6437 if args.args.len() != 1 {
6438 return Err(datafusion::error::DataFusionError::Execution(
6439 "_cypher_to_float64 requires exactly 1 argument".into(),
6440 ));
6441 }
6442 match &args.args[0] {
6443 ColumnarValue::Scalar(scalar) => {
6444 let f = match scalar {
6445 ScalarValue::LargeBinary(Some(bytes)) => cv_bytes_as_f64(bytes),
6446 ScalarValue::Int64(Some(i)) => Some(*i as f64),
6447 ScalarValue::Int32(Some(i)) => Some(*i as f64),
6448 ScalarValue::Float64(Some(f)) => Some(*f),
6449 ScalarValue::Float32(Some(f)) => Some(*f as f64),
6450 _ => None,
6451 };
6452 Ok(ColumnarValue::Scalar(ScalarValue::Float64(f)))
6453 }
6454 ColumnarValue::Array(arr) => {
6455 let len = arr.len();
6456 let mut builder = arrow::array::Float64Builder::with_capacity(len);
6457 match arr.data_type() {
6458 DataType::LargeBinary => {
6459 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6460 for i in 0..len {
6461 if lb.is_null(i) {
6462 builder.append_null();
6463 } else {
6464 match cv_bytes_as_f64(lb.value(i)) {
6465 Some(f) => builder.append_value(f),
6466 None => builder.append_null(),
6467 }
6468 }
6469 }
6470 }
6471 DataType::Int64 => {
6472 let int_arr = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6473 for i in 0..len {
6474 if int_arr.is_null(i) {
6475 builder.append_null();
6476 } else {
6477 builder.append_value(int_arr.value(i) as f64);
6478 }
6479 }
6480 }
6481 DataType::Float64 => {
6482 let f_arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6483 for i in 0..len {
6484 if f_arr.is_null(i) {
6485 builder.append_null();
6486 } else {
6487 builder.append_value(f_arr.value(i));
6488 }
6489 }
6490 }
6491 _ => {
6492 for _ in 0..len {
6493 builder.append_null();
6494 }
6495 }
6496 }
6497 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
6498 }
6499 }
6500 }
6501}
6502
6503fn create_cypher_to_float64_udf() -> ScalarUDF {
6504 ScalarUDF::from(CypherToFloat64Udf::new())
6505}
6506
6507pub fn cypher_to_float64_expr(
6509 arg: datafusion::logical_expr::Expr,
6510) -> datafusion::logical_expr::Expr {
6511 datafusion::logical_expr::Expr::ScalarFunction(
6512 datafusion::logical_expr::expr::ScalarFunction::new_udf(
6513 Arc::new(create_cypher_to_float64_udf()),
6514 vec![arg],
6515 ),
6516 )
6517}
6518
6519pub fn cypher_to_float64_udf() -> datafusion::logical_expr::ScalarUDF {
6521 create_cypher_to_float64_udf()
6522}
6523
6524fn cypher_type_rank(val: &Value) -> u8 {
6532 match val {
6533 Value::Null => 0,
6534 Value::List(_) => 1,
6535 Value::String(_) => 2,
6536 Value::Bool(_) => 3,
6537 Value::Int(_) | Value::Float(_) => 4,
6538 _ => 5, }
6540}
6541
6542fn cypher_cross_type_cmp(a: &Value, b: &Value) -> std::cmp::Ordering {
6545 use std::cmp::Ordering;
6546 let ra = cypher_type_rank(a);
6547 let rb = cypher_type_rank(b);
6548 if ra != rb {
6549 return ra.cmp(&rb);
6550 }
6551 match (a, b) {
6553 (Value::Int(l), Value::Int(r)) => l.cmp(r),
6554 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r).unwrap_or(Ordering::Equal),
6555 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r).unwrap_or(Ordering::Equal),
6556 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)).unwrap_or(Ordering::Equal),
6557 (Value::String(l), Value::String(r)) => l.cmp(r),
6558 (Value::Bool(l), Value::Bool(r)) => l.cmp(r),
6559 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r).unwrap_or(Ordering::Equal),
6560 _ => Ordering::Equal,
6561 }
6562}
6563
6564fn scalar_binary_to_value(bytes: &[u8]) -> Value {
6566 uni_common::cypher_value_codec::decode(bytes).unwrap_or(Value::Null)
6567}
6568
6569use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, AggregateUDFImpl};
6570
6571#[derive(Debug, Clone)]
6573struct CypherMinMaxUdaf {
6574 name: String,
6575 signature: Signature,
6576 is_max: bool,
6577}
6578
6579impl CypherMinMaxUdaf {
6580 fn new(is_max: bool) -> Self {
6581 let name = if is_max { "_cypher_max" } else { "_cypher_min" };
6582 Self {
6583 name: name.to_string(),
6584 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6585 is_max,
6586 }
6587 }
6588}
6589
6590impl PartialEq for CypherMinMaxUdaf {
6591 fn eq(&self, other: &Self) -> bool {
6592 self.name == other.name
6593 }
6594}
6595
6596impl Eq for CypherMinMaxUdaf {}
6597
6598impl Hash for CypherMinMaxUdaf {
6599 fn hash<H: Hasher>(&self, state: &mut H) {
6600 self.name.hash(state);
6601 }
6602}
6603
6604impl AggregateUDFImpl for CypherMinMaxUdaf {
6605 fn as_any(&self) -> &dyn Any {
6606 self
6607 }
6608 fn name(&self) -> &str {
6609 &self.name
6610 }
6611 fn signature(&self) -> &Signature {
6612 &self.signature
6613 }
6614 fn return_type(&self, args: &[DataType]) -> DFResult<DataType> {
6615 Ok(args.first().cloned().unwrap_or(DataType::LargeBinary))
6617 }
6618 fn accumulator(
6619 &self,
6620 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6621 ) -> DFResult<Box<dyn DfAccumulator>> {
6622 Ok(Box::new(CypherMinMaxAccumulator {
6623 current: None,
6624 is_max: self.is_max,
6625 return_type: acc_args.return_field.data_type().clone(),
6626 }))
6627 }
6628 fn state_fields(
6629 &self,
6630 args: datafusion::logical_expr::function::StateFieldsArgs,
6631 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6632 Ok(vec![Arc::new(arrow::datatypes::Field::new(
6633 args.name,
6634 DataType::LargeBinary,
6635 true,
6636 ))])
6637 }
6638}
6639
6640#[derive(Debug)]
6641struct CypherMinMaxAccumulator {
6642 current: Option<Value>,
6643 is_max: bool,
6644 return_type: DataType,
6645}
6646
6647impl DfAccumulator for CypherMinMaxAccumulator {
6648 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6649 let arr = &values[0];
6650 match arr.data_type() {
6651 DataType::LargeBinary => {
6652 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6653 for i in 0..lb.len() {
6654 if lb.is_null(i) {
6655 continue;
6656 }
6657 let val = scalar_binary_to_value(lb.value(i));
6658 if val.is_null() {
6659 continue;
6660 }
6661 self.current = Some(match self.current.take() {
6662 None => val,
6663 Some(cur) => {
6664 let ord = cypher_cross_type_cmp(&val, &cur);
6665 if (self.is_max && ord == std::cmp::Ordering::Greater)
6666 || (!self.is_max && ord == std::cmp::Ordering::Less)
6667 {
6668 val
6669 } else {
6670 cur
6671 }
6672 }
6673 });
6674 }
6675 }
6676 _ => {
6677 for i in 0..arr.len() {
6679 if arr.is_null(i) {
6680 continue;
6681 }
6682 let sv = ScalarValue::try_from_array(arr, i).map_err(|e| {
6683 datafusion::error::DataFusionError::Execution(e.to_string())
6684 })?;
6685 let val = scalar_to_value(&sv)?;
6686 if val.is_null() {
6687 continue;
6688 }
6689 self.current = Some(match self.current.take() {
6690 None => val,
6691 Some(cur) => {
6692 let ord = cypher_cross_type_cmp(&val, &cur);
6693 if (self.is_max && ord == std::cmp::Ordering::Greater)
6694 || (!self.is_max && ord == std::cmp::Ordering::Less)
6695 {
6696 val
6697 } else {
6698 cur
6699 }
6700 }
6701 });
6702 }
6703 }
6704 }
6705 Ok(())
6706 }
6707 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6708 match &self.current {
6709 None => {
6710 ScalarValue::try_from(&self.return_type)
6712 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
6713 }
6714 Some(val) => {
6715 if matches!(self.return_type, DataType::LargeBinary) {
6717 let bytes = uni_common::cypher_value_codec::encode(val);
6718 return Ok(ScalarValue::LargeBinary(Some(bytes)));
6719 }
6720 match val {
6722 Value::Int(i) => match &self.return_type {
6723 DataType::Int64 => Ok(ScalarValue::Int64(Some(*i))),
6724 DataType::UInt64 => Ok(ScalarValue::UInt64(Some(*i as u64))),
6725 _ => {
6726 let bytes = uni_common::cypher_value_codec::encode(val);
6727 Ok(ScalarValue::LargeBinary(Some(bytes)))
6728 }
6729 },
6730 Value::Float(f) => match &self.return_type {
6731 DataType::Float64 => Ok(ScalarValue::Float64(Some(*f))),
6732 _ => {
6733 let bytes = uni_common::cypher_value_codec::encode(val);
6734 Ok(ScalarValue::LargeBinary(Some(bytes)))
6735 }
6736 },
6737 Value::String(s) => match &self.return_type {
6738 DataType::Utf8 => Ok(ScalarValue::Utf8(Some(s.clone()))),
6739 DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
6740 _ => {
6741 let bytes = uni_common::cypher_value_codec::encode(val);
6742 Ok(ScalarValue::LargeBinary(Some(bytes)))
6743 }
6744 },
6745 Value::Bool(b) => match &self.return_type {
6746 DataType::Boolean => Ok(ScalarValue::Boolean(Some(*b))),
6747 _ => {
6748 let bytes = uni_common::cypher_value_codec::encode(val);
6749 Ok(ScalarValue::LargeBinary(Some(bytes)))
6750 }
6751 },
6752 _ => {
6753 let bytes = uni_common::cypher_value_codec::encode(val);
6755 Ok(ScalarValue::LargeBinary(Some(bytes)))
6756 }
6757 }
6758 }
6759 }
6760 }
6761 fn size(&self) -> usize {
6762 std::mem::size_of_val(self) + self.current.as_ref().map_or(0, |_| 64)
6763 }
6764 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6765 Ok(vec![self.evaluate()?])
6766 }
6767 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6768 self.update_batch(states)
6769 }
6770}
6771
6772pub fn create_cypher_min_udaf() -> AggregateUDF {
6773 AggregateUDF::from(CypherMinMaxUdaf::new(false))
6774}
6775
6776pub fn create_cypher_max_udaf() -> AggregateUDF {
6777 AggregateUDF::from(CypherMinMaxUdaf::new(true))
6778}
6779
6780#[derive(Debug, Clone)]
6786struct CypherSumUdaf {
6787 signature: Signature,
6788}
6789
6790impl CypherSumUdaf {
6791 fn new() -> Self {
6792 Self {
6793 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6794 }
6795 }
6796}
6797
6798impl PartialEq for CypherSumUdaf {
6799 fn eq(&self, other: &Self) -> bool {
6800 self.signature == other.signature
6801 }
6802}
6803
6804impl Eq for CypherSumUdaf {}
6805
6806impl Hash for CypherSumUdaf {
6807 fn hash<H: Hasher>(&self, state: &mut H) {
6808 self.name().hash(state);
6809 }
6810}
6811
6812impl AggregateUDFImpl for CypherSumUdaf {
6813 fn as_any(&self) -> &dyn Any {
6814 self
6815 }
6816 fn name(&self) -> &str {
6817 "_cypher_sum"
6818 }
6819 fn signature(&self) -> &Signature {
6820 &self.signature
6821 }
6822 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6823 Ok(DataType::LargeBinary)
6826 }
6827 fn accumulator(
6828 &self,
6829 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6830 ) -> DFResult<Box<dyn DfAccumulator>> {
6831 Ok(Box::new(CypherSumAccumulator {
6832 sum: 0.0,
6833 all_ints: true,
6834 int_sum: 0i64,
6835 has_value: false,
6836 }))
6837 }
6838 fn state_fields(
6839 &self,
6840 args: datafusion::logical_expr::function::StateFieldsArgs,
6841 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6842 Ok(vec![
6843 Arc::new(arrow::datatypes::Field::new(
6844 format!("{}_sum", args.name),
6845 DataType::Float64,
6846 true,
6847 )),
6848 Arc::new(arrow::datatypes::Field::new(
6849 format!("{}_int_sum", args.name),
6850 DataType::Int64,
6851 true,
6852 )),
6853 Arc::new(arrow::datatypes::Field::new(
6854 format!("{}_all_ints", args.name),
6855 DataType::Boolean,
6856 true,
6857 )),
6858 Arc::new(arrow::datatypes::Field::new(
6859 format!("{}_has_value", args.name),
6860 DataType::Boolean,
6861 true,
6862 )),
6863 ])
6864 }
6865}
6866
6867#[derive(Debug)]
6868struct CypherSumAccumulator {
6869 sum: f64,
6870 all_ints: bool,
6871 int_sum: i64,
6872 has_value: bool,
6873}
6874
6875impl DfAccumulator for CypherSumAccumulator {
6876 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6877 let arr = &values[0];
6878 for i in 0..arr.len() {
6879 if arr.is_null(i) {
6880 continue;
6881 }
6882 match arr.data_type() {
6883 DataType::LargeBinary => {
6884 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6885 let bytes = lb.value(i);
6886 use uni_common::cypher_value_codec::{
6887 TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag,
6888 };
6889 match peek_tag(bytes) {
6890 Some(TAG_INT) => {
6891 if let Some(v) = decode_int(bytes) {
6892 self.sum += v as f64;
6893 self.int_sum = self.int_sum.wrapping_add(v);
6894 self.has_value = true;
6895 }
6896 }
6897 Some(TAG_FLOAT) => {
6898 if let Some(v) = decode_float(bytes) {
6899 self.sum += v;
6900 self.all_ints = false;
6901 self.has_value = true;
6902 }
6903 }
6904 _ => {} }
6906 }
6907 DataType::Int64 => {
6908 let a = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6909 let v = a.value(i);
6910 self.sum += v as f64;
6911 self.int_sum = self.int_sum.wrapping_add(v);
6912 self.has_value = true;
6913 }
6914 DataType::Float64 => {
6915 let a = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6916 self.sum += a.value(i);
6917 self.all_ints = false;
6918 self.has_value = true;
6919 }
6920 _ => {}
6921 }
6922 }
6923 Ok(())
6924 }
6925 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6926 if !self.has_value {
6927 return Ok(ScalarValue::LargeBinary(None));
6928 }
6929 let val = if self.all_ints {
6930 Value::Int(self.int_sum)
6931 } else {
6932 Value::Float(self.sum)
6933 };
6934 let bytes = uni_common::cypher_value_codec::encode(&val);
6935 Ok(ScalarValue::LargeBinary(Some(bytes)))
6936 }
6937 fn size(&self) -> usize {
6938 std::mem::size_of_val(self)
6939 }
6940 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6941 Ok(vec![
6942 ScalarValue::Float64(Some(self.sum)),
6943 ScalarValue::Int64(Some(self.int_sum)),
6944 ScalarValue::Boolean(Some(self.all_ints)),
6945 ScalarValue::Boolean(Some(self.has_value)),
6946 ])
6947 }
6948 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6949 let sum_arr = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
6950 let int_sum_arr = states[1].as_any().downcast_ref::<Int64Array>().unwrap();
6951 let all_ints_arr = states[2].as_any().downcast_ref::<BooleanArray>().unwrap();
6952 let has_value_arr = states[3].as_any().downcast_ref::<BooleanArray>().unwrap();
6953 for i in 0..sum_arr.len() {
6954 if !has_value_arr.is_null(i) && has_value_arr.value(i) {
6955 self.sum += sum_arr.value(i);
6956 self.int_sum = self.int_sum.wrapping_add(int_sum_arr.value(i));
6957 if !all_ints_arr.value(i) {
6958 self.all_ints = false;
6959 }
6960 self.has_value = true;
6961 }
6962 }
6963 Ok(())
6964 }
6965}
6966
6967pub fn create_cypher_sum_udaf() -> AggregateUDF {
6968 AggregateUDF::from(CypherSumUdaf::new())
6969}
6970
6971#[derive(Debug, Clone)]
6978struct CypherCollectUdaf {
6979 signature: Signature,
6980}
6981
6982impl CypherCollectUdaf {
6983 fn new() -> Self {
6984 Self {
6985 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6986 }
6987 }
6988}
6989
6990impl PartialEq for CypherCollectUdaf {
6991 fn eq(&self, other: &Self) -> bool {
6992 self.signature == other.signature
6993 }
6994}
6995
6996impl Eq for CypherCollectUdaf {}
6997
6998impl Hash for CypherCollectUdaf {
6999 fn hash<H: Hasher>(&self, state: &mut H) {
7000 self.name().hash(state);
7001 }
7002}
7003
7004impl AggregateUDFImpl for CypherCollectUdaf {
7005 fn as_any(&self) -> &dyn Any {
7006 self
7007 }
7008 fn name(&self) -> &str {
7009 "_cypher_collect"
7010 }
7011 fn signature(&self) -> &Signature {
7012 &self.signature
7013 }
7014 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7015 Ok(DataType::LargeBinary)
7016 }
7017 fn accumulator(
7018 &self,
7019 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7020 ) -> DFResult<Box<dyn DfAccumulator>> {
7021 Ok(Box::new(CypherCollectAccumulator {
7022 values: Vec::new(),
7023 distinct: acc_args.is_distinct,
7024 }))
7025 }
7026 fn state_fields(
7027 &self,
7028 args: datafusion::logical_expr::function::StateFieldsArgs,
7029 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7030 Ok(vec![Arc::new(arrow::datatypes::Field::new(
7031 args.name,
7032 DataType::LargeBinary,
7033 true,
7034 ))])
7035 }
7036}
7037
7038#[derive(Debug)]
7039struct CypherCollectAccumulator {
7040 values: Vec<Value>,
7041 distinct: bool,
7042}
7043
7044impl DfAccumulator for CypherCollectAccumulator {
7045 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7046 let arr = &values[0];
7047 for i in 0..arr.len() {
7048 if arr.is_null(i) {
7049 continue;
7050 }
7051 if let Some(struct_arr) = arr.as_any().downcast_ref::<arrow::array::StructArray>()
7055 && struct_arr.num_columns() > 0
7056 && struct_arr.column(0).is_null(i)
7057 {
7058 continue;
7059 }
7060 let sv = ScalarValue::try_from_array(arr, i)
7061 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
7062 let val = scalar_to_value(&sv)?;
7063 if val.is_null() {
7064 continue;
7065 }
7066 if self.distinct {
7067 let repr = val.to_string();
7069 if self.values.iter().any(|v| v.to_string() == repr) {
7070 continue;
7071 }
7072 }
7073 self.values.push(val);
7074 }
7075 Ok(())
7076 }
7077 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7078 let val = Value::List(self.values.clone());
7080 let bytes = uni_common::cypher_value_codec::encode(&val);
7081 Ok(ScalarValue::LargeBinary(Some(bytes)))
7082 }
7083 fn size(&self) -> usize {
7084 std::mem::size_of_val(self) + self.values.len() * 64
7085 }
7086 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7087 Ok(vec![self.evaluate()?])
7088 }
7089 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7090 let arr = &states[0];
7092 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
7093 for i in 0..lb.len() {
7094 if lb.is_null(i) {
7095 continue;
7096 }
7097 let val = scalar_binary_to_value(lb.value(i));
7098 if let Value::List(items) = val {
7099 for item in items {
7100 if !item.is_null() {
7101 if self.distinct {
7102 let repr = item.to_string();
7103 if self.values.iter().any(|v| v.to_string() == repr) {
7104 continue;
7105 }
7106 }
7107 self.values.push(item);
7108 }
7109 }
7110 }
7111 }
7112 }
7113 Ok(())
7114 }
7115}
7116
7117pub fn create_cypher_collect_udaf() -> AggregateUDF {
7118 AggregateUDF::from(CypherCollectUdaf::new())
7119}
7120
7121pub fn create_cypher_collect_expr(
7123 arg: datafusion::logical_expr::Expr,
7124 distinct: bool,
7125) -> datafusion::logical_expr::Expr {
7126 let udaf = Arc::new(create_cypher_collect_udaf());
7129 if distinct {
7130 datafusion::logical_expr::Expr::AggregateFunction(
7132 datafusion::logical_expr::expr::AggregateFunction::new_udf(
7133 udaf,
7134 vec![arg],
7135 true, None,
7137 vec![],
7138 None,
7139 ),
7140 )
7141 } else {
7142 udaf.call(vec![arg])
7143 }
7144}
7145
7146#[derive(Debug, Clone)]
7152struct CypherPercentileDiscUdaf {
7153 signature: Signature,
7154}
7155
7156impl CypherPercentileDiscUdaf {
7157 fn new() -> Self {
7158 Self {
7159 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7160 }
7161 }
7162}
7163
7164impl PartialEq for CypherPercentileDiscUdaf {
7165 fn eq(&self, other: &Self) -> bool {
7166 self.signature == other.signature
7167 }
7168}
7169
7170impl Eq for CypherPercentileDiscUdaf {}
7171
7172impl Hash for CypherPercentileDiscUdaf {
7173 fn hash<H: Hasher>(&self, state: &mut H) {
7174 self.name().hash(state);
7175 }
7176}
7177
7178impl AggregateUDFImpl for CypherPercentileDiscUdaf {
7179 fn as_any(&self) -> &dyn Any {
7180 self
7181 }
7182 fn name(&self) -> &str {
7183 "percentiledisc"
7184 }
7185 fn signature(&self) -> &Signature {
7186 &self.signature
7187 }
7188 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7189 Ok(DataType::Float64)
7190 }
7191 fn accumulator(
7192 &self,
7193 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7194 ) -> DFResult<Box<dyn DfAccumulator>> {
7195 Ok(Box::new(CypherPercentileDiscAccumulator {
7196 values: Vec::new(),
7197 percentile: None,
7198 }))
7199 }
7200 fn state_fields(
7201 &self,
7202 args: datafusion::logical_expr::function::StateFieldsArgs,
7203 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7204 Ok(vec![
7205 Arc::new(arrow::datatypes::Field::new(
7206 format!("{}_values", args.name),
7207 DataType::List(Arc::new(arrow::datatypes::Field::new(
7208 "item",
7209 DataType::Float64,
7210 true,
7211 ))),
7212 true,
7213 )),
7214 Arc::new(arrow::datatypes::Field::new(
7215 format!("{}_percentile", args.name),
7216 DataType::Float64,
7217 true,
7218 )),
7219 ])
7220 }
7221}
7222
7223#[derive(Debug)]
7224struct CypherPercentileDiscAccumulator {
7225 values: Vec<f64>,
7226 percentile: Option<f64>,
7227}
7228
7229impl CypherPercentileDiscAccumulator {
7230 fn extract_f64(arr: &ArrayRef, i: usize) -> Option<f64> {
7231 if arr.is_null(i) {
7232 return None;
7233 }
7234 match arr.data_type() {
7235 DataType::LargeBinary => {
7236 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
7237 cv_bytes_as_f64(lb.value(i))
7238 }
7239 DataType::Int64 => {
7240 let a = arr.as_any().downcast_ref::<Int64Array>()?;
7241 Some(a.value(i) as f64)
7242 }
7243 DataType::Float64 => {
7244 let a = arr.as_any().downcast_ref::<Float64Array>()?;
7245 Some(a.value(i))
7246 }
7247 DataType::Int32 => {
7248 let a = arr.as_any().downcast_ref::<Int32Array>()?;
7249 Some(a.value(i) as f64)
7250 }
7251 DataType::Float32 => {
7252 let a = arr.as_any().downcast_ref::<Float32Array>()?;
7253 Some(a.value(i) as f64)
7254 }
7255 _ => None,
7256 }
7257 }
7258
7259 fn extract_percentile(arr: &ArrayRef, i: usize) -> Option<f64> {
7260 if arr.is_null(i) {
7261 return None;
7262 }
7263 match arr.data_type() {
7264 DataType::Float64 => {
7265 let a = arr.as_any().downcast_ref::<Float64Array>()?;
7266 Some(a.value(i))
7267 }
7268 DataType::Int64 => {
7269 let a = arr.as_any().downcast_ref::<Int64Array>()?;
7270 Some(a.value(i) as f64)
7271 }
7272 DataType::LargeBinary => {
7273 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
7274 cv_bytes_as_f64(lb.value(i))
7275 }
7276 _ => None,
7277 }
7278 }
7279}
7280
7281impl DfAccumulator for CypherPercentileDiscAccumulator {
7282 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7283 let expr_arr = &values[0];
7284 let pct_arr = &values[1];
7285 for i in 0..expr_arr.len() {
7286 if self.percentile.is_none()
7288 && let Some(p) = Self::extract_percentile(pct_arr, i)
7289 {
7290 if !(0.0..=1.0).contains(&p) {
7291 return Err(datafusion::error::DataFusionError::Execution(
7292 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7293 ));
7294 }
7295 self.percentile = Some(p);
7296 }
7297 if let Some(f) = Self::extract_f64(expr_arr, i) {
7298 self.values.push(f);
7299 }
7300 }
7301 Ok(())
7302 }
7303 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7304 let pct = match self.percentile {
7305 Some(p) if !(0.0..=1.0).contains(&p) => {
7306 return Err(datafusion::error::DataFusionError::Execution(
7307 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7308 ));
7309 }
7310 Some(p) => p,
7311 None => 0.0,
7312 };
7313 if self.values.is_empty() {
7314 return Ok(ScalarValue::Float64(None));
7315 }
7316 self.values
7317 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7318 let n = self.values.len();
7319 let idx = (pct * (n as f64 - 1.0)).round() as usize;
7320 let idx = idx.min(n - 1);
7321 let result = self.values[idx];
7322 Ok(ScalarValue::Float64(Some(result)))
7323 }
7324 fn size(&self) -> usize {
7325 std::mem::size_of_val(self) + self.values.capacity() * 8
7326 }
7327 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7328 let list_values: Vec<ScalarValue> = self
7330 .values
7331 .iter()
7332 .map(|f| ScalarValue::Float64(Some(*f)))
7333 .collect();
7334 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7335 &list_values,
7336 &DataType::Float64,
7337 true,
7338 ));
7339 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7340 }
7341 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7342 let list_arr = &states[0];
7344 let pct_arr = &states[1];
7345 if self.percentile.is_none()
7347 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7348 {
7349 for i in 0..f64_arr.len() {
7350 if !f64_arr.is_null(i) {
7351 self.percentile = Some(f64_arr.value(i));
7352 break;
7353 }
7354 }
7355 }
7356 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7358 for i in 0..list_array.len() {
7359 if list_array.is_null(i) {
7360 continue;
7361 }
7362 let inner = list_array.value(i);
7363 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7364 for j in 0..f64_arr.len() {
7365 if !f64_arr.is_null(j) {
7366 self.values.push(f64_arr.value(j));
7367 }
7368 }
7369 }
7370 }
7371 }
7372 Ok(())
7373 }
7374}
7375
7376#[derive(Debug, Clone)]
7378struct CypherPercentileContUdaf {
7379 signature: Signature,
7380}
7381
7382impl CypherPercentileContUdaf {
7383 fn new() -> Self {
7384 Self {
7385 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7386 }
7387 }
7388}
7389
7390impl PartialEq for CypherPercentileContUdaf {
7391 fn eq(&self, other: &Self) -> bool {
7392 self.signature == other.signature
7393 }
7394}
7395
7396impl Eq for CypherPercentileContUdaf {}
7397
7398impl Hash for CypherPercentileContUdaf {
7399 fn hash<H: Hasher>(&self, state: &mut H) {
7400 self.name().hash(state);
7401 }
7402}
7403
7404impl AggregateUDFImpl for CypherPercentileContUdaf {
7405 fn as_any(&self) -> &dyn Any {
7406 self
7407 }
7408 fn name(&self) -> &str {
7409 "percentilecont"
7410 }
7411 fn signature(&self) -> &Signature {
7412 &self.signature
7413 }
7414 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7415 Ok(DataType::Float64)
7416 }
7417 fn accumulator(
7418 &self,
7419 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7420 ) -> DFResult<Box<dyn DfAccumulator>> {
7421 Ok(Box::new(CypherPercentileContAccumulator {
7422 values: Vec::new(),
7423 percentile: None,
7424 }))
7425 }
7426 fn state_fields(
7427 &self,
7428 args: datafusion::logical_expr::function::StateFieldsArgs,
7429 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7430 Ok(vec![
7431 Arc::new(arrow::datatypes::Field::new(
7432 format!("{}_values", args.name),
7433 DataType::List(Arc::new(arrow::datatypes::Field::new(
7434 "item",
7435 DataType::Float64,
7436 true,
7437 ))),
7438 true,
7439 )),
7440 Arc::new(arrow::datatypes::Field::new(
7441 format!("{}_percentile", args.name),
7442 DataType::Float64,
7443 true,
7444 )),
7445 ])
7446 }
7447}
7448
7449#[derive(Debug)]
7450struct CypherPercentileContAccumulator {
7451 values: Vec<f64>,
7452 percentile: Option<f64>,
7453}
7454
7455impl DfAccumulator for CypherPercentileContAccumulator {
7456 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7457 let expr_arr = &values[0];
7458 let pct_arr = &values[1];
7459 for i in 0..expr_arr.len() {
7460 if self.percentile.is_none()
7461 && let Some(p) = CypherPercentileDiscAccumulator::extract_percentile(pct_arr, i)
7462 {
7463 if !(0.0..=1.0).contains(&p) {
7464 return Err(datafusion::error::DataFusionError::Execution(
7465 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7466 ));
7467 }
7468 self.percentile = Some(p);
7469 }
7470 if let Some(f) = CypherPercentileDiscAccumulator::extract_f64(expr_arr, i) {
7471 self.values.push(f);
7472 }
7473 }
7474 Ok(())
7475 }
7476 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7477 let pct = match self.percentile {
7478 Some(p) if !(0.0..=1.0).contains(&p) => {
7479 return Err(datafusion::error::DataFusionError::Execution(
7480 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7481 ));
7482 }
7483 Some(p) => p,
7484 None => 0.0,
7485 };
7486 if self.values.is_empty() {
7487 return Ok(ScalarValue::Float64(None));
7488 }
7489 self.values
7490 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7491 let n = self.values.len();
7492 if n == 1 {
7493 return Ok(ScalarValue::Float64(Some(self.values[0])));
7494 }
7495 let pos = pct * (n as f64 - 1.0);
7496 let lower = pos.floor() as usize;
7497 let upper = pos.ceil() as usize;
7498 let lower = lower.min(n - 1);
7499 let upper = upper.min(n - 1);
7500 if lower == upper {
7501 Ok(ScalarValue::Float64(Some(self.values[lower])))
7502 } else {
7503 let frac = pos - lower as f64;
7504 let result = self.values[lower] + frac * (self.values[upper] - self.values[lower]);
7505 Ok(ScalarValue::Float64(Some(result)))
7506 }
7507 }
7508 fn size(&self) -> usize {
7509 std::mem::size_of_val(self) + self.values.capacity() * 8
7510 }
7511 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7512 let list_values: Vec<ScalarValue> = self
7513 .values
7514 .iter()
7515 .map(|f| ScalarValue::Float64(Some(*f)))
7516 .collect();
7517 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7518 &list_values,
7519 &DataType::Float64,
7520 true,
7521 ));
7522 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7523 }
7524 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7525 let list_arr = &states[0];
7526 let pct_arr = &states[1];
7527 if self.percentile.is_none()
7528 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7529 {
7530 for i in 0..f64_arr.len() {
7531 if !f64_arr.is_null(i) {
7532 self.percentile = Some(f64_arr.value(i));
7533 break;
7534 }
7535 }
7536 }
7537 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7538 for i in 0..list_array.len() {
7539 if list_array.is_null(i) {
7540 continue;
7541 }
7542 let inner = list_array.value(i);
7543 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7544 for j in 0..f64_arr.len() {
7545 if !f64_arr.is_null(j) {
7546 self.values.push(f64_arr.value(j));
7547 }
7548 }
7549 }
7550 }
7551 }
7552 Ok(())
7553 }
7554}
7555
7556pub fn create_cypher_percentile_disc_udaf() -> AggregateUDF {
7557 AggregateUDF::from(CypherPercentileDiscUdaf::new())
7558}
7559
7560pub fn create_cypher_percentile_cont_udaf() -> AggregateUDF {
7561 AggregateUDF::from(CypherPercentileContUdaf::new())
7562}
7563
7564fn invoke_similarity_udf(
7574 func_name: &str,
7575 min_args: usize,
7576 args: ScalarFunctionArgs,
7577) -> DFResult<ColumnarValue> {
7578 let output_type = DataType::Float64;
7579 invoke_cypher_udf(args, &output_type, |val_args| {
7580 if val_args.len() < min_args {
7581 return Err(datafusion::error::DataFusionError::Execution(format!(
7582 "{} requires at least {} arguments",
7583 func_name, min_args
7584 )));
7585 }
7586 crate::similar_to::eval_similar_to_pure(&val_args[0], &val_args[1])
7587 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
7588 })
7589}
7590
7591pub fn create_similar_to_udf() -> ScalarUDF {
7593 ScalarUDF::new_from_impl(SimilarToUdf::new())
7594}
7595
7596#[derive(Debug)]
7597struct SimilarToUdf {
7598 signature: Signature,
7599}
7600
7601impl SimilarToUdf {
7602 fn new() -> Self {
7603 Self {
7604 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
7605 }
7606 }
7607}
7608
7609impl_udf_eq_hash!(SimilarToUdf);
7610
7611impl ScalarUDFImpl for SimilarToUdf {
7612 fn as_any(&self) -> &dyn Any {
7613 self
7614 }
7615
7616 fn name(&self) -> &str {
7617 "similar_to"
7618 }
7619
7620 fn signature(&self) -> &Signature {
7621 &self.signature
7622 }
7623
7624 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7625 Ok(DataType::Float64)
7626 }
7627
7628 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7629 invoke_similarity_udf("similar_to", 2, args)
7630 }
7631}
7632
7633pub fn create_vector_similarity_udf() -> ScalarUDF {
7635 ScalarUDF::new_from_impl(VectorSimilarityUdf::new())
7636}
7637
7638#[derive(Debug)]
7639struct VectorSimilarityUdf {
7640 signature: Signature,
7641}
7642
7643impl VectorSimilarityUdf {
7644 fn new() -> Self {
7645 Self {
7646 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7647 }
7648 }
7649}
7650
7651impl_udf_eq_hash!(VectorSimilarityUdf);
7652
7653impl ScalarUDFImpl for VectorSimilarityUdf {
7654 fn as_any(&self) -> &dyn Any {
7655 self
7656 }
7657
7658 fn name(&self) -> &str {
7659 "vector_similarity"
7660 }
7661
7662 fn signature(&self) -> &Signature {
7663 &self.signature
7664 }
7665
7666 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7667 Ok(DataType::Float64)
7668 }
7669
7670 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7671 invoke_similarity_udf("vector_similarity", 2, args)
7672 }
7673}
7674
7675#[cfg(test)]
7676mod tests {
7677 use super::*;
7678 use datafusion::execution::FunctionRegistry;
7679
7680 #[test]
7681 fn test_register_udfs() {
7682 let ctx = SessionContext::new();
7683 register_cypher_udfs(&ctx).unwrap();
7684
7685 assert!(ctx.udf("id").is_ok());
7688 assert!(ctx.udf("type").is_ok());
7689 assert!(ctx.udf("keys").is_ok());
7690 assert!(ctx.udf("range").is_ok());
7691 assert!(
7692 ctx.udf("_make_cypher_list").is_ok(),
7693 "_make_cypher_list UDF should be registered"
7694 );
7695 assert!(
7696 ctx.udf("_cv_to_bool").is_ok(),
7697 "_cv_to_bool UDF should be registered"
7698 );
7699 }
7700
7701 #[test]
7702 fn test_id_udf_signature() {
7703 let udf = create_id_udf();
7704 assert_eq!(udf.name(), "id");
7705 }
7706
7707 #[test]
7708 fn test_has_null_udf() {
7709 use datafusion::arrow::datatypes::{DataType, Field};
7710 use datafusion::config::ConfigOptions;
7711 use datafusion::scalar::ScalarValue;
7712 use std::sync::Arc;
7713
7714 let udf = create_has_null_udf();
7715
7716 let values = vec![
7718 ScalarValue::Int64(Some(1)),
7719 ScalarValue::Int64(Some(2)),
7720 ScalarValue::Int64(None),
7721 ];
7722
7723 let list_scalar = ScalarValue::List(ScalarValue::new_list(&values, &DataType::Int64, true));
7725
7726 let list_field = Arc::new(Field::new(
7727 "item",
7728 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
7729 true,
7730 ));
7731
7732 let args = ScalarFunctionArgs {
7733 args: vec![ColumnarValue::Scalar(list_scalar)],
7734 arg_fields: vec![list_field],
7735 number_rows: 1,
7736 return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
7737 config_options: Arc::new(ConfigOptions::default()),
7738 };
7739
7740 let result = udf.invoke_with_args(args).unwrap();
7741
7742 if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = result {
7743 assert!(b, "has_null should return true for list with null");
7744 } else {
7745 panic!("Unexpected result: {:?}", result);
7746 }
7747 }
7748
7749 fn json_to_cv_bytes(val: &serde_json::Value) -> Vec<u8> {
7755 let uni_val: uni_common::Value = val.clone().into();
7756 uni_common::cypher_value_codec::encode(&uni_val)
7757 }
7758
7759 fn make_multi_scalar_args(scalars: Vec<ScalarValue>) -> ScalarFunctionArgs {
7768 make_multi_scalar_args_with_return(scalars, DataType::LargeBinary)
7769 }
7770
7771 fn make_multi_scalar_args_with_return(
7772 scalars: Vec<ScalarValue>,
7773 return_type: DataType,
7774 ) -> ScalarFunctionArgs {
7775 use datafusion::arrow::datatypes::Field;
7776 use datafusion::config::ConfigOptions;
7777
7778 let arg_fields: Vec<_> = scalars
7779 .iter()
7780 .enumerate()
7781 .map(|(i, s)| Arc::new(Field::new(format!("arg{i}"), s.data_type(), true)))
7782 .collect();
7783 let args: Vec<_> = scalars.into_iter().map(ColumnarValue::Scalar).collect();
7784 ScalarFunctionArgs {
7785 args,
7786 arg_fields,
7787 number_rows: 1,
7788 return_field: Arc::new(Field::new("result", return_type, true)),
7789 config_options: Arc::new(ConfigOptions::default()),
7790 }
7791 }
7792
7793 fn decode_cv_scalar(cv: &ColumnarValue) -> serde_json::Value {
7795 match cv {
7796 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7797 let val = uni_common::cypher_value_codec::decode(bytes)
7798 .expect("failed to decode CypherValue output");
7799 val.into()
7800 }
7801 other => panic!("expected LargeBinary scalar, got {other:?}"),
7802 }
7803 }
7804
7805 #[test]
7806 fn test_make_cypher_list_scalars() {
7807 let udf = create_make_cypher_list_udf();
7808 let args = make_multi_scalar_args(vec![
7809 ScalarValue::Int64(Some(1)),
7810 ScalarValue::Float64(Some(3.21)),
7811 ScalarValue::Utf8(Some("hello".to_string())),
7812 ScalarValue::Boolean(Some(true)),
7813 ScalarValue::Null,
7814 ]);
7815 let result = udf.invoke_with_args(args).unwrap();
7816 let json = decode_cv_scalar(&result);
7817 let arr = json.as_array().expect("should be array");
7818 assert_eq!(arr.len(), 5);
7819 assert_eq!(arr[0], serde_json::json!(1));
7820 assert_eq!(arr[1], serde_json::json!(3.21));
7821 assert_eq!(arr[2], serde_json::json!("hello"));
7822 assert_eq!(arr[3], serde_json::json!(true));
7823 assert!(arr[4].is_null());
7824 }
7825
7826 #[test]
7827 fn test_make_cypher_list_empty() {
7828 let udf = create_make_cypher_list_udf();
7829 let args = make_multi_scalar_args(vec![]);
7830 let result = udf.invoke_with_args(args).unwrap();
7831 let json = decode_cv_scalar(&result);
7832 let arr = json.as_array().expect("should be array");
7833 assert!(arr.is_empty());
7834 }
7835
7836 #[test]
7837 fn test_make_cypher_list_single() {
7838 let udf = create_make_cypher_list_udf();
7839 let args = make_multi_scalar_args(vec![ScalarValue::Int64(Some(42))]);
7840 let result = udf.invoke_with_args(args).unwrap();
7841 let json = decode_cv_scalar(&result);
7842 let arr = json.as_array().expect("should be array");
7843 assert_eq!(arr.len(), 1);
7844 assert_eq!(arr[0], serde_json::json!(42));
7845 }
7846
7847 #[test]
7848 fn test_make_cypher_list_nested_cypher_value() {
7849 let udf = create_make_cypher_list_udf();
7850 let nested_bytes = json_to_cv_bytes(&serde_json::json!([1, 2]));
7852 let args = make_multi_scalar_args(vec![
7853 ScalarValue::LargeBinary(Some(nested_bytes)),
7854 ScalarValue::Int64(Some(3)),
7855 ]);
7856 let result = udf.invoke_with_args(args).unwrap();
7857 let json = decode_cv_scalar(&result);
7858 let arr = json.as_array().expect("should be array");
7859 assert_eq!(arr.len(), 2);
7860 assert_eq!(arr[0], serde_json::json!([1, 2]));
7861 assert_eq!(arr[1], serde_json::json!(3));
7862 }
7863
7864 fn make_cypher_in_args(
7870 element: &serde_json::Value,
7871 list: &serde_json::Value,
7872 ) -> ScalarFunctionArgs {
7873 make_multi_scalar_args_with_return(
7874 vec![
7875 ScalarValue::LargeBinary(Some(json_to_cv_bytes(element))),
7876 ScalarValue::LargeBinary(Some(json_to_cv_bytes(list))),
7877 ],
7878 DataType::Boolean,
7879 )
7880 }
7881
7882 #[test]
7883 fn test_cypher_in_found() {
7884 let udf = create_cypher_in_udf();
7885 let args = make_cypher_in_args(&serde_json::json!(3), &serde_json::json!([1, 2, 3]));
7886 let result = udf.invoke_with_args(args).unwrap();
7887 match result {
7888 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7889 other => panic!("expected Boolean(true), got {other:?}"),
7890 }
7891 }
7892
7893 #[test]
7894 fn test_cypher_in_not_found() {
7895 let udf = create_cypher_in_udf();
7896 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, 2, 3]));
7897 let result = udf.invoke_with_args(args).unwrap();
7898 match result {
7899 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7900 other => panic!("expected Boolean(false), got {other:?}"),
7901 }
7902 }
7903
7904 #[test]
7905 fn test_cypher_in_null_list() {
7906 let udf = create_cypher_in_udf();
7907 let args = make_multi_scalar_args_with_return(
7908 vec![
7909 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(1)))),
7910 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7911 ],
7912 DataType::Boolean,
7913 );
7914 let result = udf.invoke_with_args(args).unwrap();
7915 match result {
7916 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null list, got {other:?}"),
7918 }
7919 }
7920
7921 #[test]
7922 fn test_cypher_in_null_element_nonempty() {
7923 let udf = create_cypher_in_udf();
7924 let args = make_multi_scalar_args_with_return(
7925 vec![
7926 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7927 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7928 ],
7929 DataType::Boolean,
7930 );
7931 let result = udf.invoke_with_args(args).unwrap();
7932 match result {
7933 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null IN non-empty list, got {other:?}"),
7935 }
7936 }
7937
7938 #[test]
7939 fn test_cypher_in_null_element_empty() {
7940 let udf = create_cypher_in_udf();
7941 let args = make_multi_scalar_args_with_return(
7942 vec![
7943 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7944 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7945 ],
7946 DataType::Boolean,
7947 );
7948 let result = udf.invoke_with_args(args).unwrap();
7949 match result {
7950 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7951 other => panic!("expected Boolean(false) for null IN [], got {other:?}"),
7952 }
7953 }
7954
7955 #[test]
7956 fn test_cypher_in_not_found_with_null() {
7957 let udf = create_cypher_in_udf();
7958 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, null, 3]));
7959 let result = udf.invoke_with_args(args).unwrap();
7960 match result {
7961 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for 4 IN [1,null,3], got {other:?}"),
7963 }
7964 }
7965
7966 #[test]
7967 fn test_cypher_in_cross_type_int_float() {
7968 let udf = create_cypher_in_udf();
7969 let args = make_cypher_in_args(&serde_json::json!(1), &serde_json::json!([1.0, 2.0]));
7970 let result = udf.invoke_with_args(args).unwrap();
7971 match result {
7972 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7973 other => panic!("expected Boolean(true) for 1 IN [1.0, 2.0], got {other:?}"),
7974 }
7975 }
7976
7977 #[test]
7982 fn test_list_concat_basic() {
7983 let udf = create_cypher_list_concat_udf();
7984 let args = make_multi_scalar_args(vec![
7985 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7986 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([3, 4])))),
7987 ]);
7988 let result = udf.invoke_with_args(args).unwrap();
7989 let json = decode_cv_scalar(&result);
7990 assert_eq!(json, serde_json::json!([1, 2, 3, 4]));
7991 }
7992
7993 #[test]
7994 fn test_list_concat_empty() {
7995 let udf = create_cypher_list_concat_udf();
7996 let args = make_multi_scalar_args(vec![
7997 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7998 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7999 ]);
8000 let result = udf.invoke_with_args(args).unwrap();
8001 let json = decode_cv_scalar(&result);
8002 assert_eq!(json, serde_json::json!([1]));
8003 }
8004
8005 #[test]
8006 fn test_list_concat_null_left() {
8007 let udf = create_cypher_list_concat_udf();
8008 let args = make_multi_scalar_args(vec![
8009 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8010 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
8011 ]);
8012 let result = udf.invoke_with_args(args).unwrap();
8013 match result {
8014 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8015 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8016 let json: serde_json::Value = uni_val.into();
8017 assert!(json.is_null(), "expected null, got {json}");
8018 }
8019 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {} other => panic!("expected null result, got {other:?}"),
8021 }
8022 }
8023
8024 #[test]
8025 fn test_list_concat_null_right() {
8026 let udf = create_cypher_list_concat_udf();
8027 let args = make_multi_scalar_args(vec![
8028 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
8029 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8030 ]);
8031 let result = udf.invoke_with_args(args).unwrap();
8032 match result {
8033 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8034 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8035 let json: serde_json::Value = uni_val.into();
8036 assert!(json.is_null(), "expected null, got {json}");
8037 }
8038 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
8039 other => panic!("expected null result, got {other:?}"),
8040 }
8041 }
8042
8043 #[test]
8048 fn test_list_append_scalar() {
8049 let udf = create_cypher_list_append_udf();
8050 let args = make_multi_scalar_args(vec![
8051 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8052 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
8053 ]);
8054 let result = udf.invoke_with_args(args).unwrap();
8055 let json = decode_cv_scalar(&result);
8056 assert_eq!(json, serde_json::json!([1, 2, 3]));
8057 }
8058
8059 #[test]
8060 fn test_list_prepend_scalar() {
8061 let udf = create_cypher_list_append_udf();
8062 let args = make_multi_scalar_args(vec![
8063 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
8064 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8065 ]);
8066 let result = udf.invoke_with_args(args).unwrap();
8067 let json = decode_cv_scalar(&result);
8068 assert_eq!(json, serde_json::json!([3, 1, 2]));
8069 }
8070
8071 #[test]
8072 fn test_list_append_null_list() {
8073 let udf = create_cypher_list_append_udf();
8074 let args = make_multi_scalar_args(vec![
8075 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8076 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
8077 ]);
8078 let result = udf.invoke_with_args(args).unwrap();
8079 match result {
8080 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8081 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8082 let json: serde_json::Value = uni_val.into();
8083 assert!(json.is_null(), "expected null, got {json}");
8084 }
8085 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
8086 other => panic!("expected null result, got {other:?}"),
8087 }
8088 }
8089
8090 #[test]
8091 fn test_list_append_null_scalar() {
8092 let udf = create_cypher_list_append_udf();
8093 let args = make_multi_scalar_args(vec![
8094 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8095 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8096 ]);
8097 let result = udf.invoke_with_args(args).unwrap();
8098 match result {
8099 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8100 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8101 let json: serde_json::Value = uni_val.into();
8102 assert!(json.is_null(), "expected null, got {json}");
8103 }
8104 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
8105 other => panic!("expected null result, got {other:?}"),
8106 }
8107 }
8108
8109 #[test]
8114 fn test_sort_key_cross_type_ordering() {
8115 use uni_common::core::id::{Eid, Vid};
8118 use uni_common::{Edge, Node, Path, TemporalValue, Value};
8119
8120 let map_val = Value::Map([("a".to_string(), Value::String("map".to_string()))].into());
8121 let node_val = Value::Node(Node {
8122 vid: Vid::new(1),
8123 labels: vec!["L".to_string()],
8124 properties: Default::default(),
8125 });
8126 let edge_val = Value::Edge(Edge {
8127 eid: Eid::new(1),
8128 edge_type: "T".to_string(),
8129 src: Vid::new(1),
8130 dst: Vid::new(2),
8131 properties: Default::default(),
8132 });
8133 let list_val = Value::List(vec![Value::Int(1)]);
8134 let path_val = Value::Path(Path {
8135 nodes: vec![Node {
8136 vid: Vid::new(1),
8137 labels: vec!["L".to_string()],
8138 properties: Default::default(),
8139 }],
8140 edges: vec![],
8141 });
8142 let string_val = Value::String("hello".to_string());
8143 let bool_val = Value::Bool(false);
8144 let temporal_val = Value::Temporal(TemporalValue::Date {
8145 days_since_epoch: 1000,
8146 });
8147 let number_val = Value::Int(42);
8148 let nan_val = Value::Float(f64::NAN);
8149 let null_val = Value::Null;
8150
8151 let values = vec![
8152 &map_val,
8153 &node_val,
8154 &edge_val,
8155 &list_val,
8156 &path_val,
8157 &string_val,
8158 &bool_val,
8159 &temporal_val,
8160 &number_val,
8161 &nan_val,
8162 &null_val,
8163 ];
8164
8165 let keys: Vec<Vec<u8>> = values.iter().map(|v| encode_cypher_sort_key(v)).collect();
8166
8167 for i in 0..keys.len() - 1 {
8169 assert!(
8170 keys[i] < keys[i + 1],
8171 "Expected sort_key({:?}) < sort_key({:?}), but {:?} >= {:?}",
8172 values[i],
8173 values[i + 1],
8174 keys[i],
8175 keys[i + 1]
8176 );
8177 }
8178 }
8179
8180 #[test]
8181 fn test_sort_key_numbers() {
8182 let neg_inf = encode_cypher_sort_key(&Value::Float(f64::NEG_INFINITY));
8183 let neg_100 = encode_cypher_sort_key(&Value::Float(-100.0));
8184 let neg_1 = encode_cypher_sort_key(&Value::Int(-1));
8185 let zero_int = encode_cypher_sort_key(&Value::Int(0));
8186 let zero_float = encode_cypher_sort_key(&Value::Float(0.0));
8187 let one_int = encode_cypher_sort_key(&Value::Int(1));
8188 let one_float = encode_cypher_sort_key(&Value::Float(1.0));
8189 let hundred = encode_cypher_sort_key(&Value::Int(100));
8190 let pos_inf = encode_cypher_sort_key(&Value::Float(f64::INFINITY));
8191 let nan = encode_cypher_sort_key(&Value::Float(f64::NAN));
8192
8193 assert!(neg_inf < neg_100, "-inf < -100");
8194 assert!(neg_100 < neg_1, "-100 < -1");
8195 assert!(neg_1 < zero_int, "-1 < 0");
8196 assert_eq!(zero_int, zero_float, "0 int == 0.0 float");
8197 assert!(zero_int < one_int, "0 < 1");
8198 assert_eq!(one_int, one_float, "1 int == 1.0 float");
8199 assert!(one_int < hundred, "1 < 100");
8200 assert!(hundred < pos_inf, "100 < +inf");
8201 assert!(pos_inf < nan, "+inf < NaN");
8203 }
8204
8205 #[test]
8206 fn test_sort_key_booleans() {
8207 let f = encode_cypher_sort_key(&Value::Bool(false));
8208 let t = encode_cypher_sort_key(&Value::Bool(true));
8209 assert!(f < t, "false < true");
8210 }
8211
8212 #[test]
8213 fn test_sort_key_strings() {
8214 let empty = encode_cypher_sort_key(&Value::String(String::new()));
8215 let a = encode_cypher_sort_key(&Value::String("a".to_string()));
8216 let ab = encode_cypher_sort_key(&Value::String("ab".to_string()));
8217 let b = encode_cypher_sort_key(&Value::String("b".to_string()));
8218
8219 assert!(empty < a, "'' < 'a'");
8220 assert!(a < ab, "'a' < 'ab'");
8221 assert!(ab < b, "'ab' < 'b'");
8222 }
8223
8224 #[test]
8225 fn test_sort_key_lists() {
8226 let empty = encode_cypher_sort_key(&Value::List(vec![]));
8227 let one = encode_cypher_sort_key(&Value::List(vec![Value::Int(1)]));
8228 let one_two = encode_cypher_sort_key(&Value::List(vec![Value::Int(1), Value::Int(2)]));
8229 let two = encode_cypher_sort_key(&Value::List(vec![Value::Int(2)]));
8230
8231 assert!(empty < one, "[] < [1]");
8232 assert!(one < one_two, "[1] < [1,2]");
8233 assert!(one_two < two, "[1,2] < [2]");
8234 }
8235
8236 #[test]
8237 fn test_sort_key_temporal() {
8238 use uni_common::TemporalValue;
8239
8240 let date1 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8241 days_since_epoch: 100,
8242 }));
8243 let date2 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8244 days_since_epoch: 200,
8245 }));
8246 assert!(date1 < date2, "earlier date < later date");
8247
8248 let date = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8250 days_since_epoch: i32::MAX,
8251 }));
8252 let local_time = encode_cypher_sort_key(&Value::Temporal(TemporalValue::LocalTime {
8253 nanos_since_midnight: 0,
8254 }));
8255 assert!(date < local_time, "Date < LocalTime (by variant rank)");
8256 }
8257
8258 #[test]
8259 fn test_sort_key_nested_lists() {
8260 let inner_a = Value::List(vec![Value::Int(1)]);
8261 let inner_b = Value::List(vec![Value::Int(2)]);
8262
8263 let list_a = encode_cypher_sort_key(&Value::List(vec![inner_a.clone()]));
8264 let list_b = encode_cypher_sort_key(&Value::List(vec![inner_b.clone()]));
8265
8266 assert!(list_a < list_b, "[[1]] < [[2]]");
8267 }
8268
8269 #[test]
8270 fn test_sort_key_null_handling() {
8271 let null_key = encode_cypher_sort_key(&Value::Null);
8272 assert_eq!(null_key, vec![0x0A], "Null produces [0x0A]");
8273
8274 let number_key = encode_cypher_sort_key(&Value::Int(42));
8276 assert!(number_key < null_key, "number < null");
8277 }
8278
8279 #[test]
8280 fn test_byte_stuff_roundtrip() {
8281 let s1 = Value::String("a\x00b".to_string());
8283 let s2 = Value::String("a\x00c".to_string());
8284 let s3 = Value::String("a\x01".to_string());
8285
8286 let k1 = encode_cypher_sort_key(&s1);
8287 let k2 = encode_cypher_sort_key(&s2);
8288 let k3 = encode_cypher_sort_key(&s3);
8289
8290 assert!(k1 < k2, "a\\x00b < a\\x00c");
8291 assert!(k1 < k3, "a\\x00b < a\\x01");
8294 }
8295
8296 #[test]
8297 fn test_sort_key_order_preserving_f64() {
8298 let vals = [f64::NEG_INFINITY, -1.0, -0.0, 0.0, 1.0, f64::INFINITY];
8300 let encoded: Vec<[u8; 8]> = vals
8301 .iter()
8302 .map(|f| encode_order_preserving_f64(*f))
8303 .collect();
8304
8305 for i in 0..encoded.len() - 1 {
8306 assert!(
8307 encoded[i] <= encoded[i + 1],
8308 "encode({}) should <= encode({}), got {:?} vs {:?}",
8309 vals[i],
8310 vals[i + 1],
8311 encoded[i],
8312 encoded[i + 1]
8313 );
8314 }
8315 }
8316
8317 #[test]
8321 fn test_sort_key_string_as_temporal_time_with_offset() {
8322 let tv = sort_key_string_as_temporal("12:35:15+05:00")
8323 .expect("should parse Time with positive offset");
8324 match tv {
8325 uni_common::TemporalValue::Time {
8326 nanos_since_midnight,
8327 offset_seconds,
8328 } => {
8329 assert_eq!(offset_seconds, 5 * 3600, "offset should be +05:00 = 18000s");
8330 let expected_nanos = (12 * 3600 + 35 * 60 + 15) * 1_000_000_000i64;
8332 assert_eq!(nanos_since_midnight, expected_nanos);
8333 }
8334 other => panic!("expected TemporalValue::Time, got {other:?}"),
8335 }
8336 }
8337
8338 #[test]
8339 fn test_sort_key_string_as_temporal_time_negative_offset() {
8340 let tv = sort_key_string_as_temporal("10:35:00-08:00")
8341 .expect("should parse Time with negative offset");
8342 match tv {
8343 uni_common::TemporalValue::Time {
8344 nanos_since_midnight,
8345 offset_seconds,
8346 } => {
8347 assert_eq!(
8348 offset_seconds,
8349 -8 * 3600,
8350 "offset should be -08:00 = -28800s"
8351 );
8352 let expected_nanos = (10 * 3600 + 35 * 60) * 1_000_000_000i64;
8353 assert_eq!(nanos_since_midnight, expected_nanos);
8354 }
8355 other => panic!("expected TemporalValue::Time, got {other:?}"),
8356 }
8357 }
8358
8359 #[test]
8360 fn test_sort_key_string_as_temporal_date() {
8361 use super::super::expr_eval::temporal_from_value;
8362 let tv = temporal_from_value(&Value::String("2024-01-15".into()))
8363 .expect("should parse Date string");
8364 match tv {
8365 uni_common::TemporalValue::Date { days_since_epoch } => {
8366 assert!(days_since_epoch > 0, "2024-01-15 should be after epoch");
8368 }
8369 other => panic!("expected TemporalValue::Date, got {other:?}"),
8370 }
8371 }
8372}