1use arrow::array::ArrayRef;
27use arrow::datatypes::DataType;
28use arrow_array::{
29 Array, BooleanArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array,
30 LargeBinaryArray, LargeStringArray, StringArray, UInt64Array,
31};
32use chrono::Offset;
33use datafusion::error::Result as DFResult;
34use datafusion::logical_expr::{
35 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
36 Volatility,
37};
38use datafusion::prelude::SessionContext;
39use datafusion::scalar::ScalarValue;
40use std::any::Any;
41use std::hash::{Hash, Hasher};
42use std::sync::Arc;
43use uni_common::Value;
44use uni_cypher::ast::BinaryOp;
45use uni_store::storage::arrow_convert::values_to_array;
46
47use super::expr_eval::cypher_eq;
48
49macro_rules! impl_udf_eq_hash {
53 ($type:ty) => {
54 impl PartialEq for $type {
55 fn eq(&self, other: &Self) -> bool {
56 self.signature == other.signature
57 }
58 }
59
60 impl Eq for $type {}
61
62 impl Hash for $type {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.name().hash(state);
65 }
66 }
67 };
68}
69
70pub fn register_cypher_udfs(ctx: &SessionContext) -> DFResult<()> {
80 ctx.register_udf(create_id_udf());
81 ctx.register_udf(create_created_at_udf());
82 ctx.register_udf(create_updated_at_udf());
83 ctx.register_udf(create_type_udf());
84 ctx.register_udf(create_keys_udf());
85 ctx.register_udf(create_properties_udf());
86 ctx.register_udf(create_labels_udf());
87 ctx.register_udf(create_nodes_udf());
88 ctx.register_udf(create_relationships_udf());
89 ctx.register_udf(create_range_udf());
90 ctx.register_udf(create_index_udf());
91 ctx.register_udf(create_startnode_udf());
92 ctx.register_udf(create_endnode_udf());
93
94 ctx.register_udf(create_to_integer_udf());
96 ctx.register_udf(create_to_float_udf());
97 ctx.register_udf(create_to_boolean_udf());
98
99 ctx.register_udf(create_bitwise_or_udf());
101 ctx.register_udf(create_bitwise_and_udf());
102 ctx.register_udf(create_bitwise_xor_udf());
103 ctx.register_udf(create_bitwise_not_udf());
104 ctx.register_udf(create_shift_left_udf());
105 ctx.register_udf(create_shift_right_udf());
106
107 for name in &[
109 "date",
111 "time",
112 "localtime",
113 "localdatetime",
114 "datetime",
115 "duration",
116 "btic",
117 "duration.between",
119 "duration.inmonths",
120 "duration.indays",
121 "duration.inseconds",
122 "datetime.fromepoch",
123 "datetime.fromepochmillis",
124 "date.truncate",
126 "time.truncate",
127 "datetime.truncate",
128 "localdatetime.truncate",
129 "localtime.truncate",
130 "datetime.transaction",
132 "datetime.statement",
133 "datetime.realtime",
134 "date.transaction",
135 "date.statement",
136 "date.realtime",
137 "time.transaction",
138 "time.statement",
139 "time.realtime",
140 "localtime.transaction",
141 "localtime.statement",
142 "localtime.realtime",
143 "localdatetime.transaction",
144 "localdatetime.statement",
145 "localdatetime.realtime",
146 ] {
147 ctx.register_udf(create_temporal_udf(name));
148 }
149
150 ctx.register_udf(create_duration_property_udf());
152 ctx.register_udf(create_temporal_property_udf());
153 ctx.register_udf(create_tostring_udf());
154 ctx.register_udf(create_cypher_sort_key_udf());
155 ctx.register_udf(create_has_null_udf());
156 ctx.register_udf(create_cypher_size_udf());
157
158 ctx.register_udf(create_cypher_starts_with_udf());
160 ctx.register_udf(create_cypher_ends_with_udf());
161 ctx.register_udf(create_cypher_contains_udf());
162
163 ctx.register_udf(create_cypher_list_compare_udf());
165
166 ctx.register_udf(create_cypher_xor_udf());
168
169 ctx.register_udf(create_cypher_equal_udf());
171 ctx.register_udf(create_cypher_not_equal_udf());
172 ctx.register_udf(create_cypher_gt_udf());
173 ctx.register_udf(create_cypher_gt_eq_udf());
174 ctx.register_udf(create_cypher_lt_udf());
175 ctx.register_udf(create_cypher_lt_eq_udf());
176
177 ctx.register_udf(create_cv_to_bool_udf());
179
180 ctx.register_udf(create_cypher_add_udf());
182 ctx.register_udf(create_cypher_sub_udf());
183 ctx.register_udf(create_cypher_mul_udf());
184 ctx.register_udf(create_cypher_div_udf());
185 ctx.register_udf(create_cypher_mod_udf());
186
187 ctx.register_udf(create_map_project_udf());
189
190 ctx.register_udf(create_make_cypher_list_udf());
192
193 ctx.register_udf(create_cypher_in_udf());
195
196 ctx.register_udf(create_cypher_list_concat_udf());
198 ctx.register_udf(create_cypher_list_append_udf());
199 ctx.register_udf(create_cypher_list_slice_udf());
200 ctx.register_udf(create_cypher_tail_udf());
201 ctx.register_udf(create_cypher_head_udf());
202 ctx.register_udf(create_cypher_last_udf());
203 ctx.register_udf(create_cypher_reverse_udf());
204 ctx.register_udf(create_cypher_substring_udf());
205 ctx.register_udf(create_cypher_split_udf());
206 ctx.register_udf(create_cypher_list_to_cv_udf());
207 ctx.register_udf(create_cypher_scalar_to_cv_udf());
208
209 for name in &["year", "month", "day", "hour", "minute", "second"] {
211 ctx.register_udf(create_temporal_udf(name));
212 }
213
214 ctx.register_udf(create_cypher_to_float64_udf());
216
217 ctx.register_udf(create_similar_to_udf());
219 ctx.register_udf(create_vector_similarity_udf());
220 ctx.register_udf(create_sparse_similar_to_udf());
221
222 ctx.register_udaf(create_cypher_min_udaf());
224 ctx.register_udaf(create_cypher_max_udaf());
225 ctx.register_udaf(create_cypher_sum_udaf());
226 ctx.register_udaf(create_cypher_collect_udaf());
227
228 ctx.register_udaf(create_cypher_percentile_disc_udaf());
230 ctx.register_udaf(create_cypher_percentile_cont_udaf());
231
232 register_btic_scalar_udfs(ctx)?;
234
235 ctx.register_udaf(create_btic_min_udaf());
237 ctx.register_udaf(create_btic_max_udaf());
238 ctx.register_udaf(create_btic_span_agg_udaf());
239 ctx.register_udaf(create_btic_count_at_udaf());
240
241 Ok(())
242}
243
244pub fn register_custom_udfs(
252 ctx: &SessionContext,
253 registry: &crate::custom_functions::CustomFunctionRegistry,
254) -> DFResult<()> {
255 for (name, func) in registry.iter() {
256 let lower = name.to_lowercase();
259 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
260 lower,
261 func.clone(),
262 )));
263 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
265 name.to_string(),
266 func.clone(),
267 )));
268 }
269 Ok(())
270}
271
272struct CustomScalarUdf {
277 name: String,
278 func: crate::custom_functions::CustomScalarFn,
279 signature: Signature,
280}
281
282impl CustomScalarUdf {
283 fn new(name: String, func: crate::custom_functions::CustomScalarFn) -> Self {
284 Self {
285 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Volatile),
286 name,
287 func,
288 }
289 }
290}
291
292impl std::fmt::Debug for CustomScalarUdf {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 f.debug_struct("CustomScalarUdf")
295 .field("name", &self.name)
296 .finish()
297 }
298}
299
300impl_udf_eq_hash!(CustomScalarUdf);
301
302impl ScalarUDFImpl for CustomScalarUdf {
303 fn as_any(&self) -> &dyn Any {
304 self
305 }
306
307 fn name(&self) -> &str {
308 &self.name
309 }
310
311 fn signature(&self) -> &Signature {
312 &self.signature
313 }
314
315 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
316 Ok(DataType::LargeBinary)
317 }
318
319 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
320 let func = &self.func;
321 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
322 func(vals).map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
323 })
324 }
325}
326
327pub fn create_id_udf() -> ScalarUDF {
335 ScalarUDF::new_from_impl(IdUdf::new())
336}
337
338#[derive(Debug)]
339struct IdUdf {
340 signature: Signature,
341}
342
343impl IdUdf {
344 fn new() -> Self {
345 Self {
346 signature: Signature::new(
347 TypeSignature::Exact(vec![DataType::UInt64]),
348 Volatility::Immutable,
349 ),
350 }
351 }
352}
353
354impl_udf_eq_hash!(IdUdf);
355
356impl ScalarUDFImpl for IdUdf {
357 fn as_any(&self) -> &dyn Any {
358 self
359 }
360
361 fn name(&self) -> &str {
362 "id"
363 }
364
365 fn signature(&self) -> &Signature {
366 &self.signature
367 }
368
369 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
370 Ok(DataType::UInt64)
371 }
372
373 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
374 if args.args.is_empty() {
376 return Err(datafusion::error::DataFusionError::Execution(
377 "id(): requires 1 argument".to_string(),
378 ));
379 }
380 Ok(args.args[0].clone())
381 }
382}
383
384pub fn create_created_at_udf() -> ScalarUDF {
397 ScalarUDF::new_from_impl(SystemTimestampUdf::new("created_at"))
398}
399
400pub fn create_updated_at_udf() -> ScalarUDF {
407 ScalarUDF::new_from_impl(SystemTimestampUdf::new("updated_at"))
408}
409
410#[derive(Debug)]
411struct SystemTimestampUdf {
412 name: &'static str,
413 signature: Signature,
414}
415
416impl SystemTimestampUdf {
417 fn new(name: &'static str) -> Self {
418 Self {
419 name,
420 signature: Signature::new(
421 TypeSignature::Exact(vec![DataType::Timestamp(
422 arrow_schema::TimeUnit::Nanosecond,
423 Some("UTC".into()),
424 )]),
425 Volatility::Immutable,
426 ),
427 }
428 }
429}
430
431impl_udf_eq_hash!(SystemTimestampUdf);
432
433impl ScalarUDFImpl for SystemTimestampUdf {
434 fn as_any(&self) -> &dyn Any {
435 self
436 }
437
438 fn name(&self) -> &str {
439 self.name
440 }
441
442 fn signature(&self) -> &Signature {
443 &self.signature
444 }
445
446 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
447 Ok(DataType::Timestamp(
448 arrow_schema::TimeUnit::Nanosecond,
449 Some("UTC".into()),
450 ))
451 }
452
453 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
454 if args.args.is_empty() {
459 return Err(datafusion::error::DataFusionError::Execution(format!(
460 "{}(): requires 1 argument",
461 self.name
462 )));
463 }
464 Ok(args.args[0].clone())
465 }
466}
467
468pub fn create_type_udf() -> ScalarUDF {
476 ScalarUDF::new_from_impl(TypeUdf::new())
477}
478
479#[derive(Debug)]
480struct TypeUdf {
481 signature: Signature,
482}
483
484impl TypeUdf {
485 fn new() -> Self {
486 Self {
487 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
491 }
492 }
493}
494
495impl_udf_eq_hash!(TypeUdf);
496
497impl ScalarUDFImpl for TypeUdf {
498 fn as_any(&self) -> &dyn Any {
499 self
500 }
501
502 fn name(&self) -> &str {
503 "type"
504 }
505
506 fn signature(&self) -> &Signature {
507 &self.signature
508 }
509
510 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
511 Ok(DataType::Utf8)
512 }
513
514 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
515 if args.args.is_empty() {
516 return Err(datafusion::error::DataFusionError::Execution(
517 "type(): requires 1 argument".to_string(),
518 ));
519 }
520 let output_type = DataType::Utf8;
521 invoke_cypher_udf(args, &output_type, |val_args| {
522 if val_args.is_empty() {
523 return Err(datafusion::error::DataFusionError::Execution(
524 "type(): requires 1 argument".to_string(),
525 ));
526 }
527 let val = &val_args[0];
528 match val {
529 Value::Map(map) => {
531 if let Some(Value::String(t)) = map.get("_type") {
532 Ok(Value::String(t.clone()))
533 } else {
534 Err(datafusion::error::DataFusionError::Execution(
536 "TypeError: InvalidArgumentValue - type() requires a relationship argument".to_string(),
537 ))
538 }
539 }
540 Value::Null => Ok(Value::Null),
541 _ => Err(datafusion::error::DataFusionError::Execution(
542 "TypeError: InvalidArgumentValue - type() requires a relationship argument"
543 .to_string(),
544 )),
545 }
546 })
547 }
548}
549
550pub fn create_keys_udf() -> ScalarUDF {
558 ScalarUDF::new_from_impl(KeysUdf::new())
559}
560
561#[derive(Debug)]
562struct KeysUdf {
563 signature: Signature,
564}
565
566impl KeysUdf {
567 fn new() -> Self {
568 Self {
569 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
570 }
571 }
572}
573
574impl_udf_eq_hash!(KeysUdf);
575
576impl ScalarUDFImpl for KeysUdf {
577 fn as_any(&self) -> &dyn Any {
578 self
579 }
580
581 fn name(&self) -> &str {
582 "keys"
583 }
584
585 fn signature(&self) -> &Signature {
586 &self.signature
587 }
588
589 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
590 Ok(DataType::List(Arc::new(
591 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
592 )))
593 }
594
595 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
596 let output_type = self.return_type(&[])?;
597 invoke_cypher_udf(args, &output_type, |val_args| {
598 if val_args.is_empty() {
599 return Err(datafusion::error::DataFusionError::Execution(
600 "keys(): requires 1 argument".to_string(),
601 ));
602 }
603
604 let arg = &val_args[0];
605 let keys = match arg {
606 Value::Map(map) => {
607 let (source, is_entity) = match map.get("_all_props") {
617 Some(Value::Map(all)) => (all, true),
618 _ => (map, false),
619 };
620 let mut key_strings: Vec<String> = source
621 .iter()
622 .filter(|(k, v)| !k.starts_with('_') && (!is_entity || !v.is_null()))
623 .map(|(k, _)| k.clone())
624 .collect();
625 key_strings.sort();
626 key_strings
627 .into_iter()
628 .map(Value::String)
629 .collect::<Vec<_>>()
630 }
631 Value::Null => {
632 return Ok(Value::Null);
633 }
634 _ => {
635 vec![]
638 }
639 };
640
641 Ok(Value::List(keys))
642 })
643 }
644}
645
646pub fn create_properties_udf() -> ScalarUDF {
651 ScalarUDF::new_from_impl(PropertiesUdf::new())
652}
653
654#[derive(Debug)]
655struct PropertiesUdf {
656 signature: Signature,
657}
658
659impl PropertiesUdf {
660 fn new() -> Self {
661 Self {
662 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
663 }
664 }
665}
666
667impl_udf_eq_hash!(PropertiesUdf);
668
669impl ScalarUDFImpl for PropertiesUdf {
670 fn as_any(&self) -> &dyn Any {
671 self
672 }
673
674 fn name(&self) -> &str {
675 "properties"
676 }
677
678 fn signature(&self) -> &Signature {
679 &self.signature
680 }
681
682 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
683 Ok(DataType::LargeBinary)
685 }
686
687 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
688 let output_type = self.return_type(&[])?;
689 invoke_cypher_udf(args, &output_type, |val_args| {
690 if val_args.is_empty() {
691 return Err(datafusion::error::DataFusionError::Execution(
692 "properties(): requires 1 argument".to_string(),
693 ));
694 }
695
696 let arg = &val_args[0];
697 match arg {
698 Value::Map(map) => {
699 let identity_null = map
709 .get("_vid")
710 .map(|v| v.is_null())
711 .or_else(|| map.get("_eid").map(|v| v.is_null()))
712 .unwrap_or(false);
713 if identity_null {
714 return Ok(Value::Null);
715 }
716
717 let source = match map.get("_all_props") {
719 Some(Value::Map(all)) => all,
720 _ => map,
721 };
722 let filtered: std::collections::HashMap<String, Value> = source
724 .iter()
725 .filter(|(k, _)| !k.starts_with('_'))
726 .map(|(k, v)| (k.clone(), v.clone()))
727 .collect();
728 Ok(Value::Map(filtered))
729 }
730 _ => Ok(Value::Null),
731 }
732 })
733 }
734}
735
736pub fn create_index_udf() -> ScalarUDF {
741 ScalarUDF::new_from_impl(IndexUdf::new())
742}
743
744#[derive(Debug)]
745struct IndexUdf {
746 signature: Signature,
747}
748
749impl IndexUdf {
750 fn new() -> Self {
751 Self {
752 signature: Signature::any(2, Volatility::Immutable),
753 }
754 }
755}
756
757impl_udf_eq_hash!(IndexUdf);
758
759impl ScalarUDFImpl for IndexUdf {
760 fn as_any(&self) -> &dyn Any {
761 self
762 }
763
764 fn name(&self) -> &str {
765 "index"
766 }
767
768 fn signature(&self) -> &Signature {
769 &self.signature
770 }
771
772 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
773 Ok(DataType::LargeBinary)
775 }
776
777 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
778 let output_type = self.return_type(&[])?;
779 invoke_cypher_udf(args, &output_type, |val_args| {
780 if val_args.len() != 2 {
781 return Err(datafusion::error::DataFusionError::Execution(
782 "index(): requires 2 arguments".to_string(),
783 ));
784 }
785
786 let container = &val_args[0];
787 let index = &val_args[1];
788
789 let index_as_int = index.as_i64();
793
794 let result = match container {
795 Value::List(arr) => {
796 if let Some(i) = index_as_int {
797 let idx = if i < 0 {
798 let pos = arr.len() as i64 + i;
799 if pos < 0 { -1 } else { pos }
800 } else {
801 i
802 };
803 if idx >= 0 && (idx as usize) < arr.len() {
804 arr[idx as usize].clone()
805 } else {
806 Value::Null
807 }
808 } else if index.is_null() {
809 Value::Null
810 } else {
811 return Err(datafusion::error::DataFusionError::Execution(format!(
812 "TypeError: InvalidArgumentType - list index must be an integer, got: {:?}",
813 index
814 )));
815 }
816 }
817 Value::Map(map) => {
818 if let Some(key) = index.as_str() {
819 if let Some(val) = map.get(key) {
821 val.clone()
822 } else if let Some(Value::Map(all_props)) = map.get("_all_props") {
823 all_props.get(key).cloned().unwrap_or(Value::Null)
825 } else if let Some(Value::Map(props)) = map.get("properties") {
826 props.get(key).cloned().unwrap_or(Value::Null)
828 } else {
829 Value::Null
830 }
831 } else if !index.is_null() {
832 return Err(datafusion::error::DataFusionError::Execution(
833 "index(): map index must be a string".to_string(),
834 ));
835 } else {
836 Value::Null
837 }
838 }
839 Value::Node(node) => {
840 if let Some(key) = index.as_str() {
841 node.properties.get(key).cloned().unwrap_or(Value::Null)
842 } else if !index.is_null() {
843 return Err(datafusion::error::DataFusionError::Execution(
844 "index(): node index must be a string".to_string(),
845 ));
846 } else {
847 Value::Null
848 }
849 }
850 Value::Edge(edge) => {
851 if let Some(key) = index.as_str() {
852 edge.properties.get(key).cloned().unwrap_or(Value::Null)
853 } else if !index.is_null() {
854 return Err(datafusion::error::DataFusionError::Execution(
855 "index(): edge index must be a string".to_string(),
856 ));
857 } else {
858 Value::Null
859 }
860 }
861 Value::Null => Value::Null,
862 _ => {
863 return Err(datafusion::error::DataFusionError::Execution(format!(
864 "TypeError: InvalidArgumentType - cannot index into {:?}",
865 container
866 )));
867 }
868 };
869
870 Ok(result)
871 })
872 }
873}
874
875pub fn create_labels_udf() -> ScalarUDF {
880 ScalarUDF::new_from_impl(LabelsUdf::new())
881}
882
883#[derive(Debug)]
884struct LabelsUdf {
885 signature: Signature,
886}
887
888impl LabelsUdf {
889 fn new() -> Self {
890 Self {
891 signature: Signature::any(1, Volatility::Immutable),
892 }
893 }
894}
895
896impl_udf_eq_hash!(LabelsUdf);
897
898impl ScalarUDFImpl for LabelsUdf {
899 fn as_any(&self) -> &dyn Any {
900 self
901 }
902
903 fn name(&self) -> &str {
904 "labels"
905 }
906
907 fn signature(&self) -> &Signature {
908 &self.signature
909 }
910
911 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
912 Ok(DataType::List(Arc::new(
913 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
914 )))
915 }
916
917 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
918 let output_type = self.return_type(&[])?;
919 invoke_cypher_udf(args, &output_type, |val_args| {
920 if val_args.is_empty() {
921 return Err(datafusion::error::DataFusionError::Execution(
922 "labels(): requires 1 argument".to_string(),
923 ));
924 }
925
926 let node = &val_args[0];
927 match node {
928 Value::Map(map) => {
929 if let Some(Value::List(arr)) = map.get("_labels") {
930 Ok(Value::List(arr.clone()))
931 } else {
932 Err(datafusion::error::DataFusionError::Execution(
934 "TypeError: InvalidArgumentValue - labels() requires a node argument"
935 .to_string(),
936 ))
937 }
938 }
939 Value::Null => Ok(Value::Null),
940 _ => Err(datafusion::error::DataFusionError::Execution(
941 "TypeError: InvalidArgumentValue - labels() requires a node argument"
942 .to_string(),
943 )),
944 }
945 })
946 }
947}
948
949pub fn create_nodes_udf() -> ScalarUDF {
954 ScalarUDF::new_from_impl(NodesUdf::new())
955}
956
957#[derive(Debug)]
958struct NodesUdf {
959 signature: Signature,
960}
961
962impl NodesUdf {
963 fn new() -> Self {
964 Self {
965 signature: Signature::any(1, Volatility::Immutable),
966 }
967 }
968}
969
970impl_udf_eq_hash!(NodesUdf);
971
972impl ScalarUDFImpl for NodesUdf {
973 fn as_any(&self) -> &dyn Any {
974 self
975 }
976
977 fn name(&self) -> &str {
978 "nodes"
979 }
980
981 fn signature(&self) -> &Signature {
982 &self.signature
983 }
984
985 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
986 Ok(DataType::LargeBinary)
987 }
988
989 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
990 let output_type = self.return_type(&[])?;
991 invoke_cypher_udf(args, &output_type, |val_args| {
992 if val_args.is_empty() {
993 return Err(datafusion::error::DataFusionError::Execution(
994 "nodes(): requires 1 argument".to_string(),
995 ));
996 }
997
998 let path = &val_args[0];
999 let nodes = match path {
1000 Value::Map(map) => map.get("nodes").cloned().unwrap_or(Value::Null),
1001 _ => Value::Null,
1002 };
1003
1004 Ok(nodes)
1005 })
1006 }
1007}
1008
1009pub fn create_relationships_udf() -> ScalarUDF {
1014 ScalarUDF::new_from_impl(RelationshipsUdf::new())
1015}
1016
1017#[derive(Debug)]
1018struct RelationshipsUdf {
1019 signature: Signature,
1020}
1021
1022impl RelationshipsUdf {
1023 fn new() -> Self {
1024 Self {
1025 signature: Signature::any(1, Volatility::Immutable),
1026 }
1027 }
1028}
1029
1030impl_udf_eq_hash!(RelationshipsUdf);
1031
1032impl ScalarUDFImpl for RelationshipsUdf {
1033 fn as_any(&self) -> &dyn Any {
1034 self
1035 }
1036
1037 fn name(&self) -> &str {
1038 "relationships"
1039 }
1040
1041 fn signature(&self) -> &Signature {
1042 &self.signature
1043 }
1044
1045 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1046 Ok(DataType::LargeBinary)
1047 }
1048
1049 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1050 let output_type = self.return_type(&[])?;
1051 invoke_cypher_udf(args, &output_type, |val_args| {
1052 if val_args.is_empty() {
1053 return Err(datafusion::error::DataFusionError::Execution(
1054 "relationships(): requires 1 argument".to_string(),
1055 ));
1056 }
1057
1058 let path = &val_args[0];
1059 let rels = match path {
1060 Value::Map(map) => map.get("relationships").cloned().unwrap_or(Value::Null),
1061 _ => Value::Null,
1062 };
1063
1064 Ok(rels)
1065 })
1066 }
1067}
1068
1069pub fn create_startnode_udf() -> ScalarUDF {
1078 ScalarUDF::new_from_impl(StartNodeUdf::new())
1079}
1080
1081#[derive(Debug)]
1082struct StartNodeUdf {
1083 signature: Signature,
1084}
1085
1086impl StartNodeUdf {
1087 fn new() -> Self {
1088 Self {
1089 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1090 }
1091 }
1092}
1093
1094impl_udf_eq_hash!(StartNodeUdf);
1095
1096impl ScalarUDFImpl for StartNodeUdf {
1097 fn as_any(&self) -> &dyn Any {
1098 self
1099 }
1100
1101 fn name(&self) -> &str {
1102 "startnode"
1103 }
1104
1105 fn signature(&self) -> &Signature {
1106 &self.signature
1107 }
1108
1109 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1110 Ok(DataType::LargeBinary)
1111 }
1112
1113 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1114 let output_type = DataType::LargeBinary;
1115 invoke_cypher_udf(args, &output_type, |val_args| {
1116 startnode_endnode_impl(val_args, true)
1117 })
1118 }
1119}
1120
1121pub fn create_endnode_udf() -> ScalarUDF {
1127 ScalarUDF::new_from_impl(EndNodeUdf::new())
1128}
1129
1130#[derive(Debug)]
1131struct EndNodeUdf {
1132 signature: Signature,
1133}
1134
1135impl EndNodeUdf {
1136 fn new() -> Self {
1137 Self {
1138 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1139 }
1140 }
1141}
1142
1143impl_udf_eq_hash!(EndNodeUdf);
1144
1145impl ScalarUDFImpl for EndNodeUdf {
1146 fn as_any(&self) -> &dyn Any {
1147 self
1148 }
1149
1150 fn name(&self) -> &str {
1151 "endnode"
1152 }
1153
1154 fn signature(&self) -> &Signature {
1155 &self.signature
1156 }
1157
1158 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1159 Ok(DataType::LargeBinary)
1160 }
1161
1162 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1163 let output_type = DataType::LargeBinary;
1164 invoke_cypher_udf(args, &output_type, |val_args| {
1165 startnode_endnode_impl(val_args, false)
1166 })
1167 }
1168}
1169
1170fn startnode_endnode_impl(val_args: &[Value], is_start: bool) -> DFResult<Value> {
1175 if val_args.is_empty() {
1176 let fn_name = if is_start { "startNode" } else { "endNode" };
1177 return Err(datafusion::error::DataFusionError::Execution(format!(
1178 "{fn_name}(): requires at least 1 argument"
1179 )));
1180 }
1181
1182 let edge_val = &val_args[0];
1183 let target_vid = extract_endpoint_vid(edge_val, is_start);
1184
1185 let target_vid = match target_vid {
1186 Some(vid) => vid,
1187 None => return Ok(Value::Null),
1188 };
1189
1190 for node_val in val_args.iter().skip(1) {
1192 if let Some(vid) = extract_vid(node_val)
1193 && vid == target_vid
1194 {
1195 return Ok(node_val.clone());
1196 }
1197 }
1198
1199 let mut map = std::collections::HashMap::new();
1201 map.insert("_vid".to_string(), Value::Int(target_vid as i64));
1202 Ok(Value::Map(map))
1203}
1204
1205fn extract_endpoint_vid(val: &Value, is_start: bool) -> Option<u64> {
1207 match val {
1208 Value::Edge(edge) => {
1209 let vid = if is_start { edge.src } else { edge.dst };
1210 Some(vid.as_u64())
1211 }
1212 Value::Map(map) => {
1213 let key = if is_start { "_src_vid" } else { "_dst_vid" };
1215 if let Some(v) = map.get(key) {
1216 return v.as_u64();
1217 }
1218 let key2 = if is_start { "_src" } else { "_dst" };
1220 if let Some(v) = map.get(key2) {
1221 return v.as_u64();
1222 }
1223 let node_key = if is_start { "_startNode" } else { "_endNode" };
1225 if let Some(node_val) = map.get(node_key) {
1226 return extract_vid(node_val);
1227 }
1228 None
1229 }
1230 _ => None,
1231 }
1232}
1233
1234fn extract_vid(val: &Value) -> Option<u64> {
1236 match val {
1237 Value::Map(map) => map.get("_vid").and_then(|v| v.as_u64()),
1238 _ => None,
1239 }
1240}
1241
1242fn extract_i64_range_arg(arg: &ColumnarValue, row_idx: usize, name: &str) -> DFResult<i64> {
1249 match arg {
1250 ColumnarValue::Scalar(sv) => match sv {
1251 ScalarValue::Int8(Some(v)) => Ok(*v as i64),
1252 ScalarValue::Int16(Some(v)) => Ok(*v as i64),
1253 ScalarValue::Int32(Some(v)) => Ok(*v as i64),
1254 ScalarValue::Int64(Some(v)) => Ok(*v),
1255 ScalarValue::UInt8(Some(v)) => Ok(*v as i64),
1256 ScalarValue::UInt16(Some(v)) => Ok(*v as i64),
1257 ScalarValue::UInt32(Some(v)) => Ok(*v as i64),
1258 ScalarValue::UInt64(Some(v)) => Ok(*v as i64),
1259 ScalarValue::LargeBinary(Some(bytes)) => {
1260 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1261 datafusion::error::DataFusionError::Execution(format!(
1262 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1263 name
1264 ))
1265 })
1266 }
1267 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1268 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1269 name
1270 ))),
1271 },
1272 ColumnarValue::Array(arr) => {
1273 if row_idx >= arr.len() || arr.is_null(row_idx) {
1274 return Err(datafusion::error::DataFusionError::Execution(format!(
1275 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1276 name
1277 )));
1278 }
1279 if !arr.is_empty() {
1281 use datafusion::arrow::array::{
1282 Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, UInt16Array,
1283 UInt32Array, UInt64Array,
1284 };
1285 match arr.data_type() {
1286 DataType::Int8 => Ok(arr
1287 .as_any()
1288 .downcast_ref::<Int8Array>()
1289 .unwrap()
1290 .value(row_idx) as i64),
1291 DataType::Int16 => Ok(arr
1292 .as_any()
1293 .downcast_ref::<Int16Array>()
1294 .unwrap()
1295 .value(row_idx) as i64),
1296 DataType::Int32 => Ok(arr
1297 .as_any()
1298 .downcast_ref::<Int32Array>()
1299 .unwrap()
1300 .value(row_idx) as i64),
1301 DataType::Int64 => Ok(arr
1302 .as_any()
1303 .downcast_ref::<Int64Array>()
1304 .unwrap()
1305 .value(row_idx)),
1306 DataType::UInt8 => Ok(arr
1307 .as_any()
1308 .downcast_ref::<UInt8Array>()
1309 .unwrap()
1310 .value(row_idx) as i64),
1311 DataType::UInt16 => Ok(arr
1312 .as_any()
1313 .downcast_ref::<UInt16Array>()
1314 .unwrap()
1315 .value(row_idx) as i64),
1316 DataType::UInt32 => Ok(arr
1317 .as_any()
1318 .downcast_ref::<UInt32Array>()
1319 .unwrap()
1320 .value(row_idx) as i64),
1321 DataType::UInt64 => Ok(arr
1322 .as_any()
1323 .downcast_ref::<UInt64Array>()
1324 .unwrap()
1325 .value(row_idx) as i64),
1326 DataType::LargeBinary => {
1327 let bytes = arr
1328 .as_any()
1329 .downcast_ref::<LargeBinaryArray>()
1330 .unwrap()
1331 .value(row_idx);
1332 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1333 datafusion::error::DataFusionError::Execution(format!(
1334 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1335 name
1336 ))
1337 })
1338 }
1339 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1340 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1341 name
1342 ))),
1343 }
1344 } else {
1345 Err(datafusion::error::DataFusionError::Execution(format!(
1346 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1347 name
1348 )))
1349 }
1350 }
1351 }
1352}
1353
1354pub fn create_range_udf() -> ScalarUDF {
1356 ScalarUDF::new_from_impl(RangeUdf::new())
1357}
1358
1359#[derive(Debug)]
1360struct RangeUdf {
1361 signature: Signature,
1362}
1363
1364impl RangeUdf {
1365 fn new() -> Self {
1366 Self {
1367 signature: Signature::one_of(
1368 vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1369 Volatility::Immutable,
1370 ),
1371 }
1372 }
1373}
1374
1375impl_udf_eq_hash!(RangeUdf);
1376
1377impl ScalarUDFImpl for RangeUdf {
1378 fn as_any(&self) -> &dyn Any {
1379 self
1380 }
1381
1382 fn name(&self) -> &str {
1383 "range"
1384 }
1385
1386 fn signature(&self) -> &Signature {
1387 &self.signature
1388 }
1389
1390 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1391 Ok(DataType::List(Arc::new(
1392 arrow::datatypes::Field::new_list_field(DataType::Int64, true),
1393 )))
1394 }
1395
1396 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1397 if args.args.len() < 2 || args.args.len() > 3 {
1398 return Err(datafusion::error::DataFusionError::Execution(
1399 "range(): requires 2 or 3 arguments".to_string(),
1400 ));
1401 }
1402
1403 let len = args
1404 .args
1405 .iter()
1406 .find_map(|arg| match arg {
1407 ColumnarValue::Array(arr) => Some(arr.len()),
1408 _ => None,
1409 })
1410 .unwrap_or(1);
1411
1412 let mut list_builder =
1413 arrow_array::builder::ListBuilder::new(arrow_array::builder::Int64Builder::new());
1414
1415 for row_idx in 0..len {
1416 let start = extract_i64_range_arg(&args.args[0], row_idx, "start")?;
1417 let end = extract_i64_range_arg(&args.args[1], row_idx, "end")?;
1418 let step = if args.args.len() == 3 {
1419 extract_i64_range_arg(&args.args[2], row_idx, "step")?
1420 } else {
1421 1
1422 };
1423
1424 if step == 0 {
1425 return Err(datafusion::error::DataFusionError::Execution(
1426 "range(): step cannot be zero".to_string(),
1427 ));
1428 }
1429
1430 if step > 0 && start <= end {
1431 let mut current = start;
1432 while current <= end {
1433 list_builder.values().append_value(current);
1434 current += step;
1435 }
1436 } else if step < 0 && start >= end {
1437 let mut current = start;
1438 while current >= end {
1439 list_builder.values().append_value(current);
1440 current += step;
1441 }
1442 }
1443 list_builder.append(true);
1445 }
1446
1447 let list_arr = Arc::new(list_builder.finish()) as ArrayRef;
1448 if len == 1
1449 && args
1450 .args
1451 .iter()
1452 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
1453 {
1454 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
1455 &list_arr, 0,
1456 )?))
1457 } else {
1458 Ok(ColumnarValue::Array(list_arr))
1459 }
1460 }
1461}
1462
1463fn invoke_binary_bitwise_op<F>(
1471 args: &ScalarFunctionArgs,
1472 name: &str,
1473 op: F,
1474) -> DFResult<ColumnarValue>
1475where
1476 F: Fn(i64, i64) -> i64,
1477{
1478 use arrow_array::Int64Array;
1479 use datafusion::common::ScalarValue;
1480 use datafusion::error::DataFusionError;
1481
1482 if args.args.len() != 2 {
1483 return Err(DataFusionError::Execution(format!(
1484 "{}(): requires exactly 2 arguments",
1485 name
1486 )));
1487 }
1488
1489 let left = &args.args[0];
1490 let right = &args.args[1];
1491
1492 match (left, right) {
1493 (
1494 ColumnarValue::Scalar(ScalarValue::Int64(Some(l))),
1495 ColumnarValue::Scalar(ScalarValue::Int64(Some(r))),
1496 ) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*l, *r))))),
1497 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
1498 let l_arr = l_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1499 DataFusionError::Execution(format!("{}(): left array must be Int64", name))
1500 })?;
1501 let r_arr = r_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1502 DataFusionError::Execution(format!("{}(): right array must be Int64", name))
1503 })?;
1504
1505 let result: Int64Array = l_arr
1506 .iter()
1507 .zip(r_arr.iter())
1508 .map(|(l, r)| match (l, r) {
1509 (Some(l), Some(r)) => Some(op(l, r)),
1510 _ => None,
1511 })
1512 .collect();
1513
1514 Ok(ColumnarValue::Array(Arc::new(result)))
1515 }
1516 _ => Err(DataFusionError::Execution(format!(
1517 "{}(): mixed scalar/array not supported",
1518 name
1519 ))),
1520 }
1521}
1522
1523fn invoke_unary_bitwise_op<F>(
1527 args: &ScalarFunctionArgs,
1528 name: &str,
1529 op: F,
1530) -> DFResult<ColumnarValue>
1531where
1532 F: Fn(i64) -> i64,
1533{
1534 use arrow_array::Int64Array;
1535 use datafusion::common::ScalarValue;
1536 use datafusion::error::DataFusionError;
1537
1538 if args.args.len() != 1 {
1539 return Err(DataFusionError::Execution(format!(
1540 "{}(): requires exactly 1 argument",
1541 name
1542 )));
1543 }
1544
1545 let operand = &args.args[0];
1546
1547 match operand {
1548 ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
1549 Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*v)))))
1550 }
1551 ColumnarValue::Array(arr) => {
1552 let arr = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1553 DataFusionError::Execution(format!("{}(): array must be Int64", name))
1554 })?;
1555
1556 let result: Int64Array = arr.iter().map(|v| v.map(&op)).collect();
1557
1558 Ok(ColumnarValue::Array(Arc::new(result)))
1559 }
1560 _ => Err(DataFusionError::Execution(format!(
1561 "{}(): invalid argument type",
1562 name
1563 ))),
1564 }
1565}
1566
1567macro_rules! define_binary_bitwise_udf {
1571 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1572 #[derive(Debug)]
1573 struct $struct_name {
1574 signature: Signature,
1575 }
1576
1577 impl $struct_name {
1578 fn new() -> Self {
1579 Self {
1580 signature: Signature::exact(
1581 vec![DataType::Int64, DataType::Int64],
1582 Volatility::Immutable,
1583 ),
1584 }
1585 }
1586 }
1587
1588 impl_udf_eq_hash!($struct_name);
1589
1590 impl ScalarUDFImpl for $struct_name {
1591 fn as_any(&self) -> &dyn Any {
1592 self
1593 }
1594
1595 fn name(&self) -> &str {
1596 $udf_name
1597 }
1598
1599 fn signature(&self) -> &Signature {
1600 &self.signature
1601 }
1602
1603 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1604 Ok(DataType::Int64)
1605 }
1606
1607 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1608 invoke_binary_bitwise_op(&args, $udf_name, $op)
1609 }
1610 }
1611 };
1612}
1613
1614macro_rules! define_unary_bitwise_udf {
1618 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1619 #[derive(Debug)]
1620 struct $struct_name {
1621 signature: Signature,
1622 }
1623
1624 impl $struct_name {
1625 fn new() -> Self {
1626 Self {
1627 signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
1628 }
1629 }
1630 }
1631
1632 impl_udf_eq_hash!($struct_name);
1633
1634 impl ScalarUDFImpl for $struct_name {
1635 fn as_any(&self) -> &dyn Any {
1636 self
1637 }
1638
1639 fn name(&self) -> &str {
1640 $udf_name
1641 }
1642
1643 fn signature(&self) -> &Signature {
1644 &self.signature
1645 }
1646
1647 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1648 Ok(DataType::Int64)
1649 }
1650
1651 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1652 invoke_unary_bitwise_op(&args, $udf_name, $op)
1653 }
1654 }
1655 };
1656}
1657
1658define_binary_bitwise_udf!(BitwiseOrUdf, "uni.bitwise.or", |l, r| l | r);
1660define_binary_bitwise_udf!(BitwiseAndUdf, "uni.bitwise.and", |l, r| l & r);
1661define_binary_bitwise_udf!(BitwiseXorUdf, "uni.bitwise.xor", |l, r| l ^ r);
1662define_binary_bitwise_udf!(ShiftLeftUdf, "uni.bitwise.shiftLeft", |l, r| l << r);
1663define_binary_bitwise_udf!(ShiftRightUdf, "uni.bitwise.shiftRight", |l, r| l >> r);
1664
1665define_unary_bitwise_udf!(BitwiseNotUdf, "uni.bitwise.not", |v| !v);
1667
1668pub fn create_bitwise_or_udf() -> ScalarUDF {
1670 ScalarUDF::new_from_impl(BitwiseOrUdf::new())
1671}
1672
1673pub fn create_bitwise_and_udf() -> ScalarUDF {
1675 ScalarUDF::new_from_impl(BitwiseAndUdf::new())
1676}
1677
1678pub fn create_bitwise_xor_udf() -> ScalarUDF {
1680 ScalarUDF::new_from_impl(BitwiseXorUdf::new())
1681}
1682
1683pub fn create_bitwise_not_udf() -> ScalarUDF {
1685 ScalarUDF::new_from_impl(BitwiseNotUdf::new())
1686}
1687
1688pub fn create_shift_left_udf() -> ScalarUDF {
1690 ScalarUDF::new_from_impl(ShiftLeftUdf::new())
1691}
1692
1693pub fn create_shift_right_udf() -> ScalarUDF {
1695 ScalarUDF::new_from_impl(ShiftRightUdf::new())
1696}
1697
1698fn create_temporal_udf(name: &str) -> ScalarUDF {
1709 ScalarUDF::new_from_impl(TemporalUdf::new(name.to_string()))
1710}
1711
1712#[derive(Debug)]
1713struct TemporalUdf {
1714 name: String,
1715 signature: Signature,
1716}
1717
1718impl TemporalUdf {
1719 fn new(name: String) -> Self {
1720 Self {
1721 name,
1722 signature: Signature::new(
1725 TypeSignature::OneOf(vec![
1726 TypeSignature::Exact(vec![]),
1727 TypeSignature::VariadicAny,
1728 ]),
1729 Volatility::Immutable,
1730 ),
1731 }
1732 }
1733}
1734
1735impl_udf_eq_hash!(TemporalUdf);
1736
1737impl ScalarUDFImpl for TemporalUdf {
1738 fn as_any(&self) -> &dyn Any {
1739 self
1740 }
1741
1742 fn name(&self) -> &str {
1743 &self.name
1744 }
1745
1746 fn signature(&self) -> &Signature {
1747 &self.signature
1748 }
1749
1750 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1751 let name = self.name.to_lowercase();
1752 match name.as_str() {
1753 "year" | "month" | "day" | "hour" | "minute" | "second" => Ok(DataType::Int64),
1755 "datetime"
1761 | "localdatetime"
1762 | "date"
1763 | "time"
1764 | "localtime"
1765 | "duration"
1766 | "date.truncate"
1767 | "time.truncate"
1768 | "datetime.truncate"
1769 | "localdatetime.truncate"
1770 | "localtime.truncate"
1771 | "duration.between"
1772 | "duration.inmonths"
1773 | "duration.indays"
1774 | "duration.inseconds"
1775 | "datetime.fromepoch"
1776 | "datetime.fromepochmillis"
1777 | "datetime.transaction"
1778 | "datetime.statement"
1779 | "datetime.realtime"
1780 | "date.transaction"
1781 | "date.statement"
1782 | "date.realtime"
1783 | "time.transaction"
1784 | "time.statement"
1785 | "time.realtime"
1786 | "localtime.transaction"
1787 | "localtime.statement"
1788 | "localtime.realtime"
1789 | "localdatetime.transaction"
1790 | "localdatetime.statement"
1791 | "localdatetime.realtime"
1792 | "btic" => Ok(DataType::LargeBinary),
1793 _ => Ok(DataType::Utf8),
1794 }
1795 }
1796
1797 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1798 let func_name = self.name.to_uppercase();
1799 let output_type = self.return_type(&[])?;
1800 invoke_cypher_udf(args, &output_type, |val_args| {
1801 crate::datetime::eval_datetime_function(&func_name, val_args).map_err(|e| {
1802 datafusion::error::DataFusionError::Execution(format!("{}(): {}", self.name, e))
1803 })
1804 })
1805 }
1806}
1807
1808fn create_duration_property_udf() -> ScalarUDF {
1813 ScalarUDF::new_from_impl(DurationPropertyUdf::new())
1814}
1815
1816#[derive(Debug)]
1817struct DurationPropertyUdf {
1818 signature: Signature,
1819}
1820
1821impl DurationPropertyUdf {
1822 fn new() -> Self {
1823 Self {
1824 signature: Signature::new(
1825 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
1826 Volatility::Immutable,
1827 ),
1828 }
1829 }
1830}
1831
1832impl_udf_eq_hash!(DurationPropertyUdf);
1833
1834impl ScalarUDFImpl for DurationPropertyUdf {
1835 fn as_any(&self) -> &dyn Any {
1836 self
1837 }
1838
1839 fn name(&self) -> &str {
1840 "_duration_property"
1841 }
1842
1843 fn signature(&self) -> &Signature {
1844 &self.signature
1845 }
1846
1847 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1848 Ok(DataType::Int64)
1849 }
1850
1851 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1852 let output_type = self.return_type(&[])?;
1853 invoke_cypher_udf(args, &output_type, |val_args| {
1854 if val_args.len() != 2 {
1855 return Err(datafusion::error::DataFusionError::Execution(
1856 "_duration_property(): requires 2 arguments (duration_string, component)"
1857 .to_string(),
1858 ));
1859 }
1860
1861 let dur_string_owned;
1862 let dur_str = match &val_args[0] {
1863 Value::String(s) => s.as_str(),
1864 Value::Temporal(uni_common::TemporalValue::Duration { .. }) => {
1865 dur_string_owned = val_args[0].to_string();
1866 &dur_string_owned
1867 }
1868 Value::Null => return Ok(Value::Null),
1869 _ => {
1870 return Err(datafusion::error::DataFusionError::Execution(
1871 "_duration_property(): duration must be a string or temporal duration"
1872 .to_string(),
1873 ));
1874 }
1875 };
1876 let component = match &val_args[1] {
1877 Value::String(s) => s,
1878 _ => {
1879 return Err(datafusion::error::DataFusionError::Execution(
1880 "_duration_property(): component must be a string".to_string(),
1881 ));
1882 }
1883 };
1884
1885 crate::datetime::eval_duration_accessor(dur_str, component).map_err(|e| {
1886 datafusion::error::DataFusionError::Execution(format!(
1887 "_duration_property(): {}",
1888 e
1889 ))
1890 })
1891 })
1892 }
1893}
1894
1895fn create_tostring_udf() -> ScalarUDF {
1900 ScalarUDF::new_from_impl(ToStringUdf::new())
1901}
1902
1903#[derive(Debug)]
1904struct ToStringUdf {
1905 signature: Signature,
1906}
1907
1908impl ToStringUdf {
1909 fn new() -> Self {
1910 Self {
1911 signature: Signature::variadic_any(Volatility::Immutable),
1912 }
1913 }
1914}
1915
1916impl_udf_eq_hash!(ToStringUdf);
1917
1918impl ScalarUDFImpl for ToStringUdf {
1919 fn as_any(&self) -> &dyn Any {
1920 self
1921 }
1922
1923 fn name(&self) -> &str {
1924 "tostring"
1925 }
1926
1927 fn signature(&self) -> &Signature {
1928 &self.signature
1929 }
1930
1931 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1932 Ok(DataType::Utf8)
1933 }
1934
1935 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1936 let output_type = self.return_type(&[])?;
1937 invoke_cypher_udf(args, &output_type, |val_args| {
1938 if val_args.is_empty() {
1939 return Err(datafusion::error::DataFusionError::Execution(
1940 "toString(): requires 1 argument".to_string(),
1941 ));
1942 }
1943 match &val_args[0] {
1944 Value::Null => Ok(Value::Null),
1945 Value::String(s) => Ok(Value::String(s.clone())),
1946 Value::Int(i) => Ok(Value::String(i.to_string())),
1947 Value::Float(f) => Ok(Value::String(f.to_string())),
1948 Value::Bool(b) => Ok(Value::String(b.to_string())),
1949 Value::Temporal(t) => Ok(Value::String(t.to_string())),
1950 other => {
1951 let type_name = match other {
1952 Value::List(_) => "List",
1953 Value::Map(_) => "Map",
1954 Value::Node { .. } => "Node",
1955 Value::Edge { .. } => "Relationship",
1956 Value::Path { .. } => "Path",
1957 _ => "Unknown",
1958 };
1959 Err(datafusion::error::DataFusionError::Execution(format!(
1960 "TypeError: InvalidArgumentValue - toString() does not accept {} values",
1961 type_name
1962 )))
1963 }
1964 }
1965 })
1966 }
1967}
1968
1969fn create_temporal_property_udf() -> ScalarUDF {
1974 ScalarUDF::new_from_impl(TemporalPropertyUdf::new())
1975}
1976
1977#[derive(Debug)]
1978struct TemporalPropertyUdf {
1979 signature: Signature,
1980}
1981
1982impl TemporalPropertyUdf {
1983 fn new() -> Self {
1984 Self {
1985 signature: Signature::variadic_any(Volatility::Immutable),
1986 }
1987 }
1988}
1989
1990impl_udf_eq_hash!(TemporalPropertyUdf);
1991
1992impl ScalarUDFImpl for TemporalPropertyUdf {
1993 fn as_any(&self) -> &dyn Any {
1994 self
1995 }
1996
1997 fn name(&self) -> &str {
1998 "_temporal_property"
1999 }
2000
2001 fn signature(&self) -> &Signature {
2002 &self.signature
2003 }
2004
2005 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2006 Ok(DataType::LargeBinary)
2007 }
2008
2009 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2010 let output_type = self.return_type(&[])?;
2011 invoke_cypher_udf(args, &output_type, |val_args| {
2012 if val_args.len() != 2 {
2013 return Err(datafusion::error::DataFusionError::Execution(
2014 "_temporal_property(): requires 2 arguments (temporal_value, component)"
2015 .to_string(),
2016 ));
2017 }
2018
2019 let component = match &val_args[1] {
2020 Value::String(s) => s.clone(),
2021 _ => {
2022 return Err(datafusion::error::DataFusionError::Execution(
2023 "_temporal_property(): component must be a string".to_string(),
2024 ));
2025 }
2026 };
2027
2028 crate::datetime::eval_temporal_accessor_value(&val_args[0], &component).map_err(|e| {
2029 datafusion::error::DataFusionError::Execution(format!(
2030 "_temporal_property(): {}",
2031 e
2032 ))
2033 })
2034 })
2035 }
2036}
2037
2038macro_rules! downcast_arr {
2041 ($arr:expr, $array_type:ty) => {
2042 $arr.as_any().downcast_ref::<$array_type>().ok_or_else(|| {
2043 datafusion::error::DataFusionError::Execution(format!(
2044 "Failed to downcast to {}",
2045 stringify!($array_type)
2046 ))
2047 })?
2048 };
2049}
2050
2051fn cypher_type_name(val: &Value) -> &'static str {
2053 match val {
2054 Value::Null => "Null",
2055 Value::Bool(_) => "Boolean",
2056 Value::Int(_) => "Integer",
2057 Value::Float(_) => "Float",
2058 Value::String(_) => "String",
2059 Value::Bytes(_) => "Bytes",
2060 Value::List(_) => "List",
2061 Value::Map(_) => "Map",
2062 Value::Node(_) => "Node",
2063 Value::Edge(_) => "Relationship",
2064 Value::Path(_) => "Path",
2065 Value::Vector(_) => "Vector",
2066 Value::Temporal(_) => "Temporal",
2067 _ => "Unknown",
2068 }
2069}
2070
2071fn string_to_value(s: &str) -> Value {
2073 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2074 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2075 {
2076 return Value::from(obj);
2077 }
2078 Value::String(s.to_string())
2079}
2080
2081fn get_value_from_array(
2087 arr: &ArrayRef,
2088 row: usize,
2089 field: Option<&arrow::datatypes::Field>,
2090) -> DFResult<Value> {
2091 if arr.is_null(row) {
2092 return Ok(Value::Null);
2093 }
2094
2095 match arr.data_type() {
2096 DataType::LargeBinary => {
2097 let typed = downcast_arr!(arr, LargeBinaryArray);
2098 let bytes = typed.value(row);
2099 if field.is_some_and(|f| {
2103 f.metadata()
2104 .get("uni_raw_bytes")
2105 .is_some_and(|v| v == "true")
2106 }) {
2107 return Ok(Value::Bytes(bytes.to_vec()));
2108 }
2109 if let Ok(val) = uni_common::cypher_value_codec::decode(bytes) {
2110 return Ok(val);
2111 }
2112 Ok(serde_json::from_slice::<serde_json::Value>(bytes)
2114 .map(Value::from)
2115 .unwrap_or(Value::Null))
2116 }
2117 DataType::Int64 => Ok(Value::Int(downcast_arr!(arr, Int64Array).value(row))),
2118 DataType::Float64 => Ok(Value::Float(downcast_arr!(arr, Float64Array).value(row))),
2119 DataType::Utf8 => Ok(string_to_value(downcast_arr!(arr, StringArray).value(row))),
2120 DataType::LargeUtf8 => Ok(string_to_value(
2121 downcast_arr!(arr, LargeStringArray).value(row),
2122 )),
2123 DataType::Boolean => Ok(Value::Bool(downcast_arr!(arr, BooleanArray).value(row))),
2124 DataType::UInt64 => Ok(Value::Int(downcast_arr!(arr, UInt64Array).value(row) as i64)),
2125 DataType::Int32 => Ok(Value::Int(downcast_arr!(arr, Int32Array).value(row) as i64)),
2126 DataType::Float32 => Ok(Value::Float(
2127 downcast_arr!(arr, Float32Array).value(row) as f64
2128 )),
2129 DataType::List(_) | DataType::LargeList(_) => Ok(
2134 uni_store::storage::arrow_convert::arrow_to_value(arr.as_ref(), row, None),
2135 ),
2136 _ => {
2139 let scalar = ScalarValue::try_from_array(arr, row).map_err(|e| {
2140 datafusion::error::DataFusionError::Execution(format!(
2141 "Cannot extract scalar from array at row {}: {}",
2142 row, e
2143 ))
2144 })?;
2145 scalar_to_value(&scalar)
2146 }
2147 }
2148}
2149
2150fn get_value_args_for_row(
2156 args: &[ColumnarValue],
2157 arg_fields: &[arrow::datatypes::FieldRef],
2158 row: usize,
2159) -> DFResult<Vec<Value>> {
2160 args.iter()
2161 .enumerate()
2162 .map(|(idx, arg)| match arg {
2163 ColumnarValue::Scalar(scalar) => scalar_to_value(scalar),
2164 ColumnarValue::Array(arr) => {
2165 get_value_from_array(arr, row, arg_fields.get(idx).map(|f| f.as_ref()))
2166 }
2167 })
2168 .collect()
2169}
2170
2171fn invoke_cypher_udf<F>(
2173 args: ScalarFunctionArgs,
2174 output_type: &DataType,
2175 f: F,
2176) -> DFResult<ColumnarValue>
2177where
2178 F: Fn(&[Value]) -> DFResult<Value>,
2179{
2180 let len = args
2181 .args
2182 .iter()
2183 .find_map(|arg| match arg {
2184 ColumnarValue::Array(arr) => Some(arr.len()),
2185 _ => None,
2186 })
2187 .unwrap_or(1);
2188
2189 if len == 1
2190 && args
2191 .args
2192 .iter()
2193 .all(|a| matches!(a, ColumnarValue::Scalar(_)))
2194 {
2195 let row_args = get_value_args_for_row(&args.args, &args.arg_fields, 0)?;
2196 let res = f(&row_args)?;
2197 if matches!(output_type, DataType::LargeBinary | DataType::List(_)) {
2198 let arr = values_to_array(&[res], output_type)
2200 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2201 return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, 0)?));
2202 }
2203 if res.is_null() {
2205 let typed_null = ScalarValue::try_from(output_type).unwrap_or(ScalarValue::Utf8(None));
2206 return Ok(ColumnarValue::Scalar(typed_null));
2207 }
2208 return value_to_columnar(&res);
2209 }
2210
2211 let mut results = Vec::with_capacity(len);
2212 for i in 0..len {
2213 let row_args = get_value_args_for_row(&args.args, &args.arg_fields, i)?;
2214 results.push(f(&row_args)?);
2215 }
2216
2217 let arr = values_to_array(&results, output_type)
2218 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2219 Ok(ColumnarValue::Array(arr))
2220}
2221
2222fn scalar_arr_to_value(arr: &dyn arrow::array::Array) -> DFResult<Value> {
2225 if arr.is_empty() || arr.is_null(0) {
2226 Ok(Value::Null)
2227 } else {
2228 Ok(uni_store::storage::arrow_convert::arrow_to_value(
2230 arr, 0, None,
2231 ))
2232 }
2233}
2234
2235fn resolve_timezone_offset(tz_name: &str, nanos_utc: i64) -> i32 {
2237 if tz_name == "UTC" || tz_name == "Z" {
2238 return 0;
2239 }
2240 if let Ok(tz) = tz_name.parse::<chrono_tz::Tz>() {
2241 let dt = chrono::DateTime::from_timestamp_nanos(nanos_utc).with_timezone(&tz);
2242 dt.offset().fix().local_minus_utc()
2243 } else {
2244 0
2245 }
2246}
2247
2248fn duration_micros_to_value(micros: i64) -> Value {
2250 let dur = crate::datetime::CypherDuration::from_micros(micros);
2251 Value::Temporal(uni_common::TemporalValue::Duration {
2252 months: dur.months,
2253 days: dur.days,
2254 nanos: dur.nanos,
2255 })
2256}
2257
2258fn timestamp_nanos_to_value(nanos: i64, tz: Option<&Arc<str>>) -> DFResult<Value> {
2260 if let Some(tz_str) = tz {
2261 let offset = resolve_timezone_offset(tz_str.as_ref(), nanos);
2262 let tz_name = if tz_str.as_ref() == "UTC" {
2263 None
2264 } else {
2265 Some(tz_str.to_string())
2266 };
2267 Ok(Value::Temporal(uni_common::TemporalValue::DateTime {
2268 nanos_since_epoch: nanos,
2269 offset_seconds: offset,
2270 timezone_name: tz_name,
2271 }))
2272 } else {
2273 Ok(Value::Temporal(uni_common::TemporalValue::LocalDateTime {
2274 nanos_since_epoch: nanos,
2275 }))
2276 }
2277}
2278
2279pub fn scalar_to_value(scalar: &ScalarValue) -> DFResult<Value> {
2281 match scalar {
2282 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
2283 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2286 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2287 {
2288 return Ok(Value::from(obj));
2289 }
2290 Ok(Value::String(s.clone()))
2291 }
2292 ScalarValue::LargeBinary(Some(b)) => {
2293 if let Ok(val) = uni_common::cypher_value_codec::decode(b) {
2296 return Ok(val);
2297 }
2298 if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(b) {
2299 Ok(Value::from(obj))
2300 } else {
2301 Ok(Value::Null)
2302 }
2303 }
2304 ScalarValue::Int64(Some(i)) => Ok(Value::Int(*i)),
2305 ScalarValue::Int32(Some(i)) => Ok(Value::Int(*i as i64)),
2306 ScalarValue::Float64(Some(f)) => {
2307 Ok(Value::Float(*f))
2309 }
2310 ScalarValue::Boolean(Some(b)) => Ok(Value::Bool(*b)),
2311 ScalarValue::Struct(arr) => scalar_arr_to_value(arr.as_ref()),
2312 ScalarValue::List(arr) => scalar_arr_to_value(arr.as_ref()),
2313 ScalarValue::LargeList(arr) => scalar_arr_to_value(arr.as_ref()),
2314 ScalarValue::FixedSizeList(arr) => scalar_arr_to_value(arr.as_ref()),
2315 ScalarValue::UInt64(Some(u)) => Ok(Value::Int(*u as i64)),
2317 ScalarValue::UInt32(Some(u)) => Ok(Value::Int(*u as i64)),
2318 ScalarValue::UInt16(Some(u)) => Ok(Value::Int(*u as i64)),
2319 ScalarValue::UInt8(Some(u)) => Ok(Value::Int(*u as i64)),
2320 ScalarValue::Int16(Some(i)) => Ok(Value::Int(*i as i64)),
2321 ScalarValue::Int8(Some(i)) => Ok(Value::Int(*i as i64)),
2322
2323 ScalarValue::Date32(Some(days)) => Ok(Value::Temporal(uni_common::TemporalValue::Date {
2325 days_since_epoch: *days,
2326 })),
2327 ScalarValue::Date64(Some(millis)) => {
2328 let days = (*millis / 86_400_000) as i32;
2329 Ok(Value::Temporal(uni_common::TemporalValue::Date {
2330 days_since_epoch: days,
2331 }))
2332 }
2333 ScalarValue::TimestampNanosecond(Some(nanos), tz) => {
2334 timestamp_nanos_to_value(*nanos, tz.as_ref())
2335 }
2336 ScalarValue::TimestampMicrosecond(Some(micros), tz) => {
2337 timestamp_nanos_to_value(*micros * 1_000, tz.as_ref())
2338 }
2339 ScalarValue::TimestampMillisecond(Some(millis), tz) => {
2340 timestamp_nanos_to_value(*millis * 1_000_000, tz.as_ref())
2341 }
2342 ScalarValue::TimestampSecond(Some(secs), tz) => {
2343 timestamp_nanos_to_value(*secs * 1_000_000_000, tz.as_ref())
2344 }
2345 ScalarValue::Time64Nanosecond(Some(nanos)) => {
2346 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2347 nanos_since_midnight: *nanos,
2348 }))
2349 }
2350 ScalarValue::Time64Microsecond(Some(micros)) => {
2351 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2352 nanos_since_midnight: *micros * 1_000,
2353 }))
2354 }
2355 ScalarValue::IntervalMonthDayNano(Some(v)) => {
2356 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2357 months: v.months as i64,
2358 days: v.days as i64,
2359 nanos: v.nanoseconds,
2360 }))
2361 }
2362 ScalarValue::DurationMicrosecond(Some(micros)) => Ok(duration_micros_to_value(*micros)),
2363 ScalarValue::DurationMillisecond(Some(millis)) => {
2364 Ok(duration_micros_to_value(*millis * 1_000))
2365 }
2366 ScalarValue::DurationSecond(Some(secs)) => Ok(duration_micros_to_value(*secs * 1_000_000)),
2367 ScalarValue::DurationNanosecond(Some(nanos)) => {
2368 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2369 months: 0,
2370 days: 0,
2371 nanos: *nanos,
2372 }))
2373 }
2374 ScalarValue::Float32(Some(f)) => Ok(Value::Float(*f as f64)),
2375
2376 ScalarValue::FixedSizeBinary(24, Some(bytes)) => {
2378 match uni_btic::encode::decode_slice(bytes) {
2379 Ok(btic) => Ok(Value::Temporal(uni_common::TemporalValue::Btic {
2380 lo: btic.lo(),
2381 hi: btic.hi(),
2382 meta: btic.meta(),
2383 })),
2384 Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
2385 "BTIC decode error: {e}"
2386 ))),
2387 }
2388 }
2389
2390 ScalarValue::Null
2392 | ScalarValue::Utf8(None)
2393 | ScalarValue::LargeUtf8(None)
2394 | ScalarValue::LargeBinary(None)
2395 | ScalarValue::Int64(None)
2396 | ScalarValue::Int32(None)
2397 | ScalarValue::Int16(None)
2398 | ScalarValue::Int8(None)
2399 | ScalarValue::UInt64(None)
2400 | ScalarValue::UInt32(None)
2401 | ScalarValue::UInt16(None)
2402 | ScalarValue::UInt8(None)
2403 | ScalarValue::Float64(None)
2404 | ScalarValue::Float32(None)
2405 | ScalarValue::Boolean(None)
2406 | ScalarValue::Date32(None)
2407 | ScalarValue::Date64(None)
2408 | ScalarValue::TimestampMicrosecond(None, _)
2409 | ScalarValue::TimestampMillisecond(None, _)
2410 | ScalarValue::TimestampSecond(None, _)
2411 | ScalarValue::TimestampNanosecond(None, _)
2412 | ScalarValue::Time64Microsecond(None)
2413 | ScalarValue::Time64Nanosecond(None)
2414 | ScalarValue::DurationMicrosecond(None)
2415 | ScalarValue::DurationMillisecond(None)
2416 | ScalarValue::DurationSecond(None)
2417 | ScalarValue::DurationNanosecond(None)
2418 | ScalarValue::IntervalMonthDayNano(None)
2419 | ScalarValue::FixedSizeBinary(_, None) => Ok(Value::Null),
2420 other => Err(datafusion::error::DataFusionError::Execution(format!(
2421 "scalar_to_value(): unsupported scalar type {other:?}"
2422 ))),
2423 }
2424}
2425
2426fn value_to_columnar(val: &Value) -> DFResult<ColumnarValue> {
2428 let scalar = match val {
2429 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
2430 Value::Int(i) => ScalarValue::Int64(Some(*i)),
2431 Value::Float(f) => ScalarValue::Float64(Some(*f)),
2432 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
2433 Value::Null => ScalarValue::Utf8(None),
2434 Value::Temporal(tv) => {
2435 use uni_common::TemporalValue;
2436 match tv {
2437 TemporalValue::Date { days_since_epoch } => {
2438 ScalarValue::Date32(Some(*days_since_epoch))
2439 }
2440 TemporalValue::LocalTime {
2441 nanos_since_midnight,
2442 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2443 TemporalValue::Time {
2444 nanos_since_midnight,
2445 ..
2446 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2447 TemporalValue::LocalDateTime { nanos_since_epoch } => {
2448 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None)
2449 }
2450 TemporalValue::DateTime {
2451 nanos_since_epoch,
2452 timezone_name,
2453 ..
2454 } => {
2455 let tz = timezone_name.as_deref().unwrap_or("UTC");
2456 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), Some(tz.into()))
2457 }
2458 TemporalValue::Duration {
2459 months,
2460 days,
2461 nanos,
2462 } => ScalarValue::IntervalMonthDayNano(Some(
2463 arrow::datatypes::IntervalMonthDayNano {
2464 months: *months as i32,
2465 days: *days as i32,
2466 nanoseconds: *nanos,
2467 },
2468 )),
2469 TemporalValue::Btic { lo, hi, meta } => {
2470 let btic = uni_btic::Btic::new(*lo, *hi, *meta).map_err(|e| {
2471 datafusion::error::DataFusionError::Execution(format!("invalid BTIC: {e}"))
2472 })?;
2473 let packed = uni_btic::encode::encode(&btic);
2474 ScalarValue::FixedSizeBinary(24, Some(packed.to_vec()))
2475 }
2476 }
2477 }
2478 other => {
2479 return Err(datafusion::error::DataFusionError::Execution(format!(
2480 "value_to_columnar(): unsupported type {other:?}"
2481 )));
2482 }
2483 };
2484 Ok(ColumnarValue::Scalar(scalar))
2485}
2486
2487pub fn create_has_null_udf() -> ScalarUDF {
2493 ScalarUDF::new_from_impl(HasNullUdf::new())
2494}
2495
2496#[derive(Debug)]
2497struct HasNullUdf {
2498 signature: Signature,
2499}
2500
2501impl HasNullUdf {
2502 fn new() -> Self {
2503 Self {
2504 signature: Signature::any(1, Volatility::Immutable),
2505 }
2506 }
2507}
2508
2509impl_udf_eq_hash!(HasNullUdf);
2510
2511impl ScalarUDFImpl for HasNullUdf {
2512 fn as_any(&self) -> &dyn Any {
2513 self
2514 }
2515
2516 fn name(&self) -> &str {
2517 "_has_null"
2518 }
2519
2520 fn signature(&self) -> &Signature {
2521 &self.signature
2522 }
2523
2524 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2525 Ok(DataType::Boolean)
2526 }
2527
2528 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2529 if args.args.len() != 1 {
2530 return Err(datafusion::error::DataFusionError::Execution(
2531 "_has_null(): requires 1 argument".to_string(),
2532 ));
2533 }
2534
2535 fn check_list_nulls<T: arrow_array::OffsetSizeTrait>(
2537 arr: &arrow_array::GenericListArray<T>,
2538 idx: usize,
2539 ) -> bool {
2540 if arr.is_null(idx) || arr.is_empty() {
2541 false
2542 } else {
2543 arr.value(idx).null_count() > 0
2544 }
2545 }
2546
2547 match &args.args[0] {
2548 ColumnarValue::Scalar(scalar) => {
2549 let has_null = match scalar {
2550 ScalarValue::List(arr) => arr
2551 .as_any()
2552 .downcast_ref::<arrow::array::ListArray>()
2553 .map(|a| !a.is_empty() && a.value(0).null_count() > 0)
2554 .unwrap_or(arr.null_count() > 0),
2555 ScalarValue::LargeList(arr) => arr.len() > 0 && arr.value(0).null_count() > 0,
2556 ScalarValue::FixedSizeList(arr) => {
2557 arr.len() > 0 && arr.value(0).null_count() > 0
2558 }
2559 _ => false,
2560 };
2561 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(has_null))))
2562 }
2563 ColumnarValue::Array(arr) => {
2564 use arrow_array::{LargeListArray, ListArray};
2565
2566 let results: arrow::array::BooleanArray =
2567 if let Some(list_arr) = arr.as_any().downcast_ref::<ListArray>() {
2568 (0..list_arr.len())
2569 .map(|i| {
2570 if list_arr.is_null(i) {
2571 None
2572 } else {
2573 Some(check_list_nulls(list_arr, i))
2574 }
2575 })
2576 .collect()
2577 } else if let Some(large) = arr.as_any().downcast_ref::<LargeListArray>() {
2578 (0..large.len())
2579 .map(|i| {
2580 if large.is_null(i) {
2581 None
2582 } else {
2583 Some(check_list_nulls(large, i))
2584 }
2585 })
2586 .collect()
2587 } else {
2588 return Err(datafusion::error::DataFusionError::Execution(
2589 "_has_null(): requires list array".to_string(),
2590 ));
2591 };
2592 Ok(ColumnarValue::Array(Arc::new(results)))
2593 }
2594 }
2595 }
2596}
2597
2598pub fn create_to_integer_udf() -> ScalarUDF {
2603 ScalarUDF::new_from_impl(ToIntegerUdf::new())
2604}
2605
2606#[derive(Debug)]
2607struct ToIntegerUdf {
2608 signature: Signature,
2609}
2610
2611impl ToIntegerUdf {
2612 fn new() -> Self {
2613 Self {
2614 signature: Signature::any(1, Volatility::Immutable),
2615 }
2616 }
2617}
2618
2619impl_udf_eq_hash!(ToIntegerUdf);
2620
2621impl ScalarUDFImpl for ToIntegerUdf {
2622 fn as_any(&self) -> &dyn Any {
2623 self
2624 }
2625
2626 fn name(&self) -> &str {
2627 "tointeger"
2628 }
2629
2630 fn signature(&self) -> &Signature {
2631 &self.signature
2632 }
2633
2634 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2635 Ok(DataType::Int64)
2636 }
2637
2638 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2639 let output_type = self.return_type(&[])?;
2640 invoke_cypher_udf(args, &output_type, |val_args| {
2641 if val_args.is_empty() {
2642 return Err(datafusion::error::DataFusionError::Execution(
2643 "tointeger(): requires 1 argument".to_string(),
2644 ));
2645 }
2646
2647 let val = &val_args[0];
2648 match val {
2649 Value::Int(i) => Ok(Value::Int(*i)),
2650 Value::Float(f) => Ok(Value::Int(*f as i64)),
2651 Value::String(s) => {
2652 if let Ok(i) = s.parse::<i64>() {
2653 Ok(Value::Int(i))
2654 } else if let Ok(f) = s.parse::<f64>() {
2655 Ok(Value::Int(f as i64))
2656 } else {
2657 Ok(Value::Null)
2658 }
2659 }
2660 Value::Null => Ok(Value::Null),
2661 other => Err(datafusion::error::DataFusionError::Execution(format!(
2662 "InvalidArgumentValue: tointeger(): cannot convert {} to integer",
2663 cypher_type_name(other)
2664 ))),
2665 }
2666 })
2667 }
2668}
2669
2670pub fn create_to_float_udf() -> ScalarUDF {
2675 ScalarUDF::new_from_impl(ToFloatUdf::new())
2676}
2677
2678#[derive(Debug)]
2679struct ToFloatUdf {
2680 signature: Signature,
2681}
2682
2683impl ToFloatUdf {
2684 fn new() -> Self {
2685 Self {
2686 signature: Signature::any(1, Volatility::Immutable),
2687 }
2688 }
2689}
2690
2691impl_udf_eq_hash!(ToFloatUdf);
2692
2693impl ScalarUDFImpl for ToFloatUdf {
2694 fn as_any(&self) -> &dyn Any {
2695 self
2696 }
2697
2698 fn name(&self) -> &str {
2699 "tofloat"
2700 }
2701
2702 fn signature(&self) -> &Signature {
2703 &self.signature
2704 }
2705
2706 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2707 Ok(DataType::Float64)
2708 }
2709
2710 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2711 let output_type = self.return_type(&[])?;
2712 invoke_cypher_udf(args, &output_type, |val_args| {
2713 if val_args.is_empty() {
2714 return Err(datafusion::error::DataFusionError::Execution(
2715 "tofloat(): requires 1 argument".to_string(),
2716 ));
2717 }
2718
2719 let val = &val_args[0];
2720 match val {
2721 Value::Int(i) => Ok(Value::Float(*i as f64)),
2722 Value::Float(f) => Ok(Value::Float(*f)),
2723 Value::String(s) => {
2724 if let Ok(f) = s.parse::<f64>() {
2725 Ok(Value::Float(f))
2726 } else {
2727 Ok(Value::Null)
2728 }
2729 }
2730 Value::Null => Ok(Value::Null),
2731 other => Err(datafusion::error::DataFusionError::Execution(format!(
2732 "InvalidArgumentValue: tofloat(): cannot convert {} to float",
2733 cypher_type_name(other)
2734 ))),
2735 }
2736 })
2737 }
2738}
2739
2740pub fn create_to_boolean_udf() -> ScalarUDF {
2745 ScalarUDF::new_from_impl(ToBooleanUdf::new())
2746}
2747
2748#[derive(Debug)]
2749struct ToBooleanUdf {
2750 signature: Signature,
2751}
2752
2753impl ToBooleanUdf {
2754 fn new() -> Self {
2755 Self {
2756 signature: Signature::any(1, Volatility::Immutable),
2757 }
2758 }
2759}
2760
2761impl_udf_eq_hash!(ToBooleanUdf);
2762
2763impl ScalarUDFImpl for ToBooleanUdf {
2764 fn as_any(&self) -> &dyn Any {
2765 self
2766 }
2767
2768 fn name(&self) -> &str {
2769 "toboolean"
2770 }
2771
2772 fn signature(&self) -> &Signature {
2773 &self.signature
2774 }
2775
2776 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2777 Ok(DataType::Boolean)
2778 }
2779
2780 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2781 let output_type = self.return_type(&[])?;
2782 invoke_cypher_udf(args, &output_type, |val_args| {
2783 if val_args.is_empty() {
2784 return Err(datafusion::error::DataFusionError::Execution(
2785 "toboolean(): requires 1 argument".to_string(),
2786 ));
2787 }
2788
2789 let val = &val_args[0];
2790 match val {
2791 Value::Bool(b) => Ok(Value::Bool(*b)),
2792 Value::String(s) => {
2793 let s_lower = s.to_lowercase();
2794 if s_lower == "true" {
2795 Ok(Value::Bool(true))
2796 } else if s_lower == "false" {
2797 Ok(Value::Bool(false))
2798 } else {
2799 Ok(Value::Null)
2800 }
2801 }
2802 Value::Null => Ok(Value::Null),
2803 Value::Int(i) => Ok(Value::Bool(*i != 0)),
2804 other => Err(datafusion::error::DataFusionError::Execution(format!(
2805 "InvalidArgumentValue: toboolean(): cannot convert {} to boolean",
2806 cypher_type_name(other)
2807 ))),
2808 }
2809 })
2810 }
2811}
2812
2813pub fn create_cypher_sort_key_udf() -> ScalarUDF {
2820 ScalarUDF::new_from_impl(CypherSortKeyUdf::new())
2821}
2822
2823#[derive(Debug)]
2824struct CypherSortKeyUdf {
2825 signature: Signature,
2826}
2827
2828impl CypherSortKeyUdf {
2829 fn new() -> Self {
2830 Self {
2831 signature: Signature::any(1, Volatility::Immutable),
2832 }
2833 }
2834}
2835
2836impl_udf_eq_hash!(CypherSortKeyUdf);
2837
2838impl ScalarUDFImpl for CypherSortKeyUdf {
2839 fn as_any(&self) -> &dyn Any {
2840 self
2841 }
2842
2843 fn name(&self) -> &str {
2844 "_cypher_sort_key"
2845 }
2846
2847 fn signature(&self) -> &Signature {
2848 &self.signature
2849 }
2850
2851 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2852 Ok(DataType::LargeBinary)
2853 }
2854
2855 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2856 if args.args.len() != 1 {
2857 return Err(datafusion::error::DataFusionError::Execution(
2858 "_cypher_sort_key(): requires 1 argument".to_string(),
2859 ));
2860 }
2861
2862 let arg = &args.args[0];
2863 match arg {
2864 ColumnarValue::Scalar(s) => {
2865 let val = if s.is_null() {
2866 Value::Null
2867 } else {
2868 scalar_to_value(s)?
2869 };
2870 let key = encode_cypher_sort_key(&val);
2871 Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(key))))
2872 }
2873 ColumnarValue::Array(arr) => {
2874 let field = args.arg_fields.first().map(|f| f.as_ref());
2875 let mut keys: Vec<Option<Vec<u8>>> = Vec::with_capacity(arr.len());
2876 for i in 0..arr.len() {
2877 let val = if arr.is_null(i) {
2878 Value::Null
2879 } else {
2880 get_value_from_array(arr, i, field)?
2881 };
2882 keys.push(Some(encode_cypher_sort_key(&val)));
2883 }
2884 let array = LargeBinaryArray::from(
2885 keys.iter()
2886 .map(|k| k.as_deref())
2887 .collect::<Vec<Option<&[u8]>>>(),
2888 );
2889 Ok(ColumnarValue::Array(Arc::new(array)))
2890 }
2891 }
2892 }
2893}
2894
2895pub fn encode_cypher_sort_key(value: &Value) -> Vec<u8> {
2901 let mut buf = Vec::with_capacity(32);
2902 encode_sort_key_to_buf(value, &mut buf);
2903 buf
2904}
2905
2906fn encode_sort_key_to_buf(value: &Value, buf: &mut Vec<u8>) {
2908 if let Value::Map(map) = value {
2910 if let Some(tv) = sort_key_map_as_temporal(map) {
2911 buf.push(0x07); encode_temporal_payload(&tv, buf);
2913 return;
2914 }
2915 let rank = sort_key_map_rank(map);
2916 if rank != 0 {
2917 buf.push(rank);
2919 match rank {
2920 0x01 => encode_map_as_node_payload(map, buf),
2921 0x02 => encode_map_as_edge_payload(map, buf),
2922 0x04 => encode_map_as_path_payload(map, buf),
2923 _ => {} }
2925 return;
2926 }
2927 }
2928
2929 if let Value::String(s) = value {
2931 if let Some(tv) = sort_key_string_as_temporal(s) {
2932 buf.push(0x07); encode_temporal_payload(&tv, buf);
2934 return;
2935 }
2936 if let Some(temporal_type) = crate::datetime::classify_temporal(s) {
2939 buf.push(0x07); if encode_wide_temporal_sort_key(s, temporal_type, buf) {
2941 return;
2942 }
2943 buf.pop();
2945 }
2946 }
2947
2948 let rank = sort_key_type_rank(value);
2949 buf.push(rank);
2950
2951 match value {
2952 Value::Null => {} Value::Float(f) if f.is_nan() => {} Value::Bool(b) => buf.push(if *b { 0x01 } else { 0x00 }),
2955 Value::Int(i) => {
2956 let f = *i as f64;
2957 buf.extend_from_slice(&encode_order_preserving_f64(f));
2958 }
2959 Value::Float(f) => {
2960 buf.extend_from_slice(&encode_order_preserving_f64(*f));
2961 }
2962 Value::String(s) => {
2963 byte_stuff_terminate(s.as_bytes(), buf);
2964 }
2965 Value::Temporal(tv) => {
2966 encode_temporal_payload(tv, buf);
2967 }
2968 Value::List(items) => {
2969 encode_list_payload(items, buf);
2970 }
2971 Value::Map(map) => {
2972 encode_map_payload(map, buf);
2973 }
2974 Value::Node(node) => {
2975 encode_node_payload(node, buf);
2976 }
2977 Value::Edge(edge) => {
2978 encode_edge_payload(edge, buf);
2979 }
2980 Value::Path(path) => {
2981 encode_path_payload(path, buf);
2982 }
2983 Value::Bytes(b) => {
2985 byte_stuff_terminate(b, buf);
2986 }
2987 Value::Vector(v) => {
2988 for f in v {
2989 buf.extend_from_slice(&encode_order_preserving_f64(*f as f64));
2990 }
2991 }
2992 _ => {} }
2994}
2995
2996fn sort_key_type_rank(v: &Value) -> u8 {
3000 match v {
3001 Value::Map(map) => sort_key_map_rank(map),
3002 Value::Node(_) => 0x01,
3003 Value::Edge(_) => 0x02,
3004 Value::List(_) => 0x03,
3005 Value::Path(_) => 0x04,
3006 Value::String(_) => 0x05,
3007 Value::Bool(_) => 0x06,
3008 Value::Temporal(_) => 0x07,
3009 Value::Int(_) => 0x08,
3010 Value::Float(f) if f.is_nan() => 0x09,
3011 Value::Float(_) => 0x08,
3012 Value::Null => 0x0A,
3013 Value::Bytes(_) | Value::Vector(_) => 0x0B,
3014 _ => 0x0B, }
3016}
3017
3018fn sort_key_map_rank(map: &std::collections::HashMap<String, Value>) -> u8 {
3020 if sort_key_map_as_temporal(map).is_some() {
3021 0x07
3022 } else if map.contains_key("nodes")
3023 && (map.contains_key("relationships") || map.contains_key("edges"))
3024 {
3025 0x04 } else if map.contains_key("_eid")
3027 || map.contains_key("_src")
3028 || map.contains_key("_dst")
3029 || map.contains_key("_type")
3030 || map.contains_key("_type_name")
3031 {
3032 0x02 } else if map.contains_key("_vid") || map.contains_key("_labels") || map.contains_key("_label")
3034 {
3035 0x01 } else {
3037 0x00 }
3039}
3040
3041fn sort_key_map_as_temporal(
3045 map: &std::collections::HashMap<String, Value>,
3046) -> Option<uni_common::TemporalValue> {
3047 super::expr_eval::temporal_from_map_wrapper(map)
3048}
3049
3050fn sort_key_string_as_temporal(s: &str) -> Option<uni_common::TemporalValue> {
3054 super::expr_eval::temporal_from_value(&Value::String(s.to_string()))
3055}
3056
3057fn encode_wide_temporal_sort_key(
3064 s: &str,
3065 temporal_type: uni_common::TemporalType,
3066 buf: &mut Vec<u8>,
3067) -> bool {
3068 match temporal_type {
3069 uni_common::TemporalType::LocalDateTime => {
3070 if let Some(ndt) = parse_naive_datetime(s) {
3071 buf.push(0x03); let wide_nanos = naive_datetime_to_wide_nanos(&ndt);
3073 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3074 return true;
3075 }
3076 false
3077 }
3078 uni_common::TemporalType::DateTime => {
3079 let base = if let Some(bracket_pos) = s.find('[') {
3081 &s[..bracket_pos]
3082 } else {
3083 s
3084 };
3085 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%.f%:z") {
3086 buf.push(0x04); let utc = dt.naive_utc();
3088 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
3089 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3090 return true;
3091 }
3092 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%:z") {
3093 buf.push(0x04); let utc = dt.naive_utc();
3095 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
3096 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
3097 return true;
3098 }
3099 false
3100 }
3101 uni_common::TemporalType::Date => {
3102 if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
3103 && let Some(epoch) = chrono::NaiveDate::from_ymd_opt(1970, 1, 1)
3104 {
3105 buf.push(0x00); let days = nd.signed_duration_since(epoch).num_days() as i32;
3107 buf.extend_from_slice(&encode_order_preserving_i32(days));
3108 return true;
3109 }
3110 false
3111 }
3112 _ => false,
3113 }
3114}
3115
3116fn parse_naive_datetime(s: &str) -> Option<chrono::NaiveDateTime> {
3118 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
3119 .ok()
3120 .or_else(|| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").ok())
3121}
3122
3123fn naive_datetime_to_wide_nanos(ndt: &chrono::NaiveDateTime) -> i128 {
3126 let secs = ndt.and_utc().timestamp() as i128;
3127 let subsec_nanos = ndt.and_utc().timestamp_subsec_nanos() as i128;
3128 secs * 1_000_000_000 + subsec_nanos
3129}
3130
3131fn encode_map_as_node_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3133 let mut labels: Vec<String> = Vec::new();
3135 if let Some(Value::List(lbls)) = map.get("_labels") {
3136 for l in lbls {
3137 if let Value::String(s) = l {
3138 labels.push(s.clone());
3139 }
3140 }
3141 } else if let Some(Value::String(lbl)) = map.get("_label") {
3142 labels.push(lbl.clone());
3143 }
3144 labels.sort();
3145
3146 let vid = map.get("_vid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3148
3149 let labels_joined = labels.join("\x01");
3151 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3152
3153 buf.extend_from_slice(&vid.to_be_bytes());
3155
3156 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3158 for (k, v) in map {
3159 if !k.starts_with('_') {
3160 props.insert(k.clone(), v.clone());
3161 }
3162 }
3163 encode_map_payload(&props, buf);
3164}
3165
3166fn encode_map_as_edge_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3168 let edge_type = map
3169 .get("_type")
3170 .or_else(|| map.get("_type_name"))
3171 .and_then(|v| {
3172 if let Value::String(s) = v {
3173 Some(s.as_str())
3174 } else {
3175 None
3176 }
3177 })
3178 .unwrap_or("");
3179
3180 byte_stuff_terminate(edge_type.as_bytes(), buf);
3181
3182 let src = map.get("_src").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3183 let dst = map.get("_dst").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3184 let eid = map.get("_eid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3185
3186 buf.extend_from_slice(&src.to_be_bytes());
3187 buf.extend_from_slice(&dst.to_be_bytes());
3188 buf.extend_from_slice(&eid.to_be_bytes());
3189
3190 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3192 for (k, v) in map {
3193 if !k.starts_with('_') {
3194 props.insert(k.clone(), v.clone());
3195 }
3196 }
3197 encode_map_payload(&props, buf);
3198}
3199
3200fn encode_map_as_path_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3202 if let Some(Value::List(nodes)) = map.get("nodes") {
3204 encode_list_payload(nodes, buf);
3205 } else {
3206 buf.push(0x00); }
3208 let edges = map.get("relationships").or_else(|| map.get("edges"));
3210 if let Some(Value::List(edges)) = edges {
3211 encode_list_payload(edges, buf);
3212 } else {
3213 buf.push(0x00); }
3215}
3216
3217fn encode_order_preserving_f64(f: f64) -> [u8; 8] {
3224 let bits = f.to_bits();
3225 let encoded = if bits >> 63 == 1 {
3226 !bits
3228 } else {
3229 bits ^ (1u64 << 63)
3231 };
3232 encoded.to_be_bytes()
3233}
3234
3235fn encode_order_preserving_i64(i: i64) -> [u8; 8] {
3237 ((i as u64) ^ (1u64 << 63)).to_be_bytes()
3239}
3240
3241fn encode_order_preserving_i32(i: i32) -> [u8; 4] {
3243 ((i as u32) ^ (1u32 << 31)).to_be_bytes()
3244}
3245
3246fn encode_order_preserving_i128(i: i128) -> [u8; 16] {
3248 ((i as u128) ^ (1u128 << 127)).to_be_bytes()
3249}
3250
3251fn byte_stuff_terminate(data: &[u8], buf: &mut Vec<u8>) {
3256 byte_stuff(data, buf);
3257 buf.push(0x00);
3258 buf.push(0x00);
3259}
3260
3261fn byte_stuff(data: &[u8], buf: &mut Vec<u8>) {
3263 for &b in data {
3264 buf.push(b);
3265 if b == 0x00 {
3266 buf.push(0xFF);
3267 }
3268 }
3269}
3270
3271fn encode_list_payload(items: &[Value], buf: &mut Vec<u8>) {
3276 for item in items {
3277 buf.push(0x01); let elem_key = encode_cypher_sort_key(item);
3279 byte_stuff_terminate(&elem_key, buf);
3280 }
3281 buf.push(0x00); }
3283
3284fn encode_map_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3286 let mut pairs: Vec<(&String, &Value)> = map.iter().collect();
3287 pairs.sort_by_key(|(k, _)| *k);
3288
3289 for (key, value) in pairs {
3290 buf.push(0x01); byte_stuff_terminate(key.as_bytes(), buf);
3292 let val_key = encode_cypher_sort_key(value);
3293 byte_stuff_terminate(&val_key, buf);
3294 }
3295 buf.push(0x00); }
3297
3298fn encode_node_payload(node: &uni_common::Node, buf: &mut Vec<u8>) {
3302 let mut labels = node.labels.clone();
3303 labels.sort();
3304 let labels_joined = labels.join("\x01");
3305 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3306
3307 buf.extend_from_slice(&node.vid.as_u64().to_be_bytes());
3308
3309 encode_map_payload(&node.properties, buf);
3310}
3311
3312fn encode_edge_payload(edge: &uni_common::Edge, buf: &mut Vec<u8>) {
3316 byte_stuff_terminate(edge.edge_type.as_bytes(), buf);
3317
3318 buf.extend_from_slice(&edge.src.as_u64().to_be_bytes());
3319 buf.extend_from_slice(&edge.dst.as_u64().to_be_bytes());
3320 buf.extend_from_slice(&edge.eid.as_u64().to_be_bytes());
3321
3322 encode_map_payload(&edge.properties, buf);
3323}
3324
3325fn encode_path_payload(path: &uni_common::Path, buf: &mut Vec<u8>) {
3329 for node in &path.nodes {
3331 buf.push(0x01); let mut node_key = Vec::new();
3333 node_key.push(0x01); encode_node_payload(node, &mut node_key);
3335 byte_stuff_terminate(&node_key, buf);
3336 }
3337 buf.push(0x00); for edge in &path.edges {
3341 buf.push(0x01); let mut edge_key = Vec::new();
3343 edge_key.push(0x02); encode_edge_payload(edge, &mut edge_key);
3345 byte_stuff_terminate(&edge_key, buf);
3346 }
3347 buf.push(0x00); }
3349
3350fn encode_temporal_payload(tv: &uni_common::TemporalValue, buf: &mut Vec<u8>) {
3352 match tv {
3353 uni_common::TemporalValue::Date { days_since_epoch } => {
3354 buf.push(0x00); buf.extend_from_slice(&encode_order_preserving_i32(*days_since_epoch));
3356 }
3357 uni_common::TemporalValue::LocalTime {
3358 nanos_since_midnight,
3359 } => {
3360 buf.push(0x01); buf.extend_from_slice(&encode_order_preserving_i64(*nanos_since_midnight));
3362 }
3363 uni_common::TemporalValue::Time {
3364 nanos_since_midnight,
3365 offset_seconds,
3366 } => {
3367 buf.push(0x02); let utc_nanos =
3369 *nanos_since_midnight as i128 - (*offset_seconds as i128) * 1_000_000_000;
3370 buf.extend_from_slice(&encode_order_preserving_i128(utc_nanos));
3371 }
3372 uni_common::TemporalValue::LocalDateTime { nanos_since_epoch } => {
3373 buf.push(0x03); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3376 }
3377 uni_common::TemporalValue::DateTime {
3378 nanos_since_epoch, ..
3379 } => {
3380 buf.push(0x04); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3383 }
3384 uni_common::TemporalValue::Duration {
3385 months,
3386 days,
3387 nanos,
3388 } => {
3389 buf.push(0x05); buf.extend_from_slice(&encode_order_preserving_i64(*months));
3391 buf.extend_from_slice(&encode_order_preserving_i64(*days));
3392 buf.extend_from_slice(&encode_order_preserving_i64(*nanos));
3393 }
3394 uni_common::TemporalValue::Btic { lo, hi, meta } => {
3395 buf.push(0x06); if let Ok(btic) = uni_btic::Btic::new(*lo, *hi, *meta) {
3398 buf.extend_from_slice(&uni_btic::encode::encode(&btic));
3399 } else {
3400 buf.extend_from_slice(&encode_order_preserving_i64(*lo));
3401 buf.extend_from_slice(&encode_order_preserving_i64(*hi));
3402 }
3403 }
3404 }
3405}
3406
3407#[derive(Debug)]
3416struct BticScalarUdf {
3417 name: String,
3418 signature: Signature,
3419 return_type: DataType,
3420}
3421
3422impl BticScalarUdf {
3423 fn new(name: &str, num_args: usize, return_type: DataType) -> Self {
3424 Self {
3425 name: name.to_string(),
3426 signature: Signature::new(TypeSignature::Any(num_args), Volatility::Immutable),
3427 return_type,
3428 }
3429 }
3430}
3431
3432impl_udf_eq_hash!(BticScalarUdf);
3433
3434impl ScalarUDFImpl for BticScalarUdf {
3435 fn as_any(&self) -> &dyn Any {
3436 self
3437 }
3438 fn name(&self) -> &str {
3439 &self.name
3440 }
3441 fn signature(&self) -> &Signature {
3442 &self.signature
3443 }
3444 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
3445 Ok(self.return_type.clone())
3446 }
3447 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3448 let fname = self.name.to_uppercase();
3449 let rt = self.return_type.clone();
3450 invoke_cypher_udf(args, &rt, |val_args| {
3451 crate::expr_eval::eval_btic_function(&fname, val_args)
3452 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
3453 })
3454 }
3455}
3456
3457fn register_btic_scalar_udfs(ctx: &SessionContext) -> DFResult<()> {
3459 for name in &["btic_lo", "btic_hi"] {
3461 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3462 name,
3463 1,
3464 DataType::LargeBinary,
3465 )));
3466 }
3467 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3469 "btic_duration",
3470 1,
3471 DataType::Int64,
3472 )));
3473 for name in &["btic_is_instant", "btic_is_unbounded", "btic_is_finite"] {
3475 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3476 name,
3477 1,
3478 DataType::Boolean,
3479 )));
3480 }
3481 for name in &[
3483 "btic_granularity",
3484 "btic_lo_granularity",
3485 "btic_hi_granularity",
3486 "btic_certainty",
3487 "btic_lo_certainty",
3488 "btic_hi_certainty",
3489 ] {
3490 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3491 name,
3492 1,
3493 DataType::Utf8,
3494 )));
3495 }
3496 for name in &[
3498 "btic_contains_point",
3499 "btic_overlaps",
3500 "btic_contains",
3501 "btic_before",
3502 "btic_after",
3503 "btic_meets",
3504 "btic_adjacent",
3505 "btic_disjoint",
3506 "btic_equals",
3507 "btic_starts",
3508 "btic_during",
3509 "btic_finishes",
3510 ] {
3511 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3512 name,
3513 2,
3514 DataType::Boolean,
3515 )));
3516 }
3517 for name in &["btic_intersection", "btic_span", "btic_gap"] {
3519 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3520 name,
3521 2,
3522 DataType::LargeBinary,
3523 )));
3524 }
3525 Ok(())
3526}
3527
3528#[derive(Debug, Clone)]
3534struct BticMinMaxUdaf {
3535 name: String,
3536 signature: Signature,
3537 is_max: bool,
3538}
3539
3540impl BticMinMaxUdaf {
3541 fn new(is_max: bool) -> Self {
3542 Self {
3543 name: (if is_max { "btic_max" } else { "btic_min" }).to_string(),
3544 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3545 is_max,
3546 }
3547 }
3548}
3549
3550impl_udf_eq_hash!(BticMinMaxUdaf);
3551
3552impl AggregateUDFImpl for BticMinMaxUdaf {
3553 fn as_any(&self) -> &dyn Any {
3554 self
3555 }
3556 fn name(&self) -> &str {
3557 &self.name
3558 }
3559 fn signature(&self) -> &Signature {
3560 &self.signature
3561 }
3562 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3563 Ok(DataType::LargeBinary)
3564 }
3565 fn accumulator(
3566 &self,
3567 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3568 ) -> DFResult<Box<dyn DfAccumulator>> {
3569 Ok(Box::new(BticMinMaxAccumulator {
3570 current: None,
3571 is_max: self.is_max,
3572 }))
3573 }
3574 fn state_fields(
3575 &self,
3576 args: datafusion::logical_expr::function::StateFieldsArgs,
3577 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3578 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3579 args.name,
3580 DataType::LargeBinary,
3581 true,
3582 ))])
3583 }
3584}
3585
3586#[derive(Debug)]
3587struct BticMinMaxAccumulator {
3588 current: Option<uni_btic::Btic>,
3589 is_max: bool,
3590}
3591
3592impl DfAccumulator for BticMinMaxAccumulator {
3593 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3594 let arr = &values[0];
3595 for i in 0..arr.len() {
3596 if arr.is_null(i) {
3597 continue;
3598 }
3599 let Some(btic) = decode_btic_from_array(arr, i)? else {
3600 continue;
3601 };
3602 self.current = Some(match self.current.take() {
3603 None => btic,
3604 Some(cur) => {
3605 if (self.is_max && btic > cur) || (!self.is_max && btic < cur) {
3606 btic
3607 } else {
3608 cur
3609 }
3610 }
3611 });
3612 }
3613 Ok(())
3614 }
3615 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3616 Ok(btic_to_scalar_value(self.current.as_ref()))
3617 }
3618 fn size(&self) -> usize {
3619 std::mem::size_of::<Self>()
3620 }
3621 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3622 Ok(vec![self.evaluate()?])
3623 }
3624 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3625 self.update_batch(states)
3626 }
3627}
3628
3629#[derive(Debug, Clone)]
3631struct BticSpanAggUdaf {
3632 signature: Signature,
3633}
3634
3635impl BticSpanAggUdaf {
3636 fn new() -> Self {
3637 Self {
3638 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3639 }
3640 }
3641}
3642
3643impl_udf_eq_hash!(BticSpanAggUdaf);
3644
3645impl AggregateUDFImpl for BticSpanAggUdaf {
3646 fn as_any(&self) -> &dyn Any {
3647 self
3648 }
3649 fn name(&self) -> &str {
3650 "btic_span_agg"
3651 }
3652 fn signature(&self) -> &Signature {
3653 &self.signature
3654 }
3655 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3656 Ok(DataType::LargeBinary)
3657 }
3658 fn accumulator(
3659 &self,
3660 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3661 ) -> DFResult<Box<dyn DfAccumulator>> {
3662 Ok(Box::new(BticSpanAggAccumulator { current: None }))
3663 }
3664 fn state_fields(
3665 &self,
3666 args: datafusion::logical_expr::function::StateFieldsArgs,
3667 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3668 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3669 args.name,
3670 DataType::LargeBinary,
3671 true,
3672 ))])
3673 }
3674}
3675
3676#[derive(Debug)]
3677struct BticSpanAggAccumulator {
3678 current: Option<uni_btic::Btic>,
3679}
3680
3681impl DfAccumulator for BticSpanAggAccumulator {
3682 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3683 let arr = &values[0];
3684 for i in 0..arr.len() {
3685 if arr.is_null(i) {
3686 continue;
3687 }
3688 let Some(btic) = decode_btic_from_array(arr, i)? else {
3689 continue;
3690 };
3691 self.current = Some(match self.current.take() {
3692 None => btic,
3693 Some(cur) => uni_btic::set_ops::span(&cur, &btic),
3694 });
3695 }
3696 Ok(())
3697 }
3698 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3699 Ok(btic_to_scalar_value(self.current.as_ref()))
3700 }
3701 fn size(&self) -> usize {
3702 std::mem::size_of::<Self>()
3703 }
3704 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3705 Ok(vec![self.evaluate()?])
3706 }
3707 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3708 self.update_batch(states)
3709 }
3710}
3711
3712#[derive(Debug, Clone)]
3714struct BticCountAtUdaf {
3715 signature: Signature,
3716}
3717
3718impl BticCountAtUdaf {
3719 fn new() -> Self {
3720 Self {
3721 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
3722 }
3723 }
3724}
3725
3726impl_udf_eq_hash!(BticCountAtUdaf);
3727
3728impl AggregateUDFImpl for BticCountAtUdaf {
3729 fn as_any(&self) -> &dyn Any {
3730 self
3731 }
3732 fn name(&self) -> &str {
3733 "btic_count_at"
3734 }
3735 fn signature(&self) -> &Signature {
3736 &self.signature
3737 }
3738 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3739 Ok(DataType::Int64)
3740 }
3741 fn accumulator(
3742 &self,
3743 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3744 ) -> DFResult<Box<dyn DfAccumulator>> {
3745 Ok(Box::new(BticCountAtAccumulator { count: 0 }))
3746 }
3747 fn state_fields(
3748 &self,
3749 args: datafusion::logical_expr::function::StateFieldsArgs,
3750 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3751 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3752 args.name,
3753 DataType::Int64,
3754 true,
3755 ))])
3756 }
3757}
3758
3759#[derive(Debug)]
3760struct BticCountAtAccumulator {
3761 count: i64,
3762}
3763
3764impl DfAccumulator for BticCountAtAccumulator {
3765 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3766 if values.len() < 2 {
3767 return Ok(());
3768 }
3769 let btic_arr = &values[0];
3770 let point_arr = &values[1];
3771
3772 for i in 0..btic_arr.len() {
3773 if btic_arr.is_null(i) || point_arr.is_null(i) {
3774 continue;
3775 }
3776 let Some(btic) = decode_btic_from_array(btic_arr, i)? else {
3777 continue;
3778 };
3779
3780 let point_ms = if let Some(int_arr) = point_arr.as_any().downcast_ref::<Int64Array>() {
3782 int_arr.value(i)
3783 } else if let Some(lb) = point_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3784 let val = scalar_binary_to_value(lb.value(i));
3785 match &val {
3786 Value::Int(ms) => *ms,
3787 Value::Temporal(uni_common::TemporalValue::DateTime {
3788 nanos_since_epoch,
3789 ..
3790 }) => nanos_since_epoch / 1_000_000,
3791 _ => continue,
3792 }
3793 } else {
3794 continue;
3795 };
3796
3797 if uni_btic::predicates::contains_point(&btic, point_ms) {
3798 self.count += 1;
3799 }
3800 }
3801 Ok(())
3802 }
3803 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3804 Ok(ScalarValue::Int64(Some(self.count)))
3805 }
3806 fn size(&self) -> usize {
3807 std::mem::size_of::<Self>()
3808 }
3809 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3810 Ok(vec![ScalarValue::Int64(Some(self.count))])
3811 }
3812 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3813 let arr = &states[0];
3814 if let Some(int_arr) = arr.as_any().downcast_ref::<Int64Array>() {
3815 for i in 0..int_arr.len() {
3816 if !int_arr.is_null(i) {
3817 self.count += int_arr.value(i);
3818 }
3819 }
3820 }
3821 Ok(())
3822 }
3823}
3824
3825fn btic_to_scalar_value(btic: Option<&uni_btic::Btic>) -> ScalarValue {
3827 match btic {
3828 None => ScalarValue::LargeBinary(None),
3829 Some(b) => {
3830 let val = Value::Temporal(uni_common::TemporalValue::Btic {
3831 lo: b.lo(),
3832 hi: b.hi(),
3833 meta: b.meta(),
3834 });
3835 ScalarValue::LargeBinary(Some(uni_common::cypher_value_codec::encode(&val)))
3836 }
3837 }
3838}
3839
3840fn decode_btic_from_array(arr: &ArrayRef, row: usize) -> DFResult<Option<uni_btic::Btic>> {
3842 if let Some(fsb) = arr.as_any().downcast_ref::<FixedSizeBinaryArray>() {
3844 let bytes = fsb.value(row);
3845 return uni_btic::encode::decode_slice(bytes)
3846 .map(Some)
3847 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3848 }
3849 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
3851 let val = scalar_binary_to_value(lb.value(row));
3852 if let Value::Temporal(uni_common::TemporalValue::Btic { lo, hi, meta }) = val {
3853 return uni_btic::Btic::new(lo, hi, meta)
3854 .map(Some)
3855 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3856 }
3857 return Ok(None);
3858 }
3859 Ok(None)
3860}
3861
3862pub fn create_btic_min_udaf() -> AggregateUDF {
3863 AggregateUDF::from(BticMinMaxUdaf::new(false))
3864}
3865
3866pub fn create_btic_max_udaf() -> AggregateUDF {
3867 AggregateUDF::from(BticMinMaxUdaf::new(true))
3868}
3869
3870pub fn create_btic_span_agg_udaf() -> AggregateUDF {
3871 AggregateUDF::from(BticSpanAggUdaf::new())
3872}
3873
3874pub fn create_btic_count_at_udaf() -> AggregateUDF {
3875 AggregateUDF::from(BticCountAtUdaf::new())
3876}
3877
3878pub fn invoke_cypher_string_op<F>(
3883 args: &ScalarFunctionArgs,
3884 name: &str,
3885 op: F,
3886) -> DFResult<ColumnarValue>
3887where
3888 F: Fn(&str, &str) -> bool,
3889{
3890 use arrow_array::{BooleanArray, LargeBinaryArray, LargeStringArray, StringArray};
3891 use datafusion::common::ScalarValue;
3892 use datafusion::error::DataFusionError;
3893
3894 if args.args.len() != 2 {
3895 return Err(DataFusionError::Execution(format!(
3896 "{}(): requires exactly 2 arguments",
3897 name
3898 )));
3899 }
3900
3901 let left = &args.args[0];
3902 let right = &args.args[1];
3903
3904 let extract_string = |scalar: &ScalarValue| -> Option<String> {
3906 match scalar {
3907 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => Some(s.clone()),
3908 ScalarValue::LargeBinary(Some(bytes)) => {
3909 match uni_common::cypher_value_codec::decode(bytes) {
3911 Ok(uni_common::Value::String(s)) => Some(s),
3912 _ => None,
3913 }
3914 }
3915 ScalarValue::Utf8(None)
3916 | ScalarValue::LargeUtf8(None)
3917 | ScalarValue::LargeBinary(None)
3918 | ScalarValue::Null => None,
3919 _ => None,
3920 }
3921 };
3922
3923 match (left, right) {
3924 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Scalar(r_scalar)) => {
3925 let l_str = extract_string(l_scalar);
3926 let r_str = extract_string(r_scalar);
3927
3928 match (l_str, r_str) {
3929 (Some(l), Some(r)) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(op(
3930 &l, &r,
3931 ))))),
3932 _ => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
3933 }
3934 }
3935 (ColumnarValue::Array(l_arr), ColumnarValue::Scalar(r_scalar)) => {
3936 let r_val = extract_string(r_scalar);
3938
3939 if r_val.is_none() {
3940 let nulls = arrow_array::new_null_array(&DataType::Boolean, l_arr.len());
3942 return Ok(ColumnarValue::Array(nulls));
3943 }
3944 let pattern = r_val.unwrap();
3945
3946 let result_array = if let Some(arr) = l_arr.as_any().downcast_ref::<StringArray>() {
3948 arr.iter()
3949 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3950 .collect::<BooleanArray>()
3951 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeStringArray>() {
3952 arr.iter()
3953 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3954 .collect::<BooleanArray>()
3955 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3956 arr.iter()
3958 .map(|opt_bytes| {
3959 opt_bytes.and_then(|bytes| {
3960 match uni_common::cypher_value_codec::decode(bytes) {
3961 Ok(uni_common::Value::String(s)) => Some(op(&s, &pattern)),
3962 _ => None,
3963 }
3964 })
3965 })
3966 .collect::<BooleanArray>()
3967 } else {
3968 arrow_array::new_null_array(&DataType::Boolean, l_arr.len())
3970 .as_any()
3971 .downcast_ref::<BooleanArray>()
3972 .unwrap()
3973 .clone()
3974 };
3975
3976 Ok(ColumnarValue::Array(Arc::new(result_array)))
3977 }
3978 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Array(r_arr)) => {
3979 let l_val = extract_string(l_scalar);
3981
3982 if l_val.is_none() {
3983 let nulls = arrow_array::new_null_array(&DataType::Boolean, r_arr.len());
3984 return Ok(ColumnarValue::Array(nulls));
3985 }
3986 let target = l_val.unwrap();
3987
3988 let result_array = if let Some(arr) = r_arr.as_any().downcast_ref::<StringArray>() {
3989 arr.iter()
3990 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3991 .collect::<BooleanArray>()
3992 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeStringArray>() {
3993 arr.iter()
3994 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3995 .collect::<BooleanArray>()
3996 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3997 arr.iter()
3999 .map(|opt_bytes| {
4000 opt_bytes.and_then(|bytes| {
4001 match uni_common::cypher_value_codec::decode(bytes) {
4002 Ok(uni_common::Value::String(s)) => Some(op(&target, &s)),
4003 _ => None,
4004 }
4005 })
4006 })
4007 .collect::<BooleanArray>()
4008 } else {
4009 arrow_array::new_null_array(&DataType::Boolean, r_arr.len())
4011 .as_any()
4012 .downcast_ref::<BooleanArray>()
4013 .unwrap()
4014 .clone()
4015 };
4016
4017 Ok(ColumnarValue::Array(Arc::new(result_array)))
4018 }
4019 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
4020 if l_arr.len() != r_arr.len() {
4022 return Err(DataFusionError::Execution(format!(
4023 "{}(): array lengths must match",
4024 name
4025 )));
4026 }
4027
4028 let extract_string_at = |arr: &dyn Array, idx: usize| -> Option<String> {
4030 if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
4031 str_arr.value(idx).to_string().into()
4032 } else if let Some(str_arr) = arr.as_any().downcast_ref::<LargeStringArray>() {
4033 str_arr.value(idx).to_string().into()
4034 } else if let Some(bin_arr) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
4035 if bin_arr.is_null(idx) {
4036 return None;
4037 }
4038 let bytes = bin_arr.value(idx);
4039 match uni_common::cypher_value_codec::decode(bytes) {
4040 Ok(uni_common::Value::String(s)) => Some(s),
4041 _ => None,
4042 }
4043 } else {
4044 None
4045 }
4046 };
4047
4048 let result: BooleanArray = (0..l_arr.len())
4049 .map(|idx| {
4050 match (
4051 extract_string_at(l_arr.as_ref(), idx),
4052 extract_string_at(r_arr.as_ref(), idx),
4053 ) {
4054 (Some(l_str), Some(r_str)) => Some(op(&l_str, &r_str)),
4055 _ => None,
4056 }
4057 })
4058 .collect();
4059
4060 Ok(ColumnarValue::Array(Arc::new(result)))
4061 }
4062 }
4063}
4064
4065macro_rules! define_string_op_udf {
4066 ($struct_name:ident, $udf_name:literal, $op:expr) => {
4067 #[derive(Debug)]
4068 struct $struct_name {
4069 signature: Signature,
4070 }
4071
4072 impl $struct_name {
4073 fn new() -> Self {
4074 Self {
4075 signature: Signature::any(2, Volatility::Immutable),
4077 }
4078 }
4079 }
4080
4081 impl_udf_eq_hash!($struct_name);
4082
4083 impl ScalarUDFImpl for $struct_name {
4084 fn as_any(&self) -> &dyn Any {
4085 self
4086 }
4087 fn name(&self) -> &str {
4088 $udf_name
4089 }
4090 fn signature(&self) -> &Signature {
4091 &self.signature
4092 }
4093 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4094 Ok(DataType::Boolean)
4095 }
4096
4097 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4098 invoke_cypher_string_op(&args, $udf_name, $op)
4099 }
4100 }
4101 };
4102}
4103
4104define_string_op_udf!(CypherStartsWithUdf, "_cypher_starts_with", |s, p| s
4105 .starts_with(p));
4106define_string_op_udf!(CypherEndsWithUdf, "_cypher_ends_with", |s, p| s
4107 .ends_with(p));
4108define_string_op_udf!(CypherContainsUdf, "_cypher_contains", |s, p| s.contains(p));
4109
4110pub fn create_cypher_starts_with_udf() -> ScalarUDF {
4111 ScalarUDF::new_from_impl(CypherStartsWithUdf::new())
4112}
4113pub fn create_cypher_ends_with_udf() -> ScalarUDF {
4114 ScalarUDF::new_from_impl(CypherEndsWithUdf::new())
4115}
4116pub fn create_cypher_contains_udf() -> ScalarUDF {
4117 ScalarUDF::new_from_impl(CypherContainsUdf::new())
4118}
4119
4120pub fn create_cypher_equal_udf() -> ScalarUDF {
4121 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_equal", BinaryOp::Eq))
4122}
4123pub fn create_cypher_not_equal_udf() -> ScalarUDF {
4124 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_not_equal", BinaryOp::NotEq))
4125}
4126pub fn create_cypher_lt_udf() -> ScalarUDF {
4127 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt", BinaryOp::Lt))
4128}
4129pub fn create_cypher_lt_eq_udf() -> ScalarUDF {
4130 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt_eq", BinaryOp::LtEq))
4131}
4132pub fn create_cypher_gt_udf() -> ScalarUDF {
4133 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt", BinaryOp::Gt))
4134}
4135pub fn create_cypher_gt_eq_udf() -> ScalarUDF {
4136 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt_eq", BinaryOp::GtEq))
4137}
4138
4139#[expect(clippy::match_like_matches_macro)]
4141fn apply_comparison_op(ord: std::cmp::Ordering, op: &BinaryOp) -> bool {
4142 use std::cmp::Ordering;
4143 match (ord, op) {
4144 (Ordering::Less, BinaryOp::Lt | BinaryOp::LtEq | BinaryOp::NotEq) => true,
4145 (Ordering::Equal, BinaryOp::Eq | BinaryOp::LtEq | BinaryOp::GtEq) => true,
4146 (Ordering::Greater, BinaryOp::Gt | BinaryOp::GtEq | BinaryOp::NotEq) => true,
4147 _ => false,
4148 }
4149}
4150
4151fn compare_f64(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<bool> {
4154 if lhs.is_nan() || rhs.is_nan() {
4155 Some(matches!(op, BinaryOp::NotEq))
4156 } else {
4157 Some(apply_comparison_op(lhs.partial_cmp(&rhs)?, op))
4158 }
4159}
4160
4161fn cv_bytes_as_f64(bytes: &[u8]) -> Option<f64> {
4163 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4164 match peek_tag(bytes)? {
4165 TAG_INT => decode_int(bytes).map(|i| i as f64),
4166 TAG_FLOAT => decode_float(bytes),
4167 _ => None,
4168 }
4169}
4170
4171fn compare_cv_numeric(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<bool> {
4174 use uni_common::cypher_value_codec::{TAG_INT, TAG_NULL, decode_int, peek_tag};
4175 if peek_tag(bytes) == Some(TAG_INT)
4177 && let Some(lhs_int) = decode_int(bytes)
4178 && rhs.fract() == 0.0
4180 && rhs >= i64::MIN as f64
4181 && rhs <= i64::MAX as f64
4182 {
4183 return Some(apply_comparison_op(lhs_int.cmp(&(rhs as i64)), op));
4184 }
4185 if peek_tag(bytes) == Some(TAG_NULL) {
4186 return None;
4187 }
4188 let lhs = cv_bytes_as_f64(bytes)?;
4189 compare_f64(lhs, rhs, op)
4190}
4191
4192fn try_fast_compare(
4196 lhs: &ColumnarValue,
4197 rhs: &ColumnarValue,
4198 op: &BinaryOp,
4199) -> Option<ColumnarValue> {
4200 use arrow_array::builder::BooleanBuilder;
4201 use uni_common::cypher_value_codec::{
4202 TAG_INT, TAG_NULL, TAG_STRING, decode_int, decode_string, peek_tag,
4203 };
4204
4205 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4206 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4207 _ => return None,
4208 };
4209
4210 if !matches!(lhs_arr.data_type(), DataType::LargeBinary) {
4212 return None;
4213 }
4214
4215 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4216
4217 match rhs_arr.data_type() {
4218 DataType::Int64 => {
4220 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4221 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4222 for i in 0..lb_arr.len() {
4223 if lb_arr.is_null(i) || int_arr.is_null(i) {
4224 builder.append_null();
4225 } else {
4226 match compare_cv_numeric(lb_arr.value(i), int_arr.value(i) as f64, op) {
4227 Some(result) => builder.append_value(result),
4228 None => builder.append_null(),
4229 }
4230 }
4231 }
4232 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4233 }
4234
4235 DataType::Float64 => {
4237 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4238 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4239 for i in 0..lb_arr.len() {
4240 if lb_arr.is_null(i) || float_arr.is_null(i) {
4241 builder.append_null();
4242 } else {
4243 match compare_cv_numeric(lb_arr.value(i), float_arr.value(i), op) {
4244 Some(result) => builder.append_value(result),
4245 None => builder.append_null(),
4246 }
4247 }
4248 }
4249 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4250 }
4251
4252 DataType::Utf8 | DataType::LargeUtf8 => {
4254 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4255 for i in 0..lb_arr.len() {
4256 if lb_arr.is_null(i) || rhs_arr.is_null(i) {
4257 builder.append_null();
4258 } else {
4259 let bytes = lb_arr.value(i);
4260 let rhs_str = if matches!(rhs_arr.data_type(), DataType::Utf8) {
4261 rhs_arr.as_any().downcast_ref::<StringArray>()?.value(i)
4262 } else {
4263 rhs_arr
4264 .as_any()
4265 .downcast_ref::<LargeStringArray>()?
4266 .value(i)
4267 };
4268 match peek_tag(bytes) {
4269 Some(TAG_STRING) => {
4270 if let Some(lhs_str) = decode_string(bytes) {
4271 builder.append_value(apply_comparison_op(
4272 lhs_str.as_str().cmp(rhs_str),
4273 op,
4274 ));
4275 } else {
4276 builder.append_null();
4277 }
4278 }
4279 _ => builder.append_null(),
4280 }
4281 }
4282 }
4283 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4284 }
4285
4286 DataType::LargeBinary => {
4288 let rhs_lb = rhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4289 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4290 for i in 0..lb_arr.len() {
4291 if lb_arr.is_null(i) || rhs_lb.is_null(i) {
4292 builder.append_null();
4293 } else {
4294 let lhs_bytes = lb_arr.value(i);
4295 let rhs_bytes = rhs_lb.value(i);
4296 let lhs_tag = peek_tag(lhs_bytes);
4297 let rhs_tag = peek_tag(rhs_bytes);
4298
4299 if lhs_tag == Some(TAG_NULL) || rhs_tag == Some(TAG_NULL) {
4301 builder.append_null();
4302 continue;
4303 }
4304
4305 if lhs_tag == Some(TAG_INT) && rhs_tag == Some(TAG_INT) {
4307 if let (Some(l), Some(r)) = (decode_int(lhs_bytes), decode_int(rhs_bytes)) {
4308 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4309 } else {
4310 builder.append_null();
4311 }
4312 continue;
4313 }
4314
4315 if lhs_tag == Some(TAG_STRING) && rhs_tag == Some(TAG_STRING) {
4317 if let (Some(l), Some(r)) =
4318 (decode_string(lhs_bytes), decode_string(rhs_bytes))
4319 {
4320 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4321 } else {
4322 builder.append_null();
4323 }
4324 continue;
4325 }
4326
4327 if let (Some(l), Some(r)) =
4329 (cv_bytes_as_f64(lhs_bytes), cv_bytes_as_f64(rhs_bytes))
4330 {
4331 match compare_f64(l, r, op) {
4332 Some(result) => builder.append_value(result),
4333 None => builder.append_null(),
4334 }
4335 } else {
4336 return None;
4340 }
4341 }
4342 }
4343 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4344 }
4345
4346 _ => None, }
4348}
4349
4350#[derive(Debug)]
4351struct CypherCompareUdf {
4352 name: String,
4353 op: BinaryOp,
4354 signature: Signature,
4355}
4356
4357impl CypherCompareUdf {
4358 fn new(name: &str, op: BinaryOp) -> Self {
4359 Self {
4360 name: name.to_string(),
4361 op,
4362 signature: Signature::any(2, Volatility::Immutable),
4363 }
4364 }
4365}
4366
4367impl PartialEq for CypherCompareUdf {
4368 fn eq(&self, other: &Self) -> bool {
4369 self.name == other.name
4370 }
4371}
4372
4373impl Eq for CypherCompareUdf {}
4374
4375impl std::hash::Hash for CypherCompareUdf {
4376 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4377 self.name.hash(state);
4378 }
4379}
4380
4381impl ScalarUDFImpl for CypherCompareUdf {
4382 fn as_any(&self) -> &dyn Any {
4383 self
4384 }
4385 fn name(&self) -> &str {
4386 &self.name
4387 }
4388 fn signature(&self) -> &Signature {
4389 &self.signature
4390 }
4391 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4392 Ok(DataType::Boolean)
4393 }
4394
4395 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4396 if args.args.len() != 2 {
4397 return Err(datafusion::error::DataFusionError::Execution(format!(
4398 "{}(): requires 2 arguments",
4399 self.name
4400 )));
4401 }
4402
4403 if let Some(result) = try_fast_compare(&args.args[0], &args.args[1], &self.op) {
4405 return Ok(result);
4406 }
4407
4408 let output_type = DataType::Boolean;
4410 invoke_cypher_udf(args, &output_type, |val_args| {
4411 crate::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4412 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4413 })
4414 }
4415}
4416
4417pub fn create_cypher_add_udf() -> ScalarUDF {
4423 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_add", BinaryOp::Add))
4424}
4425pub fn create_cypher_sub_udf() -> ScalarUDF {
4426 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_sub", BinaryOp::Sub))
4427}
4428pub fn create_cypher_mul_udf() -> ScalarUDF {
4429 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mul", BinaryOp::Mul))
4430}
4431pub fn create_cypher_div_udf() -> ScalarUDF {
4432 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_div", BinaryOp::Div))
4433}
4434pub fn create_cypher_mod_udf() -> ScalarUDF {
4435 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mod", BinaryOp::Mod))
4436}
4437
4438pub fn create_cypher_abs_udf() -> ScalarUDF {
4440 ScalarUDF::new_from_impl(CypherAbsUdf::new())
4441}
4442
4443pub fn cypher_abs_expr(arg: datafusion::logical_expr::Expr) -> datafusion::logical_expr::Expr {
4445 datafusion::logical_expr::Expr::ScalarFunction(
4446 datafusion::logical_expr::expr::ScalarFunction::new_udf(
4447 Arc::new(create_cypher_abs_udf()),
4448 vec![arg],
4449 ),
4450 )
4451}
4452
4453#[derive(Debug)]
4454struct CypherAbsUdf {
4455 signature: Signature,
4456}
4457
4458impl CypherAbsUdf {
4459 fn new() -> Self {
4460 Self {
4461 signature: Signature::any(1, Volatility::Immutable),
4462 }
4463 }
4464}
4465
4466impl_udf_eq_hash!(CypherAbsUdf);
4467
4468impl ScalarUDFImpl for CypherAbsUdf {
4469 fn as_any(&self) -> &dyn Any {
4470 self
4471 }
4472 fn name(&self) -> &str {
4473 "_cypher_abs"
4474 }
4475 fn signature(&self) -> &Signature {
4476 &self.signature
4477 }
4478 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
4479 Ok(DataType::LargeBinary)
4480 }
4481 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4482 if args.args.len() != 1 {
4483 return Err(datafusion::error::DataFusionError::Execution(
4484 "_cypher_abs requires exactly 1 argument".into(),
4485 ));
4486 }
4487 invoke_cypher_udf(args, &DataType::LargeBinary, |val_args| {
4488 match &val_args[0] {
4489 Value::Int(i) => i.checked_abs().map(Value::Int).ok_or_else(|| {
4490 datafusion::error::DataFusionError::Execution(
4491 "integer overflow in abs()".into(),
4492 )
4493 }),
4494 Value::Float(f) => Ok(Value::Float(f.abs())),
4495 Value::Null => Ok(Value::Null),
4496 other => Err(datafusion::error::DataFusionError::Execution(format!(
4497 "abs() requires a numeric argument, got {other:?}"
4498 ))),
4499 }
4500 })
4501 }
4502}
4503
4504pub(crate) fn checked_int_op(lhs: i64, rhs: i64, op: &BinaryOp) -> Option<i64> {
4515 match op {
4516 BinaryOp::Add => lhs.checked_add(rhs),
4517 BinaryOp::Sub => lhs.checked_sub(rhs),
4518 BinaryOp::Mul => lhs.checked_mul(rhs),
4519 BinaryOp::Div => {
4521 if rhs == 0 {
4522 None
4523 } else {
4524 lhs.checked_div(rhs)
4525 }
4526 }
4527 BinaryOp::Mod => {
4528 if rhs == 0 {
4529 None
4530 } else {
4531 lhs.checked_rem(rhs)
4532 }
4533 }
4534 _ => None,
4535 }
4536}
4537
4538enum CvArithOutcome {
4545 Value(Vec<u8>),
4547 Null,
4549 Error(datafusion::error::DataFusionError),
4551}
4552
4553fn apply_int_arithmetic(lhs: i64, rhs: i64, op: &BinaryOp) -> CvArithOutcome {
4560 use uni_common::cypher_value_codec::encode_int;
4561 if !matches!(
4562 op,
4563 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod
4564 ) {
4565 return CvArithOutcome::Null;
4566 }
4567 match checked_int_op(lhs, rhs, op) {
4568 Some(v) => CvArithOutcome::Value(encode_int(v)),
4569 None => {
4570 let msg = if matches!(op, BinaryOp::Div | BinaryOp::Mod) && rhs == 0 {
4571 "division by zero"
4572 } else {
4573 "integer overflow"
4574 };
4575 CvArithOutcome::Error(datafusion::error::DataFusionError::Execution(msg.into()))
4576 }
4577 }
4578}
4579
4580fn apply_float_arithmetic(lhs: f64, rhs: f64, op: &BinaryOp) -> CvArithOutcome {
4586 use uni_common::cypher_value_codec::encode_float;
4587 let result = match op {
4588 BinaryOp::Add => lhs + rhs,
4589 BinaryOp::Sub => lhs - rhs,
4590 BinaryOp::Mul => lhs * rhs,
4591 BinaryOp::Div => lhs / rhs, BinaryOp::Mod => lhs % rhs,
4593 _ => return CvArithOutcome::Null,
4594 };
4595 CvArithOutcome::Value(encode_float(result))
4596}
4597
4598fn cv_arithmetic_int(bytes: &[u8], rhs: i64, op: &BinaryOp) -> CvArithOutcome {
4604 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4605 match peek_tag(bytes) {
4606 Some(TAG_INT) => match decode_int(bytes) {
4607 Some(lhs) => apply_int_arithmetic(lhs, rhs, op),
4608 None => CvArithOutcome::Null,
4609 },
4610 Some(TAG_FLOAT) => match decode_float(bytes) {
4611 Some(lhs) => apply_float_arithmetic(lhs, rhs as f64, op),
4612 None => CvArithOutcome::Null,
4613 },
4614 _ => CvArithOutcome::Null,
4615 }
4616}
4617
4618fn cv_arithmetic_float(bytes: &[u8], rhs: f64, op: &BinaryOp) -> CvArithOutcome {
4624 match cv_bytes_as_f64(bytes) {
4625 Some(lhs) => apply_float_arithmetic(lhs, rhs, op),
4626 None => CvArithOutcome::Null,
4627 }
4628}
4629
4630pub(crate) fn cypher_arith_return_type(arg_types: &[DataType]) -> DataType {
4645 let any_large_binary = arg_types.iter().any(|t| matches!(t, DataType::LargeBinary));
4646 if any_large_binary {
4647 return DataType::LargeBinary;
4648 }
4649 let any_float = arg_types.iter().any(|t| matches!(t, DataType::Float64));
4650 if any_float {
4651 return DataType::Float64;
4652 }
4653 let all_int_or_null = !arg_types.is_empty()
4657 && arg_types
4658 .iter()
4659 .all(|t| matches!(t, DataType::Int64 | DataType::Null));
4660 let any_int = arg_types.iter().any(|t| matches!(t, DataType::Int64));
4661 if all_int_or_null && any_int {
4662 DataType::Int64
4663 } else {
4664 DataType::LargeBinary
4665 }
4666}
4667
4668fn try_fast_arithmetic(
4678 lhs: &ColumnarValue,
4679 rhs: &ColumnarValue,
4680 op: &BinaryOp,
4681) -> Option<DFResult<ColumnarValue>> {
4682 use arrow_array::builder::LargeBinaryBuilder;
4683
4684 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4685 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4686 _ => return None,
4687 };
4688
4689 match (lhs_arr.data_type(), rhs_arr.data_type()) {
4690 (DataType::Int64, DataType::Int64) => Some(int_kernel_arithmetic(lhs_arr, rhs_arr, op)),
4694
4695 (DataType::Int64, DataType::Float64)
4698 | (DataType::Float64, DataType::Int64)
4699 | (DataType::Float64, DataType::Float64) => {
4700 Some(float_kernel_arithmetic(lhs_arr, rhs_arr, op))
4701 }
4702
4703 (DataType::LargeBinary, DataType::Int64) => {
4705 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4706 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4707 let mut builder = LargeBinaryBuilder::new();
4708 for i in 0..lb_arr.len() {
4709 if lb_arr.is_null(i) || int_arr.is_null(i) {
4710 builder.append_null();
4711 } else {
4712 match cv_arithmetic_int(lb_arr.value(i), int_arr.value(i), op) {
4713 CvArithOutcome::Value(bytes) => builder.append_value(&bytes),
4714 CvArithOutcome::Null => builder.append_null(),
4715 CvArithOutcome::Error(e) => return Some(Err(e)),
4716 }
4717 }
4718 }
4719 Some(Ok(ColumnarValue::Array(Arc::new(builder.finish()))))
4720 }
4721
4722 (DataType::LargeBinary, DataType::Float64) => {
4724 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4725 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4726 let mut builder = LargeBinaryBuilder::new();
4727 for i in 0..lb_arr.len() {
4728 if lb_arr.is_null(i) || float_arr.is_null(i) {
4729 builder.append_null();
4730 } else {
4731 match cv_arithmetic_float(lb_arr.value(i), float_arr.value(i), op) {
4732 CvArithOutcome::Value(bytes) => builder.append_value(&bytes),
4733 CvArithOutcome::Null => builder.append_null(),
4734 CvArithOutcome::Error(e) => return Some(Err(e)),
4735 }
4736 }
4737 }
4738 Some(Ok(ColumnarValue::Array(Arc::new(builder.finish()))))
4739 }
4740
4741 _ => None, }
4743}
4744
4745fn int_kernel_arithmetic(
4753 lhs_arr: &ArrayRef,
4754 rhs_arr: &ArrayRef,
4755 op: &BinaryOp,
4756) -> DFResult<ColumnarValue> {
4757 use arrow::compute::kernels::numeric::{add, div, mul, rem, sub};
4758
4759 let result = match op {
4760 BinaryOp::Add => add(lhs_arr, rhs_arr),
4761 BinaryOp::Sub => sub(lhs_arr, rhs_arr),
4762 BinaryOp::Mul => mul(lhs_arr, rhs_arr),
4763 BinaryOp::Div => div(lhs_arr, rhs_arr),
4764 BinaryOp::Mod => rem(lhs_arr, rhs_arr),
4765 other => {
4766 return Err(datafusion::error::DataFusionError::Execution(format!(
4767 "unsupported integer arithmetic operator: {other:?}"
4768 )));
4769 }
4770 };
4771 result
4772 .map(ColumnarValue::Array)
4773 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4774}
4775
4776fn float_kernel_arithmetic(
4782 lhs_arr: &ArrayRef,
4783 rhs_arr: &ArrayRef,
4784 op: &BinaryOp,
4785) -> DFResult<ColumnarValue> {
4786 use arrow::compute::cast;
4787 use arrow::compute::kernels::numeric::{add, div, mul, rem, sub};
4788
4789 let lhs_f = cast(lhs_arr, &DataType::Float64)
4790 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
4791 let rhs_f = cast(rhs_arr, &DataType::Float64)
4792 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
4793
4794 let result = match op {
4795 BinaryOp::Add => add(&lhs_f, &rhs_f),
4796 BinaryOp::Sub => sub(&lhs_f, &rhs_f),
4797 BinaryOp::Mul => mul(&lhs_f, &rhs_f),
4798 BinaryOp::Div => div(&lhs_f, &rhs_f),
4799 BinaryOp::Mod => rem(&lhs_f, &rhs_f),
4800 other => {
4801 return Err(datafusion::error::DataFusionError::Execution(format!(
4802 "unsupported float arithmetic operator: {other:?}"
4803 )));
4804 }
4805 };
4806 result
4807 .map(ColumnarValue::Array)
4808 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4809}
4810
4811#[derive(Debug)]
4812struct CypherArithmeticUdf {
4813 name: String,
4814 op: BinaryOp,
4815 signature: Signature,
4816}
4817
4818impl CypherArithmeticUdf {
4819 fn new(name: &str, op: BinaryOp) -> Self {
4820 Self {
4821 name: name.to_string(),
4822 op,
4823 signature: Signature::any(2, Volatility::Immutable),
4824 }
4825 }
4826}
4827
4828impl PartialEq for CypherArithmeticUdf {
4829 fn eq(&self, other: &Self) -> bool {
4830 self.name == other.name
4831 }
4832}
4833
4834impl Eq for CypherArithmeticUdf {}
4835
4836impl std::hash::Hash for CypherArithmeticUdf {
4837 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4838 self.name.hash(state);
4839 }
4840}
4841
4842impl ScalarUDFImpl for CypherArithmeticUdf {
4843 fn as_any(&self) -> &dyn Any {
4844 self
4845 }
4846 fn name(&self) -> &str {
4847 &self.name
4848 }
4849 fn signature(&self) -> &Signature {
4850 &self.signature
4851 }
4852 fn return_type(&self, arg_types: &[DataType]) -> DFResult<DataType> {
4853 Ok(cypher_arith_return_type(arg_types))
4857 }
4858
4859 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4860 if args.args.len() != 2 {
4861 return Err(datafusion::error::DataFusionError::Execution(format!(
4862 "{}(): requires 2 arguments",
4863 self.name
4864 )));
4865 }
4866
4867 if let Some(result) = try_fast_arithmetic(&args.args[0], &args.args[1], &self.op) {
4870 return result;
4871 }
4872
4873 let output_type = args.return_field.data_type().clone();
4880 invoke_cypher_udf(args, &output_type, |val_args| {
4881 crate::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4882 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4883 })
4884 }
4885}
4886
4887pub fn create_cypher_xor_udf() -> ScalarUDF {
4892 ScalarUDF::new_from_impl(CypherXorUdf::new())
4893}
4894
4895#[derive(Debug)]
4896struct CypherXorUdf {
4897 signature: Signature,
4898}
4899
4900impl CypherXorUdf {
4901 fn new() -> Self {
4902 Self {
4903 signature: Signature::any(2, Volatility::Immutable),
4904 }
4905 }
4906}
4907
4908impl_udf_eq_hash!(CypherXorUdf);
4909
4910impl ScalarUDFImpl for CypherXorUdf {
4911 fn as_any(&self) -> &dyn Any {
4912 self
4913 }
4914 fn name(&self) -> &str {
4915 "_cypher_xor"
4916 }
4917 fn signature(&self) -> &Signature {
4918 &self.signature
4919 }
4920 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4921 Ok(DataType::Boolean)
4922 }
4923
4924 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4925 let output_type = DataType::Boolean;
4926 invoke_cypher_udf(args, &output_type, |val_args| {
4927 if val_args.len() != 2 {
4928 return Err(datafusion::error::DataFusionError::Execution(
4929 "_cypher_xor(): requires 2 arguments".to_string(),
4930 ));
4931 }
4932 let coerce_bool = |v: &Value| -> Value {
4934 match v {
4935 Value::String(s) if s == "true" => Value::Bool(true),
4936 Value::String(s) if s == "false" => Value::Bool(false),
4937 other => other.clone(),
4938 }
4939 };
4940 let left = coerce_bool(&val_args[0]);
4941 let right = coerce_bool(&val_args[1]);
4942 crate::expr_eval::eval_binary_op(&left, &BinaryOp::Xor, &right)
4943 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4944 })
4945 }
4946}
4947
4948pub fn create_cv_to_bool_udf() -> ScalarUDF {
4955 ScalarUDF::new_from_impl(CvToBoolUdf::new())
4956}
4957
4958#[derive(Debug)]
4959struct CvToBoolUdf {
4960 signature: Signature,
4961}
4962
4963impl CvToBoolUdf {
4964 fn new() -> Self {
4965 Self {
4966 signature: Signature::exact(vec![DataType::LargeBinary], Volatility::Immutable),
4967 }
4968 }
4969}
4970
4971impl_udf_eq_hash!(CvToBoolUdf);
4972
4973impl ScalarUDFImpl for CvToBoolUdf {
4974 fn as_any(&self) -> &dyn Any {
4975 self
4976 }
4977 fn name(&self) -> &str {
4978 "_cv_to_bool"
4979 }
4980 fn signature(&self) -> &Signature {
4981 &self.signature
4982 }
4983 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4984 Ok(DataType::Boolean)
4985 }
4986
4987 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4988 if args.args.len() != 1 {
4989 return Err(datafusion::error::DataFusionError::Execution(
4990 "_cv_to_bool() requires exactly 1 argument".to_string(),
4991 ));
4992 }
4993
4994 match &args.args[0] {
4995 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
4996 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4998 let b = match peek_tag(bytes) {
4999 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
5000 Some(TAG_NULL) => false,
5001 _ => false, };
5003 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))))
5004 }
5005 ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
5006 ColumnarValue::Array(arr) => {
5007 let lb_arr = arr
5008 .as_any()
5009 .downcast_ref::<arrow_array::LargeBinaryArray>()
5010 .ok_or_else(|| {
5011 datafusion::error::DataFusionError::Execution(format!(
5012 "_cv_to_bool(): expected LargeBinary array, got {:?}",
5013 arr.data_type()
5014 ))
5015 })?;
5016
5017 let mut builder = arrow_array::builder::BooleanBuilder::with_capacity(lb_arr.len());
5018
5019 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
5021
5022 for i in 0..lb_arr.len() {
5023 if lb_arr.is_null(i) {
5024 builder.append_null();
5025 } else {
5026 let bytes = lb_arr.value(i);
5027 let b = match peek_tag(bytes) {
5028 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
5029 Some(TAG_NULL) => false,
5030 _ => false, };
5032 builder.append_value(b);
5033 }
5034 }
5035 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
5036 }
5037 }
5038 }
5039}
5040
5041pub fn create_cypher_size_udf() -> ScalarUDF {
5047 ScalarUDF::new_from_impl(CypherSizeUdf::new())
5048}
5049
5050#[derive(Debug)]
5051struct CypherSizeUdf {
5052 signature: Signature,
5053}
5054
5055impl CypherSizeUdf {
5056 fn new() -> Self {
5057 Self {
5058 signature: Signature::any(1, Volatility::Immutable),
5059 }
5060 }
5061}
5062
5063impl_udf_eq_hash!(CypherSizeUdf);
5064
5065impl ScalarUDFImpl for CypherSizeUdf {
5066 fn as_any(&self) -> &dyn Any {
5067 self
5068 }
5069
5070 fn name(&self) -> &str {
5071 "_cypher_size"
5072 }
5073
5074 fn signature(&self) -> &Signature {
5075 &self.signature
5076 }
5077
5078 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5079 Ok(DataType::Int64)
5080 }
5081
5082 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5083 if args.args.len() != 1 {
5084 return Err(datafusion::error::DataFusionError::Execution(
5085 "_cypher_size() requires exactly 1 argument".to_string(),
5086 ));
5087 }
5088
5089 match &args.args[0] {
5090 ColumnarValue::Scalar(scalar) => {
5091 let result = cypher_size_scalar(scalar)?;
5092 Ok(ColumnarValue::Scalar(result))
5093 }
5094 ColumnarValue::Array(arr) => {
5095 let mut results: Vec<Option<i64>> = Vec::with_capacity(arr.len());
5096 for i in 0..arr.len() {
5097 if arr.is_null(i) {
5098 results.push(None);
5099 } else {
5100 let scalar = ScalarValue::try_from_array(arr, i)?;
5101 match cypher_size_scalar(&scalar)? {
5102 ScalarValue::Int64(v) => results.push(v),
5103 _ => results.push(None),
5104 }
5105 }
5106 }
5107 let arr: ArrayRef = Arc::new(arrow_array::Int64Array::from(results));
5108 Ok(ColumnarValue::Array(arr))
5109 }
5110 }
5111 }
5112}
5113
5114fn cypher_size_scalar(scalar: &ScalarValue) -> DFResult<ScalarValue> {
5115 match scalar {
5116 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
5118 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
5119 }
5120 ScalarValue::List(arr) => {
5123 if arr.is_empty() || arr.is_null(0) {
5124 Ok(ScalarValue::Int64(None))
5125 } else {
5126 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
5127 }
5128 }
5129 ScalarValue::LargeList(arr) => {
5130 if arr.is_empty() || arr.is_null(0) {
5131 Ok(ScalarValue::Int64(None))
5132 } else {
5133 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
5134 }
5135 }
5136 ScalarValue::LargeBinary(Some(b)) => {
5138 if let Ok(uni_val) = uni_common::cypher_value_codec::decode(b) {
5139 match &uni_val {
5140 uni_common::Value::Node(_) => {
5141 Err(datafusion::error::DataFusionError::Execution(
5142 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
5143 ))
5144 }
5145 uni_common::Value::Edge(_) => {
5146 Err(datafusion::error::DataFusionError::Execution(
5147 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
5148 ))
5149 }
5150 _ => {
5151 let json_val: serde_json::Value = uni_val.into();
5152 match json_val {
5153 serde_json::Value::Array(arr) => Ok(ScalarValue::Int64(Some(arr.len() as i64))),
5154 serde_json::Value::String(s) => {
5155 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
5156 }
5157 serde_json::Value::Object(m) => Ok(ScalarValue::Int64(Some(m.len() as i64))),
5158 _ => Ok(ScalarValue::Int64(None)),
5159 }
5160 }
5161 }
5162 } else {
5163 Ok(ScalarValue::Int64(None))
5164 }
5165 }
5166 ScalarValue::Map(arr) => {
5168 if arr.is_empty() || arr.is_null(0) {
5169 Ok(ScalarValue::Int64(None))
5170 } else {
5171 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
5173 }
5174 }
5175 ScalarValue::Struct(arr) => {
5177 if arr.is_null(0) {
5178 Ok(ScalarValue::Int64(None))
5179 } else {
5180 let schema = arr.fields();
5181 let field_names: Vec<&str> = schema.iter().map(|f| f.name().as_str()).collect();
5182 if field_names.contains(&"_vid") && !field_names.contains(&"relationships") {
5184 return Err(datafusion::error::DataFusionError::Execution(
5185 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
5186 ));
5187 }
5188 if field_names.contains(&"_eid")
5190 || (field_names.contains(&"_src") && field_names.contains(&"_dst"))
5191 {
5192 return Err(datafusion::error::DataFusionError::Execution(
5193 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
5194 ));
5195 }
5196 if let Some((rels_idx, _)) = schema
5198 .iter()
5199 .enumerate()
5200 .find(|(_, f)| f.name() == "relationships")
5201 {
5202 let rels_col = arr.column(rels_idx);
5204 if let Some(list_arr) =
5205 rels_col.as_any().downcast_ref::<arrow_array::ListArray>()
5206 {
5207 if list_arr.is_null(0) {
5208 Ok(ScalarValue::Int64(Some(0)))
5209 } else {
5210 Ok(ScalarValue::Int64(Some(list_arr.value(0).len() as i64)))
5211 }
5212 } else {
5213 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
5214 }
5215 } else {
5216 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
5217 }
5218 }
5219 }
5220 ScalarValue::Null
5222 | ScalarValue::Utf8(None)
5223 | ScalarValue::LargeUtf8(None)
5224 | ScalarValue::LargeBinary(None) => Ok(ScalarValue::Int64(None)),
5225 other => Err(datafusion::error::DataFusionError::Execution(format!(
5226 "_cypher_size(): unsupported type {other:?}"
5227 ))),
5228 }
5229}
5230
5231pub fn create_cypher_list_compare_udf() -> ScalarUDF {
5237 ScalarUDF::new_from_impl(CypherListCompareUdf::new())
5238}
5239
5240#[derive(Debug)]
5241struct CypherListCompareUdf {
5242 signature: Signature,
5243}
5244
5245impl CypherListCompareUdf {
5246 fn new() -> Self {
5247 Self {
5248 signature: Signature::any(3, Volatility::Immutable),
5249 }
5250 }
5251}
5252
5253impl_udf_eq_hash!(CypherListCompareUdf);
5254
5255impl ScalarUDFImpl for CypherListCompareUdf {
5256 fn as_any(&self) -> &dyn Any {
5257 self
5258 }
5259
5260 fn name(&self) -> &str {
5261 "_cypher_list_compare"
5262 }
5263
5264 fn signature(&self) -> &Signature {
5265 &self.signature
5266 }
5267
5268 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5269 Ok(DataType::Boolean)
5270 }
5271
5272 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5273 let output_type = DataType::Boolean;
5274 invoke_cypher_udf(args, &output_type, |val_args| {
5275 if val_args.len() != 3 {
5276 return Err(datafusion::error::DataFusionError::Execution(
5277 "_cypher_list_compare(): requires 3 arguments (left, right, op)".to_string(),
5278 ));
5279 }
5280
5281 let left = &val_args[0];
5282 let right = &val_args[1];
5283 let op_str = match &val_args[2] {
5284 Value::String(s) => s.as_str(),
5285 _ => {
5286 return Err(datafusion::error::DataFusionError::Execution(
5287 "_cypher_list_compare(): op must be a string".to_string(),
5288 ));
5289 }
5290 };
5291
5292 let (left_items, right_items) = match (left, right) {
5293 (Value::List(l), Value::List(r)) => (l, r),
5294 (Value::Null, _) | (_, Value::Null) => return Ok(Value::Null),
5295 _ => {
5296 return Err(datafusion::error::DataFusionError::Execution(
5297 "_cypher_list_compare(): both arguments must be lists".to_string(),
5298 ));
5299 }
5300 };
5301
5302 let cmp = cypher_list_cmp(left_items, right_items);
5304
5305 let result = match (op_str, cmp) {
5306 (_, None) => Value::Null,
5307 ("lt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Less),
5308 ("lteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Greater),
5309 ("gt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Greater),
5310 ("gteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Less),
5311 _ => {
5312 return Err(datafusion::error::DataFusionError::Execution(format!(
5313 "_cypher_list_compare(): unknown op '{}'",
5314 op_str
5315 )));
5316 }
5317 };
5318
5319 Ok(result)
5320 })
5321 }
5322}
5323
5324pub fn create_map_project_udf() -> ScalarUDF {
5329 ScalarUDF::new_from_impl(MapProjectUdf::new())
5330}
5331
5332#[derive(Debug)]
5333struct MapProjectUdf {
5334 signature: Signature,
5335}
5336
5337impl MapProjectUdf {
5338 fn new() -> Self {
5339 Self {
5340 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5341 }
5342 }
5343}
5344
5345impl_udf_eq_hash!(MapProjectUdf);
5346
5347impl ScalarUDFImpl for MapProjectUdf {
5348 fn as_any(&self) -> &dyn Any {
5349 self
5350 }
5351
5352 fn name(&self) -> &str {
5353 "_map_project"
5354 }
5355
5356 fn signature(&self) -> &Signature {
5357 &self.signature
5358 }
5359
5360 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5361 Ok(DataType::LargeBinary)
5362 }
5363
5364 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5365 let output_type = self.return_type(&[])?;
5366 invoke_cypher_udf(args, &output_type, |val_args| {
5367 let mut result_map = std::collections::HashMap::new();
5368 let mut i = 0;
5369 while i + 1 < val_args.len() {
5370 let key = &val_args[i];
5371 let value = &val_args[i + 1];
5372 if let Some(k) = key.as_str() {
5373 if k == "__all__" {
5374 match value {
5376 Value::Map(map) => {
5377 let source = match map.get("_all_props") {
5384 Some(Value::Map(all)) => all,
5385 _ => map,
5386 };
5387 for (mk, mv) in source {
5388 if !mk.starts_with('_') {
5389 result_map.insert(mk.clone(), mv.clone());
5390 }
5391 }
5392 }
5393 Value::Node(node) => {
5394 for (pk, pv) in &node.properties {
5395 result_map.insert(pk.clone(), pv.clone());
5396 }
5397 }
5398 Value::Edge(edge) => {
5399 for (pk, pv) in &edge.properties {
5400 result_map.insert(pk.clone(), pv.clone());
5401 }
5402 }
5403 _ => {}
5404 }
5405 } else {
5406 result_map.insert(k.to_string(), value.clone());
5407 }
5408 }
5409 i += 2;
5410 }
5411 Ok(Value::Map(result_map))
5412 })
5413 }
5414}
5415
5416pub fn create_make_cypher_list_udf() -> ScalarUDF {
5421 ScalarUDF::new_from_impl(MakeCypherListUdf::new())
5422}
5423
5424#[derive(Debug)]
5425struct MakeCypherListUdf {
5426 signature: Signature,
5427}
5428
5429impl MakeCypherListUdf {
5430 fn new() -> Self {
5431 Self {
5432 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5433 }
5434 }
5435}
5436
5437impl_udf_eq_hash!(MakeCypherListUdf);
5438
5439impl ScalarUDFImpl for MakeCypherListUdf {
5440 fn as_any(&self) -> &dyn Any {
5441 self
5442 }
5443
5444 fn name(&self) -> &str {
5445 "_make_cypher_list"
5446 }
5447
5448 fn signature(&self) -> &Signature {
5449 &self.signature
5450 }
5451
5452 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5453 Ok(DataType::LargeBinary)
5454 }
5455
5456 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5457 let output_type = self.return_type(&[])?;
5458 invoke_cypher_udf(args, &output_type, |val_args| {
5459 Ok(Value::List(val_args.to_vec()))
5460 })
5461 }
5462}
5463
5464pub fn create_cypher_in_udf() -> ScalarUDF {
5481 ScalarUDF::new_from_impl(CypherInUdf::new())
5482}
5483
5484#[derive(Debug)]
5485struct CypherInUdf {
5486 signature: Signature,
5487}
5488
5489impl CypherInUdf {
5490 fn new() -> Self {
5491 Self {
5492 signature: Signature::any(2, Volatility::Immutable),
5493 }
5494 }
5495}
5496
5497impl_udf_eq_hash!(CypherInUdf);
5498
5499impl ScalarUDFImpl for CypherInUdf {
5500 fn as_any(&self) -> &dyn Any {
5501 self
5502 }
5503
5504 fn name(&self) -> &str {
5505 "_cypher_in"
5506 }
5507
5508 fn signature(&self) -> &Signature {
5509 &self.signature
5510 }
5511
5512 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5513 Ok(DataType::Boolean)
5514 }
5515
5516 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5517 invoke_cypher_udf(args, &DataType::Boolean, |vals| {
5518 if vals.len() != 2 {
5519 return Err(datafusion::error::DataFusionError::Execution(
5520 "_cypher_in(): requires 2 arguments".to_string(),
5521 ));
5522 }
5523 let element = &vals[0];
5524 let list_val = &vals[1];
5525
5526 if list_val.is_null() {
5528 return Ok(Value::Null);
5529 }
5530
5531 let items = match list_val {
5533 Value::List(items) => items.as_slice(),
5534 _ => {
5535 return Err(datafusion::error::DataFusionError::Execution(format!(
5536 "_cypher_in(): second argument must be a list, got {:?}",
5537 list_val
5538 )));
5539 }
5540 };
5541
5542 if element.is_null() {
5544 return if items.is_empty() {
5545 Ok(Value::Bool(false))
5546 } else {
5547 Ok(Value::Null) };
5549 }
5550
5551 let mut has_null = false;
5553 for item in items {
5554 match cypher_eq(element, item) {
5555 Some(true) => return Ok(Value::Bool(true)),
5556 None => has_null = true,
5557 Some(false) => {}
5558 }
5559 }
5560
5561 if has_null {
5562 Ok(Value::Null) } else {
5564 Ok(Value::Bool(false))
5565 }
5566 })
5567 }
5568}
5569
5570pub fn create_cypher_list_concat_udf() -> ScalarUDF {
5576 ScalarUDF::new_from_impl(CypherListConcatUdf::new())
5577}
5578
5579#[derive(Debug)]
5580struct CypherListConcatUdf {
5581 signature: Signature,
5582}
5583
5584impl CypherListConcatUdf {
5585 fn new() -> Self {
5586 Self {
5587 signature: Signature::any(2, Volatility::Immutable),
5588 }
5589 }
5590}
5591
5592impl_udf_eq_hash!(CypherListConcatUdf);
5593
5594impl ScalarUDFImpl for CypherListConcatUdf {
5595 fn as_any(&self) -> &dyn Any {
5596 self
5597 }
5598
5599 fn name(&self) -> &str {
5600 "_cypher_list_concat"
5601 }
5602
5603 fn signature(&self) -> &Signature {
5604 &self.signature
5605 }
5606
5607 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5608 Ok(DataType::LargeBinary)
5609 }
5610
5611 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5612 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5613 if vals.len() != 2 {
5614 return Err(datafusion::error::DataFusionError::Execution(
5615 "_cypher_list_concat(): requires 2 arguments".to_string(),
5616 ));
5617 }
5618 if vals[0].is_null() || vals[1].is_null() {
5620 return Ok(Value::Null);
5621 }
5622 match (&vals[0], &vals[1]) {
5623 (Value::List(left), Value::List(right)) => {
5624 let mut result = left.clone();
5625 result.extend(right.iter().cloned());
5626 Ok(Value::List(result))
5627 }
5628 (Value::List(list), elem) => {
5631 let mut result = list.clone();
5632 result.push(elem.clone());
5633 Ok(Value::List(result))
5634 }
5635 (elem, Value::List(list)) => {
5636 let mut result = vec![elem.clone()];
5637 result.extend(list.iter().cloned());
5638 Ok(Value::List(result))
5639 }
5640 _ => {
5641 crate::expr_eval::eval_binary_op(
5644 &vals[0],
5645 &uni_cypher::ast::BinaryOp::Add,
5646 &vals[1],
5647 )
5648 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
5649 }
5650 }
5651 })
5652 }
5653}
5654
5655pub fn create_cypher_list_append_udf() -> ScalarUDF {
5661 ScalarUDF::new_from_impl(CypherListAppendUdf::new())
5662}
5663
5664#[derive(Debug)]
5665struct CypherListAppendUdf {
5666 signature: Signature,
5667}
5668
5669impl CypherListAppendUdf {
5670 fn new() -> Self {
5671 Self {
5672 signature: Signature::any(2, Volatility::Immutable),
5673 }
5674 }
5675}
5676
5677impl_udf_eq_hash!(CypherListAppendUdf);
5678
5679impl ScalarUDFImpl for CypherListAppendUdf {
5680 fn as_any(&self) -> &dyn Any {
5681 self
5682 }
5683
5684 fn name(&self) -> &str {
5685 "_cypher_list_append"
5686 }
5687
5688 fn signature(&self) -> &Signature {
5689 &self.signature
5690 }
5691
5692 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5693 Ok(DataType::LargeBinary)
5694 }
5695
5696 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5697 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5698 if vals.len() != 2 {
5699 return Err(datafusion::error::DataFusionError::Execution(
5700 "_cypher_list_append(): requires 2 arguments".to_string(),
5701 ));
5702 }
5703 let left = &vals[0];
5704 let right = &vals[1];
5705
5706 if left.is_null() || right.is_null() {
5708 return Ok(Value::Null);
5709 }
5710
5711 match (left, right) {
5712 (Value::List(list), elem) => {
5714 let mut result = list.clone();
5715 result.push(elem.clone());
5716 Ok(Value::List(result))
5717 }
5718 (elem, Value::List(list)) => {
5720 let mut result = vec![elem.clone()];
5721 result.extend(list.iter().cloned());
5722 Ok(Value::List(result))
5723 }
5724 _ => Err(datafusion::error::DataFusionError::Execution(format!(
5725 "_cypher_list_append(): at least one argument must be a list, got {:?} and {:?}",
5726 left, right
5727 ))),
5728 }
5729 })
5730 }
5731}
5732
5733pub fn create_cypher_list_slice_udf() -> ScalarUDF {
5739 ScalarUDF::new_from_impl(CypherListSliceUdf::new())
5740}
5741
5742#[derive(Debug)]
5743struct CypherListSliceUdf {
5744 signature: Signature,
5745}
5746
5747impl CypherListSliceUdf {
5748 fn new() -> Self {
5749 Self {
5750 signature: Signature::any(3, Volatility::Immutable),
5751 }
5752 }
5753}
5754
5755impl_udf_eq_hash!(CypherListSliceUdf);
5756
5757impl ScalarUDFImpl for CypherListSliceUdf {
5758 fn as_any(&self) -> &dyn Any {
5759 self
5760 }
5761
5762 fn name(&self) -> &str {
5763 "_cypher_list_slice"
5764 }
5765
5766 fn signature(&self) -> &Signature {
5767 &self.signature
5768 }
5769
5770 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5771 Ok(DataType::LargeBinary)
5772 }
5773
5774 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5775 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5776 if vals.len() != 3 {
5777 return Err(datafusion::error::DataFusionError::Execution(
5778 "_cypher_list_slice(): requires 3 arguments (list, start, end)".to_string(),
5779 ));
5780 }
5781 if vals[0].is_null() {
5783 return Ok(Value::Null);
5784 }
5785 let list = match &vals[0] {
5786 Value::List(l) => l,
5787 _ => {
5788 return Err(datafusion::error::DataFusionError::Execution(format!(
5789 "_cypher_list_slice(): first argument must be a list, got {:?}",
5790 vals[0]
5791 )));
5792 }
5793 };
5794 if vals[1].is_null() || vals[2].is_null() {
5796 return Ok(Value::Null);
5797 }
5798
5799 let len = list.len() as i64;
5800 let raw_start = match &vals[1] {
5801 Value::Int(i) => *i,
5802 _ => 0,
5803 };
5804 let raw_end = match &vals[2] {
5805 Value::Int(i) => *i,
5806 _ => len,
5807 };
5808
5809 let start = if raw_start < 0 {
5811 (len + raw_start).max(0) as usize
5812 } else {
5813 (raw_start).min(len) as usize
5814 };
5815 let end = if raw_end == i64::MAX {
5816 len as usize
5817 } else if raw_end < 0 {
5818 (len + raw_end).max(0) as usize
5819 } else {
5820 (raw_end).min(len) as usize
5821 };
5822
5823 if start >= end {
5824 return Ok(Value::List(vec![]));
5825 }
5826 Ok(Value::List(list[start..end.min(list.len())].to_vec()))
5827 })
5828 }
5829}
5830
5831pub fn create_cypher_reverse_udf() -> ScalarUDF {
5842 ScalarUDF::new_from_impl(CypherReverseUdf::new())
5843}
5844
5845#[derive(Debug)]
5846struct CypherReverseUdf {
5847 signature: Signature,
5848}
5849
5850impl CypherReverseUdf {
5851 fn new() -> Self {
5852 Self {
5853 signature: Signature::any(1, Volatility::Immutable),
5854 }
5855 }
5856}
5857
5858impl_udf_eq_hash!(CypherReverseUdf);
5859
5860impl ScalarUDFImpl for CypherReverseUdf {
5861 fn as_any(&self) -> &dyn Any {
5862 self
5863 }
5864
5865 fn name(&self) -> &str {
5866 "_cypher_reverse"
5867 }
5868
5869 fn signature(&self) -> &Signature {
5870 &self.signature
5871 }
5872
5873 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5874 Ok(DataType::LargeBinary)
5875 }
5876
5877 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5878 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5879 if vals.len() != 1 {
5880 return Err(datafusion::error::DataFusionError::Execution(
5881 "_cypher_reverse(): requires exactly 1 argument".to_string(),
5882 ));
5883 }
5884 match &vals[0] {
5885 Value::Null => Ok(Value::Null),
5886 Value::String(s) => Ok(Value::String(s.chars().rev().collect())),
5887 Value::List(l) => {
5888 let mut reversed = l.clone();
5889 reversed.reverse();
5890 Ok(Value::List(reversed))
5891 }
5892 other => Err(datafusion::error::DataFusionError::Execution(format!(
5893 "_cypher_reverse(): expected string or list, got {:?}",
5894 other
5895 ))),
5896 }
5897 })
5898 }
5899}
5900
5901pub fn create_cypher_substring_udf() -> ScalarUDF {
5912 ScalarUDF::new_from_impl(CypherSubstringUdf::new())
5913}
5914
5915#[derive(Debug)]
5916struct CypherSubstringUdf {
5917 signature: Signature,
5918}
5919
5920impl CypherSubstringUdf {
5921 fn new() -> Self {
5922 Self {
5923 signature: Signature::variadic_any(Volatility::Immutable),
5924 }
5925 }
5926}
5927
5928impl_udf_eq_hash!(CypherSubstringUdf);
5929
5930impl ScalarUDFImpl for CypherSubstringUdf {
5931 fn as_any(&self) -> &dyn Any {
5932 self
5933 }
5934
5935 fn name(&self) -> &str {
5936 "_cypher_substring"
5937 }
5938
5939 fn signature(&self) -> &Signature {
5940 &self.signature
5941 }
5942
5943 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5944 Ok(DataType::Utf8)
5945 }
5946
5947 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5948 invoke_cypher_udf(args, &DataType::Utf8, |vals| {
5949 if vals.len() < 2 || vals.len() > 3 {
5950 return Err(datafusion::error::DataFusionError::Execution(
5951 "_cypher_substring(): requires 2 or 3 arguments".to_string(),
5952 ));
5953 }
5954 if vals.iter().any(|v| v.is_null()) {
5956 return Ok(Value::Null);
5957 }
5958 let s = match &vals[0] {
5959 Value::String(s) => s.as_str(),
5960 other => {
5961 return Err(datafusion::error::DataFusionError::Execution(format!(
5962 "_cypher_substring(): first argument must be a string, got {:?}",
5963 other
5964 )));
5965 }
5966 };
5967 let start = match &vals[1] {
5968 Value::Int(i) => *i,
5969 other => {
5970 return Err(datafusion::error::DataFusionError::Execution(format!(
5971 "_cypher_substring(): second argument must be an integer, got {:?}",
5972 other
5973 )));
5974 }
5975 };
5976
5977 let chars: Vec<char> = s.chars().collect();
5979 let len = chars.len() as i64;
5980
5981 let start_idx = start.max(0).min(len) as usize;
5983
5984 let end_idx = if vals.len() == 3 {
5985 let length = match &vals[2] {
5986 Value::Int(i) => *i,
5987 other => {
5988 return Err(datafusion::error::DataFusionError::Execution(format!(
5989 "_cypher_substring(): third argument must be an integer, got {:?}",
5990 other
5991 )));
5992 }
5993 };
5994 if length < 0 {
5995 return Err(datafusion::error::DataFusionError::Execution(
5996 "ArgumentError: NegativeIntegerArgument - substring length must be non-negative".to_string(),
5997 ));
5998 }
5999 (start_idx as i64 + length).min(len) as usize
6000 } else {
6001 len as usize
6002 };
6003
6004 Ok(Value::String(chars[start_idx..end_idx].iter().collect()))
6005 })
6006 }
6007}
6008
6009pub fn create_cypher_split_udf() -> ScalarUDF {
6018 ScalarUDF::new_from_impl(CypherSplitUdf::new())
6019}
6020
6021#[derive(Debug)]
6022struct CypherSplitUdf {
6023 signature: Signature,
6024}
6025
6026impl CypherSplitUdf {
6027 fn new() -> Self {
6028 Self {
6029 signature: Signature::any(2, Volatility::Immutable),
6030 }
6031 }
6032}
6033
6034impl_udf_eq_hash!(CypherSplitUdf);
6035
6036impl ScalarUDFImpl for CypherSplitUdf {
6037 fn as_any(&self) -> &dyn Any {
6038 self
6039 }
6040
6041 fn name(&self) -> &str {
6042 "_cypher_split"
6043 }
6044
6045 fn signature(&self) -> &Signature {
6046 &self.signature
6047 }
6048
6049 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6050 Ok(DataType::LargeBinary)
6051 }
6052
6053 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6054 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6055 if vals.len() != 2 {
6056 return Err(datafusion::error::DataFusionError::Execution(
6057 "_cypher_split(): requires exactly 2 arguments".to_string(),
6058 ));
6059 }
6060 if vals.iter().any(|v| v.is_null()) {
6062 return Ok(Value::Null);
6063 }
6064 let s = match &vals[0] {
6065 Value::String(s) => s.clone(),
6066 other => {
6067 return Err(datafusion::error::DataFusionError::Execution(format!(
6068 "_cypher_split(): first argument must be a string, got {:?}",
6069 other
6070 )));
6071 }
6072 };
6073 let delimiter = match &vals[1] {
6074 Value::String(d) => d.clone(),
6075 other => {
6076 return Err(datafusion::error::DataFusionError::Execution(format!(
6077 "_cypher_split(): second argument must be a string, got {:?}",
6078 other
6079 )));
6080 }
6081 };
6082 let parts: Vec<Value> = s
6083 .split(&delimiter)
6084 .map(|p| Value::String(p.to_string()))
6085 .collect();
6086 Ok(Value::List(parts))
6087 })
6088 }
6089}
6090
6091pub fn create_cypher_list_to_cv_udf() -> ScalarUDF {
6102 ScalarUDF::new_from_impl(CypherListToCvUdf::new())
6103}
6104
6105#[derive(Debug)]
6106struct CypherListToCvUdf {
6107 signature: Signature,
6108}
6109
6110impl CypherListToCvUdf {
6111 fn new() -> Self {
6112 Self {
6113 signature: Signature::any(1, Volatility::Immutable),
6114 }
6115 }
6116}
6117
6118impl_udf_eq_hash!(CypherListToCvUdf);
6119
6120impl ScalarUDFImpl for CypherListToCvUdf {
6121 fn as_any(&self) -> &dyn Any {
6122 self
6123 }
6124
6125 fn name(&self) -> &str {
6126 "_cypher_list_to_cv"
6127 }
6128
6129 fn signature(&self) -> &Signature {
6130 &self.signature
6131 }
6132
6133 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6134 Ok(DataType::LargeBinary)
6135 }
6136
6137 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6138 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6139 if vals.len() != 1 {
6140 return Err(datafusion::error::DataFusionError::Execution(
6141 "_cypher_list_to_cv(): requires exactly 1 argument".to_string(),
6142 ));
6143 }
6144 Ok(vals[0].clone())
6145 })
6146 }
6147}
6148
6149pub fn create_cypher_scalar_to_cv_udf() -> ScalarUDF {
6160 ScalarUDF::new_from_impl(CypherScalarToCvUdf::new())
6161}
6162
6163#[derive(Debug)]
6164struct CypherScalarToCvUdf {
6165 signature: Signature,
6166}
6167
6168impl CypherScalarToCvUdf {
6169 fn new() -> Self {
6170 Self {
6171 signature: Signature::any(1, Volatility::Immutable),
6172 }
6173 }
6174}
6175
6176impl_udf_eq_hash!(CypherScalarToCvUdf);
6177
6178impl ScalarUDFImpl for CypherScalarToCvUdf {
6179 fn as_any(&self) -> &dyn Any {
6180 self
6181 }
6182
6183 fn name(&self) -> &str {
6184 "_cypher_scalar_to_cv"
6185 }
6186
6187 fn signature(&self) -> &Signature {
6188 &self.signature
6189 }
6190
6191 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6192 Ok(DataType::LargeBinary)
6193 }
6194
6195 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6196 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6197 if vals.len() != 1 {
6198 return Err(datafusion::error::DataFusionError::Execution(
6199 "_cypher_scalar_to_cv(): requires exactly 1 argument".to_string(),
6200 ));
6201 }
6202 Ok(vals[0].clone())
6203 })
6204 }
6205}
6206
6207pub fn create_cypher_tail_udf() -> ScalarUDF {
6219 ScalarUDF::new_from_impl(CypherTailUdf::new())
6220}
6221
6222#[derive(Debug)]
6223struct CypherTailUdf {
6224 signature: Signature,
6225}
6226
6227impl CypherTailUdf {
6228 fn new() -> Self {
6229 Self {
6230 signature: Signature::any(1, Volatility::Immutable),
6231 }
6232 }
6233}
6234
6235impl_udf_eq_hash!(CypherTailUdf);
6236
6237impl ScalarUDFImpl for CypherTailUdf {
6238 fn as_any(&self) -> &dyn Any {
6239 self
6240 }
6241
6242 fn name(&self) -> &str {
6243 "_cypher_tail"
6244 }
6245
6246 fn signature(&self) -> &Signature {
6247 &self.signature
6248 }
6249
6250 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6251 Ok(DataType::LargeBinary)
6252 }
6253
6254 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6255 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6256 if vals.len() != 1 {
6257 return Err(datafusion::error::DataFusionError::Execution(
6258 "_cypher_tail(): requires exactly 1 argument".to_string(),
6259 ));
6260 }
6261 match &vals[0] {
6262 Value::Null => Ok(Value::Null),
6263 Value::List(l) => {
6264 if l.is_empty() {
6265 Ok(Value::List(vec![]))
6266 } else {
6267 Ok(Value::List(l[1..].to_vec()))
6268 }
6269 }
6270 other => Err(datafusion::error::DataFusionError::Execution(format!(
6271 "_cypher_tail(): expected list, got {:?}",
6272 other
6273 ))),
6274 }
6275 })
6276 }
6277}
6278
6279pub fn create_cypher_head_udf() -> ScalarUDF {
6290 ScalarUDF::new_from_impl(CypherHeadUdf::new())
6291}
6292
6293#[derive(Debug)]
6294struct CypherHeadUdf {
6295 signature: Signature,
6296}
6297
6298impl CypherHeadUdf {
6299 fn new() -> Self {
6300 Self {
6301 signature: Signature::any(1, Volatility::Immutable),
6302 }
6303 }
6304}
6305
6306impl_udf_eq_hash!(CypherHeadUdf);
6307
6308impl ScalarUDFImpl for CypherHeadUdf {
6309 fn as_any(&self) -> &dyn Any {
6310 self
6311 }
6312
6313 fn name(&self) -> &str {
6314 "head"
6315 }
6316
6317 fn signature(&self) -> &Signature {
6318 &self.signature
6319 }
6320
6321 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6322 Ok(DataType::LargeBinary)
6323 }
6324
6325 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6326 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6327 if vals.len() != 1 {
6328 return Err(datafusion::error::DataFusionError::Execution(
6329 "head(): requires exactly 1 argument".to_string(),
6330 ));
6331 }
6332 match &vals[0] {
6333 Value::Null => Ok(Value::Null),
6334 Value::List(l) => Ok(l.first().cloned().unwrap_or(Value::Null)),
6335 other => Err(datafusion::error::DataFusionError::Execution(format!(
6336 "head(): expected list, got {:?}",
6337 other
6338 ))),
6339 }
6340 })
6341 }
6342}
6343
6344pub fn create_cypher_last_udf() -> ScalarUDF {
6355 ScalarUDF::new_from_impl(CypherLastUdf::new())
6356}
6357
6358#[derive(Debug)]
6359struct CypherLastUdf {
6360 signature: Signature,
6361}
6362
6363impl CypherLastUdf {
6364 fn new() -> Self {
6365 Self {
6366 signature: Signature::any(1, Volatility::Immutable),
6367 }
6368 }
6369}
6370
6371impl_udf_eq_hash!(CypherLastUdf);
6372
6373impl ScalarUDFImpl for CypherLastUdf {
6374 fn as_any(&self) -> &dyn Any {
6375 self
6376 }
6377
6378 fn name(&self) -> &str {
6379 "last"
6380 }
6381
6382 fn signature(&self) -> &Signature {
6383 &self.signature
6384 }
6385
6386 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6387 Ok(DataType::LargeBinary)
6388 }
6389
6390 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6391 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6392 if vals.len() != 1 {
6393 return Err(datafusion::error::DataFusionError::Execution(
6394 "last(): requires exactly 1 argument".to_string(),
6395 ));
6396 }
6397 match &vals[0] {
6398 Value::Null => Ok(Value::Null),
6399 Value::List(l) => Ok(l.last().cloned().unwrap_or(Value::Null)),
6400 other => Err(datafusion::error::DataFusionError::Execution(format!(
6401 "last(): expected list, got {:?}",
6402 other
6403 ))),
6404 }
6405 })
6406 }
6407}
6408
6409fn cypher_list_cmp(left: &[Value], right: &[Value]) -> Option<std::cmp::Ordering> {
6412 let min_len = left.len().min(right.len());
6413 for i in 0..min_len {
6414 let cmp = cypher_value_cmp(&left[i], &right[i])?;
6415 if cmp != std::cmp::Ordering::Equal {
6416 return Some(cmp);
6417 }
6418 }
6419 Some(left.len().cmp(&right.len()))
6421}
6422
6423fn cypher_value_cmp(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
6426 match (a, b) {
6427 (Value::Null, Value::Null) => Some(std::cmp::Ordering::Equal),
6428 (Value::Null, _) | (_, Value::Null) => None,
6429 (Value::Int(l), Value::Int(r)) => Some(l.cmp(r)),
6430 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r),
6431 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r),
6432 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)),
6433 (Value::String(l), Value::String(r)) => Some(l.cmp(r)),
6434 (Value::Bool(l), Value::Bool(r)) => Some(l.cmp(r)),
6435 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r),
6436 _ => None, }
6438}
6439
6440struct CypherToFloat64Udf {
6448 signature: Signature,
6449}
6450
6451impl CypherToFloat64Udf {
6452 fn new() -> Self {
6453 Self {
6454 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6455 }
6456 }
6457}
6458
6459impl_udf_eq_hash!(CypherToFloat64Udf);
6460
6461impl std::fmt::Debug for CypherToFloat64Udf {
6462 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6463 f.debug_struct("CypherToFloat64Udf").finish()
6464 }
6465}
6466
6467impl ScalarUDFImpl for CypherToFloat64Udf {
6468 fn as_any(&self) -> &dyn Any {
6469 self
6470 }
6471 fn name(&self) -> &str {
6472 "_cypher_to_float64"
6473 }
6474 fn signature(&self) -> &Signature {
6475 &self.signature
6476 }
6477 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6478 Ok(DataType::Float64)
6479 }
6480 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6481 if args.args.len() != 1 {
6482 return Err(datafusion::error::DataFusionError::Execution(
6483 "_cypher_to_float64 requires exactly 1 argument".into(),
6484 ));
6485 }
6486 match &args.args[0] {
6487 ColumnarValue::Scalar(scalar) => {
6488 let f = match scalar {
6489 ScalarValue::LargeBinary(Some(bytes)) => cv_bytes_as_f64(bytes),
6490 ScalarValue::Int64(Some(i)) => Some(*i as f64),
6491 ScalarValue::Int32(Some(i)) => Some(*i as f64),
6492 ScalarValue::Float64(Some(f)) => Some(*f),
6493 ScalarValue::Float32(Some(f)) => Some(*f as f64),
6494 _ => None,
6495 };
6496 Ok(ColumnarValue::Scalar(ScalarValue::Float64(f)))
6497 }
6498 ColumnarValue::Array(arr) => {
6499 let len = arr.len();
6500 let mut builder = arrow::array::Float64Builder::with_capacity(len);
6501 match arr.data_type() {
6502 DataType::LargeBinary => {
6503 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6504 for i in 0..len {
6505 if lb.is_null(i) {
6506 builder.append_null();
6507 } else {
6508 match cv_bytes_as_f64(lb.value(i)) {
6509 Some(f) => builder.append_value(f),
6510 None => builder.append_null(),
6511 }
6512 }
6513 }
6514 }
6515 DataType::Int64 => {
6516 let int_arr = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6517 for i in 0..len {
6518 if int_arr.is_null(i) {
6519 builder.append_null();
6520 } else {
6521 builder.append_value(int_arr.value(i) as f64);
6522 }
6523 }
6524 }
6525 DataType::Float64 => {
6526 let f_arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6527 for i in 0..len {
6528 if f_arr.is_null(i) {
6529 builder.append_null();
6530 } else {
6531 builder.append_value(f_arr.value(i));
6532 }
6533 }
6534 }
6535 _ => {
6536 for _ in 0..len {
6537 builder.append_null();
6538 }
6539 }
6540 }
6541 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
6542 }
6543 }
6544 }
6545}
6546
6547fn create_cypher_to_float64_udf() -> ScalarUDF {
6548 ScalarUDF::from(CypherToFloat64Udf::new())
6549}
6550
6551pub fn cypher_to_float64_expr(
6553 arg: datafusion::logical_expr::Expr,
6554) -> datafusion::logical_expr::Expr {
6555 datafusion::logical_expr::Expr::ScalarFunction(
6556 datafusion::logical_expr::expr::ScalarFunction::new_udf(
6557 Arc::new(create_cypher_to_float64_udf()),
6558 vec![arg],
6559 ),
6560 )
6561}
6562
6563pub fn cypher_to_float64_udf() -> datafusion::logical_expr::ScalarUDF {
6565 create_cypher_to_float64_udf()
6566}
6567
6568fn cypher_type_rank(val: &Value) -> u8 {
6576 match val {
6577 Value::Null => 0,
6578 Value::List(_) => 1,
6579 Value::String(_) => 2,
6580 Value::Bool(_) => 3,
6581 Value::Int(_) | Value::Float(_) => 4,
6582 _ => 5, }
6584}
6585
6586fn cypher_cross_type_cmp(a: &Value, b: &Value) -> std::cmp::Ordering {
6589 use std::cmp::Ordering;
6590 let ra = cypher_type_rank(a);
6591 let rb = cypher_type_rank(b);
6592 if ra != rb {
6593 return ra.cmp(&rb);
6594 }
6595 match (a, b) {
6597 (Value::Int(l), Value::Int(r)) => l.cmp(r),
6598 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r).unwrap_or(Ordering::Equal),
6599 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r).unwrap_or(Ordering::Equal),
6600 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)).unwrap_or(Ordering::Equal),
6601 (Value::String(l), Value::String(r)) => l.cmp(r),
6602 (Value::Bool(l), Value::Bool(r)) => l.cmp(r),
6603 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r).unwrap_or(Ordering::Equal),
6604 _ => Ordering::Equal,
6605 }
6606}
6607
6608fn scalar_binary_to_value(bytes: &[u8]) -> Value {
6610 uni_common::cypher_value_codec::decode(bytes).unwrap_or(Value::Null)
6611}
6612
6613use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, AggregateUDFImpl};
6614
6615#[derive(Debug, Clone)]
6617struct CypherMinMaxUdaf {
6618 name: String,
6619 signature: Signature,
6620 is_max: bool,
6621}
6622
6623impl CypherMinMaxUdaf {
6624 fn new(is_max: bool) -> Self {
6625 let name = if is_max { "_cypher_max" } else { "_cypher_min" };
6626 Self {
6627 name: name.to_string(),
6628 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6629 is_max,
6630 }
6631 }
6632}
6633
6634impl PartialEq for CypherMinMaxUdaf {
6635 fn eq(&self, other: &Self) -> bool {
6636 self.name == other.name
6637 }
6638}
6639
6640impl Eq for CypherMinMaxUdaf {}
6641
6642impl Hash for CypherMinMaxUdaf {
6643 fn hash<H: Hasher>(&self, state: &mut H) {
6644 self.name.hash(state);
6645 }
6646}
6647
6648impl AggregateUDFImpl for CypherMinMaxUdaf {
6649 fn as_any(&self) -> &dyn Any {
6650 self
6651 }
6652 fn name(&self) -> &str {
6653 &self.name
6654 }
6655 fn signature(&self) -> &Signature {
6656 &self.signature
6657 }
6658 fn return_type(&self, args: &[DataType]) -> DFResult<DataType> {
6659 Ok(args.first().cloned().unwrap_or(DataType::LargeBinary))
6661 }
6662 fn accumulator(
6663 &self,
6664 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6665 ) -> DFResult<Box<dyn DfAccumulator>> {
6666 let raw_bytes = acc_args
6672 .expr_fields
6673 .first()
6674 .and_then(|field| field.metadata().get("uni_raw_bytes"))
6675 .is_some_and(|v| v == "true");
6676 Ok(Box::new(CypherMinMaxAccumulator {
6677 current: None,
6678 is_max: self.is_max,
6679 return_type: acc_args.return_field.data_type().clone(),
6680 raw_bytes,
6681 }))
6682 }
6683 fn state_fields(
6684 &self,
6685 args: datafusion::logical_expr::function::StateFieldsArgs,
6686 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6687 Ok(vec![Arc::new(arrow::datatypes::Field::new(
6688 args.name,
6689 DataType::LargeBinary,
6690 true,
6691 ))])
6692 }
6693}
6694
6695#[derive(Debug)]
6696struct CypherMinMaxAccumulator {
6697 current: Option<Value>,
6698 is_max: bool,
6699 return_type: DataType,
6700 raw_bytes: bool,
6702}
6703
6704impl CypherMinMaxAccumulator {
6705 fn accumulate(&mut self, val: Value) {
6707 if val.is_null() {
6708 return;
6709 }
6710 self.current = Some(match self.current.take() {
6711 None => val,
6712 Some(cur) => {
6713 let ord = cypher_cross_type_cmp(&val, &cur);
6714 if (self.is_max && ord == std::cmp::Ordering::Greater)
6715 || (!self.is_max && ord == std::cmp::Ordering::Less)
6716 {
6717 val
6718 } else {
6719 cur
6720 }
6721 }
6722 });
6723 }
6724}
6725
6726impl DfAccumulator for CypherMinMaxAccumulator {
6727 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6728 let arr = &values[0];
6729 match arr.data_type() {
6730 DataType::LargeBinary => {
6731 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6732 for i in 0..lb.len() {
6733 if lb.is_null(i) {
6734 continue;
6735 }
6736 let val = if self.raw_bytes {
6739 Value::Bytes(lb.value(i).to_vec())
6740 } else {
6741 scalar_binary_to_value(lb.value(i))
6742 };
6743 self.accumulate(val);
6744 }
6745 }
6746 _ => {
6747 for i in 0..arr.len() {
6749 if arr.is_null(i) {
6750 continue;
6751 }
6752 let sv = ScalarValue::try_from_array(arr, i).map_err(|e| {
6753 datafusion::error::DataFusionError::Execution(e.to_string())
6754 })?;
6755 self.accumulate(scalar_to_value(&sv)?);
6756 }
6757 }
6758 }
6759 Ok(())
6760 }
6761 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6762 match &self.current {
6763 None => {
6764 ScalarValue::try_from(&self.return_type)
6766 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
6767 }
6768 Some(val) => {
6769 if matches!(self.return_type, DataType::LargeBinary) {
6771 let bytes = uni_common::cypher_value_codec::encode(val);
6772 return Ok(ScalarValue::LargeBinary(Some(bytes)));
6773 }
6774 match val {
6776 Value::Int(i) => match &self.return_type {
6777 DataType::Int64 => Ok(ScalarValue::Int64(Some(*i))),
6778 DataType::UInt64 => Ok(ScalarValue::UInt64(Some(*i as u64))),
6779 _ => {
6780 let bytes = uni_common::cypher_value_codec::encode(val);
6781 Ok(ScalarValue::LargeBinary(Some(bytes)))
6782 }
6783 },
6784 Value::Float(f) => match &self.return_type {
6785 DataType::Float64 => Ok(ScalarValue::Float64(Some(*f))),
6786 _ => {
6787 let bytes = uni_common::cypher_value_codec::encode(val);
6788 Ok(ScalarValue::LargeBinary(Some(bytes)))
6789 }
6790 },
6791 Value::String(s) => match &self.return_type {
6792 DataType::Utf8 => Ok(ScalarValue::Utf8(Some(s.clone()))),
6793 DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
6794 _ => {
6795 let bytes = uni_common::cypher_value_codec::encode(val);
6796 Ok(ScalarValue::LargeBinary(Some(bytes)))
6797 }
6798 },
6799 Value::Bool(b) => match &self.return_type {
6800 DataType::Boolean => Ok(ScalarValue::Boolean(Some(*b))),
6801 _ => {
6802 let bytes = uni_common::cypher_value_codec::encode(val);
6803 Ok(ScalarValue::LargeBinary(Some(bytes)))
6804 }
6805 },
6806 _ => {
6807 let bytes = uni_common::cypher_value_codec::encode(val);
6809 Ok(ScalarValue::LargeBinary(Some(bytes)))
6810 }
6811 }
6812 }
6813 }
6814 }
6815 fn size(&self) -> usize {
6816 std::mem::size_of_val(self) + self.current.as_ref().map_or(0, |_| 64)
6817 }
6818 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6819 Ok(vec![self.evaluate()?])
6820 }
6821 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6822 let arr = &states[0];
6826 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
6827 for i in 0..lb.len() {
6828 if lb.is_null(i) {
6829 continue;
6830 }
6831 self.accumulate(scalar_binary_to_value(lb.value(i)));
6832 }
6833 }
6834 Ok(())
6835 }
6836}
6837
6838pub fn create_cypher_min_udaf() -> AggregateUDF {
6839 AggregateUDF::from(CypherMinMaxUdaf::new(false))
6840}
6841
6842pub fn create_cypher_max_udaf() -> AggregateUDF {
6843 AggregateUDF::from(CypherMinMaxUdaf::new(true))
6844}
6845
6846#[derive(Debug, Clone)]
6852struct CypherSumUdaf {
6853 signature: Signature,
6854}
6855
6856impl CypherSumUdaf {
6857 fn new() -> Self {
6858 Self {
6859 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6860 }
6861 }
6862}
6863
6864impl PartialEq for CypherSumUdaf {
6865 fn eq(&self, other: &Self) -> bool {
6866 self.signature == other.signature
6867 }
6868}
6869
6870impl Eq for CypherSumUdaf {}
6871
6872impl Hash for CypherSumUdaf {
6873 fn hash<H: Hasher>(&self, state: &mut H) {
6874 self.name().hash(state);
6875 }
6876}
6877
6878impl AggregateUDFImpl for CypherSumUdaf {
6879 fn as_any(&self) -> &dyn Any {
6880 self
6881 }
6882 fn name(&self) -> &str {
6883 "_cypher_sum"
6884 }
6885 fn signature(&self) -> &Signature {
6886 &self.signature
6887 }
6888 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6889 Ok(DataType::LargeBinary)
6892 }
6893 fn accumulator(
6894 &self,
6895 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6896 ) -> DFResult<Box<dyn DfAccumulator>> {
6897 Ok(Box::new(CypherSumAccumulator {
6898 sum: 0.0,
6899 all_ints: true,
6900 int_sum: 0i64,
6901 has_value: false,
6902 }))
6903 }
6904 fn state_fields(
6905 &self,
6906 args: datafusion::logical_expr::function::StateFieldsArgs,
6907 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6908 Ok(vec![
6909 Arc::new(arrow::datatypes::Field::new(
6910 format!("{}_sum", args.name),
6911 DataType::Float64,
6912 true,
6913 )),
6914 Arc::new(arrow::datatypes::Field::new(
6915 format!("{}_int_sum", args.name),
6916 DataType::Int64,
6917 true,
6918 )),
6919 Arc::new(arrow::datatypes::Field::new(
6920 format!("{}_all_ints", args.name),
6921 DataType::Boolean,
6922 true,
6923 )),
6924 Arc::new(arrow::datatypes::Field::new(
6925 format!("{}_has_value", args.name),
6926 DataType::Boolean,
6927 true,
6928 )),
6929 ])
6930 }
6931}
6932
6933#[derive(Debug)]
6934struct CypherSumAccumulator {
6935 sum: f64,
6936 all_ints: bool,
6937 int_sum: i64,
6938 has_value: bool,
6939}
6940
6941impl DfAccumulator for CypherSumAccumulator {
6942 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6943 let arr = &values[0];
6944 for i in 0..arr.len() {
6945 if arr.is_null(i) {
6946 continue;
6947 }
6948 match arr.data_type() {
6949 DataType::LargeBinary => {
6950 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6951 let bytes = lb.value(i);
6952 use uni_common::cypher_value_codec::{
6953 TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag,
6954 };
6955 match peek_tag(bytes) {
6956 Some(TAG_INT) => {
6957 if let Some(v) = decode_int(bytes) {
6958 self.sum += v as f64;
6959 self.int_sum = self.int_sum.wrapping_add(v);
6960 self.has_value = true;
6961 }
6962 }
6963 Some(TAG_FLOAT) => {
6964 if let Some(v) = decode_float(bytes) {
6965 self.sum += v;
6966 self.all_ints = false;
6967 self.has_value = true;
6968 }
6969 }
6970 _ => {} }
6972 }
6973 DataType::Int64 => {
6974 let a = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6975 let v = a.value(i);
6976 self.sum += v as f64;
6977 self.int_sum = self.int_sum.wrapping_add(v);
6978 self.has_value = true;
6979 }
6980 DataType::Float64 => {
6981 let a = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6982 self.sum += a.value(i);
6983 self.all_ints = false;
6984 self.has_value = true;
6985 }
6986 _ => {}
6987 }
6988 }
6989 Ok(())
6990 }
6991 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6992 if !self.has_value {
6993 return Ok(ScalarValue::LargeBinary(None));
6994 }
6995 let val = if self.all_ints {
6996 Value::Int(self.int_sum)
6997 } else {
6998 Value::Float(self.sum)
6999 };
7000 let bytes = uni_common::cypher_value_codec::encode(&val);
7001 Ok(ScalarValue::LargeBinary(Some(bytes)))
7002 }
7003 fn size(&self) -> usize {
7004 std::mem::size_of_val(self)
7005 }
7006 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7007 Ok(vec![
7008 ScalarValue::Float64(Some(self.sum)),
7009 ScalarValue::Int64(Some(self.int_sum)),
7010 ScalarValue::Boolean(Some(self.all_ints)),
7011 ScalarValue::Boolean(Some(self.has_value)),
7012 ])
7013 }
7014 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7015 let sum_arr = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
7016 let int_sum_arr = states[1].as_any().downcast_ref::<Int64Array>().unwrap();
7017 let all_ints_arr = states[2].as_any().downcast_ref::<BooleanArray>().unwrap();
7018 let has_value_arr = states[3].as_any().downcast_ref::<BooleanArray>().unwrap();
7019 for i in 0..sum_arr.len() {
7020 if !has_value_arr.is_null(i) && has_value_arr.value(i) {
7021 self.sum += sum_arr.value(i);
7022 self.int_sum = self.int_sum.wrapping_add(int_sum_arr.value(i));
7023 if !all_ints_arr.value(i) {
7024 self.all_ints = false;
7025 }
7026 self.has_value = true;
7027 }
7028 }
7029 Ok(())
7030 }
7031}
7032
7033pub fn create_cypher_sum_udaf() -> AggregateUDF {
7034 AggregateUDF::from(CypherSumUdaf::new())
7035}
7036
7037#[derive(Debug, Clone)]
7044struct CypherCollectUdaf {
7045 signature: Signature,
7046}
7047
7048impl CypherCollectUdaf {
7049 fn new() -> Self {
7050 Self {
7051 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
7052 }
7053 }
7054}
7055
7056impl PartialEq for CypherCollectUdaf {
7057 fn eq(&self, other: &Self) -> bool {
7058 self.signature == other.signature
7059 }
7060}
7061
7062impl Eq for CypherCollectUdaf {}
7063
7064impl Hash for CypherCollectUdaf {
7065 fn hash<H: Hasher>(&self, state: &mut H) {
7066 self.name().hash(state);
7067 }
7068}
7069
7070impl AggregateUDFImpl for CypherCollectUdaf {
7071 fn as_any(&self) -> &dyn Any {
7072 self
7073 }
7074 fn name(&self) -> &str {
7075 "_cypher_collect"
7076 }
7077 fn signature(&self) -> &Signature {
7078 &self.signature
7079 }
7080 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7081 Ok(DataType::LargeBinary)
7082 }
7083 fn accumulator(
7084 &self,
7085 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7086 ) -> DFResult<Box<dyn DfAccumulator>> {
7087 let raw_bytes = acc_args
7094 .expr_fields
7095 .first()
7096 .and_then(|field| field.metadata().get("uni_raw_bytes"))
7097 .is_some_and(|v| v == "true");
7098 Ok(Box::new(CypherCollectAccumulator {
7099 values: Vec::new(),
7100 distinct: acc_args.is_distinct,
7101 raw_bytes,
7102 }))
7103 }
7104 fn state_fields(
7105 &self,
7106 args: datafusion::logical_expr::function::StateFieldsArgs,
7107 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7108 Ok(vec![Arc::new(arrow::datatypes::Field::new(
7109 args.name,
7110 DataType::LargeBinary,
7111 true,
7112 ))])
7113 }
7114}
7115
7116#[derive(Debug)]
7117struct CypherCollectAccumulator {
7118 values: Vec<Value>,
7119 distinct: bool,
7120 raw_bytes: bool,
7123}
7124
7125impl CypherCollectAccumulator {
7126 fn push_value(&mut self, val: Value) {
7128 if self.distinct {
7129 let repr = val.to_string();
7131 if self.values.iter().any(|v| v.to_string() == repr) {
7132 return;
7133 }
7134 }
7135 self.values.push(val);
7136 }
7137}
7138
7139impl DfAccumulator for CypherCollectAccumulator {
7140 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7141 let arr = &values[0];
7142 if self.raw_bytes
7145 && let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>()
7146 {
7147 for i in 0..lb.len() {
7148 if lb.is_null(i) {
7149 continue;
7150 }
7151 self.push_value(Value::Bytes(lb.value(i).to_vec()));
7152 }
7153 return Ok(());
7154 }
7155 for i in 0..arr.len() {
7156 if arr.is_null(i) {
7157 continue;
7158 }
7159 if let Some(struct_arr) = arr.as_any().downcast_ref::<arrow::array::StructArray>()
7163 && struct_arr.num_columns() > 0
7164 && struct_arr.column(0).is_null(i)
7165 {
7166 continue;
7167 }
7168 let sv = ScalarValue::try_from_array(arr, i)
7169 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
7170 let val = scalar_to_value(&sv)?;
7171 if val.is_null() {
7172 continue;
7173 }
7174 self.push_value(val);
7175 }
7176 Ok(())
7177 }
7178 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7179 let val = Value::List(self.values.clone());
7181 let bytes = uni_common::cypher_value_codec::encode(&val);
7182 Ok(ScalarValue::LargeBinary(Some(bytes)))
7183 }
7184 fn size(&self) -> usize {
7185 std::mem::size_of_val(self) + self.values.len() * 64
7186 }
7187 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7188 Ok(vec![self.evaluate()?])
7189 }
7190 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7191 let arr = &states[0];
7193 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
7194 for i in 0..lb.len() {
7195 if lb.is_null(i) {
7196 continue;
7197 }
7198 let val = scalar_binary_to_value(lb.value(i));
7199 if let Value::List(items) = val {
7200 for item in items {
7201 if !item.is_null() {
7202 self.push_value(item);
7203 }
7204 }
7205 }
7206 }
7207 }
7208 Ok(())
7209 }
7210}
7211
7212pub fn create_cypher_collect_udaf() -> AggregateUDF {
7213 AggregateUDF::from(CypherCollectUdaf::new())
7214}
7215
7216pub fn create_cypher_collect_expr(
7218 arg: datafusion::logical_expr::Expr,
7219 distinct: bool,
7220) -> datafusion::logical_expr::Expr {
7221 let udaf = Arc::new(create_cypher_collect_udaf());
7224 if distinct {
7225 datafusion::logical_expr::Expr::AggregateFunction(
7227 datafusion::logical_expr::expr::AggregateFunction::new_udf(
7228 udaf,
7229 vec![arg],
7230 true, None,
7232 vec![],
7233 None,
7234 ),
7235 )
7236 } else {
7237 udaf.call(vec![arg])
7238 }
7239}
7240
7241#[derive(Debug, Clone)]
7247struct CypherPercentileDiscUdaf {
7248 signature: Signature,
7249}
7250
7251impl CypherPercentileDiscUdaf {
7252 fn new() -> Self {
7253 Self {
7254 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7255 }
7256 }
7257}
7258
7259impl PartialEq for CypherPercentileDiscUdaf {
7260 fn eq(&self, other: &Self) -> bool {
7261 self.signature == other.signature
7262 }
7263}
7264
7265impl Eq for CypherPercentileDiscUdaf {}
7266
7267impl Hash for CypherPercentileDiscUdaf {
7268 fn hash<H: Hasher>(&self, state: &mut H) {
7269 self.name().hash(state);
7270 }
7271}
7272
7273impl AggregateUDFImpl for CypherPercentileDiscUdaf {
7274 fn as_any(&self) -> &dyn Any {
7275 self
7276 }
7277 fn name(&self) -> &str {
7278 "percentiledisc"
7279 }
7280 fn signature(&self) -> &Signature {
7281 &self.signature
7282 }
7283 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7284 Ok(DataType::Float64)
7285 }
7286 fn accumulator(
7287 &self,
7288 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7289 ) -> DFResult<Box<dyn DfAccumulator>> {
7290 Ok(Box::new(CypherPercentileDiscAccumulator {
7291 values: Vec::new(),
7292 percentile: None,
7293 }))
7294 }
7295 fn state_fields(
7296 &self,
7297 args: datafusion::logical_expr::function::StateFieldsArgs,
7298 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7299 Ok(vec![
7300 Arc::new(arrow::datatypes::Field::new(
7301 format!("{}_values", args.name),
7302 DataType::List(Arc::new(arrow::datatypes::Field::new(
7303 "item",
7304 DataType::Float64,
7305 true,
7306 ))),
7307 true,
7308 )),
7309 Arc::new(arrow::datatypes::Field::new(
7310 format!("{}_percentile", args.name),
7311 DataType::Float64,
7312 true,
7313 )),
7314 ])
7315 }
7316}
7317
7318#[derive(Debug)]
7319struct CypherPercentileDiscAccumulator {
7320 values: Vec<f64>,
7321 percentile: Option<f64>,
7322}
7323
7324impl CypherPercentileDiscAccumulator {
7325 fn extract_f64(arr: &ArrayRef, i: usize) -> Option<f64> {
7326 if arr.is_null(i) {
7327 return None;
7328 }
7329 match arr.data_type() {
7330 DataType::LargeBinary => {
7331 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
7332 cv_bytes_as_f64(lb.value(i))
7333 }
7334 DataType::Int64 => {
7335 let a = arr.as_any().downcast_ref::<Int64Array>()?;
7336 Some(a.value(i) as f64)
7337 }
7338 DataType::Float64 => {
7339 let a = arr.as_any().downcast_ref::<Float64Array>()?;
7340 Some(a.value(i))
7341 }
7342 DataType::Int32 => {
7343 let a = arr.as_any().downcast_ref::<Int32Array>()?;
7344 Some(a.value(i) as f64)
7345 }
7346 DataType::Float32 => {
7347 let a = arr.as_any().downcast_ref::<Float32Array>()?;
7348 Some(a.value(i) as f64)
7349 }
7350 _ => None,
7351 }
7352 }
7353
7354 fn extract_percentile(arr: &ArrayRef, i: usize) -> Option<f64> {
7355 if arr.is_null(i) {
7356 return None;
7357 }
7358 match arr.data_type() {
7359 DataType::Float64 => {
7360 let a = arr.as_any().downcast_ref::<Float64Array>()?;
7361 Some(a.value(i))
7362 }
7363 DataType::Int64 => {
7364 let a = arr.as_any().downcast_ref::<Int64Array>()?;
7365 Some(a.value(i) as f64)
7366 }
7367 DataType::LargeBinary => {
7368 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
7369 cv_bytes_as_f64(lb.value(i))
7370 }
7371 _ => None,
7372 }
7373 }
7374}
7375
7376impl DfAccumulator for CypherPercentileDiscAccumulator {
7377 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7378 let expr_arr = &values[0];
7379 let pct_arr = &values[1];
7380 for i in 0..expr_arr.len() {
7381 if self.percentile.is_none()
7383 && let Some(p) = Self::extract_percentile(pct_arr, i)
7384 {
7385 if !(0.0..=1.0).contains(&p) {
7386 return Err(datafusion::error::DataFusionError::Execution(
7387 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7388 ));
7389 }
7390 self.percentile = Some(p);
7391 }
7392 if let Some(f) = Self::extract_f64(expr_arr, i) {
7393 self.values.push(f);
7394 }
7395 }
7396 Ok(())
7397 }
7398 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7399 let pct = match self.percentile {
7400 Some(p) if !(0.0..=1.0).contains(&p) => {
7401 return Err(datafusion::error::DataFusionError::Execution(
7402 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7403 ));
7404 }
7405 Some(p) => p,
7406 None => 0.0,
7407 };
7408 if self.values.is_empty() {
7409 return Ok(ScalarValue::Float64(None));
7410 }
7411 self.values
7412 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7413 let n = self.values.len();
7414 let idx = (pct * (n as f64 - 1.0)).round() as usize;
7415 let idx = idx.min(n - 1);
7416 let result = self.values[idx];
7417 Ok(ScalarValue::Float64(Some(result)))
7418 }
7419 fn size(&self) -> usize {
7420 std::mem::size_of_val(self) + self.values.capacity() * 8
7421 }
7422 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7423 let list_values: Vec<ScalarValue> = self
7425 .values
7426 .iter()
7427 .map(|f| ScalarValue::Float64(Some(*f)))
7428 .collect();
7429 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7430 &list_values,
7431 &DataType::Float64,
7432 true,
7433 ));
7434 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7435 }
7436 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7437 let list_arr = &states[0];
7439 let pct_arr = &states[1];
7440 if self.percentile.is_none()
7442 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7443 {
7444 for i in 0..f64_arr.len() {
7445 if !f64_arr.is_null(i) {
7446 self.percentile = Some(f64_arr.value(i));
7447 break;
7448 }
7449 }
7450 }
7451 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7453 for i in 0..list_array.len() {
7454 if list_array.is_null(i) {
7455 continue;
7456 }
7457 let inner = list_array.value(i);
7458 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7459 for j in 0..f64_arr.len() {
7460 if !f64_arr.is_null(j) {
7461 self.values.push(f64_arr.value(j));
7462 }
7463 }
7464 }
7465 }
7466 }
7467 Ok(())
7468 }
7469}
7470
7471#[derive(Debug, Clone)]
7473struct CypherPercentileContUdaf {
7474 signature: Signature,
7475}
7476
7477impl CypherPercentileContUdaf {
7478 fn new() -> Self {
7479 Self {
7480 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7481 }
7482 }
7483}
7484
7485impl PartialEq for CypherPercentileContUdaf {
7486 fn eq(&self, other: &Self) -> bool {
7487 self.signature == other.signature
7488 }
7489}
7490
7491impl Eq for CypherPercentileContUdaf {}
7492
7493impl Hash for CypherPercentileContUdaf {
7494 fn hash<H: Hasher>(&self, state: &mut H) {
7495 self.name().hash(state);
7496 }
7497}
7498
7499impl AggregateUDFImpl for CypherPercentileContUdaf {
7500 fn as_any(&self) -> &dyn Any {
7501 self
7502 }
7503 fn name(&self) -> &str {
7504 "percentilecont"
7505 }
7506 fn signature(&self) -> &Signature {
7507 &self.signature
7508 }
7509 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7510 Ok(DataType::Float64)
7511 }
7512 fn accumulator(
7513 &self,
7514 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7515 ) -> DFResult<Box<dyn DfAccumulator>> {
7516 Ok(Box::new(CypherPercentileContAccumulator {
7517 values: Vec::new(),
7518 percentile: None,
7519 }))
7520 }
7521 fn state_fields(
7522 &self,
7523 args: datafusion::logical_expr::function::StateFieldsArgs,
7524 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7525 Ok(vec![
7526 Arc::new(arrow::datatypes::Field::new(
7527 format!("{}_values", args.name),
7528 DataType::List(Arc::new(arrow::datatypes::Field::new(
7529 "item",
7530 DataType::Float64,
7531 true,
7532 ))),
7533 true,
7534 )),
7535 Arc::new(arrow::datatypes::Field::new(
7536 format!("{}_percentile", args.name),
7537 DataType::Float64,
7538 true,
7539 )),
7540 ])
7541 }
7542}
7543
7544#[derive(Debug)]
7545struct CypherPercentileContAccumulator {
7546 values: Vec<f64>,
7547 percentile: Option<f64>,
7548}
7549
7550impl DfAccumulator for CypherPercentileContAccumulator {
7551 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7552 let expr_arr = &values[0];
7553 let pct_arr = &values[1];
7554 for i in 0..expr_arr.len() {
7555 if self.percentile.is_none()
7556 && let Some(p) = CypherPercentileDiscAccumulator::extract_percentile(pct_arr, i)
7557 {
7558 if !(0.0..=1.0).contains(&p) {
7559 return Err(datafusion::error::DataFusionError::Execution(
7560 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7561 ));
7562 }
7563 self.percentile = Some(p);
7564 }
7565 if let Some(f) = CypherPercentileDiscAccumulator::extract_f64(expr_arr, i) {
7566 self.values.push(f);
7567 }
7568 }
7569 Ok(())
7570 }
7571 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7572 let pct = match self.percentile {
7573 Some(p) if !(0.0..=1.0).contains(&p) => {
7574 return Err(datafusion::error::DataFusionError::Execution(
7575 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7576 ));
7577 }
7578 Some(p) => p,
7579 None => 0.0,
7580 };
7581 if self.values.is_empty() {
7582 return Ok(ScalarValue::Float64(None));
7583 }
7584 self.values
7585 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7586 let n = self.values.len();
7587 if n == 1 {
7588 return Ok(ScalarValue::Float64(Some(self.values[0])));
7589 }
7590 let pos = pct * (n as f64 - 1.0);
7591 let lower = pos.floor() as usize;
7592 let upper = pos.ceil() as usize;
7593 let lower = lower.min(n - 1);
7594 let upper = upper.min(n - 1);
7595 if lower == upper {
7596 Ok(ScalarValue::Float64(Some(self.values[lower])))
7597 } else {
7598 let frac = pos - lower as f64;
7599 let result = self.values[lower] + frac * (self.values[upper] - self.values[lower]);
7600 Ok(ScalarValue::Float64(Some(result)))
7601 }
7602 }
7603 fn size(&self) -> usize {
7604 std::mem::size_of_val(self) + self.values.capacity() * 8
7605 }
7606 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7607 let list_values: Vec<ScalarValue> = self
7608 .values
7609 .iter()
7610 .map(|f| ScalarValue::Float64(Some(*f)))
7611 .collect();
7612 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7613 &list_values,
7614 &DataType::Float64,
7615 true,
7616 ));
7617 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7618 }
7619 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7620 let list_arr = &states[0];
7621 let pct_arr = &states[1];
7622 if self.percentile.is_none()
7623 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7624 {
7625 for i in 0..f64_arr.len() {
7626 if !f64_arr.is_null(i) {
7627 self.percentile = Some(f64_arr.value(i));
7628 break;
7629 }
7630 }
7631 }
7632 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7633 for i in 0..list_array.len() {
7634 if list_array.is_null(i) {
7635 continue;
7636 }
7637 let inner = list_array.value(i);
7638 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7639 for j in 0..f64_arr.len() {
7640 if !f64_arr.is_null(j) {
7641 self.values.push(f64_arr.value(j));
7642 }
7643 }
7644 }
7645 }
7646 }
7647 Ok(())
7648 }
7649}
7650
7651pub fn create_cypher_percentile_disc_udaf() -> AggregateUDF {
7652 AggregateUDF::from(CypherPercentileDiscUdaf::new())
7653}
7654
7655pub fn create_cypher_percentile_cont_udaf() -> AggregateUDF {
7656 AggregateUDF::from(CypherPercentileContUdaf::new())
7657}
7658
7659fn invoke_similarity_udf(
7669 func_name: &str,
7670 min_args: usize,
7671 args: ScalarFunctionArgs,
7672) -> DFResult<ColumnarValue> {
7673 let output_type = DataType::Float64;
7674 invoke_cypher_udf(args, &output_type, |val_args| {
7675 if val_args.len() < min_args {
7676 return Err(datafusion::error::DataFusionError::Execution(format!(
7677 "{} requires at least {} arguments",
7678 func_name, min_args
7679 )));
7680 }
7681 crate::similar_to::eval_similar_to_pure(&val_args[0], &val_args[1])
7682 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
7683 })
7684}
7685
7686pub fn create_similar_to_udf() -> ScalarUDF {
7688 ScalarUDF::new_from_impl(SimilarToUdf::new())
7689}
7690
7691#[derive(Debug)]
7692struct SimilarToUdf {
7693 signature: Signature,
7694}
7695
7696impl SimilarToUdf {
7697 fn new() -> Self {
7698 Self {
7699 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
7700 }
7701 }
7702}
7703
7704impl_udf_eq_hash!(SimilarToUdf);
7705
7706impl ScalarUDFImpl for SimilarToUdf {
7707 fn as_any(&self) -> &dyn Any {
7708 self
7709 }
7710
7711 fn name(&self) -> &str {
7712 "similar_to"
7713 }
7714
7715 fn signature(&self) -> &Signature {
7716 &self.signature
7717 }
7718
7719 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7720 Ok(DataType::Float64)
7721 }
7722
7723 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7724 invoke_similarity_udf("similar_to", 2, args)
7725 }
7726}
7727
7728pub fn create_vector_similarity_udf() -> ScalarUDF {
7730 ScalarUDF::new_from_impl(VectorSimilarityUdf::new())
7731}
7732
7733#[derive(Debug)]
7734struct VectorSimilarityUdf {
7735 signature: Signature,
7736}
7737
7738impl VectorSimilarityUdf {
7739 fn new() -> Self {
7740 Self {
7741 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7742 }
7743 }
7744}
7745
7746impl_udf_eq_hash!(VectorSimilarityUdf);
7747
7748impl ScalarUDFImpl for VectorSimilarityUdf {
7749 fn as_any(&self) -> &dyn Any {
7750 self
7751 }
7752
7753 fn name(&self) -> &str {
7754 "vector_similarity"
7755 }
7756
7757 fn signature(&self) -> &Signature {
7758 &self.signature
7759 }
7760
7761 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7762 Ok(DataType::Float64)
7763 }
7764
7765 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7766 invoke_similarity_udf("vector_similarity", 2, args)
7767 }
7768}
7769
7770fn invoke_sparse_similarity_udf(args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7775 let output_type = DataType::Float64;
7776 invoke_cypher_udf(args, &output_type, |val_args| {
7777 if val_args.len() != 2 {
7778 return Err(datafusion::error::DataFusionError::Execution(
7779 "sparse_similar_to takes 2 arguments".to_string(),
7780 ));
7781 }
7782 crate::similar_to::eval_sparse_similar_to_pure(&val_args[0], &val_args[1])
7783 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
7784 })
7785}
7786
7787pub fn create_sparse_similar_to_udf() -> ScalarUDF {
7789 ScalarUDF::new_from_impl(SparseSimilarToUdf::new())
7790}
7791
7792#[derive(Debug)]
7793struct SparseSimilarToUdf {
7794 signature: Signature,
7795}
7796
7797impl SparseSimilarToUdf {
7798 fn new() -> Self {
7799 Self {
7800 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7801 }
7802 }
7803}
7804
7805impl_udf_eq_hash!(SparseSimilarToUdf);
7806
7807impl ScalarUDFImpl for SparseSimilarToUdf {
7808 fn as_any(&self) -> &dyn Any {
7809 self
7810 }
7811
7812 fn name(&self) -> &str {
7813 "sparse_similar_to"
7814 }
7815
7816 fn signature(&self) -> &Signature {
7817 &self.signature
7818 }
7819
7820 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7821 Ok(DataType::Float64)
7822 }
7823
7824 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7825 invoke_sparse_similarity_udf(args)
7826 }
7827}
7828
7829#[cfg(test)]
7830mod tests {
7831 use super::*;
7832 use datafusion::execution::FunctionRegistry;
7833
7834 #[test]
7835 fn test_register_udfs() {
7836 let ctx = SessionContext::new();
7837 register_cypher_udfs(&ctx).unwrap();
7838
7839 assert!(ctx.udf("id").is_ok());
7842 assert!(ctx.udf("type").is_ok());
7843 assert!(ctx.udf("keys").is_ok());
7844 assert!(ctx.udf("range").is_ok());
7845 assert!(
7846 ctx.udf("_make_cypher_list").is_ok(),
7847 "_make_cypher_list UDF should be registered"
7848 );
7849 assert!(
7850 ctx.udf("_cv_to_bool").is_ok(),
7851 "_cv_to_bool UDF should be registered"
7852 );
7853 }
7854
7855 #[test]
7856 fn test_id_udf_signature() {
7857 let udf = create_id_udf();
7858 assert_eq!(udf.name(), "id");
7859 }
7860
7861 #[test]
7862 fn test_has_null_udf() {
7863 use datafusion::arrow::datatypes::{DataType, Field};
7864 use datafusion::config::ConfigOptions;
7865 use datafusion::scalar::ScalarValue;
7866 use std::sync::Arc;
7867
7868 let udf = create_has_null_udf();
7869
7870 let values = vec![
7872 ScalarValue::Int64(Some(1)),
7873 ScalarValue::Int64(Some(2)),
7874 ScalarValue::Int64(None),
7875 ];
7876
7877 let list_scalar = ScalarValue::List(ScalarValue::new_list(&values, &DataType::Int64, true));
7879
7880 let list_field = Arc::new(Field::new(
7881 "item",
7882 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
7883 true,
7884 ));
7885
7886 let args = ScalarFunctionArgs {
7887 args: vec![ColumnarValue::Scalar(list_scalar)],
7888 arg_fields: vec![list_field],
7889 number_rows: 1,
7890 return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
7891 config_options: Arc::new(ConfigOptions::default()),
7892 };
7893
7894 let result = udf.invoke_with_args(args).unwrap();
7895
7896 if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = result {
7897 assert!(b, "has_null should return true for list with null");
7898 } else {
7899 panic!("Unexpected result: {:?}", result);
7900 }
7901 }
7902
7903 fn json_to_cv_bytes(val: &serde_json::Value) -> Vec<u8> {
7909 let uni_val: uni_common::Value = val.clone().into();
7910 uni_common::cypher_value_codec::encode(&uni_val)
7911 }
7912
7913 fn make_multi_scalar_args(scalars: Vec<ScalarValue>) -> ScalarFunctionArgs {
7922 make_multi_scalar_args_with_return(scalars, DataType::LargeBinary)
7923 }
7924
7925 fn make_multi_scalar_args_with_return(
7926 scalars: Vec<ScalarValue>,
7927 return_type: DataType,
7928 ) -> ScalarFunctionArgs {
7929 use datafusion::arrow::datatypes::Field;
7930 use datafusion::config::ConfigOptions;
7931
7932 let arg_fields: Vec<_> = scalars
7933 .iter()
7934 .enumerate()
7935 .map(|(i, s)| Arc::new(Field::new(format!("arg{i}"), s.data_type(), true)))
7936 .collect();
7937 let args: Vec<_> = scalars.into_iter().map(ColumnarValue::Scalar).collect();
7938 ScalarFunctionArgs {
7939 args,
7940 arg_fields,
7941 number_rows: 1,
7942 return_field: Arc::new(Field::new("result", return_type, true)),
7943 config_options: Arc::new(ConfigOptions::default()),
7944 }
7945 }
7946
7947 fn decode_cv_scalar(cv: &ColumnarValue) -> serde_json::Value {
7949 match cv {
7950 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7951 let val = uni_common::cypher_value_codec::decode(bytes)
7952 .expect("failed to decode CypherValue output");
7953 val.into()
7954 }
7955 other => panic!("expected LargeBinary scalar, got {other:?}"),
7956 }
7957 }
7958
7959 #[test]
7960 fn test_make_cypher_list_scalars() {
7961 let udf = create_make_cypher_list_udf();
7962 let args = make_multi_scalar_args(vec![
7963 ScalarValue::Int64(Some(1)),
7964 ScalarValue::Float64(Some(3.21)),
7965 ScalarValue::Utf8(Some("hello".to_string())),
7966 ScalarValue::Boolean(Some(true)),
7967 ScalarValue::Null,
7968 ]);
7969 let result = udf.invoke_with_args(args).unwrap();
7970 let json = decode_cv_scalar(&result);
7971 let arr = json.as_array().expect("should be array");
7972 assert_eq!(arr.len(), 5);
7973 assert_eq!(arr[0], serde_json::json!(1));
7974 assert_eq!(arr[1], serde_json::json!(3.21));
7975 assert_eq!(arr[2], serde_json::json!("hello"));
7976 assert_eq!(arr[3], serde_json::json!(true));
7977 assert!(arr[4].is_null());
7978 }
7979
7980 #[test]
7981 fn test_make_cypher_list_empty() {
7982 let udf = create_make_cypher_list_udf();
7983 let args = make_multi_scalar_args(vec![]);
7984 let result = udf.invoke_with_args(args).unwrap();
7985 let json = decode_cv_scalar(&result);
7986 let arr = json.as_array().expect("should be array");
7987 assert!(arr.is_empty());
7988 }
7989
7990 #[test]
7991 fn test_make_cypher_list_single() {
7992 let udf = create_make_cypher_list_udf();
7993 let args = make_multi_scalar_args(vec![ScalarValue::Int64(Some(42))]);
7994 let result = udf.invoke_with_args(args).unwrap();
7995 let json = decode_cv_scalar(&result);
7996 let arr = json.as_array().expect("should be array");
7997 assert_eq!(arr.len(), 1);
7998 assert_eq!(arr[0], serde_json::json!(42));
7999 }
8000
8001 #[test]
8002 fn test_make_cypher_list_nested_cypher_value() {
8003 let udf = create_make_cypher_list_udf();
8004 let nested_bytes = json_to_cv_bytes(&serde_json::json!([1, 2]));
8006 let args = make_multi_scalar_args(vec![
8007 ScalarValue::LargeBinary(Some(nested_bytes)),
8008 ScalarValue::Int64(Some(3)),
8009 ]);
8010 let result = udf.invoke_with_args(args).unwrap();
8011 let json = decode_cv_scalar(&result);
8012 let arr = json.as_array().expect("should be array");
8013 assert_eq!(arr.len(), 2);
8014 assert_eq!(arr[0], serde_json::json!([1, 2]));
8015 assert_eq!(arr[1], serde_json::json!(3));
8016 }
8017
8018 fn make_cypher_in_args(
8024 element: &serde_json::Value,
8025 list: &serde_json::Value,
8026 ) -> ScalarFunctionArgs {
8027 make_multi_scalar_args_with_return(
8028 vec![
8029 ScalarValue::LargeBinary(Some(json_to_cv_bytes(element))),
8030 ScalarValue::LargeBinary(Some(json_to_cv_bytes(list))),
8031 ],
8032 DataType::Boolean,
8033 )
8034 }
8035
8036 #[test]
8037 fn test_cypher_in_found() {
8038 let udf = create_cypher_in_udf();
8039 let args = make_cypher_in_args(&serde_json::json!(3), &serde_json::json!([1, 2, 3]));
8040 let result = udf.invoke_with_args(args).unwrap();
8041 match result {
8042 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
8043 other => panic!("expected Boolean(true), got {other:?}"),
8044 }
8045 }
8046
8047 #[test]
8048 fn test_cypher_in_not_found() {
8049 let udf = create_cypher_in_udf();
8050 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, 2, 3]));
8051 let result = udf.invoke_with_args(args).unwrap();
8052 match result {
8053 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
8054 other => panic!("expected Boolean(false), got {other:?}"),
8055 }
8056 }
8057
8058 #[test]
8059 fn test_cypher_in_null_list() {
8060 let udf = create_cypher_in_udf();
8061 let args = make_multi_scalar_args_with_return(
8062 vec![
8063 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(1)))),
8064 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8065 ],
8066 DataType::Boolean,
8067 );
8068 let result = udf.invoke_with_args(args).unwrap();
8069 match result {
8070 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null list, got {other:?}"),
8072 }
8073 }
8074
8075 #[test]
8076 fn test_cypher_in_null_element_nonempty() {
8077 let udf = create_cypher_in_udf();
8078 let args = make_multi_scalar_args_with_return(
8079 vec![
8080 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8081 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8082 ],
8083 DataType::Boolean,
8084 );
8085 let result = udf.invoke_with_args(args).unwrap();
8086 match result {
8087 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null IN non-empty list, got {other:?}"),
8089 }
8090 }
8091
8092 #[test]
8093 fn test_cypher_in_null_element_empty() {
8094 let udf = create_cypher_in_udf();
8095 let args = make_multi_scalar_args_with_return(
8096 vec![
8097 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8098 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
8099 ],
8100 DataType::Boolean,
8101 );
8102 let result = udf.invoke_with_args(args).unwrap();
8103 match result {
8104 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
8105 other => panic!("expected Boolean(false) for null IN [], got {other:?}"),
8106 }
8107 }
8108
8109 #[test]
8110 fn test_cypher_in_not_found_with_null() {
8111 let udf = create_cypher_in_udf();
8112 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, null, 3]));
8113 let result = udf.invoke_with_args(args).unwrap();
8114 match result {
8115 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for 4 IN [1,null,3], got {other:?}"),
8117 }
8118 }
8119
8120 #[test]
8121 fn test_cypher_in_cross_type_int_float() {
8122 let udf = create_cypher_in_udf();
8123 let args = make_cypher_in_args(&serde_json::json!(1), &serde_json::json!([1.0, 2.0]));
8124 let result = udf.invoke_with_args(args).unwrap();
8125 match result {
8126 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
8127 other => panic!("expected Boolean(true) for 1 IN [1.0, 2.0], got {other:?}"),
8128 }
8129 }
8130
8131 #[test]
8136 fn test_list_concat_basic() {
8137 let udf = create_cypher_list_concat_udf();
8138 let args = make_multi_scalar_args(vec![
8139 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8140 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([3, 4])))),
8141 ]);
8142 let result = udf.invoke_with_args(args).unwrap();
8143 let json = decode_cv_scalar(&result);
8144 assert_eq!(json, serde_json::json!([1, 2, 3, 4]));
8145 }
8146
8147 #[test]
8148 fn test_list_concat_empty() {
8149 let udf = create_cypher_list_concat_udf();
8150 let args = make_multi_scalar_args(vec![
8151 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
8152 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
8153 ]);
8154 let result = udf.invoke_with_args(args).unwrap();
8155 let json = decode_cv_scalar(&result);
8156 assert_eq!(json, serde_json::json!([1]));
8157 }
8158
8159 #[test]
8160 fn test_list_concat_null_left() {
8161 let udf = create_cypher_list_concat_udf();
8162 let args = make_multi_scalar_args(vec![
8163 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8164 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
8165 ]);
8166 let result = udf.invoke_with_args(args).unwrap();
8167 match result {
8168 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8169 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8170 let json: serde_json::Value = uni_val.into();
8171 assert!(json.is_null(), "expected null, got {json}");
8172 }
8173 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {} other => panic!("expected null result, got {other:?}"),
8175 }
8176 }
8177
8178 #[test]
8179 fn test_list_concat_null_right() {
8180 let udf = create_cypher_list_concat_udf();
8181 let args = make_multi_scalar_args(vec![
8182 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
8183 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8184 ]);
8185 let result = udf.invoke_with_args(args).unwrap();
8186 match result {
8187 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8188 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8189 let json: serde_json::Value = uni_val.into();
8190 assert!(json.is_null(), "expected null, got {json}");
8191 }
8192 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
8193 other => panic!("expected null result, got {other:?}"),
8194 }
8195 }
8196
8197 #[test]
8202 fn test_list_append_scalar() {
8203 let udf = create_cypher_list_append_udf();
8204 let args = make_multi_scalar_args(vec![
8205 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8206 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
8207 ]);
8208 let result = udf.invoke_with_args(args).unwrap();
8209 let json = decode_cv_scalar(&result);
8210 assert_eq!(json, serde_json::json!([1, 2, 3]));
8211 }
8212
8213 #[test]
8214 fn test_list_prepend_scalar() {
8215 let udf = create_cypher_list_append_udf();
8216 let args = make_multi_scalar_args(vec![
8217 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
8218 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8219 ]);
8220 let result = udf.invoke_with_args(args).unwrap();
8221 let json = decode_cv_scalar(&result);
8222 assert_eq!(json, serde_json::json!([3, 1, 2]));
8223 }
8224
8225 #[test]
8226 fn test_list_append_null_list() {
8227 let udf = create_cypher_list_append_udf();
8228 let args = make_multi_scalar_args(vec![
8229 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8230 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
8231 ]);
8232 let result = udf.invoke_with_args(args).unwrap();
8233 match result {
8234 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8235 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8236 let json: serde_json::Value = uni_val.into();
8237 assert!(json.is_null(), "expected null, got {json}");
8238 }
8239 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
8240 other => panic!("expected null result, got {other:?}"),
8241 }
8242 }
8243
8244 #[test]
8245 fn test_list_append_null_scalar() {
8246 let udf = create_cypher_list_append_udf();
8247 let args = make_multi_scalar_args(vec![
8248 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
8249 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
8250 ]);
8251 let result = udf.invoke_with_args(args).unwrap();
8252 match result {
8253 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
8254 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
8255 let json: serde_json::Value = uni_val.into();
8256 assert!(json.is_null(), "expected null, got {json}");
8257 }
8258 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
8259 other => panic!("expected null result, got {other:?}"),
8260 }
8261 }
8262
8263 #[test]
8268 fn test_sort_key_cross_type_ordering() {
8269 use uni_common::core::id::{Eid, Vid};
8272 use uni_common::{Edge, Node, Path, TemporalValue, Value};
8273
8274 let map_val = Value::Map([("a".to_string(), Value::String("map".to_string()))].into());
8275 let node_val = Value::Node(Node {
8276 vid: Vid::new(1),
8277 labels: vec!["L".to_string()],
8278 properties: Default::default(),
8279 });
8280 let edge_val = Value::Edge(Edge {
8281 eid: Eid::new(1),
8282 edge_type: "T".to_string(),
8283 src: Vid::new(1),
8284 dst: Vid::new(2),
8285 properties: Default::default(),
8286 });
8287 let list_val = Value::List(vec![Value::Int(1)]);
8288 let path_val = Value::Path(Path {
8289 nodes: vec![Node {
8290 vid: Vid::new(1),
8291 labels: vec!["L".to_string()],
8292 properties: Default::default(),
8293 }],
8294 edges: vec![],
8295 });
8296 let string_val = Value::String("hello".to_string());
8297 let bool_val = Value::Bool(false);
8298 let temporal_val = Value::Temporal(TemporalValue::Date {
8299 days_since_epoch: 1000,
8300 });
8301 let number_val = Value::Int(42);
8302 let nan_val = Value::Float(f64::NAN);
8303 let null_val = Value::Null;
8304
8305 let values = vec![
8306 &map_val,
8307 &node_val,
8308 &edge_val,
8309 &list_val,
8310 &path_val,
8311 &string_val,
8312 &bool_val,
8313 &temporal_val,
8314 &number_val,
8315 &nan_val,
8316 &null_val,
8317 ];
8318
8319 let keys: Vec<Vec<u8>> = values.iter().map(|v| encode_cypher_sort_key(v)).collect();
8320
8321 for i in 0..keys.len() - 1 {
8323 assert!(
8324 keys[i] < keys[i + 1],
8325 "Expected sort_key({:?}) < sort_key({:?}), but {:?} >= {:?}",
8326 values[i],
8327 values[i + 1],
8328 keys[i],
8329 keys[i + 1]
8330 );
8331 }
8332 }
8333
8334 #[test]
8335 fn test_sort_key_numbers() {
8336 let neg_inf = encode_cypher_sort_key(&Value::Float(f64::NEG_INFINITY));
8337 let neg_100 = encode_cypher_sort_key(&Value::Float(-100.0));
8338 let neg_1 = encode_cypher_sort_key(&Value::Int(-1));
8339 let zero_int = encode_cypher_sort_key(&Value::Int(0));
8340 let zero_float = encode_cypher_sort_key(&Value::Float(0.0));
8341 let one_int = encode_cypher_sort_key(&Value::Int(1));
8342 let one_float = encode_cypher_sort_key(&Value::Float(1.0));
8343 let hundred = encode_cypher_sort_key(&Value::Int(100));
8344 let pos_inf = encode_cypher_sort_key(&Value::Float(f64::INFINITY));
8345 let nan = encode_cypher_sort_key(&Value::Float(f64::NAN));
8346
8347 assert!(neg_inf < neg_100, "-inf < -100");
8348 assert!(neg_100 < neg_1, "-100 < -1");
8349 assert!(neg_1 < zero_int, "-1 < 0");
8350 assert_eq!(zero_int, zero_float, "0 int == 0.0 float");
8351 assert!(zero_int < one_int, "0 < 1");
8352 assert_eq!(one_int, one_float, "1 int == 1.0 float");
8353 assert!(one_int < hundred, "1 < 100");
8354 assert!(hundred < pos_inf, "100 < +inf");
8355 assert!(pos_inf < nan, "+inf < NaN");
8357 }
8358
8359 #[test]
8360 fn test_sort_key_booleans() {
8361 let f = encode_cypher_sort_key(&Value::Bool(false));
8362 let t = encode_cypher_sort_key(&Value::Bool(true));
8363 assert!(f < t, "false < true");
8364 }
8365
8366 #[test]
8367 fn test_sort_key_strings() {
8368 let empty = encode_cypher_sort_key(&Value::String(String::new()));
8369 let a = encode_cypher_sort_key(&Value::String("a".to_string()));
8370 let ab = encode_cypher_sort_key(&Value::String("ab".to_string()));
8371 let b = encode_cypher_sort_key(&Value::String("b".to_string()));
8372
8373 assert!(empty < a, "'' < 'a'");
8374 assert!(a < ab, "'a' < 'ab'");
8375 assert!(ab < b, "'ab' < 'b'");
8376 }
8377
8378 #[test]
8379 fn test_sort_key_lists() {
8380 let empty = encode_cypher_sort_key(&Value::List(vec![]));
8381 let one = encode_cypher_sort_key(&Value::List(vec![Value::Int(1)]));
8382 let one_two = encode_cypher_sort_key(&Value::List(vec![Value::Int(1), Value::Int(2)]));
8383 let two = encode_cypher_sort_key(&Value::List(vec![Value::Int(2)]));
8384
8385 assert!(empty < one, "[] < [1]");
8386 assert!(one < one_two, "[1] < [1,2]");
8387 assert!(one_two < two, "[1,2] < [2]");
8388 }
8389
8390 #[test]
8391 fn test_sort_key_temporal() {
8392 use uni_common::TemporalValue;
8393
8394 let date1 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8395 days_since_epoch: 100,
8396 }));
8397 let date2 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8398 days_since_epoch: 200,
8399 }));
8400 assert!(date1 < date2, "earlier date < later date");
8401
8402 let date = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
8404 days_since_epoch: i32::MAX,
8405 }));
8406 let local_time = encode_cypher_sort_key(&Value::Temporal(TemporalValue::LocalTime {
8407 nanos_since_midnight: 0,
8408 }));
8409 assert!(date < local_time, "Date < LocalTime (by variant rank)");
8410 }
8411
8412 #[test]
8413 fn test_sort_key_nested_lists() {
8414 let inner_a = Value::List(vec![Value::Int(1)]);
8415 let inner_b = Value::List(vec![Value::Int(2)]);
8416
8417 let list_a = encode_cypher_sort_key(&Value::List(vec![inner_a.clone()]));
8418 let list_b = encode_cypher_sort_key(&Value::List(vec![inner_b.clone()]));
8419
8420 assert!(list_a < list_b, "[[1]] < [[2]]");
8421 }
8422
8423 #[test]
8424 fn test_sort_key_null_handling() {
8425 let null_key = encode_cypher_sort_key(&Value::Null);
8426 assert_eq!(null_key, vec![0x0A], "Null produces [0x0A]");
8427
8428 let number_key = encode_cypher_sort_key(&Value::Int(42));
8430 assert!(number_key < null_key, "number < null");
8431 }
8432
8433 #[test]
8434 fn test_byte_stuff_roundtrip() {
8435 let s1 = Value::String("a\x00b".to_string());
8437 let s2 = Value::String("a\x00c".to_string());
8438 let s3 = Value::String("a\x01".to_string());
8439
8440 let k1 = encode_cypher_sort_key(&s1);
8441 let k2 = encode_cypher_sort_key(&s2);
8442 let k3 = encode_cypher_sort_key(&s3);
8443
8444 assert!(k1 < k2, "a\\x00b < a\\x00c");
8445 assert!(k1 < k3, "a\\x00b < a\\x01");
8448 }
8449
8450 #[test]
8451 fn test_sort_key_order_preserving_f64() {
8452 let vals = [f64::NEG_INFINITY, -1.0, -0.0, 0.0, 1.0, f64::INFINITY];
8454 let encoded: Vec<[u8; 8]> = vals
8455 .iter()
8456 .map(|f| encode_order_preserving_f64(*f))
8457 .collect();
8458
8459 for i in 0..encoded.len() - 1 {
8460 assert!(
8461 encoded[i] <= encoded[i + 1],
8462 "encode({}) should <= encode({}), got {:?} vs {:?}",
8463 vals[i],
8464 vals[i + 1],
8465 encoded[i],
8466 encoded[i + 1]
8467 );
8468 }
8469 }
8470
8471 #[test]
8475 fn test_sort_key_string_as_temporal_time_with_offset() {
8476 let tv = sort_key_string_as_temporal("12:35:15+05:00")
8477 .expect("should parse Time with positive offset");
8478 match tv {
8479 uni_common::TemporalValue::Time {
8480 nanos_since_midnight,
8481 offset_seconds,
8482 } => {
8483 assert_eq!(offset_seconds, 5 * 3600, "offset should be +05:00 = 18000s");
8484 let expected_nanos = (12 * 3600 + 35 * 60 + 15) * 1_000_000_000i64;
8486 assert_eq!(nanos_since_midnight, expected_nanos);
8487 }
8488 other => panic!("expected TemporalValue::Time, got {other:?}"),
8489 }
8490 }
8491
8492 #[test]
8493 fn test_sort_key_string_as_temporal_time_negative_offset() {
8494 let tv = sort_key_string_as_temporal("10:35:00-08:00")
8495 .expect("should parse Time with negative offset");
8496 match tv {
8497 uni_common::TemporalValue::Time {
8498 nanos_since_midnight,
8499 offset_seconds,
8500 } => {
8501 assert_eq!(
8502 offset_seconds,
8503 -8 * 3600,
8504 "offset should be -08:00 = -28800s"
8505 );
8506 let expected_nanos = (10 * 3600 + 35 * 60) * 1_000_000_000i64;
8507 assert_eq!(nanos_since_midnight, expected_nanos);
8508 }
8509 other => panic!("expected TemporalValue::Time, got {other:?}"),
8510 }
8511 }
8512
8513 #[test]
8514 fn test_sort_key_string_as_temporal_date() {
8515 use super::super::expr_eval::temporal_from_value;
8516 let tv = temporal_from_value(&Value::String("2024-01-15".into()))
8517 .expect("should parse Date string");
8518 match tv {
8519 uni_common::TemporalValue::Date { days_since_epoch } => {
8520 assert!(days_since_epoch > 0, "2024-01-15 should be after epoch");
8522 }
8523 other => panic!("expected TemporalValue::Date, got {other:?}"),
8524 }
8525 }
8526}