1use arrow_array::builder::{BooleanBuilder, Float64Builder, Int64Builder, StringBuilder};
17use arrow_array::{ArrayRef, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_cypher::ast::Expr;
32
33use crate::query::df_graph::GraphExecutionContext;
34use crate::query::df_graph::common::{
35 arrow_err, compute_plan_properties, evaluate_simple_expr, labels_data_type,
36};
37use crate::query::df_graph::scan::resolve_property_type;
38
39pub(crate) fn map_yield_to_canonical(yield_name: &str) -> &'static str {
52 match yield_name.to_lowercase().as_str() {
53 "vid" | "_vid" => "vid",
54 "distance" | "dist" | "_distance" => "distance",
55 "score" | "_score" => "score",
56 "vector_score" => "vector_score",
57 "fts_score" => "fts_score",
58 "raw_score" => "raw_score",
59 "rerank_score" | "_rerank_score" => "rerank_score",
60 _ => "node",
61 }
62}
63
64pub(crate) const NODE_YIELD_PROCEDURE_NAMES: &[&str] = &[
74 "uni.vector.query",
75 "uni.fts.query",
76 "uni.search",
77 "uni.create.vNode",
82];
83
84pub(crate) fn is_node_yield_procedure_static(name: &str) -> bool {
88 NODE_YIELD_PROCEDURE_NAMES.contains(&name)
89}
90
91pub(crate) fn canonical_search_type(canonical: &str) -> DataType {
96 match canonical {
97 "distance" => DataType::Float64,
98 "score" | "vector_score" | "fts_score" | "raw_score" | "rerank_score" => DataType::Float32,
99 "vid" => DataType::Int64,
100 _ => DataType::Utf8,
101 }
102}
103
104fn expand_node_yield_fields(
111 output_name: &str,
112 target_properties: &HashMap<String, Vec<String>>,
113 graph_ctx: &GraphExecutionContext,
114 fields: &mut Vec<Field>,
115) {
116 fields.push(Field::new(
117 format!("{}._vid", output_name),
118 DataType::UInt64,
119 false,
120 ));
121 fields.push(Field::new(output_name, DataType::Utf8, false));
122 fields.push(Field::new(
123 format!("{}._labels", output_name),
124 labels_data_type(),
125 true,
126 ));
127
128 if let Some(props) = target_properties.get(output_name) {
129 let uni_schema = graph_ctx.storage().schema_manager().schema();
130 for prop_name in props {
131 let col_name = format!("{}.{}", output_name, prop_name);
132 let arrow_type = resolve_property_type(prop_name, None);
133 let resolved_type = uni_schema
134 .properties
135 .values()
136 .find_map(|label_props| {
137 label_props
138 .get(prop_name.as_str())
139 .map(|_| resolve_property_type(prop_name, Some(label_props)))
140 })
141 .unwrap_or(arrow_type);
142 fields.push(Field::new(&col_name, resolved_type, true));
143 }
144 }
145}
146
147fn field_from_signature(col_name: &str, sig_field: &Field) -> Field {
151 let mut new_field = Field::new(
152 col_name,
153 sig_field.data_type().clone(),
154 sig_field.is_nullable(),
155 );
156 if !sig_field.metadata().is_empty() {
157 new_field = new_field.with_metadata(sig_field.metadata().clone());
158 }
159 new_field
160}
161
162pub struct GraphProcedureCallExec {
167 graph_ctx: Arc<GraphExecutionContext>,
169
170 procedure_name: String,
172
173 arguments: Vec<Expr>,
175
176 yield_items: Vec<(String, Option<String>)>,
178
179 params: HashMap<String, Value>,
181
182 outer_values: HashMap<String, Value>,
184
185 target_properties: HashMap<String, Vec<String>>,
187
188 schema: SchemaRef,
190
191 properties: Arc<PlanProperties>,
193
194 metrics: ExecutionPlanMetricsSet,
196}
197
198impl fmt::Debug for GraphProcedureCallExec {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 f.debug_struct("GraphProcedureCallExec")
201 .field("procedure_name", &self.procedure_name)
202 .field("yield_items", &self.yield_items)
203 .finish()
204 }
205}
206
207impl GraphProcedureCallExec {
208 pub fn new(
210 graph_ctx: Arc<GraphExecutionContext>,
211 procedure_name: String,
212 arguments: Vec<Expr>,
213 yield_items: Vec<(String, Option<String>)>,
214 params: HashMap<String, Value>,
215 outer_values: HashMap<String, Value>,
216 target_properties: HashMap<String, Vec<String>>,
217 ) -> Self {
218 let schema = Self::build_schema(
219 &procedure_name,
220 &yield_items,
221 &target_properties,
222 &graph_ctx,
223 );
224 let properties = compute_plan_properties(schema.clone());
225
226 Self {
227 graph_ctx,
228 procedure_name,
229 arguments,
230 yield_items,
231 params,
232 outer_values,
233 target_properties,
234 schema,
235 properties,
236 metrics: ExecutionPlanMetricsSet::new(),
237 }
238 }
239
240 fn build_schema(
254 procedure_name: &str,
255 yield_items: &[(String, Option<String>)],
256 target_properties: &HashMap<String, Vec<String>>,
257 graph_ctx: &GraphExecutionContext,
258 ) -> SchemaRef {
259 let mut fields = Vec::new();
260
261 if let Some(registry) = graph_ctx.procedure_registry()
262 && let Some(entry) = registry.resolve_user_procedure(procedure_name)
263 {
264 let supports_node_yield = entry.signature.yields.iter().any(|f| {
265 f.metadata()
266 .get("_yield_kind")
267 .is_some_and(|v| v == "node_vid_source")
268 });
269
270 for (yield_name, alias) in yield_items {
271 let col_name = alias.as_ref().unwrap_or(yield_name);
272
273 if supports_node_yield {
274 let canonical = map_yield_to_canonical(yield_name);
275 if canonical == "node" {
276 expand_node_yield_fields(
277 col_name,
278 target_properties,
279 graph_ctx,
280 &mut fields,
281 );
282 continue;
283 }
284 if let Some(sig_field) = entry
290 .signature
291 .yields
292 .iter()
293 .find(|f| f.name() == canonical)
294 {
295 fields.push(field_from_signature(col_name, sig_field));
296 } else {
297 fields.push(Field::new(col_name, canonical_search_type(canonical), true));
298 }
299 continue;
300 }
301
302 let field = entry
306 .signature
307 .yields
308 .iter()
309 .find(|f| f.name() == yield_name.as_str())
310 .map(|f| field_from_signature(col_name, f))
311 .unwrap_or_else(|| Field::new(col_name, DataType::Utf8, true));
312 fields.push(field);
313 }
314 } else if let Some(registry) = graph_ctx.procedure_registry()
315 && let Some(proc_def) = registry.get(procedure_name)
316 {
317 for (name, alias) in yield_items {
318 let col_name = alias.as_ref().unwrap_or(name);
319 let data_type = proc_def
320 .outputs
321 .iter()
322 .find(|o| o.name == *name)
323 .map(|o| procedure_value_type_to_arrow(&o.output_type))
324 .unwrap_or(DataType::Utf8);
325 fields.push(Field::new(col_name, data_type, true));
326 }
327 } else if yield_items.is_empty() {
328 } else {
330 for (name, alias) in yield_items {
331 let col_name = alias.as_ref().unwrap_or(name);
332 fields.push(Field::new(col_name, DataType::Utf8, true));
333 }
334 }
335
336 Arc::new(Schema::new(fields))
337 }
338}
339
340pub(crate) fn value_type_to_arrow(vt: &uni_algo::algo::procedures::ValueType) -> DataType {
342 use uni_algo::algo::procedures::ValueType;
343 match vt {
344 ValueType::Int => DataType::Int64,
345 ValueType::Float => DataType::Float64,
346 ValueType::String => DataType::Utf8,
347 ValueType::Bool => DataType::Boolean,
348 ValueType::List
349 | ValueType::Map
350 | ValueType::Node
351 | ValueType::Relationship
352 | ValueType::Path
353 | ValueType::Any => DataType::Utf8,
354 }
355}
356
357pub(crate) fn is_complex_value_type(vt: &uni_algo::algo::procedures::ValueType) -> bool {
360 use uni_algo::algo::procedures::ValueType;
361 matches!(
362 vt,
363 ValueType::List
364 | ValueType::Map
365 | ValueType::Node
366 | ValueType::Relationship
367 | ValueType::Path
368 )
369}
370
371fn procedure_value_type_to_arrow(
373 vt: &crate::query::executor::procedure::ProcedureValueType,
374) -> DataType {
375 use crate::query::executor::procedure::ProcedureValueType;
376 match vt {
377 ProcedureValueType::Integer => DataType::Int64,
378 ProcedureValueType::Float | ProcedureValueType::Number => DataType::Float64,
379 ProcedureValueType::Boolean => DataType::Boolean,
380 ProcedureValueType::String | ProcedureValueType::Any => DataType::Utf8,
381 }
382}
383
384impl DisplayAs for GraphProcedureCallExec {
385 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386 write!(
387 f,
388 "GraphProcedureCallExec: procedure={}",
389 self.procedure_name
390 )
391 }
392}
393
394impl ExecutionPlan for GraphProcedureCallExec {
395 fn name(&self) -> &str {
396 "GraphProcedureCallExec"
397 }
398
399 fn as_any(&self) -> &dyn Any {
400 self
401 }
402
403 fn schema(&self) -> SchemaRef {
404 self.schema.clone()
405 }
406
407 fn properties(&self) -> &Arc<PlanProperties> {
408 &self.properties
409 }
410
411 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
412 vec![]
413 }
414
415 fn with_new_children(
416 self: Arc<Self>,
417 children: Vec<Arc<dyn ExecutionPlan>>,
418 ) -> DFResult<Arc<dyn ExecutionPlan>> {
419 if !children.is_empty() {
420 return Err(datafusion::error::DataFusionError::Internal(
421 "GraphProcedureCallExec has no children".to_string(),
422 ));
423 }
424 Ok(self)
425 }
426
427 fn execute(
428 &self,
429 partition: usize,
430 _context: Arc<TaskContext>,
431 ) -> DFResult<SendableRecordBatchStream> {
432 let metrics = BaselineMetrics::new(&self.metrics, partition);
433
434 let mut evaluated_args = Vec::with_capacity(self.arguments.len());
436 for arg in &self.arguments {
437 evaluated_args.push(evaluate_simple_expr(arg, &self.params, &self.outer_values)?);
438 }
439
440 Ok(Box::pin(ProcedureCallStream::new(
441 self.graph_ctx.clone(),
442 self.procedure_name.clone(),
443 evaluated_args,
444 self.yield_items.clone(),
445 self.target_properties.clone(),
446 self.schema.clone(),
447 metrics,
448 )))
449 }
450
451 fn metrics(&self) -> Option<MetricsSet> {
452 Some(self.metrics.clone_inner())
453 }
454}
455
456enum ProcedureCallState {
462 Init,
464 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
466 Done,
468}
469
470struct ProcedureCallStream {
472 graph_ctx: Arc<GraphExecutionContext>,
473 procedure_name: String,
474 evaluated_args: Vec<Value>,
475 yield_items: Vec<(String, Option<String>)>,
476 target_properties: HashMap<String, Vec<String>>,
477 schema: SchemaRef,
478 state: ProcedureCallState,
479 metrics: BaselineMetrics,
480}
481
482impl ProcedureCallStream {
483 fn new(
484 graph_ctx: Arc<GraphExecutionContext>,
485 procedure_name: String,
486 evaluated_args: Vec<Value>,
487 yield_items: Vec<(String, Option<String>)>,
488 target_properties: HashMap<String, Vec<String>>,
489 schema: SchemaRef,
490 metrics: BaselineMetrics,
491 ) -> Self {
492 Self {
493 graph_ctx,
494 procedure_name,
495 evaluated_args,
496 yield_items,
497 target_properties,
498 schema,
499 state: ProcedureCallState::Init,
500 metrics,
501 }
502 }
503}
504
505impl Stream for ProcedureCallStream {
506 type Item = DFResult<RecordBatch>;
507
508 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
509 let metrics = self.metrics.clone();
510 let _timer = metrics.elapsed_compute().timer();
511 loop {
512 let state = std::mem::replace(&mut self.state, ProcedureCallState::Done);
513
514 match state {
515 ProcedureCallState::Init => {
516 let graph_ctx = self.graph_ctx.clone();
517 let procedure_name = self.procedure_name.clone();
518 let evaluated_args = self.evaluated_args.clone();
519 let yield_items = self.yield_items.clone();
520 let target_properties = self.target_properties.clone();
521 let schema = self.schema.clone();
522
523 let fut = async move {
524 graph_ctx.check_timeout().map_err(|e| {
525 datafusion::error::DataFusionError::Execution(e.to_string())
526 })?;
527
528 execute_procedure(
529 &graph_ctx,
530 &procedure_name,
531 &evaluated_args,
532 &yield_items,
533 &target_properties,
534 &schema,
535 )
536 .await
537 };
538
539 self.state = ProcedureCallState::Executing(Box::pin(fut));
540 }
541 ProcedureCallState::Executing(mut fut) => match fut.as_mut().poll(cx) {
542 Poll::Ready(Ok(batch)) => {
543 self.state = ProcedureCallState::Done;
544 self.metrics
545 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
546 return Poll::Ready(batch.map(Ok));
547 }
548 Poll::Ready(Err(e)) => {
549 self.state = ProcedureCallState::Done;
550 return Poll::Ready(Some(Err(e)));
551 }
552 Poll::Pending => {
553 self.state = ProcedureCallState::Executing(fut);
554 return Poll::Pending;
555 }
556 },
557 ProcedureCallState::Done => {
558 return Poll::Ready(None);
559 }
560 }
561 }
562 }
563}
564
565impl RecordBatchStream for ProcedureCallStream {
566 fn schema(&self) -> SchemaRef {
567 self.schema.clone()
568 }
569}
570
571async fn execute_procedure(
592 graph_ctx: &GraphExecutionContext,
593 procedure_name: &str,
594 args: &[Value],
595 yield_items: &[(String, Option<String>)],
596 target_properties: &HashMap<String, Vec<String>>,
597 schema: &SchemaRef,
598) -> DFResult<Option<RecordBatch>> {
599 if let Some(registry) = graph_ctx.procedure_registry()
605 && let Some(entry) = registry.resolve_user_procedure(procedure_name)
606 {
607 return execute_plugin_procedure(
608 graph_ctx,
609 procedure_name,
610 &entry,
611 args,
612 yield_items,
613 target_properties,
614 schema,
615 )
616 .await;
617 }
618
619 execute_registered_procedure(graph_ctx, procedure_name, args, yield_items, schema).await
620}
621
622async fn execute_plugin_procedure(
632 graph_ctx: &GraphExecutionContext,
633 procedure_name: &str,
634 entry: &uni_plugin::registry::ProcedureEntry,
635 args: &[Value],
636 yield_items: &[(String, Option<String>)],
637 target_properties: &HashMap<String, Vec<String>>,
638 schema: &SchemaRef,
639) -> DFResult<Option<RecordBatch>> {
640 use datafusion::logical_expr::ColumnarValue;
641 use futures::StreamExt;
642
643 let mut columnar_args: Vec<ColumnarValue> = Vec::with_capacity(args.len());
649 for v in args {
650 columnar_args.push(value_to_columnar(v).map_err(|e| {
651 datafusion::error::DataFusionError::Execution(format!(
652 "Procedure '{procedure_name}': argument conversion failed: {e}"
653 ))
654 })?);
655 }
656
657 let mut host =
658 crate::query::executor::procedure_host::QueryProcedureHost::from_graph_ctx_with_request(
659 graph_ctx,
660 target_properties.clone(),
661 yield_items.to_vec(),
662 Some(schema.clone()),
663 );
664 if let Some(writer) = graph_ctx.writer() {
669 host = host.with_writer(std::sync::Arc::clone(writer));
670 }
671 let principal = crate::current_principal();
678 let ctx = uni_plugin::host::build_procedure_context(&host, principal.as_deref());
679 let mut stream = entry.procedure.invoke(ctx, &columnar_args).map_err(|e| {
680 datafusion::error::DataFusionError::Execution(format!("Procedure '{procedure_name}': {e}"))
681 })?;
682
683 let mut batches: Vec<RecordBatch> = Vec::new();
687 while let Some(item) = stream.next().await {
688 let batch = item.map_err(|e| {
689 datafusion::error::DataFusionError::Execution(format!(
690 "Procedure '{procedure_name}' stream error: {e}"
691 ))
692 })?;
693 batches.push(batch);
694 }
695
696 if batches.is_empty() {
697 return Ok(Some(create_empty_batch(schema.clone())?));
700 }
701
702 let plugin_schema = batches[0].schema();
705 let combined = if batches.len() == 1 {
706 batches.pop().unwrap()
707 } else {
708 arrow::compute::concat_batches(&plugin_schema, &batches).map_err(arrow_err)?
709 };
710
711 if combined.schema().fields() == schema.fields() {
715 return Ok(Some(combined));
716 }
717
718 if yield_items.is_empty()
721 || (yield_items.len() == combined.num_columns()
722 && yield_items
723 .iter()
724 .zip(combined.schema().fields().iter())
725 .all(|((name, _alias), field)| name == field.name()))
726 {
727 return Ok(Some(combined));
728 }
729
730 let mut projected_cols: Vec<ArrayRef> = Vec::with_capacity(yield_items.len());
731 let mut projected_fields: Vec<Field> = Vec::with_capacity(yield_items.len());
732 for (name, _alias) in yield_items {
733 let idx = combined.schema().index_of(name).map_err(|_| {
734 datafusion::error::DataFusionError::Execution(format!(
735 "Procedure '{procedure_name}': YIELD column `{name}` not in plugin output schema"
736 ))
737 })?;
738 projected_cols.push(combined.column(idx).clone());
739 projected_fields.push(combined.schema().field(idx).clone());
740 }
741 let projected_schema = Arc::new(Schema::new(projected_fields));
742 let projected = RecordBatch::try_new(projected_schema, projected_cols).map_err(arrow_err)?;
743 Ok(Some(projected))
744}
745
746pub(crate) fn value_to_columnar(
750 v: &Value,
751) -> Result<datafusion::logical_expr::ColumnarValue, String> {
752 use datafusion::logical_expr::ColumnarValue;
753 use datafusion::scalar::ScalarValue;
754
755 let scalar = match v {
756 Value::Null => ScalarValue::Null,
757 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
758 Value::Int(i) => ScalarValue::Int64(Some(*i)),
759 Value::Float(f) => ScalarValue::Float64(Some(*f)),
760 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
761 Value::Bytes(b) => ScalarValue::Binary(Some(b.clone())),
762 other => {
763 let json = serde_json::to_vec(other)
770 .map_err(|e| format!("plugin arg encoding failed for {other:?}: {e}"))?;
771 ScalarValue::LargeBinary(Some(json))
772 }
773 };
774 Ok(ColumnarValue::Scalar(scalar))
775}
776
777pub(crate) fn build_typed_column<'a>(
791 values: impl Iterator<Item = Option<&'a Value>>,
792 num_rows: usize,
793 data_type: &DataType,
794) -> ArrayRef {
795 match data_type {
796 DataType::UInt64 => {
797 let mut builder = arrow_array::builder::UInt64Builder::with_capacity(num_rows);
798 for val in values {
799 match val.and_then(uni_common::Value::as_u64) {
800 Some(u) => builder.append_value(u),
801 None => builder.append_null(),
802 }
803 }
804 Arc::new(builder.finish())
805 }
806 DataType::Struct(fields) if is_edge_struct_shape(fields) => {
807 build_edge_struct_column(values, num_rows, fields)
808 }
809 DataType::Int64 => {
810 let mut builder = Int64Builder::with_capacity(num_rows);
811 for val in values {
812 match val.and_then(|v| v.as_i64()) {
813 Some(i) => builder.append_value(i),
814 None => builder.append_null(),
815 }
816 }
817 Arc::new(builder.finish())
818 }
819 DataType::Float64 => {
820 let mut builder = Float64Builder::with_capacity(num_rows);
821 for val in values {
822 match val.and_then(|v| v.as_f64()) {
823 Some(f) => builder.append_value(f),
824 None => builder.append_null(),
825 }
826 }
827 Arc::new(builder.finish())
828 }
829 DataType::Boolean => {
830 let mut builder = BooleanBuilder::with_capacity(num_rows);
831 for val in values {
832 match val.and_then(|v| v.as_bool()) {
833 Some(b) => builder.append_value(b),
834 None => builder.append_null(),
835 }
836 }
837 Arc::new(builder.finish())
838 }
839 _ => {
840 let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
842 for val in values {
843 match val {
844 Some(Value::String(s)) => builder.append_value(s),
845 Some(v) => builder.append_value(format!("{v}")),
846 None => builder.append_null(),
847 }
848 }
849 Arc::new(builder.finish())
850 }
851 }
852}
853
854fn is_edge_struct_shape(fields: &arrow_schema::Fields) -> bool {
860 let names: std::collections::HashSet<&str> = fields.iter().map(|f| f.name().as_str()).collect();
861 names.contains("_eid")
862 && names.contains("_type_name")
863 && names.contains("_src")
864 && names.contains("_dst")
865 && names.contains("properties")
866}
867
868fn build_edge_struct_column<'a>(
873 values: impl Iterator<Item = Option<&'a Value>>,
874 _num_rows: usize,
875 fields: &arrow_schema::Fields,
876) -> ArrayRef {
877 use arrow_array::builder::{LargeBinaryBuilder, StringBuilder, UInt64Builder};
878 use uni_common::Value as V;
879
880 let mut eid_b = UInt64Builder::new();
881 let mut type_b = StringBuilder::new();
882 let mut src_b = UInt64Builder::new();
883 let mut dst_b = UInt64Builder::new();
884 let mut props_b = LargeBinaryBuilder::new();
885 let mut validity: Vec<bool> = Vec::new();
886
887 for val in values {
888 match val {
889 Some(V::Edge(e)) => {
890 eid_b.append_value(e.eid.as_u64());
891 type_b.append_value(&e.edge_type);
892 src_b.append_value(e.src.as_u64());
893 dst_b.append_value(e.dst.as_u64());
894 let props_value = V::Map(e.properties.clone());
895 let bytes = uni_common::cypher_value_codec::encode(&props_value);
896 props_b.append_value(&bytes);
897 validity.push(true);
898 }
899 _ => {
900 eid_b.append_null();
901 type_b.append_null();
902 src_b.append_null();
903 dst_b.append_null();
904 props_b.append_null();
905 validity.push(false);
906 }
907 }
908 }
909
910 let arrays: Vec<ArrayRef> = vec![
911 Arc::new(eid_b.finish()),
912 Arc::new(type_b.finish()),
913 Arc::new(src_b.finish()),
914 Arc::new(dst_b.finish()),
915 Arc::new(props_b.finish()),
916 ];
917 let canonical: [&str; 5] = ["_eid", "_type_name", "_src", "_dst", "properties"];
922 let mut ordered: Vec<ArrayRef> = Vec::with_capacity(fields.len());
923 for f in fields.iter() {
924 let idx = canonical
925 .iter()
926 .position(|n| *n == f.name().as_str())
927 .expect("is_edge_struct_shape vetted these field names");
928 ordered.push(arrays[idx].clone());
929 }
930 let nulls = arrow::buffer::NullBuffer::from(validity);
931 Arc::new(
932 arrow_array::StructArray::try_new(fields.clone(), ordered, Some(nulls))
933 .expect("StructArray construction with vetted shape"),
934 )
935}
936
937pub(crate) fn create_empty_batch(schema: SchemaRef) -> DFResult<RecordBatch> {
943 if schema.fields().is_empty() {
944 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
945 RecordBatch::try_new_with_options(schema, vec![], &options).map_err(arrow_err)
946 } else {
947 Ok(RecordBatch::new_empty(schema))
948 }
949}
950
951async fn execute_registered_procedure(
960 graph_ctx: &GraphExecutionContext,
961 procedure_name: &str,
962 args: &[Value],
963 yield_items: &[(String, Option<String>)],
964 schema: &SchemaRef,
965) -> DFResult<Option<RecordBatch>> {
966 let registry = graph_ctx.procedure_registry().ok_or_else(|| {
967 datafusion::error::DataFusionError::Execution(format!(
968 "Procedure '{}' not supported in DataFusion engine (no procedure registry)",
969 procedure_name
970 ))
971 })?;
972
973 let proc_def = registry.get(procedure_name).ok_or_else(|| {
974 datafusion::error::DataFusionError::Execution(format!(
975 "ProcedureNotFound: Unknown procedure '{}'",
976 procedure_name
977 ))
978 })?;
979
980 if args.len() != proc_def.params.len() {
982 return Err(datafusion::error::DataFusionError::Execution(format!(
983 "InvalidNumberOfArguments: Procedure '{}' expects {} argument(s), got {}",
984 proc_def.name,
985 proc_def.params.len(),
986 args.len()
987 )));
988 }
989
990 for (i, (arg_val, param)) in args.iter().zip(&proc_def.params).enumerate() {
992 if !arg_val.is_null() && !check_proc_type_compatible(arg_val, ¶m.param_type) {
993 return Err(datafusion::error::DataFusionError::Execution(format!(
994 "InvalidArgumentType: Argument {} ('{}') of procedure '{}' has incompatible type",
995 i, param.name, proc_def.name
996 )));
997 }
998 }
999
1000 let filtered: Vec<&HashMap<String, Value>> = proc_def
1002 .data
1003 .iter()
1004 .filter(|row| {
1005 for (param, arg_val) in proc_def.params.iter().zip(args) {
1006 if let Some(row_val) = row.get(¶m.name)
1007 && !proc_values_match(row_val, arg_val)
1008 {
1009 return false;
1010 }
1011 }
1012 true
1013 })
1014 .collect();
1015
1016 if yield_items.is_empty() {
1018 return Ok(Some(create_empty_batch(schema.clone())?));
1019 }
1020
1021 if filtered.is_empty() {
1022 return Ok(Some(create_empty_batch(schema.clone())?));
1023 }
1024
1025 let num_rows = filtered.len();
1028 let mut columns: Vec<ArrayRef> = Vec::new();
1029
1030 for (idx, (name, _alias)) in yield_items.iter().enumerate() {
1031 let field = schema.field(idx);
1032 let values = filtered.iter().map(|row| row.get(name.as_str()));
1033 columns.push(build_typed_column(values, num_rows, field.data_type()));
1034 }
1035
1036 let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
1037 Ok(Some(batch))
1038}
1039
1040fn check_proc_type_compatible(
1042 val: &Value,
1043 expected: &crate::query::executor::procedure::ProcedureValueType,
1044) -> bool {
1045 use crate::query::executor::procedure::ProcedureValueType;
1046 match expected {
1047 ProcedureValueType::Any => true,
1048 ProcedureValueType::String => val.is_string(),
1049 ProcedureValueType::Boolean => val.is_bool(),
1050 ProcedureValueType::Integer => val.is_i64(),
1051 ProcedureValueType::Float => val.is_f64() || val.is_i64(),
1052 ProcedureValueType::Number => val.is_number(),
1053 }
1054}
1055
1056fn proc_values_match(row_val: &Value, arg_val: &Value) -> bool {
1058 if arg_val.is_null() || row_val.is_null() {
1059 return arg_val.is_null() && row_val.is_null();
1060 }
1061 if let (Some(a), Some(b)) = (row_val.as_f64(), arg_val.as_f64()) {
1063 return (a - b).abs() < f64::EPSILON;
1064 }
1065 row_val == arg_val
1066}
1067
1068pub(crate) fn json_to_value(jv: &serde_json::Value) -> Value {
1070 match jv {
1071 serde_json::Value::Null => Value::Null,
1072 serde_json::Value::Bool(b) => Value::Bool(*b),
1073 serde_json::Value::Number(n) => {
1074 if let Some(i) = n.as_i64() {
1075 Value::Int(i)
1076 } else if let Some(f) = n.as_f64() {
1077 Value::Float(f)
1078 } else {
1079 Value::Null
1080 }
1081 }
1082 serde_json::Value::String(s) => Value::String(s.clone()),
1083 other => Value::String(other.to_string()),
1084 }
1085}
1086
1087pub(crate) fn require_string_arg(
1093 args: &[Value],
1094 index: usize,
1095 description: &str,
1096) -> DFResult<String> {
1097 args.get(index)
1098 .and_then(|v| v.as_str())
1099 .map(|s| s.to_string())
1100 .ok_or_else(|| {
1101 datafusion::error::DataFusionError::Execution(format!("{description} must be a string"))
1102 })
1103}
1104
1105pub(crate) fn extract_optional_filter(args: &[Value], index: usize) -> Option<String> {
1108 args.get(index).and_then(|v| {
1109 if v.is_null() {
1110 None
1111 } else {
1112 v.as_str().map(|s| s.to_string())
1113 }
1114 })
1115}