1use arrow_array::builder::{BooleanBuilder, Float64Builder, Int64Builder, StringBuilder};
17use arrow_array::{ArrayRef, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_cypher::ast::Expr;
32
33use crate::query::df_graph::GraphExecutionContext;
34use crate::query::df_graph::common::{
35 arrow_err, compute_plan_properties, evaluate_simple_expr, labels_data_type,
36};
37use crate::query::df_graph::scan::{property_field, resolve_property_type};
38
39pub(crate) fn map_yield_to_canonical(yield_name: &str) -> &'static str {
52 match yield_name.to_lowercase().as_str() {
53 "vid" | "_vid" => "vid",
54 "distance" | "dist" | "_distance" => "distance",
55 "score" | "_score" => "score",
56 "vector_score" => "vector_score",
57 "fts_score" => "fts_score",
58 "sparse_score" => "sparse_score",
59 "raw_score" => "raw_score",
60 "rerank_score" | "_rerank_score" => "rerank_score",
61 _ => "node",
62 }
63}
64
65pub(crate) const NODE_YIELD_PROCEDURE_NAMES: &[&str] = &[
75 "uni.vector.query",
76 "uni.fts.query",
77 "uni.sparse.query",
78 "uni.search",
79 "uni.create.vNode",
84];
85
86pub(crate) fn is_node_yield_procedure_static(name: &str) -> bool {
90 NODE_YIELD_PROCEDURE_NAMES.contains(&name)
91}
92
93pub(crate) fn canonical_search_type(canonical: &str) -> DataType {
98 match canonical {
99 "distance" => DataType::Float64,
100 "score" | "vector_score" | "fts_score" | "sparse_score" | "raw_score" | "rerank_score" => {
101 DataType::Float32
102 }
103 "vid" => DataType::Int64,
104 _ => DataType::Utf8,
105 }
106}
107
108fn expand_node_yield_fields(
115 output_name: &str,
116 target_properties: &HashMap<String, Vec<String>>,
117 graph_ctx: &GraphExecutionContext,
118 fields: &mut Vec<Field>,
119) {
120 fields.push(Field::new(
121 format!("{}._vid", output_name),
122 DataType::UInt64,
123 false,
124 ));
125 fields.push(Field::new(output_name, DataType::Utf8, false));
126 fields.push(Field::new(
127 format!("{}._labels", output_name),
128 labels_data_type(),
129 true,
130 ));
131
132 if let Some(props) = target_properties.get(output_name) {
133 let uni_schema = graph_ctx.storage().schema_manager().schema();
134 for prop_name in props {
135 let col_name = format!("{}.{}", output_name, prop_name);
136 let arrow_type = resolve_property_type(prop_name, None);
137 let resolved_type = uni_schema
138 .properties
139 .values()
140 .find_map(|label_props| {
141 label_props
142 .get(prop_name.as_str())
143 .map(|_| resolve_property_type(prop_name, Some(label_props)))
144 })
145 .unwrap_or(arrow_type);
146 let uni_type = uni_schema
147 .properties
148 .values()
149 .find_map(|label_props| label_props.get(prop_name.as_str()).map(|m| &m.r#type));
150 fields.push(property_field(&col_name, resolved_type, uni_type));
151 }
152 }
153}
154
155fn field_from_signature(col_name: &str, sig_field: &Field) -> Field {
159 let mut new_field = Field::new(
160 col_name,
161 sig_field.data_type().clone(),
162 sig_field.is_nullable(),
163 );
164 if !sig_field.metadata().is_empty() {
165 new_field = new_field.with_metadata(sig_field.metadata().clone());
166 }
167 new_field
168}
169
170pub struct GraphProcedureCallExec {
175 graph_ctx: Arc<GraphExecutionContext>,
177
178 procedure_name: String,
180
181 arguments: Vec<Expr>,
183
184 yield_items: Vec<(String, Option<String>)>,
186
187 params: HashMap<String, Value>,
189
190 outer_values: HashMap<String, Value>,
192
193 target_properties: HashMap<String, Vec<String>>,
195
196 schema: SchemaRef,
198
199 properties: Arc<PlanProperties>,
201
202 metrics: ExecutionPlanMetricsSet,
204}
205
206impl fmt::Debug for GraphProcedureCallExec {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 f.debug_struct("GraphProcedureCallExec")
209 .field("procedure_name", &self.procedure_name)
210 .field("yield_items", &self.yield_items)
211 .finish()
212 }
213}
214
215impl GraphProcedureCallExec {
216 pub fn new(
218 graph_ctx: Arc<GraphExecutionContext>,
219 procedure_name: String,
220 arguments: Vec<Expr>,
221 yield_items: Vec<(String, Option<String>)>,
222 params: HashMap<String, Value>,
223 outer_values: HashMap<String, Value>,
224 target_properties: HashMap<String, Vec<String>>,
225 ) -> Self {
226 let schema = Self::build_schema(
227 &procedure_name,
228 &yield_items,
229 &target_properties,
230 &graph_ctx,
231 );
232 let properties = compute_plan_properties(schema.clone());
233
234 Self {
235 graph_ctx,
236 procedure_name,
237 arguments,
238 yield_items,
239 params,
240 outer_values,
241 target_properties,
242 schema,
243 properties,
244 metrics: ExecutionPlanMetricsSet::new(),
245 }
246 }
247
248 fn build_schema(
262 procedure_name: &str,
263 yield_items: &[(String, Option<String>)],
264 target_properties: &HashMap<String, Vec<String>>,
265 graph_ctx: &GraphExecutionContext,
266 ) -> SchemaRef {
267 let mut fields = Vec::new();
268
269 if let Some(registry) = graph_ctx.procedure_registry()
270 && let Some(entry) = registry.resolve_user_procedure(procedure_name)
271 {
272 let supports_node_yield = entry.signature.yields.iter().any(|f| {
273 f.metadata()
274 .get("_yield_kind")
275 .is_some_and(|v| v == "node_vid_source")
276 });
277
278 for (yield_name, alias) in yield_items {
279 let col_name = alias.as_ref().unwrap_or(yield_name);
280
281 if supports_node_yield {
282 let canonical = map_yield_to_canonical(yield_name);
283 if canonical == "node" {
284 expand_node_yield_fields(
285 col_name,
286 target_properties,
287 graph_ctx,
288 &mut fields,
289 );
290 continue;
291 }
292 if let Some(sig_field) = entry
298 .signature
299 .yields
300 .iter()
301 .find(|f| f.name() == canonical)
302 {
303 fields.push(field_from_signature(col_name, sig_field));
304 } else {
305 fields.push(Field::new(col_name, canonical_search_type(canonical), true));
306 }
307 continue;
308 }
309
310 let field = entry
314 .signature
315 .yields
316 .iter()
317 .find(|f| f.name() == yield_name.as_str())
318 .map(|f| field_from_signature(col_name, f))
319 .unwrap_or_else(|| Field::new(col_name, DataType::Utf8, true));
320 fields.push(field);
321 }
322 } else if let Some(registry) = graph_ctx.procedure_registry()
323 && let Some(proc_def) = registry.get(procedure_name)
324 {
325 for (name, alias) in yield_items {
326 let col_name = alias.as_ref().unwrap_or(name);
327 let data_type = proc_def
328 .outputs
329 .iter()
330 .find(|o| o.name == *name)
331 .map(|o| procedure_value_type_to_arrow(&o.output_type))
332 .unwrap_or(DataType::Utf8);
333 fields.push(Field::new(col_name, data_type, true));
334 }
335 } else if yield_items.is_empty() {
336 } else {
338 for (name, alias) in yield_items {
339 let col_name = alias.as_ref().unwrap_or(name);
340 fields.push(Field::new(col_name, DataType::Utf8, true));
341 }
342 }
343
344 Arc::new(Schema::new(fields))
345 }
346}
347
348pub(crate) fn value_type_to_arrow(vt: &uni_algo::algo::procedures::ValueType) -> DataType {
350 use uni_algo::algo::procedures::ValueType;
351 match vt {
352 ValueType::Int => DataType::Int64,
353 ValueType::Float => DataType::Float64,
354 ValueType::String => DataType::Utf8,
355 ValueType::Bool => DataType::Boolean,
356 ValueType::List
357 | ValueType::Map
358 | ValueType::Node
359 | ValueType::Relationship
360 | ValueType::Path
361 | ValueType::Any => DataType::Utf8,
362 }
363}
364
365pub(crate) fn is_complex_value_type(vt: &uni_algo::algo::procedures::ValueType) -> bool {
368 use uni_algo::algo::procedures::ValueType;
369 matches!(
370 vt,
371 ValueType::List
372 | ValueType::Map
373 | ValueType::Node
374 | ValueType::Relationship
375 | ValueType::Path
376 )
377}
378
379fn procedure_value_type_to_arrow(
381 vt: &crate::query::executor::procedure::ProcedureValueType,
382) -> DataType {
383 use crate::query::executor::procedure::ProcedureValueType;
384 match vt {
385 ProcedureValueType::Integer => DataType::Int64,
386 ProcedureValueType::Float | ProcedureValueType::Number => DataType::Float64,
387 ProcedureValueType::Boolean => DataType::Boolean,
388 ProcedureValueType::String | ProcedureValueType::Any => DataType::Utf8,
389 }
390}
391
392impl DisplayAs for GraphProcedureCallExec {
393 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394 write!(
395 f,
396 "GraphProcedureCallExec: procedure={}",
397 self.procedure_name
398 )
399 }
400}
401
402impl ExecutionPlan for GraphProcedureCallExec {
403 fn name(&self) -> &str {
404 "GraphProcedureCallExec"
405 }
406
407 fn as_any(&self) -> &dyn Any {
408 self
409 }
410
411 fn schema(&self) -> SchemaRef {
412 self.schema.clone()
413 }
414
415 fn properties(&self) -> &Arc<PlanProperties> {
416 &self.properties
417 }
418
419 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
420 vec![]
421 }
422
423 fn with_new_children(
424 self: Arc<Self>,
425 children: Vec<Arc<dyn ExecutionPlan>>,
426 ) -> DFResult<Arc<dyn ExecutionPlan>> {
427 if !children.is_empty() {
428 return Err(datafusion::error::DataFusionError::Internal(
429 "GraphProcedureCallExec has no children".to_string(),
430 ));
431 }
432 Ok(self)
433 }
434
435 fn execute(
436 &self,
437 partition: usize,
438 _context: Arc<TaskContext>,
439 ) -> DFResult<SendableRecordBatchStream> {
440 let metrics = BaselineMetrics::new(&self.metrics, partition);
441
442 let mut evaluated_args = Vec::with_capacity(self.arguments.len());
444 for arg in &self.arguments {
445 evaluated_args.push(evaluate_simple_expr(arg, &self.params, &self.outer_values)?);
446 }
447
448 Ok(Box::pin(ProcedureCallStream::new(
449 self.graph_ctx.clone(),
450 self.procedure_name.clone(),
451 evaluated_args,
452 self.yield_items.clone(),
453 self.target_properties.clone(),
454 self.schema.clone(),
455 metrics,
456 )))
457 }
458
459 fn metrics(&self) -> Option<MetricsSet> {
460 Some(self.metrics.clone_inner())
461 }
462}
463
464enum ProcedureCallState {
470 Init,
472 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
474 Done,
476}
477
478struct ProcedureCallStream {
480 graph_ctx: Arc<GraphExecutionContext>,
481 procedure_name: String,
482 evaluated_args: Vec<Value>,
483 yield_items: Vec<(String, Option<String>)>,
484 target_properties: HashMap<String, Vec<String>>,
485 schema: SchemaRef,
486 state: ProcedureCallState,
487 metrics: BaselineMetrics,
488}
489
490impl ProcedureCallStream {
491 fn new(
492 graph_ctx: Arc<GraphExecutionContext>,
493 procedure_name: String,
494 evaluated_args: Vec<Value>,
495 yield_items: Vec<(String, Option<String>)>,
496 target_properties: HashMap<String, Vec<String>>,
497 schema: SchemaRef,
498 metrics: BaselineMetrics,
499 ) -> Self {
500 Self {
501 graph_ctx,
502 procedure_name,
503 evaluated_args,
504 yield_items,
505 target_properties,
506 schema,
507 state: ProcedureCallState::Init,
508 metrics,
509 }
510 }
511}
512
513impl Stream for ProcedureCallStream {
514 type Item = DFResult<RecordBatch>;
515
516 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
517 let metrics = self.metrics.clone();
518 let _timer = metrics.elapsed_compute().timer();
519 loop {
520 let state = std::mem::replace(&mut self.state, ProcedureCallState::Done);
521
522 match state {
523 ProcedureCallState::Init => {
524 let graph_ctx = self.graph_ctx.clone();
525 let procedure_name = self.procedure_name.clone();
526 let evaluated_args = self.evaluated_args.clone();
527 let yield_items = self.yield_items.clone();
528 let target_properties = self.target_properties.clone();
529 let schema = self.schema.clone();
530
531 let fut = async move {
532 graph_ctx.check_timeout().map_err(|e| {
533 datafusion::error::DataFusionError::Execution(e.to_string())
534 })?;
535
536 execute_procedure(
537 &graph_ctx,
538 &procedure_name,
539 &evaluated_args,
540 &yield_items,
541 &target_properties,
542 &schema,
543 )
544 .await
545 };
546
547 self.state = ProcedureCallState::Executing(Box::pin(fut));
548 }
549 ProcedureCallState::Executing(mut fut) => match fut.as_mut().poll(cx) {
550 Poll::Ready(Ok(batch)) => {
551 self.state = ProcedureCallState::Done;
552 self.metrics
553 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
554 return Poll::Ready(batch.map(Ok));
555 }
556 Poll::Ready(Err(e)) => {
557 self.state = ProcedureCallState::Done;
558 return Poll::Ready(Some(Err(e)));
559 }
560 Poll::Pending => {
561 self.state = ProcedureCallState::Executing(fut);
562 return Poll::Pending;
563 }
564 },
565 ProcedureCallState::Done => {
566 return Poll::Ready(None);
567 }
568 }
569 }
570 }
571}
572
573impl RecordBatchStream for ProcedureCallStream {
574 fn schema(&self) -> SchemaRef {
575 self.schema.clone()
576 }
577}
578
579async fn execute_procedure(
600 graph_ctx: &GraphExecutionContext,
601 procedure_name: &str,
602 args: &[Value],
603 yield_items: &[(String, Option<String>)],
604 target_properties: &HashMap<String, Vec<String>>,
605 schema: &SchemaRef,
606) -> DFResult<Option<RecordBatch>> {
607 if let Some(registry) = graph_ctx.procedure_registry()
613 && let Some(entry) = registry.resolve_user_procedure(procedure_name)
614 {
615 return execute_plugin_procedure(
616 graph_ctx,
617 procedure_name,
618 &entry,
619 args,
620 yield_items,
621 target_properties,
622 schema,
623 )
624 .await;
625 }
626
627 execute_registered_procedure(graph_ctx, procedure_name, args, yield_items, schema).await
628}
629
630async fn execute_plugin_procedure(
640 graph_ctx: &GraphExecutionContext,
641 procedure_name: &str,
642 entry: &uni_plugin::registry::ProcedureEntry,
643 args: &[Value],
644 yield_items: &[(String, Option<String>)],
645 target_properties: &HashMap<String, Vec<String>>,
646 schema: &SchemaRef,
647) -> DFResult<Option<RecordBatch>> {
648 use datafusion::logical_expr::ColumnarValue;
649 use futures::StreamExt;
650
651 let mut columnar_args: Vec<ColumnarValue> = Vec::with_capacity(args.len());
657 for v in args {
658 columnar_args.push(value_to_columnar(v).map_err(|e| {
659 datafusion::error::DataFusionError::Execution(format!(
660 "Procedure '{procedure_name}': argument conversion failed: {e}"
661 ))
662 })?);
663 }
664
665 let mut host =
666 crate::query::executor::procedure_host::QueryProcedureHost::from_graph_ctx_with_request(
667 graph_ctx,
668 target_properties.clone(),
669 yield_items.to_vec(),
670 Some(schema.clone()),
671 );
672 if let Some(writer) = graph_ctx.writer() {
677 host = host.with_writer(std::sync::Arc::clone(writer));
678 }
679 let principal = crate::current_principal();
686 let ctx = uni_plugin::host::build_procedure_context(&host, principal.as_deref());
687 let mut stream = entry.procedure.invoke(ctx, &columnar_args).map_err(|e| {
688 datafusion::error::DataFusionError::Execution(format!("Procedure '{procedure_name}': {e}"))
689 })?;
690
691 let mut batches: Vec<RecordBatch> = Vec::new();
695 while let Some(item) = stream.next().await {
696 let batch = item.map_err(|e| {
697 datafusion::error::DataFusionError::Execution(format!(
698 "Procedure '{procedure_name}' stream error: {e}"
699 ))
700 })?;
701 batches.push(batch);
702 }
703
704 if batches.is_empty() {
705 return Ok(Some(create_empty_batch(schema.clone())?));
708 }
709
710 let plugin_schema = batches[0].schema();
713 let combined = if batches.len() == 1 {
714 batches.pop().unwrap()
715 } else {
716 arrow::compute::concat_batches(&plugin_schema, &batches).map_err(arrow_err)?
717 };
718
719 if combined.schema().fields() == schema.fields() {
723 return Ok(Some(combined));
724 }
725
726 if yield_items.is_empty()
729 || (yield_items.len() == combined.num_columns()
730 && yield_items
731 .iter()
732 .zip(combined.schema().fields().iter())
733 .all(|((name, _alias), field)| name == field.name()))
734 {
735 return Ok(Some(combined));
736 }
737
738 let mut projected_cols: Vec<ArrayRef> = Vec::with_capacity(yield_items.len());
739 let mut projected_fields: Vec<Field> = Vec::with_capacity(yield_items.len());
740 for (name, _alias) in yield_items {
741 let idx = combined.schema().index_of(name).map_err(|_| {
742 datafusion::error::DataFusionError::Execution(format!(
743 "Procedure '{procedure_name}': YIELD column `{name}` not in plugin output schema"
744 ))
745 })?;
746 projected_cols.push(combined.column(idx).clone());
747 projected_fields.push(combined.schema().field(idx).clone());
748 }
749 let projected_schema = Arc::new(Schema::new(projected_fields));
750 let projected = RecordBatch::try_new(projected_schema, projected_cols).map_err(arrow_err)?;
751 Ok(Some(projected))
752}
753
754pub(crate) fn value_to_columnar(
758 v: &Value,
759) -> Result<datafusion::logical_expr::ColumnarValue, String> {
760 use datafusion::logical_expr::ColumnarValue;
761 use datafusion::scalar::ScalarValue;
762
763 let scalar = match v {
764 Value::Null => ScalarValue::Null,
765 Value::Bool(b) => ScalarValue::Boolean(Some(*b)),
766 Value::Int(i) => ScalarValue::Int64(Some(*i)),
767 Value::Float(f) => ScalarValue::Float64(Some(*f)),
768 Value::String(s) => ScalarValue::Utf8(Some(s.clone())),
769 Value::Bytes(b) => ScalarValue::Binary(Some(b.clone())),
770 other => {
771 let json = serde_json::to_vec(other)
778 .map_err(|e| format!("plugin arg encoding failed for {other:?}: {e}"))?;
779 ScalarValue::LargeBinary(Some(json))
780 }
781 };
782 Ok(ColumnarValue::Scalar(scalar))
783}
784
785pub(crate) fn build_typed_column<'a>(
799 values: impl Iterator<Item = Option<&'a Value>>,
800 num_rows: usize,
801 data_type: &DataType,
802) -> ArrayRef {
803 match data_type {
804 DataType::UInt64 => {
805 let mut builder = arrow_array::builder::UInt64Builder::with_capacity(num_rows);
806 for val in values {
807 match val.and_then(uni_common::Value::as_u64) {
808 Some(u) => builder.append_value(u),
809 None => builder.append_null(),
810 }
811 }
812 Arc::new(builder.finish())
813 }
814 DataType::Struct(fields) if is_edge_struct_shape(fields) => {
815 build_edge_struct_column(values, num_rows, fields)
816 }
817 DataType::Int64 => {
818 let mut builder = Int64Builder::with_capacity(num_rows);
819 for val in values {
820 match val.and_then(|v| v.as_i64()) {
821 Some(i) => builder.append_value(i),
822 None => builder.append_null(),
823 }
824 }
825 Arc::new(builder.finish())
826 }
827 DataType::Float64 => {
828 let mut builder = Float64Builder::with_capacity(num_rows);
829 for val in values {
830 match val.and_then(|v| v.as_f64()) {
831 Some(f) => builder.append_value(f),
832 None => builder.append_null(),
833 }
834 }
835 Arc::new(builder.finish())
836 }
837 DataType::Boolean => {
838 let mut builder = BooleanBuilder::with_capacity(num_rows);
839 for val in values {
840 match val.and_then(|v| v.as_bool()) {
841 Some(b) => builder.append_value(b),
842 None => builder.append_null(),
843 }
844 }
845 Arc::new(builder.finish())
846 }
847 _ => {
848 let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
850 for val in values {
851 match val {
852 Some(Value::String(s)) => builder.append_value(s),
853 Some(v) => builder.append_value(format!("{v}")),
854 None => builder.append_null(),
855 }
856 }
857 Arc::new(builder.finish())
858 }
859 }
860}
861
862fn is_edge_struct_shape(fields: &arrow_schema::Fields) -> bool {
868 let names: std::collections::HashSet<&str> = fields.iter().map(|f| f.name().as_str()).collect();
869 names.contains("_eid")
870 && names.contains("_type_name")
871 && names.contains("_src")
872 && names.contains("_dst")
873 && names.contains("properties")
874}
875
876fn build_edge_struct_column<'a>(
881 values: impl Iterator<Item = Option<&'a Value>>,
882 _num_rows: usize,
883 fields: &arrow_schema::Fields,
884) -> ArrayRef {
885 use arrow_array::builder::{LargeBinaryBuilder, StringBuilder, UInt64Builder};
886 use uni_common::Value as V;
887
888 let mut eid_b = UInt64Builder::new();
889 let mut type_b = StringBuilder::new();
890 let mut src_b = UInt64Builder::new();
891 let mut dst_b = UInt64Builder::new();
892 let mut props_b = LargeBinaryBuilder::new();
893 let mut validity: Vec<bool> = Vec::new();
894
895 for val in values {
896 match val {
897 Some(V::Edge(e)) => {
898 eid_b.append_value(e.eid.as_u64());
899 type_b.append_value(&e.edge_type);
900 src_b.append_value(e.src.as_u64());
901 dst_b.append_value(e.dst.as_u64());
902 let props_value = V::Map(e.properties.clone());
903 let bytes = uni_common::cypher_value_codec::encode(&props_value);
904 props_b.append_value(&bytes);
905 validity.push(true);
906 }
907 _ => {
908 eid_b.append_null();
909 type_b.append_null();
910 src_b.append_null();
911 dst_b.append_null();
912 props_b.append_null();
913 validity.push(false);
914 }
915 }
916 }
917
918 let arrays: Vec<ArrayRef> = vec![
919 Arc::new(eid_b.finish()),
920 Arc::new(type_b.finish()),
921 Arc::new(src_b.finish()),
922 Arc::new(dst_b.finish()),
923 Arc::new(props_b.finish()),
924 ];
925 let canonical: [&str; 5] = ["_eid", "_type_name", "_src", "_dst", "properties"];
930 let mut ordered: Vec<ArrayRef> = Vec::with_capacity(fields.len());
931 for f in fields.iter() {
932 let idx = canonical
933 .iter()
934 .position(|n| *n == f.name().as_str())
935 .expect("is_edge_struct_shape vetted these field names");
936 ordered.push(arrays[idx].clone());
937 }
938 let nulls = arrow::buffer::NullBuffer::from(validity);
939 Arc::new(
940 arrow_array::StructArray::try_new(fields.clone(), ordered, Some(nulls))
941 .expect("StructArray construction with vetted shape"),
942 )
943}
944
945pub(crate) fn create_empty_batch(schema: SchemaRef) -> DFResult<RecordBatch> {
951 if schema.fields().is_empty() {
952 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
953 RecordBatch::try_new_with_options(schema, vec![], &options).map_err(arrow_err)
954 } else {
955 Ok(RecordBatch::new_empty(schema))
956 }
957}
958
959async fn execute_registered_procedure(
968 graph_ctx: &GraphExecutionContext,
969 procedure_name: &str,
970 args: &[Value],
971 yield_items: &[(String, Option<String>)],
972 schema: &SchemaRef,
973) -> DFResult<Option<RecordBatch>> {
974 let registry = graph_ctx.procedure_registry().ok_or_else(|| {
975 datafusion::error::DataFusionError::Execution(format!(
976 "Procedure '{}' not supported in DataFusion engine (no procedure registry)",
977 procedure_name
978 ))
979 })?;
980
981 let proc_def = registry.get(procedure_name).ok_or_else(|| {
982 datafusion::error::DataFusionError::Execution(format!(
983 "ProcedureNotFound: Unknown procedure '{}'",
984 procedure_name
985 ))
986 })?;
987
988 if args.len() != proc_def.params.len() {
990 return Err(datafusion::error::DataFusionError::Execution(format!(
991 "InvalidNumberOfArguments: Procedure '{}' expects {} argument(s), got {}",
992 proc_def.name,
993 proc_def.params.len(),
994 args.len()
995 )));
996 }
997
998 for (i, (arg_val, param)) in args.iter().zip(&proc_def.params).enumerate() {
1000 if !arg_val.is_null() && !check_proc_type_compatible(arg_val, ¶m.param_type) {
1001 return Err(datafusion::error::DataFusionError::Execution(format!(
1002 "InvalidArgumentType: Argument {} ('{}') of procedure '{}' has incompatible type",
1003 i, param.name, proc_def.name
1004 )));
1005 }
1006 }
1007
1008 let filtered: Vec<&HashMap<String, Value>> = proc_def
1010 .data
1011 .iter()
1012 .filter(|row| {
1013 for (param, arg_val) in proc_def.params.iter().zip(args) {
1014 if let Some(row_val) = row.get(¶m.name)
1015 && !proc_values_match(row_val, arg_val)
1016 {
1017 return false;
1018 }
1019 }
1020 true
1021 })
1022 .collect();
1023
1024 if yield_items.is_empty() {
1026 return Ok(Some(create_empty_batch(schema.clone())?));
1027 }
1028
1029 if filtered.is_empty() {
1030 return Ok(Some(create_empty_batch(schema.clone())?));
1031 }
1032
1033 let num_rows = filtered.len();
1036 let mut columns: Vec<ArrayRef> = Vec::new();
1037
1038 for (idx, (name, _alias)) in yield_items.iter().enumerate() {
1039 let field = schema.field(idx);
1040 let values = filtered.iter().map(|row| row.get(name.as_str()));
1041 columns.push(build_typed_column(values, num_rows, field.data_type()));
1042 }
1043
1044 let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
1045 Ok(Some(batch))
1046}
1047
1048fn check_proc_type_compatible(
1050 val: &Value,
1051 expected: &crate::query::executor::procedure::ProcedureValueType,
1052) -> bool {
1053 use crate::query::executor::procedure::ProcedureValueType;
1054 match expected {
1055 ProcedureValueType::Any => true,
1056 ProcedureValueType::String => val.is_string(),
1057 ProcedureValueType::Boolean => val.is_bool(),
1058 ProcedureValueType::Integer => val.is_i64(),
1059 ProcedureValueType::Float => val.is_f64() || val.is_i64(),
1060 ProcedureValueType::Number => val.is_number(),
1061 }
1062}
1063
1064fn proc_values_match(row_val: &Value, arg_val: &Value) -> bool {
1066 if arg_val.is_null() || row_val.is_null() {
1067 return arg_val.is_null() && row_val.is_null();
1068 }
1069 if let (Some(a), Some(b)) = (row_val.as_f64(), arg_val.as_f64()) {
1071 return (a - b).abs() < f64::EPSILON;
1072 }
1073 row_val == arg_val
1074}
1075
1076pub(crate) fn json_to_value(jv: &serde_json::Value) -> Value {
1078 match jv {
1079 serde_json::Value::Null => Value::Null,
1080 serde_json::Value::Bool(b) => Value::Bool(*b),
1081 serde_json::Value::Number(n) => {
1082 if let Some(i) = n.as_i64() {
1083 Value::Int(i)
1084 } else if let Some(f) = n.as_f64() {
1085 Value::Float(f)
1086 } else {
1087 Value::Null
1088 }
1089 }
1090 serde_json::Value::String(s) => Value::String(s.clone()),
1091 other => Value::String(other.to_string()),
1092 }
1093}
1094
1095pub(crate) fn require_string_arg(
1101 args: &[Value],
1102 index: usize,
1103 description: &str,
1104) -> DFResult<String> {
1105 args.get(index)
1106 .and_then(|v| v.as_str())
1107 .map(|s| s.to_string())
1108 .ok_or_else(|| {
1109 datafusion::error::DataFusionError::Execution(format!("{description} must be a string"))
1110 })
1111}
1112
1113pub(crate) fn extract_optional_filter(args: &[Value], index: usize) -> Option<String> {
1116 args.get(index).and_then(|v| {
1117 if v.is_null() {
1118 None
1119 } else {
1120 v.as_str().map(|s| s.to_string())
1121 }
1122 })
1123}