1use arrow::array::ArrayRef;
27use arrow::datatypes::DataType;
28use arrow_array::{
29 Array, BooleanArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array,
30 LargeBinaryArray, LargeStringArray, StringArray, UInt64Array,
31};
32use chrono::Offset;
33use datafusion::error::Result as DFResult;
34use datafusion::logical_expr::{
35 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
36 Volatility,
37};
38use datafusion::prelude::SessionContext;
39use datafusion::scalar::ScalarValue;
40use std::any::Any;
41use std::hash::{Hash, Hasher};
42use std::sync::Arc;
43use uni_common::Value;
44use uni_cypher::ast::BinaryOp;
45use uni_store::storage::arrow_convert::values_to_array;
46
47use super::expr_eval::cypher_eq;
48
49macro_rules! impl_udf_eq_hash {
53 ($type:ty) => {
54 impl PartialEq for $type {
55 fn eq(&self, other: &Self) -> bool {
56 self.signature == other.signature
57 }
58 }
59
60 impl Eq for $type {}
61
62 impl Hash for $type {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.name().hash(state);
65 }
66 }
67 };
68}
69
70pub fn register_cypher_udfs(ctx: &SessionContext) -> DFResult<()> {
80 ctx.register_udf(create_id_udf());
81 ctx.register_udf(create_type_udf());
82 ctx.register_udf(create_keys_udf());
83 ctx.register_udf(create_properties_udf());
84 ctx.register_udf(create_labels_udf());
85 ctx.register_udf(create_nodes_udf());
86 ctx.register_udf(create_relationships_udf());
87 ctx.register_udf(create_range_udf());
88 ctx.register_udf(create_index_udf());
89 ctx.register_udf(create_startnode_udf());
90 ctx.register_udf(create_endnode_udf());
91
92 ctx.register_udf(create_to_integer_udf());
94 ctx.register_udf(create_to_float_udf());
95 ctx.register_udf(create_to_boolean_udf());
96
97 ctx.register_udf(create_bitwise_or_udf());
99 ctx.register_udf(create_bitwise_and_udf());
100 ctx.register_udf(create_bitwise_xor_udf());
101 ctx.register_udf(create_bitwise_not_udf());
102 ctx.register_udf(create_shift_left_udf());
103 ctx.register_udf(create_shift_right_udf());
104
105 for name in &[
107 "date",
109 "time",
110 "localtime",
111 "localdatetime",
112 "datetime",
113 "duration",
114 "duration.between",
116 "duration.inmonths",
117 "duration.indays",
118 "duration.inseconds",
119 "datetime.fromepoch",
120 "datetime.fromepochmillis",
121 "date.truncate",
123 "time.truncate",
124 "datetime.truncate",
125 "localdatetime.truncate",
126 "localtime.truncate",
127 "datetime.transaction",
129 "datetime.statement",
130 "datetime.realtime",
131 "date.transaction",
132 "date.statement",
133 "date.realtime",
134 "time.transaction",
135 "time.statement",
136 "time.realtime",
137 "localtime.transaction",
138 "localtime.statement",
139 "localtime.realtime",
140 "localdatetime.transaction",
141 "localdatetime.statement",
142 "localdatetime.realtime",
143 ] {
144 ctx.register_udf(create_temporal_udf(name));
145 }
146
147 ctx.register_udf(create_duration_property_udf());
149 ctx.register_udf(create_temporal_property_udf());
150 ctx.register_udf(create_tostring_udf());
151 ctx.register_udf(create_cypher_sort_key_udf());
152 ctx.register_udf(create_has_null_udf());
153 ctx.register_udf(create_cypher_size_udf());
154
155 ctx.register_udf(create_cypher_starts_with_udf());
157 ctx.register_udf(create_cypher_ends_with_udf());
158 ctx.register_udf(create_cypher_contains_udf());
159
160 ctx.register_udf(create_cypher_list_compare_udf());
162
163 ctx.register_udf(create_cypher_xor_udf());
165
166 ctx.register_udf(create_cypher_equal_udf());
168 ctx.register_udf(create_cypher_not_equal_udf());
169 ctx.register_udf(create_cypher_gt_udf());
170 ctx.register_udf(create_cypher_gt_eq_udf());
171 ctx.register_udf(create_cypher_lt_udf());
172 ctx.register_udf(create_cypher_lt_eq_udf());
173
174 ctx.register_udf(create_cv_to_bool_udf());
176
177 ctx.register_udf(create_cypher_add_udf());
179 ctx.register_udf(create_cypher_sub_udf());
180 ctx.register_udf(create_cypher_mul_udf());
181 ctx.register_udf(create_cypher_div_udf());
182 ctx.register_udf(create_cypher_mod_udf());
183
184 ctx.register_udf(create_map_project_udf());
186
187 ctx.register_udf(create_make_cypher_list_udf());
189
190 ctx.register_udf(create_cypher_in_udf());
192
193 ctx.register_udf(create_cypher_list_concat_udf());
195 ctx.register_udf(create_cypher_list_append_udf());
196 ctx.register_udf(create_cypher_list_slice_udf());
197 ctx.register_udf(create_cypher_tail_udf());
198 ctx.register_udf(create_cypher_head_udf());
199 ctx.register_udf(create_cypher_last_udf());
200 ctx.register_udf(create_cypher_reverse_udf());
201 ctx.register_udf(create_cypher_substring_udf());
202 ctx.register_udf(create_cypher_split_udf());
203 ctx.register_udf(create_cypher_list_to_cv_udf());
204 ctx.register_udf(create_cypher_scalar_to_cv_udf());
205
206 for name in &["year", "month", "day", "hour", "minute", "second"] {
208 ctx.register_udf(create_temporal_udf(name));
209 }
210
211 ctx.register_udf(create_cypher_to_float64_udf());
213
214 ctx.register_udf(create_similar_to_udf());
216 ctx.register_udf(create_vector_similarity_udf());
217
218 ctx.register_udaf(create_cypher_min_udaf());
220 ctx.register_udaf(create_cypher_max_udaf());
221 ctx.register_udaf(create_cypher_sum_udaf());
222 ctx.register_udaf(create_cypher_collect_udaf());
223
224 ctx.register_udaf(create_cypher_percentile_disc_udaf());
226 ctx.register_udaf(create_cypher_percentile_cont_udaf());
227
228 register_btic_scalar_udfs(ctx)?;
230
231 ctx.register_udaf(create_btic_min_udaf());
233 ctx.register_udaf(create_btic_max_udaf());
234 ctx.register_udaf(create_btic_span_agg_udaf());
235 ctx.register_udaf(create_btic_count_at_udaf());
236
237 Ok(())
238}
239
240pub fn register_custom_udfs(
243 ctx: &SessionContext,
244 registry: &super::executor::custom_functions::CustomFunctionRegistry,
245) -> DFResult<()> {
246 for (name, func) in registry.iter() {
247 let lower = name.to_lowercase();
250 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
251 lower,
252 func.clone(),
253 )));
254 ctx.register_udf(ScalarUDF::new_from_impl(CustomScalarUdf::new(
256 name.to_string(),
257 func.clone(),
258 )));
259 }
260 Ok(())
261}
262
263struct CustomScalarUdf {
268 name: String,
269 func: super::executor::custom_functions::CustomScalarFn,
270 signature: Signature,
271}
272
273impl CustomScalarUdf {
274 fn new(name: String, func: super::executor::custom_functions::CustomScalarFn) -> Self {
275 Self {
276 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Volatile),
277 name,
278 func,
279 }
280 }
281}
282
283impl std::fmt::Debug for CustomScalarUdf {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 f.debug_struct("CustomScalarUdf")
286 .field("name", &self.name)
287 .finish()
288 }
289}
290
291impl_udf_eq_hash!(CustomScalarUdf);
292
293impl ScalarUDFImpl for CustomScalarUdf {
294 fn as_any(&self) -> &dyn Any {
295 self
296 }
297
298 fn name(&self) -> &str {
299 &self.name
300 }
301
302 fn signature(&self) -> &Signature {
303 &self.signature
304 }
305
306 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
307 Ok(DataType::LargeBinary)
308 }
309
310 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
311 let func = &self.func;
312 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
313 func(vals).map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
314 })
315 }
316}
317
318pub fn create_id_udf() -> ScalarUDF {
326 ScalarUDF::new_from_impl(IdUdf::new())
327}
328
329#[derive(Debug)]
330struct IdUdf {
331 signature: Signature,
332}
333
334impl IdUdf {
335 fn new() -> Self {
336 Self {
337 signature: Signature::new(
338 TypeSignature::Exact(vec![DataType::UInt64]),
339 Volatility::Immutable,
340 ),
341 }
342 }
343}
344
345impl_udf_eq_hash!(IdUdf);
346
347impl ScalarUDFImpl for IdUdf {
348 fn as_any(&self) -> &dyn Any {
349 self
350 }
351
352 fn name(&self) -> &str {
353 "id"
354 }
355
356 fn signature(&self) -> &Signature {
357 &self.signature
358 }
359
360 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
361 Ok(DataType::UInt64)
362 }
363
364 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
365 if args.args.is_empty() {
367 return Err(datafusion::error::DataFusionError::Execution(
368 "id(): requires 1 argument".to_string(),
369 ));
370 }
371 Ok(args.args[0].clone())
372 }
373}
374
375pub fn create_type_udf() -> ScalarUDF {
383 ScalarUDF::new_from_impl(TypeUdf::new())
384}
385
386#[derive(Debug)]
387struct TypeUdf {
388 signature: Signature,
389}
390
391impl TypeUdf {
392 fn new() -> Self {
393 Self {
394 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
398 }
399 }
400}
401
402impl_udf_eq_hash!(TypeUdf);
403
404impl ScalarUDFImpl for TypeUdf {
405 fn as_any(&self) -> &dyn Any {
406 self
407 }
408
409 fn name(&self) -> &str {
410 "type"
411 }
412
413 fn signature(&self) -> &Signature {
414 &self.signature
415 }
416
417 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
418 Ok(DataType::Utf8)
419 }
420
421 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
422 if args.args.is_empty() {
423 return Err(datafusion::error::DataFusionError::Execution(
424 "type(): requires 1 argument".to_string(),
425 ));
426 }
427 let output_type = DataType::Utf8;
428 invoke_cypher_udf(args, &output_type, |val_args| {
429 if val_args.is_empty() {
430 return Err(datafusion::error::DataFusionError::Execution(
431 "type(): requires 1 argument".to_string(),
432 ));
433 }
434 let val = &val_args[0];
435 match val {
436 Value::Map(map) => {
438 if let Some(Value::String(t)) = map.get("_type") {
439 Ok(Value::String(t.clone()))
440 } else {
441 Err(datafusion::error::DataFusionError::Execution(
443 "TypeError: InvalidArgumentValue - type() requires a relationship argument".to_string(),
444 ))
445 }
446 }
447 Value::Null => Ok(Value::Null),
448 _ => Err(datafusion::error::DataFusionError::Execution(
449 "TypeError: InvalidArgumentValue - type() requires a relationship argument"
450 .to_string(),
451 )),
452 }
453 })
454 }
455}
456
457pub fn create_keys_udf() -> ScalarUDF {
465 ScalarUDF::new_from_impl(KeysUdf::new())
466}
467
468#[derive(Debug)]
469struct KeysUdf {
470 signature: Signature,
471}
472
473impl KeysUdf {
474 fn new() -> Self {
475 Self {
476 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
477 }
478 }
479}
480
481impl_udf_eq_hash!(KeysUdf);
482
483impl ScalarUDFImpl for KeysUdf {
484 fn as_any(&self) -> &dyn Any {
485 self
486 }
487
488 fn name(&self) -> &str {
489 "keys"
490 }
491
492 fn signature(&self) -> &Signature {
493 &self.signature
494 }
495
496 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
497 Ok(DataType::List(Arc::new(
498 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
499 )))
500 }
501
502 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
503 let output_type = self.return_type(&[])?;
504 invoke_cypher_udf(args, &output_type, |val_args| {
505 if val_args.is_empty() {
506 return Err(datafusion::error::DataFusionError::Execution(
507 "keys(): requires 1 argument".to_string(),
508 ));
509 }
510
511 let arg = &val_args[0];
512 let keys = match arg {
513 Value::Map(map) => {
514 let (source, is_entity) = match map.get("_all_props") {
524 Some(Value::Map(all)) => (all, true),
525 _ => (map, false),
526 };
527 let mut key_strings: Vec<String> = source
528 .iter()
529 .filter(|(k, v)| !k.starts_with('_') && (!is_entity || !v.is_null()))
530 .map(|(k, _)| k.clone())
531 .collect();
532 key_strings.sort();
533 key_strings
534 .into_iter()
535 .map(Value::String)
536 .collect::<Vec<_>>()
537 }
538 Value::Null => {
539 return Ok(Value::Null);
540 }
541 _ => {
542 vec![]
545 }
546 };
547
548 Ok(Value::List(keys))
549 })
550 }
551}
552
553pub fn create_properties_udf() -> ScalarUDF {
558 ScalarUDF::new_from_impl(PropertiesUdf::new())
559}
560
561#[derive(Debug)]
562struct PropertiesUdf {
563 signature: Signature,
564}
565
566impl PropertiesUdf {
567 fn new() -> Self {
568 Self {
569 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
570 }
571 }
572}
573
574impl_udf_eq_hash!(PropertiesUdf);
575
576impl ScalarUDFImpl for PropertiesUdf {
577 fn as_any(&self) -> &dyn Any {
578 self
579 }
580
581 fn name(&self) -> &str {
582 "properties"
583 }
584
585 fn signature(&self) -> &Signature {
586 &self.signature
587 }
588
589 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
590 Ok(DataType::LargeBinary)
592 }
593
594 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
595 let output_type = self.return_type(&[])?;
596 invoke_cypher_udf(args, &output_type, |val_args| {
597 if val_args.is_empty() {
598 return Err(datafusion::error::DataFusionError::Execution(
599 "properties(): requires 1 argument".to_string(),
600 ));
601 }
602
603 let arg = &val_args[0];
604 match arg {
605 Value::Map(map) => {
606 let identity_null = map
616 .get("_vid")
617 .map(|v| v.is_null())
618 .or_else(|| map.get("_eid").map(|v| v.is_null()))
619 .unwrap_or(false);
620 if identity_null {
621 return Ok(Value::Null);
622 }
623
624 let source = match map.get("_all_props") {
626 Some(Value::Map(all)) => all,
627 _ => map,
628 };
629 let filtered: std::collections::HashMap<String, Value> = source
631 .iter()
632 .filter(|(k, _)| !k.starts_with('_'))
633 .map(|(k, v)| (k.clone(), v.clone()))
634 .collect();
635 Ok(Value::Map(filtered))
636 }
637 _ => Ok(Value::Null),
638 }
639 })
640 }
641}
642
643pub fn create_index_udf() -> ScalarUDF {
648 ScalarUDF::new_from_impl(IndexUdf::new())
649}
650
651#[derive(Debug)]
652struct IndexUdf {
653 signature: Signature,
654}
655
656impl IndexUdf {
657 fn new() -> Self {
658 Self {
659 signature: Signature::any(2, Volatility::Immutable),
660 }
661 }
662}
663
664impl_udf_eq_hash!(IndexUdf);
665
666impl ScalarUDFImpl for IndexUdf {
667 fn as_any(&self) -> &dyn Any {
668 self
669 }
670
671 fn name(&self) -> &str {
672 "index"
673 }
674
675 fn signature(&self) -> &Signature {
676 &self.signature
677 }
678
679 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
680 Ok(DataType::LargeBinary)
682 }
683
684 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
685 let output_type = self.return_type(&[])?;
686 invoke_cypher_udf(args, &output_type, |val_args| {
687 if val_args.len() != 2 {
688 return Err(datafusion::error::DataFusionError::Execution(
689 "index(): requires 2 arguments".to_string(),
690 ));
691 }
692
693 let container = &val_args[0];
694 let index = &val_args[1];
695
696 let index_as_int = index.as_i64();
700
701 let result = match container {
702 Value::List(arr) => {
703 if let Some(i) = index_as_int {
704 let idx = if i < 0 {
705 let pos = arr.len() as i64 + i;
706 if pos < 0 { -1 } else { pos }
707 } else {
708 i
709 };
710 if idx >= 0 && (idx as usize) < arr.len() {
711 arr[idx as usize].clone()
712 } else {
713 Value::Null
714 }
715 } else if index.is_null() {
716 Value::Null
717 } else {
718 return Err(datafusion::error::DataFusionError::Execution(format!(
719 "TypeError: InvalidArgumentType - list index must be an integer, got: {:?}",
720 index
721 )));
722 }
723 }
724 Value::Map(map) => {
725 if let Some(key) = index.as_str() {
726 if let Some(val) = map.get(key) {
728 val.clone()
729 } else if let Some(Value::Map(all_props)) = map.get("_all_props") {
730 all_props.get(key).cloned().unwrap_or(Value::Null)
732 } else if let Some(Value::Map(props)) = map.get("properties") {
733 props.get(key).cloned().unwrap_or(Value::Null)
735 } else {
736 Value::Null
737 }
738 } else if !index.is_null() {
739 return Err(datafusion::error::DataFusionError::Execution(
740 "index(): map index must be a string".to_string(),
741 ));
742 } else {
743 Value::Null
744 }
745 }
746 Value::Node(node) => {
747 if let Some(key) = index.as_str() {
748 node.properties.get(key).cloned().unwrap_or(Value::Null)
749 } else if !index.is_null() {
750 return Err(datafusion::error::DataFusionError::Execution(
751 "index(): node index must be a string".to_string(),
752 ));
753 } else {
754 Value::Null
755 }
756 }
757 Value::Edge(edge) => {
758 if let Some(key) = index.as_str() {
759 edge.properties.get(key).cloned().unwrap_or(Value::Null)
760 } else if !index.is_null() {
761 return Err(datafusion::error::DataFusionError::Execution(
762 "index(): edge index must be a string".to_string(),
763 ));
764 } else {
765 Value::Null
766 }
767 }
768 Value::Null => Value::Null,
769 _ => {
770 return Err(datafusion::error::DataFusionError::Execution(format!(
771 "TypeError: InvalidArgumentType - cannot index into {:?}",
772 container
773 )));
774 }
775 };
776
777 Ok(result)
778 })
779 }
780}
781
782pub fn create_labels_udf() -> ScalarUDF {
787 ScalarUDF::new_from_impl(LabelsUdf::new())
788}
789
790#[derive(Debug)]
791struct LabelsUdf {
792 signature: Signature,
793}
794
795impl LabelsUdf {
796 fn new() -> Self {
797 Self {
798 signature: Signature::any(1, Volatility::Immutable),
799 }
800 }
801}
802
803impl_udf_eq_hash!(LabelsUdf);
804
805impl ScalarUDFImpl for LabelsUdf {
806 fn as_any(&self) -> &dyn Any {
807 self
808 }
809
810 fn name(&self) -> &str {
811 "labels"
812 }
813
814 fn signature(&self) -> &Signature {
815 &self.signature
816 }
817
818 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
819 Ok(DataType::List(Arc::new(
820 arrow::datatypes::Field::new_list_field(DataType::Utf8, true),
821 )))
822 }
823
824 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
825 let output_type = self.return_type(&[])?;
826 invoke_cypher_udf(args, &output_type, |val_args| {
827 if val_args.is_empty() {
828 return Err(datafusion::error::DataFusionError::Execution(
829 "labels(): requires 1 argument".to_string(),
830 ));
831 }
832
833 let node = &val_args[0];
834 match node {
835 Value::Map(map) => {
836 if let Some(Value::List(arr)) = map.get("_labels") {
837 Ok(Value::List(arr.clone()))
838 } else {
839 Err(datafusion::error::DataFusionError::Execution(
841 "TypeError: InvalidArgumentValue - labels() requires a node argument"
842 .to_string(),
843 ))
844 }
845 }
846 Value::Null => Ok(Value::Null),
847 _ => Err(datafusion::error::DataFusionError::Execution(
848 "TypeError: InvalidArgumentValue - labels() requires a node argument"
849 .to_string(),
850 )),
851 }
852 })
853 }
854}
855
856pub fn create_nodes_udf() -> ScalarUDF {
861 ScalarUDF::new_from_impl(NodesUdf::new())
862}
863
864#[derive(Debug)]
865struct NodesUdf {
866 signature: Signature,
867}
868
869impl NodesUdf {
870 fn new() -> Self {
871 Self {
872 signature: Signature::any(1, Volatility::Immutable),
873 }
874 }
875}
876
877impl_udf_eq_hash!(NodesUdf);
878
879impl ScalarUDFImpl for NodesUdf {
880 fn as_any(&self) -> &dyn Any {
881 self
882 }
883
884 fn name(&self) -> &str {
885 "nodes"
886 }
887
888 fn signature(&self) -> &Signature {
889 &self.signature
890 }
891
892 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
893 Ok(DataType::LargeBinary)
894 }
895
896 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
897 let output_type = self.return_type(&[])?;
898 invoke_cypher_udf(args, &output_type, |val_args| {
899 if val_args.is_empty() {
900 return Err(datafusion::error::DataFusionError::Execution(
901 "nodes(): requires 1 argument".to_string(),
902 ));
903 }
904
905 let path = &val_args[0];
906 let nodes = match path {
907 Value::Map(map) => map.get("nodes").cloned().unwrap_or(Value::Null),
908 _ => Value::Null,
909 };
910
911 Ok(nodes)
912 })
913 }
914}
915
916pub fn create_relationships_udf() -> ScalarUDF {
921 ScalarUDF::new_from_impl(RelationshipsUdf::new())
922}
923
924#[derive(Debug)]
925struct RelationshipsUdf {
926 signature: Signature,
927}
928
929impl RelationshipsUdf {
930 fn new() -> Self {
931 Self {
932 signature: Signature::any(1, Volatility::Immutable),
933 }
934 }
935}
936
937impl_udf_eq_hash!(RelationshipsUdf);
938
939impl ScalarUDFImpl for RelationshipsUdf {
940 fn as_any(&self) -> &dyn Any {
941 self
942 }
943
944 fn name(&self) -> &str {
945 "relationships"
946 }
947
948 fn signature(&self) -> &Signature {
949 &self.signature
950 }
951
952 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
953 Ok(DataType::LargeBinary)
954 }
955
956 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
957 let output_type = self.return_type(&[])?;
958 invoke_cypher_udf(args, &output_type, |val_args| {
959 if val_args.is_empty() {
960 return Err(datafusion::error::DataFusionError::Execution(
961 "relationships(): requires 1 argument".to_string(),
962 ));
963 }
964
965 let path = &val_args[0];
966 let rels = match path {
967 Value::Map(map) => map.get("relationships").cloned().unwrap_or(Value::Null),
968 _ => Value::Null,
969 };
970
971 Ok(rels)
972 })
973 }
974}
975
976pub fn create_startnode_udf() -> ScalarUDF {
985 ScalarUDF::new_from_impl(StartNodeUdf::new())
986}
987
988#[derive(Debug)]
989struct StartNodeUdf {
990 signature: Signature,
991}
992
993impl StartNodeUdf {
994 fn new() -> Self {
995 Self {
996 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
997 }
998 }
999}
1000
1001impl_udf_eq_hash!(StartNodeUdf);
1002
1003impl ScalarUDFImpl for StartNodeUdf {
1004 fn as_any(&self) -> &dyn Any {
1005 self
1006 }
1007
1008 fn name(&self) -> &str {
1009 "startnode"
1010 }
1011
1012 fn signature(&self) -> &Signature {
1013 &self.signature
1014 }
1015
1016 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1017 Ok(DataType::LargeBinary)
1018 }
1019
1020 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1021 let output_type = DataType::LargeBinary;
1022 invoke_cypher_udf(args, &output_type, |val_args| {
1023 startnode_endnode_impl(val_args, true)
1024 })
1025 }
1026}
1027
1028pub fn create_endnode_udf() -> ScalarUDF {
1034 ScalarUDF::new_from_impl(EndNodeUdf::new())
1035}
1036
1037#[derive(Debug)]
1038struct EndNodeUdf {
1039 signature: Signature,
1040}
1041
1042impl EndNodeUdf {
1043 fn new() -> Self {
1044 Self {
1045 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
1046 }
1047 }
1048}
1049
1050impl_udf_eq_hash!(EndNodeUdf);
1051
1052impl ScalarUDFImpl for EndNodeUdf {
1053 fn as_any(&self) -> &dyn Any {
1054 self
1055 }
1056
1057 fn name(&self) -> &str {
1058 "endnode"
1059 }
1060
1061 fn signature(&self) -> &Signature {
1062 &self.signature
1063 }
1064
1065 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1066 Ok(DataType::LargeBinary)
1067 }
1068
1069 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1070 let output_type = DataType::LargeBinary;
1071 invoke_cypher_udf(args, &output_type, |val_args| {
1072 startnode_endnode_impl(val_args, false)
1073 })
1074 }
1075}
1076
1077fn startnode_endnode_impl(val_args: &[Value], is_start: bool) -> DFResult<Value> {
1082 if val_args.is_empty() {
1083 let fn_name = if is_start { "startNode" } else { "endNode" };
1084 return Err(datafusion::error::DataFusionError::Execution(format!(
1085 "{fn_name}(): requires at least 1 argument"
1086 )));
1087 }
1088
1089 let edge_val = &val_args[0];
1090 let target_vid = extract_endpoint_vid(edge_val, is_start);
1091
1092 let target_vid = match target_vid {
1093 Some(vid) => vid,
1094 None => return Ok(Value::Null),
1095 };
1096
1097 for node_val in val_args.iter().skip(1) {
1099 if let Some(vid) = extract_vid(node_val)
1100 && vid == target_vid
1101 {
1102 return Ok(node_val.clone());
1103 }
1104 }
1105
1106 let mut map = std::collections::HashMap::new();
1108 map.insert("_vid".to_string(), Value::Int(target_vid as i64));
1109 Ok(Value::Map(map))
1110}
1111
1112fn extract_endpoint_vid(val: &Value, is_start: bool) -> Option<u64> {
1114 match val {
1115 Value::Edge(edge) => {
1116 let vid = if is_start { edge.src } else { edge.dst };
1117 Some(vid.as_u64())
1118 }
1119 Value::Map(map) => {
1120 let key = if is_start { "_src_vid" } else { "_dst_vid" };
1122 if let Some(v) = map.get(key) {
1123 return v.as_u64();
1124 }
1125 let key2 = if is_start { "_src" } else { "_dst" };
1127 if let Some(v) = map.get(key2) {
1128 return v.as_u64();
1129 }
1130 let node_key = if is_start { "_startNode" } else { "_endNode" };
1132 if let Some(node_val) = map.get(node_key) {
1133 return extract_vid(node_val);
1134 }
1135 None
1136 }
1137 _ => None,
1138 }
1139}
1140
1141fn extract_vid(val: &Value) -> Option<u64> {
1143 match val {
1144 Value::Map(map) => map.get("_vid").and_then(|v| v.as_u64()),
1145 _ => None,
1146 }
1147}
1148
1149fn extract_i64_range_arg(arg: &ColumnarValue, row_idx: usize, name: &str) -> DFResult<i64> {
1156 match arg {
1157 ColumnarValue::Scalar(sv) => match sv {
1158 ScalarValue::Int8(Some(v)) => Ok(*v as i64),
1159 ScalarValue::Int16(Some(v)) => Ok(*v as i64),
1160 ScalarValue::Int32(Some(v)) => Ok(*v as i64),
1161 ScalarValue::Int64(Some(v)) => Ok(*v),
1162 ScalarValue::UInt8(Some(v)) => Ok(*v as i64),
1163 ScalarValue::UInt16(Some(v)) => Ok(*v as i64),
1164 ScalarValue::UInt32(Some(v)) => Ok(*v as i64),
1165 ScalarValue::UInt64(Some(v)) => Ok(*v as i64),
1166 ScalarValue::LargeBinary(Some(bytes)) => {
1167 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1168 datafusion::error::DataFusionError::Execution(format!(
1169 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1170 name
1171 ))
1172 })
1173 }
1174 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1175 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1176 name
1177 ))),
1178 },
1179 ColumnarValue::Array(arr) => {
1180 if row_idx >= arr.len() || arr.is_null(row_idx) {
1181 return Err(datafusion::error::DataFusionError::Execution(format!(
1182 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1183 name
1184 )));
1185 }
1186 if !arr.is_empty() {
1188 use datafusion::arrow::array::{
1189 Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, UInt16Array,
1190 UInt32Array, UInt64Array,
1191 };
1192 match arr.data_type() {
1193 DataType::Int8 => Ok(arr
1194 .as_any()
1195 .downcast_ref::<Int8Array>()
1196 .unwrap()
1197 .value(row_idx) as i64),
1198 DataType::Int16 => Ok(arr
1199 .as_any()
1200 .downcast_ref::<Int16Array>()
1201 .unwrap()
1202 .value(row_idx) as i64),
1203 DataType::Int32 => Ok(arr
1204 .as_any()
1205 .downcast_ref::<Int32Array>()
1206 .unwrap()
1207 .value(row_idx) as i64),
1208 DataType::Int64 => Ok(arr
1209 .as_any()
1210 .downcast_ref::<Int64Array>()
1211 .unwrap()
1212 .value(row_idx)),
1213 DataType::UInt8 => Ok(arr
1214 .as_any()
1215 .downcast_ref::<UInt8Array>()
1216 .unwrap()
1217 .value(row_idx) as i64),
1218 DataType::UInt16 => Ok(arr
1219 .as_any()
1220 .downcast_ref::<UInt16Array>()
1221 .unwrap()
1222 .value(row_idx) as i64),
1223 DataType::UInt32 => Ok(arr
1224 .as_any()
1225 .downcast_ref::<UInt32Array>()
1226 .unwrap()
1227 .value(row_idx) as i64),
1228 DataType::UInt64 => Ok(arr
1229 .as_any()
1230 .downcast_ref::<UInt64Array>()
1231 .unwrap()
1232 .value(row_idx) as i64),
1233 DataType::LargeBinary => {
1234 let bytes = arr
1235 .as_any()
1236 .downcast_ref::<LargeBinaryArray>()
1237 .unwrap()
1238 .value(row_idx);
1239 scalar_binary_to_value(bytes).as_i64().ok_or_else(|| {
1240 datafusion::error::DataFusionError::Execution(format!(
1241 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1242 name
1243 ))
1244 })
1245 }
1246 _ => Err(datafusion::error::DataFusionError::Execution(format!(
1247 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1248 name
1249 ))),
1250 }
1251 } else {
1252 Err(datafusion::error::DataFusionError::Execution(format!(
1253 "ArgumentError: InvalidArgumentType - range() {} must be an integer",
1254 name
1255 )))
1256 }
1257 }
1258 }
1259}
1260
1261pub fn create_range_udf() -> ScalarUDF {
1263 ScalarUDF::new_from_impl(RangeUdf::new())
1264}
1265
1266#[derive(Debug)]
1267struct RangeUdf {
1268 signature: Signature,
1269}
1270
1271impl RangeUdf {
1272 fn new() -> Self {
1273 Self {
1274 signature: Signature::one_of(
1275 vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1276 Volatility::Immutable,
1277 ),
1278 }
1279 }
1280}
1281
1282impl_udf_eq_hash!(RangeUdf);
1283
1284impl ScalarUDFImpl for RangeUdf {
1285 fn as_any(&self) -> &dyn Any {
1286 self
1287 }
1288
1289 fn name(&self) -> &str {
1290 "range"
1291 }
1292
1293 fn signature(&self) -> &Signature {
1294 &self.signature
1295 }
1296
1297 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1298 Ok(DataType::List(Arc::new(
1299 arrow::datatypes::Field::new_list_field(DataType::Int64, true),
1300 )))
1301 }
1302
1303 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1304 if args.args.len() < 2 || args.args.len() > 3 {
1305 return Err(datafusion::error::DataFusionError::Execution(
1306 "range(): requires 2 or 3 arguments".to_string(),
1307 ));
1308 }
1309
1310 let len = args
1311 .args
1312 .iter()
1313 .find_map(|arg| match arg {
1314 ColumnarValue::Array(arr) => Some(arr.len()),
1315 _ => None,
1316 })
1317 .unwrap_or(1);
1318
1319 let mut list_builder =
1320 arrow_array::builder::ListBuilder::new(arrow_array::builder::Int64Builder::new());
1321
1322 for row_idx in 0..len {
1323 let start = extract_i64_range_arg(&args.args[0], row_idx, "start")?;
1324 let end = extract_i64_range_arg(&args.args[1], row_idx, "end")?;
1325 let step = if args.args.len() == 3 {
1326 extract_i64_range_arg(&args.args[2], row_idx, "step")?
1327 } else {
1328 1
1329 };
1330
1331 if step == 0 {
1332 return Err(datafusion::error::DataFusionError::Execution(
1333 "range(): step cannot be zero".to_string(),
1334 ));
1335 }
1336
1337 if step > 0 && start <= end {
1338 let mut current = start;
1339 while current <= end {
1340 list_builder.values().append_value(current);
1341 current += step;
1342 }
1343 } else if step < 0 && start >= end {
1344 let mut current = start;
1345 while current >= end {
1346 list_builder.values().append_value(current);
1347 current += step;
1348 }
1349 }
1350 list_builder.append(true);
1352 }
1353
1354 let list_arr = Arc::new(list_builder.finish()) as ArrayRef;
1355 if len == 1
1356 && args
1357 .args
1358 .iter()
1359 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
1360 {
1361 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
1362 &list_arr, 0,
1363 )?))
1364 } else {
1365 Ok(ColumnarValue::Array(list_arr))
1366 }
1367 }
1368}
1369
1370fn invoke_binary_bitwise_op<F>(
1378 args: &ScalarFunctionArgs,
1379 name: &str,
1380 op: F,
1381) -> DFResult<ColumnarValue>
1382where
1383 F: Fn(i64, i64) -> i64,
1384{
1385 use arrow_array::Int64Array;
1386 use datafusion::common::ScalarValue;
1387 use datafusion::error::DataFusionError;
1388
1389 if args.args.len() != 2 {
1390 return Err(DataFusionError::Execution(format!(
1391 "{}(): requires exactly 2 arguments",
1392 name
1393 )));
1394 }
1395
1396 let left = &args.args[0];
1397 let right = &args.args[1];
1398
1399 match (left, right) {
1400 (
1401 ColumnarValue::Scalar(ScalarValue::Int64(Some(l))),
1402 ColumnarValue::Scalar(ScalarValue::Int64(Some(r))),
1403 ) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*l, *r))))),
1404 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
1405 let l_arr = l_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1406 DataFusionError::Execution(format!("{}(): left array must be Int64", name))
1407 })?;
1408 let r_arr = r_arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1409 DataFusionError::Execution(format!("{}(): right array must be Int64", name))
1410 })?;
1411
1412 let result: Int64Array = l_arr
1413 .iter()
1414 .zip(r_arr.iter())
1415 .map(|(l, r)| match (l, r) {
1416 (Some(l), Some(r)) => Some(op(l, r)),
1417 _ => None,
1418 })
1419 .collect();
1420
1421 Ok(ColumnarValue::Array(Arc::new(result)))
1422 }
1423 _ => Err(DataFusionError::Execution(format!(
1424 "{}(): mixed scalar/array not supported",
1425 name
1426 ))),
1427 }
1428}
1429
1430fn invoke_unary_bitwise_op<F>(
1434 args: &ScalarFunctionArgs,
1435 name: &str,
1436 op: F,
1437) -> DFResult<ColumnarValue>
1438where
1439 F: Fn(i64) -> i64,
1440{
1441 use arrow_array::Int64Array;
1442 use datafusion::common::ScalarValue;
1443 use datafusion::error::DataFusionError;
1444
1445 if args.args.len() != 1 {
1446 return Err(DataFusionError::Execution(format!(
1447 "{}(): requires exactly 1 argument",
1448 name
1449 )));
1450 }
1451
1452 let operand = &args.args[0];
1453
1454 match operand {
1455 ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => {
1456 Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(op(*v)))))
1457 }
1458 ColumnarValue::Array(arr) => {
1459 let arr = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
1460 DataFusionError::Execution(format!("{}(): array must be Int64", name))
1461 })?;
1462
1463 let result: Int64Array = arr.iter().map(|v| v.map(&op)).collect();
1464
1465 Ok(ColumnarValue::Array(Arc::new(result)))
1466 }
1467 _ => Err(DataFusionError::Execution(format!(
1468 "{}(): invalid argument type",
1469 name
1470 ))),
1471 }
1472}
1473
1474macro_rules! define_binary_bitwise_udf {
1478 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1479 #[derive(Debug)]
1480 struct $struct_name {
1481 signature: Signature,
1482 }
1483
1484 impl $struct_name {
1485 fn new() -> Self {
1486 Self {
1487 signature: Signature::exact(
1488 vec![DataType::Int64, DataType::Int64],
1489 Volatility::Immutable,
1490 ),
1491 }
1492 }
1493 }
1494
1495 impl_udf_eq_hash!($struct_name);
1496
1497 impl ScalarUDFImpl for $struct_name {
1498 fn as_any(&self) -> &dyn Any {
1499 self
1500 }
1501
1502 fn name(&self) -> &str {
1503 $udf_name
1504 }
1505
1506 fn signature(&self) -> &Signature {
1507 &self.signature
1508 }
1509
1510 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1511 Ok(DataType::Int64)
1512 }
1513
1514 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1515 invoke_binary_bitwise_op(&args, $udf_name, $op)
1516 }
1517 }
1518 };
1519}
1520
1521macro_rules! define_unary_bitwise_udf {
1525 ($struct_name:ident, $udf_name:literal, $op:expr) => {
1526 #[derive(Debug)]
1527 struct $struct_name {
1528 signature: Signature,
1529 }
1530
1531 impl $struct_name {
1532 fn new() -> Self {
1533 Self {
1534 signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
1535 }
1536 }
1537 }
1538
1539 impl_udf_eq_hash!($struct_name);
1540
1541 impl ScalarUDFImpl for $struct_name {
1542 fn as_any(&self) -> &dyn Any {
1543 self
1544 }
1545
1546 fn name(&self) -> &str {
1547 $udf_name
1548 }
1549
1550 fn signature(&self) -> &Signature {
1551 &self.signature
1552 }
1553
1554 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1555 Ok(DataType::Int64)
1556 }
1557
1558 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1559 invoke_unary_bitwise_op(&args, $udf_name, $op)
1560 }
1561 }
1562 };
1563}
1564
1565define_binary_bitwise_udf!(BitwiseOrUdf, "uni.bitwise.or", |l, r| l | r);
1567define_binary_bitwise_udf!(BitwiseAndUdf, "uni.bitwise.and", |l, r| l & r);
1568define_binary_bitwise_udf!(BitwiseXorUdf, "uni.bitwise.xor", |l, r| l ^ r);
1569define_binary_bitwise_udf!(ShiftLeftUdf, "uni.bitwise.shiftLeft", |l, r| l << r);
1570define_binary_bitwise_udf!(ShiftRightUdf, "uni.bitwise.shiftRight", |l, r| l >> r);
1571
1572define_unary_bitwise_udf!(BitwiseNotUdf, "uni.bitwise.not", |v| !v);
1574
1575pub fn create_bitwise_or_udf() -> ScalarUDF {
1577 ScalarUDF::new_from_impl(BitwiseOrUdf::new())
1578}
1579
1580pub fn create_bitwise_and_udf() -> ScalarUDF {
1582 ScalarUDF::new_from_impl(BitwiseAndUdf::new())
1583}
1584
1585pub fn create_bitwise_xor_udf() -> ScalarUDF {
1587 ScalarUDF::new_from_impl(BitwiseXorUdf::new())
1588}
1589
1590pub fn create_bitwise_not_udf() -> ScalarUDF {
1592 ScalarUDF::new_from_impl(BitwiseNotUdf::new())
1593}
1594
1595pub fn create_shift_left_udf() -> ScalarUDF {
1597 ScalarUDF::new_from_impl(ShiftLeftUdf::new())
1598}
1599
1600pub fn create_shift_right_udf() -> ScalarUDF {
1602 ScalarUDF::new_from_impl(ShiftRightUdf::new())
1603}
1604
1605fn create_temporal_udf(name: &str) -> ScalarUDF {
1616 ScalarUDF::new_from_impl(TemporalUdf::new(name.to_string()))
1617}
1618
1619#[derive(Debug)]
1620struct TemporalUdf {
1621 name: String,
1622 signature: Signature,
1623}
1624
1625impl TemporalUdf {
1626 fn new(name: String) -> Self {
1627 Self {
1628 name,
1629 signature: Signature::new(
1632 TypeSignature::OneOf(vec![
1633 TypeSignature::Exact(vec![]),
1634 TypeSignature::VariadicAny,
1635 ]),
1636 Volatility::Immutable,
1637 ),
1638 }
1639 }
1640}
1641
1642impl_udf_eq_hash!(TemporalUdf);
1643
1644impl ScalarUDFImpl for TemporalUdf {
1645 fn as_any(&self) -> &dyn Any {
1646 self
1647 }
1648
1649 fn name(&self) -> &str {
1650 &self.name
1651 }
1652
1653 fn signature(&self) -> &Signature {
1654 &self.signature
1655 }
1656
1657 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1658 let name = self.name.to_lowercase();
1659 match name.as_str() {
1660 "year" | "month" | "day" | "hour" | "minute" | "second" => Ok(DataType::Int64),
1662 "datetime"
1668 | "localdatetime"
1669 | "date"
1670 | "time"
1671 | "localtime"
1672 | "duration"
1673 | "date.truncate"
1674 | "time.truncate"
1675 | "datetime.truncate"
1676 | "localdatetime.truncate"
1677 | "localtime.truncate"
1678 | "duration.between"
1679 | "duration.inmonths"
1680 | "duration.indays"
1681 | "duration.inseconds"
1682 | "datetime.fromepoch"
1683 | "datetime.fromepochmillis"
1684 | "datetime.transaction"
1685 | "datetime.statement"
1686 | "datetime.realtime"
1687 | "date.transaction"
1688 | "date.statement"
1689 | "date.realtime"
1690 | "time.transaction"
1691 | "time.statement"
1692 | "time.realtime"
1693 | "localtime.transaction"
1694 | "localtime.statement"
1695 | "localtime.realtime"
1696 | "localdatetime.transaction"
1697 | "localdatetime.statement"
1698 | "localdatetime.realtime" => Ok(DataType::LargeBinary),
1699 _ => Ok(DataType::Utf8),
1700 }
1701 }
1702
1703 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1704 let func_name = self.name.to_uppercase();
1705 let output_type = self.return_type(&[])?;
1706 invoke_cypher_udf(args, &output_type, |val_args| {
1707 crate::query::datetime::eval_datetime_function(&func_name, val_args).map_err(|e| {
1708 datafusion::error::DataFusionError::Execution(format!("{}(): {}", self.name, e))
1709 })
1710 })
1711 }
1712}
1713
1714fn create_duration_property_udf() -> ScalarUDF {
1719 ScalarUDF::new_from_impl(DurationPropertyUdf::new())
1720}
1721
1722#[derive(Debug)]
1723struct DurationPropertyUdf {
1724 signature: Signature,
1725}
1726
1727impl DurationPropertyUdf {
1728 fn new() -> Self {
1729 Self {
1730 signature: Signature::new(
1731 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
1732 Volatility::Immutable,
1733 ),
1734 }
1735 }
1736}
1737
1738impl_udf_eq_hash!(DurationPropertyUdf);
1739
1740impl ScalarUDFImpl for DurationPropertyUdf {
1741 fn as_any(&self) -> &dyn Any {
1742 self
1743 }
1744
1745 fn name(&self) -> &str {
1746 "_duration_property"
1747 }
1748
1749 fn signature(&self) -> &Signature {
1750 &self.signature
1751 }
1752
1753 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1754 Ok(DataType::Int64)
1755 }
1756
1757 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1758 let output_type = self.return_type(&[])?;
1759 invoke_cypher_udf(args, &output_type, |val_args| {
1760 if val_args.len() != 2 {
1761 return Err(datafusion::error::DataFusionError::Execution(
1762 "_duration_property(): requires 2 arguments (duration_string, component)"
1763 .to_string(),
1764 ));
1765 }
1766
1767 let dur_string_owned;
1768 let dur_str = match &val_args[0] {
1769 Value::String(s) => s.as_str(),
1770 Value::Temporal(uni_common::TemporalValue::Duration { .. }) => {
1771 dur_string_owned = val_args[0].to_string();
1772 &dur_string_owned
1773 }
1774 Value::Null => return Ok(Value::Null),
1775 _ => {
1776 return Err(datafusion::error::DataFusionError::Execution(
1777 "_duration_property(): duration must be a string or temporal duration"
1778 .to_string(),
1779 ));
1780 }
1781 };
1782 let component = match &val_args[1] {
1783 Value::String(s) => s,
1784 _ => {
1785 return Err(datafusion::error::DataFusionError::Execution(
1786 "_duration_property(): component must be a string".to_string(),
1787 ));
1788 }
1789 };
1790
1791 crate::query::datetime::eval_duration_accessor(dur_str, component).map_err(|e| {
1792 datafusion::error::DataFusionError::Execution(format!(
1793 "_duration_property(): {}",
1794 e
1795 ))
1796 })
1797 })
1798 }
1799}
1800
1801fn create_tostring_udf() -> ScalarUDF {
1806 ScalarUDF::new_from_impl(ToStringUdf::new())
1807}
1808
1809#[derive(Debug)]
1810struct ToStringUdf {
1811 signature: Signature,
1812}
1813
1814impl ToStringUdf {
1815 fn new() -> Self {
1816 Self {
1817 signature: Signature::variadic_any(Volatility::Immutable),
1818 }
1819 }
1820}
1821
1822impl_udf_eq_hash!(ToStringUdf);
1823
1824impl ScalarUDFImpl for ToStringUdf {
1825 fn as_any(&self) -> &dyn Any {
1826 self
1827 }
1828
1829 fn name(&self) -> &str {
1830 "tostring"
1831 }
1832
1833 fn signature(&self) -> &Signature {
1834 &self.signature
1835 }
1836
1837 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1838 Ok(DataType::Utf8)
1839 }
1840
1841 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1842 let output_type = self.return_type(&[])?;
1843 invoke_cypher_udf(args, &output_type, |val_args| {
1844 if val_args.is_empty() {
1845 return Err(datafusion::error::DataFusionError::Execution(
1846 "toString(): requires 1 argument".to_string(),
1847 ));
1848 }
1849 match &val_args[0] {
1850 Value::Null => Ok(Value::Null),
1851 Value::String(s) => Ok(Value::String(s.clone())),
1852 Value::Int(i) => Ok(Value::String(i.to_string())),
1853 Value::Float(f) => Ok(Value::String(f.to_string())),
1854 Value::Bool(b) => Ok(Value::String(b.to_string())),
1855 Value::Temporal(t) => Ok(Value::String(t.to_string())),
1856 other => {
1857 let type_name = match other {
1858 Value::List(_) => "List",
1859 Value::Map(_) => "Map",
1860 Value::Node { .. } => "Node",
1861 Value::Edge { .. } => "Relationship",
1862 Value::Path { .. } => "Path",
1863 _ => "Unknown",
1864 };
1865 Err(datafusion::error::DataFusionError::Execution(format!(
1866 "TypeError: InvalidArgumentValue - toString() does not accept {} values",
1867 type_name
1868 )))
1869 }
1870 }
1871 })
1872 }
1873}
1874
1875fn create_temporal_property_udf() -> ScalarUDF {
1880 ScalarUDF::new_from_impl(TemporalPropertyUdf::new())
1881}
1882
1883#[derive(Debug)]
1884struct TemporalPropertyUdf {
1885 signature: Signature,
1886}
1887
1888impl TemporalPropertyUdf {
1889 fn new() -> Self {
1890 Self {
1891 signature: Signature::variadic_any(Volatility::Immutable),
1892 }
1893 }
1894}
1895
1896impl_udf_eq_hash!(TemporalPropertyUdf);
1897
1898impl ScalarUDFImpl for TemporalPropertyUdf {
1899 fn as_any(&self) -> &dyn Any {
1900 self
1901 }
1902
1903 fn name(&self) -> &str {
1904 "_temporal_property"
1905 }
1906
1907 fn signature(&self) -> &Signature {
1908 &self.signature
1909 }
1910
1911 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
1912 Ok(DataType::LargeBinary)
1913 }
1914
1915 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
1916 let output_type = self.return_type(&[])?;
1917 invoke_cypher_udf(args, &output_type, |val_args| {
1918 if val_args.len() != 2 {
1919 return Err(datafusion::error::DataFusionError::Execution(
1920 "_temporal_property(): requires 2 arguments (temporal_value, component)"
1921 .to_string(),
1922 ));
1923 }
1924
1925 let component = match &val_args[1] {
1926 Value::String(s) => s.clone(),
1927 _ => {
1928 return Err(datafusion::error::DataFusionError::Execution(
1929 "_temporal_property(): component must be a string".to_string(),
1930 ));
1931 }
1932 };
1933
1934 crate::query::datetime::eval_temporal_accessor_value(&val_args[0], &component).map_err(
1935 |e| {
1936 datafusion::error::DataFusionError::Execution(format!(
1937 "_temporal_property(): {}",
1938 e
1939 ))
1940 },
1941 )
1942 })
1943 }
1944}
1945
1946macro_rules! downcast_arr {
1949 ($arr:expr, $array_type:ty) => {
1950 $arr.as_any().downcast_ref::<$array_type>().ok_or_else(|| {
1951 datafusion::error::DataFusionError::Execution(format!(
1952 "Failed to downcast to {}",
1953 stringify!($array_type)
1954 ))
1955 })?
1956 };
1957}
1958
1959fn cypher_type_name(val: &Value) -> &'static str {
1961 match val {
1962 Value::Null => "Null",
1963 Value::Bool(_) => "Boolean",
1964 Value::Int(_) => "Integer",
1965 Value::Float(_) => "Float",
1966 Value::String(_) => "String",
1967 Value::Bytes(_) => "Bytes",
1968 Value::List(_) => "List",
1969 Value::Map(_) => "Map",
1970 Value::Node(_) => "Node",
1971 Value::Edge(_) => "Relationship",
1972 Value::Path(_) => "Path",
1973 Value::Vector(_) => "Vector",
1974 Value::Temporal(_) => "Temporal",
1975 _ => "Unknown",
1976 }
1977}
1978
1979fn string_to_value(s: &str) -> Value {
1981 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
1982 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
1983 {
1984 return Value::from(obj);
1985 }
1986 Value::String(s.to_string())
1987}
1988
1989fn get_value_from_array(arr: &ArrayRef, row: usize) -> DFResult<Value> {
1995 if arr.is_null(row) {
1996 return Ok(Value::Null);
1997 }
1998
1999 match arr.data_type() {
2000 DataType::LargeBinary => {
2001 let typed = downcast_arr!(arr, LargeBinaryArray);
2002 let bytes = typed.value(row);
2003 if let Ok(val) = uni_common::cypher_value_codec::decode(bytes) {
2004 return Ok(val);
2005 }
2006 Ok(serde_json::from_slice::<serde_json::Value>(bytes)
2008 .map(Value::from)
2009 .unwrap_or(Value::Null))
2010 }
2011 DataType::Int64 => Ok(Value::Int(downcast_arr!(arr, Int64Array).value(row))),
2012 DataType::Float64 => Ok(Value::Float(downcast_arr!(arr, Float64Array).value(row))),
2013 DataType::Utf8 => Ok(string_to_value(downcast_arr!(arr, StringArray).value(row))),
2014 DataType::LargeUtf8 => Ok(string_to_value(
2015 downcast_arr!(arr, LargeStringArray).value(row),
2016 )),
2017 DataType::Boolean => Ok(Value::Bool(downcast_arr!(arr, BooleanArray).value(row))),
2018 DataType::UInt64 => Ok(Value::Int(downcast_arr!(arr, UInt64Array).value(row) as i64)),
2019 DataType::Int32 => Ok(Value::Int(downcast_arr!(arr, Int32Array).value(row) as i64)),
2020 DataType::Float32 => Ok(Value::Float(
2021 downcast_arr!(arr, Float32Array).value(row) as f64
2022 )),
2023 _ => {
2026 let scalar = ScalarValue::try_from_array(arr, row).map_err(|e| {
2027 datafusion::error::DataFusionError::Execution(format!(
2028 "Cannot extract scalar from array at row {}: {}",
2029 row, e
2030 ))
2031 })?;
2032 scalar_to_value(&scalar)
2033 }
2034 }
2035}
2036
2037fn get_value_args_for_row(args: &[ColumnarValue], row: usize) -> DFResult<Vec<Value>> {
2039 args.iter()
2040 .map(|arg| match arg {
2041 ColumnarValue::Scalar(scalar) => scalar_to_value(scalar),
2042 ColumnarValue::Array(arr) => get_value_from_array(arr, row),
2043 })
2044 .collect()
2045}
2046
2047fn invoke_cypher_udf<F>(
2049 args: ScalarFunctionArgs,
2050 output_type: &DataType,
2051 f: F,
2052) -> DFResult<ColumnarValue>
2053where
2054 F: Fn(&[Value]) -> DFResult<Value>,
2055{
2056 let len = args
2057 .args
2058 .iter()
2059 .find_map(|arg| match arg {
2060 ColumnarValue::Array(arr) => Some(arr.len()),
2061 _ => None,
2062 })
2063 .unwrap_or(1);
2064
2065 if len == 1
2066 && args
2067 .args
2068 .iter()
2069 .all(|a| matches!(a, ColumnarValue::Scalar(_)))
2070 {
2071 let row_args = get_value_args_for_row(&args.args, 0)?;
2072 let res = f(&row_args)?;
2073 if matches!(output_type, DataType::LargeBinary | DataType::List(_)) {
2074 let arr = values_to_array(&[res], output_type)
2076 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2077 return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&arr, 0)?));
2078 }
2079 if res.is_null() {
2081 let typed_null = ScalarValue::try_from(output_type).unwrap_or(ScalarValue::Utf8(None));
2082 return Ok(ColumnarValue::Scalar(typed_null));
2083 }
2084 return value_to_columnar(&res);
2085 }
2086
2087 let mut results = Vec::with_capacity(len);
2088 for i in 0..len {
2089 let row_args = get_value_args_for_row(&args.args, i)?;
2090 results.push(f(&row_args)?);
2091 }
2092
2093 let arr = values_to_array(&results, output_type)
2094 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2095 Ok(ColumnarValue::Array(arr))
2096}
2097
2098fn scalar_arr_to_value(arr: &dyn arrow::array::Array) -> DFResult<Value> {
2101 if arr.is_empty() || arr.is_null(0) {
2102 Ok(Value::Null)
2103 } else {
2104 Ok(uni_store::storage::arrow_convert::arrow_to_value(
2106 arr, 0, None,
2107 ))
2108 }
2109}
2110
2111fn resolve_timezone_offset(tz_name: &str, nanos_utc: i64) -> i32 {
2113 if tz_name == "UTC" || tz_name == "Z" {
2114 return 0;
2115 }
2116 if let Ok(tz) = tz_name.parse::<chrono_tz::Tz>() {
2117 let dt = chrono::DateTime::from_timestamp_nanos(nanos_utc).with_timezone(&tz);
2118 dt.offset().fix().local_minus_utc()
2119 } else {
2120 0
2121 }
2122}
2123
2124fn duration_micros_to_value(micros: i64) -> Value {
2126 let dur = crate::query::datetime::CypherDuration::from_micros(micros);
2127 Value::Temporal(uni_common::TemporalValue::Duration {
2128 months: dur.months,
2129 days: dur.days,
2130 nanos: dur.nanos,
2131 })
2132}
2133
2134fn timestamp_nanos_to_value(nanos: i64, tz: Option<&Arc<str>>) -> DFResult<Value> {
2136 if let Some(tz_str) = tz {
2137 let offset = resolve_timezone_offset(tz_str.as_ref(), nanos);
2138 let tz_name = if tz_str.as_ref() == "UTC" {
2139 None
2140 } else {
2141 Some(tz_str.to_string())
2142 };
2143 Ok(Value::Temporal(uni_common::TemporalValue::DateTime {
2144 nanos_since_epoch: nanos,
2145 offset_seconds: offset,
2146 timezone_name: tz_name,
2147 }))
2148 } else {
2149 Ok(Value::Temporal(uni_common::TemporalValue::LocalDateTime {
2150 nanos_since_epoch: nanos,
2151 }))
2152 }
2153}
2154
2155pub(crate) fn scalar_to_value(scalar: &ScalarValue) -> DFResult<Value> {
2157 match scalar {
2158 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
2159 if (s.starts_with('{') || s.starts_with('[') || s.starts_with('"'))
2162 && let Ok(obj) = serde_json::from_str::<serde_json::Value>(s)
2163 {
2164 return Ok(Value::from(obj));
2165 }
2166 Ok(Value::String(s.clone()))
2167 }
2168 ScalarValue::LargeBinary(Some(b)) => {
2169 if let Ok(val) = uni_common::cypher_value_codec::decode(b) {
2172 return Ok(val);
2173 }
2174 if let Ok(obj) = serde_json::from_slice::<serde_json::Value>(b) {
2175 Ok(Value::from(obj))
2176 } else {
2177 Ok(Value::Null)
2178 }
2179 }
2180 ScalarValue::Int64(Some(i)) => Ok(Value::Int(*i)),
2181 ScalarValue::Int32(Some(i)) => Ok(Value::Int(*i as i64)),
2182 ScalarValue::Float64(Some(f)) => {
2183 Ok(Value::Float(*f))
2185 }
2186 ScalarValue::Boolean(Some(b)) => Ok(Value::Bool(*b)),
2187 ScalarValue::Struct(arr) => scalar_arr_to_value(arr.as_ref()),
2188 ScalarValue::List(arr) => scalar_arr_to_value(arr.as_ref()),
2189 ScalarValue::LargeList(arr) => scalar_arr_to_value(arr.as_ref()),
2190 ScalarValue::FixedSizeList(arr) => scalar_arr_to_value(arr.as_ref()),
2191 ScalarValue::UInt64(Some(u)) => Ok(Value::Int(*u as i64)),
2193 ScalarValue::UInt32(Some(u)) => Ok(Value::Int(*u as i64)),
2194 ScalarValue::UInt16(Some(u)) => Ok(Value::Int(*u as i64)),
2195 ScalarValue::UInt8(Some(u)) => Ok(Value::Int(*u as i64)),
2196 ScalarValue::Int16(Some(i)) => Ok(Value::Int(*i as i64)),
2197 ScalarValue::Int8(Some(i)) => Ok(Value::Int(*i as i64)),
2198
2199 ScalarValue::Date32(Some(days)) => Ok(Value::Temporal(uni_common::TemporalValue::Date {
2201 days_since_epoch: *days,
2202 })),
2203 ScalarValue::Date64(Some(millis)) => {
2204 let days = (*millis / 86_400_000) as i32;
2205 Ok(Value::Temporal(uni_common::TemporalValue::Date {
2206 days_since_epoch: days,
2207 }))
2208 }
2209 ScalarValue::TimestampNanosecond(Some(nanos), tz) => {
2210 timestamp_nanos_to_value(*nanos, tz.as_ref())
2211 }
2212 ScalarValue::TimestampMicrosecond(Some(micros), tz) => {
2213 timestamp_nanos_to_value(*micros * 1_000, tz.as_ref())
2214 }
2215 ScalarValue::TimestampMillisecond(Some(millis), tz) => {
2216 timestamp_nanos_to_value(*millis * 1_000_000, tz.as_ref())
2217 }
2218 ScalarValue::TimestampSecond(Some(secs), tz) => {
2219 timestamp_nanos_to_value(*secs * 1_000_000_000, tz.as_ref())
2220 }
2221 ScalarValue::Time64Nanosecond(Some(nanos)) => {
2222 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2223 nanos_since_midnight: *nanos,
2224 }))
2225 }
2226 ScalarValue::Time64Microsecond(Some(micros)) => {
2227 Ok(Value::Temporal(uni_common::TemporalValue::LocalTime {
2228 nanos_since_midnight: *micros * 1_000,
2229 }))
2230 }
2231 ScalarValue::IntervalMonthDayNano(Some(v)) => {
2232 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2233 months: v.months as i64,
2234 days: v.days as i64,
2235 nanos: v.nanoseconds,
2236 }))
2237 }
2238 ScalarValue::DurationMicrosecond(Some(micros)) => Ok(duration_micros_to_value(*micros)),
2239 ScalarValue::DurationMillisecond(Some(millis)) => {
2240 Ok(duration_micros_to_value(*millis * 1_000))
2241 }
2242 ScalarValue::DurationSecond(Some(secs)) => Ok(duration_micros_to_value(*secs * 1_000_000)),
2243 ScalarValue::DurationNanosecond(Some(nanos)) => {
2244 Ok(Value::Temporal(uni_common::TemporalValue::Duration {
2245 months: 0,
2246 days: 0,
2247 nanos: *nanos,
2248 }))
2249 }
2250 ScalarValue::Float32(Some(f)) => Ok(Value::Float(*f as f64)),
2251
2252 ScalarValue::Null
2254 | ScalarValue::Utf8(None)
2255 | ScalarValue::LargeUtf8(None)
2256 | ScalarValue::LargeBinary(None)
2257 | ScalarValue::Int64(None)
2258 | ScalarValue::Int32(None)
2259 | ScalarValue::Int16(None)
2260 | ScalarValue::Int8(None)
2261 | ScalarValue::UInt64(None)
2262 | ScalarValue::UInt32(None)
2263 | ScalarValue::UInt16(None)
2264 | ScalarValue::UInt8(None)
2265 | ScalarValue::Float64(None)
2266 | ScalarValue::Float32(None)
2267 | ScalarValue::Boolean(None)
2268 | ScalarValue::Date32(None)
2269 | ScalarValue::Date64(None)
2270 | ScalarValue::TimestampMicrosecond(None, _)
2271 | ScalarValue::TimestampMillisecond(None, _)
2272 | ScalarValue::TimestampSecond(None, _)
2273 | ScalarValue::TimestampNanosecond(None, _)
2274 | ScalarValue::Time64Microsecond(None)
2275 | ScalarValue::Time64Nanosecond(None)
2276 | ScalarValue::DurationMicrosecond(None)
2277 | ScalarValue::DurationMillisecond(None)
2278 | ScalarValue::DurationSecond(None)
2279 | ScalarValue::DurationNanosecond(None)
2280 | ScalarValue::IntervalMonthDayNano(None) => Ok(Value::Null),
2281 other => Err(datafusion::error::DataFusionError::Execution(format!(
2282 "scalar_to_value(): unsupported scalar type {other:?}"
2283 ))),
2284 }
2285}
2286
2287fn value_to_columnar(val: &Value) -> DFResult<ColumnarValue> {
2289 let scalar = match val {
2290 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
2291 Value::Int(i) => ScalarValue::Int64(Some(*i)),
2292 Value::Float(f) => ScalarValue::Float64(Some(*f)),
2293 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
2294 Value::Null => ScalarValue::Utf8(None),
2295 Value::Temporal(tv) => {
2296 use uni_common::TemporalValue;
2297 match tv {
2298 TemporalValue::Date { days_since_epoch } => {
2299 ScalarValue::Date32(Some(*days_since_epoch))
2300 }
2301 TemporalValue::LocalTime {
2302 nanos_since_midnight,
2303 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2304 TemporalValue::Time {
2305 nanos_since_midnight,
2306 ..
2307 } => ScalarValue::Time64Nanosecond(Some(*nanos_since_midnight)),
2308 TemporalValue::LocalDateTime { nanos_since_epoch } => {
2309 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), None)
2310 }
2311 TemporalValue::DateTime {
2312 nanos_since_epoch,
2313 timezone_name,
2314 ..
2315 } => {
2316 let tz = timezone_name.as_deref().unwrap_or("UTC");
2317 ScalarValue::TimestampNanosecond(Some(*nanos_since_epoch), Some(tz.into()))
2318 }
2319 TemporalValue::Duration {
2320 months,
2321 days,
2322 nanos,
2323 } => ScalarValue::IntervalMonthDayNano(Some(
2324 arrow::datatypes::IntervalMonthDayNano {
2325 months: *months as i32,
2326 days: *days as i32,
2327 nanoseconds: *nanos,
2328 },
2329 )),
2330 TemporalValue::Btic { lo, hi, meta } => {
2331 let btic = uni_btic::Btic::new(*lo, *hi, *meta).map_err(|e| {
2332 datafusion::error::DataFusionError::Execution(format!("invalid BTIC: {e}"))
2333 })?;
2334 let packed = uni_btic::encode::encode(&btic);
2335 ScalarValue::FixedSizeBinary(24, Some(packed.to_vec()))
2336 }
2337 }
2338 }
2339 other => {
2340 return Err(datafusion::error::DataFusionError::Execution(format!(
2341 "value_to_columnar(): unsupported type {other:?}"
2342 )));
2343 }
2344 };
2345 Ok(ColumnarValue::Scalar(scalar))
2346}
2347
2348pub fn create_has_null_udf() -> ScalarUDF {
2354 ScalarUDF::new_from_impl(HasNullUdf::new())
2355}
2356
2357#[derive(Debug)]
2358struct HasNullUdf {
2359 signature: Signature,
2360}
2361
2362impl HasNullUdf {
2363 fn new() -> Self {
2364 Self {
2365 signature: Signature::any(1, Volatility::Immutable),
2366 }
2367 }
2368}
2369
2370impl_udf_eq_hash!(HasNullUdf);
2371
2372impl ScalarUDFImpl for HasNullUdf {
2373 fn as_any(&self) -> &dyn Any {
2374 self
2375 }
2376
2377 fn name(&self) -> &str {
2378 "_has_null"
2379 }
2380
2381 fn signature(&self) -> &Signature {
2382 &self.signature
2383 }
2384
2385 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2386 Ok(DataType::Boolean)
2387 }
2388
2389 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2390 if args.args.len() != 1 {
2391 return Err(datafusion::error::DataFusionError::Execution(
2392 "_has_null(): requires 1 argument".to_string(),
2393 ));
2394 }
2395
2396 fn check_list_nulls<T: arrow_array::OffsetSizeTrait>(
2398 arr: &arrow_array::GenericListArray<T>,
2399 idx: usize,
2400 ) -> bool {
2401 if arr.is_null(idx) || arr.is_empty() {
2402 false
2403 } else {
2404 arr.value(idx).null_count() > 0
2405 }
2406 }
2407
2408 match &args.args[0] {
2409 ColumnarValue::Scalar(scalar) => {
2410 let has_null = match scalar {
2411 ScalarValue::List(arr) => arr
2412 .as_any()
2413 .downcast_ref::<arrow::array::ListArray>()
2414 .map(|a| !a.is_empty() && a.value(0).null_count() > 0)
2415 .unwrap_or(arr.null_count() > 0),
2416 ScalarValue::LargeList(arr) => arr.len() > 0 && arr.value(0).null_count() > 0,
2417 ScalarValue::FixedSizeList(arr) => {
2418 arr.len() > 0 && arr.value(0).null_count() > 0
2419 }
2420 _ => false,
2421 };
2422 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(has_null))))
2423 }
2424 ColumnarValue::Array(arr) => {
2425 use arrow_array::{LargeListArray, ListArray};
2426
2427 let results: arrow::array::BooleanArray =
2428 if let Some(list_arr) = arr.as_any().downcast_ref::<ListArray>() {
2429 (0..list_arr.len())
2430 .map(|i| {
2431 if list_arr.is_null(i) {
2432 None
2433 } else {
2434 Some(check_list_nulls(list_arr, i))
2435 }
2436 })
2437 .collect()
2438 } else if let Some(large) = arr.as_any().downcast_ref::<LargeListArray>() {
2439 (0..large.len())
2440 .map(|i| {
2441 if large.is_null(i) {
2442 None
2443 } else {
2444 Some(check_list_nulls(large, i))
2445 }
2446 })
2447 .collect()
2448 } else {
2449 return Err(datafusion::error::DataFusionError::Execution(
2450 "_has_null(): requires list array".to_string(),
2451 ));
2452 };
2453 Ok(ColumnarValue::Array(Arc::new(results)))
2454 }
2455 }
2456 }
2457}
2458
2459pub fn create_to_integer_udf() -> ScalarUDF {
2464 ScalarUDF::new_from_impl(ToIntegerUdf::new())
2465}
2466
2467#[derive(Debug)]
2468struct ToIntegerUdf {
2469 signature: Signature,
2470}
2471
2472impl ToIntegerUdf {
2473 fn new() -> Self {
2474 Self {
2475 signature: Signature::any(1, Volatility::Immutable),
2476 }
2477 }
2478}
2479
2480impl_udf_eq_hash!(ToIntegerUdf);
2481
2482impl ScalarUDFImpl for ToIntegerUdf {
2483 fn as_any(&self) -> &dyn Any {
2484 self
2485 }
2486
2487 fn name(&self) -> &str {
2488 "tointeger"
2489 }
2490
2491 fn signature(&self) -> &Signature {
2492 &self.signature
2493 }
2494
2495 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2496 Ok(DataType::Int64)
2497 }
2498
2499 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2500 let output_type = self.return_type(&[])?;
2501 invoke_cypher_udf(args, &output_type, |val_args| {
2502 if val_args.is_empty() {
2503 return Err(datafusion::error::DataFusionError::Execution(
2504 "tointeger(): requires 1 argument".to_string(),
2505 ));
2506 }
2507
2508 let val = &val_args[0];
2509 match val {
2510 Value::Int(i) => Ok(Value::Int(*i)),
2511 Value::Float(f) => Ok(Value::Int(*f as i64)),
2512 Value::String(s) => {
2513 if let Ok(i) = s.parse::<i64>() {
2514 Ok(Value::Int(i))
2515 } else if let Ok(f) = s.parse::<f64>() {
2516 Ok(Value::Int(f as i64))
2517 } else {
2518 Ok(Value::Null)
2519 }
2520 }
2521 Value::Null => Ok(Value::Null),
2522 other => Err(datafusion::error::DataFusionError::Execution(format!(
2523 "InvalidArgumentValue: tointeger(): cannot convert {} to integer",
2524 cypher_type_name(other)
2525 ))),
2526 }
2527 })
2528 }
2529}
2530
2531pub fn create_to_float_udf() -> ScalarUDF {
2536 ScalarUDF::new_from_impl(ToFloatUdf::new())
2537}
2538
2539#[derive(Debug)]
2540struct ToFloatUdf {
2541 signature: Signature,
2542}
2543
2544impl ToFloatUdf {
2545 fn new() -> Self {
2546 Self {
2547 signature: Signature::any(1, Volatility::Immutable),
2548 }
2549 }
2550}
2551
2552impl_udf_eq_hash!(ToFloatUdf);
2553
2554impl ScalarUDFImpl for ToFloatUdf {
2555 fn as_any(&self) -> &dyn Any {
2556 self
2557 }
2558
2559 fn name(&self) -> &str {
2560 "tofloat"
2561 }
2562
2563 fn signature(&self) -> &Signature {
2564 &self.signature
2565 }
2566
2567 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2568 Ok(DataType::Float64)
2569 }
2570
2571 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2572 let output_type = self.return_type(&[])?;
2573 invoke_cypher_udf(args, &output_type, |val_args| {
2574 if val_args.is_empty() {
2575 return Err(datafusion::error::DataFusionError::Execution(
2576 "tofloat(): requires 1 argument".to_string(),
2577 ));
2578 }
2579
2580 let val = &val_args[0];
2581 match val {
2582 Value::Int(i) => Ok(Value::Float(*i as f64)),
2583 Value::Float(f) => Ok(Value::Float(*f)),
2584 Value::String(s) => {
2585 if let Ok(f) = s.parse::<f64>() {
2586 Ok(Value::Float(f))
2587 } else {
2588 Ok(Value::Null)
2589 }
2590 }
2591 Value::Null => Ok(Value::Null),
2592 other => Err(datafusion::error::DataFusionError::Execution(format!(
2593 "InvalidArgumentValue: tofloat(): cannot convert {} to float",
2594 cypher_type_name(other)
2595 ))),
2596 }
2597 })
2598 }
2599}
2600
2601pub fn create_to_boolean_udf() -> ScalarUDF {
2606 ScalarUDF::new_from_impl(ToBooleanUdf::new())
2607}
2608
2609#[derive(Debug)]
2610struct ToBooleanUdf {
2611 signature: Signature,
2612}
2613
2614impl ToBooleanUdf {
2615 fn new() -> Self {
2616 Self {
2617 signature: Signature::any(1, Volatility::Immutable),
2618 }
2619 }
2620}
2621
2622impl_udf_eq_hash!(ToBooleanUdf);
2623
2624impl ScalarUDFImpl for ToBooleanUdf {
2625 fn as_any(&self) -> &dyn Any {
2626 self
2627 }
2628
2629 fn name(&self) -> &str {
2630 "toboolean"
2631 }
2632
2633 fn signature(&self) -> &Signature {
2634 &self.signature
2635 }
2636
2637 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2638 Ok(DataType::Boolean)
2639 }
2640
2641 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2642 let output_type = self.return_type(&[])?;
2643 invoke_cypher_udf(args, &output_type, |val_args| {
2644 if val_args.is_empty() {
2645 return Err(datafusion::error::DataFusionError::Execution(
2646 "toboolean(): requires 1 argument".to_string(),
2647 ));
2648 }
2649
2650 let val = &val_args[0];
2651 match val {
2652 Value::Bool(b) => Ok(Value::Bool(*b)),
2653 Value::String(s) => {
2654 let s_lower = s.to_lowercase();
2655 if s_lower == "true" {
2656 Ok(Value::Bool(true))
2657 } else if s_lower == "false" {
2658 Ok(Value::Bool(false))
2659 } else {
2660 Ok(Value::Null)
2661 }
2662 }
2663 Value::Null => Ok(Value::Null),
2664 Value::Int(i) => Ok(Value::Bool(*i != 0)),
2665 other => Err(datafusion::error::DataFusionError::Execution(format!(
2666 "InvalidArgumentValue: toboolean(): cannot convert {} to boolean",
2667 cypher_type_name(other)
2668 ))),
2669 }
2670 })
2671 }
2672}
2673
2674pub fn create_cypher_sort_key_udf() -> ScalarUDF {
2681 ScalarUDF::new_from_impl(CypherSortKeyUdf::new())
2682}
2683
2684#[derive(Debug)]
2685struct CypherSortKeyUdf {
2686 signature: Signature,
2687}
2688
2689impl CypherSortKeyUdf {
2690 fn new() -> Self {
2691 Self {
2692 signature: Signature::any(1, Volatility::Immutable),
2693 }
2694 }
2695}
2696
2697impl_udf_eq_hash!(CypherSortKeyUdf);
2698
2699impl ScalarUDFImpl for CypherSortKeyUdf {
2700 fn as_any(&self) -> &dyn Any {
2701 self
2702 }
2703
2704 fn name(&self) -> &str {
2705 "_cypher_sort_key"
2706 }
2707
2708 fn signature(&self) -> &Signature {
2709 &self.signature
2710 }
2711
2712 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
2713 Ok(DataType::LargeBinary)
2714 }
2715
2716 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
2717 if args.args.len() != 1 {
2718 return Err(datafusion::error::DataFusionError::Execution(
2719 "_cypher_sort_key(): requires 1 argument".to_string(),
2720 ));
2721 }
2722
2723 let arg = &args.args[0];
2724 match arg {
2725 ColumnarValue::Scalar(s) => {
2726 let val = if s.is_null() {
2727 Value::Null
2728 } else {
2729 scalar_to_value(s)?
2730 };
2731 let key = encode_cypher_sort_key(&val);
2732 Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(key))))
2733 }
2734 ColumnarValue::Array(arr) => {
2735 let mut keys: Vec<Option<Vec<u8>>> = Vec::with_capacity(arr.len());
2736 for i in 0..arr.len() {
2737 let val = if arr.is_null(i) {
2738 Value::Null
2739 } else {
2740 get_value_from_array(arr, i)?
2741 };
2742 keys.push(Some(encode_cypher_sort_key(&val)));
2743 }
2744 let array = LargeBinaryArray::from(
2745 keys.iter()
2746 .map(|k| k.as_deref())
2747 .collect::<Vec<Option<&[u8]>>>(),
2748 );
2749 Ok(ColumnarValue::Array(Arc::new(array)))
2750 }
2751 }
2752 }
2753}
2754
2755pub fn encode_cypher_sort_key(value: &Value) -> Vec<u8> {
2761 let mut buf = Vec::with_capacity(32);
2762 encode_sort_key_to_buf(value, &mut buf);
2763 buf
2764}
2765
2766fn encode_sort_key_to_buf(value: &Value, buf: &mut Vec<u8>) {
2768 if let Value::Map(map) = value {
2770 if let Some(tv) = sort_key_map_as_temporal(map) {
2771 buf.push(0x07); encode_temporal_payload(&tv, buf);
2773 return;
2774 }
2775 let rank = sort_key_map_rank(map);
2776 if rank != 0 {
2777 buf.push(rank);
2779 match rank {
2780 0x01 => encode_map_as_node_payload(map, buf),
2781 0x02 => encode_map_as_edge_payload(map, buf),
2782 0x04 => encode_map_as_path_payload(map, buf),
2783 _ => {} }
2785 return;
2786 }
2787 }
2788
2789 if let Value::String(s) = value {
2791 if let Some(tv) = sort_key_string_as_temporal(s) {
2792 buf.push(0x07); encode_temporal_payload(&tv, buf);
2794 return;
2795 }
2796 if let Some(temporal_type) = crate::query::datetime::classify_temporal(s) {
2799 buf.push(0x07); if encode_wide_temporal_sort_key(s, temporal_type, buf) {
2801 return;
2802 }
2803 buf.pop();
2805 }
2806 }
2807
2808 let rank = sort_key_type_rank(value);
2809 buf.push(rank);
2810
2811 match value {
2812 Value::Null => {} Value::Float(f) if f.is_nan() => {} Value::Bool(b) => buf.push(if *b { 0x01 } else { 0x00 }),
2815 Value::Int(i) => {
2816 let f = *i as f64;
2817 buf.extend_from_slice(&encode_order_preserving_f64(f));
2818 }
2819 Value::Float(f) => {
2820 buf.extend_from_slice(&encode_order_preserving_f64(*f));
2821 }
2822 Value::String(s) => {
2823 byte_stuff_terminate(s.as_bytes(), buf);
2824 }
2825 Value::Temporal(tv) => {
2826 encode_temporal_payload(tv, buf);
2827 }
2828 Value::List(items) => {
2829 encode_list_payload(items, buf);
2830 }
2831 Value::Map(map) => {
2832 encode_map_payload(map, buf);
2833 }
2834 Value::Node(node) => {
2835 encode_node_payload(node, buf);
2836 }
2837 Value::Edge(edge) => {
2838 encode_edge_payload(edge, buf);
2839 }
2840 Value::Path(path) => {
2841 encode_path_payload(path, buf);
2842 }
2843 Value::Bytes(b) => {
2845 byte_stuff_terminate(b, buf);
2846 }
2847 Value::Vector(v) => {
2848 for f in v {
2849 buf.extend_from_slice(&encode_order_preserving_f64(*f as f64));
2850 }
2851 }
2852 _ => {} }
2854}
2855
2856fn sort_key_type_rank(v: &Value) -> u8 {
2860 match v {
2861 Value::Map(map) => sort_key_map_rank(map),
2862 Value::Node(_) => 0x01,
2863 Value::Edge(_) => 0x02,
2864 Value::List(_) => 0x03,
2865 Value::Path(_) => 0x04,
2866 Value::String(_) => 0x05,
2867 Value::Bool(_) => 0x06,
2868 Value::Temporal(_) => 0x07,
2869 Value::Int(_) => 0x08,
2870 Value::Float(f) if f.is_nan() => 0x09,
2871 Value::Float(_) => 0x08,
2872 Value::Null => 0x0A,
2873 Value::Bytes(_) | Value::Vector(_) => 0x0B,
2874 _ => 0x0B, }
2876}
2877
2878fn sort_key_map_rank(map: &std::collections::HashMap<String, Value>) -> u8 {
2880 if sort_key_map_as_temporal(map).is_some() {
2881 0x07
2882 } else if map.contains_key("nodes")
2883 && (map.contains_key("relationships") || map.contains_key("edges"))
2884 {
2885 0x04 } else if map.contains_key("_eid")
2887 || map.contains_key("_src")
2888 || map.contains_key("_dst")
2889 || map.contains_key("_type")
2890 || map.contains_key("_type_name")
2891 {
2892 0x02 } else if map.contains_key("_vid") || map.contains_key("_labels") || map.contains_key("_label")
2894 {
2895 0x01 } else {
2897 0x00 }
2899}
2900
2901fn sort_key_map_as_temporal(
2905 map: &std::collections::HashMap<String, Value>,
2906) -> Option<uni_common::TemporalValue> {
2907 super::expr_eval::temporal_from_map_wrapper(map)
2908}
2909
2910fn sort_key_string_as_temporal(s: &str) -> Option<uni_common::TemporalValue> {
2914 super::expr_eval::temporal_from_value(&Value::String(s.to_string()))
2915}
2916
2917fn encode_wide_temporal_sort_key(
2924 s: &str,
2925 temporal_type: uni_common::TemporalType,
2926 buf: &mut Vec<u8>,
2927) -> bool {
2928 match temporal_type {
2929 uni_common::TemporalType::LocalDateTime => {
2930 if let Some(ndt) = parse_naive_datetime(s) {
2931 buf.push(0x03); let wide_nanos = naive_datetime_to_wide_nanos(&ndt);
2933 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
2934 return true;
2935 }
2936 false
2937 }
2938 uni_common::TemporalType::DateTime => {
2939 let base = if let Some(bracket_pos) = s.find('[') {
2941 &s[..bracket_pos]
2942 } else {
2943 s
2944 };
2945 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%.f%:z") {
2946 buf.push(0x04); let utc = dt.naive_utc();
2948 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
2949 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
2950 return true;
2951 }
2952 if let Ok(dt) = chrono::DateTime::parse_from_str(base, "%Y-%m-%dT%H:%M:%S%:z") {
2953 buf.push(0x04); let utc = dt.naive_utc();
2955 let wide_nanos = naive_datetime_to_wide_nanos(&utc);
2956 buf.extend_from_slice(&encode_order_preserving_i128(wide_nanos));
2957 return true;
2958 }
2959 false
2960 }
2961 uni_common::TemporalType::Date => {
2962 if let Ok(nd) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
2963 && let Some(epoch) = chrono::NaiveDate::from_ymd_opt(1970, 1, 1)
2964 {
2965 buf.push(0x00); let days = nd.signed_duration_since(epoch).num_days() as i32;
2967 buf.extend_from_slice(&encode_order_preserving_i32(days));
2968 return true;
2969 }
2970 false
2971 }
2972 _ => false,
2973 }
2974}
2975
2976fn parse_naive_datetime(s: &str) -> Option<chrono::NaiveDateTime> {
2978 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
2979 .ok()
2980 .or_else(|| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").ok())
2981}
2982
2983fn naive_datetime_to_wide_nanos(ndt: &chrono::NaiveDateTime) -> i128 {
2986 let secs = ndt.and_utc().timestamp() as i128;
2987 let subsec_nanos = ndt.and_utc().timestamp_subsec_nanos() as i128;
2988 secs * 1_000_000_000 + subsec_nanos
2989}
2990
2991fn encode_map_as_node_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
2993 let mut labels: Vec<String> = Vec::new();
2995 if let Some(Value::List(lbls)) = map.get("_labels") {
2996 for l in lbls {
2997 if let Value::String(s) = l {
2998 labels.push(s.clone());
2999 }
3000 }
3001 } else if let Some(Value::String(lbl)) = map.get("_label") {
3002 labels.push(lbl.clone());
3003 }
3004 labels.sort();
3005
3006 let vid = map.get("_vid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3008
3009 let labels_joined = labels.join("\x01");
3011 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3012
3013 buf.extend_from_slice(&vid.to_be_bytes());
3015
3016 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3018 for (k, v) in map {
3019 if !k.starts_with('_') {
3020 props.insert(k.clone(), v.clone());
3021 }
3022 }
3023 encode_map_payload(&props, buf);
3024}
3025
3026fn encode_map_as_edge_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3028 let edge_type = map
3029 .get("_type")
3030 .or_else(|| map.get("_type_name"))
3031 .and_then(|v| {
3032 if let Value::String(s) = v {
3033 Some(s.as_str())
3034 } else {
3035 None
3036 }
3037 })
3038 .unwrap_or("");
3039
3040 byte_stuff_terminate(edge_type.as_bytes(), buf);
3041
3042 let src = map.get("_src").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3043 let dst = map.get("_dst").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3044 let eid = map.get("_eid").and_then(|v| v.as_i64()).unwrap_or(0) as u64;
3045
3046 buf.extend_from_slice(&src.to_be_bytes());
3047 buf.extend_from_slice(&dst.to_be_bytes());
3048 buf.extend_from_slice(&eid.to_be_bytes());
3049
3050 let mut props: std::collections::HashMap<String, Value> = std::collections::HashMap::new();
3052 for (k, v) in map {
3053 if !k.starts_with('_') {
3054 props.insert(k.clone(), v.clone());
3055 }
3056 }
3057 encode_map_payload(&props, buf);
3058}
3059
3060fn encode_map_as_path_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3062 if let Some(Value::List(nodes)) = map.get("nodes") {
3064 encode_list_payload(nodes, buf);
3065 } else {
3066 buf.push(0x00); }
3068 let edges = map.get("relationships").or_else(|| map.get("edges"));
3070 if let Some(Value::List(edges)) = edges {
3071 encode_list_payload(edges, buf);
3072 } else {
3073 buf.push(0x00); }
3075}
3076
3077fn encode_order_preserving_f64(f: f64) -> [u8; 8] {
3084 let bits = f.to_bits();
3085 let encoded = if bits >> 63 == 1 {
3086 !bits
3088 } else {
3089 bits ^ (1u64 << 63)
3091 };
3092 encoded.to_be_bytes()
3093}
3094
3095fn encode_order_preserving_i64(i: i64) -> [u8; 8] {
3097 ((i as u64) ^ (1u64 << 63)).to_be_bytes()
3099}
3100
3101fn encode_order_preserving_i32(i: i32) -> [u8; 4] {
3103 ((i as u32) ^ (1u32 << 31)).to_be_bytes()
3104}
3105
3106fn encode_order_preserving_i128(i: i128) -> [u8; 16] {
3108 ((i as u128) ^ (1u128 << 127)).to_be_bytes()
3109}
3110
3111fn byte_stuff_terminate(data: &[u8], buf: &mut Vec<u8>) {
3116 byte_stuff(data, buf);
3117 buf.push(0x00);
3118 buf.push(0x00);
3119}
3120
3121fn byte_stuff(data: &[u8], buf: &mut Vec<u8>) {
3123 for &b in data {
3124 buf.push(b);
3125 if b == 0x00 {
3126 buf.push(0xFF);
3127 }
3128 }
3129}
3130
3131fn encode_list_payload(items: &[Value], buf: &mut Vec<u8>) {
3136 for item in items {
3137 buf.push(0x01); let elem_key = encode_cypher_sort_key(item);
3139 byte_stuff_terminate(&elem_key, buf);
3140 }
3141 buf.push(0x00); }
3143
3144fn encode_map_payload(map: &std::collections::HashMap<String, Value>, buf: &mut Vec<u8>) {
3146 let mut pairs: Vec<(&String, &Value)> = map.iter().collect();
3147 pairs.sort_by_key(|(k, _)| *k);
3148
3149 for (key, value) in pairs {
3150 buf.push(0x01); byte_stuff_terminate(key.as_bytes(), buf);
3152 let val_key = encode_cypher_sort_key(value);
3153 byte_stuff_terminate(&val_key, buf);
3154 }
3155 buf.push(0x00); }
3157
3158fn encode_node_payload(node: &uni_common::Node, buf: &mut Vec<u8>) {
3162 let mut labels = node.labels.clone();
3163 labels.sort();
3164 let labels_joined = labels.join("\x01");
3165 byte_stuff_terminate(labels_joined.as_bytes(), buf);
3166
3167 buf.extend_from_slice(&node.vid.as_u64().to_be_bytes());
3168
3169 encode_map_payload(&node.properties, buf);
3170}
3171
3172fn encode_edge_payload(edge: &uni_common::Edge, buf: &mut Vec<u8>) {
3176 byte_stuff_terminate(edge.edge_type.as_bytes(), buf);
3177
3178 buf.extend_from_slice(&edge.src.as_u64().to_be_bytes());
3179 buf.extend_from_slice(&edge.dst.as_u64().to_be_bytes());
3180 buf.extend_from_slice(&edge.eid.as_u64().to_be_bytes());
3181
3182 encode_map_payload(&edge.properties, buf);
3183}
3184
3185fn encode_path_payload(path: &uni_common::Path, buf: &mut Vec<u8>) {
3189 for node in &path.nodes {
3191 buf.push(0x01); let mut node_key = Vec::new();
3193 node_key.push(0x01); encode_node_payload(node, &mut node_key);
3195 byte_stuff_terminate(&node_key, buf);
3196 }
3197 buf.push(0x00); for edge in &path.edges {
3201 buf.push(0x01); let mut edge_key = Vec::new();
3203 edge_key.push(0x02); encode_edge_payload(edge, &mut edge_key);
3205 byte_stuff_terminate(&edge_key, buf);
3206 }
3207 buf.push(0x00); }
3209
3210fn encode_temporal_payload(tv: &uni_common::TemporalValue, buf: &mut Vec<u8>) {
3212 match tv {
3213 uni_common::TemporalValue::Date { days_since_epoch } => {
3214 buf.push(0x00); buf.extend_from_slice(&encode_order_preserving_i32(*days_since_epoch));
3216 }
3217 uni_common::TemporalValue::LocalTime {
3218 nanos_since_midnight,
3219 } => {
3220 buf.push(0x01); buf.extend_from_slice(&encode_order_preserving_i64(*nanos_since_midnight));
3222 }
3223 uni_common::TemporalValue::Time {
3224 nanos_since_midnight,
3225 offset_seconds,
3226 } => {
3227 buf.push(0x02); let utc_nanos =
3229 *nanos_since_midnight as i128 - (*offset_seconds as i128) * 1_000_000_000;
3230 buf.extend_from_slice(&encode_order_preserving_i128(utc_nanos));
3231 }
3232 uni_common::TemporalValue::LocalDateTime { nanos_since_epoch } => {
3233 buf.push(0x03); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3236 }
3237 uni_common::TemporalValue::DateTime {
3238 nanos_since_epoch, ..
3239 } => {
3240 buf.push(0x04); buf.extend_from_slice(&encode_order_preserving_i128(*nanos_since_epoch as i128));
3243 }
3244 uni_common::TemporalValue::Duration {
3245 months,
3246 days,
3247 nanos,
3248 } => {
3249 buf.push(0x05); buf.extend_from_slice(&encode_order_preserving_i64(*months));
3251 buf.extend_from_slice(&encode_order_preserving_i64(*days));
3252 buf.extend_from_slice(&encode_order_preserving_i64(*nanos));
3253 }
3254 uni_common::TemporalValue::Btic { lo, hi, meta } => {
3255 buf.push(0x06); if let Ok(btic) = uni_btic::Btic::new(*lo, *hi, *meta) {
3258 buf.extend_from_slice(&uni_btic::encode::encode(&btic));
3259 } else {
3260 buf.extend_from_slice(&encode_order_preserving_i64(*lo));
3261 buf.extend_from_slice(&encode_order_preserving_i64(*hi));
3262 }
3263 }
3264 }
3265}
3266
3267#[derive(Debug)]
3276struct BticScalarUdf {
3277 name: String,
3278 signature: Signature,
3279 return_type: DataType,
3280}
3281
3282impl BticScalarUdf {
3283 fn new(name: &str, num_args: usize, return_type: DataType) -> Self {
3284 Self {
3285 name: name.to_string(),
3286 signature: Signature::new(TypeSignature::Any(num_args), Volatility::Immutable),
3287 return_type,
3288 }
3289 }
3290}
3291
3292impl_udf_eq_hash!(BticScalarUdf);
3293
3294impl ScalarUDFImpl for BticScalarUdf {
3295 fn as_any(&self) -> &dyn Any {
3296 self
3297 }
3298 fn name(&self) -> &str {
3299 &self.name
3300 }
3301 fn signature(&self) -> &Signature {
3302 &self.signature
3303 }
3304 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
3305 Ok(self.return_type.clone())
3306 }
3307 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3308 let fname = self.name.to_uppercase();
3309 let rt = self.return_type.clone();
3310 invoke_cypher_udf(args, &rt, |val_args| {
3311 crate::query::expr_eval::eval_btic_function(&fname, val_args)
3312 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
3313 })
3314 }
3315}
3316
3317fn register_btic_scalar_udfs(ctx: &SessionContext) -> DFResult<()> {
3319 for name in &["btic_lo", "btic_hi"] {
3321 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3322 name,
3323 1,
3324 DataType::LargeBinary,
3325 )));
3326 }
3327 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3329 "btic_duration",
3330 1,
3331 DataType::Int64,
3332 )));
3333 for name in &["btic_is_instant", "btic_is_unbounded", "btic_is_finite"] {
3335 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3336 name,
3337 1,
3338 DataType::Boolean,
3339 )));
3340 }
3341 for name in &[
3343 "btic_granularity",
3344 "btic_lo_granularity",
3345 "btic_hi_granularity",
3346 "btic_certainty",
3347 "btic_lo_certainty",
3348 "btic_hi_certainty",
3349 ] {
3350 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3351 name,
3352 1,
3353 DataType::Utf8,
3354 )));
3355 }
3356 for name in &[
3358 "btic_contains_point",
3359 "btic_overlaps",
3360 "btic_contains",
3361 "btic_before",
3362 "btic_after",
3363 "btic_meets",
3364 "btic_adjacent",
3365 "btic_disjoint",
3366 "btic_equals",
3367 "btic_starts",
3368 "btic_during",
3369 "btic_finishes",
3370 ] {
3371 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3372 name,
3373 2,
3374 DataType::Boolean,
3375 )));
3376 }
3377 for name in &["btic_intersection", "btic_span", "btic_gap"] {
3379 ctx.register_udf(ScalarUDF::new_from_impl(BticScalarUdf::new(
3380 name,
3381 2,
3382 DataType::LargeBinary,
3383 )));
3384 }
3385 Ok(())
3386}
3387
3388#[derive(Debug, Clone)]
3394struct BticMinMaxUdaf {
3395 name: String,
3396 signature: Signature,
3397 is_max: bool,
3398}
3399
3400impl BticMinMaxUdaf {
3401 fn new(is_max: bool) -> Self {
3402 Self {
3403 name: (if is_max { "btic_max" } else { "btic_min" }).to_string(),
3404 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3405 is_max,
3406 }
3407 }
3408}
3409
3410impl_udf_eq_hash!(BticMinMaxUdaf);
3411
3412impl AggregateUDFImpl for BticMinMaxUdaf {
3413 fn as_any(&self) -> &dyn Any {
3414 self
3415 }
3416 fn name(&self) -> &str {
3417 &self.name
3418 }
3419 fn signature(&self) -> &Signature {
3420 &self.signature
3421 }
3422 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3423 Ok(DataType::LargeBinary)
3424 }
3425 fn accumulator(
3426 &self,
3427 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3428 ) -> DFResult<Box<dyn DfAccumulator>> {
3429 Ok(Box::new(BticMinMaxAccumulator {
3430 current: None,
3431 is_max: self.is_max,
3432 }))
3433 }
3434 fn state_fields(
3435 &self,
3436 args: datafusion::logical_expr::function::StateFieldsArgs,
3437 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3438 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3439 args.name,
3440 DataType::LargeBinary,
3441 true,
3442 ))])
3443 }
3444}
3445
3446#[derive(Debug)]
3447struct BticMinMaxAccumulator {
3448 current: Option<uni_btic::Btic>,
3449 is_max: bool,
3450}
3451
3452impl DfAccumulator for BticMinMaxAccumulator {
3453 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3454 let arr = &values[0];
3455 for i in 0..arr.len() {
3456 if arr.is_null(i) {
3457 continue;
3458 }
3459 let Some(btic) = decode_btic_from_array(arr, i)? else {
3460 continue;
3461 };
3462 self.current = Some(match self.current.take() {
3463 None => btic,
3464 Some(cur) => {
3465 if (self.is_max && btic > cur) || (!self.is_max && btic < cur) {
3466 btic
3467 } else {
3468 cur
3469 }
3470 }
3471 });
3472 }
3473 Ok(())
3474 }
3475 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3476 Ok(btic_to_scalar_value(self.current.as_ref()))
3477 }
3478 fn size(&self) -> usize {
3479 std::mem::size_of::<Self>()
3480 }
3481 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3482 Ok(vec![self.evaluate()?])
3483 }
3484 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3485 self.update_batch(states)
3486 }
3487}
3488
3489#[derive(Debug, Clone)]
3491struct BticSpanAggUdaf {
3492 signature: Signature,
3493}
3494
3495impl BticSpanAggUdaf {
3496 fn new() -> Self {
3497 Self {
3498 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
3499 }
3500 }
3501}
3502
3503impl_udf_eq_hash!(BticSpanAggUdaf);
3504
3505impl AggregateUDFImpl for BticSpanAggUdaf {
3506 fn as_any(&self) -> &dyn Any {
3507 self
3508 }
3509 fn name(&self) -> &str {
3510 "btic_span_agg"
3511 }
3512 fn signature(&self) -> &Signature {
3513 &self.signature
3514 }
3515 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3516 Ok(DataType::LargeBinary)
3517 }
3518 fn accumulator(
3519 &self,
3520 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3521 ) -> DFResult<Box<dyn DfAccumulator>> {
3522 Ok(Box::new(BticSpanAggAccumulator { current: None }))
3523 }
3524 fn state_fields(
3525 &self,
3526 args: datafusion::logical_expr::function::StateFieldsArgs,
3527 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3528 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3529 args.name,
3530 DataType::LargeBinary,
3531 true,
3532 ))])
3533 }
3534}
3535
3536#[derive(Debug)]
3537struct BticSpanAggAccumulator {
3538 current: Option<uni_btic::Btic>,
3539}
3540
3541impl DfAccumulator for BticSpanAggAccumulator {
3542 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3543 let arr = &values[0];
3544 for i in 0..arr.len() {
3545 if arr.is_null(i) {
3546 continue;
3547 }
3548 let Some(btic) = decode_btic_from_array(arr, i)? else {
3549 continue;
3550 };
3551 self.current = Some(match self.current.take() {
3552 None => btic,
3553 Some(cur) => uni_btic::set_ops::span(&cur, &btic),
3554 });
3555 }
3556 Ok(())
3557 }
3558 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3559 Ok(btic_to_scalar_value(self.current.as_ref()))
3560 }
3561 fn size(&self) -> usize {
3562 std::mem::size_of::<Self>()
3563 }
3564 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3565 Ok(vec![self.evaluate()?])
3566 }
3567 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3568 self.update_batch(states)
3569 }
3570}
3571
3572#[derive(Debug, Clone)]
3574struct BticCountAtUdaf {
3575 signature: Signature,
3576}
3577
3578impl BticCountAtUdaf {
3579 fn new() -> Self {
3580 Self {
3581 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
3582 }
3583 }
3584}
3585
3586impl_udf_eq_hash!(BticCountAtUdaf);
3587
3588impl AggregateUDFImpl for BticCountAtUdaf {
3589 fn as_any(&self) -> &dyn Any {
3590 self
3591 }
3592 fn name(&self) -> &str {
3593 "btic_count_at"
3594 }
3595 fn signature(&self) -> &Signature {
3596 &self.signature
3597 }
3598 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
3599 Ok(DataType::Int64)
3600 }
3601 fn accumulator(
3602 &self,
3603 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
3604 ) -> DFResult<Box<dyn DfAccumulator>> {
3605 Ok(Box::new(BticCountAtAccumulator { count: 0 }))
3606 }
3607 fn state_fields(
3608 &self,
3609 args: datafusion::logical_expr::function::StateFieldsArgs,
3610 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
3611 Ok(vec![Arc::new(arrow::datatypes::Field::new(
3612 args.name,
3613 DataType::Int64,
3614 true,
3615 ))])
3616 }
3617}
3618
3619#[derive(Debug)]
3620struct BticCountAtAccumulator {
3621 count: i64,
3622}
3623
3624impl DfAccumulator for BticCountAtAccumulator {
3625 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
3626 if values.len() < 2 {
3627 return Ok(());
3628 }
3629 let btic_arr = &values[0];
3630 let point_arr = &values[1];
3631
3632 for i in 0..btic_arr.len() {
3633 if btic_arr.is_null(i) || point_arr.is_null(i) {
3634 continue;
3635 }
3636 let Some(btic) = decode_btic_from_array(btic_arr, i)? else {
3637 continue;
3638 };
3639
3640 let point_ms = if let Some(int_arr) = point_arr.as_any().downcast_ref::<Int64Array>() {
3642 int_arr.value(i)
3643 } else if let Some(lb) = point_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3644 let val = scalar_binary_to_value(lb.value(i));
3645 match &val {
3646 Value::Int(ms) => *ms,
3647 Value::Temporal(uni_common::TemporalValue::DateTime {
3648 nanos_since_epoch,
3649 ..
3650 }) => nanos_since_epoch / 1_000_000,
3651 _ => continue,
3652 }
3653 } else {
3654 continue;
3655 };
3656
3657 if uni_btic::predicates::contains_point(&btic, point_ms) {
3658 self.count += 1;
3659 }
3660 }
3661 Ok(())
3662 }
3663 fn evaluate(&mut self) -> DFResult<ScalarValue> {
3664 Ok(ScalarValue::Int64(Some(self.count)))
3665 }
3666 fn size(&self) -> usize {
3667 std::mem::size_of::<Self>()
3668 }
3669 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
3670 Ok(vec![ScalarValue::Int64(Some(self.count))])
3671 }
3672 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
3673 let arr = &states[0];
3674 if let Some(int_arr) = arr.as_any().downcast_ref::<Int64Array>() {
3675 for i in 0..int_arr.len() {
3676 if !int_arr.is_null(i) {
3677 self.count += int_arr.value(i);
3678 }
3679 }
3680 }
3681 Ok(())
3682 }
3683}
3684
3685fn btic_to_scalar_value(btic: Option<&uni_btic::Btic>) -> ScalarValue {
3687 match btic {
3688 None => ScalarValue::LargeBinary(None),
3689 Some(b) => {
3690 let val = Value::Temporal(uni_common::TemporalValue::Btic {
3691 lo: b.lo(),
3692 hi: b.hi(),
3693 meta: b.meta(),
3694 });
3695 ScalarValue::LargeBinary(Some(uni_common::cypher_value_codec::encode(&val)))
3696 }
3697 }
3698}
3699
3700fn decode_btic_from_array(arr: &ArrayRef, row: usize) -> DFResult<Option<uni_btic::Btic>> {
3702 if let Some(fsb) = arr.as_any().downcast_ref::<FixedSizeBinaryArray>() {
3704 let bytes = fsb.value(row);
3705 return uni_btic::encode::decode_slice(bytes)
3706 .map(Some)
3707 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3708 }
3709 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
3711 let val = scalar_binary_to_value(lb.value(row));
3712 if let Value::Temporal(uni_common::TemporalValue::Btic { lo, hi, meta }) = val {
3713 return uni_btic::Btic::new(lo, hi, meta)
3714 .map(Some)
3715 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()));
3716 }
3717 return Ok(None);
3718 }
3719 Ok(None)
3720}
3721
3722pub(crate) fn create_btic_min_udaf() -> AggregateUDF {
3723 AggregateUDF::from(BticMinMaxUdaf::new(false))
3724}
3725
3726pub(crate) fn create_btic_max_udaf() -> AggregateUDF {
3727 AggregateUDF::from(BticMinMaxUdaf::new(true))
3728}
3729
3730pub(crate) fn create_btic_span_agg_udaf() -> AggregateUDF {
3731 AggregateUDF::from(BticSpanAggUdaf::new())
3732}
3733
3734pub(crate) fn create_btic_count_at_udaf() -> AggregateUDF {
3735 AggregateUDF::from(BticCountAtUdaf::new())
3736}
3737
3738pub fn invoke_cypher_string_op<F>(
3743 args: &ScalarFunctionArgs,
3744 name: &str,
3745 op: F,
3746) -> DFResult<ColumnarValue>
3747where
3748 F: Fn(&str, &str) -> bool,
3749{
3750 use arrow_array::{BooleanArray, LargeBinaryArray, LargeStringArray, StringArray};
3751 use datafusion::common::ScalarValue;
3752 use datafusion::error::DataFusionError;
3753
3754 if args.args.len() != 2 {
3755 return Err(DataFusionError::Execution(format!(
3756 "{}(): requires exactly 2 arguments",
3757 name
3758 )));
3759 }
3760
3761 let left = &args.args[0];
3762 let right = &args.args[1];
3763
3764 let extract_string = |scalar: &ScalarValue| -> Option<String> {
3766 match scalar {
3767 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => Some(s.clone()),
3768 ScalarValue::LargeBinary(Some(bytes)) => {
3769 match uni_common::cypher_value_codec::decode(bytes) {
3771 Ok(uni_common::Value::String(s)) => Some(s),
3772 _ => None,
3773 }
3774 }
3775 ScalarValue::Utf8(None)
3776 | ScalarValue::LargeUtf8(None)
3777 | ScalarValue::LargeBinary(None)
3778 | ScalarValue::Null => None,
3779 _ => None,
3780 }
3781 };
3782
3783 match (left, right) {
3784 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Scalar(r_scalar)) => {
3785 let l_str = extract_string(l_scalar);
3786 let r_str = extract_string(r_scalar);
3787
3788 match (l_str, r_str) {
3789 (Some(l), Some(r)) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(op(
3790 &l, &r,
3791 ))))),
3792 _ => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
3793 }
3794 }
3795 (ColumnarValue::Array(l_arr), ColumnarValue::Scalar(r_scalar)) => {
3796 let r_val = extract_string(r_scalar);
3798
3799 if r_val.is_none() {
3800 let nulls = arrow_array::new_null_array(&DataType::Boolean, l_arr.len());
3802 return Ok(ColumnarValue::Array(nulls));
3803 }
3804 let pattern = r_val.unwrap();
3805
3806 let result_array = if let Some(arr) = l_arr.as_any().downcast_ref::<StringArray>() {
3808 arr.iter()
3809 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3810 .collect::<BooleanArray>()
3811 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeStringArray>() {
3812 arr.iter()
3813 .map(|opt_s| opt_s.map(|s| op(s, &pattern)))
3814 .collect::<BooleanArray>()
3815 } else if let Some(arr) = l_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3816 arr.iter()
3818 .map(|opt_bytes| {
3819 opt_bytes.and_then(|bytes| {
3820 match uni_common::cypher_value_codec::decode(bytes) {
3821 Ok(uni_common::Value::String(s)) => Some(op(&s, &pattern)),
3822 _ => None,
3823 }
3824 })
3825 })
3826 .collect::<BooleanArray>()
3827 } else {
3828 arrow_array::new_null_array(&DataType::Boolean, l_arr.len())
3830 .as_any()
3831 .downcast_ref::<BooleanArray>()
3832 .unwrap()
3833 .clone()
3834 };
3835
3836 Ok(ColumnarValue::Array(Arc::new(result_array)))
3837 }
3838 (ColumnarValue::Scalar(l_scalar), ColumnarValue::Array(r_arr)) => {
3839 let l_val = extract_string(l_scalar);
3841
3842 if l_val.is_none() {
3843 let nulls = arrow_array::new_null_array(&DataType::Boolean, r_arr.len());
3844 return Ok(ColumnarValue::Array(nulls));
3845 }
3846 let target = l_val.unwrap();
3847
3848 let result_array = if let Some(arr) = r_arr.as_any().downcast_ref::<StringArray>() {
3849 arr.iter()
3850 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3851 .collect::<BooleanArray>()
3852 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeStringArray>() {
3853 arr.iter()
3854 .map(|opt_s| opt_s.map(|s| op(&target, s)))
3855 .collect::<BooleanArray>()
3856 } else if let Some(arr) = r_arr.as_any().downcast_ref::<LargeBinaryArray>() {
3857 arr.iter()
3859 .map(|opt_bytes| {
3860 opt_bytes.and_then(|bytes| {
3861 match uni_common::cypher_value_codec::decode(bytes) {
3862 Ok(uni_common::Value::String(s)) => Some(op(&target, &s)),
3863 _ => None,
3864 }
3865 })
3866 })
3867 .collect::<BooleanArray>()
3868 } else {
3869 arrow_array::new_null_array(&DataType::Boolean, r_arr.len())
3871 .as_any()
3872 .downcast_ref::<BooleanArray>()
3873 .unwrap()
3874 .clone()
3875 };
3876
3877 Ok(ColumnarValue::Array(Arc::new(result_array)))
3878 }
3879 (ColumnarValue::Array(l_arr), ColumnarValue::Array(r_arr)) => {
3880 if l_arr.len() != r_arr.len() {
3882 return Err(DataFusionError::Execution(format!(
3883 "{}(): array lengths must match",
3884 name
3885 )));
3886 }
3887
3888 let extract_string_at = |arr: &dyn Array, idx: usize| -> Option<String> {
3890 if let Some(str_arr) = arr.as_any().downcast_ref::<StringArray>() {
3891 str_arr.value(idx).to_string().into()
3892 } else if let Some(str_arr) = arr.as_any().downcast_ref::<LargeStringArray>() {
3893 str_arr.value(idx).to_string().into()
3894 } else if let Some(bin_arr) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
3895 if bin_arr.is_null(idx) {
3896 return None;
3897 }
3898 let bytes = bin_arr.value(idx);
3899 match uni_common::cypher_value_codec::decode(bytes) {
3900 Ok(uni_common::Value::String(s)) => Some(s),
3901 _ => None,
3902 }
3903 } else {
3904 None
3905 }
3906 };
3907
3908 let result: BooleanArray = (0..l_arr.len())
3909 .map(|idx| {
3910 match (
3911 extract_string_at(l_arr.as_ref(), idx),
3912 extract_string_at(r_arr.as_ref(), idx),
3913 ) {
3914 (Some(l_str), Some(r_str)) => Some(op(&l_str, &r_str)),
3915 _ => None,
3916 }
3917 })
3918 .collect();
3919
3920 Ok(ColumnarValue::Array(Arc::new(result)))
3921 }
3922 }
3923}
3924
3925macro_rules! define_string_op_udf {
3926 ($struct_name:ident, $udf_name:literal, $op:expr) => {
3927 #[derive(Debug)]
3928 struct $struct_name {
3929 signature: Signature,
3930 }
3931
3932 impl $struct_name {
3933 fn new() -> Self {
3934 Self {
3935 signature: Signature::any(2, Volatility::Immutable),
3937 }
3938 }
3939 }
3940
3941 impl_udf_eq_hash!($struct_name);
3942
3943 impl ScalarUDFImpl for $struct_name {
3944 fn as_any(&self) -> &dyn Any {
3945 self
3946 }
3947 fn name(&self) -> &str {
3948 $udf_name
3949 }
3950 fn signature(&self) -> &Signature {
3951 &self.signature
3952 }
3953 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
3954 Ok(DataType::Boolean)
3955 }
3956
3957 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
3958 invoke_cypher_string_op(&args, $udf_name, $op)
3959 }
3960 }
3961 };
3962}
3963
3964define_string_op_udf!(CypherStartsWithUdf, "_cypher_starts_with", |s, p| s
3965 .starts_with(p));
3966define_string_op_udf!(CypherEndsWithUdf, "_cypher_ends_with", |s, p| s
3967 .ends_with(p));
3968define_string_op_udf!(CypherContainsUdf, "_cypher_contains", |s, p| s.contains(p));
3969
3970pub fn create_cypher_starts_with_udf() -> ScalarUDF {
3971 ScalarUDF::new_from_impl(CypherStartsWithUdf::new())
3972}
3973pub fn create_cypher_ends_with_udf() -> ScalarUDF {
3974 ScalarUDF::new_from_impl(CypherEndsWithUdf::new())
3975}
3976pub fn create_cypher_contains_udf() -> ScalarUDF {
3977 ScalarUDF::new_from_impl(CypherContainsUdf::new())
3978}
3979
3980pub fn create_cypher_equal_udf() -> ScalarUDF {
3981 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_equal", BinaryOp::Eq))
3982}
3983pub fn create_cypher_not_equal_udf() -> ScalarUDF {
3984 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_not_equal", BinaryOp::NotEq))
3985}
3986pub fn create_cypher_lt_udf() -> ScalarUDF {
3987 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt", BinaryOp::Lt))
3988}
3989pub fn create_cypher_lt_eq_udf() -> ScalarUDF {
3990 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_lt_eq", BinaryOp::LtEq))
3991}
3992pub fn create_cypher_gt_udf() -> ScalarUDF {
3993 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt", BinaryOp::Gt))
3994}
3995pub fn create_cypher_gt_eq_udf() -> ScalarUDF {
3996 ScalarUDF::new_from_impl(CypherCompareUdf::new("_cypher_gt_eq", BinaryOp::GtEq))
3997}
3998
3999#[expect(clippy::match_like_matches_macro)]
4001fn apply_comparison_op(ord: std::cmp::Ordering, op: &BinaryOp) -> bool {
4002 use std::cmp::Ordering;
4003 match (ord, op) {
4004 (Ordering::Less, BinaryOp::Lt | BinaryOp::LtEq | BinaryOp::NotEq) => true,
4005 (Ordering::Equal, BinaryOp::Eq | BinaryOp::LtEq | BinaryOp::GtEq) => true,
4006 (Ordering::Greater, BinaryOp::Gt | BinaryOp::GtEq | BinaryOp::NotEq) => true,
4007 _ => false,
4008 }
4009}
4010
4011fn compare_f64(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<bool> {
4014 if lhs.is_nan() || rhs.is_nan() {
4015 Some(matches!(op, BinaryOp::NotEq))
4016 } else {
4017 Some(apply_comparison_op(lhs.partial_cmp(&rhs)?, op))
4018 }
4019}
4020
4021fn cv_bytes_as_f64(bytes: &[u8]) -> Option<f64> {
4023 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4024 match peek_tag(bytes)? {
4025 TAG_INT => decode_int(bytes).map(|i| i as f64),
4026 TAG_FLOAT => decode_float(bytes),
4027 _ => None,
4028 }
4029}
4030
4031fn compare_cv_numeric(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<bool> {
4034 use uni_common::cypher_value_codec::{TAG_INT, TAG_NULL, decode_int, peek_tag};
4035 if peek_tag(bytes) == Some(TAG_INT)
4037 && let Some(lhs_int) = decode_int(bytes)
4038 && rhs.fract() == 0.0
4040 && rhs >= i64::MIN as f64
4041 && rhs <= i64::MAX as f64
4042 {
4043 return Some(apply_comparison_op(lhs_int.cmp(&(rhs as i64)), op));
4044 }
4045 if peek_tag(bytes) == Some(TAG_NULL) {
4046 return None;
4047 }
4048 let lhs = cv_bytes_as_f64(bytes)?;
4049 compare_f64(lhs, rhs, op)
4050}
4051
4052fn try_fast_compare(
4056 lhs: &ColumnarValue,
4057 rhs: &ColumnarValue,
4058 op: &BinaryOp,
4059) -> Option<ColumnarValue> {
4060 use arrow_array::builder::BooleanBuilder;
4061 use uni_common::cypher_value_codec::{
4062 TAG_INT, TAG_NULL, TAG_STRING, decode_int, decode_string, peek_tag,
4063 };
4064
4065 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4066 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4067 _ => return None,
4068 };
4069
4070 if !matches!(lhs_arr.data_type(), DataType::LargeBinary) {
4072 return None;
4073 }
4074
4075 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4076
4077 match rhs_arr.data_type() {
4078 DataType::Int64 => {
4080 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4081 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4082 for i in 0..lb_arr.len() {
4083 if lb_arr.is_null(i) || int_arr.is_null(i) {
4084 builder.append_null();
4085 } else {
4086 match compare_cv_numeric(lb_arr.value(i), int_arr.value(i) as f64, op) {
4087 Some(result) => builder.append_value(result),
4088 None => builder.append_null(),
4089 }
4090 }
4091 }
4092 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4093 }
4094
4095 DataType::Float64 => {
4097 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4098 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4099 for i in 0..lb_arr.len() {
4100 if lb_arr.is_null(i) || float_arr.is_null(i) {
4101 builder.append_null();
4102 } else {
4103 match compare_cv_numeric(lb_arr.value(i), float_arr.value(i), op) {
4104 Some(result) => builder.append_value(result),
4105 None => builder.append_null(),
4106 }
4107 }
4108 }
4109 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4110 }
4111
4112 DataType::Utf8 | DataType::LargeUtf8 => {
4114 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4115 for i in 0..lb_arr.len() {
4116 if lb_arr.is_null(i) || rhs_arr.is_null(i) {
4117 builder.append_null();
4118 } else {
4119 let bytes = lb_arr.value(i);
4120 let rhs_str = if matches!(rhs_arr.data_type(), DataType::Utf8) {
4121 rhs_arr.as_any().downcast_ref::<StringArray>()?.value(i)
4122 } else {
4123 rhs_arr
4124 .as_any()
4125 .downcast_ref::<LargeStringArray>()?
4126 .value(i)
4127 };
4128 match peek_tag(bytes) {
4129 Some(TAG_STRING) => {
4130 if let Some(lhs_str) = decode_string(bytes) {
4131 builder.append_value(apply_comparison_op(
4132 lhs_str.as_str().cmp(rhs_str),
4133 op,
4134 ));
4135 } else {
4136 builder.append_null();
4137 }
4138 }
4139 _ => builder.append_null(),
4140 }
4141 }
4142 }
4143 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4144 }
4145
4146 DataType::LargeBinary => {
4148 let rhs_lb = rhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4149 let mut builder = BooleanBuilder::with_capacity(lb_arr.len());
4150 for i in 0..lb_arr.len() {
4151 if lb_arr.is_null(i) || rhs_lb.is_null(i) {
4152 builder.append_null();
4153 } else {
4154 let lhs_bytes = lb_arr.value(i);
4155 let rhs_bytes = rhs_lb.value(i);
4156 let lhs_tag = peek_tag(lhs_bytes);
4157 let rhs_tag = peek_tag(rhs_bytes);
4158
4159 if lhs_tag == Some(TAG_NULL) || rhs_tag == Some(TAG_NULL) {
4161 builder.append_null();
4162 continue;
4163 }
4164
4165 if lhs_tag == Some(TAG_INT) && rhs_tag == Some(TAG_INT) {
4167 if let (Some(l), Some(r)) = (decode_int(lhs_bytes), decode_int(rhs_bytes)) {
4168 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4169 } else {
4170 builder.append_null();
4171 }
4172 continue;
4173 }
4174
4175 if lhs_tag == Some(TAG_STRING) && rhs_tag == Some(TAG_STRING) {
4177 if let (Some(l), Some(r)) =
4178 (decode_string(lhs_bytes), decode_string(rhs_bytes))
4179 {
4180 builder.append_value(apply_comparison_op(l.cmp(&r), op));
4181 } else {
4182 builder.append_null();
4183 }
4184 continue;
4185 }
4186
4187 if let (Some(l), Some(r)) =
4189 (cv_bytes_as_f64(lhs_bytes), cv_bytes_as_f64(rhs_bytes))
4190 {
4191 match compare_f64(l, r, op) {
4192 Some(result) => builder.append_value(result),
4193 None => builder.append_null(),
4194 }
4195 } else {
4196 return None;
4200 }
4201 }
4202 }
4203 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4204 }
4205
4206 _ => None, }
4208}
4209
4210#[derive(Debug)]
4211struct CypherCompareUdf {
4212 name: String,
4213 op: BinaryOp,
4214 signature: Signature,
4215}
4216
4217impl CypherCompareUdf {
4218 fn new(name: &str, op: BinaryOp) -> Self {
4219 Self {
4220 name: name.to_string(),
4221 op,
4222 signature: Signature::any(2, Volatility::Immutable),
4223 }
4224 }
4225}
4226
4227impl PartialEq for CypherCompareUdf {
4228 fn eq(&self, other: &Self) -> bool {
4229 self.name == other.name
4230 }
4231}
4232
4233impl Eq for CypherCompareUdf {}
4234
4235impl std::hash::Hash for CypherCompareUdf {
4236 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4237 self.name.hash(state);
4238 }
4239}
4240
4241impl ScalarUDFImpl for CypherCompareUdf {
4242 fn as_any(&self) -> &dyn Any {
4243 self
4244 }
4245 fn name(&self) -> &str {
4246 &self.name
4247 }
4248 fn signature(&self) -> &Signature {
4249 &self.signature
4250 }
4251 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4252 Ok(DataType::Boolean)
4253 }
4254
4255 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4256 if args.args.len() != 2 {
4257 return Err(datafusion::error::DataFusionError::Execution(format!(
4258 "{}(): requires 2 arguments",
4259 self.name
4260 )));
4261 }
4262
4263 if let Some(result) = try_fast_compare(&args.args[0], &args.args[1], &self.op) {
4265 return Ok(result);
4266 }
4267
4268 let output_type = DataType::Boolean;
4270 invoke_cypher_udf(args, &output_type, |val_args| {
4271 crate::query::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4272 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4273 })
4274 }
4275}
4276
4277pub fn create_cypher_add_udf() -> ScalarUDF {
4283 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_add", BinaryOp::Add))
4284}
4285pub fn create_cypher_sub_udf() -> ScalarUDF {
4286 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_sub", BinaryOp::Sub))
4287}
4288pub fn create_cypher_mul_udf() -> ScalarUDF {
4289 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mul", BinaryOp::Mul))
4290}
4291pub fn create_cypher_div_udf() -> ScalarUDF {
4292 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_div", BinaryOp::Div))
4293}
4294pub fn create_cypher_mod_udf() -> ScalarUDF {
4295 ScalarUDF::new_from_impl(CypherArithmeticUdf::new("_cypher_mod", BinaryOp::Mod))
4296}
4297
4298pub fn create_cypher_abs_udf() -> ScalarUDF {
4300 ScalarUDF::new_from_impl(CypherAbsUdf::new())
4301}
4302
4303pub(crate) fn cypher_abs_expr(
4305 arg: datafusion::logical_expr::Expr,
4306) -> datafusion::logical_expr::Expr {
4307 datafusion::logical_expr::Expr::ScalarFunction(
4308 datafusion::logical_expr::expr::ScalarFunction::new_udf(
4309 Arc::new(create_cypher_abs_udf()),
4310 vec![arg],
4311 ),
4312 )
4313}
4314
4315#[derive(Debug)]
4316struct CypherAbsUdf {
4317 signature: Signature,
4318}
4319
4320impl CypherAbsUdf {
4321 fn new() -> Self {
4322 Self {
4323 signature: Signature::any(1, Volatility::Immutable),
4324 }
4325 }
4326}
4327
4328impl_udf_eq_hash!(CypherAbsUdf);
4329
4330impl ScalarUDFImpl for CypherAbsUdf {
4331 fn as_any(&self) -> &dyn Any {
4332 self
4333 }
4334 fn name(&self) -> &str {
4335 "_cypher_abs"
4336 }
4337 fn signature(&self) -> &Signature {
4338 &self.signature
4339 }
4340 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
4341 Ok(DataType::LargeBinary)
4342 }
4343 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4344 if args.args.len() != 1 {
4345 return Err(datafusion::error::DataFusionError::Execution(
4346 "_cypher_abs requires exactly 1 argument".into(),
4347 ));
4348 }
4349 invoke_cypher_udf(args, &DataType::LargeBinary, |val_args| {
4350 match &val_args[0] {
4351 Value::Int(i) => i.checked_abs().map(Value::Int).ok_or_else(|| {
4352 datafusion::error::DataFusionError::Execution(
4353 "integer overflow in abs()".into(),
4354 )
4355 }),
4356 Value::Float(f) => Ok(Value::Float(f.abs())),
4357 Value::Null => Ok(Value::Null),
4358 other => Err(datafusion::error::DataFusionError::Execution(format!(
4359 "abs() requires a numeric argument, got {other:?}"
4360 ))),
4361 }
4362 })
4363 }
4364}
4365
4366fn apply_int_arithmetic(lhs: i64, rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
4369 use uni_common::cypher_value_codec::encode_int;
4370 match op {
4371 BinaryOp::Add => lhs.checked_add(rhs).map(encode_int),
4372 BinaryOp::Sub => lhs.checked_sub(rhs).map(encode_int),
4373 BinaryOp::Mul => lhs.checked_mul(rhs).map(encode_int),
4374 BinaryOp::Div => {
4375 if rhs == 0 {
4377 None
4378 } else {
4379 lhs.checked_div(rhs).map(encode_int)
4380 }
4381 }
4382 BinaryOp::Mod => {
4383 if rhs == 0 {
4384 None
4385 } else {
4386 lhs.checked_rem(rhs).map(encode_int)
4387 }
4388 }
4389 _ => None,
4390 }
4391}
4392
4393fn apply_float_arithmetic(lhs: f64, rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
4395 use uni_common::cypher_value_codec::encode_float;
4396 let result = match op {
4397 BinaryOp::Add => lhs + rhs,
4398 BinaryOp::Sub => lhs - rhs,
4399 BinaryOp::Mul => lhs * rhs,
4400 BinaryOp::Div => lhs / rhs, BinaryOp::Mod => lhs % rhs,
4402 _ => return None,
4403 };
4404 Some(encode_float(result))
4405}
4406
4407fn cv_arithmetic_int(bytes: &[u8], rhs: i64, op: &BinaryOp) -> Option<Vec<u8>> {
4410 use uni_common::cypher_value_codec::{TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag};
4411 match peek_tag(bytes)? {
4412 TAG_INT => apply_int_arithmetic(decode_int(bytes)?, rhs, op),
4413 TAG_FLOAT => apply_float_arithmetic(decode_float(bytes)?, rhs as f64, op),
4414 _ => None,
4415 }
4416}
4417
4418fn cv_arithmetic_float(bytes: &[u8], rhs: f64, op: &BinaryOp) -> Option<Vec<u8>> {
4421 let lhs = cv_bytes_as_f64(bytes)?;
4422 apply_float_arithmetic(lhs, rhs, op)
4423}
4424
4425fn try_fast_arithmetic(
4429 lhs: &ColumnarValue,
4430 rhs: &ColumnarValue,
4431 op: &BinaryOp,
4432) -> Option<ColumnarValue> {
4433 use arrow_array::builder::LargeBinaryBuilder;
4434
4435 let (lhs_arr, rhs_arr) = match (lhs, rhs) {
4436 (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l, r),
4437 _ => return None,
4438 };
4439
4440 match (lhs_arr.data_type(), rhs_arr.data_type()) {
4441 (DataType::LargeBinary, DataType::Int64) => {
4443 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4444 let int_arr = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4445 let mut builder = LargeBinaryBuilder::new();
4446 for i in 0..lb_arr.len() {
4447 if lb_arr.is_null(i) || int_arr.is_null(i) {
4448 builder.append_null();
4449 } else if let Some(bytes) = cv_arithmetic_int(lb_arr.value(i), int_arr.value(i), op)
4450 {
4451 builder.append_value(&bytes);
4452 } else {
4453 builder.append_null();
4454 }
4455 }
4456 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4457 }
4458
4459 (DataType::LargeBinary, DataType::Float64) => {
4461 let lb_arr = lhs_arr.as_any().downcast_ref::<LargeBinaryArray>()?;
4462 let float_arr = rhs_arr.as_any().downcast_ref::<Float64Array>()?;
4463 let mut builder = LargeBinaryBuilder::new();
4464 for i in 0..lb_arr.len() {
4465 if lb_arr.is_null(i) || float_arr.is_null(i) {
4466 builder.append_null();
4467 } else if let Some(bytes) =
4468 cv_arithmetic_float(lb_arr.value(i), float_arr.value(i), op)
4469 {
4470 builder.append_value(&bytes);
4471 } else {
4472 builder.append_null();
4473 }
4474 }
4475 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4476 }
4477
4478 (DataType::Int64, DataType::Int64) => {
4480 let lhs_int = lhs_arr.as_any().downcast_ref::<Int64Array>()?;
4481 let rhs_int = rhs_arr.as_any().downcast_ref::<Int64Array>()?;
4482 let mut builder = LargeBinaryBuilder::new();
4483 for i in 0..lhs_int.len() {
4484 if lhs_int.is_null(i) || rhs_int.is_null(i) {
4485 builder.append_null();
4486 } else if let Some(bytes) =
4487 apply_int_arithmetic(lhs_int.value(i), rhs_int.value(i), op)
4488 {
4489 builder.append_value(&bytes);
4490 } else {
4491 builder.append_null();
4492 }
4493 }
4494 Some(ColumnarValue::Array(Arc::new(builder.finish())))
4495 }
4496
4497 _ => None, }
4499}
4500
4501#[derive(Debug)]
4502struct CypherArithmeticUdf {
4503 name: String,
4504 op: BinaryOp,
4505 signature: Signature,
4506}
4507
4508impl CypherArithmeticUdf {
4509 fn new(name: &str, op: BinaryOp) -> Self {
4510 Self {
4511 name: name.to_string(),
4512 op,
4513 signature: Signature::any(2, Volatility::Immutable),
4514 }
4515 }
4516}
4517
4518impl PartialEq for CypherArithmeticUdf {
4519 fn eq(&self, other: &Self) -> bool {
4520 self.name == other.name
4521 }
4522}
4523
4524impl Eq for CypherArithmeticUdf {}
4525
4526impl std::hash::Hash for CypherArithmeticUdf {
4527 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
4528 self.name.hash(state);
4529 }
4530}
4531
4532impl ScalarUDFImpl for CypherArithmeticUdf {
4533 fn as_any(&self) -> &dyn Any {
4534 self
4535 }
4536 fn name(&self) -> &str {
4537 &self.name
4538 }
4539 fn signature(&self) -> &Signature {
4540 &self.signature
4541 }
4542 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4543 Ok(DataType::LargeBinary) }
4545
4546 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4547 if args.args.len() != 2 {
4548 return Err(datafusion::error::DataFusionError::Execution(format!(
4549 "{}(): requires 2 arguments",
4550 self.name
4551 )));
4552 }
4553
4554 if let Some(result) = try_fast_arithmetic(&args.args[0], &args.args[1], &self.op) {
4556 return Ok(result);
4557 }
4558
4559 let output_type = DataType::LargeBinary;
4561 invoke_cypher_udf(args, &output_type, |val_args| {
4562 crate::query::expr_eval::eval_binary_op(&val_args[0], &self.op, &val_args[1])
4563 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4564 })
4565 }
4566}
4567
4568pub fn create_cypher_xor_udf() -> ScalarUDF {
4573 ScalarUDF::new_from_impl(CypherXorUdf::new())
4574}
4575
4576#[derive(Debug)]
4577struct CypherXorUdf {
4578 signature: Signature,
4579}
4580
4581impl CypherXorUdf {
4582 fn new() -> Self {
4583 Self {
4584 signature: Signature::any(2, Volatility::Immutable),
4585 }
4586 }
4587}
4588
4589impl_udf_eq_hash!(CypherXorUdf);
4590
4591impl ScalarUDFImpl for CypherXorUdf {
4592 fn as_any(&self) -> &dyn Any {
4593 self
4594 }
4595 fn name(&self) -> &str {
4596 "_cypher_xor"
4597 }
4598 fn signature(&self) -> &Signature {
4599 &self.signature
4600 }
4601 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4602 Ok(DataType::Boolean)
4603 }
4604
4605 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4606 let output_type = DataType::Boolean;
4607 invoke_cypher_udf(args, &output_type, |val_args| {
4608 if val_args.len() != 2 {
4609 return Err(datafusion::error::DataFusionError::Execution(
4610 "_cypher_xor(): requires 2 arguments".to_string(),
4611 ));
4612 }
4613 let coerce_bool = |v: &Value| -> Value {
4615 match v {
4616 Value::String(s) if s == "true" => Value::Bool(true),
4617 Value::String(s) if s == "false" => Value::Bool(false),
4618 other => other.clone(),
4619 }
4620 };
4621 let left = coerce_bool(&val_args[0]);
4622 let right = coerce_bool(&val_args[1]);
4623 crate::query::expr_eval::eval_binary_op(&left, &BinaryOp::Xor, &right)
4624 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
4625 })
4626 }
4627}
4628
4629pub fn create_cv_to_bool_udf() -> ScalarUDF {
4636 ScalarUDF::new_from_impl(CvToBoolUdf::new())
4637}
4638
4639#[derive(Debug)]
4640struct CvToBoolUdf {
4641 signature: Signature,
4642}
4643
4644impl CvToBoolUdf {
4645 fn new() -> Self {
4646 Self {
4647 signature: Signature::exact(vec![DataType::LargeBinary], Volatility::Immutable),
4648 }
4649 }
4650}
4651
4652impl_udf_eq_hash!(CvToBoolUdf);
4653
4654impl ScalarUDFImpl for CvToBoolUdf {
4655 fn as_any(&self) -> &dyn Any {
4656 self
4657 }
4658 fn name(&self) -> &str {
4659 "_cv_to_bool"
4660 }
4661 fn signature(&self) -> &Signature {
4662 &self.signature
4663 }
4664 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4665 Ok(DataType::Boolean)
4666 }
4667
4668 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4669 if args.args.len() != 1 {
4670 return Err(datafusion::error::DataFusionError::Execution(
4671 "_cv_to_bool() requires exactly 1 argument".to_string(),
4672 ));
4673 }
4674
4675 match &args.args[0] {
4676 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
4677 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4679 let b = match peek_tag(bytes) {
4680 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4681 Some(TAG_NULL) => false,
4682 _ => false, };
4684 Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))))
4685 }
4686 ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))),
4687 ColumnarValue::Array(arr) => {
4688 let lb_arr = arr
4689 .as_any()
4690 .downcast_ref::<arrow_array::LargeBinaryArray>()
4691 .ok_or_else(|| {
4692 datafusion::error::DataFusionError::Execution(format!(
4693 "_cv_to_bool(): expected LargeBinary array, got {:?}",
4694 arr.data_type()
4695 ))
4696 })?;
4697
4698 let mut builder = arrow_array::builder::BooleanBuilder::with_capacity(lb_arr.len());
4699
4700 use uni_common::cypher_value_codec::{TAG_BOOL, TAG_NULL, decode_bool, peek_tag};
4702
4703 for i in 0..lb_arr.len() {
4704 if lb_arr.is_null(i) {
4705 builder.append_null();
4706 } else {
4707 let bytes = lb_arr.value(i);
4708 let b = match peek_tag(bytes) {
4709 Some(TAG_BOOL) => decode_bool(bytes).unwrap_or(false),
4710 Some(TAG_NULL) => false,
4711 _ => false, };
4713 builder.append_value(b);
4714 }
4715 }
4716 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
4717 }
4718 }
4719 }
4720}
4721
4722pub fn create_cypher_size_udf() -> ScalarUDF {
4728 ScalarUDF::new_from_impl(CypherSizeUdf::new())
4729}
4730
4731#[derive(Debug)]
4732struct CypherSizeUdf {
4733 signature: Signature,
4734}
4735
4736impl CypherSizeUdf {
4737 fn new() -> Self {
4738 Self {
4739 signature: Signature::any(1, Volatility::Immutable),
4740 }
4741 }
4742}
4743
4744impl_udf_eq_hash!(CypherSizeUdf);
4745
4746impl ScalarUDFImpl for CypherSizeUdf {
4747 fn as_any(&self) -> &dyn Any {
4748 self
4749 }
4750
4751 fn name(&self) -> &str {
4752 "_cypher_size"
4753 }
4754
4755 fn signature(&self) -> &Signature {
4756 &self.signature
4757 }
4758
4759 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4760 Ok(DataType::Int64)
4761 }
4762
4763 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4764 if args.args.len() != 1 {
4765 return Err(datafusion::error::DataFusionError::Execution(
4766 "_cypher_size() requires exactly 1 argument".to_string(),
4767 ));
4768 }
4769
4770 match &args.args[0] {
4771 ColumnarValue::Scalar(scalar) => {
4772 let result = cypher_size_scalar(scalar)?;
4773 Ok(ColumnarValue::Scalar(result))
4774 }
4775 ColumnarValue::Array(arr) => {
4776 let mut results: Vec<Option<i64>> = Vec::with_capacity(arr.len());
4777 for i in 0..arr.len() {
4778 if arr.is_null(i) {
4779 results.push(None);
4780 } else {
4781 let scalar = ScalarValue::try_from_array(arr, i)?;
4782 match cypher_size_scalar(&scalar)? {
4783 ScalarValue::Int64(v) => results.push(v),
4784 _ => results.push(None),
4785 }
4786 }
4787 }
4788 let arr: ArrayRef = Arc::new(arrow_array::Int64Array::from(results));
4789 Ok(ColumnarValue::Array(arr))
4790 }
4791 }
4792 }
4793}
4794
4795fn cypher_size_scalar(scalar: &ScalarValue) -> DFResult<ScalarValue> {
4796 match scalar {
4797 ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
4799 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
4800 }
4801 ScalarValue::List(arr) => {
4804 if arr.is_empty() || arr.is_null(0) {
4805 Ok(ScalarValue::Int64(None))
4806 } else {
4807 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4808 }
4809 }
4810 ScalarValue::LargeList(arr) => {
4811 if arr.is_empty() || arr.is_null(0) {
4812 Ok(ScalarValue::Int64(None))
4813 } else {
4814 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4815 }
4816 }
4817 ScalarValue::LargeBinary(Some(b)) => {
4819 if let Ok(uni_val) = uni_common::cypher_value_codec::decode(b) {
4820 match &uni_val {
4821 uni_common::Value::Node(_) => {
4822 Err(datafusion::error::DataFusionError::Execution(
4823 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
4824 ))
4825 }
4826 uni_common::Value::Edge(_) => {
4827 Err(datafusion::error::DataFusionError::Execution(
4828 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
4829 ))
4830 }
4831 _ => {
4832 let json_val: serde_json::Value = uni_val.into();
4833 match json_val {
4834 serde_json::Value::Array(arr) => Ok(ScalarValue::Int64(Some(arr.len() as i64))),
4835 serde_json::Value::String(s) => {
4836 Ok(ScalarValue::Int64(Some(s.chars().count() as i64)))
4837 }
4838 serde_json::Value::Object(m) => Ok(ScalarValue::Int64(Some(m.len() as i64))),
4839 _ => Ok(ScalarValue::Int64(None)),
4840 }
4841 }
4842 }
4843 } else {
4844 Ok(ScalarValue::Int64(None))
4845 }
4846 }
4847 ScalarValue::Map(arr) => {
4849 if arr.is_empty() || arr.is_null(0) {
4850 Ok(ScalarValue::Int64(None))
4851 } else {
4852 Ok(ScalarValue::Int64(Some(arr.value(0).len() as i64)))
4854 }
4855 }
4856 ScalarValue::Struct(arr) => {
4858 if arr.is_null(0) {
4859 Ok(ScalarValue::Int64(None))
4860 } else {
4861 let schema = arr.fields();
4862 let field_names: Vec<&str> = schema.iter().map(|f| f.name().as_str()).collect();
4863 if field_names.contains(&"_vid") && !field_names.contains(&"relationships") {
4865 return Err(datafusion::error::DataFusionError::Execution(
4866 "TypeError: InvalidArgumentValue - length() is not supported for Node values".to_string(),
4867 ));
4868 }
4869 if field_names.contains(&"_eid")
4871 || (field_names.contains(&"_src") && field_names.contains(&"_dst"))
4872 {
4873 return Err(datafusion::error::DataFusionError::Execution(
4874 "TypeError: InvalidArgumentValue - length() is not supported for Relationship values".to_string(),
4875 ));
4876 }
4877 if let Some((rels_idx, _)) = schema
4879 .iter()
4880 .enumerate()
4881 .find(|(_, f)| f.name() == "relationships")
4882 {
4883 let rels_col = arr.column(rels_idx);
4885 if let Some(list_arr) =
4886 rels_col.as_any().downcast_ref::<arrow_array::ListArray>()
4887 {
4888 if list_arr.is_null(0) {
4889 Ok(ScalarValue::Int64(Some(0)))
4890 } else {
4891 Ok(ScalarValue::Int64(Some(list_arr.value(0).len() as i64)))
4892 }
4893 } else {
4894 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
4895 }
4896 } else {
4897 Ok(ScalarValue::Int64(Some(arr.num_columns() as i64)))
4898 }
4899 }
4900 }
4901 ScalarValue::Null
4903 | ScalarValue::Utf8(None)
4904 | ScalarValue::LargeUtf8(None)
4905 | ScalarValue::LargeBinary(None) => Ok(ScalarValue::Int64(None)),
4906 other => Err(datafusion::error::DataFusionError::Execution(format!(
4907 "_cypher_size(): unsupported type {other:?}"
4908 ))),
4909 }
4910}
4911
4912pub fn create_cypher_list_compare_udf() -> ScalarUDF {
4918 ScalarUDF::new_from_impl(CypherListCompareUdf::new())
4919}
4920
4921#[derive(Debug)]
4922struct CypherListCompareUdf {
4923 signature: Signature,
4924}
4925
4926impl CypherListCompareUdf {
4927 fn new() -> Self {
4928 Self {
4929 signature: Signature::any(3, Volatility::Immutable),
4930 }
4931 }
4932}
4933
4934impl_udf_eq_hash!(CypherListCompareUdf);
4935
4936impl ScalarUDFImpl for CypherListCompareUdf {
4937 fn as_any(&self) -> &dyn Any {
4938 self
4939 }
4940
4941 fn name(&self) -> &str {
4942 "_cypher_list_compare"
4943 }
4944
4945 fn signature(&self) -> &Signature {
4946 &self.signature
4947 }
4948
4949 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
4950 Ok(DataType::Boolean)
4951 }
4952
4953 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
4954 let output_type = DataType::Boolean;
4955 invoke_cypher_udf(args, &output_type, |val_args| {
4956 if val_args.len() != 3 {
4957 return Err(datafusion::error::DataFusionError::Execution(
4958 "_cypher_list_compare(): requires 3 arguments (left, right, op)".to_string(),
4959 ));
4960 }
4961
4962 let left = &val_args[0];
4963 let right = &val_args[1];
4964 let op_str = match &val_args[2] {
4965 Value::String(s) => s.as_str(),
4966 _ => {
4967 return Err(datafusion::error::DataFusionError::Execution(
4968 "_cypher_list_compare(): op must be a string".to_string(),
4969 ));
4970 }
4971 };
4972
4973 let (left_items, right_items) = match (left, right) {
4974 (Value::List(l), Value::List(r)) => (l, r),
4975 (Value::Null, _) | (_, Value::Null) => return Ok(Value::Null),
4976 _ => {
4977 return Err(datafusion::error::DataFusionError::Execution(
4978 "_cypher_list_compare(): both arguments must be lists".to_string(),
4979 ));
4980 }
4981 };
4982
4983 let cmp = cypher_list_cmp(left_items, right_items);
4985
4986 let result = match (op_str, cmp) {
4987 (_, None) => Value::Null,
4988 ("lt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Less),
4989 ("lteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Greater),
4990 ("gt", Some(ord)) => Value::Bool(ord == std::cmp::Ordering::Greater),
4991 ("gteq", Some(ord)) => Value::Bool(ord != std::cmp::Ordering::Less),
4992 _ => {
4993 return Err(datafusion::error::DataFusionError::Execution(format!(
4994 "_cypher_list_compare(): unknown op '{}'",
4995 op_str
4996 )));
4997 }
4998 };
4999
5000 Ok(result)
5001 })
5002 }
5003}
5004
5005pub fn create_map_project_udf() -> ScalarUDF {
5010 ScalarUDF::new_from_impl(MapProjectUdf::new())
5011}
5012
5013#[derive(Debug)]
5014struct MapProjectUdf {
5015 signature: Signature,
5016}
5017
5018impl MapProjectUdf {
5019 fn new() -> Self {
5020 Self {
5021 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5022 }
5023 }
5024}
5025
5026impl_udf_eq_hash!(MapProjectUdf);
5027
5028impl ScalarUDFImpl for MapProjectUdf {
5029 fn as_any(&self) -> &dyn Any {
5030 self
5031 }
5032
5033 fn name(&self) -> &str {
5034 "_map_project"
5035 }
5036
5037 fn signature(&self) -> &Signature {
5038 &self.signature
5039 }
5040
5041 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5042 Ok(DataType::LargeBinary)
5043 }
5044
5045 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5046 let output_type = self.return_type(&[])?;
5047 invoke_cypher_udf(args, &output_type, |val_args| {
5048 let mut result_map = std::collections::HashMap::new();
5049 let mut i = 0;
5050 while i + 1 < val_args.len() {
5051 let key = &val_args[i];
5052 let value = &val_args[i + 1];
5053 if let Some(k) = key.as_str() {
5054 if k == "__all__" {
5055 match value {
5057 Value::Map(map) => {
5058 for (mk, mv) in map {
5059 if !mk.starts_with('_') {
5060 result_map.insert(mk.clone(), mv.clone());
5061 }
5062 }
5063 }
5064 Value::Node(node) => {
5065 for (pk, pv) in &node.properties {
5066 result_map.insert(pk.clone(), pv.clone());
5067 }
5068 }
5069 Value::Edge(edge) => {
5070 for (pk, pv) in &edge.properties {
5071 result_map.insert(pk.clone(), pv.clone());
5072 }
5073 }
5074 _ => {}
5075 }
5076 } else {
5077 result_map.insert(k.to_string(), value.clone());
5078 }
5079 }
5080 i += 2;
5081 }
5082 Ok(Value::Map(result_map))
5083 })
5084 }
5085}
5086
5087pub fn create_make_cypher_list_udf() -> ScalarUDF {
5092 ScalarUDF::new_from_impl(MakeCypherListUdf::new())
5093}
5094
5095#[derive(Debug)]
5096struct MakeCypherListUdf {
5097 signature: Signature,
5098}
5099
5100impl MakeCypherListUdf {
5101 fn new() -> Self {
5102 Self {
5103 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
5104 }
5105 }
5106}
5107
5108impl_udf_eq_hash!(MakeCypherListUdf);
5109
5110impl ScalarUDFImpl for MakeCypherListUdf {
5111 fn as_any(&self) -> &dyn Any {
5112 self
5113 }
5114
5115 fn name(&self) -> &str {
5116 "_make_cypher_list"
5117 }
5118
5119 fn signature(&self) -> &Signature {
5120 &self.signature
5121 }
5122
5123 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5124 Ok(DataType::LargeBinary)
5125 }
5126
5127 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5128 let output_type = self.return_type(&[])?;
5129 invoke_cypher_udf(args, &output_type, |val_args| {
5130 Ok(Value::List(val_args.to_vec()))
5131 })
5132 }
5133}
5134
5135pub fn create_cypher_in_udf() -> ScalarUDF {
5152 ScalarUDF::new_from_impl(CypherInUdf::new())
5153}
5154
5155#[derive(Debug)]
5156struct CypherInUdf {
5157 signature: Signature,
5158}
5159
5160impl CypherInUdf {
5161 fn new() -> Self {
5162 Self {
5163 signature: Signature::any(2, Volatility::Immutable),
5164 }
5165 }
5166}
5167
5168impl_udf_eq_hash!(CypherInUdf);
5169
5170impl ScalarUDFImpl for CypherInUdf {
5171 fn as_any(&self) -> &dyn Any {
5172 self
5173 }
5174
5175 fn name(&self) -> &str {
5176 "_cypher_in"
5177 }
5178
5179 fn signature(&self) -> &Signature {
5180 &self.signature
5181 }
5182
5183 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5184 Ok(DataType::Boolean)
5185 }
5186
5187 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5188 invoke_cypher_udf(args, &DataType::Boolean, |vals| {
5189 if vals.len() != 2 {
5190 return Err(datafusion::error::DataFusionError::Execution(
5191 "_cypher_in(): requires 2 arguments".to_string(),
5192 ));
5193 }
5194 let element = &vals[0];
5195 let list_val = &vals[1];
5196
5197 if list_val.is_null() {
5199 return Ok(Value::Null);
5200 }
5201
5202 let items = match list_val {
5204 Value::List(items) => items.as_slice(),
5205 _ => {
5206 return Err(datafusion::error::DataFusionError::Execution(format!(
5207 "_cypher_in(): second argument must be a list, got {:?}",
5208 list_val
5209 )));
5210 }
5211 };
5212
5213 if element.is_null() {
5215 return if items.is_empty() {
5216 Ok(Value::Bool(false))
5217 } else {
5218 Ok(Value::Null) };
5220 }
5221
5222 let mut has_null = false;
5224 for item in items {
5225 match cypher_eq(element, item) {
5226 Some(true) => return Ok(Value::Bool(true)),
5227 None => has_null = true,
5228 Some(false) => {}
5229 }
5230 }
5231
5232 if has_null {
5233 Ok(Value::Null) } else {
5235 Ok(Value::Bool(false))
5236 }
5237 })
5238 }
5239}
5240
5241pub fn create_cypher_list_concat_udf() -> ScalarUDF {
5247 ScalarUDF::new_from_impl(CypherListConcatUdf::new())
5248}
5249
5250#[derive(Debug)]
5251struct CypherListConcatUdf {
5252 signature: Signature,
5253}
5254
5255impl CypherListConcatUdf {
5256 fn new() -> Self {
5257 Self {
5258 signature: Signature::any(2, Volatility::Immutable),
5259 }
5260 }
5261}
5262
5263impl_udf_eq_hash!(CypherListConcatUdf);
5264
5265impl ScalarUDFImpl for CypherListConcatUdf {
5266 fn as_any(&self) -> &dyn Any {
5267 self
5268 }
5269
5270 fn name(&self) -> &str {
5271 "_cypher_list_concat"
5272 }
5273
5274 fn signature(&self) -> &Signature {
5275 &self.signature
5276 }
5277
5278 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5279 Ok(DataType::LargeBinary)
5280 }
5281
5282 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5283 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5284 if vals.len() != 2 {
5285 return Err(datafusion::error::DataFusionError::Execution(
5286 "_cypher_list_concat(): requires 2 arguments".to_string(),
5287 ));
5288 }
5289 if vals[0].is_null() || vals[1].is_null() {
5291 return Ok(Value::Null);
5292 }
5293 match (&vals[0], &vals[1]) {
5294 (Value::List(left), Value::List(right)) => {
5295 let mut result = left.clone();
5296 result.extend(right.iter().cloned());
5297 Ok(Value::List(result))
5298 }
5299 (Value::List(list), elem) => {
5302 let mut result = list.clone();
5303 result.push(elem.clone());
5304 Ok(Value::List(result))
5305 }
5306 (elem, Value::List(list)) => {
5307 let mut result = vec![elem.clone()];
5308 result.extend(list.iter().cloned());
5309 Ok(Value::List(result))
5310 }
5311 _ => {
5312 crate::query::expr_eval::eval_binary_op(
5315 &vals[0],
5316 &uni_cypher::ast::BinaryOp::Add,
5317 &vals[1],
5318 )
5319 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
5320 }
5321 }
5322 })
5323 }
5324}
5325
5326pub fn create_cypher_list_append_udf() -> ScalarUDF {
5332 ScalarUDF::new_from_impl(CypherListAppendUdf::new())
5333}
5334
5335#[derive(Debug)]
5336struct CypherListAppendUdf {
5337 signature: Signature,
5338}
5339
5340impl CypherListAppendUdf {
5341 fn new() -> Self {
5342 Self {
5343 signature: Signature::any(2, Volatility::Immutable),
5344 }
5345 }
5346}
5347
5348impl_udf_eq_hash!(CypherListAppendUdf);
5349
5350impl ScalarUDFImpl for CypherListAppendUdf {
5351 fn as_any(&self) -> &dyn Any {
5352 self
5353 }
5354
5355 fn name(&self) -> &str {
5356 "_cypher_list_append"
5357 }
5358
5359 fn signature(&self) -> &Signature {
5360 &self.signature
5361 }
5362
5363 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5364 Ok(DataType::LargeBinary)
5365 }
5366
5367 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5368 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5369 if vals.len() != 2 {
5370 return Err(datafusion::error::DataFusionError::Execution(
5371 "_cypher_list_append(): requires 2 arguments".to_string(),
5372 ));
5373 }
5374 let left = &vals[0];
5375 let right = &vals[1];
5376
5377 if left.is_null() || right.is_null() {
5379 return Ok(Value::Null);
5380 }
5381
5382 match (left, right) {
5383 (Value::List(list), elem) => {
5385 let mut result = list.clone();
5386 result.push(elem.clone());
5387 Ok(Value::List(result))
5388 }
5389 (elem, Value::List(list)) => {
5391 let mut result = vec![elem.clone()];
5392 result.extend(list.iter().cloned());
5393 Ok(Value::List(result))
5394 }
5395 _ => Err(datafusion::error::DataFusionError::Execution(format!(
5396 "_cypher_list_append(): at least one argument must be a list, got {:?} and {:?}",
5397 left, right
5398 ))),
5399 }
5400 })
5401 }
5402}
5403
5404pub fn create_cypher_list_slice_udf() -> ScalarUDF {
5410 ScalarUDF::new_from_impl(CypherListSliceUdf::new())
5411}
5412
5413#[derive(Debug)]
5414struct CypherListSliceUdf {
5415 signature: Signature,
5416}
5417
5418impl CypherListSliceUdf {
5419 fn new() -> Self {
5420 Self {
5421 signature: Signature::any(3, Volatility::Immutable),
5422 }
5423 }
5424}
5425
5426impl_udf_eq_hash!(CypherListSliceUdf);
5427
5428impl ScalarUDFImpl for CypherListSliceUdf {
5429 fn as_any(&self) -> &dyn Any {
5430 self
5431 }
5432
5433 fn name(&self) -> &str {
5434 "_cypher_list_slice"
5435 }
5436
5437 fn signature(&self) -> &Signature {
5438 &self.signature
5439 }
5440
5441 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5442 Ok(DataType::LargeBinary)
5443 }
5444
5445 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5446 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5447 if vals.len() != 3 {
5448 return Err(datafusion::error::DataFusionError::Execution(
5449 "_cypher_list_slice(): requires 3 arguments (list, start, end)".to_string(),
5450 ));
5451 }
5452 if vals[0].is_null() {
5454 return Ok(Value::Null);
5455 }
5456 let list = match &vals[0] {
5457 Value::List(l) => l,
5458 _ => {
5459 return Err(datafusion::error::DataFusionError::Execution(format!(
5460 "_cypher_list_slice(): first argument must be a list, got {:?}",
5461 vals[0]
5462 )));
5463 }
5464 };
5465 if vals[1].is_null() || vals[2].is_null() {
5467 return Ok(Value::Null);
5468 }
5469
5470 let len = list.len() as i64;
5471 let raw_start = match &vals[1] {
5472 Value::Int(i) => *i,
5473 _ => 0,
5474 };
5475 let raw_end = match &vals[2] {
5476 Value::Int(i) => *i,
5477 _ => len,
5478 };
5479
5480 let start = if raw_start < 0 {
5482 (len + raw_start).max(0) as usize
5483 } else {
5484 (raw_start).min(len) as usize
5485 };
5486 let end = if raw_end == i64::MAX {
5487 len as usize
5488 } else if raw_end < 0 {
5489 (len + raw_end).max(0) as usize
5490 } else {
5491 (raw_end).min(len) as usize
5492 };
5493
5494 if start >= end {
5495 return Ok(Value::List(vec![]));
5496 }
5497 Ok(Value::List(list[start..end.min(list.len())].to_vec()))
5498 })
5499 }
5500}
5501
5502pub fn create_cypher_reverse_udf() -> ScalarUDF {
5513 ScalarUDF::new_from_impl(CypherReverseUdf::new())
5514}
5515
5516#[derive(Debug)]
5517struct CypherReverseUdf {
5518 signature: Signature,
5519}
5520
5521impl CypherReverseUdf {
5522 fn new() -> Self {
5523 Self {
5524 signature: Signature::any(1, Volatility::Immutable),
5525 }
5526 }
5527}
5528
5529impl_udf_eq_hash!(CypherReverseUdf);
5530
5531impl ScalarUDFImpl for CypherReverseUdf {
5532 fn as_any(&self) -> &dyn Any {
5533 self
5534 }
5535
5536 fn name(&self) -> &str {
5537 "_cypher_reverse"
5538 }
5539
5540 fn signature(&self) -> &Signature {
5541 &self.signature
5542 }
5543
5544 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5545 Ok(DataType::LargeBinary)
5546 }
5547
5548 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5549 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5550 if vals.len() != 1 {
5551 return Err(datafusion::error::DataFusionError::Execution(
5552 "_cypher_reverse(): requires exactly 1 argument".to_string(),
5553 ));
5554 }
5555 match &vals[0] {
5556 Value::Null => Ok(Value::Null),
5557 Value::String(s) => Ok(Value::String(s.chars().rev().collect())),
5558 Value::List(l) => {
5559 let mut reversed = l.clone();
5560 reversed.reverse();
5561 Ok(Value::List(reversed))
5562 }
5563 other => Err(datafusion::error::DataFusionError::Execution(format!(
5564 "_cypher_reverse(): expected string or list, got {:?}",
5565 other
5566 ))),
5567 }
5568 })
5569 }
5570}
5571
5572pub fn create_cypher_substring_udf() -> ScalarUDF {
5583 ScalarUDF::new_from_impl(CypherSubstringUdf::new())
5584}
5585
5586#[derive(Debug)]
5587struct CypherSubstringUdf {
5588 signature: Signature,
5589}
5590
5591impl CypherSubstringUdf {
5592 fn new() -> Self {
5593 Self {
5594 signature: Signature::variadic_any(Volatility::Immutable),
5595 }
5596 }
5597}
5598
5599impl_udf_eq_hash!(CypherSubstringUdf);
5600
5601impl ScalarUDFImpl for CypherSubstringUdf {
5602 fn as_any(&self) -> &dyn Any {
5603 self
5604 }
5605
5606 fn name(&self) -> &str {
5607 "_cypher_substring"
5608 }
5609
5610 fn signature(&self) -> &Signature {
5611 &self.signature
5612 }
5613
5614 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5615 Ok(DataType::Utf8)
5616 }
5617
5618 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5619 invoke_cypher_udf(args, &DataType::Utf8, |vals| {
5620 if vals.len() < 2 || vals.len() > 3 {
5621 return Err(datafusion::error::DataFusionError::Execution(
5622 "_cypher_substring(): requires 2 or 3 arguments".to_string(),
5623 ));
5624 }
5625 if vals.iter().any(|v| v.is_null()) {
5627 return Ok(Value::Null);
5628 }
5629 let s = match &vals[0] {
5630 Value::String(s) => s.as_str(),
5631 other => {
5632 return Err(datafusion::error::DataFusionError::Execution(format!(
5633 "_cypher_substring(): first argument must be a string, got {:?}",
5634 other
5635 )));
5636 }
5637 };
5638 let start = match &vals[1] {
5639 Value::Int(i) => *i,
5640 other => {
5641 return Err(datafusion::error::DataFusionError::Execution(format!(
5642 "_cypher_substring(): second argument must be an integer, got {:?}",
5643 other
5644 )));
5645 }
5646 };
5647
5648 let chars: Vec<char> = s.chars().collect();
5650 let len = chars.len() as i64;
5651
5652 let start_idx = start.max(0).min(len) as usize;
5654
5655 let end_idx = if vals.len() == 3 {
5656 let length = match &vals[2] {
5657 Value::Int(i) => *i,
5658 other => {
5659 return Err(datafusion::error::DataFusionError::Execution(format!(
5660 "_cypher_substring(): third argument must be an integer, got {:?}",
5661 other
5662 )));
5663 }
5664 };
5665 if length < 0 {
5666 return Err(datafusion::error::DataFusionError::Execution(
5667 "ArgumentError: NegativeIntegerArgument - substring length must be non-negative".to_string(),
5668 ));
5669 }
5670 (start_idx as i64 + length).min(len) as usize
5671 } else {
5672 len as usize
5673 };
5674
5675 Ok(Value::String(chars[start_idx..end_idx].iter().collect()))
5676 })
5677 }
5678}
5679
5680pub fn create_cypher_split_udf() -> ScalarUDF {
5689 ScalarUDF::new_from_impl(CypherSplitUdf::new())
5690}
5691
5692#[derive(Debug)]
5693struct CypherSplitUdf {
5694 signature: Signature,
5695}
5696
5697impl CypherSplitUdf {
5698 fn new() -> Self {
5699 Self {
5700 signature: Signature::any(2, Volatility::Immutable),
5701 }
5702 }
5703}
5704
5705impl_udf_eq_hash!(CypherSplitUdf);
5706
5707impl ScalarUDFImpl for CypherSplitUdf {
5708 fn as_any(&self) -> &dyn Any {
5709 self
5710 }
5711
5712 fn name(&self) -> &str {
5713 "_cypher_split"
5714 }
5715
5716 fn signature(&self) -> &Signature {
5717 &self.signature
5718 }
5719
5720 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5721 Ok(DataType::LargeBinary)
5722 }
5723
5724 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5725 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5726 if vals.len() != 2 {
5727 return Err(datafusion::error::DataFusionError::Execution(
5728 "_cypher_split(): requires exactly 2 arguments".to_string(),
5729 ));
5730 }
5731 if vals.iter().any(|v| v.is_null()) {
5733 return Ok(Value::Null);
5734 }
5735 let s = match &vals[0] {
5736 Value::String(s) => s.clone(),
5737 other => {
5738 return Err(datafusion::error::DataFusionError::Execution(format!(
5739 "_cypher_split(): first argument must be a string, got {:?}",
5740 other
5741 )));
5742 }
5743 };
5744 let delimiter = match &vals[1] {
5745 Value::String(d) => d.clone(),
5746 other => {
5747 return Err(datafusion::error::DataFusionError::Execution(format!(
5748 "_cypher_split(): second argument must be a string, got {:?}",
5749 other
5750 )));
5751 }
5752 };
5753 let parts: Vec<Value> = s
5754 .split(&delimiter)
5755 .map(|p| Value::String(p.to_string()))
5756 .collect();
5757 Ok(Value::List(parts))
5758 })
5759 }
5760}
5761
5762pub fn create_cypher_list_to_cv_udf() -> ScalarUDF {
5773 ScalarUDF::new_from_impl(CypherListToCvUdf::new())
5774}
5775
5776#[derive(Debug)]
5777struct CypherListToCvUdf {
5778 signature: Signature,
5779}
5780
5781impl CypherListToCvUdf {
5782 fn new() -> Self {
5783 Self {
5784 signature: Signature::any(1, Volatility::Immutable),
5785 }
5786 }
5787}
5788
5789impl_udf_eq_hash!(CypherListToCvUdf);
5790
5791impl ScalarUDFImpl for CypherListToCvUdf {
5792 fn as_any(&self) -> &dyn Any {
5793 self
5794 }
5795
5796 fn name(&self) -> &str {
5797 "_cypher_list_to_cv"
5798 }
5799
5800 fn signature(&self) -> &Signature {
5801 &self.signature
5802 }
5803
5804 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5805 Ok(DataType::LargeBinary)
5806 }
5807
5808 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5809 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5810 if vals.len() != 1 {
5811 return Err(datafusion::error::DataFusionError::Execution(
5812 "_cypher_list_to_cv(): requires exactly 1 argument".to_string(),
5813 ));
5814 }
5815 Ok(vals[0].clone())
5816 })
5817 }
5818}
5819
5820pub fn create_cypher_scalar_to_cv_udf() -> ScalarUDF {
5831 ScalarUDF::new_from_impl(CypherScalarToCvUdf::new())
5832}
5833
5834#[derive(Debug)]
5835struct CypherScalarToCvUdf {
5836 signature: Signature,
5837}
5838
5839impl CypherScalarToCvUdf {
5840 fn new() -> Self {
5841 Self {
5842 signature: Signature::any(1, Volatility::Immutable),
5843 }
5844 }
5845}
5846
5847impl_udf_eq_hash!(CypherScalarToCvUdf);
5848
5849impl ScalarUDFImpl for CypherScalarToCvUdf {
5850 fn as_any(&self) -> &dyn Any {
5851 self
5852 }
5853
5854 fn name(&self) -> &str {
5855 "_cypher_scalar_to_cv"
5856 }
5857
5858 fn signature(&self) -> &Signature {
5859 &self.signature
5860 }
5861
5862 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5863 Ok(DataType::LargeBinary)
5864 }
5865
5866 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5867 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5868 if vals.len() != 1 {
5869 return Err(datafusion::error::DataFusionError::Execution(
5870 "_cypher_scalar_to_cv(): requires exactly 1 argument".to_string(),
5871 ));
5872 }
5873 Ok(vals[0].clone())
5874 })
5875 }
5876}
5877
5878pub fn create_cypher_tail_udf() -> ScalarUDF {
5890 ScalarUDF::new_from_impl(CypherTailUdf::new())
5891}
5892
5893#[derive(Debug)]
5894struct CypherTailUdf {
5895 signature: Signature,
5896}
5897
5898impl CypherTailUdf {
5899 fn new() -> Self {
5900 Self {
5901 signature: Signature::any(1, Volatility::Immutable),
5902 }
5903 }
5904}
5905
5906impl_udf_eq_hash!(CypherTailUdf);
5907
5908impl ScalarUDFImpl for CypherTailUdf {
5909 fn as_any(&self) -> &dyn Any {
5910 self
5911 }
5912
5913 fn name(&self) -> &str {
5914 "_cypher_tail"
5915 }
5916
5917 fn signature(&self) -> &Signature {
5918 &self.signature
5919 }
5920
5921 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5922 Ok(DataType::LargeBinary)
5923 }
5924
5925 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5926 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5927 if vals.len() != 1 {
5928 return Err(datafusion::error::DataFusionError::Execution(
5929 "_cypher_tail(): requires exactly 1 argument".to_string(),
5930 ));
5931 }
5932 match &vals[0] {
5933 Value::Null => Ok(Value::Null),
5934 Value::List(l) => {
5935 if l.is_empty() {
5936 Ok(Value::List(vec![]))
5937 } else {
5938 Ok(Value::List(l[1..].to_vec()))
5939 }
5940 }
5941 other => Err(datafusion::error::DataFusionError::Execution(format!(
5942 "_cypher_tail(): expected list, got {:?}",
5943 other
5944 ))),
5945 }
5946 })
5947 }
5948}
5949
5950pub fn create_cypher_head_udf() -> ScalarUDF {
5961 ScalarUDF::new_from_impl(CypherHeadUdf::new())
5962}
5963
5964#[derive(Debug)]
5965struct CypherHeadUdf {
5966 signature: Signature,
5967}
5968
5969impl CypherHeadUdf {
5970 fn new() -> Self {
5971 Self {
5972 signature: Signature::any(1, Volatility::Immutable),
5973 }
5974 }
5975}
5976
5977impl_udf_eq_hash!(CypherHeadUdf);
5978
5979impl ScalarUDFImpl for CypherHeadUdf {
5980 fn as_any(&self) -> &dyn Any {
5981 self
5982 }
5983
5984 fn name(&self) -> &str {
5985 "head"
5986 }
5987
5988 fn signature(&self) -> &Signature {
5989 &self.signature
5990 }
5991
5992 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
5993 Ok(DataType::LargeBinary)
5994 }
5995
5996 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
5997 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
5998 if vals.len() != 1 {
5999 return Err(datafusion::error::DataFusionError::Execution(
6000 "head(): requires exactly 1 argument".to_string(),
6001 ));
6002 }
6003 match &vals[0] {
6004 Value::Null => Ok(Value::Null),
6005 Value::List(l) => Ok(l.first().cloned().unwrap_or(Value::Null)),
6006 other => Err(datafusion::error::DataFusionError::Execution(format!(
6007 "head(): expected list, got {:?}",
6008 other
6009 ))),
6010 }
6011 })
6012 }
6013}
6014
6015pub fn create_cypher_last_udf() -> ScalarUDF {
6026 ScalarUDF::new_from_impl(CypherLastUdf::new())
6027}
6028
6029#[derive(Debug)]
6030struct CypherLastUdf {
6031 signature: Signature,
6032}
6033
6034impl CypherLastUdf {
6035 fn new() -> Self {
6036 Self {
6037 signature: Signature::any(1, Volatility::Immutable),
6038 }
6039 }
6040}
6041
6042impl_udf_eq_hash!(CypherLastUdf);
6043
6044impl ScalarUDFImpl for CypherLastUdf {
6045 fn as_any(&self) -> &dyn Any {
6046 self
6047 }
6048
6049 fn name(&self) -> &str {
6050 "last"
6051 }
6052
6053 fn signature(&self) -> &Signature {
6054 &self.signature
6055 }
6056
6057 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
6058 Ok(DataType::LargeBinary)
6059 }
6060
6061 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6062 invoke_cypher_udf(args, &DataType::LargeBinary, |vals| {
6063 if vals.len() != 1 {
6064 return Err(datafusion::error::DataFusionError::Execution(
6065 "last(): requires exactly 1 argument".to_string(),
6066 ));
6067 }
6068 match &vals[0] {
6069 Value::Null => Ok(Value::Null),
6070 Value::List(l) => Ok(l.last().cloned().unwrap_or(Value::Null)),
6071 other => Err(datafusion::error::DataFusionError::Execution(format!(
6072 "last(): expected list, got {:?}",
6073 other
6074 ))),
6075 }
6076 })
6077 }
6078}
6079
6080fn cypher_list_cmp(left: &[Value], right: &[Value]) -> Option<std::cmp::Ordering> {
6083 let min_len = left.len().min(right.len());
6084 for i in 0..min_len {
6085 let cmp = cypher_value_cmp(&left[i], &right[i])?;
6086 if cmp != std::cmp::Ordering::Equal {
6087 return Some(cmp);
6088 }
6089 }
6090 Some(left.len().cmp(&right.len()))
6092}
6093
6094fn cypher_value_cmp(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
6097 match (a, b) {
6098 (Value::Null, Value::Null) => Some(std::cmp::Ordering::Equal),
6099 (Value::Null, _) | (_, Value::Null) => None,
6100 (Value::Int(l), Value::Int(r)) => Some(l.cmp(r)),
6101 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r),
6102 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r),
6103 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)),
6104 (Value::String(l), Value::String(r)) => Some(l.cmp(r)),
6105 (Value::Bool(l), Value::Bool(r)) => Some(l.cmp(r)),
6106 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r),
6107 _ => None, }
6109}
6110
6111struct CypherToFloat64Udf {
6119 signature: Signature,
6120}
6121
6122impl CypherToFloat64Udf {
6123 fn new() -> Self {
6124 Self {
6125 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6126 }
6127 }
6128}
6129
6130impl_udf_eq_hash!(CypherToFloat64Udf);
6131
6132impl std::fmt::Debug for CypherToFloat64Udf {
6133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6134 f.debug_struct("CypherToFloat64Udf").finish()
6135 }
6136}
6137
6138impl ScalarUDFImpl for CypherToFloat64Udf {
6139 fn as_any(&self) -> &dyn Any {
6140 self
6141 }
6142 fn name(&self) -> &str {
6143 "_cypher_to_float64"
6144 }
6145 fn signature(&self) -> &Signature {
6146 &self.signature
6147 }
6148 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6149 Ok(DataType::Float64)
6150 }
6151 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
6152 if args.args.len() != 1 {
6153 return Err(datafusion::error::DataFusionError::Execution(
6154 "_cypher_to_float64 requires exactly 1 argument".into(),
6155 ));
6156 }
6157 match &args.args[0] {
6158 ColumnarValue::Scalar(scalar) => {
6159 let f = match scalar {
6160 ScalarValue::LargeBinary(Some(bytes)) => cv_bytes_as_f64(bytes),
6161 ScalarValue::Int64(Some(i)) => Some(*i as f64),
6162 ScalarValue::Int32(Some(i)) => Some(*i as f64),
6163 ScalarValue::Float64(Some(f)) => Some(*f),
6164 ScalarValue::Float32(Some(f)) => Some(*f as f64),
6165 _ => None,
6166 };
6167 Ok(ColumnarValue::Scalar(ScalarValue::Float64(f)))
6168 }
6169 ColumnarValue::Array(arr) => {
6170 let len = arr.len();
6171 let mut builder = arrow::array::Float64Builder::with_capacity(len);
6172 match arr.data_type() {
6173 DataType::LargeBinary => {
6174 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6175 for i in 0..len {
6176 if lb.is_null(i) {
6177 builder.append_null();
6178 } else {
6179 match cv_bytes_as_f64(lb.value(i)) {
6180 Some(f) => builder.append_value(f),
6181 None => builder.append_null(),
6182 }
6183 }
6184 }
6185 }
6186 DataType::Int64 => {
6187 let int_arr = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6188 for i in 0..len {
6189 if int_arr.is_null(i) {
6190 builder.append_null();
6191 } else {
6192 builder.append_value(int_arr.value(i) as f64);
6193 }
6194 }
6195 }
6196 DataType::Float64 => {
6197 let f_arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6198 for i in 0..len {
6199 if f_arr.is_null(i) {
6200 builder.append_null();
6201 } else {
6202 builder.append_value(f_arr.value(i));
6203 }
6204 }
6205 }
6206 _ => {
6207 for _ in 0..len {
6208 builder.append_null();
6209 }
6210 }
6211 }
6212 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
6213 }
6214 }
6215 }
6216}
6217
6218fn create_cypher_to_float64_udf() -> ScalarUDF {
6219 ScalarUDF::from(CypherToFloat64Udf::new())
6220}
6221
6222pub(crate) fn cypher_to_float64_expr(
6224 arg: datafusion::logical_expr::Expr,
6225) -> datafusion::logical_expr::Expr {
6226 datafusion::logical_expr::Expr::ScalarFunction(
6227 datafusion::logical_expr::expr::ScalarFunction::new_udf(
6228 Arc::new(create_cypher_to_float64_udf()),
6229 vec![arg],
6230 ),
6231 )
6232}
6233
6234pub(crate) fn cypher_to_float64_udf() -> datafusion::logical_expr::ScalarUDF {
6236 create_cypher_to_float64_udf()
6237}
6238
6239fn cypher_type_rank(val: &Value) -> u8 {
6247 match val {
6248 Value::Null => 0,
6249 Value::List(_) => 1,
6250 Value::String(_) => 2,
6251 Value::Bool(_) => 3,
6252 Value::Int(_) | Value::Float(_) => 4,
6253 _ => 5, }
6255}
6256
6257fn cypher_cross_type_cmp(a: &Value, b: &Value) -> std::cmp::Ordering {
6260 use std::cmp::Ordering;
6261 let ra = cypher_type_rank(a);
6262 let rb = cypher_type_rank(b);
6263 if ra != rb {
6264 return ra.cmp(&rb);
6265 }
6266 match (a, b) {
6268 (Value::Int(l), Value::Int(r)) => l.cmp(r),
6269 (Value::Float(l), Value::Float(r)) => l.partial_cmp(r).unwrap_or(Ordering::Equal),
6270 (Value::Int(l), Value::Float(r)) => (*l as f64).partial_cmp(r).unwrap_or(Ordering::Equal),
6271 (Value::Float(l), Value::Int(r)) => l.partial_cmp(&(*r as f64)).unwrap_or(Ordering::Equal),
6272 (Value::String(l), Value::String(r)) => l.cmp(r),
6273 (Value::Bool(l), Value::Bool(r)) => l.cmp(r),
6274 (Value::List(l), Value::List(r)) => cypher_list_cmp(l, r).unwrap_or(Ordering::Equal),
6275 _ => Ordering::Equal,
6276 }
6277}
6278
6279fn scalar_binary_to_value(bytes: &[u8]) -> Value {
6281 uni_common::cypher_value_codec::decode(bytes).unwrap_or(Value::Null)
6282}
6283
6284use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, AggregateUDFImpl};
6285
6286#[derive(Debug, Clone)]
6288struct CypherMinMaxUdaf {
6289 name: String,
6290 signature: Signature,
6291 is_max: bool,
6292}
6293
6294impl CypherMinMaxUdaf {
6295 fn new(is_max: bool) -> Self {
6296 let name = if is_max { "_cypher_max" } else { "_cypher_min" };
6297 Self {
6298 name: name.to_string(),
6299 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6300 is_max,
6301 }
6302 }
6303}
6304
6305impl PartialEq for CypherMinMaxUdaf {
6306 fn eq(&self, other: &Self) -> bool {
6307 self.name == other.name
6308 }
6309}
6310
6311impl Eq for CypherMinMaxUdaf {}
6312
6313impl Hash for CypherMinMaxUdaf {
6314 fn hash<H: Hasher>(&self, state: &mut H) {
6315 self.name.hash(state);
6316 }
6317}
6318
6319impl AggregateUDFImpl for CypherMinMaxUdaf {
6320 fn as_any(&self) -> &dyn Any {
6321 self
6322 }
6323 fn name(&self) -> &str {
6324 &self.name
6325 }
6326 fn signature(&self) -> &Signature {
6327 &self.signature
6328 }
6329 fn return_type(&self, args: &[DataType]) -> DFResult<DataType> {
6330 Ok(args.first().cloned().unwrap_or(DataType::LargeBinary))
6332 }
6333 fn accumulator(
6334 &self,
6335 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6336 ) -> DFResult<Box<dyn DfAccumulator>> {
6337 Ok(Box::new(CypherMinMaxAccumulator {
6338 current: None,
6339 is_max: self.is_max,
6340 return_type: acc_args.return_field.data_type().clone(),
6341 }))
6342 }
6343 fn state_fields(
6344 &self,
6345 args: datafusion::logical_expr::function::StateFieldsArgs,
6346 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6347 Ok(vec![Arc::new(arrow::datatypes::Field::new(
6348 args.name,
6349 DataType::LargeBinary,
6350 true,
6351 ))])
6352 }
6353}
6354
6355#[derive(Debug)]
6356struct CypherMinMaxAccumulator {
6357 current: Option<Value>,
6358 is_max: bool,
6359 return_type: DataType,
6360}
6361
6362impl DfAccumulator for CypherMinMaxAccumulator {
6363 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6364 let arr = &values[0];
6365 match arr.data_type() {
6366 DataType::LargeBinary => {
6367 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6368 for i in 0..lb.len() {
6369 if lb.is_null(i) {
6370 continue;
6371 }
6372 let val = scalar_binary_to_value(lb.value(i));
6373 if val.is_null() {
6374 continue;
6375 }
6376 self.current = Some(match self.current.take() {
6377 None => val,
6378 Some(cur) => {
6379 let ord = cypher_cross_type_cmp(&val, &cur);
6380 if (self.is_max && ord == std::cmp::Ordering::Greater)
6381 || (!self.is_max && ord == std::cmp::Ordering::Less)
6382 {
6383 val
6384 } else {
6385 cur
6386 }
6387 }
6388 });
6389 }
6390 }
6391 _ => {
6392 for i in 0..arr.len() {
6394 if arr.is_null(i) {
6395 continue;
6396 }
6397 let sv = ScalarValue::try_from_array(arr, i).map_err(|e| {
6398 datafusion::error::DataFusionError::Execution(e.to_string())
6399 })?;
6400 let val = scalar_to_value(&sv)?;
6401 if val.is_null() {
6402 continue;
6403 }
6404 self.current = Some(match self.current.take() {
6405 None => val,
6406 Some(cur) => {
6407 let ord = cypher_cross_type_cmp(&val, &cur);
6408 if (self.is_max && ord == std::cmp::Ordering::Greater)
6409 || (!self.is_max && ord == std::cmp::Ordering::Less)
6410 {
6411 val
6412 } else {
6413 cur
6414 }
6415 }
6416 });
6417 }
6418 }
6419 }
6420 Ok(())
6421 }
6422 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6423 match &self.current {
6424 None => {
6425 ScalarValue::try_from(&self.return_type)
6427 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
6428 }
6429 Some(val) => {
6430 if matches!(self.return_type, DataType::LargeBinary) {
6432 let bytes = uni_common::cypher_value_codec::encode(val);
6433 return Ok(ScalarValue::LargeBinary(Some(bytes)));
6434 }
6435 match val {
6437 Value::Int(i) => match &self.return_type {
6438 DataType::Int64 => Ok(ScalarValue::Int64(Some(*i))),
6439 DataType::UInt64 => Ok(ScalarValue::UInt64(Some(*i as u64))),
6440 _ => {
6441 let bytes = uni_common::cypher_value_codec::encode(val);
6442 Ok(ScalarValue::LargeBinary(Some(bytes)))
6443 }
6444 },
6445 Value::Float(f) => match &self.return_type {
6446 DataType::Float64 => Ok(ScalarValue::Float64(Some(*f))),
6447 _ => {
6448 let bytes = uni_common::cypher_value_codec::encode(val);
6449 Ok(ScalarValue::LargeBinary(Some(bytes)))
6450 }
6451 },
6452 Value::String(s) => match &self.return_type {
6453 DataType::Utf8 => Ok(ScalarValue::Utf8(Some(s.clone()))),
6454 DataType::LargeUtf8 => Ok(ScalarValue::LargeUtf8(Some(s.clone()))),
6455 _ => {
6456 let bytes = uni_common::cypher_value_codec::encode(val);
6457 Ok(ScalarValue::LargeBinary(Some(bytes)))
6458 }
6459 },
6460 Value::Bool(b) => match &self.return_type {
6461 DataType::Boolean => Ok(ScalarValue::Boolean(Some(*b))),
6462 _ => {
6463 let bytes = uni_common::cypher_value_codec::encode(val);
6464 Ok(ScalarValue::LargeBinary(Some(bytes)))
6465 }
6466 },
6467 _ => {
6468 let bytes = uni_common::cypher_value_codec::encode(val);
6470 Ok(ScalarValue::LargeBinary(Some(bytes)))
6471 }
6472 }
6473 }
6474 }
6475 }
6476 fn size(&self) -> usize {
6477 std::mem::size_of_val(self) + self.current.as_ref().map_or(0, |_| 64)
6478 }
6479 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6480 Ok(vec![self.evaluate()?])
6481 }
6482 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6483 self.update_batch(states)
6484 }
6485}
6486
6487pub(crate) fn create_cypher_min_udaf() -> AggregateUDF {
6488 AggregateUDF::from(CypherMinMaxUdaf::new(false))
6489}
6490
6491pub(crate) fn create_cypher_max_udaf() -> AggregateUDF {
6492 AggregateUDF::from(CypherMinMaxUdaf::new(true))
6493}
6494
6495#[derive(Debug, Clone)]
6501struct CypherSumUdaf {
6502 signature: Signature,
6503}
6504
6505impl CypherSumUdaf {
6506 fn new() -> Self {
6507 Self {
6508 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6509 }
6510 }
6511}
6512
6513impl PartialEq for CypherSumUdaf {
6514 fn eq(&self, other: &Self) -> bool {
6515 self.signature == other.signature
6516 }
6517}
6518
6519impl Eq for CypherSumUdaf {}
6520
6521impl Hash for CypherSumUdaf {
6522 fn hash<H: Hasher>(&self, state: &mut H) {
6523 self.name().hash(state);
6524 }
6525}
6526
6527impl AggregateUDFImpl for CypherSumUdaf {
6528 fn as_any(&self) -> &dyn Any {
6529 self
6530 }
6531 fn name(&self) -> &str {
6532 "_cypher_sum"
6533 }
6534 fn signature(&self) -> &Signature {
6535 &self.signature
6536 }
6537 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6538 Ok(DataType::LargeBinary)
6541 }
6542 fn accumulator(
6543 &self,
6544 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6545 ) -> DFResult<Box<dyn DfAccumulator>> {
6546 Ok(Box::new(CypherSumAccumulator {
6547 sum: 0.0,
6548 all_ints: true,
6549 int_sum: 0i64,
6550 has_value: false,
6551 }))
6552 }
6553 fn state_fields(
6554 &self,
6555 args: datafusion::logical_expr::function::StateFieldsArgs,
6556 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6557 Ok(vec![
6558 Arc::new(arrow::datatypes::Field::new(
6559 format!("{}_sum", args.name),
6560 DataType::Float64,
6561 true,
6562 )),
6563 Arc::new(arrow::datatypes::Field::new(
6564 format!("{}_int_sum", args.name),
6565 DataType::Int64,
6566 true,
6567 )),
6568 Arc::new(arrow::datatypes::Field::new(
6569 format!("{}_all_ints", args.name),
6570 DataType::Boolean,
6571 true,
6572 )),
6573 Arc::new(arrow::datatypes::Field::new(
6574 format!("{}_has_value", args.name),
6575 DataType::Boolean,
6576 true,
6577 )),
6578 ])
6579 }
6580}
6581
6582#[derive(Debug)]
6583struct CypherSumAccumulator {
6584 sum: f64,
6585 all_ints: bool,
6586 int_sum: i64,
6587 has_value: bool,
6588}
6589
6590impl DfAccumulator for CypherSumAccumulator {
6591 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6592 let arr = &values[0];
6593 for i in 0..arr.len() {
6594 if arr.is_null(i) {
6595 continue;
6596 }
6597 match arr.data_type() {
6598 DataType::LargeBinary => {
6599 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
6600 let bytes = lb.value(i);
6601 use uni_common::cypher_value_codec::{
6602 TAG_FLOAT, TAG_INT, decode_float, decode_int, peek_tag,
6603 };
6604 match peek_tag(bytes) {
6605 Some(TAG_INT) => {
6606 if let Some(v) = decode_int(bytes) {
6607 self.sum += v as f64;
6608 self.int_sum = self.int_sum.wrapping_add(v);
6609 self.has_value = true;
6610 }
6611 }
6612 Some(TAG_FLOAT) => {
6613 if let Some(v) = decode_float(bytes) {
6614 self.sum += v;
6615 self.all_ints = false;
6616 self.has_value = true;
6617 }
6618 }
6619 _ => {} }
6621 }
6622 DataType::Int64 => {
6623 let a = arr.as_any().downcast_ref::<Int64Array>().unwrap();
6624 let v = a.value(i);
6625 self.sum += v as f64;
6626 self.int_sum = self.int_sum.wrapping_add(v);
6627 self.has_value = true;
6628 }
6629 DataType::Float64 => {
6630 let a = arr.as_any().downcast_ref::<Float64Array>().unwrap();
6631 self.sum += a.value(i);
6632 self.all_ints = false;
6633 self.has_value = true;
6634 }
6635 _ => {}
6636 }
6637 }
6638 Ok(())
6639 }
6640 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6641 if !self.has_value {
6642 return Ok(ScalarValue::LargeBinary(None));
6643 }
6644 let val = if self.all_ints {
6645 Value::Int(self.int_sum)
6646 } else {
6647 Value::Float(self.sum)
6648 };
6649 let bytes = uni_common::cypher_value_codec::encode(&val);
6650 Ok(ScalarValue::LargeBinary(Some(bytes)))
6651 }
6652 fn size(&self) -> usize {
6653 std::mem::size_of_val(self)
6654 }
6655 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6656 Ok(vec![
6657 ScalarValue::Float64(Some(self.sum)),
6658 ScalarValue::Int64(Some(self.int_sum)),
6659 ScalarValue::Boolean(Some(self.all_ints)),
6660 ScalarValue::Boolean(Some(self.has_value)),
6661 ])
6662 }
6663 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6664 let sum_arr = states[0].as_any().downcast_ref::<Float64Array>().unwrap();
6665 let int_sum_arr = states[1].as_any().downcast_ref::<Int64Array>().unwrap();
6666 let all_ints_arr = states[2].as_any().downcast_ref::<BooleanArray>().unwrap();
6667 let has_value_arr = states[3].as_any().downcast_ref::<BooleanArray>().unwrap();
6668 for i in 0..sum_arr.len() {
6669 if !has_value_arr.is_null(i) && has_value_arr.value(i) {
6670 self.sum += sum_arr.value(i);
6671 self.int_sum = self.int_sum.wrapping_add(int_sum_arr.value(i));
6672 if !all_ints_arr.value(i) {
6673 self.all_ints = false;
6674 }
6675 self.has_value = true;
6676 }
6677 }
6678 Ok(())
6679 }
6680}
6681
6682pub(crate) fn create_cypher_sum_udaf() -> AggregateUDF {
6683 AggregateUDF::from(CypherSumUdaf::new())
6684}
6685
6686#[derive(Debug, Clone)]
6693struct CypherCollectUdaf {
6694 signature: Signature,
6695}
6696
6697impl CypherCollectUdaf {
6698 fn new() -> Self {
6699 Self {
6700 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
6701 }
6702 }
6703}
6704
6705impl PartialEq for CypherCollectUdaf {
6706 fn eq(&self, other: &Self) -> bool {
6707 self.signature == other.signature
6708 }
6709}
6710
6711impl Eq for CypherCollectUdaf {}
6712
6713impl Hash for CypherCollectUdaf {
6714 fn hash<H: Hasher>(&self, state: &mut H) {
6715 self.name().hash(state);
6716 }
6717}
6718
6719impl AggregateUDFImpl for CypherCollectUdaf {
6720 fn as_any(&self) -> &dyn Any {
6721 self
6722 }
6723 fn name(&self) -> &str {
6724 "_cypher_collect"
6725 }
6726 fn signature(&self) -> &Signature {
6727 &self.signature
6728 }
6729 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6730 Ok(DataType::LargeBinary)
6731 }
6732 fn accumulator(
6733 &self,
6734 acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6735 ) -> DFResult<Box<dyn DfAccumulator>> {
6736 Ok(Box::new(CypherCollectAccumulator {
6737 values: Vec::new(),
6738 distinct: acc_args.is_distinct,
6739 }))
6740 }
6741 fn state_fields(
6742 &self,
6743 args: datafusion::logical_expr::function::StateFieldsArgs,
6744 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6745 Ok(vec![Arc::new(arrow::datatypes::Field::new(
6746 args.name,
6747 DataType::LargeBinary,
6748 true,
6749 ))])
6750 }
6751}
6752
6753#[derive(Debug)]
6754struct CypherCollectAccumulator {
6755 values: Vec<Value>,
6756 distinct: bool,
6757}
6758
6759impl DfAccumulator for CypherCollectAccumulator {
6760 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6761 let arr = &values[0];
6762 for i in 0..arr.len() {
6763 if arr.is_null(i) {
6764 continue;
6765 }
6766 if let Some(struct_arr) = arr.as_any().downcast_ref::<arrow::array::StructArray>()
6770 && struct_arr.num_columns() > 0
6771 && struct_arr.column(0).is_null(i)
6772 {
6773 continue;
6774 }
6775 let sv = ScalarValue::try_from_array(arr, i)
6776 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
6777 let val = scalar_to_value(&sv)?;
6778 if val.is_null() {
6779 continue;
6780 }
6781 if self.distinct {
6782 let repr = val.to_string();
6784 if self.values.iter().any(|v| v.to_string() == repr) {
6785 continue;
6786 }
6787 }
6788 self.values.push(val);
6789 }
6790 Ok(())
6791 }
6792 fn evaluate(&mut self) -> DFResult<ScalarValue> {
6793 let val = Value::List(self.values.clone());
6795 let bytes = uni_common::cypher_value_codec::encode(&val);
6796 Ok(ScalarValue::LargeBinary(Some(bytes)))
6797 }
6798 fn size(&self) -> usize {
6799 std::mem::size_of_val(self) + self.values.len() * 64
6800 }
6801 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
6802 Ok(vec![self.evaluate()?])
6803 }
6804 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
6805 let arr = &states[0];
6807 if let Some(lb) = arr.as_any().downcast_ref::<LargeBinaryArray>() {
6808 for i in 0..lb.len() {
6809 if lb.is_null(i) {
6810 continue;
6811 }
6812 let val = scalar_binary_to_value(lb.value(i));
6813 if let Value::List(items) = val {
6814 for item in items {
6815 if !item.is_null() {
6816 if self.distinct {
6817 let repr = item.to_string();
6818 if self.values.iter().any(|v| v.to_string() == repr) {
6819 continue;
6820 }
6821 }
6822 self.values.push(item);
6823 }
6824 }
6825 }
6826 }
6827 }
6828 Ok(())
6829 }
6830}
6831
6832pub(crate) fn create_cypher_collect_udaf() -> AggregateUDF {
6833 AggregateUDF::from(CypherCollectUdaf::new())
6834}
6835
6836pub(crate) fn create_cypher_collect_expr(
6838 arg: datafusion::logical_expr::Expr,
6839 distinct: bool,
6840) -> datafusion::logical_expr::Expr {
6841 let udaf = Arc::new(create_cypher_collect_udaf());
6844 if distinct {
6845 datafusion::logical_expr::Expr::AggregateFunction(
6847 datafusion::logical_expr::expr::AggregateFunction::new_udf(
6848 udaf,
6849 vec![arg],
6850 true, None,
6852 vec![],
6853 None,
6854 ),
6855 )
6856 } else {
6857 udaf.call(vec![arg])
6858 }
6859}
6860
6861#[derive(Debug, Clone)]
6867struct CypherPercentileDiscUdaf {
6868 signature: Signature,
6869}
6870
6871impl CypherPercentileDiscUdaf {
6872 fn new() -> Self {
6873 Self {
6874 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
6875 }
6876 }
6877}
6878
6879impl PartialEq for CypherPercentileDiscUdaf {
6880 fn eq(&self, other: &Self) -> bool {
6881 self.signature == other.signature
6882 }
6883}
6884
6885impl Eq for CypherPercentileDiscUdaf {}
6886
6887impl Hash for CypherPercentileDiscUdaf {
6888 fn hash<H: Hasher>(&self, state: &mut H) {
6889 self.name().hash(state);
6890 }
6891}
6892
6893impl AggregateUDFImpl for CypherPercentileDiscUdaf {
6894 fn as_any(&self) -> &dyn Any {
6895 self
6896 }
6897 fn name(&self) -> &str {
6898 "percentiledisc"
6899 }
6900 fn signature(&self) -> &Signature {
6901 &self.signature
6902 }
6903 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
6904 Ok(DataType::Float64)
6905 }
6906 fn accumulator(
6907 &self,
6908 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
6909 ) -> DFResult<Box<dyn DfAccumulator>> {
6910 Ok(Box::new(CypherPercentileDiscAccumulator {
6911 values: Vec::new(),
6912 percentile: None,
6913 }))
6914 }
6915 fn state_fields(
6916 &self,
6917 args: datafusion::logical_expr::function::StateFieldsArgs,
6918 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
6919 Ok(vec![
6920 Arc::new(arrow::datatypes::Field::new(
6921 format!("{}_values", args.name),
6922 DataType::List(Arc::new(arrow::datatypes::Field::new(
6923 "item",
6924 DataType::Float64,
6925 true,
6926 ))),
6927 true,
6928 )),
6929 Arc::new(arrow::datatypes::Field::new(
6930 format!("{}_percentile", args.name),
6931 DataType::Float64,
6932 true,
6933 )),
6934 ])
6935 }
6936}
6937
6938#[derive(Debug)]
6939struct CypherPercentileDiscAccumulator {
6940 values: Vec<f64>,
6941 percentile: Option<f64>,
6942}
6943
6944impl CypherPercentileDiscAccumulator {
6945 fn extract_f64(arr: &ArrayRef, i: usize) -> Option<f64> {
6946 if arr.is_null(i) {
6947 return None;
6948 }
6949 match arr.data_type() {
6950 DataType::LargeBinary => {
6951 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
6952 cv_bytes_as_f64(lb.value(i))
6953 }
6954 DataType::Int64 => {
6955 let a = arr.as_any().downcast_ref::<Int64Array>()?;
6956 Some(a.value(i) as f64)
6957 }
6958 DataType::Float64 => {
6959 let a = arr.as_any().downcast_ref::<Float64Array>()?;
6960 Some(a.value(i))
6961 }
6962 DataType::Int32 => {
6963 let a = arr.as_any().downcast_ref::<Int32Array>()?;
6964 Some(a.value(i) as f64)
6965 }
6966 DataType::Float32 => {
6967 let a = arr.as_any().downcast_ref::<Float32Array>()?;
6968 Some(a.value(i) as f64)
6969 }
6970 _ => None,
6971 }
6972 }
6973
6974 fn extract_percentile(arr: &ArrayRef, i: usize) -> Option<f64> {
6975 if arr.is_null(i) {
6976 return None;
6977 }
6978 match arr.data_type() {
6979 DataType::Float64 => {
6980 let a = arr.as_any().downcast_ref::<Float64Array>()?;
6981 Some(a.value(i))
6982 }
6983 DataType::Int64 => {
6984 let a = arr.as_any().downcast_ref::<Int64Array>()?;
6985 Some(a.value(i) as f64)
6986 }
6987 DataType::LargeBinary => {
6988 let lb = arr.as_any().downcast_ref::<LargeBinaryArray>()?;
6989 cv_bytes_as_f64(lb.value(i))
6990 }
6991 _ => None,
6992 }
6993 }
6994}
6995
6996impl DfAccumulator for CypherPercentileDiscAccumulator {
6997 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
6998 let expr_arr = &values[0];
6999 let pct_arr = &values[1];
7000 for i in 0..expr_arr.len() {
7001 if self.percentile.is_none()
7003 && let Some(p) = Self::extract_percentile(pct_arr, i)
7004 {
7005 if !(0.0..=1.0).contains(&p) {
7006 return Err(datafusion::error::DataFusionError::Execution(
7007 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7008 ));
7009 }
7010 self.percentile = Some(p);
7011 }
7012 if let Some(f) = Self::extract_f64(expr_arr, i) {
7013 self.values.push(f);
7014 }
7015 }
7016 Ok(())
7017 }
7018 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7019 let pct = match self.percentile {
7020 Some(p) if !(0.0..=1.0).contains(&p) => {
7021 return Err(datafusion::error::DataFusionError::Execution(
7022 "ArgumentError: NumberOutOfRange - percentileDisc(): percentile value must be between 0.0 and 1.0".to_string(),
7023 ));
7024 }
7025 Some(p) => p,
7026 None => 0.0,
7027 };
7028 if self.values.is_empty() {
7029 return Ok(ScalarValue::Float64(None));
7030 }
7031 self.values
7032 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7033 let n = self.values.len();
7034 let idx = (pct * (n as f64 - 1.0)).round() as usize;
7035 let idx = idx.min(n - 1);
7036 let result = self.values[idx];
7037 Ok(ScalarValue::Float64(Some(result)))
7038 }
7039 fn size(&self) -> usize {
7040 std::mem::size_of_val(self) + self.values.capacity() * 8
7041 }
7042 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7043 let list_values: Vec<ScalarValue> = self
7045 .values
7046 .iter()
7047 .map(|f| ScalarValue::Float64(Some(*f)))
7048 .collect();
7049 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7050 &list_values,
7051 &DataType::Float64,
7052 true,
7053 ));
7054 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7055 }
7056 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7057 let list_arr = &states[0];
7059 let pct_arr = &states[1];
7060 if self.percentile.is_none()
7062 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7063 {
7064 for i in 0..f64_arr.len() {
7065 if !f64_arr.is_null(i) {
7066 self.percentile = Some(f64_arr.value(i));
7067 break;
7068 }
7069 }
7070 }
7071 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7073 for i in 0..list_array.len() {
7074 if list_array.is_null(i) {
7075 continue;
7076 }
7077 let inner = list_array.value(i);
7078 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7079 for j in 0..f64_arr.len() {
7080 if !f64_arr.is_null(j) {
7081 self.values.push(f64_arr.value(j));
7082 }
7083 }
7084 }
7085 }
7086 }
7087 Ok(())
7088 }
7089}
7090
7091#[derive(Debug, Clone)]
7093struct CypherPercentileContUdaf {
7094 signature: Signature,
7095}
7096
7097impl CypherPercentileContUdaf {
7098 fn new() -> Self {
7099 Self {
7100 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7101 }
7102 }
7103}
7104
7105impl PartialEq for CypherPercentileContUdaf {
7106 fn eq(&self, other: &Self) -> bool {
7107 self.signature == other.signature
7108 }
7109}
7110
7111impl Eq for CypherPercentileContUdaf {}
7112
7113impl Hash for CypherPercentileContUdaf {
7114 fn hash<H: Hasher>(&self, state: &mut H) {
7115 self.name().hash(state);
7116 }
7117}
7118
7119impl AggregateUDFImpl for CypherPercentileContUdaf {
7120 fn as_any(&self) -> &dyn Any {
7121 self
7122 }
7123 fn name(&self) -> &str {
7124 "percentilecont"
7125 }
7126 fn signature(&self) -> &Signature {
7127 &self.signature
7128 }
7129 fn return_type(&self, _args: &[DataType]) -> DFResult<DataType> {
7130 Ok(DataType::Float64)
7131 }
7132 fn accumulator(
7133 &self,
7134 _acc_args: datafusion::logical_expr::function::AccumulatorArgs,
7135 ) -> DFResult<Box<dyn DfAccumulator>> {
7136 Ok(Box::new(CypherPercentileContAccumulator {
7137 values: Vec::new(),
7138 percentile: None,
7139 }))
7140 }
7141 fn state_fields(
7142 &self,
7143 args: datafusion::logical_expr::function::StateFieldsArgs,
7144 ) -> DFResult<Vec<Arc<arrow::datatypes::Field>>> {
7145 Ok(vec![
7146 Arc::new(arrow::datatypes::Field::new(
7147 format!("{}_values", args.name),
7148 DataType::List(Arc::new(arrow::datatypes::Field::new(
7149 "item",
7150 DataType::Float64,
7151 true,
7152 ))),
7153 true,
7154 )),
7155 Arc::new(arrow::datatypes::Field::new(
7156 format!("{}_percentile", args.name),
7157 DataType::Float64,
7158 true,
7159 )),
7160 ])
7161 }
7162}
7163
7164#[derive(Debug)]
7165struct CypherPercentileContAccumulator {
7166 values: Vec<f64>,
7167 percentile: Option<f64>,
7168}
7169
7170impl DfAccumulator for CypherPercentileContAccumulator {
7171 fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
7172 let expr_arr = &values[0];
7173 let pct_arr = &values[1];
7174 for i in 0..expr_arr.len() {
7175 if self.percentile.is_none()
7176 && let Some(p) = CypherPercentileDiscAccumulator::extract_percentile(pct_arr, i)
7177 {
7178 if !(0.0..=1.0).contains(&p) {
7179 return Err(datafusion::error::DataFusionError::Execution(
7180 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7181 ));
7182 }
7183 self.percentile = Some(p);
7184 }
7185 if let Some(f) = CypherPercentileDiscAccumulator::extract_f64(expr_arr, i) {
7186 self.values.push(f);
7187 }
7188 }
7189 Ok(())
7190 }
7191 fn evaluate(&mut self) -> DFResult<ScalarValue> {
7192 let pct = match self.percentile {
7193 Some(p) if !(0.0..=1.0).contains(&p) => {
7194 return Err(datafusion::error::DataFusionError::Execution(
7195 "ArgumentError: NumberOutOfRange - percentileCont(): percentile value must be between 0.0 and 1.0".to_string(),
7196 ));
7197 }
7198 Some(p) => p,
7199 None => 0.0,
7200 };
7201 if self.values.is_empty() {
7202 return Ok(ScalarValue::Float64(None));
7203 }
7204 self.values
7205 .sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
7206 let n = self.values.len();
7207 if n == 1 {
7208 return Ok(ScalarValue::Float64(Some(self.values[0])));
7209 }
7210 let pos = pct * (n as f64 - 1.0);
7211 let lower = pos.floor() as usize;
7212 let upper = pos.ceil() as usize;
7213 let lower = lower.min(n - 1);
7214 let upper = upper.min(n - 1);
7215 if lower == upper {
7216 Ok(ScalarValue::Float64(Some(self.values[lower])))
7217 } else {
7218 let frac = pos - lower as f64;
7219 let result = self.values[lower] + frac * (self.values[upper] - self.values[lower]);
7220 Ok(ScalarValue::Float64(Some(result)))
7221 }
7222 }
7223 fn size(&self) -> usize {
7224 std::mem::size_of_val(self) + self.values.capacity() * 8
7225 }
7226 fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
7227 let list_values: Vec<ScalarValue> = self
7228 .values
7229 .iter()
7230 .map(|f| ScalarValue::Float64(Some(*f)))
7231 .collect();
7232 let list_scalar = ScalarValue::List(ScalarValue::new_list(
7233 &list_values,
7234 &DataType::Float64,
7235 true,
7236 ));
7237 Ok(vec![list_scalar, ScalarValue::Float64(self.percentile)])
7238 }
7239 fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
7240 let list_arr = &states[0];
7241 let pct_arr = &states[1];
7242 if self.percentile.is_none()
7243 && let Some(f64_arr) = pct_arr.as_any().downcast_ref::<Float64Array>()
7244 {
7245 for i in 0..f64_arr.len() {
7246 if !f64_arr.is_null(i) {
7247 self.percentile = Some(f64_arr.value(i));
7248 break;
7249 }
7250 }
7251 }
7252 if let Some(list_array) = list_arr.as_any().downcast_ref::<arrow_array::ListArray>() {
7253 for i in 0..list_array.len() {
7254 if list_array.is_null(i) {
7255 continue;
7256 }
7257 let inner = list_array.value(i);
7258 if let Some(f64_arr) = inner.as_any().downcast_ref::<Float64Array>() {
7259 for j in 0..f64_arr.len() {
7260 if !f64_arr.is_null(j) {
7261 self.values.push(f64_arr.value(j));
7262 }
7263 }
7264 }
7265 }
7266 }
7267 Ok(())
7268 }
7269}
7270
7271pub(crate) fn create_cypher_percentile_disc_udaf() -> AggregateUDF {
7272 AggregateUDF::from(CypherPercentileDiscUdaf::new())
7273}
7274
7275pub(crate) fn create_cypher_percentile_cont_udaf() -> AggregateUDF {
7276 AggregateUDF::from(CypherPercentileContUdaf::new())
7277}
7278
7279fn invoke_similarity_udf(
7289 func_name: &str,
7290 min_args: usize,
7291 args: ScalarFunctionArgs,
7292) -> DFResult<ColumnarValue> {
7293 let output_type = DataType::Float64;
7294 invoke_cypher_udf(args, &output_type, |val_args| {
7295 if val_args.len() < min_args {
7296 return Err(datafusion::error::DataFusionError::Execution(format!(
7297 "{} requires at least {} arguments",
7298 func_name, min_args
7299 )));
7300 }
7301 crate::query::similar_to::eval_similar_to_pure(&val_args[0], &val_args[1])
7302 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
7303 })
7304}
7305
7306pub fn create_similar_to_udf() -> ScalarUDF {
7308 ScalarUDF::new_from_impl(SimilarToUdf::new())
7309}
7310
7311#[derive(Debug)]
7312struct SimilarToUdf {
7313 signature: Signature,
7314}
7315
7316impl SimilarToUdf {
7317 fn new() -> Self {
7318 Self {
7319 signature: Signature::new(TypeSignature::VariadicAny, Volatility::Immutable),
7320 }
7321 }
7322}
7323
7324impl_udf_eq_hash!(SimilarToUdf);
7325
7326impl ScalarUDFImpl for SimilarToUdf {
7327 fn as_any(&self) -> &dyn Any {
7328 self
7329 }
7330
7331 fn name(&self) -> &str {
7332 "similar_to"
7333 }
7334
7335 fn signature(&self) -> &Signature {
7336 &self.signature
7337 }
7338
7339 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7340 Ok(DataType::Float64)
7341 }
7342
7343 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7344 invoke_similarity_udf("similar_to", 2, args)
7345 }
7346}
7347
7348pub fn create_vector_similarity_udf() -> ScalarUDF {
7350 ScalarUDF::new_from_impl(VectorSimilarityUdf::new())
7351}
7352
7353#[derive(Debug)]
7354struct VectorSimilarityUdf {
7355 signature: Signature,
7356}
7357
7358impl VectorSimilarityUdf {
7359 fn new() -> Self {
7360 Self {
7361 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
7362 }
7363 }
7364}
7365
7366impl_udf_eq_hash!(VectorSimilarityUdf);
7367
7368impl ScalarUDFImpl for VectorSimilarityUdf {
7369 fn as_any(&self) -> &dyn Any {
7370 self
7371 }
7372
7373 fn name(&self) -> &str {
7374 "vector_similarity"
7375 }
7376
7377 fn signature(&self) -> &Signature {
7378 &self.signature
7379 }
7380
7381 fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
7382 Ok(DataType::Float64)
7383 }
7384
7385 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
7386 invoke_similarity_udf("vector_similarity", 2, args)
7387 }
7388}
7389
7390#[cfg(test)]
7391mod tests {
7392 use super::*;
7393 use datafusion::execution::FunctionRegistry;
7394
7395 #[test]
7396 fn test_register_udfs() {
7397 let ctx = SessionContext::new();
7398 register_cypher_udfs(&ctx).unwrap();
7399
7400 assert!(ctx.udf("id").is_ok());
7403 assert!(ctx.udf("type").is_ok());
7404 assert!(ctx.udf("keys").is_ok());
7405 assert!(ctx.udf("range").is_ok());
7406 assert!(
7407 ctx.udf("_make_cypher_list").is_ok(),
7408 "_make_cypher_list UDF should be registered"
7409 );
7410 assert!(
7411 ctx.udf("_cv_to_bool").is_ok(),
7412 "_cv_to_bool UDF should be registered"
7413 );
7414 }
7415
7416 #[test]
7417 fn test_id_udf_signature() {
7418 let udf = create_id_udf();
7419 assert_eq!(udf.name(), "id");
7420 }
7421
7422 #[test]
7423 fn test_has_null_udf() {
7424 use datafusion::arrow::datatypes::{DataType, Field};
7425 use datafusion::config::ConfigOptions;
7426 use datafusion::scalar::ScalarValue;
7427 use std::sync::Arc;
7428
7429 let udf = create_has_null_udf();
7430
7431 let values = vec![
7433 ScalarValue::Int64(Some(1)),
7434 ScalarValue::Int64(Some(2)),
7435 ScalarValue::Int64(None),
7436 ];
7437
7438 let list_scalar = ScalarValue::List(ScalarValue::new_list(&values, &DataType::Int64, true));
7440
7441 let list_field = Arc::new(Field::new(
7442 "item",
7443 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
7444 true,
7445 ));
7446
7447 let args = ScalarFunctionArgs {
7448 args: vec![ColumnarValue::Scalar(list_scalar)],
7449 arg_fields: vec![list_field],
7450 number_rows: 1,
7451 return_field: Arc::new(Field::new("result", DataType::Boolean, true)),
7452 config_options: Arc::new(ConfigOptions::default()),
7453 };
7454
7455 let result = udf.invoke_with_args(args).unwrap();
7456
7457 if let ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) = result {
7458 assert!(b, "has_null should return true for list with null");
7459 } else {
7460 panic!("Unexpected result: {:?}", result);
7461 }
7462 }
7463
7464 fn json_to_cv_bytes(val: &serde_json::Value) -> Vec<u8> {
7470 let uni_val: uni_common::Value = val.clone().into();
7471 uni_common::cypher_value_codec::encode(&uni_val)
7472 }
7473
7474 fn make_multi_scalar_args(scalars: Vec<ScalarValue>) -> ScalarFunctionArgs {
7483 make_multi_scalar_args_with_return(scalars, DataType::LargeBinary)
7484 }
7485
7486 fn make_multi_scalar_args_with_return(
7487 scalars: Vec<ScalarValue>,
7488 return_type: DataType,
7489 ) -> ScalarFunctionArgs {
7490 use datafusion::arrow::datatypes::Field;
7491 use datafusion::config::ConfigOptions;
7492
7493 let arg_fields: Vec<_> = scalars
7494 .iter()
7495 .enumerate()
7496 .map(|(i, s)| Arc::new(Field::new(format!("arg{i}"), s.data_type(), true)))
7497 .collect();
7498 let args: Vec<_> = scalars.into_iter().map(ColumnarValue::Scalar).collect();
7499 ScalarFunctionArgs {
7500 args,
7501 arg_fields,
7502 number_rows: 1,
7503 return_field: Arc::new(Field::new("result", return_type, true)),
7504 config_options: Arc::new(ConfigOptions::default()),
7505 }
7506 }
7507
7508 fn decode_cv_scalar(cv: &ColumnarValue) -> serde_json::Value {
7510 match cv {
7511 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7512 let val = uni_common::cypher_value_codec::decode(bytes)
7513 .expect("failed to decode CypherValue output");
7514 val.into()
7515 }
7516 other => panic!("expected LargeBinary scalar, got {other:?}"),
7517 }
7518 }
7519
7520 #[test]
7521 fn test_make_cypher_list_scalars() {
7522 let udf = create_make_cypher_list_udf();
7523 let args = make_multi_scalar_args(vec![
7524 ScalarValue::Int64(Some(1)),
7525 ScalarValue::Float64(Some(3.21)),
7526 ScalarValue::Utf8(Some("hello".to_string())),
7527 ScalarValue::Boolean(Some(true)),
7528 ScalarValue::Null,
7529 ]);
7530 let result = udf.invoke_with_args(args).unwrap();
7531 let json = decode_cv_scalar(&result);
7532 let arr = json.as_array().expect("should be array");
7533 assert_eq!(arr.len(), 5);
7534 assert_eq!(arr[0], serde_json::json!(1));
7535 assert_eq!(arr[1], serde_json::json!(3.21));
7536 assert_eq!(arr[2], serde_json::json!("hello"));
7537 assert_eq!(arr[3], serde_json::json!(true));
7538 assert!(arr[4].is_null());
7539 }
7540
7541 #[test]
7542 fn test_make_cypher_list_empty() {
7543 let udf = create_make_cypher_list_udf();
7544 let args = make_multi_scalar_args(vec![]);
7545 let result = udf.invoke_with_args(args).unwrap();
7546 let json = decode_cv_scalar(&result);
7547 let arr = json.as_array().expect("should be array");
7548 assert!(arr.is_empty());
7549 }
7550
7551 #[test]
7552 fn test_make_cypher_list_single() {
7553 let udf = create_make_cypher_list_udf();
7554 let args = make_multi_scalar_args(vec![ScalarValue::Int64(Some(42))]);
7555 let result = udf.invoke_with_args(args).unwrap();
7556 let json = decode_cv_scalar(&result);
7557 let arr = json.as_array().expect("should be array");
7558 assert_eq!(arr.len(), 1);
7559 assert_eq!(arr[0], serde_json::json!(42));
7560 }
7561
7562 #[test]
7563 fn test_make_cypher_list_nested_cypher_value() {
7564 let udf = create_make_cypher_list_udf();
7565 let nested_bytes = json_to_cv_bytes(&serde_json::json!([1, 2]));
7567 let args = make_multi_scalar_args(vec![
7568 ScalarValue::LargeBinary(Some(nested_bytes)),
7569 ScalarValue::Int64(Some(3)),
7570 ]);
7571 let result = udf.invoke_with_args(args).unwrap();
7572 let json = decode_cv_scalar(&result);
7573 let arr = json.as_array().expect("should be array");
7574 assert_eq!(arr.len(), 2);
7575 assert_eq!(arr[0], serde_json::json!([1, 2]));
7576 assert_eq!(arr[1], serde_json::json!(3));
7577 }
7578
7579 fn make_cypher_in_args(
7585 element: &serde_json::Value,
7586 list: &serde_json::Value,
7587 ) -> ScalarFunctionArgs {
7588 make_multi_scalar_args_with_return(
7589 vec![
7590 ScalarValue::LargeBinary(Some(json_to_cv_bytes(element))),
7591 ScalarValue::LargeBinary(Some(json_to_cv_bytes(list))),
7592 ],
7593 DataType::Boolean,
7594 )
7595 }
7596
7597 #[test]
7598 fn test_cypher_in_found() {
7599 let udf = create_cypher_in_udf();
7600 let args = make_cypher_in_args(&serde_json::json!(3), &serde_json::json!([1, 2, 3]));
7601 let result = udf.invoke_with_args(args).unwrap();
7602 match result {
7603 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7604 other => panic!("expected Boolean(true), got {other:?}"),
7605 }
7606 }
7607
7608 #[test]
7609 fn test_cypher_in_not_found() {
7610 let udf = create_cypher_in_udf();
7611 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, 2, 3]));
7612 let result = udf.invoke_with_args(args).unwrap();
7613 match result {
7614 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7615 other => panic!("expected Boolean(false), got {other:?}"),
7616 }
7617 }
7618
7619 #[test]
7620 fn test_cypher_in_null_list() {
7621 let udf = create_cypher_in_udf();
7622 let args = make_multi_scalar_args_with_return(
7623 vec![
7624 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(1)))),
7625 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7626 ],
7627 DataType::Boolean,
7628 );
7629 let result = udf.invoke_with_args(args).unwrap();
7630 match result {
7631 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null list, got {other:?}"),
7633 }
7634 }
7635
7636 #[test]
7637 fn test_cypher_in_null_element_nonempty() {
7638 let udf = create_cypher_in_udf();
7639 let args = make_multi_scalar_args_with_return(
7640 vec![
7641 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7642 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7643 ],
7644 DataType::Boolean,
7645 );
7646 let result = udf.invoke_with_args(args).unwrap();
7647 match result {
7648 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for null IN non-empty list, got {other:?}"),
7650 }
7651 }
7652
7653 #[test]
7654 fn test_cypher_in_null_element_empty() {
7655 let udf = create_cypher_in_udf();
7656 let args = make_multi_scalar_args_with_return(
7657 vec![
7658 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7659 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7660 ],
7661 DataType::Boolean,
7662 );
7663 let result = udf.invoke_with_args(args).unwrap();
7664 match result {
7665 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(!b),
7666 other => panic!("expected Boolean(false) for null IN [], got {other:?}"),
7667 }
7668 }
7669
7670 #[test]
7671 fn test_cypher_in_not_found_with_null() {
7672 let udf = create_cypher_in_udf();
7673 let args = make_cypher_in_args(&serde_json::json!(4), &serde_json::json!([1, null, 3]));
7674 let result = udf.invoke_with_args(args).unwrap();
7675 match result {
7676 ColumnarValue::Scalar(ScalarValue::Boolean(None)) => {} other => panic!("expected Boolean(None) for 4 IN [1,null,3], got {other:?}"),
7678 }
7679 }
7680
7681 #[test]
7682 fn test_cypher_in_cross_type_int_float() {
7683 let udf = create_cypher_in_udf();
7684 let args = make_cypher_in_args(&serde_json::json!(1), &serde_json::json!([1.0, 2.0]));
7685 let result = udf.invoke_with_args(args).unwrap();
7686 match result {
7687 ColumnarValue::Scalar(ScalarValue::Boolean(Some(b))) => assert!(b),
7688 other => panic!("expected Boolean(true) for 1 IN [1.0, 2.0], got {other:?}"),
7689 }
7690 }
7691
7692 #[test]
7697 fn test_list_concat_basic() {
7698 let udf = create_cypher_list_concat_udf();
7699 let args = make_multi_scalar_args(vec![
7700 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7701 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([3, 4])))),
7702 ]);
7703 let result = udf.invoke_with_args(args).unwrap();
7704 let json = decode_cv_scalar(&result);
7705 assert_eq!(json, serde_json::json!([1, 2, 3, 4]));
7706 }
7707
7708 #[test]
7709 fn test_list_concat_empty() {
7710 let udf = create_cypher_list_concat_udf();
7711 let args = make_multi_scalar_args(vec![
7712 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([])))),
7713 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7714 ]);
7715 let result = udf.invoke_with_args(args).unwrap();
7716 let json = decode_cv_scalar(&result);
7717 assert_eq!(json, serde_json::json!([1]));
7718 }
7719
7720 #[test]
7721 fn test_list_concat_null_left() {
7722 let udf = create_cypher_list_concat_udf();
7723 let args = make_multi_scalar_args(vec![
7724 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7725 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7726 ]);
7727 let result = udf.invoke_with_args(args).unwrap();
7728 match result {
7729 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7730 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7731 let json: serde_json::Value = uni_val.into();
7732 assert!(json.is_null(), "expected null, got {json}");
7733 }
7734 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {} other => panic!("expected null result, got {other:?}"),
7736 }
7737 }
7738
7739 #[test]
7740 fn test_list_concat_null_right() {
7741 let udf = create_cypher_list_concat_udf();
7742 let args = make_multi_scalar_args(vec![
7743 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1])))),
7744 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7745 ]);
7746 let result = udf.invoke_with_args(args).unwrap();
7747 match result {
7748 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7749 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7750 let json: serde_json::Value = uni_val.into();
7751 assert!(json.is_null(), "expected null, got {json}");
7752 }
7753 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7754 other => panic!("expected null result, got {other:?}"),
7755 }
7756 }
7757
7758 #[test]
7763 fn test_list_append_scalar() {
7764 let udf = create_cypher_list_append_udf();
7765 let args = make_multi_scalar_args(vec![
7766 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7767 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7768 ]);
7769 let result = udf.invoke_with_args(args).unwrap();
7770 let json = decode_cv_scalar(&result);
7771 assert_eq!(json, serde_json::json!([1, 2, 3]));
7772 }
7773
7774 #[test]
7775 fn test_list_prepend_scalar() {
7776 let udf = create_cypher_list_append_udf();
7777 let args = make_multi_scalar_args(vec![
7778 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7779 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7780 ]);
7781 let result = udf.invoke_with_args(args).unwrap();
7782 let json = decode_cv_scalar(&result);
7783 assert_eq!(json, serde_json::json!([3, 1, 2]));
7784 }
7785
7786 #[test]
7787 fn test_list_append_null_list() {
7788 let udf = create_cypher_list_append_udf();
7789 let args = make_multi_scalar_args(vec![
7790 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7791 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(3)))),
7792 ]);
7793 let result = udf.invoke_with_args(args).unwrap();
7794 match result {
7795 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7796 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7797 let json: serde_json::Value = uni_val.into();
7798 assert!(json.is_null(), "expected null, got {json}");
7799 }
7800 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7801 other => panic!("expected null result, got {other:?}"),
7802 }
7803 }
7804
7805 #[test]
7806 fn test_list_append_null_scalar() {
7807 let udf = create_cypher_list_append_udf();
7808 let args = make_multi_scalar_args(vec![
7809 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!([1, 2])))),
7810 ScalarValue::LargeBinary(Some(json_to_cv_bytes(&serde_json::json!(null)))),
7811 ]);
7812 let result = udf.invoke_with_args(args).unwrap();
7813 match result {
7814 ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(bytes))) => {
7815 let uni_val = uni_common::cypher_value_codec::decode(&bytes).expect("decode");
7816 let json: serde_json::Value = uni_val.into();
7817 assert!(json.is_null(), "expected null, got {json}");
7818 }
7819 ColumnarValue::Scalar(ScalarValue::LargeBinary(None)) => {}
7820 other => panic!("expected null result, got {other:?}"),
7821 }
7822 }
7823
7824 #[test]
7829 fn test_sort_key_cross_type_ordering() {
7830 use uni_common::core::id::{Eid, Vid};
7833 use uni_common::{Edge, Node, Path, TemporalValue, Value};
7834
7835 let map_val = Value::Map([("a".to_string(), Value::String("map".to_string()))].into());
7836 let node_val = Value::Node(Node {
7837 vid: Vid::new(1),
7838 labels: vec!["L".to_string()],
7839 properties: Default::default(),
7840 });
7841 let edge_val = Value::Edge(Edge {
7842 eid: Eid::new(1),
7843 edge_type: "T".to_string(),
7844 src: Vid::new(1),
7845 dst: Vid::new(2),
7846 properties: Default::default(),
7847 });
7848 let list_val = Value::List(vec![Value::Int(1)]);
7849 let path_val = Value::Path(Path {
7850 nodes: vec![Node {
7851 vid: Vid::new(1),
7852 labels: vec!["L".to_string()],
7853 properties: Default::default(),
7854 }],
7855 edges: vec![],
7856 });
7857 let string_val = Value::String("hello".to_string());
7858 let bool_val = Value::Bool(false);
7859 let temporal_val = Value::Temporal(TemporalValue::Date {
7860 days_since_epoch: 1000,
7861 });
7862 let number_val = Value::Int(42);
7863 let nan_val = Value::Float(f64::NAN);
7864 let null_val = Value::Null;
7865
7866 let values = vec![
7867 &map_val,
7868 &node_val,
7869 &edge_val,
7870 &list_val,
7871 &path_val,
7872 &string_val,
7873 &bool_val,
7874 &temporal_val,
7875 &number_val,
7876 &nan_val,
7877 &null_val,
7878 ];
7879
7880 let keys: Vec<Vec<u8>> = values.iter().map(|v| encode_cypher_sort_key(v)).collect();
7881
7882 for i in 0..keys.len() - 1 {
7884 assert!(
7885 keys[i] < keys[i + 1],
7886 "Expected sort_key({:?}) < sort_key({:?}), but {:?} >= {:?}",
7887 values[i],
7888 values[i + 1],
7889 keys[i],
7890 keys[i + 1]
7891 );
7892 }
7893 }
7894
7895 #[test]
7896 fn test_sort_key_numbers() {
7897 let neg_inf = encode_cypher_sort_key(&Value::Float(f64::NEG_INFINITY));
7898 let neg_100 = encode_cypher_sort_key(&Value::Float(-100.0));
7899 let neg_1 = encode_cypher_sort_key(&Value::Int(-1));
7900 let zero_int = encode_cypher_sort_key(&Value::Int(0));
7901 let zero_float = encode_cypher_sort_key(&Value::Float(0.0));
7902 let one_int = encode_cypher_sort_key(&Value::Int(1));
7903 let one_float = encode_cypher_sort_key(&Value::Float(1.0));
7904 let hundred = encode_cypher_sort_key(&Value::Int(100));
7905 let pos_inf = encode_cypher_sort_key(&Value::Float(f64::INFINITY));
7906 let nan = encode_cypher_sort_key(&Value::Float(f64::NAN));
7907
7908 assert!(neg_inf < neg_100, "-inf < -100");
7909 assert!(neg_100 < neg_1, "-100 < -1");
7910 assert!(neg_1 < zero_int, "-1 < 0");
7911 assert_eq!(zero_int, zero_float, "0 int == 0.0 float");
7912 assert!(zero_int < one_int, "0 < 1");
7913 assert_eq!(one_int, one_float, "1 int == 1.0 float");
7914 assert!(one_int < hundred, "1 < 100");
7915 assert!(hundred < pos_inf, "100 < +inf");
7916 assert!(pos_inf < nan, "+inf < NaN");
7918 }
7919
7920 #[test]
7921 fn test_sort_key_booleans() {
7922 let f = encode_cypher_sort_key(&Value::Bool(false));
7923 let t = encode_cypher_sort_key(&Value::Bool(true));
7924 assert!(f < t, "false < true");
7925 }
7926
7927 #[test]
7928 fn test_sort_key_strings() {
7929 let empty = encode_cypher_sort_key(&Value::String(String::new()));
7930 let a = encode_cypher_sort_key(&Value::String("a".to_string()));
7931 let ab = encode_cypher_sort_key(&Value::String("ab".to_string()));
7932 let b = encode_cypher_sort_key(&Value::String("b".to_string()));
7933
7934 assert!(empty < a, "'' < 'a'");
7935 assert!(a < ab, "'a' < 'ab'");
7936 assert!(ab < b, "'ab' < 'b'");
7937 }
7938
7939 #[test]
7940 fn test_sort_key_lists() {
7941 let empty = encode_cypher_sort_key(&Value::List(vec![]));
7942 let one = encode_cypher_sort_key(&Value::List(vec![Value::Int(1)]));
7943 let one_two = encode_cypher_sort_key(&Value::List(vec![Value::Int(1), Value::Int(2)]));
7944 let two = encode_cypher_sort_key(&Value::List(vec![Value::Int(2)]));
7945
7946 assert!(empty < one, "[] < [1]");
7947 assert!(one < one_two, "[1] < [1,2]");
7948 assert!(one_two < two, "[1,2] < [2]");
7949 }
7950
7951 #[test]
7952 fn test_sort_key_temporal() {
7953 use uni_common::TemporalValue;
7954
7955 let date1 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
7956 days_since_epoch: 100,
7957 }));
7958 let date2 = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
7959 days_since_epoch: 200,
7960 }));
7961 assert!(date1 < date2, "earlier date < later date");
7962
7963 let date = encode_cypher_sort_key(&Value::Temporal(TemporalValue::Date {
7965 days_since_epoch: i32::MAX,
7966 }));
7967 let local_time = encode_cypher_sort_key(&Value::Temporal(TemporalValue::LocalTime {
7968 nanos_since_midnight: 0,
7969 }));
7970 assert!(date < local_time, "Date < LocalTime (by variant rank)");
7971 }
7972
7973 #[test]
7974 fn test_sort_key_nested_lists() {
7975 let inner_a = Value::List(vec![Value::Int(1)]);
7976 let inner_b = Value::List(vec![Value::Int(2)]);
7977
7978 let list_a = encode_cypher_sort_key(&Value::List(vec![inner_a.clone()]));
7979 let list_b = encode_cypher_sort_key(&Value::List(vec![inner_b.clone()]));
7980
7981 assert!(list_a < list_b, "[[1]] < [[2]]");
7982 }
7983
7984 #[test]
7985 fn test_sort_key_null_handling() {
7986 let null_key = encode_cypher_sort_key(&Value::Null);
7987 assert_eq!(null_key, vec![0x0A], "Null produces [0x0A]");
7988
7989 let number_key = encode_cypher_sort_key(&Value::Int(42));
7991 assert!(number_key < null_key, "number < null");
7992 }
7993
7994 #[test]
7995 fn test_byte_stuff_roundtrip() {
7996 let s1 = Value::String("a\x00b".to_string());
7998 let s2 = Value::String("a\x00c".to_string());
7999 let s3 = Value::String("a\x01".to_string());
8000
8001 let k1 = encode_cypher_sort_key(&s1);
8002 let k2 = encode_cypher_sort_key(&s2);
8003 let k3 = encode_cypher_sort_key(&s3);
8004
8005 assert!(k1 < k2, "a\\x00b < a\\x00c");
8006 assert!(k1 < k3, "a\\x00b < a\\x01");
8009 }
8010
8011 #[test]
8012 fn test_sort_key_order_preserving_f64() {
8013 let vals = [f64::NEG_INFINITY, -1.0, -0.0, 0.0, 1.0, f64::INFINITY];
8015 let encoded: Vec<[u8; 8]> = vals
8016 .iter()
8017 .map(|f| encode_order_preserving_f64(*f))
8018 .collect();
8019
8020 for i in 0..encoded.len() - 1 {
8021 assert!(
8022 encoded[i] <= encoded[i + 1],
8023 "encode({}) should <= encode({}), got {:?} vs {:?}",
8024 vals[i],
8025 vals[i + 1],
8026 encoded[i],
8027 encoded[i + 1]
8028 );
8029 }
8030 }
8031
8032 #[test]
8036 fn test_sort_key_string_as_temporal_time_with_offset() {
8037 let tv = sort_key_string_as_temporal("12:35:15+05:00")
8038 .expect("should parse Time with positive offset");
8039 match tv {
8040 uni_common::TemporalValue::Time {
8041 nanos_since_midnight,
8042 offset_seconds,
8043 } => {
8044 assert_eq!(offset_seconds, 5 * 3600, "offset should be +05:00 = 18000s");
8045 let expected_nanos = (12 * 3600 + 35 * 60 + 15) * 1_000_000_000i64;
8047 assert_eq!(nanos_since_midnight, expected_nanos);
8048 }
8049 other => panic!("expected TemporalValue::Time, got {other:?}"),
8050 }
8051 }
8052
8053 #[test]
8054 fn test_sort_key_string_as_temporal_time_negative_offset() {
8055 let tv = sort_key_string_as_temporal("10:35:00-08:00")
8056 .expect("should parse Time with negative offset");
8057 match tv {
8058 uni_common::TemporalValue::Time {
8059 nanos_since_midnight,
8060 offset_seconds,
8061 } => {
8062 assert_eq!(
8063 offset_seconds,
8064 -8 * 3600,
8065 "offset should be -08:00 = -28800s"
8066 );
8067 let expected_nanos = (10 * 3600 + 35 * 60) * 1_000_000_000i64;
8068 assert_eq!(nanos_since_midnight, expected_nanos);
8069 }
8070 other => panic!("expected TemporalValue::Time, got {other:?}"),
8071 }
8072 }
8073
8074 #[test]
8075 fn test_sort_key_string_as_temporal_date() {
8076 use super::super::expr_eval::temporal_from_value;
8077 let tv = temporal_from_value(&Value::String("2024-01-15".into()))
8078 .expect("should parse Date string");
8079 match tv {
8080 uni_common::TemporalValue::Date { days_since_epoch } => {
8081 assert!(days_since_epoch > 0, "2024-01-15 should be after epoch");
8083 }
8084 other => panic!("expected TemporalValue::Date, got {other:?}"),
8085 }
8086 }
8087}