1use arrow::array::ArrayRef;
27use arrow::datatypes::DataType;
28use arrow_array::{
29 Array, BooleanArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array,
30 LargeBinaryArray, LargeStringArray, StringArray, UInt64Array,
31};
32use chrono::Offset;
33use datafusion::error::Result as DFResult;
34use datafusion::logical_expr::{
35 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
36 Volatility,
37};
38use datafusion::prelude::SessionContext;
39use datafusion::scalar::ScalarValue;
40use std::any::Any;
41use std::hash::{Hash, Hasher};
42use std::sync::Arc;
43use uni_common::Value;
44use uni_cypher::ast::BinaryOp;
45use uni_store::storage::arrow_convert::values_to_array;
46
47use super::expr_eval::cypher_eq;
48
49macro_rules! impl_udf_eq_hash {
53 ($type:ty) => {
54 impl PartialEq for $type {
55 fn eq(&self, other: &Self) -> bool {
56 self.signature == other.signature
57 }
58 }
59
60 impl Eq for $type {}
61
62 impl Hash for $type {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.name().hash(state);
65 }
66 }
67 };
68}
69
70pub fn register_cypher_udfs(ctx: &SessionContext) -> DFResult<()> {
80 ctx.register_udf(create_id_udf());
81 ctx.register_udf(create_created_at_udf());
82 ctx.register_udf(create_updated_at_udf());
83 ctx.register_udf(create_type_udf());
84 ctx.register_udf(create_keys_udf());
85 ctx.register_udf(create_properties_udf());
86 ctx.register_udf(create_labels_udf());
87 ctx.register_udf(create_nodes_udf());
88 ctx.register_udf(create_relationships_udf());
89 ctx.register_udf(create_range_udf());
90 ctx.register_udf(create_index_udf());
91 ctx.register_udf(create_startnode_udf());
92 ctx.register_udf(create_endnode_udf());
93
94 ctx.register_udf(create_to_integer_udf());
96 ctx.register_udf(create_to_float_udf());
97 ctx.register_udf(create_to_boolean_udf());
98
99 ctx.register_udf(create_bitwise_or_udf());
101 ctx.register_udf(create_bitwise_and_udf());
102 ctx.register_udf(create_bitwise_xor_udf());
103 ctx.register_udf(create_bitwise_not_udf());
104 ctx.register_udf(create_shift_left_udf());
105 ctx.register_udf(create_shift_right_udf());
106
107 for name in &[
109 "date",
111 "time",
112 "localtime",
113 "localdatetime",
114 "datetime",
115 "duration",
116 "btic",
117 "duration.between",
119 "duration.inmonths",
120 "duration.indays",
121 "duration.inseconds",
122 "datetime.fromepoch",
123 "datetime.fromepochmillis",
124 "date.truncate",
126 "time.truncate",
127 "datetime.truncate",
128 "localdatetime.truncate",
129 "localtime.truncate",
130 "datetime.transaction",
132 "datetime.statement",
133 "datetime.realtime",
134 "date.transaction",
135 "date.statement",
136 "date.realtime",
137 "time.transaction",
138 "time.statement",
139 "time.realtime",
140 "localtime.transaction",
141 "localtime.statement",
142 "localtime.realtime",
143 "localdatetime.transaction",
144 "localdatetime.statement",
145 "localdatetime.realtime",
146 ] {
147 ctx.register_udf(create_temporal_udf(name));
148 }
149
150 ctx.register_udf(create_duration_property_udf());
152 ctx.register_udf(create_temporal_property_udf());
153 ctx.register_udf(create_tostring_udf());
154 ctx.register_udf(create_cypher_sort_key_udf());
155 ctx.register_udf(create_has_null_udf());
156 ctx.register_udf(create_cypher_size_udf());
157
158 ctx.register_udf(create_cypher_starts_with_udf());
160 ctx.register_udf(create_cypher_ends_with_udf());
161 ctx.register_udf(create_cypher_contains_udf());
162
163 ctx.register_udf(create_cypher_list_compare_udf());
165
166 ctx.register_udf(create_cypher_xor_udf());
168
169 ctx.register_udf(create_cypher_equal_udf());
171 ctx.register_udf(create_cypher_not_equal_udf());
172 ctx.register_udf(create_cypher_gt_udf());
173 ctx.register_udf(create_cypher_gt_eq_udf());
174 ctx.register_udf(create_cypher_lt_udf());
175 ctx.register_udf(create_cypher_lt_eq_udf());
176
177 ctx.register_udf(create_cv_to_bool_udf());
179
180 ctx.register_udf(create_cypher_add_udf());
182 ctx.register_udf(create_cypher_sub_udf());
183 ctx.register_udf(create_cypher_mul_udf());
184 ctx.register_udf(create_cypher_div_udf());
185 ctx.register_udf(create_cypher_mod_udf());
186
187 ctx.register_udf(create_map_project_udf());
189
190 ctx.register_udf(create_make_cypher_list_udf());
192
193 ctx.register_udf(create_cypher_in_udf());
195
196 ctx.register_udf(create_cypher_list_concat_udf());
198 ctx.register_udf(create_cypher_list_append_udf());
199 ctx.register_udf(create_cypher_list_slice_udf());
200 ctx.register_udf(create_cypher_tail_udf());
201 ctx.register_udf(create_cypher_head_udf());
202 ctx.register_udf(create_cypher_last_udf());
203 ctx.register_udf(create_cypher_reverse_udf());
204 ctx.register_udf(create_cypher_substring_udf());
205 ctx.register_udf(create_cypher_split_udf());
206 ctx.register_udf(create_cypher_list_to_cv_udf());
207 ctx.register_udf(create_cypher_scalar_to_cv_udf());
208
209 for name in &["year", "month", "day", "hour", "minute", "second"] {
211 ctx.register_udf(create_temporal_udf(name));
212 }
213
214 ctx.register_udf(create_cypher_to_float64_udf());
216
217 ctx.register_udf(create_similar_to_udf());
219 ctx.register_udf(create_vector_similarity_udf());
220
221 ctx.register_udaf(create_cypher_min_udaf());
223 ctx.register_udaf(create_cypher_max_udaf());
224 ctx.register_udaf(create_cypher_sum_udaf());
225 ctx.register_udaf(create_cypher_collect_udaf());
226
227 ctx.register_udaf(create_cypher_percentile_disc_udaf());
229 ctx.register_udaf(create_cypher_percentile_cont_udaf());
230
231 register_btic_scalar_udfs(ctx)?;
233
234 ctx.register_udaf(create_btic_min_udaf());
236 ctx.register_udaf(create_btic_max_udaf());
237 ctx.register_udaf(create_btic_span_agg_udaf());
238 ctx.register_udaf(create_btic_count_at_udaf());
239
240 Ok(())
241}
242
243pub fn register_custom_udfs(
251 ctx: &SessionContext,
252 registry: &crate::custom_functions::CustomFunctionRegistry,
253) -> DFResult<()> {
254 for (name, func) in registry.iter() {
255 let lower = name.to_lowercase();
258 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
259 lower,
260 func.clone(),
261 )));
262 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
264 name.to_string(),
265 func.clone(),
266 )));
267 }
268 Ok(())
269}
270
271struct CustomScalarUdf {
276 name: String,
277 func: crate::custom_functions::CustomScalarFn,
278 signature: Signature,
279}
280
281impl CustomScalarUdf {
282 fn new(name: String, func: crate::custom_functions::CustomScalarFn) -> Self {
283 Self {
284 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Volatile),
285 name,
286 func,
287 }
288 }
289}
290
291impl std::fmt::Debug for CustomScalarUdf {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("CustomScalarUdf")
294 .field("name", &self.name)
295 .finish()
296 }
297}
298
299impl_udf_eq_hash!(CustomScalarUdf);
300
301impl ScalarUDFImpl for CustomScalarUdf {
302 fn as_any(&self) -> &dyn Any {
303 self
304 }
305
306 fn name(&self) -> &str {
307 &self.name
308 }
309
310 fn signature(&self) -> &Signature {
311 &self.signature
312 }
313
314 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
315 Ok(DataType::LargeBinary)
316 }
317
318 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
319 let func = &self.func;
320 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
321 func(vals).map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
322 })
323 }
324}
325
326pub fn create_id_udf() -> ScalarUDF {
334 ScalarUDF::new_from_impl(IdUdf::new())
335}
336
337#[derive(Debug)]
338struct IdUdf {
339 signature: Signature,
340}
341
342impl IdUdf {
343 fn new() -> Self {
344 Self {
345 signature: Signature::new(
346 TypeSignature::Exact(vec![DataType::UInt64]),
347 Volatility::Immutable,
348 ),
349 }
350 }
351}
352
353impl_udf_eq_hash!(IdUdf);
354
355impl ScalarUDFImpl for IdUdf {
356 fn as_any(&self) -> &dyn Any {
357 self
358 }
359
360 fn name(&self) -> &str {
361 "id"
362 }
363
364 fn signature(&self) -> &Signature {
365 &self.signature
366 }
367
368 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
369 Ok(DataType::UInt64)
370 }
371
372 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
373 if args.args.is_empty() {
375 return Err(datafusion::error::DataFusionError::Execution(
376 "id(): requires 1 argument".to_string(),
377 ));
378 }
379 Ok(args.args[0].clone())
380 }
381}
382
383pub fn create_created_at_udf() -> ScalarUDF {
396 ScalarUDF::new_from_impl(SystemTimestampUdf::new("created_at"))
397}
398
399pub fn create_updated_at_udf() -> ScalarUDF {
406 ScalarUDF::new_from_impl(SystemTimestampUdf::new("updated_at"))
407}
408
409#[derive(Debug)]
410struct SystemTimestampUdf {
411 name: &'static str,
412 signature: Signature,
413}
414
415impl SystemTimestampUdf {
416 fn new(name: &'static str) -> Self {
417 Self {
418 name,
419 signature: Signature::new(
420 TypeSignature::Exact(vec![DataType::Timestamp(
421 arrow_schema::TimeUnit::Nanosecond,
422 Some("UTC".into()),
423 )]),
424 Volatility::Immutable,
425 ),
426 }
427 }
428}
429
430impl_udf_eq_hash!(SystemTimestampUdf);
431
432impl ScalarUDFImpl for SystemTimestampUdf {
433 fn as_any(&self) -> &dyn Any {
434 self
435 }
436
437 fn name(&self) -> &str {
438 self.name
439 }
440
441 fn signature(&self) -> &Signature {
442 &self.signature
443 }
444
445 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
446 Ok(DataType::Timestamp(
447 arrow_schema::TimeUnit::Nanosecond,
448 Some("UTC".into()),
449 ))
450 }
451
452 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
453 if args.args.is_empty() {
458 return Err(datafusion::error::DataFusionError::Execution(format!(
459 "{}(): requires 1 argument",
460 self.name
461 )));
462 }
463 Ok(args.args[0].clone())
464 }
465}
466
467pub fn create_type_udf() -> ScalarUDF {
475 ScalarUDF::new_from_impl(TypeUdf::new())
476}
477
478#[derive(Debug)]
479struct TypeUdf {
480 signature: Signature,
481}
482
483impl TypeUdf {
484 fn new() -> Self {
485 Self {
486 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
490 }
491 }
492}
493
494impl_udf_eq_hash!(TypeUdf);
495
496impl ScalarUDFImpl for TypeUdf {
497 fn as_any(&self) -> &dyn Any {
498 self
499 }
500
501 fn name(&self) -> &str {
502 "type"
503 }
504
505 fn signature(&self) -> &Signature {
506 &self.signature
507 }
508
509 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
510 Ok(DataType::Utf8)
511 }
512
513 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
514 if args.args.is_empty() {
515 return Err(datafusion::error::DataFusionError::Execution(
516 "type(): requires 1 argument".to_string(),
517 ));
518 }
519 let output_type = DataType::Utf8;
520 invoke_cypher_udf(args, &output_type, |val_args| {
521 if val_args.is_empty() {
522 return Err(datafusion::error::DataFusionError::Execution(
523 "type(): requires 1 argument".to_string(),
524 ));
525 }
526 let val = &val_args[0];
527 match val {
528 Value::Map(map) => {
530 if let Some(Value::String(t)) = map.get("_type") {
531 Ok(Value::String(t.clone()))
532 } else {
533 Err(datafusion::error::DataFusionError::Execution(
535 "TypeError: InvalidArgumentValue - type() requires a relationship argument".to_string(),
536 ))
537 }
538 }
539 Value::Null => Ok(Value::Null),
540 _ => Err(datafusion::error::DataFusionError::Execution(
541 "TypeError: InvalidArgumentValue - type() requires a relationship argument"
542 .to_string(),
543 )),
544 }
545 })
546 }
547}
548
549pub fn create_keys_udf() -> ScalarUDF {
557 ScalarUDF::new_from_impl(KeysUdf::new())
558}
559
560#[derive(Debug)]
561struct KeysUdf {
562 signature: Signature,
563}
564
565impl KeysUdf {
566 fn new() -> Self {
567 Self {
568 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
569 }
570 }
571}
572
573impl_udf_eq_hash!(KeysUdf);
574
575impl ScalarUDFImpl for KeysUdf {
576 fn as_any(&self) -> &dyn Any {
577 self
578 }
579
580 fn name(&self) -> &str {
581 "keys"
582 }
583
584 fn signature(&self) -> &Signature {
585 &self.signature
586 }
587
588 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
589 Ok(DataType::List(Arc::new(
590 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
591 )))
592 }
593
594 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
595 let output_type = self.return_type(&[])?;
596 invoke_cypher_udf(args, &output_type, |val_args| {
597 if val_args.is_empty() {
598 return Err(datafusion::error::DataFusionError::Execution(
599 "keys(): requires 1 argument".to_string(),
600 ));
601 }
602
603 let arg = &val_args[0];
604 let keys = match arg {
605 Value::Map(map) => {
606 let (source, is_entity) = match map.get("_all_props") {
616 Some(Value::Map(all)) => (all, true),
617 _ => (map, false),
618 };
619 let mut key_strings: Vec<String> = source
620 .iter()
621 .filter(|(k, v)| !k.starts_with('_') && (!is_entity || !v.is_null()))
622 .map(|(k, _)| k.clone())
623 .collect();
624 key_strings.sort();
625 key_strings
626 .into_iter()
627 .map(Value::String)
628 .collect::<Vec<_>>()
629 }
630 Value::Null => {
631 return Ok(Value::Null);
632 }
633 _ => {
634 vec![]
637 }
638 };
639
640 Ok(Value::List(keys))
641 })
642 }
643}
644
645pub fn create_properties_udf() -> ScalarUDF {
650 ScalarUDF::new_from_impl(PropertiesUdf::new())
651}
652
653#[derive(Debug)]
654struct PropertiesUdf {
655 signature: Signature,
656}
657
658impl PropertiesUdf {
659 fn new() -> Self {
660 Self {
661 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
662 }
663 }
664}
665
666impl_udf_eq_hash!(PropertiesUdf);
667
668impl ScalarUDFImpl for PropertiesUdf {
669 fn as_any(&self) -> &dyn Any {
670 self
671 }
672
673 fn name(&self) -> &str {
674 "properties"
675 }
676
677 fn signature(&self) -> &Signature {
678 &self.signature
679 }
680
681 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
682 Ok(DataType::LargeBinary)
684 }
685
686 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
687 let output_type = self.return_type(&[])?;
688 invoke_cypher_udf(args, &output_type, |val_args| {
689 if val_args.is_empty() {
690 return Err(datafusion::error::DataFusionError::Execution(
691 "properties(): requires 1 argument".to_string(),
692 ));
693 }
694
695 let arg = &val_args[0];
696 match arg {
697 Value::Map(map) => {
698 let identity_null = map
708 .get("_vid")
709 .map(|v| v.is_null())
710 .or_else(|| map.get("_eid").map(|v| v.is_null()))
711 .unwrap_or(false);
712 if identity_null {
713 return Ok(Value::Null);
714 }
715
716 let source = match map.get("_all_props") {
718 Some(Value::Map(all)) => all,
719 _ => map,
720 };
721 let filtered: std::collections::HashMap<String, Value> = source
723 .iter()
724 .filter(|(k, _)| !k.starts_with('_'))
725 .map(|(k, v)| (k.clone(), v.clone()))
726 .collect();
727 Ok(Value::Map(filtered))
728 }
729 _ => Ok(Value::Null),
730 }
731 })
732 }
733}
734
735pub fn create_index_udf() -> ScalarUDF {
740 ScalarUDF::new_from_impl(IndexUdf::new())
741}
742
743#[derive(Debug)]
744struct IndexUdf {
745 signature: Signature,
746}
747
748impl IndexUdf {
749 fn new() -> Self {
750 Self {
751 signature: Signature::any(2, Volatility::Immutable),
752 }
753 }
754}
755
756impl_udf_eq_hash!(IndexUdf);
757
758impl ScalarUDFImpl for IndexUdf {
759 fn as_any(&self) -> &dyn Any {
760 self
761 }
762
763 fn name(&self) -> &str {
764 "index"
765 }
766
767 fn signature(&self) -> &Signature {
768 &self.signature
769 }
770
771 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
772 Ok(DataType::LargeBinary)
774 }
775
776 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
777 let output_type = self.return_type(&[])?;
778 invoke_cypher_udf(args, &output_type, |val_args| {
779 if val_args.len() != 2 {
780 return Err(datafusion::error::DataFusionError::Execution(
781 "index(): requires 2 arguments".to_string(),
782 ));
783 }
784
785 let container = &val_args[0];
786 let index = &val_args[1];
787
788 let index_as_int = index.as_i64();
792
793 let result = match container {
794 Value::List(arr) => {
795 if let Some(i) = index_as_int {
796 let idx = if i < 0 {
797 let pos = arr.len() as i64 + i;
798 if pos < 0 { -1 } else { pos }
799 } else {
800 i
801 };
802 if idx >= 0 && (idx as usize) < arr.len() {
803 arr[idx as usize].clone()
804 } else {
805 Value::Null
806 }
807 } else if index.is_null() {
808 Value::Null
809 } else {
810 return Err(datafusion::error::DataFusionError::Execution(format!(
811 "TypeError: InvalidArgumentType - list index must be an integer, got: {:?}",
812 index
813 )));
814 }
815 }
816 Value::Map(map) => {
817 if let Some(key) = index.as_str() {
818 if let Some(val) = map.get(key) {
820 val.clone()
821 } else if let Some(Value::Map(all_props)) = map.get("_all_props") {
822 all_props.get(key).cloned().unwrap_or(Value::Null)
824 } else if let Some(Value::Map(props)) = map.get("properties") {
825 props.get(key).cloned().unwrap_or(Value::Null)
827 } else {
828 Value::Null
829 }
830 } else if !index.is_null() {
831 return Err(datafusion::error::DataFusionError::Execution(
832 "index(): map index must be a string".to_string(),
833 ));
834 } else {
835 Value::Null
836 }
837 }
838 Value::Node(node) => {
839 if let Some(key) = index.as_str() {
840 node.properties.get(key).cloned().unwrap_or(Value::Null)
841 } else if !index.is_null() {
842 return Err(datafusion::error::DataFusionError::Execution(
843 "index(): node index must be a string".to_string(),
844 ));
845 } else {
846 Value::Null
847 }
848 }
849 Value::Edge(edge) => {
850 if let Some(key) = index.as_str() {
851 edge.properties.get(key).cloned().unwrap_or(Value::Null)
852 } else if !index.is_null() {
853 return Err(datafusion::error::DataFusionError::Execution(
854 "index(): edge index must be a string".to_string(),
855 ));
856 } else {
857 Value::Null
858 }
859 }
860 Value::Null => Value::Null,
861 _ => {
862 return Err(datafusion::error::DataFusionError::Execution(format!(
863 "TypeError: InvalidArgumentType - cannot index into {:?}",
864 container
865 )));
866 }
867 };
868
869 Ok(result)
870 })
871 }
872}
873
874pub fn create_labels_udf() -> ScalarUDF {
879 ScalarUDF::new_from_impl(LabelsUdf::new())
880}
881
882#[derive(Debug)]
883struct LabelsUdf {
884 signature: Signature,
885}
886
887impl LabelsUdf {
888 fn new() -> Self {
889 Self {
890 signature: Signature::any(1, Volatility::Immutable),
891 }
892 }
893}
894
895impl_udf_eq_hash!(LabelsUdf);
896
897impl ScalarUDFImpl for LabelsUdf {
898 fn as_any(&self) -> &dyn Any {
899 self
900 }
901
902 fn name(&self) -> &str {
903 "labels"
904 }
905
906 fn signature(&self) -> &Signature {
907 &self.signature
908 }
909
910 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
911 Ok(DataType::List(Arc::new(
912 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
913 )))
914 }
915
916 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
917 let output_type = self.return_type(&[])?;
918 invoke_cypher_udf(args, &output_type, |val_args| {
919 if val_args.is_empty() {
920 return Err(datafusion::error::DataFusionError::Execution(
921 "labels(): requires 1 argument".to_string(),
922 ));
923 }
924
925 let node = &val_args[0];
926 match node {
927 Value::Map(map) => {
928 if let Some(Value::List(arr)) = map.get("_labels") {
929 Ok(Value::List(arr.clone()))
930 } else {
931 Err(datafusion::error::DataFusionError::Execution(
933 "TypeError: InvalidArgumentValue - labels() requires a node argument"
934 .to_string(),
935 ))
936 }
937 }
938 Value::Null => Ok(Value::Null),
939 _ => Err(datafusion::error::DataFusionError::Execution(
940 "TypeError: InvalidArgumentValue - labels() requires a node argument"
941 .to_string(),
942 )),
943 }
944 })
945 }
946}
947
948pub fn create_nodes_udf() -> ScalarUDF {
953 ScalarUDF::new_from_impl(NodesUdf::new())
954}
955
956#[derive(Debug)]
957struct NodesUdf {
958 signature: Signature,
959}
960
961impl NodesUdf {
962 fn new() -> Self {
963 Self {
964 signature: Signature::any(1, Volatility::Immutable),
965 }
966 }
967}
968
969impl_udf_eq_hash!(NodesUdf);
970
971impl ScalarUDFImpl for NodesUdf {
972 fn as_any(&self) -> &dyn Any {
973 self
974 }
975
976 fn name(&self) -> &str {
977 "nodes"
978 }
979
980 fn signature(&self) -> &Signature {
981 &self.signature
982 }
983
984 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
985 Ok(DataType::LargeBinary)
986 }
987
988 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
989 let output_type = self.return_type(&[])?;
990 invoke_cypher_udf(args, &output_type, |val_args| {
991 if val_args.is_empty() {
992 return Err(datafusion::error::DataFusionError::Execution(
993 "nodes(): requires 1 argument".to_string(),
994 ));
995 }
996
997 let path = &val_args[0];
998 let nodes = match path {
999 Value::Map(map) => map.get("nodes").cloned().unwrap_or(Value::Null),
1000 _ => Value::Null,
1001 };
1002
1003 Ok(nodes)
1004 })
1005 }
1006}
1007
1008pub fn create_relationships_udf() -> ScalarUDF {
1013 ScalarUDF::new_from_impl(RelationshipsUdf::new())
1014}
1015
1016#[derive(Debug)]
1017struct RelationshipsUdf {
1018 signature: Signature,
1019}
1020
1021impl RelationshipsUdf {
1022 fn new() -> Self {
1023 Self {
1024 signature: Signature::any(1, Volatility::Immutable),
1025 }
1026 }
1027}
1028
1029impl_udf_eq_hash!(RelationshipsUdf);
1030
1031impl ScalarUDFImpl for RelationshipsUdf {
1032 fn as_any(&self) -> &dyn Any {
1033 self
1034 }
1035
1036 fn name(&self) -> &str {
1037 "relationships"
1038 }
1039
1040 fn signature(&self) -> &Signature {
1041 &self.signature
1042 }
1043
1044 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1045 Ok(DataType::LargeBinary)
1046 }
1047
1048 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1049 let output_type = self.return_type(&[])?;
1050 invoke_cypher_udf(args, &output_type, |val_args| {
1051 if val_args.is_empty() {
1052 return Err(datafusion::error::DataFusionError::Execution(
1053 "relationships(): requires 1 argument".to_string(),
1054 ));
1055 }
1056
1057 let path = &val_args[0];
1058 let rels = match path {
1059 Value::Map(map) => map.get("relationships").cloned().unwrap_or(Value::Null),
1060 _ => Value::Null,
1061 };
1062
1063 Ok(rels)
1064 })
1065 }
1066}
1067
1068pub fn create_startnode_udf() -> ScalarUDF {
1077 ScalarUDF::new_from_impl(StartNodeUdf::new())
1078}
1079
1080#[derive(Debug)]
1081struct StartNodeUdf {
1082 signature: Signature,
1083}
1084
1085impl StartNodeUdf {
1086 fn new() -> Self {
1087 Self {
1088 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1089 }
1090 }
1091}
1092
1093impl_udf_eq_hash!(StartNodeUdf);
1094
1095impl ScalarUDFImpl for StartNodeUdf {
1096 fn as_any(&self) -> &dyn Any {
1097 self
1098 }
1099
1100 fn name(&self) -> &str {
1101 "startnode"
1102 }
1103
1104 fn signature(&self) -> &Signature {
1105 &self.signature
1106 }
1107
1108 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1109 Ok(DataType::LargeBinary)
1110 }
1111
1112 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1113 let output_type = DataType::LargeBinary;
1114 invoke_cypher_udf(args, &output_type, |val_args| {
1115 startnode_endnode_impl(val_args, true)
1116 })
1117 }
1118}
1119
1120pub fn create_endnode_udf() -> ScalarUDF {
1126 ScalarUDF::new_from_impl(EndNodeUdf::new())
1127}
1128
1129#[derive(Debug)]
1130struct EndNodeUdf {
1131 signature: Signature,
1132}
1133
1134impl EndNodeUdf {
1135 fn new() -> Self {
1136 Self {
1137 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1138 }
1139 }
1140}
1141
1142impl_udf_eq_hash!(EndNodeUdf);
1143
1144impl ScalarUDFImpl for EndNodeUdf {
1145 fn as_any(&self) -> &dyn Any {
1146 self
1147 }
1148
1149 fn name(&self) -> &str {
1150 "endnode"
1151 }
1152
1153 fn signature(&self) -> &Signature {
1154 &self.signature
1155 }
1156
1157 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1158 Ok(DataType::LargeBinary)
1159 }
1160
1161 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1162 let output_type = DataType::LargeBinary;
1163 invoke_cypher_udf(args, &output_type, |val_args| {
1164 startnode_endnode_impl(val_args, false)
1165 })
1166 }
1167}
1168
1169fn startnode_endnode_impl(val_args: &[Value], is_start: bool) -> DFResult<Value> {
1174 if val_args.is_empty() {
1175 let fn_name = if is_start { "startNode" } else { "endNode" };
1176 return Err(datafusion::error::DataFusionError::Execution(format!(
1177 "{fn_name}(): requires at least 1 argument"
1178 )));
1179 }
1180
1181 let edge_val = &val_args[0];
1182 let target_vid = extract_endpoint_vid(edge_val, is_start);
1183
1184 let target_vid = match target_vid {
1185 Some(vid) => vid,
1186 None => return Ok(Value::Null),
1187 };
1188
1189 for node_val in val_args.iter().skip(1) {
1191 if let Some(vid) = extract_vid(node_val)
1192 && vid == target_vid
1193 {
1194 return Ok(node_val.clone());
1195 }
1196 }
1197
1198 let mut map = std::collections::HashMap::new();
1200 map.insert("_vid".to_string(), Value::Int(target_vid as i64));
1201 Ok(Value::Map(map))
1202}
1203
1204fn extract_endpoint_vid(val: &Value, is_start: bool) -> Option<u64> {
1206 match val {
1207 Value::Edge(edge) => {
1208 let vid = if is_start { edge.src } else { edge.dst };
1209 Some(vid.as_u64())
1210 }
1211 Value::Map(map) => {
1212 let key = if is_start { "_src_vid" } else { "_dst_vid" };
1214 if let Some(v) = map.get(key) {
1215 return v.as_u64();
1216 }
1217 let key2 = if is_start { "_src" } else { "_dst" };
1219 if let Some(v) = map.get(key2) {
1220 return v.as_u64();
1221 }
1222 let node_key = if is_start { "_startNode" } else { "_endNode" };
1224 if let Some(node_val) = map.get(node_key) {
1225 return extract_vid(node_val);
1226 }
1227 None
1228 }
1229 _ => None,
1230 }
1231}
1232
1233fn extract_vid(val: &Value) -> Option<u64> {
1235 match val {
1236 Value::Map(map) => map.get("_vid").and_then(|v| v.as_u64()),
1237 _ => None,
1238 }
1239}
1240
1241fn extract_i64_range_arg(arg: &ColumnarValue, row_idx: usize, name: &str) -> DFResult<i64> {
1248 match arg {
1249 ColumnarValue::Scalar(sv) => match sv {
1250 ScalarValue::Int8(Some(v)) => Ok(*v as i64),
1251 ScalarValue::Int16(Some(v)) => Ok(*v as i64),
1252 ScalarValue::Int32(Some(v)) => Ok(*v as i64),
1253 ScalarValue::Int64(Some(v)) => Ok(*v),
1254 ScalarValue::UInt8(Some(v)) => Ok(*v as i64),
1255 ScalarValue::UInt16(Some(v)) => Ok(*v as i64),
1256 ScalarValue::UInt32(Some(v)) => Ok(*v as i64),
1257 ScalarValue::UInt64(Some(v)) => Ok(*v as i64),
1258 ScalarValue::LargeBinary(Some(bytes)) => {
1259 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1260 datafusion::error::DataFusionError::Execution(format!(
1261 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1262 name
1263 ))
1264 })
1265 }
1266 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1267 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1268 name
1269 ))),
1270 },
1271 ColumnarValue::Array(arr) => {
1272 if row_idx >= arr.len() || arr.is_null(row_idx) {
1273 return Err(datafusion::error::DataFusionError::Execution(format!(
1274 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1275 name
1276 )));
1277 }
1278 if !arr.is_empty() {
1280 use datafusion::arrow::array::{
1281 Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, UInt16Array,
1282 UInt32Array, UInt64Array,
1283 };
1284 match arr.data_type() {
1285 DataType::Int8 => Ok(arr
1286 .as_any()
1287 .downcast_ref::<Int8Array>()
1288 .unwrap()
1289 .value(row_idx) as i64),
1290 DataType::Int16 => Ok(arr
1291 .as_any()
1292 .downcast_ref::<Int16Array>()
1293 .unwrap()
1294 .value(row_idx) as i64),
1295 DataType::Int32 => Ok(arr
1296 .as_any()
1297 .downcast_ref::<Int32Array>()
1298 .unwrap()
1299 .value(row_idx) as i64),
1300 DataType::Int64 => Ok(arr
1301 .as_any()
1302 .downcast_ref::<Int64Array>()
1303 .unwrap()
1304 .value(row_idx)),
1305 DataType::UInt8 => Ok(arr
1306 .as_any()
1307 .downcast_ref::<UInt8Array>()
1308 .unwrap()
1309 .value(row_idx) as i64),
1310 DataType::UInt16 => Ok(arr
1311 .as_any()
1312 .downcast_ref::<UInt16Array>()
1313 .unwrap()
1314 .value(row_idx) as i64),
1315 DataType::UInt32 => Ok(arr
1316 .as_any()
1317 .downcast_ref::<UInt32Array>()
1318 .unwrap()
1319 .value(row_idx) as i64),
1320 DataType::UInt64 => Ok(arr
1321 .as_any()
1322 .downcast_ref::<UInt64Array>()
1323 .unwrap()
1324 .value(row_idx) as i64),
1325 DataType::LargeBinary => {
1326 let bytes = arr
1327 .as_any()
1328 .downcast_ref::<LargeBinaryArray>()
1329 .unwrap()
1330 .value(row_idx);
1331 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1332 datafusion::error::DataFusionError::Execution(format!(
1333 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1334 name
1335 ))
1336 })
1337 }
1338 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1339 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1340 name
1341 ))),
1342 }
1343 } else {
1344 Err(datafusion::error::DataFusionError::Execution(format!(
1345 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1346 name
1347 )))
1348 }
1349 }
1350 }
1351}
1352
1353pub fn create_range_udf() -> ScalarUDF {
1355 ScalarUDF::new_from_impl(RangeUdf::new())
1356}
1357
1358#[derive(Debug)]
1359struct RangeUdf {
1360 signature: Signature,
1361}
1362
1363impl RangeUdf {
1364 fn new() -> Self {
1365 Self {
1366 signature: Signature::one_of(
1367 vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1368 Volatility::Immutable,
1369 ),
1370 }
1371 }
1372}
1373
1374impl_udf_eq_hash!(RangeUdf);
1375
1376impl ScalarUDFImpl for RangeUdf {
1377 fn as_any(&self) -> &dyn Any {
1378 self
1379 }
1380
1381 fn name(&self) -> &str {
1382 "range"
1383 }
1384
1385 fn signature(&self) -> &Signature {
1386 &self.signature
1387 }
1388
1389 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1390 Ok(DataType::List(Arc::new(
1391 arrow::datatypes::Field::new_list_field(DataType::Int64, true),
1392 )))
1393 }
1394
1395 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1396 if args.args.len() < 2 || args.args.len() > 3 {
1397 return Err(datafusion::error::DataFusionError::Execution(
1398 "range(): requires 2 or 3 arguments".to_string(),
1399 ));
1400 }
1401
1402 let len = args
1403 .args
1404 .iter()
1405 .find_map(|arg| match arg {
1406 ColumnarValue::Array(arr) => Some(arr.len()),
1407 _ => None,
1408 })
1409 .unwrap_or(1);
1410
1411 let mut list_builder =
1412 arrow_array::builder::ListBuilder::new(arrow_array::builder::Int64Builder::new());
1413
1414 for row_idx in 0..len {
1415 let start = extract_i64_range_arg(&args.args[0], row_idx, "start")?;
1416 let end = extract_i64_range_arg(&args.args[1], row_idx, "end")?;
1417 let step = if args.args.len() == 3 {
1418 extract_i64_range_arg(&args.args[2], row_idx, "step")?
1419 } else {
1420 1
1421 };
1422
1423 if step == 0 {
1424 return Err(datafusion::error::DataFusionError::Execution(
1425 "range(): step cannot be zero".to_string(),
1426 ));
1427 }
1428
1429 if step > 0 && start <= end {
1430 let mut current = start;
1431 while current <= end {
1432 list_builder.values().append_value(current);
1433 current += step;
1434 }
1435 } else if step < 0 && start >= end {
1436 let mut current = start;
1437 while current >= end {
1438 list_builder.values().append_value(current);
1439 current += step;
1440 }
1441 }
1442 list_builder.append(true);
1444 }
1445
1446 let list_arr = Arc::new(list_builder.finish()) as ArrayRef;
1447 if len == 1
1448 && args
1449 .args
1450 .iter()
1451 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
1452 {
1453 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
1454 &list_arr, 0,
1455 )?))
1456 } else {
1457 Ok(ColumnarValue::Array(list_arr))
1458 }
1459 }
1460}
1461
1462fn invoke_binary_bitwise_op<F>(
1470 args: &ScalarFunctionArgs,
1471 name: &str,
1472 op: F,
1473) -> DFResult<ColumnarValue>
1474where
1475 F: Fn(i64, i64) -> i64,
1476{
1477 use arrow_array::Int64Array;
1478 use datafusion::common::ScalarValue;
1479 use datafusion::error::DataFusionError;
1480
1481 if args.args.len() != 2 {
1482 return Err(DataFusionError::Execution(format!(
1483 "{}(): requires exactly 2 arguments",
1484 name
1485 )));
1486 }
1487
1488 let left = &args.args[0];
1489 let right = &args.args[1];
1490
1491 match (left, right) {
1492 (
1493 ColumnarValue::Scalar(ScalarValue::Int64(Some(l))),
1494 ColumnarValue::Scalar(ScalarValue::Int64(Some(r))),
1495 ) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*l, *r))))),
1496 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
1497 let l_arr = l_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1498 DataFusionError::Execution(format!("{}(): left array must be Int64", name))
1499 })?;
1500 let r_arr = r_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1501 DataFusionError::Execution(format!("{}(): right array must be Int64", name))
1502 })?;
1503
1504 let result: Int64Array = l_arr
1505 .iter()
1506 .zip(r_arr.iter())
1507 .map(|(l, r)| match (l, r) {
1508 (Some(l), Some(r)) => Some(op(l, r)),
1509 _ => None,
1510 })
1511 .collect();
1512
1513 Ok(ColumnarValue::Array(Arc::new(result)))
1514 }
1515 _ => Err(DataFusionError::Execution(format!(
1516 "{}(): mixed scalar/array not supported",
1517 name
1518 ))),
1519 }
1520}
1521
1522fn invoke_unary_bitwise_op<F>(
1526 args: &ScalarFunctionArgs,
1527 name: &str,
1528 op: F,
1529) -> DFResult<ColumnarValue>
1530where
1531 F: Fn(i64) -> i64,
1532{
1533 use arrow_array::Int64Array;
1534 use datafusion::common::ScalarValue;
1535 use datafusion::error::DataFusionError;
1536
1537 if args.args.len() != 1 {
1538 return Err(DataFusionError::Execution(format!(
1539 "{}(): requires exactly 1 argument",
1540 name
1541 )));
1542 }
1543
1544 let operand = &args.args[0];
1545
1546 match operand {
1547 ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
1548 Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*v)))))
1549 }
1550 ColumnarValue::Array(arr) => {
1551 let arr = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1552 DataFusionError::Execution(format!("{}(): array must be Int64", name))
1553 })?;
1554
1555 let result: Int64Array = arr.iter().map(|v| v.map(&op)).collect();
1556
1557 Ok(ColumnarValue::Array(Arc::new(result)))
1558 }
1559 _ => Err(DataFusionError::Execution(format!(
1560 "{}(): invalid argument type",
1561 name
1562 ))),
1563 }
1564}
1565
1566macro_rules! define_binary_bitwise_udf {
1570 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1571 #[derive(Debug)]
1572 struct $struct_name {
1573 signature: Signature,
1574 }
1575
1576 impl $struct_name {
1577 fn new() -> Self {
1578 Self {
1579 signature: Signature::exact(
1580 vec![DataType::Int64, DataType::Int64],
1581 Volatility::Immutable,
1582 ),
1583 }
1584 }
1585 }
1586
1587 impl_udf_eq_hash!($struct_name);
1588
1589 impl ScalarUDFImpl for $struct_name {
1590 fn as_any(&self) -> &dyn Any {
1591 self
1592 }
1593
1594 fn name(&self) -> &str {
1595 $udf_name
1596 }
1597
1598 fn signature(&self) -> &Signature {
1599 &self.signature
1600 }
1601
1602 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1603 Ok(DataType::Int64)
1604 }
1605
1606 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1607 invoke_binary_bitwise_op(&args, $udf_name, $op)
1608 }
1609 }
1610 };
1611}
1612
1613macro_rules! define_unary_bitwise_udf {
1617 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1618 #[derive(Debug)]
1619 struct $struct_name {
1620 signature: Signature,
1621 }
1622
1623 impl $struct_name {
1624 fn new() -> Self {
1625 Self {
1626 signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
1627 }
1628 }
1629 }
1630
1631 impl_udf_eq_hash!($struct_name);
1632
1633 impl ScalarUDFImpl for $struct_name {
1634 fn as_any(&self) -> &dyn Any {
1635 self
1636 }
1637
1638 fn name(&self) -> &str {
1639 $udf_name
1640 }
1641
1642 fn signature(&self) -> &Signature {
1643 &self.signature
1644 }
1645
1646 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1647 Ok(DataType::Int64)
1648 }
1649
1650 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1651 invoke_unary_bitwise_op(&args, $udf_name, $op)
1652 }
1653 }
1654 };
1655}
1656
1657define_binary_bitwise_udf!(BitwiseOrUdf, "uni.bitwise.or", |l, r| l | r);
1659define_binary_bitwise_udf!(BitwiseAndUdf, "uni.bitwise.and", |l, r| l & r);
1660define_binary_bitwise_udf!(BitwiseXorUdf, "uni.bitwise.xor", |l, r| l ^ r);
1661define_binary_bitwise_udf!(ShiftLeftUdf, "uni.bitwise.shiftLeft", |l, r| l << r);
1662define_binary_bitwise_udf!(ShiftRightUdf, "uni.bitwise.shiftRight", |l, r| l >> r);
1663
1664define_unary_bitwise_udf!(BitwiseNotUdf, "uni.bitwise.not", |v| !v);
1666
1667pub fn create_bitwise_or_udf() -> ScalarUDF {
1669 ScalarUDF::new_from_impl(BitwiseOrUdf::new())
1670}
1671
1672pub fn create_bitwise_and_udf() -> ScalarUDF {
1674 ScalarUDF::new_from_impl(BitwiseAndUdf::new())
1675}
1676
1677pub fn create_bitwise_xor_udf() -> ScalarUDF {
1679 ScalarUDF::new_from_impl(BitwiseXorUdf::new())
1680}
1681
1682pub fn create_bitwise_not_udf() -> ScalarUDF {
1684 ScalarUDF::new_from_impl(BitwiseNotUdf::new())
1685}
1686
1687pub fn create_shift_left_udf() -> ScalarUDF {
1689 ScalarUDF::new_from_impl(ShiftLeftUdf::new())
1690}
1691
1692pub fn create_shift_right_udf() -> ScalarUDF {
1694 ScalarUDF::new_from_impl(ShiftRightUdf::new())
1695}
1696
1697fn create_temporal_udf(name: &str) -> ScalarUDF {
1708 ScalarUDF::new_from_impl(TemporalUdf::new(name.to_string()))
1709}
1710
1711#[derive(Debug)]
1712struct TemporalUdf {
1713 name: String,
1714 signature: Signature,
1715}
1716
1717impl TemporalUdf {
1718 fn new(name: String) -> Self {
1719 Self {
1720 name,
1721 signature: Signature::new(
1724 TypeSignature::OneOf(vec![
1725 TypeSignature::Exact(vec![]),
1726 TypeSignature::VariadicAny,
1727 ]),
1728 Volatility::Immutable,
1729 ),
1730 }
1731 }
1732}
1733
1734impl_udf_eq_hash!(TemporalUdf);
1735
1736impl ScalarUDFImpl for TemporalUdf {
1737 fn as_any(&self) -> &dyn Any {
1738 self
1739 }
1740
1741 fn name(&self) -> &str {
1742 &self.name
1743 }
1744
1745 fn signature(&self) -> &Signature {
1746 &self.signature
1747 }
1748
1749 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1750 let name = self.name.to_lowercase();
1751 match name.as_str() {
1752 "year" | "month" | "day" | "hour" | "minute" | "second" => Ok(DataType::Int64),
1754 "datetime"
1760 | "localdatetime"
1761 | "date"
1762 | "time"
1763 | "localtime"
1764 | "duration"
1765 | "date.truncate"
1766 | "time.truncate"
1767 | "datetime.truncate"
1768 | "localdatetime.truncate"
1769 | "localtime.truncate"
1770 | "duration.between"
1771 | "duration.inmonths"
1772 | "duration.indays"
1773 | "duration.inseconds"
1774 | "datetime.fromepoch"
1775 | "datetime.fromepochmillis"
1776 | "datetime.transaction"
1777 | "datetime.statement"
1778 | "datetime.realtime"
1779 | "date.transaction"
1780 | "date.statement"
1781 | "date.realtime"
1782 | "time.transaction"
1783 | "time.statement"
1784 | "time.realtime"
1785 | "localtime.transaction"
1786 | "localtime.statement"
1787 | "localtime.realtime"
1788 | "localdatetime.transaction"
1789 | "localdatetime.statement"
1790 | "localdatetime.realtime"
1791 | "btic" => Ok(DataType::LargeBinary),
1792 _ => Ok(DataType::Utf8),
1793 }
1794 }
1795
1796 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1797 let func_name = self.name.to_uppercase();
1798 let output_type = self.return_type(&[])?;
1799 invoke_cypher_udf(args, &output_type, |val_args| {
1800 crate::datetime::eval_datetime_function(&func_name, val_args).map_err(|e| {
1801 datafusion::error::DataFusionError::Execution(format!("{}(): {}", self.name, e))
1802 })
1803 })
1804 }
1805}
1806
1807fn create_duration_property_udf() -> ScalarUDF {
1812 ScalarUDF::new_from_impl(DurationPropertyUdf::new())
1813}
1814
1815#[derive(Debug)]
1816struct DurationPropertyUdf {
1817 signature: Signature,
1818}
1819
1820impl DurationPropertyUdf {
1821 fn new() -> Self {
1822 Self {
1823 signature: Signature::new(
1824 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
1825 Volatility::Immutable,
1826 ),
1827 }
1828 }
1829}
1830
1831impl_udf_eq_hash!(DurationPropertyUdf);
1832
1833impl ScalarUDFImpl for DurationPropertyUdf {
1834 fn as_any(&self) -> &dyn Any {
1835 self
1836 }
1837
1838 fn name(&self) -> &str {
1839 "_duration_property"
1840 }
1841
1842 fn signature(&self) -> &Signature {
1843 &self.signature
1844 }
1845
1846 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1847 Ok(DataType::Int64)
1848 }
1849
1850 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1851 let output_type = self.return_type(&[])?;
1852 invoke_cypher_udf(args, &output_type, |val_args| {
1853 if val_args.len() != 2 {
1854 return Err(datafusion::error::DataFusionError::Execution(
1855 "_duration_property(): requires 2 arguments (duration_string, component)"
1856 .to_string(),
1857 ));
1858 }
1859
1860 let dur_string_owned;
1861 let dur_str = match &val_args[0] {
1862 Value::String(s) => s.as_str(),
1863 Value::Temporal(uni_common::TemporalValue::Duration { .. }) => {
1864 dur_string_owned = val_args[0].to_string();
1865 &dur_string_owned
1866 }
1867 Value::Null => return Ok(Value::Null),
1868 _ => {
1869 return Err(datafusion::error::DataFusionError::Execution(
1870 "_duration_property(): duration must be a string or temporal duration"
1871 .to_string(),
1872 ));
1873 }
1874 };
1875 let component = match &val_args[1] {
1876 Value::String(s) => s,
1877 _ => {
1878 return Err(datafusion::error::DataFusionError::Execution(
1879 "_duration_property(): component must be a string".to_string(),
1880 ));
1881 }
1882 };
1883
1884 crate::datetime::eval_duration_accessor(dur_str, component).map_err(|e| {
1885 datafusion::error::DataFusionError::Execution(format!(
1886 "_duration_property(): {}",
1887 e
1888 ))
1889 })
1890 })
1891 }
1892}
1893
1894fn create_tostring_udf() -> ScalarUDF {
1899 ScalarUDF::new_from_impl(ToStringUdf::new())
1900}
1901
1902#[derive(Debug)]
1903struct ToStringUdf {
1904 signature: Signature,
1905}
1906
1907impl ToStringUdf {
1908 fn new() -> Self {
1909 Self {
1910 signature: Signature::variadic_any(Volatility::Immutable),
1911 }
1912 }
1913}
1914
1915impl_udf_eq_hash!(ToStringUdf);
1916
1917impl ScalarUDFImpl for ToStringUdf {
1918 fn as_any(&self) -> &dyn Any {
1919 self
1920 }
1921
1922 fn name(&self) -> &str {
1923 "tostring"
1924 }
1925
1926 fn signature(&self) -> &Signature {
1927 &self.signature
1928 }
1929
1930 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1931 Ok(DataType::Utf8)
1932 }
1933
1934 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1935 let output_type = self.return_type(&[])?;
1936 invoke_cypher_udf(args, &output_type, |val_args| {
1937 if val_args.is_empty() {
1938 return Err(datafusion::error::DataFusionError::Execution(
1939 "toString(): requires 1 argument".to_string(),
1940 ));
1941 }
1942 match &val_args[0] {
1943 Value::Null => Ok(Value::Null),
1944 Value::String(s) => Ok(Value::String(s.clone())),
1945 Value::Int(i) => Ok(Value::String(i.to_string())),
1946 Value::Float(f) => Ok(Value::String(f.to_string())),
1947 Value::Bool(b) => Ok(Value::String(b.to_string())),
1948 Value::Temporal(t) => Ok(Value::String(t.to_string())),
1949 other => {
1950 let type_name = match other {
1951 Value::List(_) => "List",
1952 Value::Map(_) => "Map",
1953 Value::Node { .. } => "Node",
1954 Value::Edge { .. } => "Relationship",
1955 Value::Path { .. } => "Path",
1956 _ => "Unknown",
1957 };
1958 Err(datafusion::error::DataFusionError::Execution(format!(
1959 "TypeError: InvalidArgumentValue - toString() does not accept {} values",
1960 type_name
1961 )))
1962 }
1963 }
1964 })
1965 }
1966}
1967
1968fn create_temporal_property_udf() -> ScalarUDF {
1973 ScalarUDF::new_from_impl(TemporalPropertyUdf::new())
1974}
1975
1976#[derive(Debug)]
1977struct TemporalPropertyUdf {
1978 signature: Signature,
1979}
1980
1981impl TemporalPropertyUdf {
1982 fn new() -> Self {
1983 Self {
1984 signature: Signature::variadic_any(Volatility::Immutable),
1985 }
1986 }
1987}
1988
1989impl_udf_eq_hash!(TemporalPropertyUdf);
1990
1991impl ScalarUDFImpl for TemporalPropertyUdf {
1992 fn as_any(&self) -> &dyn Any {
1993 self
1994 }
1995
1996 fn name(&self) -> &str {
1997 "_temporal_property"
1998 }
1999
2000 fn signature(&self) -> &Signature {
2001 &self.signature
2002 }
2003
2004 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2005 Ok(DataType::LargeBinary)
2006 }
2007
2008 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2009 let output_type = self.return_type(&[])?;
2010 invoke_cypher_udf(args, &output_type, |val_args| {
2011 if val_args.len() != 2 {
2012 return Err(datafusion::error::DataFusionError::Execution(
2013 "_temporal_property(): requires 2 arguments (temporal_value, component)"
2014 .to_string(),
2015 ));
2016 }
2017
2018 let component = match &val_args[1] {
2019 Value::String(s) => s.clone(),
2020 _ => {
2021 return Err(datafusion::error::DataFusionError::Execution(
2022 "_temporal_property(): component must be a string".to_string(),
2023 ));
2024 }
2025 };
2026
2027 crate::datetime::eval_temporal_accessor_value(&val_args[0], &component).map_err(|e| {
2028 datafusion::error::DataFusionError::Execution(format!(
2029 "_temporal_property(): {}",
2030 e
2031 ))
2032 })
2033 })
2034 }
2035}
2036
2037macro_rules! downcast_arr {
2040 ($arr:expr, $array_type:ty) => {
2041 $arr.as_any().downcast_ref::<$array_type>().ok_or_else(|| {
2042 datafusion::error::DataFusionError::Execution(format!(
2043 "Failed to downcast to {}",
2044 stringify!($array_type)
2045 ))
2046 })?
2047 };
2048}
2049
2050fn cypher_type_name(val: &Value) -> &'static str {
2052 match val {
2053 Value::Null => "Null",
2054 Value::Bool(_) => "Boolean",
2055 Value::Int(_) => "Integer",
2056 Value::Float(_) => "Float",
2057 Value::String(_) => "String",
2058 Value::Bytes(_) => "Bytes",
2059 Value::List(_) => "List",
2060 Value::Map(_) => "Map",
2061 Value::Node(_) => "Node",
2062 Value::Edge(_) => "Relationship",
2063 Value::Path(_) => "Path",
2064 Value::Vector(_) => "Vector",
2065 Value::Temporal(_) => "Temporal",
2066 _ => "Unknown",
2067 }
2068}
2069
2070fn string_to_value(s: &str) -> Value {
2072 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2073 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2074 {
2075 return Value::from(obj);
2076 }
2077 Value::String(s.to_string())
2078}
2079
2080fn get_value_from_array(arr: &ArrayRef, row: usize) -> DFResult<Value> {
2086 if arr.is_null(row) {
2087 return Ok(Value::Null);
2088 }
2089
2090 match arr.data_type() {
2091 DataType::LargeBinary => {
2092 let typed = downcast_arr!(arr, LargeBinaryArray);
2093 let bytes = typed.value(row);
2094 if let Ok(val) = uni_common::cypher_value_codec::decode(bytes) {
2095 return Ok(val);
2096 }
2097 Ok(serde_json::from_slice::<serde_json::Value>(bytes)
2099 .map(Value::from)
2100 .unwrap_or(Value::Null))
2101 }
2102 DataType::Int64 => Ok(Value::Int(downcast_arr!(arr, Int64Array).value(row))),
2103 DataType::Float64 => Ok(Value::Float(downcast_arr!(arr, Float64Array).value(row))),
2104 DataType::Utf8 => Ok(string_to_value(downcast_arr!(arr, StringArray).value(row))),
2105 DataType::LargeUtf8 => Ok(string_to_value(
2106 downcast_arr!(arr, LargeStringArray).value(row),
2107 )),
2108 DataType::Boolean => Ok(Value::Bool(downcast_arr!(arr, BooleanArray).value(row))),
2109 DataType::UInt64 => Ok(Value::Int(downcast_arr!(arr, UInt64Array).value(row) as i64)),
2110 DataType::Int32 => Ok(Value::Int(downcast_arr!(arr, Int32Array).value(row) as i64)),
2111 DataType::Float32 => Ok(Value::Float(
2112 downcast_arr!(arr, Float32Array).value(row) as f64
2113 )),
2114 _ => {
2117 let scalar = ScalarValue::try_from_array(arr, row).map_err(|e| {
2118 datafusion::error::DataFusionError::Execution(format!(
2119 "Cannot extract scalar from array at row {}: {}",
2120 row, e
2121 ))
2122 })?;
2123 scalar_to_value(&scalar)
2124 }
2125 }
2126}
2127
2128fn get_value_args_for_row(args: &[ColumnarValue], row: usize) -> DFResult<Vec<Value>> {
2130 args.iter()
2131 .map(|arg| match arg {
2132 ColumnarValue::Scalar(scalar) => scalar_to_value(scalar),
2133 ColumnarValue::Array(arr) => get_value_from_array(arr, row),
2134 })
2135 .collect()
2136}
2137
2138fn invoke_cypher_udf<F>(
2140 args: ScalarFunctionArgs,
2141 output_type: &DataType,
2142 f: F,
2143) -> DFResult<ColumnarValue>
2144where
2145 F: Fn(&[Value]) -> DFResult<Value>,
2146{
2147 let len = args
2148 .args
2149 .iter()
2150 .find_map(|arg| match arg {
2151 ColumnarValue::Array(arr) => Some(arr.len()),
2152 _ => None,
2153 })
2154 .unwrap_or(1);
2155
2156 if len == 1
2157 && args
2158 .args
2159 .iter()
2160 .all(|a| matches!(a, ColumnarValue::Scalar(_)))
2161 {
2162 let row_args = get_value_args_for_row(&args.args, 0)?;
2163 let res = f(&row_args)?;
2164 if matches!(output_type, DataType::LargeBinary | DataType::List(_)) {
2165 let arr = values_to_array(&[res], output_type)
2167 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2168 return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, 0)?));
2169 }
2170 if res.is_null() {
2172 let typed_null = ScalarValue::try_from(output_type).unwrap_or(ScalarValue::Utf8(None));
2173 return Ok(ColumnarValue::Scalar(typed_null));
2174 }
2175 return value_to_columnar(&res);
2176 }
2177
2178 let mut results = Vec::with_capacity(len);
2179 for i in 0..len {
2180 let row_args = get_value_args_for_row(&args.args, i)?;
2181 results.push(f(&row_args)?);
2182 }
2183
2184 let arr = values_to_array(&results, output_type)
2185 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2186 Ok(ColumnarValue::Array(arr))
2187}
2188
2189fn scalar_arr_to_value(arr: &dyn arrow::array::Array) -> DFResult<Value> {
2192 if arr.is_empty() || arr.is_null(0) {
2193 Ok(Value::Null)
2194 } else {
2195 Ok(uni_store::storage::arrow_convert::arrow_to_value(
2197 arr, 0, None,
2198 ))
2199 }
2200}
2201
2202fn resolve_timezone_offset(tz_name: &str, nanos_utc: i64) -> i32 {
2204 if tz_name == "UTC" || tz_name == "Z" {
2205 return 0;
2206 }
2207 if let Ok(tz) = tz_name.parse::<chrono_tz::Tz>() {
2208 let dt = chrono::DateTime::from_timestamp_nanos(nanos_utc).with_timezone(&tz);
2209 dt.offset().fix().local_minus_utc()
2210 } else {
2211 0
2212 }
2213}
2214
2215fn duration_micros_to_value(micros: i64) -> Value {
2217 let dur = crate::datetime::CypherDuration::from_micros(micros);
2218 Value::Temporal(uni_common::TemporalValue::Duration {
2219 months: dur.months,
2220 days: dur.days,
2221 nanos: dur.nanos,
2222 })
2223}
2224
2225fn timestamp_nanos_to_value(nanos: i64, tz: Option<&Arc<str>>) -> DFResult<Value> {
2227 if let Some(tz_str) = tz {
2228 let offset = resolve_timezone_offset(tz_str.as_ref(), nanos);
2229 let tz_name = if tz_str.as_ref() == "UTC" {
2230 None
2231 } else {
2232 Some(tz_str.to_string())
2233 };
2234 Ok(Value::Temporal(uni_common::TemporalValue::DateTime {
2235 nanos_since_epoch: nanos,
2236 offset_seconds: offset,
2237 timezone_name: tz_name,
2238 }))
2239 } else {
2240 Ok(Value::Temporal(uni_common::TemporalValue::LocalDateTime {
2241 nanos_since_epoch: nanos,
2242 }))
2243 }
2244}
2245
2246pub fn scalar_to_value(scalar: &ScalarValue) -> DFResult<Value> {
2248 match scalar {
2249 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
2250 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2253 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2254 {
2255 return Ok(Value::from(obj));
2256 }
2257 Ok(Value::String(s.clone()))
2258 }
2259 ScalarValue::LargeBinary(Some(b)) => {
2260 if let Ok(val) = uni_common::cypher_value_codec::decode(b) {
2263 return Ok(val);
2264 }
2265 if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(b) {
2266 Ok(Value::from(obj))
2267 } else {
2268 Ok(Value::Null)
2269 }
2270 }
2271 ScalarValue::Int64(Some(i)) => Ok(Value::Int(*i)),
2272 ScalarValue::Int32(Some(i)) => Ok(Value::Int(*i as i64)),
2273 ScalarValue::Float64(Some(f)) => {
2274 Ok(Value::Float(*f))
2276 }
2277 ScalarValue::Boolean(Some(b)) => Ok(Value::Bool(*b)),
2278 ScalarValue::Struct(arr) => scalar_arr_to_value(arr.as_ref()),
2279 ScalarValue::List(arr) => scalar_arr_to_value(arr.as_ref()),
2280 ScalarValue::LargeList(arr) => scalar_arr_to_value(arr.as_ref()),
2281 ScalarValue::FixedSizeList(arr) => scalar_arr_to_value(arr.as_ref()),
2282 ScalarValue::UInt64(Some(u)) => Ok(Value::Int(*u as i64)),
2284 ScalarValue::UInt32(Some(u)) => Ok(Value::Int(*u as i64)),
2285 ScalarValue::UInt16(Some(u)) => Ok(Value::Int(*u as i64)),
2286 ScalarValue::UInt8(Some(u)) => Ok(Value::Int(*u as i64)),
2287 ScalarValue::Int16(Some(i)) => Ok(Value::Int(*i as i64)),
2288 ScalarValue::Int8(Some(i)) => Ok(Value::Int(*i as i64)),
2289
2290 ScalarValue::Date32(Some(days)) => Ok(Value::Temporal(uni_common::TemporalValue::Date {
2292 days_since_epoch: *days,
2293 })),
2294 ScalarValue::Date64(Some(millis)) => {
2295 let days = (*millis / 86_400_000) as i32;
2296 Ok(Value::Temporal(uni_common::TemporalValue::Date {
2297 days_since_epoch: days,
2298 }))
2299 }
2300 ScalarValue::TimestampNanosecond(Some(nanos), tz) => {
2301 timestamp_nanos_to_value(*nanos, tz.as_ref())
2302 }
2303 ScalarValue::TimestampMicrosecond(Some(micros), tz) => {
2304 timestamp_nanos_to_value(*micros * 1_000, tz.as_ref())
2305 }
2306 ScalarValue::TimestampMillisecond(Some(millis), tz) => {
2307 timestamp_nanos_to_value(*millis * 1_000_000, tz.as_ref())
2308 }
2309 ScalarValue::TimestampSecond(Some(secs), tz) => {
2310 timestamp_nanos_to_value(*secs * 1_000_000_000, tz.as_ref())
2311 }
2312 ScalarValue::Time64Nanosecond(Some(nanos)) => {
2313 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2314 nanos_since_midnight: *nanos,
2315 }))
2316 }
2317 ScalarValue::Time64Microsecond(Some(micros)) => {
2318 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2319 nanos_since_midnight: *micros * 1_000,
2320 }))
2321 }
2322 ScalarValue::IntervalMonthDayNano(Some(v)) => {
2323 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2324 months: v.months as i64,
2325 days: v.days as i64,
2326 nanos: v.nanoseconds,
2327 }))
2328 }
2329 ScalarValue::DurationMicrosecond(Some(micros)) => Ok(duration_micros_to_value(*micros)),
2330 ScalarValue::DurationMillisecond(Some(millis)) => {
2331 Ok(duration_micros_to_value(*millis * 1_000))
2332 }
2333 ScalarValue::DurationSecond(Some(secs)) => Ok(duration_micros_to_value(*secs * 1_000_000)),
2334 ScalarValue::DurationNanosecond(Some(nanos)) => {
2335 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2336 months: 0,
2337 days: 0,
2338 nanos: *nanos,
2339 }))
2340 }
2341 ScalarValue::Float32(Some(f)) => Ok(Value::Float(*f as f64)),
2342
2343 ScalarValue::FixedSizeBinary(24, Some(bytes)) => {
2345 match uni_btic::encode::decode_slice(bytes) {
2346 Ok(btic) => Ok(Value::Temporal(uni_common::TemporalValue::Btic {
2347 lo: btic.lo(),
2348 hi: btic.hi(),
2349 meta: btic.meta(),
2350 })),
2351 Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
2352 "BTIC decode error: {e}"
2353 ))),
2354 }
2355 }
2356
2357 ScalarValue::Null
2359 | ScalarValue::Utf8(None)
2360 | ScalarValue::LargeUtf8(None)
2361 | ScalarValue::LargeBinary(None)
2362 | ScalarValue::Int64(None)
2363 | ScalarValue::Int32(None)
2364 | ScalarValue::Int16(None)
2365 | ScalarValue::Int8(None)
2366 | ScalarValue::UInt64(None)
2367 | ScalarValue::UInt32(None)
2368 | ScalarValue::UInt16(None)
2369 | ScalarValue::UInt8(None)
2370 | ScalarValue::Float64(None)
2371 | ScalarValue::Float32(None)
2372 | ScalarValue::Boolean(None)
2373 | ScalarValue::Date32(None)
2374 | ScalarValue::Date64(None)
2375 | ScalarValue::TimestampMicrosecond(None, _)
2376 | ScalarValue::TimestampMillisecond(None, _)
2377 | ScalarValue::TimestampSecond(None, _)
2378 | ScalarValue::TimestampNanosecond(None, _)
2379 | ScalarValue::Time64Microsecond(None)
2380 | ScalarValue::Time64Nanosecond(None)
2381 | ScalarValue::DurationMicrosecond(None)
2382 | ScalarValue::DurationMillisecond(None)
2383 | ScalarValue::DurationSecond(None)
2384 | ScalarValue::DurationNanosecond(None)
2385 | ScalarValue::IntervalMonthDayNano(None)
2386 | ScalarValue::FixedSizeBinary(_, None) => Ok(Value::Null),
2387 other => Err(datafusion::error::DataFusionError::Execution(format!(
2388 "scalar_to_value(): unsupported scalar type {other:?}"
2389 ))),
2390 }
2391}
2392
2393fn value_to_columnar(val: &Value) -> DFResult<ColumnarValue> {
2395 let scalar = match val {
2396 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
2397 Value::Int(i) => ScalarValue::Int64(Some(*i)),
2398 Value::Float(f) => ScalarValue::Float64(Some(*f)),
2399 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
2400 Value::Null => ScalarValue::Utf8(None),
2401 Value::Temporal(tv) => {
2402 use uni_common::TemporalValue;
2403 match tv {
2404 TemporalValue::Date { days_since_epoch } => {
2405 ScalarValue::Date32(Some(*days_since_epoch))
2406 }
2407 TemporalValue::LocalTime {
2408 nanos_since_midnight,
2409 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2410 TemporalValue::Time {
2411 nanos_since_midnight,
2412 ..
2413 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2414 TemporalValue::LocalDateTime { nanos_since_epoch } => {
2415 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None)
2416 }
2417 TemporalValue::DateTime {
2418 nanos_since_epoch,
2419 timezone_name,
2420 ..
2421 } => {
2422 let tz = timezone_name.as_deref().unwrap_or("UTC");
2423 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), Some(tz.into()))
2424 }
2425 TemporalValue::Duration {
2426 months,
2427 days,
2428 nanos,
2429 } => ScalarValue::IntervalMonthDayNano(Some(
2430 arrow::datatypes::IntervalMonthDayNano {
2431 months: *months as i32,
2432 days: *days as i32,
2433 nanoseconds: *nanos,
2434 },
2435 )),
2436 TemporalValue::Btic { lo, hi, meta } => {
2437 let btic = uni_btic::Btic::new(*lo, *hi, *meta).map_err(|e| {
2438 datafusion::error::DataFusionError::Execution(format!("invalid BTIC: {e}"))
2439 })?;
2440 let packed = uni_btic::encode::encode(&btic);
2441 ScalarValue::FixedSizeBinary(24, Some(packed.to_vec()))
2442 }
2443 }
2444 }
2445 other => {
2446 return Err(datafusion::error::DataFusionError::Execution(format!(
2447 "value_to_columnar(): unsupported type {other:?}"
2448 )));
2449 }
2450 };
2451 Ok(ColumnarValue::Scalar(scalar))
2452}
2453
2454pub fn create_has_null_udf() -> ScalarUDF {
2460 ScalarUDF::new_from_impl(HasNullUdf::new())
2461}
2462
2463#[derive(Debug)]
2464struct HasNullUdf {
2465 signature: Signature,
2466}
2467
2468impl HasNullUdf {
2469 fn new() -> Self {
2470 Self {
2471 signature: Signature::any(1, Volatility::Immutable),
2472 }
2473 }
2474}
2475
2476impl_udf_eq_hash!(HasNullUdf);
2477
2478impl ScalarUDFImpl for HasNullUdf {
2479 fn as_any(&self) -> &dyn Any {
2480 self
2481 }
2482
2483 fn name(&self) -> &str {
2484 "_has_null"
2485 }
2486
2487 fn signature(&self) -> &Signature {
2488 &self.signature
2489 }
2490
2491 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2492 Ok(DataType::Boolean)
2493 }
2494
2495 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2496 if args.args.len() != 1 {
2497 return Err(datafusion::error::DataFusionError::Execution(
2498 "_has_null(): requires 1 argument".to_string(),
2499 ));
2500 }
2501
2502 fn check_list_nulls<T: arrow_array::OffsetSizeTrait>(
2504 arr: &arrow_array::GenericListArray<T>,
2505 idx: usize,
2506 ) -> bool {
2507 if arr.is_null(idx) || arr.is_empty() {
2508 false
2509 } else {
2510 arr.value(idx).null_count() > 0
2511 }
2512 }
2513
2514 match &args.args[0] {
2515 ColumnarValue::Scalar(scalar) => {
2516 let has_null = match scalar {
2517 ScalarValue::List(arr) => arr
2518 .as_any()
2519 .downcast_ref::<arrow::array::ListArray>()
2520 .map(|a| !a.is_empty() && a.value(0).null_count() > 0)
2521 .unwrap_or(arr.null_count() > 0),
2522 ScalarValue::LargeList(arr) => arr.len() > 0 && arr.value(0).null_count() > 0,
2523 ScalarValue::FixedSizeList(arr) => {
2524 arr.len() > 0 && arr.value(0).null_count() > 0
2525 }
2526 _ => false,
2527 };
2528 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(has_null))))
2529 }
2530 ColumnarValue::Array(arr) => {
2531 use arrow_array::{LargeListArray, ListArray};
2532
2533 let results: arrow::array::BooleanArray =
2534 if let Some(list_arr) = arr.as_any().downcast_ref::<ListArray>() {
2535 (0..list_arr.len())
2536 .map(|i| {
2537 if list_arr.is_null(i) {
2538 None
2539 } else {
2540 Some(check_list_nulls(list_arr, i))
2541 }
2542 })
2543 .collect()
2544 } else if let Some(large) = arr.as_any().downcast_ref::<LargeListArray>() {
2545 (0..large.len())
2546 .map(|i| {
2547 if large.is_null(i) {
2548 None
2549 } else {
2550 Some(check_list_nulls(large, i))
2551 }
2552 })
2553 .collect()
2554 } else {
2555 return Err(datafusion::error::DataFusionError::Execution(
2556 "_has_null(): requires list array".to_string(),
2557 ));
2558 };
2559 Ok(ColumnarValue::Array(Arc::new(results)))
2560 }
2561 }
2562 }
2563}
2564
2565pub fn create_to_integer_udf() -> ScalarUDF {
2570 ScalarUDF::new_from_impl(ToIntegerUdf::new())
2571}
2572
2573#[derive(Debug)]
2574struct ToIntegerUdf {
2575 signature: Signature,
2576}
2577
2578impl ToIntegerUdf {
2579 fn new() -> Self {
2580 Self {
2581 signature: Signature::any(1, Volatility::Immutable),
2582 }
2583 }
2584}
2585
2586impl_udf_eq_hash!(ToIntegerUdf);
2587
2588impl ScalarUDFImpl for ToIntegerUdf {
2589 fn as_any(&self) -> &dyn Any {
2590 self
2591 }
2592
2593 fn name(&self) -> &str {
2594 "tointeger"
2595 }
2596
2597 fn signature(&self) -> &Signature {
2598 &self.signature
2599 }
2600
2601 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2602 Ok(DataType::Int64)
2603 }
2604
2605 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2606 let output_type = self.return_type(&[])?;
2607 invoke_cypher_udf(args, &output_type, |val_args| {
2608 if val_args.is_empty() {
2609 return Err(datafusion::error::DataFusionError::Execution(
2610 "tointeger(): requires 1 argument".to_string(),
2611 ));
2612 }
2613
2614 let val = &val_args[0];
2615 match val {
2616 Value::Int(i) => Ok(Value::Int(*i)),
2617 Value::Float(f) => Ok(Value::Int(*f as i64)),
2618 Value::String(s) => {
2619 if let Ok(i) = s.parse::<i64>() {
2620 Ok(Value::Int(i))
2621 } else if let Ok(f) = s.parse::<f64>() {
2622 Ok(Value::Int(f as i64))
2623 } else {
2624 Ok(Value::Null)
2625 }
2626 }
2627 Value::Null => Ok(Value::Null),
2628 other => Err(datafusion::error::DataFusionError::Execution(format!(
2629 "InvalidArgumentValue: tointeger(): cannot convert {} to integer",
2630 cypher_type_name(other)
2631 ))),
2632 }
2633 })
2634 }
2635}
2636
2637pub fn create_to_float_udf() -> ScalarUDF {
2642 ScalarUDF::new_from_impl(ToFloatUdf::new())
2643}
2644
2645#[derive(Debug)]
2646struct ToFloatUdf {
2647 signature: Signature,
2648}
2649
2650impl ToFloatUdf {
2651 fn new() -> Self {
2652 Self {
2653 signature: Signature::any(1, Volatility::Immutable),
2654 }
2655 }
2656}
2657
2658impl_udf_eq_hash!(ToFloatUdf);
2659
2660impl ScalarUDFImpl for ToFloatUdf {
2661 fn as_any(&self) -> &dyn Any {
2662 self
2663 }
2664
2665 fn name(&self) -> &str {
2666 "tofloat"
2667 }
2668
2669 fn signature(&self) -> &Signature {
2670 &self.signature
2671 }
2672
2673 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2674 Ok(DataType::Float64)
2675 }
2676
2677 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2678 let output_type = self.return_type(&[])?;
2679 invoke_cypher_udf(args, &output_type, |val_args| {
2680 if val_args.is_empty() {
2681 return Err(datafusion::error::DataFusionError::Execution(
2682 "tofloat(): requires 1 argument".to_string(),
2683 ));
2684 }
2685
2686 let val = &val_args[0];
2687 match val {
2688 Value::Int(i) => Ok(Value::Float(*i as f64)),
2689 Value::Float(f) => Ok(Value::Float(*f)),
2690 Value::String(s) => {
2691 if let Ok(f) = s.parse::<f64>() {
2692 Ok(Value::Float(f))
2693 } else {
2694 Ok(Value::Null)
2695 }
2696 }
2697 Value::Null => Ok(Value::Null),
2698 other => Err(datafusion::error::DataFusionError::Execution(format!(
2699 "InvalidArgumentValue: tofloat(): cannot convert {} to float",
2700 cypher_type_name(other)
2701 ))),
2702 }
2703 })
2704 }
2705}
2706
2707pub fn create_to_boolean_udf() -> ScalarUDF {
2712 ScalarUDF::new_from_impl(ToBooleanUdf::new())
2713}
2714
2715#[derive(Debug)]
2716struct ToBooleanUdf {
2717 signature: Signature,
2718}
2719
2720impl ToBooleanUdf {
2721 fn new() -> Self {
2722 Self {
2723 signature: Signature::any(1, Volatility::Immutable),
2724 }
2725 }
2726}
2727
2728impl_udf_eq_hash!(ToBooleanUdf);
2729
2730impl ScalarUDFImpl for ToBooleanUdf {
2731 fn as_any(&self) -> &dyn Any {
2732 self
2733 }
2734
2735 fn name(&self) -> &str {
2736 "toboolean"
2737 }
2738
2739 fn signature(&self) -> &Signature {
2740 &self.signature
2741 }
2742
2743 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2744 Ok(DataType::Boolean)
2745 }
2746
2747 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2748 let output_type = self.return_type(&[])?;
2749 invoke_cypher_udf(args, &output_type, |val_args| {
2750 if val_args.is_empty() {
2751 return Err(datafusion::error::DataFusionError::Execution(
2752 "toboolean(): requires 1 argument".to_string(),
2753 ));
2754 }
2755
2756 let val = &val_args[0];
2757 match val {
2758 Value::Bool(b) => Ok(Value::Bool(*b)),
2759 Value::String(s) => {
2760 let s_lower = s.to_lowercase();
2761 if s_lower == "true" {
2762 Ok(Value::Bool(true))
2763 } else if s_lower == "false" {
2764 Ok(Value::Bool(false))
2765 } else {
2766 Ok(Value::Null)
2767 }
2768 }
2769 Value::Null => Ok(Value::Null),
2770 Value::Int(i) => Ok(Value::Bool(*i != 0)),
2771 other => Err(datafusion::error::DataFusionError::Execution(format!(
2772 "InvalidArgumentValue: toboolean(): cannot convert {} to boolean",
2773 cypher_type_name(other)
2774 ))),
2775 }
2776 })
2777 }
2778}
2779
2780pub fn create_cypher_sort_key_udf() -> ScalarUDF {
2787 ScalarUDF::new_from_impl(CypherSortKeyUdf::new())
2788}
2789
2790#[derive(Debug)]
2791struct CypherSortKeyUdf {
2792 signature: Signature,
2793}
2794
2795impl CypherSortKeyUdf {
2796 fn new() -> Self {
2797 Self {
2798 signature: Signature::any(1, Volatility::Immutable),
2799 }
2800 }
2801}
2802
2803impl_udf_eq_hash!(CypherSortKeyUdf);
2804
2805impl ScalarUDFImpl for CypherSortKeyUdf {
2806 fn as_any(&self) -> &dyn Any {
2807 self
2808 }
2809
2810 fn name(&self) -> &str {
2811 "_cypher_sort_key"
2812 }
2813
2814 fn signature(&self) -> &Signature {
2815 &self.signature
2816 }
2817
2818 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2819 Ok(DataType::LargeBinary)
2820 }
2821
2822 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2823 if args.args.len() != 1 {
2824 return Err(datafusion::error::DataFusionError::Execution(
2825 "_cypher_sort_key(): requires 1 argument".to_string(),
2826 ));
2827 }
2828
2829 let arg = &args.args[0];
2830 match arg {
2831 ColumnarValue::Scalar(s) => {
2832 let val = if s.is_null() {
2833 Value::Null
2834 } else {
2835 scalar_to_value(s)?
2836 };
2837 let key = encode_cypher_sort_key(&val);
2838 Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(key))))
2839 }
2840 ColumnarValue::Array(arr) => {
2841 let mut keys: Vec<Option<Vec<u8>>> = Vec::with_capacity(arr.len());
2842 for i in 0..arr.len() {
2843 let val = if arr.is_null(i) {
2844 Value::Null
2845 } else {
2846 get_value_from_array(arr, i)?
2847 };
2848 keys.push(Some(encode_cypher_sort_key(&val)));
2849 }
2850 let array = LargeBinaryArray::from(
2851 keys.iter()
2852 .map(|k| k.as_deref())
2853 .collect::<Vec<Option<&[u8]>>>(),
2854 );
2855 Ok(ColumnarValue::Array(Arc::new(array)))
2856 }
2857 }
2858 }
2859}
2860
2861pub fn encode_cypher_sort_key(value: &Value) -> Vec<u8> {
2867 let mut buf = Vec::with_capacity(32);
2868 encode_sort_key_to_buf(value, &mut buf);
2869 buf
2870}
2871
2872fn encode_sort_key_to_buf(value: &Value, buf: &mut Vec<u8>) {
2874 if let Value::Map(map) = value {
2876 if let Some(tv) = sort_key_map_as_temporal(map) {
2877 buf.push(0x07); encode_temporal_payload(&tv, buf);
2879 return;
2880 }
2881 let rank = sort_key_map_rank(map);
2882 if rank != 0 {
2883 buf.push(rank);
2885 match rank {
2886 0x01 => encode_map_as_node_payload(map, buf),
2887 0x02 => encode_map_as_edge_payload(map, buf),
2888 0x04 => encode_map_as_path_payload(map, buf),
2889 _ => {} }
2891 return;
2892 }
2893 }
2894
2895 if let Value::String(s) = value {
2897 if let Some(tv) = sort_key_string_as_temporal(s) {
2898 buf.push(0x07); encode_temporal_payload(&tv, buf);
2900 return;
2901 }
2902 if let Some(temporal_type) = crate::datetime::classify_temporal(s) {
2905 buf.push(0x07); if encode_wide_temporal_sort_key(s, temporal_type, buf) {
2907 return;
2908 }
2909 buf.pop();
2911 }
2912 }
2913
2914 let rank = sort_key_type_rank(value);
2915 buf.push(rank);
2916
2917 match value {
2918 Value::Null => {} Value::Float(f) if f.is_nan() => {} Value::Bool(b) => buf.push(if *b { 0x01 } else { 0x00 }),
2921 Value::Int(i) => {
2922 let f = *i as f64;
2923 buf.extend_from_slice(&encode_order_preserving_f64(f));
2924 }
2925 Value::Float(f) => {
2926 buf.extend_from_slice(&encode_order_preserving_f64(*f));
2927 }
2928 Value::String(s) => {
2929 byte_stuff_terminate(s.as_bytes(), buf);
2930 }
2931 Value::Temporal(tv) => {
2932 encode_temporal_payload(tv, buf);
2933 }
2934 Value::List(items) => {
2935 encode_list_payload(items, buf);
2936 }
2937 Value::Map(map) => {
2938 encode_map_payload(map, buf);
2939 }
2940 Value::Node(node) => {
2941 encode_node_payload(node, buf);
2942 }
2943 Value::Edge(edge) => {
2944 encode_edge_payload(edge, buf);
2945 }
2946 Value::Path(path) => {
2947 encode_path_payload(path, buf);
2948 }
2949 Value::Bytes(b) => {
2951 byte_stuff_terminate(b, buf);
2952 }
2953 Value::Vector(v) => {
2954 for f in v {
2955 buf.extend_from_slice(&encode_order_preserving_f64(*f as f64));
2956 }
2957 }
2958 _ => {} }
2960}
2961
2962fn sort_key_type_rank(v: &Value) -> u8 {
2966 match v {
2967 Value::Map(map) => sort_key_map_rank(map),
2968 Value::Node(_) => 0x01,
2969 Value::Edge(_) => 0x02,
2970 Value::List(_) => 0x03,
2971 Value::Path(_) => 0x04,
2972 Value::String(_) => 0x05,
2973 Value::Bool(_) => 0x06,
2974 Value::Temporal(_) => 0x07,
2975 Value::Int(_) => 0x08,
2976 Value::Float(f) if f.is_nan() => 0x09,
2977 Value::Float(_) => 0x08,
2978 Value::Null => 0x0A,
2979 Value::Bytes(_) | Value::Vector(_) => 0x0B,
2980 _ => 0x0B, }
2982}
2983
2984fn sort_key_map_rank(map: &std::collections::HashMap<String, Value>) -> u8 {
2986 if sort_key_map_as_temporal(map).is_some() {
2987 0x07
2988 } else if map.contains_key("nodes")
2989 && (map.contains_key("relationships") || map.contains_key("edges"))
2990 {
2991 0x04 } else if map.contains_key("_eid")
2993 || map.contains_key("_src")
2994 || map.contains_key("_dst")
2995 || map.contains_key("_type")
2996 || map.contains_key("_type_name")
2997 {
2998 0x02 } else if map.contains_key("_vid") || map.contains_key("_labels") || map.contains_key("_label")
3000 {
3001 0x01 } else {
3003 0x00 }
3005}
3006
3007fn sort_key_map_as_temporal(
3011 map: &std::collections::HashMap<String, Value>,
3012) -> Option<uni_common::TemporalValue> {
3013 super::expr_eval::temporal_from_map_wrapper(map)
3014}
3015
3016fn sort_key_string_as_temporal(s: &str) -> Option<uni_common::TemporalValue> {
3020 super::expr_eval::temporal_from_value(&Value::String(s.to_string()))
3021}
3022
3023fn encode_wide_temporal_sort_key(
3030 s: &str,
3031 temporal_type: uni_common::TemporalType,
3032 buf: &mut Vec<u8>,
3033) -> bool {
3034 match temporal_type {
3035 uni_common::TemporalType::LocalDateTime => {
3036 if let Some(ndt) = parse_naive_datetime(s) {
3037 buf.push(0x03); let wide_nanos = naive_datetime_to_wide_nanos(&ndt);
3039 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3040 return true;
3041 }
3042 false
3043 }
3044 uni_common::TemporalType::DateTime => {
3045 let base = if let Some(bracket_pos) = s.find('[') {
3047 &s[..bracket_pos]
3048 } else {
3049 s
3050 };
3051 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%.f%:z") {
3052 buf.push(0x04); let utc = dt.naive_utc();
3054 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
3055 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3056 return true;
3057 }
3058 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%:z") {
3059 buf.push(0x04); let utc = dt.naive_utc();
3061 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
3062 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3063 return true;
3064 }
3065 false
3066 }
3067 uni_common::TemporalType::Date => {
3068 if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
3069 && let Some(epoch) = chrono::NaiveDate::from_ymd_opt(1970, 1, 1)
3070 {
3071 buf.push(0x00); let days = nd.signed_duration_since(epoch).num_days() as i32;
3073 buf.extend_from_slice(&encode_order_preserving_i32(days));
3074 return true;
3075 }
3076 false
3077 }
3078 _ => false,
3079 }
3080}
3081
3082fn parse_naive_datetime(s: &str) -> Option<chrono::NaiveDateTime> {
3084 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
3085 .ok()
3086 .or_else(|| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").ok())
3087}
3088
3089fn naive_datetime_to_wide_nanos(ndt: &chrono::NaiveDateTime) -> i128 {
3092 let secs = ndt.and_utc().timestamp() as i128;
3093 let subsec_nanos = ndt.and_utc().timestamp_subsec_nanos() as i128;
3094 secs * 1_000_000_000 + subsec_nanos
3095}
3096
3097fn encode_map_as_node_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3099 let mut labels: Vec<String> = Vec::new();
3101 if let Some(Value::List(lbls)) = map.get("_labels") {
3102 for l in lbls {
3103 if let Value::String(s) = l {
3104 labels.push(s.clone());
3105 }
3106 }
3107 } else if let Some(Value::String(lbl)) = map.get("_label") {
3108 labels.push(lbl.clone());
3109 }
3110 labels.sort();
3111
3112 let vid = map.get("_vid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3114
3115 let labels_joined = labels.join("\x01");
3117 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3118
3119 buf.extend_from_slice(&vid.to_be_bytes());
3121
3122 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3124 for (k, v) in map {
3125 if !k.starts_with('_') {
3126 props.insert(k.clone(), v.clone());
3127 }
3128 }
3129 encode_map_payload(&props, buf);
3130}
3131
3132fn encode_map_as_edge_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3134 let edge_type = map
3135 .get("_type")
3136 .or_else(|| map.get("_type_name"))
3137 .and_then(|v| {
3138 if let Value::String(s) = v {
3139 Some(s.as_str())
3140 } else {
3141 None
3142 }
3143 })
3144 .unwrap_or("");
3145
3146 byte_stuff_terminate(edge_type.as_bytes(), buf);
3147
3148 let src = map.get("_src").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3149 let dst = map.get("_dst").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3150 let eid = map.get("_eid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3151
3152 buf.extend_from_slice(&src.to_be_bytes());
3153 buf.extend_from_slice(&dst.to_be_bytes());
3154 buf.extend_from_slice(&eid.to_be_bytes());
3155
3156 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3158 for (k, v) in map {
3159 if !k.starts_with('_') {
3160 props.insert(k.clone(), v.clone());
3161 }
3162 }
3163 encode_map_payload(&props, buf);
3164}
3165
3166fn encode_map_as_path_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3168 if let Some(Value::List(nodes)) = map.get("nodes") {
3170 encode_list_payload(nodes, buf);
3171 } else {
3172 buf.push(0x00); }
3174 let edges = map.get("relationships").or_else(|| map.get("edges"));
3176 if let Some(Value::List(edges)) = edges {
3177 encode_list_payload(edges, buf);
3178 } else {
3179 buf.push(0x00); }
3181}
3182
3183fn encode_order_preserving_f64(f: f64) -> [u8; 8] {
3190 let bits = f.to_bits();
3191 let encoded = if bits >> 63 == 1 {
3192 !bits
3194 } else {
3195 bits ^ (1u64 << 63)
3197 };
3198 encoded.to_be_bytes()
3199}
3200
3201fn encode_order_preserving_i64(i: i64) -> [u8; 8] {
3203 ((i as u64) ^ (1u64 << 63)).to_be_bytes()
3205}
3206
3207fn encode_order_preserving_i32(i: i32) -> [u8; 4] {
3209 ((i as u32) ^ (1u32 << 31)).to_be_bytes()
3210}
3211
3212fn encode_order_preserving_i128(i: i128) -> [u8; 16] {
3214 ((i as u128) ^ (1u128 << 127)).to_be_bytes()
3215}
3216
3217fn byte_stuff_terminate(data: &[u8], buf: &mut Vec<u8>) {
3222 byte_stuff(data, buf);
3223 buf.push(0x00);
3224 buf.push(0x00);
3225}
3226
3227fn byte_stuff(data: &[u8], buf: &mut Vec<u8>) {
3229 for &b in data {
3230 buf.push(b);
3231 if b == 0x00 {
3232 buf.push(0xFF);
3233 }
3234 }
3235}
3236
3237fn encode_list_payload(items: &[Value], buf: &mut Vec<u8>) {
3242 for item in items {
3243 buf.push(0x01); let elem_key = encode_cypher_sort_key(item);
3245 byte_stuff_terminate(&elem_key, buf);
3246 }
3247 buf.push(0x00); }
3249
3250fn encode_map_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3252 let mut pairs: Vec<(&String, &Value)> = map.iter().collect();
3253 pairs.sort_by_key(|(k, _)| *k);
3254
3255 for (key, value) in pairs {
3256 buf.push(0x01); byte_stuff_terminate(key.as_bytes(), buf);
3258 let val_key = encode_cypher_sort_key(value);
3259 byte_stuff_terminate(&val_key, buf);
3260 }
3261 buf.push(0x00); }
3263
3264fn encode_node_payload(node: &uni_common::Node, buf: &mut Vec<u8>) {
3268 let mut labels = node.labels.clone();
3269 labels.sort();
3270 let labels_joined = labels.join("\x01");
3271 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3272
3273 buf.extend_from_slice(&node.vid.as_u64().to_be_bytes());
3274
3275 encode_map_payload(&node.properties, buf);
3276}
3277
3278fn encode_edge_payload(edge: &uni_common::Edge, buf: &mut Vec<u8>) {
3282 byte_stuff_terminate(edge.edge_type.as_bytes(), buf);
3283
3284 buf.extend_from_slice(&edge.src.as_u64().to_be_bytes());
3285 buf.extend_from_slice(&edge.dst.as_u64().to_be_bytes());
3286 buf.extend_from_slice(&edge.eid.as_u64().to_be_bytes());
3287
3288 encode_map_payload(&edge.properties, buf);
3289}
3290
3291fn encode_path_payload(path: &uni_common::Path, buf: &mut Vec<u8>) {
3295 for node in &path.nodes {
3297 buf.push(0x01); let mut node_key = Vec::new();
3299 node_key.push(0x01); encode_node_payload(node, &mut node_key);
3301 byte_stuff_terminate(&node_key, buf);
3302 }
3303 buf.push(0x00); for edge in &path.edges {
3307 buf.push(0x01); let mut edge_key = Vec::new();
3309 edge_key.push(0x02); encode_edge_payload(edge, &mut edge_key);
3311 byte_stuff_terminate(&edge_key, buf);
3312 }
3313 buf.push(0x00); }
3315
3316fn encode_temporal_payload(tv: &uni_common::TemporalValue, buf: &mut Vec<u8>) {
3318 match tv {
3319 uni_common::TemporalValue::Date { days_since_epoch } => {
3320 buf.push(0x00); buf.extend_from_slice(&encode_order_preserving_i32(*days_since_epoch));
3322 }
3323 uni_common::TemporalValue::LocalTime {
3324 nanos_since_midnight,
3325 } => {
3326 buf.push(0x01); buf.extend_from_slice(&encode_order_preserving_i64(*nanos_since_midnight));
3328 }
3329 uni_common::TemporalValue::Time {
3330 nanos_since_midnight,
3331 offset_seconds,
3332 } => {
3333 buf.push(0x02); let utc_nanos =
3335 *nanos_since_midnight as i128 - (*offset_seconds as i128) * 1_000_000_000;
3336 buf.extend_from_slice(&encode_order_preserving_i128(utc_nanos));
3337 }
3338 uni_common::TemporalValue::LocalDateTime { nanos_since_epoch } => {
3339 buf.push(0x03); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3342 }
3343 uni_common::TemporalValue::DateTime {
3344 nanos_since_epoch, ..
3345 } => {
3346 buf.push(0x04); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3349 }
3350 uni_common::TemporalValue::Duration {
3351 months,
3352 days,
3353 nanos,
3354 } => {
3355 buf.push(0x05); buf.extend_from_slice(&encode_order_preserving_i64(*months));
3357 buf.extend_from_slice(&encode_order_preserving_i64(*days));
3358 buf.extend_from_slice(&encode_order_preserving_i64(*nanos));
3359 }
3360 uni_common::TemporalValue::Btic { lo, hi, meta } => {
3361 buf.push(0x06); if let Ok(btic) = uni_btic::Btic::new(*lo, *hi, *meta) {
3364 buf.extend_from_slice(&uni_btic::encode::encode(&btic));
3365 } else {
3366 buf.extend_from_slice(&encode_order_preserving_i64(*lo));
3367 buf.extend_from_slice(&encode_order_preserving_i64(*hi));
3368 }
3369 }
3370 }
3371}
3372
3373#[derive(Debug)]
3382struct BticScalarUdf {
3383 name: String,
3384 signature: Signature,
3385 return_type: DataType,
3386}
3387
3388impl BticScalarUdf {
3389 fn new(name: &str, num_args: usize, return_type: DataType) -> Self {
3390 Self {
3391 name: name.to_string(),
3392 signature: Signature::new(TypeSignature::Any(num_args), Volatility::Immutable),
3393 return_type,
3394 }
3395 }
3396}
3397
3398impl_udf_eq_hash!(BticScalarUdf);
3399
3400impl ScalarUDFImpl for BticScalarUdf {
3401 fn as_any(&self) -> &dyn Any {
3402 self
3403 }
3404 fn name(&self) -> &str {
3405 &self.name
3406 }
3407 fn signature(&self) -> &Signature {
3408 &self.signature
3409 }
3410 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
3411 Ok(self.return_type.clone())
3412 }
3413 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3414 let fname = self.name.to_uppercase();
3415 let rt = self.return_type.clone();
3416 invoke_cypher_udf(args, &rt, |val_args| {
3417 crate::expr_eval::eval_btic_function(&fname, val_args)
3418 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
3419 })
3420 }
3421}
3422
3423fn register_btic_scalar_udfs(ctx: &SessionContext) -> DFResult<()> {
3425 for name in &["btic_lo", "btic_hi"] {
3427 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3428 name,
3429 1,
3430 DataType::LargeBinary,
3431 )));
3432 }
3433 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3435 "btic_duration",
3436 1,
3437 DataType::Int64,
3438 )));
3439 for name in &["btic_is_instant", "btic_is_unbounded", "btic_is_finite"] {
3441 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3442 name,
3443 1,
3444 DataType::Boolean,
3445 )));
3446 }
3447 for name in &[
3449 "btic_granularity",
3450 "btic_lo_granularity",
3451 "btic_hi_granularity",
3452 "btic_certainty",
3453 "btic_lo_certainty",
3454 "btic_hi_certainty",
3455 ] {
3456 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3457 name,
3458 1,
3459 DataType::Utf8,
3460 )));
3461 }
3462 for name in &[
3464 "btic_contains_point",
3465 "btic_overlaps",
3466 "btic_contains",
3467 "btic_before",
3468 "btic_after",
3469 "btic_meets",
3470 "btic_adjacent",
3471 "btic_disjoint",
3472 "btic_equals",
3473 "btic_starts",
3474 "btic_during",
3475 "btic_finishes",
3476 ] {
3477 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3478 name,
3479 2,
3480 DataType::Boolean,
3481 )));
3482 }
3483 for name in &["btic_intersection", "btic_span", "btic_gap"] {
3485 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3486 name,
3487 2,
3488 DataType::LargeBinary,
3489 )));
3490 }
3491 Ok(())
3492}
3493
3494#[derive(Debug, Clone)]
3500struct BticMinMaxUdaf {
3501 name: String,
3502 signature: Signature,
3503 is_max: bool,
3504}
3505
3506impl BticMinMaxUdaf {
3507 fn new(is_max: bool) -> Self {
3508 Self {
3509 name: (if is_max { "btic_max" } else { "btic_min" }).to_string(),
3510 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3511 is_max,
3512 }
3513 }
3514}
3515
3516impl_udf_eq_hash!(BticMinMaxUdaf);
3517
3518impl AggregateUDFImpl for BticMinMaxUdaf {
3519 fn as_any(&self) -> &dyn Any {
3520 self
3521 }
3522 fn name(&self) -> &str {
3523 &self.name
3524 }
3525 fn signature(&self) -> &Signature {
3526 &self.signature
3527 }
3528 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3529 Ok(DataType::LargeBinary)
3530 }
3531 fn accumulator(
3532 &self,
3533 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3534 ) -> DFResult<Box<dyn DfAccumulator>> {
3535 Ok(Box::new(BticMinMaxAccumulator {
3536 current: None,
3537 is_max: self.is_max,
3538 }))
3539 }
3540 fn state_fields(
3541 &self,
3542 args: datafusion::logical_expr::function::StateFieldsArgs,
3543 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3544 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3545 args.name,
3546 DataType::LargeBinary,
3547 true,
3548 ))])
3549 }
3550}
3551
3552#[derive(Debug)]
3553struct BticMinMaxAccumulator {
3554 current: Option<uni_btic::Btic>,
3555 is_max: bool,
3556}
3557
3558impl DfAccumulator for BticMinMaxAccumulator {
3559 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3560 let arr = &values[0];
3561 for i in 0..arr.len() {
3562 if arr.is_null(i) {
3563 continue;
3564 }
3565 let Some(btic) = decode_btic_from_array(arr, i)? else {
3566 continue;
3567 };
3568 self.current = Some(match self.current.take() {
3569 None => btic,
3570 Some(cur) => {
3571 if (self.is_max && btic > cur) || (!self.is_max && btic < cur) {
3572 btic
3573 } else {
3574 cur
3575 }
3576 }
3577 });
3578 }
3579 Ok(())
3580 }
3581 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3582 Ok(btic_to_scalar_value(self.current.as_ref()))
3583 }
3584 fn size(&self) -> usize {
3585 std::mem::size_of::<Self>()
3586 }
3587 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3588 Ok(vec![self.evaluate()?])
3589 }
3590 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3591 self.update_batch(states)
3592 }
3593}
3594
3595#[derive(Debug, Clone)]
3597struct BticSpanAggUdaf {
3598 signature: Signature,
3599}
3600
3601impl BticSpanAggUdaf {
3602 fn new() -> Self {
3603 Self {
3604 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3605 }
3606 }
3607}
3608
3609impl_udf_eq_hash!(BticSpanAggUdaf);
3610
3611impl AggregateUDFImpl for BticSpanAggUdaf {
3612 fn as_any(&self) -> &dyn Any {
3613 self
3614 }
3615 fn name(&self) -> &str {
3616 "btic_span_agg"
3617 }
3618 fn signature(&self) -> &Signature {
3619 &self.signature
3620 }
3621 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3622 Ok(DataType::LargeBinary)
3623 }
3624 fn accumulator(
3625 &self,
3626 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3627 ) -> DFResult<Box<dyn DfAccumulator>> {
3628 Ok(Box::new(BticSpanAggAccumulator { current: None }))
3629 }
3630 fn state_fields(
3631 &self,
3632 args: datafusion::logical_expr::function::StateFieldsArgs,
3633 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3634 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3635 args.name,
3636 DataType::LargeBinary,
3637 true,
3638 ))])
3639 }
3640}
3641
3642#[derive(Debug)]
3643struct BticSpanAggAccumulator {
3644 current: Option<uni_btic::Btic>,
3645}
3646
3647impl DfAccumulator for BticSpanAggAccumulator {
3648 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3649 let arr = &values[0];
3650 for i in 0..arr.len() {
3651 if arr.is_null(i) {
3652 continue;
3653 }
3654 let Some(btic) = decode_btic_from_array(arr, i)? else {
3655 continue;
3656 };
3657 self.current = Some(match self.current.take() {
3658 None => btic,
3659 Some(cur) => uni_btic::set_ops::span(&cur, &btic),
3660 });
3661 }
3662 Ok(())
3663 }
3664 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3665 Ok(btic_to_scalar_value(self.current.as_ref()))
3666 }
3667 fn size(&self) -> usize {
3668 std::mem::size_of::<Self>()
3669 }
3670 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3671 Ok(vec![self.evaluate()?])
3672 }
3673 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3674 self.update_batch(states)
3675 }
3676}
3677
3678#[derive(Debug, Clone)]
3680struct BticCountAtUdaf {
3681 signature: Signature,
3682}
3683
3684impl BticCountAtUdaf {
3685 fn new() -> Self {
3686 Self {
3687 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
3688 }
3689 }
3690}
3691
3692impl_udf_eq_hash!(BticCountAtUdaf);
3693
3694impl AggregateUDFImpl for BticCountAtUdaf {
3695 fn as_any(&self) -> &dyn Any {
3696 self
3697 }
3698 fn name(&self) -> &str {
3699 "btic_count_at"
3700 }
3701 fn signature(&self) -> &Signature {
3702 &self.signature
3703 }
3704 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3705 Ok(DataType::Int64)
3706 }
3707 fn accumulator(
3708 &self,
3709 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3710 ) -> DFResult<Box<dyn DfAccumulator>> {
3711 Ok(Box::new(BticCountAtAccumulator { count: 0 }))
3712 }
3713 fn state_fields(
3714 &self,
3715 args: datafusion::logical_expr::function::StateFieldsArgs,
3716 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3717 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3718 args.name,
3719 DataType::Int64,
3720 true,
3721 ))])
3722 }
3723}
3724
3725#[derive(Debug)]
3726struct BticCountAtAccumulator {
3727 count: i64,
3728}
3729
3730impl DfAccumulator for BticCountAtAccumulator {
3731 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3732 if values.len() < 2 {
3733 return Ok(());
3734 }
3735 let btic_arr = &values[0];
3736 let point_arr = &values[1];
3737
3738 for i in 0..btic_arr.len() {
3739 if btic_arr.is_null(i) || point_arr.is_null(i) {
3740 continue;
3741 }
3742 let Some(btic) = decode_btic_from_array(btic_arr, i)? else {
3743 continue;
3744 };
3745
3746 let point_ms = if let Some(int_arr) = point_arr.as_any().downcast_ref::<Int64Array>() {
3748 int_arr.value(i)
3749 } else if let Some(lb) = point_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3750 let val = scalar_binary_to_value(lb.value(i));
3751 match &val {
3752 Value::Int(ms) => *ms,
3753 Value::Temporal(uni_common::TemporalValue::DateTime {
3754 nanos_since_epoch,
3755 ..
3756 }) => nanos_since_epoch / 1_000_000,
3757 _ => continue,
3758 }
3759 } else {
3760 continue;
3761 };
3762
3763 if uni_btic::predicates::contains_point(&btic, point_ms) {
3764 self.count += 1;
3765 }
3766 }
3767 Ok(())
3768 }
3769 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3770 Ok(ScalarValue::Int64(Some(self.count)))
3771 }
3772 fn size(&self) -> usize {
3773 std::mem::size_of::<Self>()
3774 }
3775 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3776 Ok(vec![ScalarValue::Int64(Some(self.count))])
3777 }
3778 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3779 let arr = &states[0];
3780 if let Some(int_arr) = arr.as_any().downcast_ref::<Int64Array>() {
3781 for i in 0..int_arr.len() {
3782 if !int_arr.is_null(i) {
3783 self.count += int_arr.value(i);
3784 }
3785 }
3786 }
3787 Ok(())
3788 }
3789}
3790
3791fn btic_to_scalar_value(btic: Option<&uni_btic::Btic>) -> ScalarValue {
3793 match btic {
3794 None => ScalarValue::LargeBinary(None),
3795 Some(b) => {
3796 let val = Value::Temporal(uni_common::TemporalValue::Btic {
3797 lo: b.lo(),
3798 hi: b.hi(),
3799 meta: b.meta(),
3800 });
3801 ScalarValue::LargeBinary(Some(uni_common::cypher_value_codec::encode(&val)))
3802 }
3803 }
3804}
3805
3806fn decode_btic_from_array(arr: &ArrayRef, row: usize) -> DFResult<Option<uni_btic::Btic>> {
3808 if let Some(fsb) = arr.as_any().downcast_ref::<FixedSizeBinaryArray>() {
3810 let bytes = fsb.value(row);
3811 return uni_btic::encode::decode_slice(bytes)
3812 .map(Some)
3813 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3814 }
3815 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
3817 let val = scalar_binary_to_value(lb.value(row));
3818 if let Value::Temporal(uni_common::TemporalValue::Btic { lo, hi, meta }) = val {
3819 return uni_btic::Btic::new(lo, hi, meta)
3820 .map(Some)
3821 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3822 }
3823 return Ok(None);
3824 }
3825 Ok(None)
3826}
3827
3828pub fn create_btic_min_udaf() -> AggregateUDF {
3829 AggregateUDF::from(BticMinMaxUdaf::new(false))
3830}
3831
3832pub fn create_btic_max_udaf() -> AggregateUDF {
3833 AggregateUDF::from(BticMinMaxUdaf::new(true))
3834}
3835
3836pub fn create_btic_span_agg_udaf() -> AggregateUDF {
3837 AggregateUDF::from(BticSpanAggUdaf::new())
3838}
3839
3840pub fn create_btic_count_at_udaf() -> AggregateUDF {
3841 AggregateUDF::from(BticCountAtUdaf::new())
3842}
3843
3844pub fn invoke_cypher_string_op<F>(
3849 args: &ScalarFunctionArgs,
3850 name: &str,
3851 op: F,
3852) -> DFResult<ColumnarValue>
3853where
3854 F: Fn(&str, &str) -> bool,
3855{
3856 use arrow_array::{BooleanArray, LargeBinaryArray, LargeStringArray, StringArray};
3857 use datafusion::common::ScalarValue;
3858 use datafusion::error::DataFusionError;
3859
3860 if args.args.len() != 2 {
3861 return Err(DataFusionError::Execution(format!(
3862 "{}(): requires exactly 2 arguments",
3863 name
3864 )));
3865 }
3866
3867 let left = &args.args[0];
3868 let right = &args.args[1];
3869
3870 let extract_string = |scalar: &ScalarValue| -> Option<String> {
3872 match scalar {
3873 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => Some(s.clone()),
3874 ScalarValue::LargeBinary(Some(bytes)) => {
3875 match uni_common::cypher_value_codec::decode(bytes) {
3877 Ok(uni_common::Value::String(s)) => Some(s),
3878 _ => None,
3879 }
3880 }
3881 ScalarValue::Utf8(None)
3882 | ScalarValue::LargeUtf8(None)
3883 | ScalarValue::LargeBinary(None)
3884 | ScalarValue::Null => None,
3885 _ => None,
3886 }
3887 };
3888
3889 match (left, right) {
3890 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Scalar(r_scalar)) => {
3891 let l_str = extract_string(l_scalar);
3892 let r_str = extract_string(r_scalar);
3893
3894 match (l_str, r_str) {
3895 (Some(l), Some(r)) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(op(
3896 &l, &r,
3897 ))))),
3898 _ => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
3899 }
3900 }
3901 (ColumnarValue::Array(l_arr), ColumnarValue::Scalar(r_scalar)) => {
3902 let r_val = extract_string(r_scalar);
3904
3905 if r_val.is_none() {
3906 let nulls = arrow_array::new_null_array(&DataType::Boolean, l_arr.len());
3908 return Ok(ColumnarValue::Array(nulls));
3909 }
3910 let pattern = r_val.unwrap();
3911
3912 let result_array = if let Some(arr) = l_arr.as_any().downcast_ref::<StringArray>() {
3914 arr.iter()
3915 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3916 .collect::<BooleanArray>()
3917 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeStringArray>() {
3918 arr.iter()
3919 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3920 .collect::<BooleanArray>()
3921 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3922 arr.iter()
3924 .map(|opt_bytes| {
3925 opt_bytes.and_then(|bytes| {
3926 match uni_common::cypher_value_codec::decode(bytes) {
3927 Ok(uni_common::Value::String(s)) => Some(op(&s, &pattern)),
3928 _ => None,
3929 }
3930 })
3931 })
3932 .collect::<BooleanArray>()
3933 } else {
3934 arrow_array::new_null_array(&DataType::Boolean, l_arr.len())
3936 .as_any()
3937 .downcast_ref::<BooleanArray>()
3938 .unwrap()
3939 .clone()
3940 };
3941
3942 Ok(ColumnarValue::Array(Arc::new(result_array)))
3943 }
3944 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Array(r_arr)) => {
3945 let l_val = extract_string(l_scalar);
3947
3948 if l_val.is_none() {
3949 let nulls = arrow_array::new_null_array(&DataType::Boolean, r_arr.len());
3950 return Ok(ColumnarValue::Array(nulls));
3951 }
3952 let target = l_val.unwrap();
3953
3954 let result_array = if let Some(arr) = r_arr.as_any().downcast_ref::<StringArray>() {
3955 arr.iter()
3956 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3957 .collect::<BooleanArray>()
3958 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeStringArray>() {
3959 arr.iter()
3960 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3961 .collect::<BooleanArray>()
3962 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3963 arr.iter()
3965 .map(|opt_bytes| {
3966 opt_bytes.and_then(|bytes| {
3967 match uni_common::cypher_value_codec::decode(bytes) {
3968 Ok(uni_common::Value::String(s)) => Some(op(&target, &s)),
3969 _ => None,
3970 }
3971 })
3972 })
3973 .collect::<BooleanArray>()
3974 } else {
3975 arrow_array::new_null_array(&DataType::Boolean, r_arr.len())
3977 .as_any()
3978 .downcast_ref::<BooleanArray>()
3979 .unwrap()
3980 .clone()
3981 };
3982
3983 Ok(ColumnarValue::Array(Arc::new(result_array)))
3984 }
3985 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
3986 if l_arr.len() != r_arr.len() {
3988 return Err(DataFusionError::Execution(format!(
3989 "{}(): array lengths must match",
3990 name
3991 )));
3992 }
3993
3994 let extract_string_at = |arr: &dyn Array, idx: usize| -> Option<String> {
3996 if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
3997 str_arr.value(idx).to_string().into()
3998 } else if let Some(str_arr) = arr.as_any().downcast_ref::<LargeStringArray>() {
3999 str_arr.value(idx).to_string().into()
4000 } else if let Some(bin_arr) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
4001 if bin_arr.is_null(idx) {
4002 return None;
4003 }
4004 let bytes = bin_arr.value(idx);
4005 match uni_common::cypher_value_codec::decode(bytes) {
4006 Ok(uni_common::Value::String(s)) => Some(s),
4007 _ => None,
4008 }
4009 } else {
4010 None
4011 }
4012 };
4013
4014 let result: BooleanArray = (0..l_arr.len())
4015 .map(|idx| {
4016 match (
4017 extract_string_at(l_arr.as_ref(), idx),
4018 extract_string_at(r_arr.as_ref(), idx),
4019 ) {
4020 (Some(l_str), Some(r_str)) => Some(op(&l_str, &r_str)),
4021 _ => None,
4022 }
4023 })
4024 .collect();
4025
4026 Ok(ColumnarValue::Array(Arc::new(result)))
4027 }
4028 }
4029}
4030
4031macro_rules! define_string_op_udf {
4032 ($struct_name:ident, $udf_name:literal, $op:expr) => {
4033 #[derive(Debug)]
4034 struct $struct_name {
4035 signature: Signature,
4036 }
4037
4038 impl $struct_name {
4039 fn new() -> Self {
4040 Self {
4041 signature: Signature::any(2, Volatility::Immutable),
4043 }
4044 }
4045 }
4046
4047 impl_udf_eq_hash!($struct_name);
4048
4049 impl ScalarUDFImpl for $struct_name {
4050 fn as_any(&self) -> &dyn Any {
4051 self
4052 }
4053 fn name(&self) -> &str {
4054 $udf_name
4055 }
4056 fn signature(&self) -> &Signature {
4057 &self.signature
4058 }
4059 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4060 Ok(DataType::Boolean)
4061 }
4062
4063 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4064 invoke_cypher_string_op(&args, $udf_name, $op)
4065 }
4066 }
4067 };
4068}
4069
4070define_string_op_udf!(CypherStartsWithUdf, "_cypher_starts_with", |s, p| s
4071 .starts_with(p));
4072define_string_op_udf!(CypherEndsWithUdf, "_cypher_ends_with", |s, p| s
4073 .ends_with(p));
4074define_string_op_udf!(CypherContainsUdf, "_cypher_contains", |s, p| s.contains(p));
4075
4076pub fn create_cypher_starts_with_udf() -> ScalarUDF {
4077 ScalarUDF::new_from_impl(CypherStartsWithUdf::new())
4078}
4079pub fn create_cypher_ends_with_udf() -> ScalarUDF {
4080 ScalarUDF::new_from_impl(CypherEndsWithUdf::new())
4081}
4082pub fn create_cypher_contains_udf() -> ScalarUDF {
4083 ScalarUDF::new_from_impl(CypherContainsUdf::new())
4084}
4085
4086pub fn create_cypher_equal_udf() -> ScalarUDF {
4087 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_equal", BinaryOp::Eq))
4088}
4089pub fn create_cypher_not_equal_udf() -> ScalarUDF {
4090 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_not_equal", BinaryOp::NotEq))
4091}
4092pub fn create_cypher_lt_udf() -> ScalarUDF {
4093 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt", BinaryOp::Lt))
4094}
4095pub fn create_cypher_lt_eq_udf() -> ScalarUDF {
4096 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt_eq", BinaryOp::LtEq))
4097}
4098pub fn create_cypher_gt_udf() -> ScalarUDF {
4099 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt", BinaryOp::Gt))
4100}
4101pub fn create_cypher_gt_eq_udf() -> ScalarUDF {
4102 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt_eq", BinaryOp::GtEq))
4103}
4104
4105#[expect(clippy::match_like_matches_macro)]
4107fn apply_comparison_op(ord: std::cmp::Ordering, op: &BinaryOp) -> bool {
4108 use std::cmp::Ordering;
4109 match (ord, op) {
4110 (Ordering::Less, BinaryOp::Lt | BinaryOp::LtEq | BinaryOp::NotEq) => true,
4111 (Ordering::Equal, BinaryOp::Eq | BinaryOp::LtEq | BinaryOp::GtEq) => true,
4112 (Ordering::Greater, BinaryOp::Gt | BinaryOp::GtEq | BinaryOp::NotEq) => true,
4113 _ => false,
4114 }
4115}
4116
4117fn compare_f64(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<bool> {
4120 if lhs.is_nan() || rhs.is_nan() {
4121 Some(matches!(op, BinaryOp::NotEq))
4122 } else {
4123 Some(apply_comparison_op(lhs.partial_cmp(&rhs)?, op))
4124 }
4125}
4126
4127fn cv_bytes_as_f64(bytes: &[u8]) -> Option<f64> {
4129 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4130 match peek_tag(bytes)? {
4131 TAG_INT => decode_int(bytes).map(|i| i as f64),
4132 TAG_FLOAT => decode_float(bytes),
4133 _ => None,
4134 }
4135}
4136
4137fn compare_cv_numeric(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<bool> {
4140 use uni_common::cypher_value_codec::{TAG_INT, TAG_NULL, decode_int, peek_tag};
4141 if peek_tag(bytes) == Some(TAG_INT)
4143 && let Some(lhs_int) = decode_int(bytes)
4144 && rhs.fract() == 0.0
4146 && rhs >= i64::MIN as f64
4147 && rhs <= i64::MAX as f64
4148 {
4149 return Some(apply_comparison_op(lhs_int.cmp(&(rhs as i64)), op));
4150 }
4151 if peek_tag(bytes) == Some(TAG_NULL) {
4152 return None;
4153 }
4154 let lhs = cv_bytes_as_f64(bytes)?;
4155 compare_f64(lhs, rhs, op)
4156}
4157
4158fn try_fast_compare(
4162 lhs: &ColumnarValue,
4163 rhs: &ColumnarValue,
4164 op: &BinaryOp,
4165) -> Option<ColumnarValue> {
4166 use arrow_array::builder::BooleanBuilder;
4167 use uni_common::cypher_value_codec::{
4168 TAG_INT, TAG_NULL, TAG_STRING, decode_int, decode_string, peek_tag,
4169 };
4170
4171 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4172 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4173 _ => return None,
4174 };
4175
4176 if !matches!(lhs_arr.data_type(), DataType::LargeBinary) {
4178 return None;
4179 }
4180
4181 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4182
4183 match rhs_arr.data_type() {
4184 DataType::Int64 => {
4186 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4187 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4188 for i in 0..lb_arr.len() {
4189 if lb_arr.is_null(i) || int_arr.is_null(i) {
4190 builder.append_null();
4191 } else {
4192 match compare_cv_numeric(lb_arr.value(i), int_arr.value(i) as f64, op) {
4193 Some(result) => builder.append_value(result),
4194 None => builder.append_null(),
4195 }
4196 }
4197 }
4198 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4199 }
4200
4201 DataType::Float64 => {
4203 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4204 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4205 for i in 0..lb_arr.len() {
4206 if lb_arr.is_null(i) || float_arr.is_null(i) {
4207 builder.append_null();
4208 } else {
4209 match compare_cv_numeric(lb_arr.value(i), float_arr.value(i), op) {
4210 Some(result) => builder.append_value(result),
4211 None => builder.append_null(),
4212 }
4213 }
4214 }
4215 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4216 }
4217
4218 DataType::Utf8 | DataType::LargeUtf8 => {
4220 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4221 for i in 0..lb_arr.len() {
4222 if lb_arr.is_null(i) || rhs_arr.is_null(i) {
4223 builder.append_null();
4224 } else {
4225 let bytes = lb_arr.value(i);
4226 let rhs_str = if matches!(rhs_arr.data_type(), DataType::Utf8) {
4227 rhs_arr.as_any().downcast_ref::<StringArray>()?.value(i)
4228 } else {
4229 rhs_arr
4230 .as_any()
4231 .downcast_ref::<LargeStringArray>()?
4232 .value(i)
4233 };
4234 match peek_tag(bytes) {
4235 Some(TAG_STRING) => {
4236 if let Some(lhs_str) = decode_string(bytes) {
4237 builder.append_value(apply_comparison_op(
4238 lhs_str.as_str().cmp(rhs_str),
4239 op,
4240 ));
4241 } else {
4242 builder.append_null();
4243 }
4244 }
4245 _ => builder.append_null(),
4246 }
4247 }
4248 }
4249 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4250 }
4251
4252 DataType::LargeBinary => {
4254 let rhs_lb = rhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4255 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4256 for i in 0..lb_arr.len() {
4257 if lb_arr.is_null(i) || rhs_lb.is_null(i) {
4258 builder.append_null();
4259 } else {
4260 let lhs_bytes = lb_arr.value(i);
4261 let rhs_bytes = rhs_lb.value(i);
4262 let lhs_tag = peek_tag(lhs_bytes);
4263 let rhs_tag = peek_tag(rhs_bytes);
4264
4265 if lhs_tag == Some(TAG_NULL) || rhs_tag == Some(TAG_NULL) {
4267 builder.append_null();
4268 continue;
4269 }
4270
4271 if lhs_tag == Some(TAG_INT) && rhs_tag == Some(TAG_INT) {
4273 if let (Some(l), Some(r)) = (decode_int(lhs_bytes), decode_int(rhs_bytes)) {
4274 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4275 } else {
4276 builder.append_null();
4277 }
4278 continue;
4279 }
4280
4281 if lhs_tag == Some(TAG_STRING) && rhs_tag == Some(TAG_STRING) {
4283 if let (Some(l), Some(r)) =
4284 (decode_string(lhs_bytes), decode_string(rhs_bytes))
4285 {
4286 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4287 } else {
4288 builder.append_null();
4289 }
4290 continue;
4291 }
4292
4293 if let (Some(l), Some(r)) =
4295 (cv_bytes_as_f64(lhs_bytes), cv_bytes_as_f64(rhs_bytes))
4296 {
4297 match compare_f64(l, r, op) {
4298 Some(result) => builder.append_value(result),
4299 None => builder.append_null(),
4300 }
4301 } else {
4302 return None;
4306 }
4307 }
4308 }
4309 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4310 }
4311
4312 _ => None, }
4314}
4315
4316#[derive(Debug)]
4317struct CypherCompareUdf {
4318 name: String,
4319 op: BinaryOp,
4320 signature: Signature,
4321}
4322
4323impl CypherCompareUdf {
4324 fn new(name: &str, op: BinaryOp) -> Self {
4325 Self {
4326 name: name.to_string(),
4327 op,
4328 signature: Signature::any(2, Volatility::Immutable),
4329 }
4330 }
4331}
4332
4333impl PartialEq for CypherCompareUdf {
4334 fn eq(&self, other: &Self) -> bool {
4335 self.name == other.name
4336 }
4337}
4338
4339impl Eq for CypherCompareUdf {}
4340
4341impl std::hash::Hash for CypherCompareUdf {
4342 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4343 self.name.hash(state);
4344 }
4345}
4346
4347impl ScalarUDFImpl for CypherCompareUdf {
4348 fn as_any(&self) -> &dyn Any {
4349 self
4350 }
4351 fn name(&self) -> &str {
4352 &self.name
4353 }
4354 fn signature(&self) -> &Signature {
4355 &self.signature
4356 }
4357 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4358 Ok(DataType::Boolean)
4359 }
4360
4361 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4362 if args.args.len() != 2 {
4363 return Err(datafusion::error::DataFusionError::Execution(format!(
4364 "{}(): requires 2 arguments",
4365 self.name
4366 )));
4367 }
4368
4369 if let Some(result) = try_fast_compare(&args.args[0], &args.args[1], &self.op) {
4371 return Ok(result);
4372 }
4373
4374 let output_type = DataType::Boolean;
4376 invoke_cypher_udf(args, &output_type, |val_args| {
4377 crate::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4378 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4379 })
4380 }
4381}
4382
4383pub fn create_cypher_add_udf() -> ScalarUDF {
4389 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_add", BinaryOp::Add))
4390}
4391pub fn create_cypher_sub_udf() -> ScalarUDF {
4392 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_sub", BinaryOp::Sub))
4393}
4394pub fn create_cypher_mul_udf() -> ScalarUDF {
4395 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mul", BinaryOp::Mul))
4396}
4397pub fn create_cypher_div_udf() -> ScalarUDF {
4398 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_div", BinaryOp::Div))
4399}
4400pub fn create_cypher_mod_udf() -> ScalarUDF {
4401 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mod", BinaryOp::Mod))
4402}
4403
4404pub fn create_cypher_abs_udf() -> ScalarUDF {
4406 ScalarUDF::new_from_impl(CypherAbsUdf::new())
4407}
4408
4409pub fn cypher_abs_expr(arg: datafusion::logical_expr::Expr) -> datafusion::logical_expr::Expr {
4411 datafusion::logical_expr::Expr::ScalarFunction(
4412 datafusion::logical_expr::expr::ScalarFunction::new_udf(
4413 Arc::new(create_cypher_abs_udf()),
4414 vec![arg],
4415 ),
4416 )
4417}
4418
4419#[derive(Debug)]
4420struct CypherAbsUdf {
4421 signature: Signature,
4422}
4423
4424impl CypherAbsUdf {
4425 fn new() -> Self {
4426 Self {
4427 signature: Signature::any(1, Volatility::Immutable),
4428 }
4429 }
4430}
4431
4432impl_udf_eq_hash!(CypherAbsUdf);
4433
4434impl ScalarUDFImpl for CypherAbsUdf {
4435 fn as_any(&self) -> &dyn Any {
4436 self
4437 }
4438 fn name(&self) -> &str {
4439 "_cypher_abs"
4440 }
4441 fn signature(&self) -> &Signature {
4442 &self.signature
4443 }
4444 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
4445 Ok(DataType::LargeBinary)
4446 }
4447 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4448 if args.args.len() != 1 {
4449 return Err(datafusion::error::DataFusionError::Execution(
4450 "_cypher_abs requires exactly 1 argument".into(),
4451 ));
4452 }
4453 invoke_cypher_udf(args, &DataType::LargeBinary, |val_args| {
4454 match &val_args[0] {
4455 Value::Int(i) => i.checked_abs().map(Value::Int).ok_or_else(|| {
4456 datafusion::error::DataFusionError::Execution(
4457 "integer overflow in abs()".into(),
4458 )
4459 }),
4460 Value::Float(f) => Ok(Value::Float(f.abs())),
4461 Value::Null => Ok(Value::Null),
4462 other => Err(datafusion::error::DataFusionError::Execution(format!(
4463 "abs() requires a numeric argument, got {other:?}"
4464 ))),
4465 }
4466 })
4467 }
4468}
4469
4470fn apply_int_arithmetic(lhs: i64, rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
4473 use uni_common::cypher_value_codec::encode_int;
4474 match op {
4475 BinaryOp::Add => lhs.checked_add(rhs).map(encode_int),
4476 BinaryOp::Sub => lhs.checked_sub(rhs).map(encode_int),
4477 BinaryOp::Mul => lhs.checked_mul(rhs).map(encode_int),
4478 BinaryOp::Div => {
4479 if rhs == 0 {
4481 None
4482 } else {
4483 lhs.checked_div(rhs).map(encode_int)
4484 }
4485 }
4486 BinaryOp::Mod => {
4487 if rhs == 0 {
4488 None
4489 } else {
4490 lhs.checked_rem(rhs).map(encode_int)
4491 }
4492 }
4493 _ => None,
4494 }
4495}
4496
4497fn apply_float_arithmetic(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
4499 use uni_common::cypher_value_codec::encode_float;
4500 let result = match op {
4501 BinaryOp::Add => lhs + rhs,
4502 BinaryOp::Sub => lhs - rhs,
4503 BinaryOp::Mul => lhs * rhs,
4504 BinaryOp::Div => lhs / rhs, BinaryOp::Mod => lhs % rhs,
4506 _ => return None,
4507 };
4508 Some(encode_float(result))
4509}
4510
4511fn cv_arithmetic_int(bytes: &[u8], rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
4514 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4515 match peek_tag(bytes)? {
4516 TAG_INT => apply_int_arithmetic(decode_int(bytes)?, rhs, op),
4517 TAG_FLOAT => apply_float_arithmetic(decode_float(bytes)?, rhs as f64, op),
4518 _ => None,
4519 }
4520}
4521
4522fn cv_arithmetic_float(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
4525 let lhs = cv_bytes_as_f64(bytes)?;
4526 apply_float_arithmetic(lhs, rhs, op)
4527}
4528
4529fn try_fast_arithmetic(
4533 lhs: &ColumnarValue,
4534 rhs: &ColumnarValue,
4535 op: &BinaryOp,
4536) -> Option<ColumnarValue> {
4537 use arrow_array::builder::LargeBinaryBuilder;
4538
4539 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4540 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4541 _ => return None,
4542 };
4543
4544 match (lhs_arr.data_type(), rhs_arr.data_type()) {
4545 (DataType::LargeBinary, DataType::Int64) => {
4547 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4548 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4549 let mut builder = LargeBinaryBuilder::new();
4550 for i in 0..lb_arr.len() {
4551 if lb_arr.is_null(i) || int_arr.is_null(i) {
4552 builder.append_null();
4553 } else if let Some(bytes) = cv_arithmetic_int(lb_arr.value(i), int_arr.value(i), op)
4554 {
4555 builder.append_value(&bytes);
4556 } else {
4557 builder.append_null();
4558 }
4559 }
4560 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4561 }
4562
4563 (DataType::LargeBinary, DataType::Float64) => {
4565 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4566 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4567 let mut builder = LargeBinaryBuilder::new();
4568 for i in 0..lb_arr.len() {
4569 if lb_arr.is_null(i) || float_arr.is_null(i) {
4570 builder.append_null();
4571 } else if let Some(bytes) =
4572 cv_arithmetic_float(lb_arr.value(i), float_arr.value(i), op)
4573 {
4574 builder.append_value(&bytes);
4575 } else {
4576 builder.append_null();
4577 }
4578 }
4579 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4580 }
4581
4582 (DataType::Int64, DataType::Int64) => {
4584 let lhs_int = lhs_arr.as_any().downcast_ref::<Int64Array>()?;
4585 let rhs_int = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4586 let mut builder = LargeBinaryBuilder::new();
4587 for i in 0..lhs_int.len() {
4588 if lhs_int.is_null(i) || rhs_int.is_null(i) {
4589 builder.append_null();
4590 } else if let Some(bytes) =
4591 apply_int_arithmetic(lhs_int.value(i), rhs_int.value(i), op)
4592 {
4593 builder.append_value(&bytes);
4594 } else {
4595 builder.append_null();
4596 }
4597 }
4598 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4599 }
4600
4601 _ => None, }
4603}
4604
4605#[derive(Debug)]
4606struct CypherArithmeticUdf {
4607 name: String,
4608 op: BinaryOp,
4609 signature: Signature,
4610}
4611
4612impl CypherArithmeticUdf {
4613 fn new(name: &str, op: BinaryOp) -> Self {
4614 Self {
4615 name: name.to_string(),
4616 op,
4617 signature: Signature::any(2, Volatility::Immutable),
4618 }
4619 }
4620}
4621
4622impl PartialEq for CypherArithmeticUdf {
4623 fn eq(&self, other: &Self) -> bool {
4624 self.name == other.name
4625 }
4626}
4627
4628impl Eq for CypherArithmeticUdf {}
4629
4630impl std::hash::Hash for CypherArithmeticUdf {
4631 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4632 self.name.hash(state);
4633 }
4634}
4635
4636impl ScalarUDFImpl for CypherArithmeticUdf {
4637 fn as_any(&self) -> &dyn Any {
4638 self
4639 }
4640 fn name(&self) -> &str {
4641 &self.name
4642 }
4643 fn signature(&self) -> &Signature {
4644 &self.signature
4645 }
4646 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4647 Ok(DataType::LargeBinary) }
4649
4650 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4651 if args.args.len() != 2 {
4652 return Err(datafusion::error::DataFusionError::Execution(format!(
4653 "{}(): requires 2 arguments",
4654 self.name
4655 )));
4656 }
4657
4658 if let Some(result) = try_fast_arithmetic(&args.args[0], &args.args[1], &self.op) {
4660 return Ok(result);
4661 }
4662
4663 let output_type = DataType::LargeBinary;
4665 invoke_cypher_udf(args, &output_type, |val_args| {
4666 crate::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4667 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4668 })
4669 }
4670}
4671
4672pub fn create_cypher_xor_udf() -> ScalarUDF {
4677 ScalarUDF::new_from_impl(CypherXorUdf::new())
4678}
4679
4680#[derive(Debug)]
4681struct CypherXorUdf {
4682 signature: Signature,
4683}
4684
4685impl CypherXorUdf {
4686 fn new() -> Self {
4687 Self {
4688 signature: Signature::any(2, Volatility::Immutable),
4689 }
4690 }
4691}
4692
4693impl_udf_eq_hash!(CypherXorUdf);
4694
4695impl ScalarUDFImpl for CypherXorUdf {
4696 fn as_any(&self) -> &dyn Any {
4697 self
4698 }
4699 fn name(&self) -> &str {
4700 "_cypher_xor"
4701 }
4702 fn signature(&self) -> &Signature {
4703 &self.signature
4704 }
4705 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4706 Ok(DataType::Boolean)
4707 }
4708
4709 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4710 let output_type = DataType::Boolean;
4711 invoke_cypher_udf(args, &output_type, |val_args| {
4712 if val_args.len() != 2 {
4713 return Err(datafusion::error::DataFusionError::Execution(
4714 "_cypher_xor(): requires 2 arguments".to_string(),
4715 ));
4716 }
4717 let coerce_bool = |v: &Value| -> Value {
4719 match v {
4720 Value::String(s) if s == "true" => Value::Bool(true),
4721 Value::String(s) if s == "false" => Value::Bool(false),
4722 other => other.clone(),
4723 }
4724 };
4725 let left = coerce_bool(&val_args[0]);
4726 let right = coerce_bool(&val_args[1]);
4727 crate::expr_eval::eval_binary_op(&left, &BinaryOp::Xor, &right)
4728 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4729 })
4730 }
4731}
4732
4733pub fn create_cv_to_bool_udf() -> ScalarUDF {
4740 ScalarUDF::new_from_impl(CvToBoolUdf::new())
4741}
4742
4743#[derive(Debug)]
4744struct CvToBoolUdf {
4745 signature: Signature,
4746}
4747
4748impl CvToBoolUdf {
4749 fn new() -> Self {
4750 Self {
4751 signature: Signature::exact(vec![DataType::LargeBinary], Volatility::Immutable),
4752 }
4753 }
4754}
4755
4756impl_udf_eq_hash!(CvToBoolUdf);
4757
4758impl ScalarUDFImpl for CvToBoolUdf {
4759 fn as_any(&self) -> &dyn Any {
4760 self
4761 }
4762 fn name(&self) -> &str {
4763 "_cv_to_bool"
4764 }
4765 fn signature(&self) -> &Signature {
4766 &self.signature
4767 }
4768 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4769 Ok(DataType::Boolean)
4770 }
4771
4772 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4773 if args.args.len() != 1 {
4774 return Err(datafusion::error::DataFusionError::Execution(
4775 "_cv_to_bool() requires exactly 1 argument".to_string(),
4776 ));
4777 }
4778
4779 match &args.args[0] {
4780 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
4781 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4783 let b = match peek_tag(bytes) {
4784 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4785 Some(TAG_NULL) => false,
4786 _ => false, };
4788 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))))
4789 }
4790 ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
4791 ColumnarValue::Array(arr) => {
4792 let lb_arr = arr
4793 .as_any()
4794 .downcast_ref::<arrow_array::LargeBinaryArray>()
4795 .ok_or_else(|| {
4796 datafusion::error::DataFusionError::Execution(format!(
4797 "_cv_to_bool(): expected LargeBinary array, got {:?}",
4798 arr.data_type()
4799 ))
4800 })?;
4801
4802 let mut builder = arrow_array::builder::BooleanBuilder::with_capacity(lb_arr.len());
4803
4804 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4806
4807 for i in 0..lb_arr.len() {
4808 if lb_arr.is_null(i) {
4809 builder.append_null();
4810 } else {
4811 let bytes = lb_arr.value(i);
4812 let b = match peek_tag(bytes) {
4813 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4814 Some(TAG_NULL) => false,
4815 _ => false, };
4817 builder.append_value(b);
4818 }
4819 }
4820 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
4821 }
4822 }
4823 }
4824}
4825
4826pub fn create_cypher_size_udf() -> ScalarUDF {
4832 ScalarUDF::new_from_impl(CypherSizeUdf::new())
4833}
4834
4835#[derive(Debug)]
4836struct CypherSizeUdf {
4837 signature: Signature,
4838}
4839
4840impl CypherSizeUdf {
4841 fn new() -> Self {
4842 Self {
4843 signature: Signature::any(1, Volatility::Immutable),
4844 }
4845 }
4846}
4847
4848impl_udf_eq_hash!(CypherSizeUdf);
4849
4850impl ScalarUDFImpl for CypherSizeUdf {
4851 fn as_any(&self) -> &dyn Any {
4852 self
4853 }
4854
4855 fn name(&self) -> &str {
4856 "_cypher_size"
4857 }
4858
4859 fn signature(&self) -> &Signature {
4860 &self.signature
4861 }
4862
4863 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4864 Ok(DataType::Int64)
4865 }
4866
4867 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4868 if args.args.len() != 1 {
4869 return Err(datafusion::error::DataFusionError::Execution(
4870 "_cypher_size() requires exactly 1 argument".to_string(),
4871 ));
4872 }
4873
4874 match &args.args[0] {
4875 ColumnarValue::Scalar(scalar) => {
4876 let result = cypher_size_scalar(scalar)?;
4877 Ok(ColumnarValue::Scalar(result))
4878 }
4879 ColumnarValue::Array(arr) => {
4880 let mut results: Vec<Option<i64>> = Vec::with_capacity(arr.len());
4881 for i in 0..arr.len() {
4882 if arr.is_null(i) {
4883 results.push(None);
4884 } else {
4885 let scalar = ScalarValue::try_from_array(arr, i)?;
4886 match cypher_size_scalar(&scalar)? {
4887 ScalarValue::Int64(v) => results.push(v),
4888 _ => results.push(None),
4889 }
4890 }
4891 }
4892 let arr: ArrayRef = Arc::new(arrow_array::Int64Array::from(results));
4893 Ok(ColumnarValue::Array(arr))
4894 }
4895 }
4896 }
4897}
4898
4899fn cypher_size_scalar(scalar: &ScalarValue) -> DFResult<ScalarValue> {
4900 match scalar {
4901 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
4903 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
4904 }
4905 ScalarValue::List(arr) => {
4908 if arr.is_empty() || arr.is_null(0) {
4909 Ok(ScalarValue::Int64(None))
4910 } else {
4911 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4912 }
4913 }
4914 ScalarValue::LargeList(arr) => {
4915 if arr.is_empty() || arr.is_null(0) {
4916 Ok(ScalarValue::Int64(None))
4917 } else {
4918 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4919 }
4920 }
4921 ScalarValue::LargeBinary(Some(b)) => {
4923 if let Ok(uni_val) = uni_common::cypher_value_codec::decode(b) {
4924 match &uni_val {
4925 uni_common::Value::Node(_) => {
4926 Err(datafusion::error::DataFusionError::Execution(
4927 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
4928 ))
4929 }
4930 uni_common::Value::Edge(_) => {
4931 Err(datafusion::error::DataFusionError::Execution(
4932 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
4933 ))
4934 }
4935 _ => {
4936 let json_val: serde_json::Value = uni_val.into();
4937 match json_val {
4938 serde_json::Value::Array(arr) => Ok(ScalarValue::Int64(Some(arr.len() as i64))),
4939 serde_json::Value::String(s) => {
4940 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
4941 }
4942 serde_json::Value::Object(m) => Ok(ScalarValue::Int64(Some(m.len() as i64))),
4943 _ => Ok(ScalarValue::Int64(None)),
4944 }
4945 }
4946 }
4947 } else {
4948 Ok(ScalarValue::Int64(None))
4949 }
4950 }
4951 ScalarValue::Map(arr) => {
4953 if arr.is_empty() || arr.is_null(0) {
4954 Ok(ScalarValue::Int64(None))
4955 } else {
4956 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4958 }
4959 }
4960 ScalarValue::Struct(arr) => {
4962 if arr.is_null(0) {
4963 Ok(ScalarValue::Int64(None))
4964 } else {
4965 let schema = arr.fields();
4966 let field_names: Vec<&str> = schema.iter().map(|f| f.name().as_str()).collect();
4967 if field_names.contains(&"_vid") && !field_names.contains(&"relationships") {
4969 return Err(datafusion::error::DataFusionError::Execution(
4970 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
4971 ));
4972 }
4973 if field_names.contains(&"_eid")
4975 || (field_names.contains(&"_src") && field_names.contains(&"_dst"))
4976 {
4977 return Err(datafusion::error::DataFusionError::Execution(
4978 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
4979 ));
4980 }
4981 if let Some((rels_idx, _)) = schema
4983 .iter()
4984 .enumerate()
4985 .find(|(_, f)| f.name() == "relationships")
4986 {
4987 let rels_col = arr.column(rels_idx);
4989 if let Some(list_arr) =
4990 rels_col.as_any().downcast_ref::<arrow_array::ListArray>()
4991 {
4992 if list_arr.is_null(0) {
4993 Ok(ScalarValue::Int64(Some(0)))
4994 } else {
4995 Ok(ScalarValue::Int64(Some(list_arr.value(0).len() as i64)))
4996 }
4997 } else {
4998 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
4999 }
5000 } else {
5001 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
5002 }
5003 }
5004 }
5005 ScalarValue::Null
5007 | ScalarValue::Utf8(None)
5008 | ScalarValue::LargeUtf8(None)
5009 | ScalarValue::LargeBinary(None) => Ok(ScalarValue::Int64(None)),
5010 other => Err(datafusion::error::DataFusionError::Execution(format!(
5011 "_cypher_size(): unsupported type {other:?}"
5012 ))),
5013 }
5014}
5015
5016pub fn create_cypher_list_compare_udf() -> ScalarUDF {
5022 ScalarUDF::new_from_impl(CypherListCompareUdf::new())
5023}
5024
5025#[derive(Debug)]
5026struct CypherListCompareUdf {
5027 signature: Signature,
5028}
5029
5030impl CypherListCompareUdf {
5031 fn new() -> Self {
5032 Self {
5033 signature: Signature::any(3, Volatility::Immutable),
5034 }
5035 }
5036}
5037
5038impl_udf_eq_hash!(CypherListCompareUdf);
5039
5040impl ScalarUDFImpl for CypherListCompareUdf {
5041 fn as_any(&self) -> &dyn Any {
5042 self
5043 }
5044
5045 fn name(&self) -> &str {
5046 "_cypher_list_compare"
5047 }
5048
5049 fn signature(&self) -> &Signature {
5050 &self.signature
5051 }
5052
5053 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5054 Ok(DataType::Boolean)
5055 }
5056
5057 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5058 let output_type = DataType::Boolean;
5059 invoke_cypher_udf(args, &output_type, |val_args| {
5060 if val_args.len() != 3 {
5061 return Err(datafusion::error::DataFusionError::Execution(
5062 "_cypher_list_compare(): requires 3 arguments (left, right, op)".to_string(),
5063 ));
5064 }
5065
5066 let left = &val_args[0];
5067 let right = &val_args[1];
5068 let op_str = match &val_args[2] {
5069 Value::String(s) => s.as_str(),
5070 _ => {
5071 return Err(datafusion::error::DataFusionError::Execution(
5072 "_cypher_list_compare(): op must be a string".to_string(),
5073 ));
5074 }
5075 };
5076
5077 let (left_items, right_items) = match (left, right) {
5078 (Value::List(l), Value::List(r)) => (l, r),
5079 (Value::Null, _) | (_, Value::Null) => return Ok(Value::Null),
5080 _ => {
5081 return Err(datafusion::error::DataFusionError::Execution(
5082 "_cypher_list_compare(): both arguments must be lists".to_string(),
5083 ));
5084 }
5085 };
5086
5087 let cmp = cypher_list_cmp(left_items, right_items);
5089
5090 let result = match (op_str, cmp) {
5091 (_, None) => Value::Null,
5092 ("lt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Less),
5093 ("lteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Greater),
5094 ("gt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Greater),
5095 ("gteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Less),
5096 _ => {
5097 return Err(datafusion::error::DataFusionError::Execution(format!(
5098 "_cypher_list_compare(): unknown op '{}'",
5099 op_str
5100 )));
5101 }
5102 };
5103
5104 Ok(result)
5105 })
5106 }
5107}
5108
5109pub fn create_map_project_udf() -> ScalarUDF {
5114 ScalarUDF::new_from_impl(MapProjectUdf::new())
5115}
5116
5117#[derive(Debug)]
5118struct MapProjectUdf {
5119 signature: Signature,
5120}
5121
5122impl MapProjectUdf {
5123 fn new() -> Self {
5124 Self {
5125 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5126 }
5127 }
5128}
5129
5130impl_udf_eq_hash!(MapProjectUdf);
5131
5132impl ScalarUDFImpl for MapProjectUdf {
5133 fn as_any(&self) -> &dyn Any {
5134 self
5135 }
5136
5137 fn name(&self) -> &str {
5138 "_map_project"
5139 }
5140
5141 fn signature(&self) -> &Signature {
5142 &self.signature
5143 }
5144
5145 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5146 Ok(DataType::LargeBinary)
5147 }
5148
5149 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5150 let output_type = self.return_type(&[])?;
5151 invoke_cypher_udf(args, &output_type, |val_args| {
5152 let mut result_map = std::collections::HashMap::new();
5153 let mut i = 0;
5154 while i + 1 < val_args.len() {
5155 let key = &val_args[i];
5156 let value = &val_args[i + 1];
5157 if let Some(k) = key.as_str() {
5158 if k == "__all__" {
5159 match value {
5161 Value::Map(map) => {
5162 for (mk, mv) in map {
5163 if !mk.starts_with('_') {
5164 result_map.insert(mk.clone(), mv.clone());
5165 }
5166 }
5167 }
5168 Value::Node(node) => {
5169 for (pk, pv) in &node.properties {
5170 result_map.insert(pk.clone(), pv.clone());
5171 }
5172 }
5173 Value::Edge(edge) => {
5174 for (pk, pv) in &edge.properties {
5175 result_map.insert(pk.clone(), pv.clone());
5176 }
5177 }
5178 _ => {}
5179 }
5180 } else {
5181 result_map.insert(k.to_string(), value.clone());
5182 }
5183 }
5184 i += 2;
5185 }
5186 Ok(Value::Map(result_map))
5187 })
5188 }
5189}
5190
5191pub fn create_make_cypher_list_udf() -> ScalarUDF {
5196 ScalarUDF::new_from_impl(MakeCypherListUdf::new())
5197}
5198
5199#[derive(Debug)]
5200struct MakeCypherListUdf {
5201 signature: Signature,
5202}
5203
5204impl MakeCypherListUdf {
5205 fn new() -> Self {
5206 Self {
5207 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5208 }
5209 }
5210}
5211
5212impl_udf_eq_hash!(MakeCypherListUdf);
5213
5214impl ScalarUDFImpl for MakeCypherListUdf {
5215 fn as_any(&self) -> &dyn Any {
5216 self
5217 }
5218
5219 fn name(&self) -> &str {
5220 "_make_cypher_list"
5221 }
5222
5223 fn signature(&self) -> &Signature {
5224 &self.signature
5225 }
5226
5227 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5228 Ok(DataType::LargeBinary)
5229 }
5230
5231 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5232 let output_type = self.return_type(&[])?;
5233 invoke_cypher_udf(args, &output_type, |val_args| {
5234 Ok(Value::List(val_args.to_vec()))
5235 })
5236 }
5237}
5238
5239pub fn create_cypher_in_udf() -> ScalarUDF {
5256 ScalarUDF::new_from_impl(CypherInUdf::new())
5257}
5258
5259#[derive(Debug)]
5260struct CypherInUdf {
5261 signature: Signature,
5262}
5263
5264impl CypherInUdf {
5265 fn new() -> Self {
5266 Self {
5267 signature: Signature::any(2, Volatility::Immutable),
5268 }
5269 }
5270}
5271
5272impl_udf_eq_hash!(CypherInUdf);
5273
5274impl ScalarUDFImpl for CypherInUdf {
5275 fn as_any(&self) -> &dyn Any {
5276 self
5277 }
5278
5279 fn name(&self) -> &str {
5280 "_cypher_in"
5281 }
5282
5283 fn signature(&self) -> &Signature {
5284 &self.signature
5285 }
5286
5287 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5288 Ok(DataType::Boolean)
5289 }
5290
5291 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5292 invoke_cypher_udf(args, &DataType::Boolean, |vals| {
5293 if vals.len() != 2 {
5294 return Err(datafusion::error::DataFusionError::Execution(
5295 "_cypher_in(): requires 2 arguments".to_string(),
5296 ));
5297 }
5298 let element = &vals[0];
5299 let list_val = &vals[1];
5300
5301 if list_val.is_null() {
5303 return Ok(Value::Null);
5304 }
5305
5306 let items = match list_val {
5308 Value::List(items) => items.as_slice(),
5309 _ => {
5310 return Err(datafusion::error::DataFusionError::Execution(format!(
5311 "_cypher_in(): second argument must be a list, got {:?}",
5312 list_val
5313 )));
5314 }
5315 };
5316
5317 if element.is_null() {
5319 return if items.is_empty() {
5320 Ok(Value::Bool(false))
5321 } else {
5322 Ok(Value::Null) };
5324 }
5325
5326 let mut has_null = false;
5328 for item in items {
5329 match cypher_eq(element, item) {
5330 Some(true) => return Ok(Value::Bool(true)),
5331 None => has_null = true,
5332 Some(false) => {}
5333 }
5334 }
5335
5336 if has_null {
5337 Ok(Value::Null) } else {
5339 Ok(Value::Bool(false))
5340 }
5341 })
5342 }
5343}
5344
5345pub fn create_cypher_list_concat_udf() -> ScalarUDF {
5351 ScalarUDF::new_from_impl(CypherListConcatUdf::new())
5352}
5353
5354#[derive(Debug)]
5355struct CypherListConcatUdf {
5356 signature: Signature,
5357}
5358
5359impl CypherListConcatUdf {
5360 fn new() -> Self {
5361 Self {
5362 signature: Signature::any(2, Volatility::Immutable),
5363 }
5364 }
5365}
5366
5367impl_udf_eq_hash!(CypherListConcatUdf);
5368
5369impl ScalarUDFImpl for CypherListConcatUdf {
5370 fn as_any(&self) -> &dyn Any {
5371 self
5372 }
5373
5374 fn name(&self) -> &str {
5375 "_cypher_list_concat"
5376 }
5377
5378 fn signature(&self) -> &Signature {
5379 &self.signature
5380 }
5381
5382 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5383 Ok(DataType::LargeBinary)
5384 }
5385
5386 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5387 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5388 if vals.len() != 2 {
5389 return Err(datafusion::error::DataFusionError::Execution(
5390 "_cypher_list_concat(): requires 2 arguments".to_string(),
5391 ));
5392 }
5393 if vals[0].is_null() || vals[1].is_null() {
5395 return Ok(Value::Null);
5396 }
5397 match (&vals[0], &vals[1]) {
5398 (Value::List(left), Value::List(right)) => {
5399 let mut result = left.clone();
5400 result.extend(right.iter().cloned());
5401 Ok(Value::List(result))
5402 }
5403 (Value::List(list), elem) => {
5406 let mut result = list.clone();
5407 result.push(elem.clone());
5408 Ok(Value::List(result))
5409 }
5410 (elem, Value::List(list)) => {
5411 let mut result = vec![elem.clone()];
5412 result.extend(list.iter().cloned());
5413 Ok(Value::List(result))
5414 }
5415 _ => {
5416 crate::expr_eval::eval_binary_op(
5419 &vals[0],
5420 &uni_cypher::ast::BinaryOp::Add,
5421 &vals[1],
5422 )
5423 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
5424 }
5425 }
5426 })
5427 }
5428}
5429
5430pub fn create_cypher_list_append_udf() -> ScalarUDF {
5436 ScalarUDF::new_from_impl(CypherListAppendUdf::new())
5437}
5438
5439#[derive(Debug)]
5440struct CypherListAppendUdf {
5441 signature: Signature,
5442}
5443
5444impl CypherListAppendUdf {
5445 fn new() -> Self {
5446 Self {
5447 signature: Signature::any(2, Volatility::Immutable),
5448 }
5449 }
5450}
5451
5452impl_udf_eq_hash!(CypherListAppendUdf);
5453
5454impl ScalarUDFImpl for CypherListAppendUdf {
5455 fn as_any(&self) -> &dyn Any {
5456 self
5457 }
5458
5459 fn name(&self) -> &str {
5460 "_cypher_list_append"
5461 }
5462
5463 fn signature(&self) -> &Signature {
5464 &self.signature
5465 }
5466
5467 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5468 Ok(DataType::LargeBinary)
5469 }
5470
5471 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5472 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5473 if vals.len() != 2 {
5474 return Err(datafusion::error::DataFusionError::Execution(
5475 "_cypher_list_append(): requires 2 arguments".to_string(),
5476 ));
5477 }
5478 let left = &vals[0];
5479 let right = &vals[1];
5480
5481 if left.is_null() || right.is_null() {
5483 return Ok(Value::Null);
5484 }
5485
5486 match (left, right) {
5487 (Value::List(list), elem) => {
5489 let mut result = list.clone();
5490 result.push(elem.clone());
5491 Ok(Value::List(result))
5492 }
5493 (elem, Value::List(list)) => {
5495 let mut result = vec![elem.clone()];
5496 result.extend(list.iter().cloned());
5497 Ok(Value::List(result))
5498 }
5499 _ => Err(datafusion::error::DataFusionError::Execution(format!(
5500 "_cypher_list_append(): at least one argument must be a list, got {:?} and {:?}",
5501 left, right
5502 ))),
5503 }
5504 })
5505 }
5506}
5507
5508pub fn create_cypher_list_slice_udf() -> ScalarUDF {
5514 ScalarUDF::new_from_impl(CypherListSliceUdf::new())
5515}
5516
5517#[derive(Debug)]
5518struct CypherListSliceUdf {
5519 signature: Signature,
5520}
5521
5522impl CypherListSliceUdf {
5523 fn new() -> Self {
5524 Self {
5525 signature: Signature::any(3, Volatility::Immutable),
5526 }
5527 }
5528}
5529
5530impl_udf_eq_hash!(CypherListSliceUdf);
5531
5532impl ScalarUDFImpl for CypherListSliceUdf {
5533 fn as_any(&self) -> &dyn Any {
5534 self
5535 }
5536
5537 fn name(&self) -> &str {
5538 "_cypher_list_slice"
5539 }
5540
5541 fn signature(&self) -> &Signature {
5542 &self.signature
5543 }
5544
5545 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5546 Ok(DataType::LargeBinary)
5547 }
5548
5549 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5550 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5551 if vals.len() != 3 {
5552 return Err(datafusion::error::DataFusionError::Execution(
5553 "_cypher_list_slice(): requires 3 arguments (list, start, end)".to_string(),
5554 ));
5555 }
5556 if vals[0].is_null() {
5558 return Ok(Value::Null);
5559 }
5560 let list = match &vals[0] {
5561 Value::List(l) => l,
5562 _ => {
5563 return Err(datafusion::error::DataFusionError::Execution(format!(
5564 "_cypher_list_slice(): first argument must be a list, got {:?}",
5565 vals[0]
5566 )));
5567 }
5568 };
5569 if vals[1].is_null() || vals[2].is_null() {
5571 return Ok(Value::Null);
5572 }
5573
5574 let len = list.len() as i64;
5575 let raw_start = match &vals[1] {
5576 Value::Int(i) => *i,
5577 _ => 0,
5578 };
5579 let raw_end = match &vals[2] {
5580 Value::Int(i) => *i,
5581 _ => len,
5582 };
5583
5584 let start = if raw_start < 0 {
5586 (len + raw_start).max(0) as usize
5587 } else {
5588 (raw_start).min(len) as usize
5589 };
5590 let end = if raw_end == i64::MAX {
5591 len as usize
5592 } else if raw_end < 0 {
5593 (len + raw_end).max(0) as usize
5594 } else {
5595 (raw_end).min(len) as usize
5596 };
5597
5598 if start >= end {
5599 return Ok(Value::List(vec![]));
5600 }
5601 Ok(Value::List(list[start..end.min(list.len())].to_vec()))
5602 })
5603 }
5604}
5605
5606pub fn create_cypher_reverse_udf() -> ScalarUDF {
5617 ScalarUDF::new_from_impl(CypherReverseUdf::new())
5618}
5619
5620#[derive(Debug)]
5621struct CypherReverseUdf {
5622 signature: Signature,
5623}
5624
5625impl CypherReverseUdf {
5626 fn new() -> Self {
5627 Self {
5628 signature: Signature::any(1, Volatility::Immutable),
5629 }
5630 }
5631}
5632
5633impl_udf_eq_hash!(CypherReverseUdf);
5634
5635impl ScalarUDFImpl for CypherReverseUdf {
5636 fn as_any(&self) -> &dyn Any {
5637 self
5638 }
5639
5640 fn name(&self) -> &str {
5641 "_cypher_reverse"
5642 }
5643
5644 fn signature(&self) -> &Signature {
5645 &self.signature
5646 }
5647
5648 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5649 Ok(DataType::LargeBinary)
5650 }
5651
5652 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5653 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5654 if vals.len() != 1 {
5655 return Err(datafusion::error::DataFusionError::Execution(
5656 "_cypher_reverse(): requires exactly 1 argument".to_string(),
5657 ));
5658 }
5659 match &vals[0] {
5660 Value::Null => Ok(Value::Null),
5661 Value::String(s) => Ok(Value::String(s.chars().rev().collect())),
5662 Value::List(l) => {
5663 let mut reversed = l.clone();
5664 reversed.reverse();
5665 Ok(Value::List(reversed))
5666 }
5667 other => Err(datafusion::error::DataFusionError::Execution(format!(
5668 "_cypher_reverse(): expected string or list, got {:?}",
5669 other
5670 ))),
5671 }
5672 })
5673 }
5674}
5675
5676pub fn create_cypher_substring_udf() -> ScalarUDF {
5687 ScalarUDF::new_from_impl(CypherSubstringUdf::new())
5688}
5689
5690#[derive(Debug)]
5691struct CypherSubstringUdf {
5692 signature: Signature,
5693}
5694
5695impl CypherSubstringUdf {
5696 fn new() -> Self {
5697 Self {
5698 signature: Signature::variadic_any(Volatility::Immutable),
5699 }
5700 }
5701}
5702
5703impl_udf_eq_hash!(CypherSubstringUdf);
5704
5705impl ScalarUDFImpl for CypherSubstringUdf {
5706 fn as_any(&self) -> &dyn Any {
5707 self
5708 }
5709
5710 fn name(&self) -> &str {
5711 "_cypher_substring"
5712 }
5713
5714 fn signature(&self) -> &Signature {
5715 &self.signature
5716 }
5717
5718 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5719 Ok(DataType::Utf8)
5720 }
5721
5722 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5723 invoke_cypher_udf(args, &DataType::Utf8, |vals| {
5724 if vals.len() < 2 || vals.len() > 3 {
5725 return Err(datafusion::error::DataFusionError::Execution(
5726 "_cypher_substring(): requires 2 or 3 arguments".to_string(),
5727 ));
5728 }
5729 if vals.iter().any(|v| v.is_null()) {
5731 return Ok(Value::Null);
5732 }
5733 let s = match &vals[0] {
5734 Value::String(s) => s.as_str(),
5735 other => {
5736 return Err(datafusion::error::DataFusionError::Execution(format!(
5737 "_cypher_substring(): first argument must be a string, got {:?}",
5738 other
5739 )));
5740 }
5741 };
5742 let start = match &vals[1] {
5743 Value::Int(i) => *i,
5744 other => {
5745 return Err(datafusion::error::DataFusionError::Execution(format!(
5746 "_cypher_substring(): second argument must be an integer, got {:?}",
5747 other
5748 )));
5749 }
5750 };
5751
5752 let chars: Vec<char> = s.chars().collect();
5754 let len = chars.len() as i64;
5755
5756 let start_idx = start.max(0).min(len) as usize;
5758
5759 let end_idx = if vals.len() == 3 {
5760 let length = match &vals[2] {
5761 Value::Int(i) => *i,
5762 other => {
5763 return Err(datafusion::error::DataFusionError::Execution(format!(
5764 "_cypher_substring(): third argument must be an integer, got {:?}",
5765 other
5766 )));
5767 }
5768 };
5769 if length < 0 {
5770 return Err(datafusion::error::DataFusionError::Execution(
5771 "ArgumentError: NegativeIntegerArgument - substring length must be non-negative".to_string(),
5772 ));
5773 }
5774 (start_idx as i64 + length).min(len) as usize
5775 } else {
5776 len as usize
5777 };
5778
5779 Ok(Value::String(chars[start_idx..end_idx].iter().collect()))
5780 })
5781 }
5782}
5783
5784pub fn create_cypher_split_udf() -> ScalarUDF {
5793 ScalarUDF::new_from_impl(CypherSplitUdf::new())
5794}
5795
5796#[derive(Debug)]
5797struct CypherSplitUdf {
5798 signature: Signature,
5799}
5800
5801impl CypherSplitUdf {
5802 fn new() -> Self {
5803 Self {
5804 signature: Signature::any(2, Volatility::Immutable),
5805 }
5806 }
5807}
5808
5809impl_udf_eq_hash!(CypherSplitUdf);
5810
5811impl ScalarUDFImpl for CypherSplitUdf {
5812 fn as_any(&self) -> &dyn Any {
5813 self
5814 }
5815
5816 fn name(&self) -> &str {
5817 "_cypher_split"
5818 }
5819
5820 fn signature(&self) -> &Signature {
5821 &self.signature
5822 }
5823
5824 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5825 Ok(DataType::LargeBinary)
5826 }
5827
5828 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5829 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5830 if vals.len() != 2 {
5831 return Err(datafusion::error::DataFusionError::Execution(
5832 "_cypher_split(): requires exactly 2 arguments".to_string(),
5833 ));
5834 }
5835 if vals.iter().any(|v| v.is_null()) {
5837 return Ok(Value::Null);
5838 }
5839 let s = match &vals[0] {
5840 Value::String(s) => s.clone(),
5841 other => {
5842 return Err(datafusion::error::DataFusionError::Execution(format!(
5843 "_cypher_split(): first argument must be a string, got {:?}",
5844 other
5845 )));
5846 }
5847 };
5848 let delimiter = match &vals[1] {
5849 Value::String(d) => d.clone(),
5850 other => {
5851 return Err(datafusion::error::DataFusionError::Execution(format!(
5852 "_cypher_split(): second argument must be a string, got {:?}",
5853 other
5854 )));
5855 }
5856 };
5857 let parts: Vec<Value> = s
5858 .split(&delimiter)
5859 .map(|p| Value::String(p.to_string()))
5860 .collect();
5861 Ok(Value::List(parts))
5862 })
5863 }
5864}
5865
5866pub fn create_cypher_list_to_cv_udf() -> ScalarUDF {
5877 ScalarUDF::new_from_impl(CypherListToCvUdf::new())
5878}
5879
5880#[derive(Debug)]
5881struct CypherListToCvUdf {
5882 signature: Signature,
5883}
5884
5885impl CypherListToCvUdf {
5886 fn new() -> Self {
5887 Self {
5888 signature: Signature::any(1, Volatility::Immutable),
5889 }
5890 }
5891}
5892
5893impl_udf_eq_hash!(CypherListToCvUdf);
5894
5895impl ScalarUDFImpl for CypherListToCvUdf {
5896 fn as_any(&self) -> &dyn Any {
5897 self
5898 }
5899
5900 fn name(&self) -> &str {
5901 "_cypher_list_to_cv"
5902 }
5903
5904 fn signature(&self) -> &Signature {
5905 &self.signature
5906 }
5907
5908 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5909 Ok(DataType::LargeBinary)
5910 }
5911
5912 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5913 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5914 if vals.len() != 1 {
5915 return Err(datafusion::error::DataFusionError::Execution(
5916 "_cypher_list_to_cv(): requires exactly 1 argument".to_string(),
5917 ));
5918 }
5919 Ok(vals[0].clone())
5920 })
5921 }
5922}
5923
5924pub fn create_cypher_scalar_to_cv_udf() -> ScalarUDF {
5935 ScalarUDF::new_from_impl(CypherScalarToCvUdf::new())
5936}
5937
5938#[derive(Debug)]
5939struct CypherScalarToCvUdf {
5940 signature: Signature,
5941}
5942
5943impl CypherScalarToCvUdf {
5944 fn new() -> Self {
5945 Self {
5946 signature: Signature::any(1, Volatility::Immutable),
5947 }
5948 }
5949}
5950
5951impl_udf_eq_hash!(CypherScalarToCvUdf);
5952
5953impl ScalarUDFImpl for CypherScalarToCvUdf {
5954 fn as_any(&self) -> &dyn Any {
5955 self
5956 }
5957
5958 fn name(&self) -> &str {
5959 "_cypher_scalar_to_cv"
5960 }
5961
5962 fn signature(&self) -> &Signature {
5963 &self.signature
5964 }
5965
5966 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5967 Ok(DataType::LargeBinary)
5968 }
5969
5970 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5971 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5972 if vals.len() != 1 {
5973 return Err(datafusion::error::DataFusionError::Execution(
5974 "_cypher_scalar_to_cv(): requires exactly 1 argument".to_string(),
5975 ));
5976 }
5977 Ok(vals[0].clone())
5978 })
5979 }
5980}
5981
5982pub fn create_cypher_tail_udf() -> ScalarUDF {
5994 ScalarUDF::new_from_impl(CypherTailUdf::new())
5995}
5996
5997#[derive(Debug)]
5998struct CypherTailUdf {
5999 signature: Signature,
6000}
6001
6002impl CypherTailUdf {
6003 fn new() -> Self {
6004 Self {
6005 signature: Signature::any(1, Volatility::Immutable),
6006 }
6007 }
6008}
6009
6010impl_udf_eq_hash!(CypherTailUdf);
6011
6012impl ScalarUDFImpl for CypherTailUdf {
6013 fn as_any(&self) -> &dyn Any {
6014 self
6015 }
6016
6017 fn name(&self) -> &str {
6018 "_cypher_tail"
6019 }
6020
6021 fn signature(&self) -> &Signature {
6022 &self.signature
6023 }
6024
6025 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6026 Ok(DataType::LargeBinary)
6027 }
6028
6029 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6030 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6031 if vals.len() != 1 {
6032 return Err(datafusion::error::DataFusionError::Execution(
6033 "_cypher_tail(): requires exactly 1 argument".to_string(),
6034 ));
6035 }
6036 match &vals[0] {
6037 Value::Null => Ok(Value::Null),
6038 Value::List(l) => {
6039 if l.is_empty() {
6040 Ok(Value::List(vec![]))
6041 } else {
6042 Ok(Value::List(l[1..].to_vec()))
6043 }
6044 }
6045 other => Err(datafusion::error::DataFusionError::Execution(format!(
6046 "_cypher_tail(): expected list, got {:?}",
6047 other
6048 ))),
6049 }
6050 })
6051 }
6052}
6053
6054pub fn create_cypher_head_udf() -> ScalarUDF {
6065 ScalarUDF::new_from_impl(CypherHeadUdf::new())
6066}
6067
6068#[derive(Debug)]
6069struct CypherHeadUdf {
6070 signature: Signature,
6071}
6072
6073impl CypherHeadUdf {
6074 fn new() -> Self {
6075 Self {
6076 signature: Signature::any(1, Volatility::Immutable),
6077 }
6078 }
6079}
6080
6081impl_udf_eq_hash!(CypherHeadUdf);
6082
6083impl ScalarUDFImpl for CypherHeadUdf {
6084 fn as_any(&self) -> &dyn Any {
6085 self
6086 }
6087
6088 fn name(&self) -> &str {
6089 "head"
6090 }
6091
6092 fn signature(&self) -> &Signature {
6093 &self.signature
6094 }
6095
6096 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6097 Ok(DataType::LargeBinary)
6098 }
6099
6100 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6101 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6102 if vals.len() != 1 {
6103 return Err(datafusion::error::DataFusionError::Execution(
6104 "head(): requires exactly 1 argument".to_string(),
6105 ));
6106 }
6107 match &vals[0] {
6108 Value::Null => Ok(Value::Null),
6109 Value::List(l) => Ok(l.first().cloned().unwrap_or(Value::Null)),
6110 other => Err(datafusion::error::DataFusionError::Execution(format!(
6111 "head(): expected list, got {:?}",
6112 other
6113 ))),
6114 }
6115 })
6116 }
6117}
6118
6119pub fn create_cypher_last_udf() -> ScalarUDF {
6130 ScalarUDF::new_from_impl(CypherLastUdf::new())
6131}
6132
6133#[derive(Debug)]
6134struct CypherLastUdf {
6135 signature: Signature,
6136}
6137
6138impl CypherLastUdf {
6139 fn new() -> Self {
6140 Self {
6141 signature: Signature::any(1, Volatility::Immutable),
6142 }
6143 }
6144}
6145
6146impl_udf_eq_hash!(CypherLastUdf);
6147
6148impl ScalarUDFImpl for CypherLastUdf {
6149 fn as_any(&self) -> &dyn Any {
6150 self
6151 }
6152
6153 fn name(&self) -> &str {
6154 "last"
6155 }
6156
6157 fn signature(&self) -> &Signature {
6158 &self.signature
6159 }
6160
6161 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6162 Ok(DataType::LargeBinary)
6163 }
6164
6165 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6166 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6167 if vals.len() != 1 {
6168 return Err(datafusion::error::DataFusionError::Execution(
6169 "last(): requires exactly 1 argument".to_string(),
6170 ));
6171 }
6172 match &vals[0] {
6173 Value::Null => Ok(Value::Null),
6174 Value::List(l) => Ok(l.last().cloned().unwrap_or(Value::Null)),
6175 other => Err(datafusion::error::DataFusionError::Execution(format!(
6176 "last(): expected list, got {:?}",
6177 other
6178 ))),
6179 }
6180 })
6181 }
6182}
6183
6184fn cypher_list_cmp(left: &[Value], right: &[Value]) -> Option<std::cmp::Ordering> {
6187 let min_len = left.len().min(right.len());
6188 for i in 0..min_len {
6189 let cmp = cypher_value_cmp(&left[i], &right[i])?;
6190 if cmp != std::cmp::Ordering::Equal {
6191 return Some(cmp);
6192 }
6193 }
6194 Some(left.len().cmp(&right.len()))
6196}
6197
6198fn cypher_value_cmp(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
6201 match (a, b) {
6202 (Value::Null, Value::Null) => Some(std::cmp::Ordering::Equal),
6203 (Value::Null, _) | (_, Value::Null) => None,
6204 (Value::Int(l), Value::Int(r)) => Some(l.cmp(r)),
6205 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r),
6206 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r),
6207 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)),
6208 (Value::String(l), Value::String(r)) => Some(l.cmp(r)),
6209 (Value::Bool(l), Value::Bool(r)) => Some(l.cmp(r)),
6210 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r),
6211 _ => None, }
6213}
6214
6215struct CypherToFloat64Udf {
6223 signature: Signature,
6224}
6225
6226impl CypherToFloat64Udf {
6227 fn new() -> Self {
6228 Self {
6229 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6230 }
6231 }
6232}
6233
6234impl_udf_eq_hash!(CypherToFloat64Udf);
6235
6236impl std::fmt::Debug for CypherToFloat64Udf {
6237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6238 f.debug_struct("CypherToFloat64Udf").finish()
6239 }
6240}
6241
6242impl ScalarUDFImpl for CypherToFloat64Udf {
6243 fn as_any(&self) -> &dyn Any {
6244 self
6245 }
6246 fn name(&self) -> &str {
6247 "_cypher_to_float64"
6248 }
6249 fn signature(&self) -> &Signature {
6250 &self.signature
6251 }
6252 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6253 Ok(DataType::Float64)
6254 }
6255 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6256 if args.args.len() != 1 {
6257 return Err(datafusion::error::DataFusionError::Execution(
6258 "_cypher_to_float64 requires exactly 1 argument".into(),
6259 ));
6260 }
6261 match &args.args[0] {
6262 ColumnarValue::Scalar(scalar) => {
6263 let f = match scalar {
6264 ScalarValue::LargeBinary(Some(bytes)) => cv_bytes_as_f64(bytes),
6265 ScalarValue::Int64(Some(i)) => Some(*i as f64),
6266 ScalarValue::Int32(Some(i)) => Some(*i as f64),
6267 ScalarValue::Float64(Some(f)) => Some(*f),
6268 ScalarValue::Float32(Some(f)) => Some(*f as f64),
6269 _ => None,
6270 };
6271 Ok(ColumnarValue::Scalar(ScalarValue::Float64(f)))
6272 }
6273 ColumnarValue::Array(arr) => {
6274 let len = arr.len();
6275 let mut builder = arrow::array::Float64Builder::with_capacity(len);
6276 match arr.data_type() {
6277 DataType::LargeBinary => {
6278 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6279 for i in 0..len {
6280 if lb.is_null(i) {
6281 builder.append_null();
6282 } else {
6283 match cv_bytes_as_f64(lb.value(i)) {
6284 Some(f) => builder.append_value(f),
6285 None => builder.append_null(),
6286 }
6287 }
6288 }
6289 }
6290 DataType::Int64 => {
6291 let int_arr = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6292 for i in 0..len {
6293 if int_arr.is_null(i) {
6294 builder.append_null();
6295 } else {
6296 builder.append_value(int_arr.value(i) as f64);
6297 }
6298 }
6299 }
6300 DataType::Float64 => {
6301 let f_arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6302 for i in 0..len {
6303 if f_arr.is_null(i) {
6304 builder.append_null();
6305 } else {
6306 builder.append_value(f_arr.value(i));
6307 }
6308 }
6309 }
6310 _ => {
6311 for _ in 0..len {
6312 builder.append_null();
6313 }
6314 }
6315 }
6316 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
6317 }
6318 }
6319 }
6320}
6321
6322fn create_cypher_to_float64_udf() -> ScalarUDF {
6323 ScalarUDF::from(CypherToFloat64Udf::new())
6324}
6325
6326pub fn cypher_to_float64_expr(
6328 arg: datafusion::logical_expr::Expr,
6329) -> datafusion::logical_expr::Expr {
6330 datafusion::logical_expr::Expr::ScalarFunction(
6331 datafusion::logical_expr::expr::ScalarFunction::new_udf(
6332 Arc::new(create_cypher_to_float64_udf()),
6333 vec![arg],
6334 ),
6335 )
6336}
6337
6338pub fn cypher_to_float64_udf() -> datafusion::logical_expr::ScalarUDF {
6340 create_cypher_to_float64_udf()
6341}
6342
6343fn cypher_type_rank(val: &Value) -> u8 {
6351 match val {
6352 Value::Null => 0,
6353 Value::List(_) => 1,
6354 Value::String(_) => 2,
6355 Value::Bool(_) => 3,
6356 Value::Int(_) | Value::Float(_) => 4,
6357 _ => 5, }
6359}
6360
6361fn cypher_cross_type_cmp(a: &Value, b: &Value) -> std::cmp::Ordering {
6364 use std::cmp::Ordering;
6365 let ra = cypher_type_rank(a);
6366 let rb = cypher_type_rank(b);
6367 if ra != rb {
6368 return ra.cmp(&rb);
6369 }
6370 match (a, b) {
6372 (Value::Int(l), Value::Int(r)) => l.cmp(r),
6373 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r).unwrap_or(Ordering::Equal),
6374 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r).unwrap_or(Ordering::Equal),
6375 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)).unwrap_or(Ordering::Equal),
6376 (Value::String(l), Value::String(r)) => l.cmp(r),
6377 (Value::Bool(l), Value::Bool(r)) => l.cmp(r),
6378 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r).unwrap_or(Ordering::Equal),
6379 _ => Ordering::Equal,
6380 }
6381}
6382
6383fn scalar_binary_to_value(bytes: &[u8]) -> Value {
6385 uni_common::cypher_value_codec::decode(bytes).unwrap_or(Value::Null)
6386}
6387
6388use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, AggregateUDFImpl};
6389
6390#[derive(Debug, Clone)]
6392struct CypherMinMaxUdaf {
6393 name: String,
6394 signature: Signature,
6395 is_max: bool,
6396}
6397
6398impl CypherMinMaxUdaf {
6399 fn new(is_max: bool) -> Self {
6400 let name = if is_max { "_cypher_max" } else { "_cypher_min" };
6401 Self {
6402 name: name.to_string(),
6403 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6404 is_max,
6405 }
6406 }
6407}
6408
6409impl PartialEq for CypherMinMaxUdaf {
6410 fn eq(&self, other: &Self) -> bool {
6411 self.name == other.name
6412 }
6413}
6414
6415impl Eq for CypherMinMaxUdaf {}
6416
6417impl Hash for CypherMinMaxUdaf {
6418 fn hash<H: Hasher>(&self, state: &mut H) {
6419 self.name.hash(state);
6420 }
6421}
6422
6423impl AggregateUDFImpl for CypherMinMaxUdaf {
6424 fn as_any(&self) -> &dyn Any {
6425 self
6426 }
6427 fn name(&self) -> &str {
6428 &self.name
6429 }
6430 fn signature(&self) -> &Signature {
6431 &self.signature
6432 }
6433 fn return_type(&self, args: &[DataType]) -> DFResult<DataType> {
6434 Ok(args.first().cloned().unwrap_or(DataType::LargeBinary))
6436 }
6437 fn accumulator(
6438 &self,
6439 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6440 ) -> DFResult<Box<dyn DfAccumulator>> {
6441 Ok(Box::new(CypherMinMaxAccumulator {
6442 current: None,
6443 is_max: self.is_max,
6444 return_type: acc_args.return_field.data_type().clone(),
6445 }))
6446 }
6447 fn state_fields(
6448 &self,
6449 args: datafusion::logical_expr::function::StateFieldsArgs,
6450 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6451 Ok(vec![Arc::new(arrow::datatypes::Field::new(
6452 args.name,
6453 DataType::LargeBinary,
6454 true,
6455 ))])
6456 }
6457}
6458
6459#[derive(Debug)]
6460struct CypherMinMaxAccumulator {
6461 current: Option<Value>,
6462 is_max: bool,
6463 return_type: DataType,
6464}
6465
6466impl DfAccumulator for CypherMinMaxAccumulator {
6467 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6468 let arr = &values[0];
6469 match arr.data_type() {
6470 DataType::LargeBinary => {
6471 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6472 for i in 0..lb.len() {
6473 if lb.is_null(i) {
6474 continue;
6475 }
6476 let val = scalar_binary_to_value(lb.value(i));
6477 if val.is_null() {
6478 continue;
6479 }
6480 self.current = Some(match self.current.take() {
6481 None => val,
6482 Some(cur) => {
6483 let ord = cypher_cross_type_cmp(&val, &cur);
6484 if (self.is_max && ord == std::cmp::Ordering::Greater)
6485 || (!self.is_max && ord == std::cmp::Ordering::Less)
6486 {
6487 val
6488 } else {
6489 cur
6490 }
6491 }
6492 });
6493 }
6494 }
6495 _ => {
6496 for i in 0..arr.len() {
6498 if arr.is_null(i) {
6499 continue;
6500 }
6501 let sv = ScalarValue::try_from_array(arr, i).map_err(|e| {
6502 datafusion::error::DataFusionError::Execution(e.to_string())
6503 })?;
6504 let val = scalar_to_value(&sv)?;
6505 if val.is_null() {
6506 continue;
6507 }
6508 self.current = Some(match self.current.take() {
6509 None => val,
6510 Some(cur) => {
6511 let ord = cypher_cross_type_cmp(&val, &cur);
6512 if (self.is_max && ord == std::cmp::Ordering::Greater)
6513 || (!self.is_max && ord == std::cmp::Ordering::Less)
6514 {
6515 val
6516 } else {
6517 cur
6518 }
6519 }
6520 });
6521 }
6522 }
6523 }
6524 Ok(())
6525 }
6526 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6527 match &self.current {
6528 None => {
6529 ScalarValue::try_from(&self.return_type)
6531 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
6532 }
6533 Some(val) => {
6534 if matches!(self.return_type, DataType::LargeBinary) {
6536 let bytes = uni_common::cypher_value_codec::encode(val);
6537 return Ok(ScalarValue::LargeBinary(Some(bytes)));
6538 }
6539 match val {
6541 Value::Int(i) => match &self.return_type {
6542 DataType::Int64 => Ok(ScalarValue::Int64(Some(*i))),
6543 DataType::UInt64 => Ok(ScalarValue::UInt64(Some(*i as u64))),
6544 _ => {
6545 let bytes = uni_common::cypher_value_codec::encode(val);
6546 Ok(ScalarValue::LargeBinary(Some(bytes)))
6547 }
6548 },
6549 Value::Float(f) => match &self.return_type {
6550 DataType::Float64 => Ok(ScalarValue::Float64(Some(*f))),
6551 _ => {
6552 let bytes = uni_common::cypher_value_codec::encode(val);
6553 Ok(ScalarValue::LargeBinary(Some(bytes)))
6554 }
6555 },
6556 Value::String(s) => match &self.return_type {
6557 DataType::Utf8 => Ok(ScalarValue::Utf8(Some(s.clone()))),
6558 DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
6559 _ => {
6560 let bytes = uni_common::cypher_value_codec::encode(val);
6561 Ok(ScalarValue::LargeBinary(Some(bytes)))
6562 }
6563 },
6564 Value::Bool(b) => match &self.return_type {
6565 DataType::Boolean => Ok(ScalarValue::Boolean(Some(*b))),
6566 _ => {
6567 let bytes = uni_common::cypher_value_codec::encode(val);
6568 Ok(ScalarValue::LargeBinary(Some(bytes)))
6569 }
6570 },
6571 _ => {
6572 let bytes = uni_common::cypher_value_codec::encode(val);
6574 Ok(ScalarValue::LargeBinary(Some(bytes)))
6575 }
6576 }
6577 }
6578 }
6579 }
6580 fn size(&self) -> usize {
6581 std::mem::size_of_val(self) + self.current.as_ref().map_or(0, |_| 64)
6582 }
6583 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6584 Ok(vec![self.evaluate()?])
6585 }
6586 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6587 self.update_batch(states)
6588 }
6589}
6590
6591pub fn create_cypher_min_udaf() -> AggregateUDF {
6592 AggregateUDF::from(CypherMinMaxUdaf::new(false))
6593}
6594
6595pub fn create_cypher_max_udaf() -> AggregateUDF {
6596 AggregateUDF::from(CypherMinMaxUdaf::new(true))
6597}
6598
6599#[derive(Debug, Clone)]
6605struct CypherSumUdaf {
6606 signature: Signature,
6607}
6608
6609impl CypherSumUdaf {
6610 fn new() -> Self {
6611 Self {
6612 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6613 }
6614 }
6615}
6616
6617impl PartialEq for CypherSumUdaf {
6618 fn eq(&self, other: &Self) -> bool {
6619 self.signature == other.signature
6620 }
6621}
6622
6623impl Eq for CypherSumUdaf {}
6624
6625impl Hash for CypherSumUdaf {
6626 fn hash<H: Hasher>(&self, state: &mut H) {
6627 self.name().hash(state);
6628 }
6629}
6630
6631impl AggregateUDFImpl for CypherSumUdaf {
6632 fn as_any(&self) -> &dyn Any {
6633 self
6634 }
6635 fn name(&self) -> &str {
6636 "_cypher_sum"
6637 }
6638 fn signature(&self) -> &Signature {
6639 &self.signature
6640 }
6641 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6642 Ok(DataType::LargeBinary)
6645 }
6646 fn accumulator(
6647 &self,
6648 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6649 ) -> DFResult<Box<dyn DfAccumulator>> {
6650 Ok(Box::new(CypherSumAccumulator {
6651 sum: 0.0,
6652 all_ints: true,
6653 int_sum: 0i64,
6654 has_value: false,
6655 }))
6656 }
6657 fn state_fields(
6658 &self,
6659 args: datafusion::logical_expr::function::StateFieldsArgs,
6660 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6661 Ok(vec![
6662 Arc::new(arrow::datatypes::Field::new(
6663 format!("{}_sum", args.name),
6664 DataType::Float64,
6665 true,
6666 )),
6667 Arc::new(arrow::datatypes::Field::new(
6668 format!("{}_int_sum", args.name),
6669 DataType::Int64,
6670 true,
6671 )),
6672 Arc::new(arrow::datatypes::Field::new(
6673 format!("{}_all_ints", args.name),
6674 DataType::Boolean,
6675 true,
6676 )),
6677 Arc::new(arrow::datatypes::Field::new(
6678 format!("{}_has_value", args.name),
6679 DataType::Boolean,
6680 true,
6681 )),
6682 ])
6683 }
6684}
6685
6686#[derive(Debug)]
6687struct CypherSumAccumulator {
6688 sum: f64,
6689 all_ints: bool,
6690 int_sum: i64,
6691 has_value: bool,
6692}
6693
6694impl DfAccumulator for CypherSumAccumulator {
6695 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6696 let arr = &values[0];
6697 for i in 0..arr.len() {
6698 if arr.is_null(i) {
6699 continue;
6700 }
6701 match arr.data_type() {
6702 DataType::LargeBinary => {
6703 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6704 let bytes = lb.value(i);
6705 use uni_common::cypher_value_codec::{
6706 TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag,
6707 };
6708 match peek_tag(bytes) {
6709 Some(TAG_INT) => {
6710 if let Some(v) = decode_int(bytes) {
6711 self.sum += v as f64;
6712 self.int_sum = self.int_sum.wrapping_add(v);
6713 self.has_value = true;
6714 }
6715 }
6716 Some(TAG_FLOAT) => {
6717 if let Some(v) = decode_float(bytes) {
6718 self.sum += v;
6719 self.all_ints = false;
6720 self.has_value = true;
6721 }
6722 }
6723 _ => {} }
6725 }
6726 DataType::Int64 => {
6727 let a = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6728 let v = a.value(i);
6729 self.sum += v as f64;
6730 self.int_sum = self.int_sum.wrapping_add(v);
6731 self.has_value = true;
6732 }
6733 DataType::Float64 => {
6734 let a = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6735 self.sum += a.value(i);
6736 self.all_ints = false;
6737 self.has_value = true;
6738 }
6739 _ => {}
6740 }
6741 }
6742 Ok(())
6743 }
6744 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6745 if !self.has_value {
6746 return Ok(ScalarValue::LargeBinary(None));
6747 }
6748 let val = if self.all_ints {
6749 Value::Int(self.int_sum)
6750 } else {
6751 Value::Float(self.sum)
6752 };
6753 let bytes = uni_common::cypher_value_codec::encode(&val);
6754 Ok(ScalarValue::LargeBinary(Some(bytes)))
6755 }
6756 fn size(&self) -> usize {
6757 std::mem::size_of_val(self)
6758 }
6759 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6760 Ok(vec![
6761 ScalarValue::Float64(Some(self.sum)),
6762 ScalarValue::Int64(Some(self.int_sum)),
6763 ScalarValue::Boolean(Some(self.all_ints)),
6764 ScalarValue::Boolean(Some(self.has_value)),
6765 ])
6766 }
6767 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6768 let sum_arr = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
6769 let int_sum_arr = states[1].as_any().downcast_ref::<Int64Array>().unwrap();
6770 let all_ints_arr = states[2].as_any().downcast_ref::<BooleanArray>().unwrap();
6771 let has_value_arr = states[3].as_any().downcast_ref::<BooleanArray>().unwrap();
6772 for i in 0..sum_arr.len() {
6773 if !has_value_arr.is_null(i) && has_value_arr.value(i) {
6774 self.sum += sum_arr.value(i);
6775 self.int_sum = self.int_sum.wrapping_add(int_sum_arr.value(i));
6776 if !all_ints_arr.value(i) {
6777 self.all_ints = false;
6778 }
6779 self.has_value = true;
6780 }
6781 }
6782 Ok(())
6783 }
6784}
6785
6786pub fn create_cypher_sum_udaf() -> AggregateUDF {
6787 AggregateUDF::from(CypherSumUdaf::new())
6788}
6789
6790#[derive(Debug, Clone)]
6797struct CypherCollectUdaf {
6798 signature: Signature,
6799}
6800
6801impl CypherCollectUdaf {
6802 fn new() -> Self {
6803 Self {
6804 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6805 }
6806 }
6807}
6808
6809impl PartialEq for CypherCollectUdaf {
6810 fn eq(&self, other: &Self) -> bool {
6811 self.signature == other.signature
6812 }
6813}
6814
6815impl Eq for CypherCollectUdaf {}
6816
6817impl Hash for CypherCollectUdaf {
6818 fn hash<H: Hasher>(&self, state: &mut H) {
6819 self.name().hash(state);
6820 }
6821}
6822
6823impl AggregateUDFImpl for CypherCollectUdaf {
6824 fn as_any(&self) -> &dyn Any {
6825 self
6826 }
6827 fn name(&self) -> &str {
6828 "_cypher_collect"
6829 }
6830 fn signature(&self) -> &Signature {
6831 &self.signature
6832 }
6833 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6834 Ok(DataType::LargeBinary)
6835 }
6836 fn accumulator(
6837 &self,
6838 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6839 ) -> DFResult<Box<dyn DfAccumulator>> {
6840 Ok(Box::new(CypherCollectAccumulator {
6841 values: Vec::new(),
6842 distinct: acc_args.is_distinct,
6843 }))
6844 }
6845 fn state_fields(
6846 &self,
6847 args: datafusion::logical_expr::function::StateFieldsArgs,
6848 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6849 Ok(vec![Arc::new(arrow::datatypes::Field::new(
6850 args.name,
6851 DataType::LargeBinary,
6852 true,
6853 ))])
6854 }
6855}
6856
6857#[derive(Debug)]
6858struct CypherCollectAccumulator {
6859 values: Vec<Value>,
6860 distinct: bool,
6861}
6862
6863impl DfAccumulator for CypherCollectAccumulator {
6864 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6865 let arr = &values[0];
6866 for i in 0..arr.len() {
6867 if arr.is_null(i) {
6868 continue;
6869 }
6870 if let Some(struct_arr) = arr.as_any().downcast_ref::<arrow::array::StructArray>()
6874 && struct_arr.num_columns() > 0
6875 && struct_arr.column(0).is_null(i)
6876 {
6877 continue;
6878 }
6879 let sv = ScalarValue::try_from_array(arr, i)
6880 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
6881 let val = scalar_to_value(&sv)?;
6882 if val.is_null() {
6883 continue;
6884 }
6885 if self.distinct {
6886 let repr = val.to_string();
6888 if self.values.iter().any(|v| v.to_string() == repr) {
6889 continue;
6890 }
6891 }
6892 self.values.push(val);
6893 }
6894 Ok(())
6895 }
6896 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6897 let val = Value::List(self.values.clone());
6899 let bytes = uni_common::cypher_value_codec::encode(&val);
6900 Ok(ScalarValue::LargeBinary(Some(bytes)))
6901 }
6902 fn size(&self) -> usize {
6903 std::mem::size_of_val(self) + self.values.len() * 64
6904 }
6905 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6906 Ok(vec![self.evaluate()?])
6907 }
6908 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6909 let arr = &states[0];
6911 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
6912 for i in 0..lb.len() {
6913 if lb.is_null(i) {
6914 continue;
6915 }
6916 let val = scalar_binary_to_value(lb.value(i));
6917 if let Value::List(items) = val {
6918 for item in items {
6919 if !item.is_null() {
6920 if self.distinct {
6921 let repr = item.to_string();
6922 if self.values.iter().any(|v| v.to_string() == repr) {
6923 continue;
6924 }
6925 }
6926 self.values.push(item);
6927 }
6928 }
6929 }
6930 }
6931 }
6932 Ok(())
6933 }
6934}
6935
6936pub fn create_cypher_collect_udaf() -> AggregateUDF {
6937 AggregateUDF::from(CypherCollectUdaf::new())
6938}
6939
6940pub fn create_cypher_collect_expr(
6942 arg: datafusion::logical_expr::Expr,
6943 distinct: bool,
6944) -> datafusion::logical_expr::Expr {
6945 let udaf = Arc::new(create_cypher_collect_udaf());
6948 if distinct {
6949 datafusion::logical_expr::Expr::AggregateFunction(
6951 datafusion::logical_expr::expr::AggregateFunction::new_udf(
6952 udaf,
6953 vec![arg],
6954 true, None,
6956 vec![],
6957 None,
6958 ),
6959 )
6960 } else {
6961 udaf.call(vec![arg])
6962 }
6963}
6964
6965#[derive(Debug, Clone)]
6971struct CypherPercentileDiscUdaf {
6972 signature: Signature,
6973}
6974
6975impl CypherPercentileDiscUdaf {
6976 fn new() -> Self {
6977 Self {
6978 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
6979 }
6980 }
6981}
6982
6983impl PartialEq for CypherPercentileDiscUdaf {
6984 fn eq(&self, other: &Self) -> bool {
6985 self.signature == other.signature
6986 }
6987}
6988
6989impl Eq for CypherPercentileDiscUdaf {}
6990
6991impl Hash for CypherPercentileDiscUdaf {
6992 fn hash<H: Hasher>(&self, state: &mut H) {
6993 self.name().hash(state);
6994 }
6995}
6996
6997impl AggregateUDFImpl for CypherPercentileDiscUdaf {
6998 fn as_any(&self) -> &dyn Any {
6999 self
7000 }
7001 fn name(&self) -> &str {
7002 "percentiledisc"
7003 }
7004 fn signature(&self) -> &Signature {
7005 &self.signature
7006 }
7007 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7008 Ok(DataType::Float64)
7009 }
7010 fn accumulator(
7011 &self,
7012 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7013 ) -> DFResult<Box<dyn DfAccumulator>> {
7014 Ok(Box::new(CypherPercentileDiscAccumulator {
7015 values: Vec::new(),
7016 percentile: None,
7017 }))
7018 }
7019 fn state_fields(
7020 &self,
7021 args: datafusion::logical_expr::function::StateFieldsArgs,
7022 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7023 Ok(vec![
7024 Arc::new(arrow::datatypes::Field::new(
7025 format!("{}_values", args.name),
7026 DataType::List(Arc::new(arrow::datatypes::Field::new(
7027 "item",
7028 DataType::Float64,
7029 true,
7030 ))),
7031 true,
7032 )),
7033 Arc::new(arrow::datatypes::Field::new(
7034 format!("{}_percentile", args.name),
7035 DataType::Float64,
7036 true,
7037 )),
7038 ])
7039 }
7040}
7041
7042#[derive(Debug)]
7043struct CypherPercentileDiscAccumulator {
7044 values: Vec<f64>,
7045 percentile: Option<f64>,
7046}
7047
7048impl CypherPercentileDiscAccumulator {
7049 fn extract_f64(arr: &ArrayRef, i: usize) -> Option<f64> {
7050 if arr.is_null(i) {
7051 return None;
7052 }
7053 match arr.data_type() {
7054 DataType::LargeBinary => {
7055 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
7056 cv_bytes_as_f64(lb.value(i))
7057 }
7058 DataType::Int64 => {
7059 let a = arr.as_any().downcast_ref::<Int64Array>()?;
7060 Some(a.value(i) as f64)
7061 }
7062 DataType::Float64 => {
7063 let a = arr.as_any().downcast_ref::<Float64Array>()?;
7064 Some(a.value(i))
7065 }
7066 DataType::Int32 => {
7067 let a = arr.as_any().downcast_ref::<Int32Array>()?;
7068 Some(a.value(i) as f64)
7069 }
7070 DataType::Float32 => {
7071 let a = arr.as_any().downcast_ref::<Float32Array>()?;
7072 Some(a.value(i) as f64)
7073 }
7074 _ => None,
7075 }
7076 }
7077
7078 fn extract_percentile(arr: &ArrayRef, i: usize) -> Option<f64> {
7079 if arr.is_null(i) {
7080 return None;
7081 }
7082 match arr.data_type() {
7083 DataType::Float64 => {
7084 let a = arr.as_any().downcast_ref::<Float64Array>()?;
7085 Some(a.value(i))
7086 }
7087 DataType::Int64 => {
7088 let a = arr.as_any().downcast_ref::<Int64Array>()?;
7089 Some(a.value(i) as f64)
7090 }
7091 DataType::LargeBinary => {
7092 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
7093 cv_bytes_as_f64(lb.value(i))
7094 }
7095 _ => None,
7096 }
7097 }
7098}
7099
7100impl DfAccumulator for CypherPercentileDiscAccumulator {
7101 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7102 let expr_arr = &values[0];
7103 let pct_arr = &values[1];
7104 for i in 0..expr_arr.len() {
7105 if self.percentile.is_none()
7107 && let Some(p) = Self::extract_percentile(pct_arr, i)
7108 {
7109 if !(0.0..=1.0).contains(&p) {
7110 return Err(datafusion::error::DataFusionError::Execution(
7111 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7112 ));
7113 }
7114 self.percentile = Some(p);
7115 }
7116 if let Some(f) = Self::extract_f64(expr_arr, i) {
7117 self.values.push(f);
7118 }
7119 }
7120 Ok(())
7121 }
7122 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7123 let pct = match self.percentile {
7124 Some(p) if !(0.0..=1.0).contains(&p) => {
7125 return Err(datafusion::error::DataFusionError::Execution(
7126 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7127 ));
7128 }
7129 Some(p) => p,
7130 None => 0.0,
7131 };
7132 if self.values.is_empty() {
7133 return Ok(ScalarValue::Float64(None));
7134 }
7135 self.values
7136 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7137 let n = self.values.len();
7138 let idx = (pct * (n as f64 - 1.0)).round() as usize;
7139 let idx = idx.min(n - 1);
7140 let result = self.values[idx];
7141 Ok(ScalarValue::Float64(Some(result)))
7142 }
7143 fn size(&self) -> usize {
7144 std::mem::size_of_val(self) + self.values.capacity() * 8
7145 }
7146 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7147 let list_values: Vec<ScalarValue> = self
7149 .values
7150 .iter()
7151 .map(|f| ScalarValue::Float64(Some(*f)))
7152 .collect();
7153 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7154 &list_values,
7155 &DataType::Float64,
7156 true,
7157 ));
7158 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7159 }
7160 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7161 let list_arr = &states[0];
7163 let pct_arr = &states[1];
7164 if self.percentile.is_none()
7166 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7167 {
7168 for i in 0..f64_arr.len() {
7169 if !f64_arr.is_null(i) {
7170 self.percentile = Some(f64_arr.value(i));
7171 break;
7172 }
7173 }
7174 }
7175 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7177 for i in 0..list_array.len() {
7178 if list_array.is_null(i) {
7179 continue;
7180 }
7181 let inner = list_array.value(i);
7182 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7183 for j in 0..f64_arr.len() {
7184 if !f64_arr.is_null(j) {
7185 self.values.push(f64_arr.value(j));
7186 }
7187 }
7188 }
7189 }
7190 }
7191 Ok(())
7192 }
7193}
7194
7195#[derive(Debug, Clone)]
7197struct CypherPercentileContUdaf {
7198 signature: Signature,
7199}
7200
7201impl CypherPercentileContUdaf {
7202 fn new() -> Self {
7203 Self {
7204 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7205 }
7206 }
7207}
7208
7209impl PartialEq for CypherPercentileContUdaf {
7210 fn eq(&self, other: &Self) -> bool {
7211 self.signature == other.signature
7212 }
7213}
7214
7215impl Eq for CypherPercentileContUdaf {}
7216
7217impl Hash for CypherPercentileContUdaf {
7218 fn hash<H: Hasher>(&self, state: &mut H) {
7219 self.name().hash(state);
7220 }
7221}
7222
7223impl AggregateUDFImpl for CypherPercentileContUdaf {
7224 fn as_any(&self) -> &dyn Any {
7225 self
7226 }
7227 fn name(&self) -> &str {
7228 "percentilecont"
7229 }
7230 fn signature(&self) -> &Signature {
7231 &self.signature
7232 }
7233 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7234 Ok(DataType::Float64)
7235 }
7236 fn accumulator(
7237 &self,
7238 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7239 ) -> DFResult<Box<dyn DfAccumulator>> {
7240 Ok(Box::new(CypherPercentileContAccumulator {
7241 values: Vec::new(),
7242 percentile: None,
7243 }))
7244 }
7245 fn state_fields(
7246 &self,
7247 args: datafusion::logical_expr::function::StateFieldsArgs,
7248 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7249 Ok(vec![
7250 Arc::new(arrow::datatypes::Field::new(
7251 format!("{}_values", args.name),
7252 DataType::List(Arc::new(arrow::datatypes::Field::new(
7253 "item",
7254 DataType::Float64,
7255 true,
7256 ))),
7257 true,
7258 )),
7259 Arc::new(arrow::datatypes::Field::new(
7260 format!("{}_percentile", args.name),
7261 DataType::Float64,
7262 true,
7263 )),
7264 ])
7265 }
7266}
7267
7268#[derive(Debug)]
7269struct CypherPercentileContAccumulator {
7270 values: Vec<f64>,
7271 percentile: Option<f64>,
7272}
7273
7274impl DfAccumulator for CypherPercentileContAccumulator {
7275 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7276 let expr_arr = &values[0];
7277 let pct_arr = &values[1];
7278 for i in 0..expr_arr.len() {
7279 if self.percentile.is_none()
7280 && let Some(p) = CypherPercentileDiscAccumulator::extract_percentile(pct_arr, i)
7281 {
7282 if !(0.0..=1.0).contains(&p) {
7283 return Err(datafusion::error::DataFusionError::Execution(
7284 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7285 ));
7286 }
7287 self.percentile = Some(p);
7288 }
7289 if let Some(f) = CypherPercentileDiscAccumulator::extract_f64(expr_arr, i) {
7290 self.values.push(f);
7291 }
7292 }
7293 Ok(())
7294 }
7295 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7296 let pct = match self.percentile {
7297 Some(p) if !(0.0..=1.0).contains(&p) => {
7298 return Err(datafusion::error::DataFusionError::Execution(
7299 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7300 ));
7301 }
7302 Some(p) => p,
7303 None => 0.0,
7304 };
7305 if self.values.is_empty() {
7306 return Ok(ScalarValue::Float64(None));
7307 }
7308 self.values
7309 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7310 let n = self.values.len();
7311 if n == 1 {
7312 return Ok(ScalarValue::Float64(Some(self.values[0])));
7313 }
7314 let pos = pct * (n as f64 - 1.0);
7315 let lower = pos.floor() as usize;
7316 let upper = pos.ceil() as usize;
7317 let lower = lower.min(n - 1);
7318 let upper = upper.min(n - 1);
7319 if lower == upper {
7320 Ok(ScalarValue::Float64(Some(self.values[lower])))
7321 } else {
7322 let frac = pos - lower as f64;
7323 let result = self.values[lower] + frac * (self.values[upper] - self.values[lower]);
7324 Ok(ScalarValue::Float64(Some(result)))
7325 }
7326 }
7327 fn size(&self) -> usize {
7328 std::mem::size_of_val(self) + self.values.capacity() * 8
7329 }
7330 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7331 let list_values: Vec<ScalarValue> = self
7332 .values
7333 .iter()
7334 .map(|f| ScalarValue::Float64(Some(*f)))
7335 .collect();
7336 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7337 &list_values,
7338 &DataType::Float64,
7339 true,
7340 ));
7341 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7342 }
7343 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7344 let list_arr = &states[0];
7345 let pct_arr = &states[1];
7346 if self.percentile.is_none()
7347 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7348 {
7349 for i in 0..f64_arr.len() {
7350 if !f64_arr.is_null(i) {
7351 self.percentile = Some(f64_arr.value(i));
7352 break;
7353 }
7354 }
7355 }
7356 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7357 for i in 0..list_array.len() {
7358 if list_array.is_null(i) {
7359 continue;
7360 }
7361 let inner = list_array.value(i);
7362 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7363 for j in 0..f64_arr.len() {
7364 if !f64_arr.is_null(j) {
7365 self.values.push(f64_arr.value(j));
7366 }
7367 }
7368 }
7369 }
7370 }
7371 Ok(())
7372 }
7373}
7374
7375pub fn create_cypher_percentile_disc_udaf() -> AggregateUDF {
7376 AggregateUDF::from(CypherPercentileDiscUdaf::new())
7377}
7378
7379pub fn create_cypher_percentile_cont_udaf() -> AggregateUDF {
7380 AggregateUDF::from(CypherPercentileContUdaf::new())
7381}
7382
7383fn invoke_similarity_udf(
7393 func_name: &str,
7394 min_args: usize,
7395 args: ScalarFunctionArgs,
7396) -> DFResult<ColumnarValue> {
7397 let output_type = DataType::Float64;
7398 invoke_cypher_udf(args, &output_type, |val_args| {
7399 if val_args.len() < min_args {
7400 return Err(datafusion::error::DataFusionError::Execution(format!(
7401 "{} requires at least {} arguments",
7402 func_name, min_args
7403 )));
7404 }
7405 crate::similar_to::eval_similar_to_pure(&val_args[0], &val_args[1])
7406 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
7407 })
7408}
7409
7410pub fn create_similar_to_udf() -> ScalarUDF {
7412 ScalarUDF::new_from_impl(SimilarToUdf::new())
7413}
7414
7415#[derive(Debug)]
7416struct SimilarToUdf {
7417 signature: Signature,
7418}
7419
7420impl SimilarToUdf {
7421 fn new() -> Self {
7422 Self {
7423 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
7424 }
7425 }
7426}
7427
7428impl_udf_eq_hash!(SimilarToUdf);
7429
7430impl ScalarUDFImpl for SimilarToUdf {
7431 fn as_any(&self) -> &dyn Any {
7432 self
7433 }
7434
7435 fn name(&self) -> &str {
7436 "similar_to"
7437 }
7438
7439 fn signature(&self) -> &Signature {
7440 &self.signature
7441 }
7442
7443 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7444 Ok(DataType::Float64)
7445 }
7446
7447 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7448 invoke_similarity_udf("similar_to", 2, args)
7449 }
7450}
7451
7452pub fn create_vector_similarity_udf() -> ScalarUDF {
7454 ScalarUDF::new_from_impl(VectorSimilarityUdf::new())
7455}
7456
7457#[derive(Debug)]
7458struct VectorSimilarityUdf {
7459 signature: Signature,
7460}
7461
7462impl VectorSimilarityUdf {
7463 fn new() -> Self {
7464 Self {
7465 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7466 }
7467 }
7468}
7469
7470impl_udf_eq_hash!(VectorSimilarityUdf);
7471
7472impl ScalarUDFImpl for VectorSimilarityUdf {
7473 fn as_any(&self) -> &dyn Any {
7474 self
7475 }
7476
7477 fn name(&self) -> &str {
7478 "vector_similarity"
7479 }
7480
7481 fn signature(&self) -> &Signature {
7482 &self.signature
7483 }
7484
7485 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7486 Ok(DataType::Float64)
7487 }
7488
7489 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7490 invoke_similarity_udf("vector_similarity", 2, args)
7491 }
7492}
7493
7494#[cfg(test)]
7495mod tests {
7496 use super::*;
7497 use datafusion::execution::FunctionRegistry;
7498
7499 #[test]
7500 fn test_register_udfs() {
7501 let ctx = SessionContext::new();
7502 register_cypher_udfs(&ctx).unwrap();
7503
7504 assert!(ctx.udf("id").is_ok());
7507 assert!(ctx.udf("type").is_ok());
7508 assert!(ctx.udf("keys").is_ok());
7509 assert!(ctx.udf("range").is_ok());
7510 assert!(
7511 ctx.udf("_make_cypher_list").is_ok(),
7512 "_make_cypher_list UDF should be registered"
7513 );
7514 assert!(
7515 ctx.udf("_cv_to_bool").is_ok(),
7516 "_cv_to_bool UDF should be registered"
7517 );
7518 }
7519
7520 #[test]
7521 fn test_id_udf_signature() {
7522 let udf = create_id_udf();
7523 assert_eq!(udf.name(), "id");
7524 }
7525
7526 #[test]
7527 fn test_has_null_udf() {
7528 use datafusion::arrow::datatypes::{DataType, Field};
7529 use datafusion::config::ConfigOptions;
7530 use datafusion::scalar::ScalarValue;
7531 use std::sync::Arc;
7532
7533 let udf = create_has_null_udf();
7534
7535 let values = vec![
7537 ScalarValue::Int64(Some(1)),
7538 ScalarValue::Int64(Some(2)),
7539 ScalarValue::Int64(None),
7540 ];
7541
7542 let list_scalar = ScalarValue::List(ScalarValue::new_list(&values, &DataType::Int64, true));
7544
7545 let list_field = Arc::new(Field::new(
7546 "item",
7547 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
7548 true,
7549 ));
7550
7551 let args = ScalarFunctionArgs {
7552 args: vec![ColumnarValue::Scalar(list_scalar)],
7553 arg_fields: vec![list_field],
7554 number_rows: 1,
7555 return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
7556 config_options: Arc::new(ConfigOptions::default()),
7557 };
7558
7559 let result = udf.invoke_with_args(args).unwrap();
7560
7561 if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = result {
7562 assert!(b, "has_null should return true for list with null");
7563 } else {
7564 panic!("Unexpected result: {:?}", result);
7565 }
7566 }
7567
7568 fn json_to_cv_bytes(val: &serde_json::Value) -> Vec<u8> {
7574 let uni_val: uni_common::Value = val.clone().into();
7575 uni_common::cypher_value_codec::encode(&uni_val)
7576 }
7577
7578 fn make_multi_scalar_args(scalars: Vec<ScalarValue>) -> ScalarFunctionArgs {
7587 make_multi_scalar_args_with_return(scalars, DataType::LargeBinary)
7588 }
7589
7590 fn make_multi_scalar_args_with_return(
7591 scalars: Vec<ScalarValue>,
7592 return_type: DataType,
7593 ) -> ScalarFunctionArgs {
7594 use datafusion::arrow::datatypes::Field;
7595 use datafusion::config::ConfigOptions;
7596
7597 let arg_fields: Vec<_> = scalars
7598 .iter()
7599 .enumerate()
7600 .map(|(i, s)| Arc::new(Field::new(format!("arg{i}"), s.data_type(), true)))
7601 .collect();
7602 let args: Vec<_> = scalars.into_iter().map(ColumnarValue::Scalar).collect();
7603 ScalarFunctionArgs {
7604 args,
7605 arg_fields,
7606 number_rows: 1,
7607 return_field: Arc::new(Field::new("result", return_type, true)),
7608 config_options: Arc::new(ConfigOptions::default()),
7609 }
7610 }
7611
7612 fn decode_cv_scalar(cv: &ColumnarValue) -> serde_json::Value {
7614 match cv {
7615 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7616 let val = uni_common::cypher_value_codec::decode(bytes)
7617 .expect("failed to decode CypherValue output");
7618 val.into()
7619 }
7620 other => panic!("expected LargeBinary scalar, got {other:?}"),
7621 }
7622 }
7623
7624 #[test]
7625 fn test_make_cypher_list_scalars() {
7626 let udf = create_make_cypher_list_udf();
7627 let args = make_multi_scalar_args(vec![
7628 ScalarValue::Int64(Some(1)),
7629 ScalarValue::Float64(Some(3.21)),
7630 ScalarValue::Utf8(Some("hello".to_string())),
7631 ScalarValue::Boolean(Some(true)),
7632 ScalarValue::Null,
7633 ]);
7634 let result = udf.invoke_with_args(args).unwrap();
7635 let json = decode_cv_scalar(&result);
7636 let arr = json.as_array().expect("should be array");
7637 assert_eq!(arr.len(), 5);
7638 assert_eq!(arr[0], serde_json::json!(1));
7639 assert_eq!(arr[1], serde_json::json!(3.21));
7640 assert_eq!(arr[2], serde_json::json!("hello"));
7641 assert_eq!(arr[3], serde_json::json!(true));
7642 assert!(arr[4].is_null());
7643 }
7644
7645 #[test]
7646 fn test_make_cypher_list_empty() {
7647 let udf = create_make_cypher_list_udf();
7648 let args = make_multi_scalar_args(vec![]);
7649 let result = udf.invoke_with_args(args).unwrap();
7650 let json = decode_cv_scalar(&result);
7651 let arr = json.as_array().expect("should be array");
7652 assert!(arr.is_empty());
7653 }
7654
7655 #[test]
7656 fn test_make_cypher_list_single() {
7657 let udf = create_make_cypher_list_udf();
7658 let args = make_multi_scalar_args(vec![ScalarValue::Int64(Some(42))]);
7659 let result = udf.invoke_with_args(args).unwrap();
7660 let json = decode_cv_scalar(&result);
7661 let arr = json.as_array().expect("should be array");
7662 assert_eq!(arr.len(), 1);
7663 assert_eq!(arr[0], serde_json::json!(42));
7664 }
7665
7666 #[test]
7667 fn test_make_cypher_list_nested_cypher_value() {
7668 let udf = create_make_cypher_list_udf();
7669 let nested_bytes = json_to_cv_bytes(&serde_json::json!([1, 2]));
7671 let args = make_multi_scalar_args(vec![
7672 ScalarValue::LargeBinary(Some(nested_bytes)),
7673 ScalarValue::Int64(Some(3)),
7674 ]);
7675 let result = udf.invoke_with_args(args).unwrap();
7676 let json = decode_cv_scalar(&result);
7677 let arr = json.as_array().expect("should be array");
7678 assert_eq!(arr.len(), 2);
7679 assert_eq!(arr[0], serde_json::json!([1, 2]));
7680 assert_eq!(arr[1], serde_json::json!(3));
7681 }
7682
7683 fn make_cypher_in_args(
7689 element: &serde_json::Value,
7690 list: &serde_json::Value,
7691 ) -> ScalarFunctionArgs {
7692 make_multi_scalar_args_with_return(
7693 vec![
7694 ScalarValue::LargeBinary(Some(json_to_cv_bytes(element))),
7695 ScalarValue::LargeBinary(Some(json_to_cv_bytes(list))),
7696 ],
7697 DataType::Boolean,
7698 )
7699 }
7700
7701 #[test]
7702 fn test_cypher_in_found() {
7703 let udf = create_cypher_in_udf();
7704 let args = make_cypher_in_args(&serde_json::json!(3), &serde_json::json!([1, 2, 3]));
7705 let result = udf.invoke_with_args(args).unwrap();
7706 match result {
7707 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7708 other => panic!("expected Boolean(true), got {other:?}"),
7709 }
7710 }
7711
7712 #[test]
7713 fn test_cypher_in_not_found() {
7714 let udf = create_cypher_in_udf();
7715 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, 2, 3]));
7716 let result = udf.invoke_with_args(args).unwrap();
7717 match result {
7718 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7719 other => panic!("expected Boolean(false), got {other:?}"),
7720 }
7721 }
7722
7723 #[test]
7724 fn test_cypher_in_null_list() {
7725 let udf = create_cypher_in_udf();
7726 let args = make_multi_scalar_args_with_return(
7727 vec![
7728 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(1)))),
7729 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7730 ],
7731 DataType::Boolean,
7732 );
7733 let result = udf.invoke_with_args(args).unwrap();
7734 match result {
7735 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null list, got {other:?}"),
7737 }
7738 }
7739
7740 #[test]
7741 fn test_cypher_in_null_element_nonempty() {
7742 let udf = create_cypher_in_udf();
7743 let args = make_multi_scalar_args_with_return(
7744 vec![
7745 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7746 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7747 ],
7748 DataType::Boolean,
7749 );
7750 let result = udf.invoke_with_args(args).unwrap();
7751 match result {
7752 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null IN non-empty list, got {other:?}"),
7754 }
7755 }
7756
7757 #[test]
7758 fn test_cypher_in_null_element_empty() {
7759 let udf = create_cypher_in_udf();
7760 let args = make_multi_scalar_args_with_return(
7761 vec![
7762 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7763 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7764 ],
7765 DataType::Boolean,
7766 );
7767 let result = udf.invoke_with_args(args).unwrap();
7768 match result {
7769 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7770 other => panic!("expected Boolean(false) for null IN [], got {other:?}"),
7771 }
7772 }
7773
7774 #[test]
7775 fn test_cypher_in_not_found_with_null() {
7776 let udf = create_cypher_in_udf();
7777 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, null, 3]));
7778 let result = udf.invoke_with_args(args).unwrap();
7779 match result {
7780 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for 4 IN [1,null,3], got {other:?}"),
7782 }
7783 }
7784
7785 #[test]
7786 fn test_cypher_in_cross_type_int_float() {
7787 let udf = create_cypher_in_udf();
7788 let args = make_cypher_in_args(&serde_json::json!(1), &serde_json::json!([1.0, 2.0]));
7789 let result = udf.invoke_with_args(args).unwrap();
7790 match result {
7791 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7792 other => panic!("expected Boolean(true) for 1 IN [1.0, 2.0], got {other:?}"),
7793 }
7794 }
7795
7796 #[test]
7801 fn test_list_concat_basic() {
7802 let udf = create_cypher_list_concat_udf();
7803 let args = make_multi_scalar_args(vec![
7804 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7805 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([3, 4])))),
7806 ]);
7807 let result = udf.invoke_with_args(args).unwrap();
7808 let json = decode_cv_scalar(&result);
7809 assert_eq!(json, serde_json::json!([1, 2, 3, 4]));
7810 }
7811
7812 #[test]
7813 fn test_list_concat_empty() {
7814 let udf = create_cypher_list_concat_udf();
7815 let args = make_multi_scalar_args(vec![
7816 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7817 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7818 ]);
7819 let result = udf.invoke_with_args(args).unwrap();
7820 let json = decode_cv_scalar(&result);
7821 assert_eq!(json, serde_json::json!([1]));
7822 }
7823
7824 #[test]
7825 fn test_list_concat_null_left() {
7826 let udf = create_cypher_list_concat_udf();
7827 let args = make_multi_scalar_args(vec![
7828 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7829 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7830 ]);
7831 let result = udf.invoke_with_args(args).unwrap();
7832 match result {
7833 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7834 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7835 let json: serde_json::Value = uni_val.into();
7836 assert!(json.is_null(), "expected null, got {json}");
7837 }
7838 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {} other => panic!("expected null result, got {other:?}"),
7840 }
7841 }
7842
7843 #[test]
7844 fn test_list_concat_null_right() {
7845 let udf = create_cypher_list_concat_udf();
7846 let args = make_multi_scalar_args(vec![
7847 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7848 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7849 ]);
7850 let result = udf.invoke_with_args(args).unwrap();
7851 match result {
7852 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7853 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7854 let json: serde_json::Value = uni_val.into();
7855 assert!(json.is_null(), "expected null, got {json}");
7856 }
7857 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7858 other => panic!("expected null result, got {other:?}"),
7859 }
7860 }
7861
7862 #[test]
7867 fn test_list_append_scalar() {
7868 let udf = create_cypher_list_append_udf();
7869 let args = make_multi_scalar_args(vec![
7870 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7871 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7872 ]);
7873 let result = udf.invoke_with_args(args).unwrap();
7874 let json = decode_cv_scalar(&result);
7875 assert_eq!(json, serde_json::json!([1, 2, 3]));
7876 }
7877
7878 #[test]
7879 fn test_list_prepend_scalar() {
7880 let udf = create_cypher_list_append_udf();
7881 let args = make_multi_scalar_args(vec![
7882 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7883 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7884 ]);
7885 let result = udf.invoke_with_args(args).unwrap();
7886 let json = decode_cv_scalar(&result);
7887 assert_eq!(json, serde_json::json!([3, 1, 2]));
7888 }
7889
7890 #[test]
7891 fn test_list_append_null_list() {
7892 let udf = create_cypher_list_append_udf();
7893 let args = make_multi_scalar_args(vec![
7894 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7895 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7896 ]);
7897 let result = udf.invoke_with_args(args).unwrap();
7898 match result {
7899 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7900 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7901 let json: serde_json::Value = uni_val.into();
7902 assert!(json.is_null(), "expected null, got {json}");
7903 }
7904 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7905 other => panic!("expected null result, got {other:?}"),
7906 }
7907 }
7908
7909 #[test]
7910 fn test_list_append_null_scalar() {
7911 let udf = create_cypher_list_append_udf();
7912 let args = make_multi_scalar_args(vec![
7913 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7914 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7915 ]);
7916 let result = udf.invoke_with_args(args).unwrap();
7917 match result {
7918 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7919 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7920 let json: serde_json::Value = uni_val.into();
7921 assert!(json.is_null(), "expected null, got {json}");
7922 }
7923 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7924 other => panic!("expected null result, got {other:?}"),
7925 }
7926 }
7927
7928 #[test]
7933 fn test_sort_key_cross_type_ordering() {
7934 use uni_common::core::id::{Eid, Vid};
7937 use uni_common::{Edge, Node, Path, TemporalValue, Value};
7938
7939 let map_val = Value::Map([("a".to_string(), Value::String("map".to_string()))].into());
7940 let node_val = Value::Node(Node {
7941 vid: Vid::new(1),
7942 labels: vec!["L".to_string()],
7943 properties: Default::default(),
7944 });
7945 let edge_val = Value::Edge(Edge {
7946 eid: Eid::new(1),
7947 edge_type: "T".to_string(),
7948 src: Vid::new(1),
7949 dst: Vid::new(2),
7950 properties: Default::default(),
7951 });
7952 let list_val = Value::List(vec![Value::Int(1)]);
7953 let path_val = Value::Path(Path {
7954 nodes: vec![Node {
7955 vid: Vid::new(1),
7956 labels: vec!["L".to_string()],
7957 properties: Default::default(),
7958 }],
7959 edges: vec![],
7960 });
7961 let string_val = Value::String("hello".to_string());
7962 let bool_val = Value::Bool(false);
7963 let temporal_val = Value::Temporal(TemporalValue::Date {
7964 days_since_epoch: 1000,
7965 });
7966 let number_val = Value::Int(42);
7967 let nan_val = Value::Float(f64::NAN);
7968 let null_val = Value::Null;
7969
7970 let values = vec![
7971 &map_val,
7972 &node_val,
7973 &edge_val,
7974 &list_val,
7975 &path_val,
7976 &string_val,
7977 &bool_val,
7978 &temporal_val,
7979 &number_val,
7980 &nan_val,
7981 &null_val,
7982 ];
7983
7984 let keys: Vec<Vec<u8>> = values.iter().map(|v| encode_cypher_sort_key(v)).collect();
7985
7986 for i in 0..keys.len() - 1 {
7988 assert!(
7989 keys[i] < keys[i + 1],
7990 "Expected sort_key({:?}) < sort_key({:?}), but {:?} >= {:?}",
7991 values[i],
7992 values[i + 1],
7993 keys[i],
7994 keys[i + 1]
7995 );
7996 }
7997 }
7998
7999 #[test]
8000 fn test_sort_key_numbers() {
8001 let neg_inf = encode_cypher_sort_key(&Value::Float(f64::NEG_INFINITY));
8002 let neg_100 = encode_cypher_sort_key(&Value::Float(-100.0));
8003 let neg_1 = encode_cypher_sort_key(&Value::Int(-1));
8004 let zero_int = encode_cypher_sort_key(&Value::Int(0));
8005 let zero_float = encode_cypher_sort_key(&Value::Float(0.0));
8006 let one_int = encode_cypher_sort_key(&Value::Int(1));
8007 let one_float = encode_cypher_sort_key(&Value::Float(1.0));
8008 let hundred = encode_cypher_sort_key(&Value::Int(100));
8009 let pos_inf = encode_cypher_sort_key(&Value::Float(f64::INFINITY));
8010 let nan = encode_cypher_sort_key(&Value::Float(f64::NAN));
8011
8012 assert!(neg_inf < neg_100, "-inf < -100");
8013 assert!(neg_100 < neg_1, "-100 < -1");
8014 assert!(neg_1 < zero_int, "-1 < 0");
8015 assert_eq!(zero_int, zero_float, "0 int == 0.0 float");
8016 assert!(zero_int < one_int, "0 < 1");
8017 assert_eq!(one_int, one_float, "1 int == 1.0 float");
8018 assert!(one_int < hundred, "1 < 100");
8019 assert!(hundred < pos_inf, "100 < +inf");
8020 assert!(pos_inf < nan, "+inf < NaN");
8022 }
8023
8024 #[test]
8025 fn test_sort_key_booleans() {
8026 let f = encode_cypher_sort_key(&Value::Bool(false));
8027 let t = encode_cypher_sort_key(&Value::Bool(true));
8028 assert!(f < t, "false < true");
8029 }
8030
8031 #[test]
8032 fn test_sort_key_strings() {
8033 let empty = encode_cypher_sort_key(&Value::String(String::new()));
8034 let a = encode_cypher_sort_key(&Value::String("a".to_string()));
8035 let ab = encode_cypher_sort_key(&Value::String("ab".to_string()));
8036 let b = encode_cypher_sort_key(&Value::String("b".to_string()));
8037
8038 assert!(empty < a, "'' < 'a'");
8039 assert!(a < ab, "'a' < 'ab'");
8040 assert!(ab < b, "'ab' < 'b'");
8041 }
8042
8043 #[test]
8044 fn test_sort_key_lists() {
8045 let empty = encode_cypher_sort_key(&Value::List(vec![]));
8046 let one = encode_cypher_sort_key(&Value::List(vec![Value::Int(1)]));
8047 let one_two = encode_cypher_sort_key(&Value::List(vec![Value::Int(1), Value::Int(2)]));
8048 let two = encode_cypher_sort_key(&Value::List(vec![Value::Int(2)]));
8049
8050 assert!(empty < one, "[] < [1]");
8051 assert!(one < one_two, "[1] < [1,2]");
8052 assert!(one_two < two, "[1,2] < [2]");
8053 }
8054
8055 #[test]
8056 fn test_sort_key_temporal() {
8057 use uni_common::TemporalValue;
8058
8059 let date1 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8060 days_since_epoch: 100,
8061 }));
8062 let date2 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8063 days_since_epoch: 200,
8064 }));
8065 assert!(date1 < date2, "earlier date < later date");
8066
8067 let date = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8069 days_since_epoch: i32::MAX,
8070 }));
8071 let local_time = encode_cypher_sort_key(&Value::Temporal(TemporalValue::LocalTime {
8072 nanos_since_midnight: 0,
8073 }));
8074 assert!(date < local_time, "Date < LocalTime (by variant rank)");
8075 }
8076
8077 #[test]
8078 fn test_sort_key_nested_lists() {
8079 let inner_a = Value::List(vec![Value::Int(1)]);
8080 let inner_b = Value::List(vec![Value::Int(2)]);
8081
8082 let list_a = encode_cypher_sort_key(&Value::List(vec![inner_a.clone()]));
8083 let list_b = encode_cypher_sort_key(&Value::List(vec![inner_b.clone()]));
8084
8085 assert!(list_a < list_b, "[[1]] < [[2]]");
8086 }
8087
8088 #[test]
8089 fn test_sort_key_null_handling() {
8090 let null_key = encode_cypher_sort_key(&Value::Null);
8091 assert_eq!(null_key, vec![0x0A], "Null produces [0x0A]");
8092
8093 let number_key = encode_cypher_sort_key(&Value::Int(42));
8095 assert!(number_key < null_key, "number < null");
8096 }
8097
8098 #[test]
8099 fn test_byte_stuff_roundtrip() {
8100 let s1 = Value::String("a\x00b".to_string());
8102 let s2 = Value::String("a\x00c".to_string());
8103 let s3 = Value::String("a\x01".to_string());
8104
8105 let k1 = encode_cypher_sort_key(&s1);
8106 let k2 = encode_cypher_sort_key(&s2);
8107 let k3 = encode_cypher_sort_key(&s3);
8108
8109 assert!(k1 < k2, "a\\x00b < a\\x00c");
8110 assert!(k1 < k3, "a\\x00b < a\\x01");
8113 }
8114
8115 #[test]
8116 fn test_sort_key_order_preserving_f64() {
8117 let vals = [f64::NEG_INFINITY, -1.0, -0.0, 0.0, 1.0, f64::INFINITY];
8119 let encoded: Vec<[u8; 8]> = vals
8120 .iter()
8121 .map(|f| encode_order_preserving_f64(*f))
8122 .collect();
8123
8124 for i in 0..encoded.len() - 1 {
8125 assert!(
8126 encoded[i] <= encoded[i + 1],
8127 "encode({}) should <= encode({}), got {:?} vs {:?}",
8128 vals[i],
8129 vals[i + 1],
8130 encoded[i],
8131 encoded[i + 1]
8132 );
8133 }
8134 }
8135
8136 #[test]
8140 fn test_sort_key_string_as_temporal_time_with_offset() {
8141 let tv = sort_key_string_as_temporal("12:35:15+05:00")
8142 .expect("should parse Time with positive offset");
8143 match tv {
8144 uni_common::TemporalValue::Time {
8145 nanos_since_midnight,
8146 offset_seconds,
8147 } => {
8148 assert_eq!(offset_seconds, 5 * 3600, "offset should be +05:00 = 18000s");
8149 let expected_nanos = (12 * 3600 + 35 * 60 + 15) * 1_000_000_000i64;
8151 assert_eq!(nanos_since_midnight, expected_nanos);
8152 }
8153 other => panic!("expected TemporalValue::Time, got {other:?}"),
8154 }
8155 }
8156
8157 #[test]
8158 fn test_sort_key_string_as_temporal_time_negative_offset() {
8159 let tv = sort_key_string_as_temporal("10:35:00-08:00")
8160 .expect("should parse Time with negative offset");
8161 match tv {
8162 uni_common::TemporalValue::Time {
8163 nanos_since_midnight,
8164 offset_seconds,
8165 } => {
8166 assert_eq!(
8167 offset_seconds,
8168 -8 * 3600,
8169 "offset should be -08:00 = -28800s"
8170 );
8171 let expected_nanos = (10 * 3600 + 35 * 60) * 1_000_000_000i64;
8172 assert_eq!(nanos_since_midnight, expected_nanos);
8173 }
8174 other => panic!("expected TemporalValue::Time, got {other:?}"),
8175 }
8176 }
8177
8178 #[test]
8179 fn test_sort_key_string_as_temporal_date() {
8180 use super::super::expr_eval::temporal_from_value;
8181 let tv = temporal_from_value(&Value::String("2024-01-15".into()))
8182 .expect("should parse Date string");
8183 match tv {
8184 uni_common::TemporalValue::Date { days_since_epoch } => {
8185 assert!(days_since_epoch > 0, "2024-01-15 should be after epoch");
8187 }
8188 other => panic!("expected TemporalValue::Date, got {other:?}"),
8189 }
8190 }
8191}