1use arrow_array::builder::{BooleanBuilder, Float64Builder, Int64Builder, StringBuilder};
17use arrow_array::{ArrayRef, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_cypher::ast::Expr;
32
33use crate::query::df_graph::GraphExecutionContext;
34use crate::query::df_graph::common::{
35 arrow_err, compute_plan_properties, evaluate_simple_expr, labels_data_type,
36};
37use crate::query::df_graph::scan::{property_field, resolve_property_type};
38
39pub(crate) fn map_yield_to_canonical(yield_name: &str) -> &'static str {
52 match yield_name.to_lowercase().as_str() {
53 "vid" | "_vid" => "vid",
54 "distance" | "dist" | "_distance" => "distance",
55 "score" | "_score" => "score",
56 "vector_score" => "vector_score",
57 "fts_score" => "fts_score",
58 "raw_score" => "raw_score",
59 "rerank_score" | "_rerank_score" => "rerank_score",
60 _ => "node",
61 }
62}
63
64pub(crate) const NODE_YIELD_PROCEDURE_NAMES: &[&str] = &[
74 "uni.vector.query",
75 "uni.fts.query",
76 "uni.search",
77 "uni.create.vNode",
82];
83
84pub(crate) fn is_node_yield_procedure_static(name: &str) -> bool {
88 NODE_YIELD_PROCEDURE_NAMES.contains(&name)
89}
90
91pub(crate) fn canonical_search_type(canonical: &str) -> DataType {
96 match canonical {
97 "distance" => DataType::Float64,
98 "score" | "vector_score" | "fts_score" | "raw_score" | "rerank_score" => DataType::Float32,
99 "vid" => DataType::Int64,
100 _ => DataType::Utf8,
101 }
102}
103
104fn expand_node_yield_fields(
111 output_name: &str,
112 target_properties: &HashMap<String, Vec<String>>,
113 graph_ctx: &GraphExecutionContext,
114 fields: &mut Vec<Field>,
115) {
116 fields.push(Field::new(
117 format!("{}._vid", output_name),
118 DataType::UInt64,
119 false,
120 ));
121 fields.push(Field::new(output_name, DataType::Utf8, false));
122 fields.push(Field::new(
123 format!("{}._labels", output_name),
124 labels_data_type(),
125 true,
126 ));
127
128 if let Some(props) = target_properties.get(output_name) {
129 let uni_schema = graph_ctx.storage().schema_manager().schema();
130 for prop_name in props {
131 let col_name = format!("{}.{}", output_name, prop_name);
132 let arrow_type = resolve_property_type(prop_name, None);
133 let resolved_type = uni_schema
134 .properties
135 .values()
136 .find_map(|label_props| {
137 label_props
138 .get(prop_name.as_str())
139 .map(|_| resolve_property_type(prop_name, Some(label_props)))
140 })
141 .unwrap_or(arrow_type);
142 let uni_type = uni_schema
143 .properties
144 .values()
145 .find_map(|label_props| label_props.get(prop_name.as_str()).map(|m| &m.r#type));
146 fields.push(property_field(&col_name, resolved_type, uni_type));
147 }
148 }
149}
150
151fn field_from_signature(col_name: &str, sig_field: &Field) -> Field {
155 let mut new_field = Field::new(
156 col_name,
157 sig_field.data_type().clone(),
158 sig_field.is_nullable(),
159 );
160 if !sig_field.metadata().is_empty() {
161 new_field = new_field.with_metadata(sig_field.metadata().clone());
162 }
163 new_field
164}
165
166pub struct GraphProcedureCallExec {
171 graph_ctx: Arc<GraphExecutionContext>,
173
174 procedure_name: String,
176
177 arguments: Vec<Expr>,
179
180 yield_items: Vec<(String, Option<String>)>,
182
183 params: HashMap<String, Value>,
185
186 outer_values: HashMap<String, Value>,
188
189 target_properties: HashMap<String, Vec<String>>,
191
192 schema: SchemaRef,
194
195 properties: Arc<PlanProperties>,
197
198 metrics: ExecutionPlanMetricsSet,
200}
201
202impl fmt::Debug for GraphProcedureCallExec {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 f.debug_struct("GraphProcedureCallExec")
205 .field("procedure_name", &self.procedure_name)
206 .field("yield_items", &self.yield_items)
207 .finish()
208 }
209}
210
211impl GraphProcedureCallExec {
212 pub fn new(
214 graph_ctx: Arc<GraphExecutionContext>,
215 procedure_name: String,
216 arguments: Vec<Expr>,
217 yield_items: Vec<(String, Option<String>)>,
218 params: HashMap<String, Value>,
219 outer_values: HashMap<String, Value>,
220 target_properties: HashMap<String, Vec<String>>,
221 ) -> Self {
222 let schema = Self::build_schema(
223 &procedure_name,
224 &yield_items,
225 &target_properties,
226 &graph_ctx,
227 );
228 let properties = compute_plan_properties(schema.clone());
229
230 Self {
231 graph_ctx,
232 procedure_name,
233 arguments,
234 yield_items,
235 params,
236 outer_values,
237 target_properties,
238 schema,
239 properties,
240 metrics: ExecutionPlanMetricsSet::new(),
241 }
242 }
243
244 fn build_schema(
258 procedure_name: &str,
259 yield_items: &[(String, Option<String>)],
260 target_properties: &HashMap<String, Vec<String>>,
261 graph_ctx: &GraphExecutionContext,
262 ) -> SchemaRef {
263 let mut fields = Vec::new();
264
265 if let Some(registry) = graph_ctx.procedure_registry()
266 && let Some(entry) = registry.resolve_user_procedure(procedure_name)
267 {
268 let supports_node_yield = entry.signature.yields.iter().any(|f| {
269 f.metadata()
270 .get("_yield_kind")
271 .is_some_and(|v| v == "node_vid_source")
272 });
273
274 for (yield_name, alias) in yield_items {
275 let col_name = alias.as_ref().unwrap_or(yield_name);
276
277 if supports_node_yield {
278 let canonical = map_yield_to_canonical(yield_name);
279 if canonical == "node" {
280 expand_node_yield_fields(
281 col_name,
282 target_properties,
283 graph_ctx,
284 &mut fields,
285 );
286 continue;
287 }
288 if let Some(sig_field) = entry
294 .signature
295 .yields
296 .iter()
297 .find(|f| f.name() == canonical)
298 {
299 fields.push(field_from_signature(col_name, sig_field));
300 } else {
301 fields.push(Field::new(col_name, canonical_search_type(canonical), true));
302 }
303 continue;
304 }
305
306 let field = entry
310 .signature
311 .yields
312 .iter()
313 .find(|f| f.name() == yield_name.as_str())
314 .map(|f| field_from_signature(col_name, f))
315 .unwrap_or_else(|| Field::new(col_name, DataType::Utf8, true));
316 fields.push(field);
317 }
318 } else if let Some(registry) = graph_ctx.procedure_registry()
319 && let Some(proc_def) = registry.get(procedure_name)
320 {
321 for (name, alias) in yield_items {
322 let col_name = alias.as_ref().unwrap_or(name);
323 let data_type = proc_def
324 .outputs
325 .iter()
326 .find(|o| o.name == *name)
327 .map(|o| procedure_value_type_to_arrow(&o.output_type))
328 .unwrap_or(DataType::Utf8);
329 fields.push(Field::new(col_name, data_type, true));
330 }
331 } else if yield_items.is_empty() {
332 } else {
334 for (name, alias) in yield_items {
335 let col_name = alias.as_ref().unwrap_or(name);
336 fields.push(Field::new(col_name, DataType::Utf8, true));
337 }
338 }
339
340 Arc::new(Schema::new(fields))
341 }
342}
343
344pub(crate) fn value_type_to_arrow(vt: &uni_algo::algo::procedures::ValueType) -> DataType {
346 use uni_algo::algo::procedures::ValueType;
347 match vt {
348 ValueType::Int => DataType::Int64,
349 ValueType::Float => DataType::Float64,
350 ValueType::String => DataType::Utf8,
351 ValueType::Bool => DataType::Boolean,
352 ValueType::List
353 | ValueType::Map
354 | ValueType::Node
355 | ValueType::Relationship
356 | ValueType::Path
357 | ValueType::Any => DataType::Utf8,
358 }
359}
360
361pub(crate) fn is_complex_value_type(vt: &uni_algo::algo::procedures::ValueType) -> bool {
364 use uni_algo::algo::procedures::ValueType;
365 matches!(
366 vt,
367 ValueType::List
368 | ValueType::Map
369 | ValueType::Node
370 | ValueType::Relationship
371 | ValueType::Path
372 )
373}
374
375fn procedure_value_type_to_arrow(
377 vt: &crate::query::executor::procedure::ProcedureValueType,
378) -> DataType {
379 use crate::query::executor::procedure::ProcedureValueType;
380 match vt {
381 ProcedureValueType::Integer => DataType::Int64,
382 ProcedureValueType::Float | ProcedureValueType::Number => DataType::Float64,
383 ProcedureValueType::Boolean => DataType::Boolean,
384 ProcedureValueType::String | ProcedureValueType::Any => DataType::Utf8,
385 }
386}
387
388impl DisplayAs for GraphProcedureCallExec {
389 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390 write!(
391 f,
392 "GraphProcedureCallExec: procedure={}",
393 self.procedure_name
394 )
395 }
396}
397
398impl ExecutionPlan for GraphProcedureCallExec {
399 fn name(&self) -> &str {
400 "GraphProcedureCallExec"
401 }
402
403 fn as_any(&self) -> &dyn Any {
404 self
405 }
406
407 fn schema(&self) -> SchemaRef {
408 self.schema.clone()
409 }
410
411 fn properties(&self) -> &Arc<PlanProperties> {
412 &self.properties
413 }
414
415 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
416 vec![]
417 }
418
419 fn with_new_children(
420 self: Arc<Self>,
421 children: Vec<Arc<dyn ExecutionPlan>>,
422 ) -> DFResult<Arc<dyn ExecutionPlan>> {
423 if !children.is_empty() {
424 return Err(datafusion::error::DataFusionError::Internal(
425 "GraphProcedureCallExec has no children".to_string(),
426 ));
427 }
428 Ok(self)
429 }
430
431 fn execute(
432 &self,
433 partition: usize,
434 _context: Arc<TaskContext>,
435 ) -> DFResult<SendableRecordBatchStream> {
436 let metrics = BaselineMetrics::new(&self.metrics, partition);
437
438 let mut evaluated_args = Vec::with_capacity(self.arguments.len());
440 for arg in &self.arguments {
441 evaluated_args.push(evaluate_simple_expr(arg, &self.params, &self.outer_values)?);
442 }
443
444 Ok(Box::pin(ProcedureCallStream::new(
445 self.graph_ctx.clone(),
446 self.procedure_name.clone(),
447 evaluated_args,
448 self.yield_items.clone(),
449 self.target_properties.clone(),
450 self.schema.clone(),
451 metrics,
452 )))
453 }
454
455 fn metrics(&self) -> Option<MetricsSet> {
456 Some(self.metrics.clone_inner())
457 }
458}
459
460enum ProcedureCallState {
466 Init,
468 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
470 Done,
472}
473
474struct ProcedureCallStream {
476 graph_ctx: Arc<GraphExecutionContext>,
477 procedure_name: String,
478 evaluated_args: Vec<Value>,
479 yield_items: Vec<(String, Option<String>)>,
480 target_properties: HashMap<String, Vec<String>>,
481 schema: SchemaRef,
482 state: ProcedureCallState,
483 metrics: BaselineMetrics,
484}
485
486impl ProcedureCallStream {
487 fn new(
488 graph_ctx: Arc<GraphExecutionContext>,
489 procedure_name: String,
490 evaluated_args: Vec<Value>,
491 yield_items: Vec<(String, Option<String>)>,
492 target_properties: HashMap<String, Vec<String>>,
493 schema: SchemaRef,
494 metrics: BaselineMetrics,
495 ) -> Self {
496 Self {
497 graph_ctx,
498 procedure_name,
499 evaluated_args,
500 yield_items,
501 target_properties,
502 schema,
503 state: ProcedureCallState::Init,
504 metrics,
505 }
506 }
507}
508
509impl Stream for ProcedureCallStream {
510 type Item = DFResult<RecordBatch>;
511
512 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
513 let metrics = self.metrics.clone();
514 let _timer = metrics.elapsed_compute().timer();
515 loop {
516 let state = std::mem::replace(&mut self.state, ProcedureCallState::Done);
517
518 match state {
519 ProcedureCallState::Init => {
520 let graph_ctx = self.graph_ctx.clone();
521 let procedure_name = self.procedure_name.clone();
522 let evaluated_args = self.evaluated_args.clone();
523 let yield_items = self.yield_items.clone();
524 let target_properties = self.target_properties.clone();
525 let schema = self.schema.clone();
526
527 let fut = async move {
528 graph_ctx.check_timeout().map_err(|e| {
529 datafusion::error::DataFusionError::Execution(e.to_string())
530 })?;
531
532 execute_procedure(
533 &graph_ctx,
534 &procedure_name,
535 &evaluated_args,
536 &yield_items,
537 &target_properties,
538 &schema,
539 )
540 .await
541 };
542
543 self.state = ProcedureCallState::Executing(Box::pin(fut));
544 }
545 ProcedureCallState::Executing(mut fut) => match fut.as_mut().poll(cx) {
546 Poll::Ready(Ok(batch)) => {
547 self.state = ProcedureCallState::Done;
548 self.metrics
549 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
550 return Poll::Ready(batch.map(Ok));
551 }
552 Poll::Ready(Err(e)) => {
553 self.state = ProcedureCallState::Done;
554 return Poll::Ready(Some(Err(e)));
555 }
556 Poll::Pending => {
557 self.state = ProcedureCallState::Executing(fut);
558 return Poll::Pending;
559 }
560 },
561 ProcedureCallState::Done => {
562 return Poll::Ready(None);
563 }
564 }
565 }
566 }
567}
568
569impl RecordBatchStream for ProcedureCallStream {
570 fn schema(&self) -> SchemaRef {
571 self.schema.clone()
572 }
573}
574
575async fn execute_procedure(
596 graph_ctx: &GraphExecutionContext,
597 procedure_name: &str,
598 args: &[Value],
599 yield_items: &[(String, Option<String>)],
600 target_properties: &HashMap<String, Vec<String>>,
601 schema: &SchemaRef,
602) -> DFResult<Option<RecordBatch>> {
603 if let Some(registry) = graph_ctx.procedure_registry()
609 && let Some(entry) = registry.resolve_user_procedure(procedure_name)
610 {
611 return execute_plugin_procedure(
612 graph_ctx,
613 procedure_name,
614 &entry,
615 args,
616 yield_items,
617 target_properties,
618 schema,
619 )
620 .await;
621 }
622
623 execute_registered_procedure(graph_ctx, procedure_name, args, yield_items, schema).await
624}
625
626async fn execute_plugin_procedure(
636 graph_ctx: &GraphExecutionContext,
637 procedure_name: &str,
638 entry: &uni_plugin::registry::ProcedureEntry,
639 args: &[Value],
640 yield_items: &[(String, Option<String>)],
641 target_properties: &HashMap<String, Vec<String>>,
642 schema: &SchemaRef,
643) -> DFResult<Option<RecordBatch>> {
644 use datafusion::logical_expr::ColumnarValue;
645 use futures::StreamExt;
646
647 let mut columnar_args: Vec<ColumnarValue> = Vec::with_capacity(args.len());
653 for v in args {
654 columnar_args.push(value_to_columnar(v).map_err(|e| {
655 datafusion::error::DataFusionError::Execution(format!(
656 "Procedure '{procedure_name}': argument conversion failed: {e}"
657 ))
658 })?);
659 }
660
661 let mut host =
662 crate::query::executor::procedure_host::QueryProcedureHost::from_graph_ctx_with_request(
663 graph_ctx,
664 target_properties.clone(),
665 yield_items.to_vec(),
666 Some(schema.clone()),
667 );
668 if let Some(writer) = graph_ctx.writer() {
673 host = host.with_writer(std::sync::Arc::clone(writer));
674 }
675 let principal = crate::current_principal();
682 let ctx = uni_plugin::host::build_procedure_context(&host, principal.as_deref());
683 let mut stream = entry.procedure.invoke(ctx, &columnar_args).map_err(|e| {
684 datafusion::error::DataFusionError::Execution(format!("Procedure '{procedure_name}': {e}"))
685 })?;
686
687 let mut batches: Vec<RecordBatch> = Vec::new();
691 while let Some(item) = stream.next().await {
692 let batch = item.map_err(|e| {
693 datafusion::error::DataFusionError::Execution(format!(
694 "Procedure '{procedure_name}' stream error: {e}"
695 ))
696 })?;
697 batches.push(batch);
698 }
699
700 if batches.is_empty() {
701 return Ok(Some(create_empty_batch(schema.clone())?));
704 }
705
706 let plugin_schema = batches[0].schema();
709 let combined = if batches.len() == 1 {
710 batches.pop().unwrap()
711 } else {
712 arrow::compute::concat_batches(&plugin_schema, &batches).map_err(arrow_err)?
713 };
714
715 if combined.schema().fields() == schema.fields() {
719 return Ok(Some(combined));
720 }
721
722 if yield_items.is_empty()
725 || (yield_items.len() == combined.num_columns()
726 && yield_items
727 .iter()
728 .zip(combined.schema().fields().iter())
729 .all(|((name, _alias), field)| name == field.name()))
730 {
731 return Ok(Some(combined));
732 }
733
734 let mut projected_cols: Vec<ArrayRef> = Vec::with_capacity(yield_items.len());
735 let mut projected_fields: Vec<Field> = Vec::with_capacity(yield_items.len());
736 for (name, _alias) in yield_items {
737 let idx = combined.schema().index_of(name).map_err(|_| {
738 datafusion::error::DataFusionError::Execution(format!(
739 "Procedure '{procedure_name}': YIELD column `{name}` not in plugin output schema"
740 ))
741 })?;
742 projected_cols.push(combined.column(idx).clone());
743 projected_fields.push(combined.schema().field(idx).clone());
744 }
745 let projected_schema = Arc::new(Schema::new(projected_fields));
746 let projected = RecordBatch::try_new(projected_schema, projected_cols).map_err(arrow_err)?;
747 Ok(Some(projected))
748}
749
750pub(crate) fn value_to_columnar(
754 v: &Value,
755) -> Result<datafusion::logical_expr::ColumnarValue, String> {
756 use datafusion::logical_expr::ColumnarValue;
757 use datafusion::scalar::ScalarValue;
758
759 let scalar = match v {
760 Value::Null => ScalarValue::Null,
761 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
762 Value::Int(i) => ScalarValue::Int64(Some(*i)),
763 Value::Float(f) => ScalarValue::Float64(Some(*f)),
764 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
765 Value::Bytes(b) => ScalarValue::Binary(Some(b.clone())),
766 other => {
767 let json = serde_json::to_vec(other)
774 .map_err(|e| format!("plugin arg encoding failed for {other:?}: {e}"))?;
775 ScalarValue::LargeBinary(Some(json))
776 }
777 };
778 Ok(ColumnarValue::Scalar(scalar))
779}
780
781pub(crate) fn build_typed_column<'a>(
795 values: impl Iterator<Item = Option<&'a Value>>,
796 num_rows: usize,
797 data_type: &DataType,
798) -> ArrayRef {
799 match data_type {
800 DataType::UInt64 => {
801 let mut builder = arrow_array::builder::UInt64Builder::with_capacity(num_rows);
802 for val in values {
803 match val.and_then(uni_common::Value::as_u64) {
804 Some(u) => builder.append_value(u),
805 None => builder.append_null(),
806 }
807 }
808 Arc::new(builder.finish())
809 }
810 DataType::Struct(fields) if is_edge_struct_shape(fields) => {
811 build_edge_struct_column(values, num_rows, fields)
812 }
813 DataType::Int64 => {
814 let mut builder = Int64Builder::with_capacity(num_rows);
815 for val in values {
816 match val.and_then(|v| v.as_i64()) {
817 Some(i) => builder.append_value(i),
818 None => builder.append_null(),
819 }
820 }
821 Arc::new(builder.finish())
822 }
823 DataType::Float64 => {
824 let mut builder = Float64Builder::with_capacity(num_rows);
825 for val in values {
826 match val.and_then(|v| v.as_f64()) {
827 Some(f) => builder.append_value(f),
828 None => builder.append_null(),
829 }
830 }
831 Arc::new(builder.finish())
832 }
833 DataType::Boolean => {
834 let mut builder = BooleanBuilder::with_capacity(num_rows);
835 for val in values {
836 match val.and_then(|v| v.as_bool()) {
837 Some(b) => builder.append_value(b),
838 None => builder.append_null(),
839 }
840 }
841 Arc::new(builder.finish())
842 }
843 _ => {
844 let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
846 for val in values {
847 match val {
848 Some(Value::String(s)) => builder.append_value(s),
849 Some(v) => builder.append_value(format!("{v}")),
850 None => builder.append_null(),
851 }
852 }
853 Arc::new(builder.finish())
854 }
855 }
856}
857
858fn is_edge_struct_shape(fields: &arrow_schema::Fields) -> bool {
864 let names: std::collections::HashSet<&str> = fields.iter().map(|f| f.name().as_str()).collect();
865 names.contains("_eid")
866 && names.contains("_type_name")
867 && names.contains("_src")
868 && names.contains("_dst")
869 && names.contains("properties")
870}
871
872fn build_edge_struct_column<'a>(
877 values: impl Iterator<Item = Option<&'a Value>>,
878 _num_rows: usize,
879 fields: &arrow_schema::Fields,
880) -> ArrayRef {
881 use arrow_array::builder::{LargeBinaryBuilder, StringBuilder, UInt64Builder};
882 use uni_common::Value as V;
883
884 let mut eid_b = UInt64Builder::new();
885 let mut type_b = StringBuilder::new();
886 let mut src_b = UInt64Builder::new();
887 let mut dst_b = UInt64Builder::new();
888 let mut props_b = LargeBinaryBuilder::new();
889 let mut validity: Vec<bool> = Vec::new();
890
891 for val in values {
892 match val {
893 Some(V::Edge(e)) => {
894 eid_b.append_value(e.eid.as_u64());
895 type_b.append_value(&e.edge_type);
896 src_b.append_value(e.src.as_u64());
897 dst_b.append_value(e.dst.as_u64());
898 let props_value = V::Map(e.properties.clone());
899 let bytes = uni_common::cypher_value_codec::encode(&props_value);
900 props_b.append_value(&bytes);
901 validity.push(true);
902 }
903 _ => {
904 eid_b.append_null();
905 type_b.append_null();
906 src_b.append_null();
907 dst_b.append_null();
908 props_b.append_null();
909 validity.push(false);
910 }
911 }
912 }
913
914 let arrays: Vec<ArrayRef> = vec![
915 Arc::new(eid_b.finish()),
916 Arc::new(type_b.finish()),
917 Arc::new(src_b.finish()),
918 Arc::new(dst_b.finish()),
919 Arc::new(props_b.finish()),
920 ];
921 let canonical: [&str; 5] = ["_eid", "_type_name", "_src", "_dst", "properties"];
926 let mut ordered: Vec<ArrayRef> = Vec::with_capacity(fields.len());
927 for f in fields.iter() {
928 let idx = canonical
929 .iter()
930 .position(|n| *n == f.name().as_str())
931 .expect("is_edge_struct_shape vetted these field names");
932 ordered.push(arrays[idx].clone());
933 }
934 let nulls = arrow::buffer::NullBuffer::from(validity);
935 Arc::new(
936 arrow_array::StructArray::try_new(fields.clone(), ordered, Some(nulls))
937 .expect("StructArray construction with vetted shape"),
938 )
939}
940
941pub(crate) fn create_empty_batch(schema: SchemaRef) -> DFResult<RecordBatch> {
947 if schema.fields().is_empty() {
948 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
949 RecordBatch::try_new_with_options(schema, vec![], &options).map_err(arrow_err)
950 } else {
951 Ok(RecordBatch::new_empty(schema))
952 }
953}
954
955async fn execute_registered_procedure(
964 graph_ctx: &GraphExecutionContext,
965 procedure_name: &str,
966 args: &[Value],
967 yield_items: &[(String, Option<String>)],
968 schema: &SchemaRef,
969) -> DFResult<Option<RecordBatch>> {
970 let registry = graph_ctx.procedure_registry().ok_or_else(|| {
971 datafusion::error::DataFusionError::Execution(format!(
972 "Procedure '{}' not supported in DataFusion engine (no procedure registry)",
973 procedure_name
974 ))
975 })?;
976
977 let proc_def = registry.get(procedure_name).ok_or_else(|| {
978 datafusion::error::DataFusionError::Execution(format!(
979 "ProcedureNotFound: Unknown procedure '{}'",
980 procedure_name
981 ))
982 })?;
983
984 if args.len() != proc_def.params.len() {
986 return Err(datafusion::error::DataFusionError::Execution(format!(
987 "InvalidNumberOfArguments: Procedure '{}' expects {} argument(s), got {}",
988 proc_def.name,
989 proc_def.params.len(),
990 args.len()
991 )));
992 }
993
994 for (i, (arg_val, param)) in args.iter().zip(&proc_def.params).enumerate() {
996 if !arg_val.is_null() && !check_proc_type_compatible(arg_val, ¶m.param_type) {
997 return Err(datafusion::error::DataFusionError::Execution(format!(
998 "InvalidArgumentType: Argument {} ('{}') of procedure '{}' has incompatible type",
999 i, param.name, proc_def.name
1000 )));
1001 }
1002 }
1003
1004 let filtered: Vec<&HashMap<String, Value>> = proc_def
1006 .data
1007 .iter()
1008 .filter(|row| {
1009 for (param, arg_val) in proc_def.params.iter().zip(args) {
1010 if let Some(row_val) = row.get(¶m.name)
1011 && !proc_values_match(row_val, arg_val)
1012 {
1013 return false;
1014 }
1015 }
1016 true
1017 })
1018 .collect();
1019
1020 if yield_items.is_empty() {
1022 return Ok(Some(create_empty_batch(schema.clone())?));
1023 }
1024
1025 if filtered.is_empty() {
1026 return Ok(Some(create_empty_batch(schema.clone())?));
1027 }
1028
1029 let num_rows = filtered.len();
1032 let mut columns: Vec<ArrayRef> = Vec::new();
1033
1034 for (idx, (name, _alias)) in yield_items.iter().enumerate() {
1035 let field = schema.field(idx);
1036 let values = filtered.iter().map(|row| row.get(name.as_str()));
1037 columns.push(build_typed_column(values, num_rows, field.data_type()));
1038 }
1039
1040 let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
1041 Ok(Some(batch))
1042}
1043
1044fn check_proc_type_compatible(
1046 val: &Value,
1047 expected: &crate::query::executor::procedure::ProcedureValueType,
1048) -> bool {
1049 use crate::query::executor::procedure::ProcedureValueType;
1050 match expected {
1051 ProcedureValueType::Any => true,
1052 ProcedureValueType::String => val.is_string(),
1053 ProcedureValueType::Boolean => val.is_bool(),
1054 ProcedureValueType::Integer => val.is_i64(),
1055 ProcedureValueType::Float => val.is_f64() || val.is_i64(),
1056 ProcedureValueType::Number => val.is_number(),
1057 }
1058}
1059
1060fn proc_values_match(row_val: &Value, arg_val: &Value) -> bool {
1062 if arg_val.is_null() || row_val.is_null() {
1063 return arg_val.is_null() && row_val.is_null();
1064 }
1065 if let (Some(a), Some(b)) = (row_val.as_f64(), arg_val.as_f64()) {
1067 return (a - b).abs() < f64::EPSILON;
1068 }
1069 row_val == arg_val
1070}
1071
1072pub(crate) fn json_to_value(jv: &serde_json::Value) -> Value {
1074 match jv {
1075 serde_json::Value::Null => Value::Null,
1076 serde_json::Value::Bool(b) => Value::Bool(*b),
1077 serde_json::Value::Number(n) => {
1078 if let Some(i) = n.as_i64() {
1079 Value::Int(i)
1080 } else if let Some(f) = n.as_f64() {
1081 Value::Float(f)
1082 } else {
1083 Value::Null
1084 }
1085 }
1086 serde_json::Value::String(s) => Value::String(s.clone()),
1087 other => Value::String(other.to_string()),
1088 }
1089}
1090
1091pub(crate) fn require_string_arg(
1097 args: &[Value],
1098 index: usize,
1099 description: &str,
1100) -> DFResult<String> {
1101 args.get(index)
1102 .and_then(|v| v.as_str())
1103 .map(|s| s.to_string())
1104 .ok_or_else(|| {
1105 datafusion::error::DataFusionError::Execution(format!("{description} must be a string"))
1106 })
1107}
1108
1109pub(crate) fn extract_optional_filter(args: &[Value], index: usize) -> Option<String> {
1112 args.get(index).and_then(|v| {
1113 if v.is_null() {
1114 None
1115 } else {
1116 v.as_str().map(|s| s.to_string())
1117 }
1118 })
1119}