Skip to main content

uni_query/query/df_graph/
procedure_call.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Procedure call execution plan for DataFusion.
5//!
6//! This module provides [`GraphProcedureCallExec`], a DataFusion [`ExecutionPlan`] that
7//! executes Cypher `CALL` procedures natively within the DataFusion engine.
8//!
9//! Used for composite queries where a `CALL` is followed by `MATCH`, e.g.:
10//! ```text
11//! CALL uni.schema.labels() YIELD label
12//! MATCH (n:Person) WHERE label = 'Person'
13//! RETURN n.name, label
14//! ```
15
16use arrow_array::builder::{
17    BooleanBuilder, Float32Builder, Float64Builder, Int64Builder, StringBuilder, UInt64Builder,
18};
19use arrow_array::{ArrayRef, RecordBatch};
20use arrow_schema::{DataType, Field, Schema, SchemaRef};
21use datafusion::common::Result as DFResult;
22use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
23use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
24use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
25use futures::Stream;
26use std::any::Any;
27use std::collections::HashMap;
28use std::fmt;
29use std::pin::Pin;
30use std::sync::Arc;
31use std::task::{Context, Poll};
32use uni_common::Value;
33use uni_common::core::id::Vid;
34use uni_common::core::schema::DistanceMetric;
35use uni_cypher::ast::Expr;
36
37use crate::query::df_graph::GraphExecutionContext;
38use crate::query::df_graph::common::{
39    arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
40};
41use crate::query::df_graph::scan::resolve_property_type;
42
43/// Maps a user-provided yield name to a canonical name.
44///
45/// - "vid", "_vid" → "vid"
46/// - "distance", "dist", "_distance" → "distance"
47/// - "score", "_score" → "score"
48/// - anything else → "node" (treated as node variable)
49pub(crate) fn map_yield_to_canonical(yield_name: &str) -> String {
50    match yield_name.to_lowercase().as_str() {
51        "vid" | "_vid" => "vid",
52        "distance" | "dist" | "_distance" => "distance",
53        "score" | "_score" => "score",
54        "vector_score" => "vector_score",
55        "fts_score" => "fts_score",
56        "raw_score" => "raw_score",
57        _ => "node",
58    }
59    .to_string()
60}
61
62/// Procedure call execution plan for DataFusion.
63///
64/// Executes Cypher CALL procedures (schema introspection, vector search, FTS, etc.)
65/// and emits results as Arrow RecordBatches.
66pub struct GraphProcedureCallExec {
67    /// Graph execution context for storage access.
68    graph_ctx: Arc<GraphExecutionContext>,
69
70    /// Fully qualified procedure name (e.g. "uni.schema.labels").
71    procedure_name: String,
72
73    /// Argument expressions from the CALL clause.
74    arguments: Vec<Expr>,
75
76    /// Yield items: (original_name, optional_alias).
77    yield_items: Vec<(String, Option<String>)>,
78
79    /// Query parameters for expression evaluation.
80    params: HashMap<String, Value>,
81
82    /// Target properties per variable (for node-like yields).
83    target_properties: HashMap<String, Vec<String>>,
84
85    /// Output schema.
86    schema: SchemaRef,
87
88    /// Plan properties.
89    properties: PlanProperties,
90
91    /// Execution metrics.
92    metrics: ExecutionPlanMetricsSet,
93}
94
95impl fmt::Debug for GraphProcedureCallExec {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        f.debug_struct("GraphProcedureCallExec")
98            .field("procedure_name", &self.procedure_name)
99            .field("yield_items", &self.yield_items)
100            .finish()
101    }
102}
103
104impl GraphProcedureCallExec {
105    /// Create a new procedure call execution plan.
106    pub fn new(
107        graph_ctx: Arc<GraphExecutionContext>,
108        procedure_name: String,
109        arguments: Vec<Expr>,
110        yield_items: Vec<(String, Option<String>)>,
111        params: HashMap<String, Value>,
112        target_properties: HashMap<String, Vec<String>>,
113    ) -> Self {
114        let schema = Self::build_schema(
115            &procedure_name,
116            &yield_items,
117            &target_properties,
118            &graph_ctx,
119        );
120        let properties = compute_plan_properties(schema.clone());
121
122        Self {
123            graph_ctx,
124            procedure_name,
125            arguments,
126            yield_items,
127            params,
128            target_properties,
129            schema,
130            properties,
131            metrics: ExecutionPlanMetricsSet::new(),
132        }
133    }
134
135    /// Build the output schema based on the procedure name and yield items.
136    fn build_schema(
137        procedure_name: &str,
138        yield_items: &[(String, Option<String>)],
139        target_properties: &HashMap<String, Vec<String>>,
140        graph_ctx: &GraphExecutionContext,
141    ) -> SchemaRef {
142        let mut fields = Vec::new();
143
144        match procedure_name {
145            "uni.schema.labels" => {
146                // Schema procedure yields scalar columns
147                for (name, alias) in yield_items {
148                    let col_name = alias.as_ref().unwrap_or(name);
149                    let data_type = match name.as_str() {
150                        "label" => DataType::Utf8,
151                        "propertyCount" | "nodeCount" | "indexCount" => DataType::Int64,
152                        _ => DataType::Utf8,
153                    };
154                    fields.push(Field::new(col_name, data_type, true));
155                }
156            }
157            "uni.schema.edgeTypes" | "uni.schema.relationshipTypes" => {
158                for (name, alias) in yield_items {
159                    let col_name = alias.as_ref().unwrap_or(name);
160                    let data_type = match name.as_str() {
161                        "type" | "relationshipType" => DataType::Utf8,
162                        "propertyCount" => DataType::Int64,
163                        "sourceLabels" | "targetLabels" => DataType::Utf8, // JSON string
164                        _ => DataType::Utf8,
165                    };
166                    fields.push(Field::new(col_name, data_type, true));
167                }
168            }
169            "uni.schema.indexes" => {
170                for (name, alias) in yield_items {
171                    let col_name = alias.as_ref().unwrap_or(name);
172                    let data_type = match name.as_str() {
173                        "name" | "type" | "label" | "state" | "properties" => DataType::Utf8,
174                        _ => DataType::Utf8,
175                    };
176                    fields.push(Field::new(col_name, data_type, true));
177                }
178            }
179            "uni.schema.constraints" => {
180                for (name, alias) in yield_items {
181                    let col_name = alias.as_ref().unwrap_or(name);
182                    let data_type = match name.as_str() {
183                        "enabled" => DataType::Boolean,
184                        _ => DataType::Utf8,
185                    };
186                    fields.push(Field::new(col_name, data_type, true));
187                }
188            }
189            "uni.schema.labelInfo" => {
190                for (name, alias) in yield_items {
191                    let col_name = alias.as_ref().unwrap_or(name);
192                    let data_type = match name.as_str() {
193                        "property" | "dataType" => DataType::Utf8,
194                        "nullable" | "indexed" | "unique" => DataType::Boolean,
195                        _ => DataType::Utf8,
196                    };
197                    fields.push(Field::new(col_name, data_type, true));
198                }
199            }
200            "uni.vector.query" | "uni.fts.query" | "uni.search" => {
201                // Search procedures yield node-like and scalar columns
202                for (name, alias) in yield_items {
203                    let output_name = alias.as_ref().unwrap_or(name);
204                    let canonical = map_yield_to_canonical(name);
205
206                    match canonical.as_str() {
207                        "node" => {
208                            // Node-like yield: emit _vid, variable, _label, and properties
209                            fields.push(Field::new(
210                                format!("{}._vid", output_name),
211                                DataType::UInt64,
212                                false,
213                            ));
214                            fields.push(Field::new(output_name, DataType::Utf8, false));
215                            fields.push(Field::new(
216                                format!("{}._labels", output_name),
217                                labels_data_type(),
218                                true,
219                            ));
220
221                            // Add property columns
222                            if let Some(props) = target_properties.get(output_name.as_str()) {
223                                let uni_schema = graph_ctx.storage().schema_manager().schema();
224                                // We don't know the exact label yet at planning time,
225                                // but we can try to resolve property types from any label
226                                for prop_name in props {
227                                    let col_name = format!("{}.{}", output_name, prop_name);
228                                    let arrow_type = resolve_property_type(prop_name, None);
229                                    // Try to resolve from all labels in the schema
230                                    let resolved_type = uni_schema
231                                        .properties
232                                        .values()
233                                        .find_map(|label_props| {
234                                            label_props.get(prop_name.as_str()).map(|_| {
235                                                resolve_property_type(prop_name, Some(label_props))
236                                            })
237                                        })
238                                        .unwrap_or(arrow_type);
239                                    fields.push(Field::new(&col_name, resolved_type, true));
240                                }
241                            }
242                        }
243                        "distance" => {
244                            fields.push(Field::new(output_name, DataType::Float64, true));
245                        }
246                        "score" | "vector_score" | "fts_score" | "raw_score" => {
247                            fields.push(Field::new(output_name, DataType::Float32, true));
248                        }
249                        "vid" => {
250                            fields.push(Field::new(output_name, DataType::Int64, true));
251                        }
252                        _ => {
253                            fields.push(Field::new(output_name, DataType::Utf8, true));
254                        }
255                    }
256                }
257            }
258            name if name.starts_with("uni.algo.") => {
259                if let Some(registry) = graph_ctx.algo_registry()
260                    && let Some(procedure) = registry.get(name)
261                {
262                    let sig = procedure.signature();
263                    for (yield_name, alias) in yield_items {
264                        let col_name = alias.as_ref().unwrap_or(yield_name);
265                        let yield_vt = sig.yields.iter().find(|(n, _)| *n == yield_name.as_str());
266                        let data_type = yield_vt
267                            .map(|(_, vt)| value_type_to_arrow(vt))
268                            .unwrap_or(DataType::Utf8);
269                        let mut field = Field::new(col_name, data_type, true);
270                        // Tag complex types (List, Map, etc.) so record_batches_to_rows
271                        // can parse the JSON string back to the original type.
272                        if yield_vt.is_some_and(|(_, vt)| is_complex_value_type(vt)) {
273                            let mut metadata = std::collections::HashMap::new();
274                            metadata.insert("cv_encoded".to_string(), "true".to_string());
275                            field = field.with_metadata(metadata);
276                        }
277                        fields.push(field);
278                    }
279                } else {
280                    // Unknown algo or no registry: fallback to Utf8
281                    for (name, alias) in yield_items {
282                        let col_name = alias.as_ref().unwrap_or(name);
283                        fields.push(Field::new(col_name, DataType::Utf8, true));
284                    }
285                }
286            }
287            _ => {
288                // Check external procedure registry for type information
289                if let Some(registry) = graph_ctx.procedure_registry()
290                    && let Some(proc_def) = registry.get(procedure_name)
291                {
292                    for (name, alias) in yield_items {
293                        let col_name = alias.as_ref().unwrap_or(name);
294                        // Find the output type from the procedure definition
295                        let data_type = proc_def
296                            .outputs
297                            .iter()
298                            .find(|o| o.name == *name)
299                            .map(|o| procedure_value_type_to_arrow(&o.output_type))
300                            .unwrap_or(DataType::Utf8);
301                        fields.push(Field::new(col_name, data_type, true));
302                    }
303                } else if yield_items.is_empty() {
304                    // Void procedure (no YIELD) — no output columns
305                } else {
306                    // Unknown procedure without registry: fallback to Utf8
307                    for (name, alias) in yield_items {
308                        let col_name = alias.as_ref().unwrap_or(name);
309                        fields.push(Field::new(col_name, DataType::Utf8, true));
310                    }
311                }
312            }
313        }
314
315        Arc::new(Schema::new(fields))
316    }
317}
318
319/// Convert an algorithm `ValueType` to an Arrow `DataType`.
320fn value_type_to_arrow(vt: &uni_algo::algo::procedures::ValueType) -> DataType {
321    use uni_algo::algo::procedures::ValueType;
322    match vt {
323        ValueType::Int => DataType::Int64,
324        ValueType::Float => DataType::Float64,
325        ValueType::String => DataType::Utf8,
326        ValueType::Bool => DataType::Boolean,
327        ValueType::List
328        | ValueType::Map
329        | ValueType::Node
330        | ValueType::Relationship
331        | ValueType::Path
332        | ValueType::Any => DataType::Utf8,
333    }
334}
335
336/// Returns true if the ValueType is a complex type that should be JSON-encoded as Utf8
337/// and tagged with `cv_encoded=true` metadata for downstream parsing.
338fn is_complex_value_type(vt: &uni_algo::algo::procedures::ValueType) -> bool {
339    use uni_algo::algo::procedures::ValueType;
340    matches!(
341        vt,
342        ValueType::List
343            | ValueType::Map
344            | ValueType::Node
345            | ValueType::Relationship
346            | ValueType::Path
347    )
348}
349
350/// Convert a `ProcedureValueType` to an Arrow `DataType`.
351fn procedure_value_type_to_arrow(
352    vt: &crate::query::executor::procedure::ProcedureValueType,
353) -> DataType {
354    use crate::query::executor::procedure::ProcedureValueType;
355    match vt {
356        ProcedureValueType::Integer => DataType::Int64,
357        ProcedureValueType::Float | ProcedureValueType::Number => DataType::Float64,
358        ProcedureValueType::Boolean => DataType::Boolean,
359        ProcedureValueType::String | ProcedureValueType::Any => DataType::Utf8,
360    }
361}
362
363impl DisplayAs for GraphProcedureCallExec {
364    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365        write!(
366            f,
367            "GraphProcedureCallExec: procedure={}",
368            self.procedure_name
369        )
370    }
371}
372
373impl ExecutionPlan for GraphProcedureCallExec {
374    fn name(&self) -> &str {
375        "GraphProcedureCallExec"
376    }
377
378    fn as_any(&self) -> &dyn Any {
379        self
380    }
381
382    fn schema(&self) -> SchemaRef {
383        self.schema.clone()
384    }
385
386    fn properties(&self) -> &PlanProperties {
387        &self.properties
388    }
389
390    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
391        vec![]
392    }
393
394    fn with_new_children(
395        self: Arc<Self>,
396        children: Vec<Arc<dyn ExecutionPlan>>,
397    ) -> DFResult<Arc<dyn ExecutionPlan>> {
398        if !children.is_empty() {
399            return Err(datafusion::error::DataFusionError::Internal(
400                "GraphProcedureCallExec has no children".to_string(),
401            ));
402        }
403        Ok(self)
404    }
405
406    fn execute(
407        &self,
408        partition: usize,
409        _context: Arc<TaskContext>,
410    ) -> DFResult<SendableRecordBatchStream> {
411        let metrics = BaselineMetrics::new(&self.metrics, partition);
412
413        // Evaluate arguments upfront
414        let mut evaluated_args = Vec::with_capacity(self.arguments.len());
415        for arg in &self.arguments {
416            evaluated_args.push(evaluate_simple_expr(arg, &self.params)?);
417        }
418
419        Ok(Box::pin(ProcedureCallStream::new(
420            self.graph_ctx.clone(),
421            self.procedure_name.clone(),
422            evaluated_args,
423            self.yield_items.clone(),
424            self.target_properties.clone(),
425            self.schema.clone(),
426            metrics,
427        )))
428    }
429
430    fn metrics(&self) -> Option<MetricsSet> {
431        Some(self.metrics.clone_inner())
432    }
433}
434
435// ---------------------------------------------------------------------------
436// Stream implementation
437// ---------------------------------------------------------------------------
438
439/// State machine for procedure call stream.
440enum ProcedureCallState {
441    /// Initial state, ready to start execution.
442    Init,
443    /// Executing the async procedure.
444    Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
445    /// Stream is done.
446    Done,
447}
448
449/// Stream that executes a procedure call.
450struct ProcedureCallStream {
451    graph_ctx: Arc<GraphExecutionContext>,
452    procedure_name: String,
453    evaluated_args: Vec<Value>,
454    yield_items: Vec<(String, Option<String>)>,
455    target_properties: HashMap<String, Vec<String>>,
456    schema: SchemaRef,
457    state: ProcedureCallState,
458    metrics: BaselineMetrics,
459}
460
461impl ProcedureCallStream {
462    fn new(
463        graph_ctx: Arc<GraphExecutionContext>,
464        procedure_name: String,
465        evaluated_args: Vec<Value>,
466        yield_items: Vec<(String, Option<String>)>,
467        target_properties: HashMap<String, Vec<String>>,
468        schema: SchemaRef,
469        metrics: BaselineMetrics,
470    ) -> Self {
471        Self {
472            graph_ctx,
473            procedure_name,
474            evaluated_args,
475            yield_items,
476            target_properties,
477            schema,
478            state: ProcedureCallState::Init,
479            metrics,
480        }
481    }
482}
483
484impl Stream for ProcedureCallStream {
485    type Item = DFResult<RecordBatch>;
486
487    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
488        loop {
489            let state = std::mem::replace(&mut self.state, ProcedureCallState::Done);
490
491            match state {
492                ProcedureCallState::Init => {
493                    let graph_ctx = self.graph_ctx.clone();
494                    let procedure_name = self.procedure_name.clone();
495                    let evaluated_args = self.evaluated_args.clone();
496                    let yield_items = self.yield_items.clone();
497                    let target_properties = self.target_properties.clone();
498                    let schema = self.schema.clone();
499
500                    let fut = async move {
501                        graph_ctx.check_timeout().map_err(|e| {
502                            datafusion::error::DataFusionError::Execution(e.to_string())
503                        })?;
504
505                        execute_procedure(
506                            &graph_ctx,
507                            &procedure_name,
508                            &evaluated_args,
509                            &yield_items,
510                            &target_properties,
511                            &schema,
512                        )
513                        .await
514                    };
515
516                    self.state = ProcedureCallState::Executing(Box::pin(fut));
517                }
518                ProcedureCallState::Executing(mut fut) => match fut.as_mut().poll(cx) {
519                    Poll::Ready(Ok(batch)) => {
520                        self.state = ProcedureCallState::Done;
521                        self.metrics
522                            .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
523                        return Poll::Ready(batch.map(Ok));
524                    }
525                    Poll::Ready(Err(e)) => {
526                        self.state = ProcedureCallState::Done;
527                        return Poll::Ready(Some(Err(e)));
528                    }
529                    Poll::Pending => {
530                        self.state = ProcedureCallState::Executing(fut);
531                        return Poll::Pending;
532                    }
533                },
534                ProcedureCallState::Done => {
535                    return Poll::Ready(None);
536                }
537            }
538        }
539    }
540}
541
542impl RecordBatchStream for ProcedureCallStream {
543    fn schema(&self) -> SchemaRef {
544        self.schema.clone()
545    }
546}
547
548// ---------------------------------------------------------------------------
549// Procedure execution dispatch
550// ---------------------------------------------------------------------------
551
552/// Execute a procedure and build a RecordBatch result.
553async fn execute_procedure(
554    graph_ctx: &GraphExecutionContext,
555    procedure_name: &str,
556    args: &[Value],
557    yield_items: &[(String, Option<String>)],
558    target_properties: &HashMap<String, Vec<String>>,
559    schema: &SchemaRef,
560) -> DFResult<Option<RecordBatch>> {
561    match procedure_name {
562        "uni.schema.labels" => execute_schema_labels(graph_ctx, yield_items, schema).await,
563        "uni.schema.edgeTypes" | "uni.schema.relationshipTypes" => {
564            execute_schema_edge_types(graph_ctx, yield_items, schema).await
565        }
566        "uni.schema.indexes" => execute_schema_indexes(graph_ctx, yield_items, schema).await,
567        "uni.schema.constraints" => {
568            execute_schema_constraints(graph_ctx, yield_items, schema).await
569        }
570        "uni.schema.labelInfo" => {
571            execute_schema_label_info(graph_ctx, args, yield_items, schema).await
572        }
573        "uni.vector.query" => {
574            execute_vector_query(graph_ctx, args, yield_items, target_properties, schema).await
575        }
576        "uni.fts.query" => {
577            execute_fts_query(graph_ctx, args, yield_items, target_properties, schema).await
578        }
579        "uni.search" => {
580            execute_hybrid_search(graph_ctx, args, yield_items, target_properties, schema).await
581        }
582        name if name.starts_with("uni.algo.") => {
583            execute_algo_procedure(graph_ctx, name, args, yield_items, schema).await
584        }
585        _ => {
586            execute_registered_procedure(graph_ctx, procedure_name, args, yield_items, schema).await
587        }
588    }
589}
590
591// ---------------------------------------------------------------------------
592// Schema procedures
593// ---------------------------------------------------------------------------
594
595async fn execute_schema_labels(
596    graph_ctx: &GraphExecutionContext,
597    yield_items: &[(String, Option<String>)],
598    schema: &SchemaRef,
599) -> DFResult<Option<RecordBatch>> {
600    let uni_schema = graph_ctx.storage().schema_manager().schema();
601    let storage = graph_ctx.storage();
602
603    // Collect rows: one per label
604    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
605    for label_name in uni_schema.labels.keys() {
606        let mut row = HashMap::new();
607        row.insert("label".to_string(), Value::String(label_name.clone()));
608
609        let prop_count = uni_schema
610            .properties
611            .get(label_name)
612            .map(|p| p.len())
613            .unwrap_or(0);
614        row.insert("propertyCount".to_string(), Value::Int(prop_count as i64));
615
616        let node_count = if let Ok(ds) = storage.vertex_dataset(label_name) {
617            if let Ok(raw) = ds.open_raw().await {
618                raw.count_rows(None).await.unwrap_or(0)
619            } else {
620                0
621            }
622        } else {
623            0
624        };
625        row.insert("nodeCount".to_string(), Value::Int(node_count as i64));
626
627        let idx_count = uni_schema
628            .indexes
629            .iter()
630            .filter(|i| i.label() == label_name)
631            .count();
632        row.insert("indexCount".to_string(), Value::Int(idx_count as i64));
633
634        rows.push(row);
635    }
636
637    build_scalar_batch(&rows, yield_items, schema)
638}
639
640async fn execute_schema_edge_types(
641    graph_ctx: &GraphExecutionContext,
642    yield_items: &[(String, Option<String>)],
643    schema: &SchemaRef,
644) -> DFResult<Option<RecordBatch>> {
645    let uni_schema = graph_ctx.storage().schema_manager().schema();
646
647    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
648    for (type_name, meta) in &uni_schema.edge_types {
649        let mut row = HashMap::new();
650        row.insert("type".to_string(), Value::String(type_name.clone()));
651        row.insert(
652            "relationshipType".to_string(),
653            Value::String(type_name.clone()),
654        );
655        row.insert(
656            "sourceLabels".to_string(),
657            Value::String(format!("{:?}", meta.src_labels)),
658        );
659        row.insert(
660            "targetLabels".to_string(),
661            Value::String(format!("{:?}", meta.dst_labels)),
662        );
663
664        let prop_count = uni_schema
665            .properties
666            .get(type_name)
667            .map(|p| p.len())
668            .unwrap_or(0);
669        row.insert("propertyCount".to_string(), Value::Int(prop_count as i64));
670
671        rows.push(row);
672    }
673
674    build_scalar_batch(&rows, yield_items, schema)
675}
676
677async fn execute_schema_indexes(
678    graph_ctx: &GraphExecutionContext,
679    yield_items: &[(String, Option<String>)],
680    schema: &SchemaRef,
681) -> DFResult<Option<RecordBatch>> {
682    let uni_schema = graph_ctx.storage().schema_manager().schema();
683
684    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
685    for idx in &uni_schema.indexes {
686        use uni_common::core::schema::IndexDefinition;
687
688        // Extract type name and properties JSON per variant
689        let (type_name, properties_json) = match &idx {
690            IndexDefinition::Vector(v) => (
691                "VECTOR",
692                serde_json::to_string(&[&v.property]).unwrap_or_default(),
693            ),
694            IndexDefinition::FullText(f) => (
695                "FULLTEXT",
696                serde_json::to_string(&f.properties).unwrap_or_default(),
697            ),
698            IndexDefinition::Scalar(s) => (
699                "SCALAR",
700                serde_json::to_string(&s.properties).unwrap_or_default(),
701            ),
702            IndexDefinition::JsonFullText(j) => (
703                "JSON_FTS",
704                serde_json::to_string(&[&j.column]).unwrap_or_default(),
705            ),
706            IndexDefinition::Inverted(inv) => (
707                "INVERTED",
708                serde_json::to_string(&[&inv.property]).unwrap_or_default(),
709            ),
710            _ => ("UNKNOWN", String::new()),
711        };
712
713        let row = HashMap::from([
714            ("state".to_string(), Value::String("ONLINE".to_string())),
715            ("name".to_string(), Value::String(idx.name().to_string())),
716            ("type".to_string(), Value::String(type_name.to_string())),
717            ("label".to_string(), Value::String(idx.label().to_string())),
718            ("properties".to_string(), Value::String(properties_json)),
719        ]);
720        rows.push(row);
721    }
722
723    build_scalar_batch(&rows, yield_items, schema)
724}
725
726async fn execute_schema_constraints(
727    graph_ctx: &GraphExecutionContext,
728    yield_items: &[(String, Option<String>)],
729    schema: &SchemaRef,
730) -> DFResult<Option<RecordBatch>> {
731    let uni_schema = graph_ctx.storage().schema_manager().schema();
732
733    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
734    for c in &uni_schema.constraints {
735        let mut row = HashMap::new();
736        row.insert("name".to_string(), Value::String(c.name.clone()));
737        row.insert("enabled".to_string(), Value::Bool(c.enabled));
738
739        match &c.constraint_type {
740            uni_common::core::schema::ConstraintType::Unique { properties } => {
741                row.insert("type".to_string(), Value::String("UNIQUE".to_string()));
742                row.insert(
743                    "properties".to_string(),
744                    Value::String(serde_json::to_string(&properties).unwrap_or_default()),
745                );
746            }
747            uni_common::core::schema::ConstraintType::Exists { property } => {
748                row.insert("type".to_string(), Value::String("EXISTS".to_string()));
749                row.insert(
750                    "properties".to_string(),
751                    Value::String(serde_json::to_string(&[&property]).unwrap_or_default()),
752                );
753            }
754            uni_common::core::schema::ConstraintType::Check { expression } => {
755                row.insert("type".to_string(), Value::String("CHECK".to_string()));
756                row.insert("expression".to_string(), Value::String(expression.clone()));
757            }
758            _ => {
759                row.insert("type".to_string(), Value::String("UNKNOWN".to_string()));
760            }
761        }
762
763        match &c.target {
764            uni_common::core::schema::ConstraintTarget::Label(l) => {
765                row.insert("label".to_string(), Value::String(l.clone()));
766            }
767            uni_common::core::schema::ConstraintTarget::EdgeType(t) => {
768                row.insert("relationshipType".to_string(), Value::String(t.clone()));
769            }
770            _ => {
771                row.insert("target".to_string(), Value::String("UNKNOWN".to_string()));
772            }
773        }
774
775        rows.push(row);
776    }
777
778    build_scalar_batch(&rows, yield_items, schema)
779}
780
781async fn execute_schema_label_info(
782    graph_ctx: &GraphExecutionContext,
783    args: &[Value],
784    yield_items: &[(String, Option<String>)],
785    schema: &SchemaRef,
786) -> DFResult<Option<RecordBatch>> {
787    let label_name = require_string_arg(args, 0, "uni.schema.labelInfo: first argument (label)")?;
788
789    let uni_schema = graph_ctx.storage().schema_manager().schema();
790
791    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
792    if let Some(props) = uni_schema.properties.get(&label_name) {
793        for (prop_name, prop_meta) in props {
794            let mut row = HashMap::new();
795            row.insert("property".to_string(), Value::String(prop_name.clone()));
796            row.insert(
797                "dataType".to_string(),
798                Value::String(format!("{:?}", prop_meta.r#type)),
799            );
800            row.insert("nullable".to_string(), Value::Bool(prop_meta.nullable));
801
802            let is_indexed = uni_schema.indexes.iter().any(|idx| match idx {
803                uni_common::core::schema::IndexDefinition::Vector(v) => {
804                    v.label == label_name && v.property == *prop_name
805                }
806                uni_common::core::schema::IndexDefinition::Scalar(s) => {
807                    s.label == label_name && s.properties.contains(prop_name)
808                }
809                uni_common::core::schema::IndexDefinition::FullText(f) => {
810                    f.label == label_name && f.properties.contains(prop_name)
811                }
812                uni_common::core::schema::IndexDefinition::Inverted(inv) => {
813                    inv.label == label_name && inv.property == *prop_name
814                }
815                uni_common::core::schema::IndexDefinition::JsonFullText(j) => j.label == label_name,
816                _ => false,
817            });
818            row.insert("indexed".to_string(), Value::Bool(is_indexed));
819
820            let unique = uni_schema.constraints.iter().any(|c| {
821                if let uni_common::core::schema::ConstraintTarget::Label(l) = &c.target
822                    && l == &label_name
823                    && c.enabled
824                    && let uni_common::core::schema::ConstraintType::Unique { properties } =
825                        &c.constraint_type
826                {
827                    return properties.contains(prop_name);
828                }
829                false
830            });
831            row.insert("unique".to_string(), Value::Bool(unique));
832
833            rows.push(row);
834        }
835    }
836
837    build_scalar_batch(&rows, yield_items, schema)
838}
839
840/// Build a typed Arrow column from an iterator of optional `Value`s.
841///
842/// Dispatches on `data_type` to build the appropriate Arrow array. For types
843/// not explicitly handled (Utf8 fallback), values are stringified.
844fn build_typed_column<'a>(
845    values: impl Iterator<Item = Option<&'a Value>>,
846    num_rows: usize,
847    data_type: &DataType,
848) -> ArrayRef {
849    match data_type {
850        DataType::Int64 => {
851            let mut builder = Int64Builder::with_capacity(num_rows);
852            for val in values {
853                match val.and_then(|v| v.as_i64()) {
854                    Some(i) => builder.append_value(i),
855                    None => builder.append_null(),
856                }
857            }
858            Arc::new(builder.finish())
859        }
860        DataType::Float64 => {
861            let mut builder = Float64Builder::with_capacity(num_rows);
862            for val in values {
863                match val.and_then(|v| v.as_f64()) {
864                    Some(f) => builder.append_value(f),
865                    None => builder.append_null(),
866                }
867            }
868            Arc::new(builder.finish())
869        }
870        DataType::Boolean => {
871            let mut builder = BooleanBuilder::with_capacity(num_rows);
872            for val in values {
873                match val.and_then(|v| v.as_bool()) {
874                    Some(b) => builder.append_value(b),
875                    None => builder.append_null(),
876                }
877            }
878            Arc::new(builder.finish())
879        }
880        _ => {
881            // Utf8 fallback: stringify values
882            let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
883            for val in values {
884                match val {
885                    Some(Value::String(s)) => builder.append_value(s),
886                    Some(v) => builder.append_value(format!("{v}")),
887                    None => builder.append_null(),
888                }
889            }
890            Arc::new(builder.finish())
891        }
892    }
893}
894
895/// Create an empty RecordBatch for the given schema.
896///
897/// When a schema has zero fields, `RecordBatch::new_empty()` panics because it
898/// cannot determine the row count from an empty array. This helper handles that
899/// edge case by using `RecordBatchOptions::with_row_count(0)`.
900fn create_empty_batch(schema: SchemaRef) -> DFResult<RecordBatch> {
901    if schema.fields().is_empty() {
902        let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
903        RecordBatch::try_new_with_options(schema, vec![], &options).map_err(arrow_err)
904    } else {
905        Ok(RecordBatch::new_empty(schema))
906    }
907}
908
909/// Build a RecordBatch from scalar-valued rows for schema procedures.
910fn build_scalar_batch(
911    rows: &[HashMap<String, Value>],
912    yield_items: &[(String, Option<String>)],
913    schema: &SchemaRef,
914) -> DFResult<Option<RecordBatch>> {
915    if rows.is_empty() {
916        return Ok(Some(create_empty_batch(schema.clone())?));
917    }
918
919    let num_rows = rows.len();
920    let mut columns: Vec<ArrayRef> = Vec::new();
921
922    for (idx, (name, _alias)) in yield_items.iter().enumerate() {
923        let field = schema.field(idx);
924        let values = rows.iter().map(|row| row.get(name));
925        columns.push(build_typed_column(values, num_rows, field.data_type()));
926    }
927
928    let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
929    Ok(Some(batch))
930}
931
932// ---------------------------------------------------------------------------
933// External/registered procedures
934// ---------------------------------------------------------------------------
935
936/// Execute an externally registered procedure (e.g., TCK test procedures).
937///
938/// Looks up the procedure in the `ProcedureRegistry`, evaluates arguments,
939/// filters data rows by matching input columns, and projects output columns.
940async fn execute_registered_procedure(
941    graph_ctx: &GraphExecutionContext,
942    procedure_name: &str,
943    args: &[Value],
944    yield_items: &[(String, Option<String>)],
945    schema: &SchemaRef,
946) -> DFResult<Option<RecordBatch>> {
947    let registry = graph_ctx.procedure_registry().ok_or_else(|| {
948        datafusion::error::DataFusionError::Execution(format!(
949            "Procedure '{}' not supported in DataFusion engine (no procedure registry)",
950            procedure_name
951        ))
952    })?;
953
954    let proc_def = registry.get(procedure_name).ok_or_else(|| {
955        datafusion::error::DataFusionError::Execution(format!(
956            "ProcedureNotFound: Unknown procedure '{}'",
957            procedure_name
958        ))
959    })?;
960
961    // Validate argument count
962    if args.len() != proc_def.params.len() {
963        return Err(datafusion::error::DataFusionError::Execution(format!(
964            "InvalidNumberOfArguments: Procedure '{}' expects {} argument(s), got {}",
965            proc_def.name,
966            proc_def.params.len(),
967            args.len()
968        )));
969    }
970
971    // Validate argument types
972    for (i, (arg_val, param)) in args.iter().zip(&proc_def.params).enumerate() {
973        if !arg_val.is_null() && !check_proc_type_compatible(arg_val, &param.param_type) {
974            return Err(datafusion::error::DataFusionError::Execution(format!(
975                "InvalidArgumentType: Argument {} ('{}') of procedure '{}' has incompatible type",
976                i, param.name, proc_def.name
977            )));
978        }
979    }
980
981    // Filter data rows: keep rows where input columns match the provided args
982    let filtered: Vec<&HashMap<String, Value>> = proc_def
983        .data
984        .iter()
985        .filter(|row| {
986            for (param, arg_val) in proc_def.params.iter().zip(args) {
987                if let Some(row_val) = row.get(&param.name)
988                    && !proc_values_match(row_val, arg_val)
989                {
990                    return false;
991                }
992            }
993            true
994        })
995        .collect();
996
997    // If the procedure has no yield items (void procedure), return empty batch
998    if yield_items.is_empty() {
999        return Ok(Some(create_empty_batch(schema.clone())?));
1000    }
1001
1002    if filtered.is_empty() {
1003        return Ok(Some(create_empty_batch(schema.clone())?));
1004    }
1005
1006    // Project output columns based on yield items
1007    // We need to map yield names back to output column names in the procedure definition
1008    let num_rows = filtered.len();
1009    let mut columns: Vec<ArrayRef> = Vec::new();
1010
1011    for (idx, (name, _alias)) in yield_items.iter().enumerate() {
1012        let field = schema.field(idx);
1013        let values = filtered.iter().map(|row| row.get(name.as_str()));
1014        columns.push(build_typed_column(values, num_rows, field.data_type()));
1015    }
1016
1017    let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
1018    Ok(Some(batch))
1019}
1020
1021/// Checks whether a value is compatible with a procedure type (DF engine version).
1022fn check_proc_type_compatible(
1023    val: &Value,
1024    expected: &crate::query::executor::procedure::ProcedureValueType,
1025) -> bool {
1026    use crate::query::executor::procedure::ProcedureValueType;
1027    match expected {
1028        ProcedureValueType::Any => true,
1029        ProcedureValueType::String => val.is_string(),
1030        ProcedureValueType::Boolean => val.is_bool(),
1031        ProcedureValueType::Integer => val.is_i64(),
1032        ProcedureValueType::Float => val.is_f64() || val.is_i64(),
1033        ProcedureValueType::Number => val.is_number(),
1034    }
1035}
1036
1037/// Checks whether two values match for input-column filtering (DF engine version).
1038fn proc_values_match(row_val: &Value, arg_val: &Value) -> bool {
1039    if arg_val.is_null() || row_val.is_null() {
1040        return arg_val.is_null() && row_val.is_null();
1041    }
1042    // Compare numbers by f64 to handle int/float cross-comparison
1043    if let (Some(a), Some(b)) = (row_val.as_f64(), arg_val.as_f64()) {
1044        return (a - b).abs() < f64::EPSILON;
1045    }
1046    row_val == arg_val
1047}
1048
1049// ---------------------------------------------------------------------------
1050// Algorithm procedures
1051// ---------------------------------------------------------------------------
1052
1053async fn execute_algo_procedure(
1054    graph_ctx: &GraphExecutionContext,
1055    procedure_name: &str,
1056    args: &[Value],
1057    yield_items: &[(String, Option<String>)],
1058    schema: &SchemaRef,
1059) -> DFResult<Option<RecordBatch>> {
1060    use futures::StreamExt;
1061    use uni_algo::algo::procedures::AlgoContext;
1062
1063    let registry = graph_ctx.algo_registry().ok_or_else(|| {
1064        datafusion::error::DataFusionError::Execution(
1065            "Algorithm registry not available".to_string(),
1066        )
1067    })?;
1068
1069    let procedure = registry.get(procedure_name).ok_or_else(|| {
1070        datafusion::error::DataFusionError::Execution(format!(
1071            "Unknown algorithm: {}",
1072            procedure_name
1073        ))
1074    })?;
1075
1076    let signature = procedure.signature();
1077
1078    // Convert uni_common::Value args to serde_json::Value for algo crate.
1079    // Note: do NOT call validate_args here — the procedure's own execute()
1080    // already validates and fills defaults internally.
1081    let serde_args: Vec<serde_json::Value> = args.iter().cloned().map(|v| v.into()).collect();
1082
1083    // Build AlgoContext — no L0Manager in the DF path (read-only snapshot)
1084    let algo_ctx = AlgoContext::new(graph_ctx.storage().clone(), None);
1085
1086    // Execute and collect stream
1087    let mut stream = procedure.execute(algo_ctx, serde_args);
1088    let mut rows = Vec::new();
1089    while let Some(row_res) = stream.next().await {
1090        // Check timeout periodically
1091        if rows.len() % 1000 == 0 {
1092            graph_ctx
1093                .check_timeout()
1094                .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1095        }
1096        let row =
1097            row_res.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1098        rows.push(row);
1099    }
1100
1101    build_algo_batch(&rows, &signature, yield_items, schema)
1102}
1103
1104/// Convert a `serde_json::Value` to a `uni_common::Value` for column building.
1105fn json_to_value(jv: &serde_json::Value) -> Value {
1106    match jv {
1107        serde_json::Value::Null => Value::Null,
1108        serde_json::Value::Bool(b) => Value::Bool(*b),
1109        serde_json::Value::Number(n) => {
1110            if let Some(i) = n.as_i64() {
1111                Value::Int(i)
1112            } else if let Some(f) = n.as_f64() {
1113                Value::Float(f)
1114            } else {
1115                Value::Null
1116            }
1117        }
1118        serde_json::Value::String(s) => Value::String(s.clone()),
1119        other => Value::String(other.to_string()),
1120    }
1121}
1122
1123/// Build a RecordBatch from algorithm result rows.
1124fn build_algo_batch(
1125    rows: &[uni_algo::algo::procedures::AlgoResultRow],
1126    signature: &uni_algo::algo::procedures::ProcedureSignature,
1127    yield_items: &[(String, Option<String>)],
1128    schema: &SchemaRef,
1129) -> DFResult<Option<RecordBatch>> {
1130    if rows.is_empty() {
1131        return Ok(Some(create_empty_batch(schema.clone())?));
1132    }
1133
1134    let num_rows = rows.len();
1135    let mut columns: Vec<ArrayRef> = Vec::new();
1136
1137    for (idx, (yield_name, _alias)) in yield_items.iter().enumerate() {
1138        let sig_idx = signature
1139            .yields
1140            .iter()
1141            .position(|(n, _)| *n == yield_name.as_str());
1142
1143        // Convert serde_json values to uni_common::Value for the shared column builder
1144        let uni_values: Vec<Value> = rows
1145            .iter()
1146            .map(|row| match sig_idx {
1147                Some(si) => json_to_value(&row.values[si]),
1148                None => Value::Null,
1149            })
1150            .collect();
1151
1152        let field = schema.field(idx);
1153        let values = uni_values.iter().map(Some);
1154        columns.push(build_typed_column(values, num_rows, field.data_type()));
1155    }
1156
1157    let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
1158    Ok(Some(batch))
1159}
1160
1161// ---------------------------------------------------------------------------
1162// Shared search argument helpers
1163// ---------------------------------------------------------------------------
1164
1165/// Extract a required string argument from the argument list at a given position.
1166fn require_string_arg(args: &[Value], index: usize, description: &str) -> DFResult<String> {
1167    args.get(index)
1168        .and_then(|v| v.as_str())
1169        .map(|s| s.to_string())
1170        .ok_or_else(|| {
1171            datafusion::error::DataFusionError::Execution(format!("{description} must be a string"))
1172        })
1173}
1174
1175/// Extract an optional filter string from the argument list.
1176/// Returns `None` if the argument is missing, null, or not a string.
1177fn extract_optional_filter(args: &[Value], index: usize) -> Option<String> {
1178    args.get(index).and_then(|v| {
1179        if v.is_null() {
1180            None
1181        } else {
1182            v.as_str().map(|s| s.to_string())
1183        }
1184    })
1185}
1186
1187/// Extract an optional float threshold from the argument list.
1188/// Returns `None` if the argument is missing or null.
1189fn extract_optional_threshold(args: &[Value], index: usize) -> Option<f64> {
1190    args.get(index)
1191        .and_then(|v| if v.is_null() { None } else { v.as_f64() })
1192}
1193
1194/// Extract a required integer argument from the argument list at a given position.
1195fn require_int_arg(args: &[Value], index: usize, description: &str) -> DFResult<usize> {
1196    args.get(index)
1197        .and_then(|v| v.as_u64())
1198        .map(|v| v as usize)
1199        .ok_or_else(|| {
1200            datafusion::error::DataFusionError::Execution(format!(
1201                "{description} must be an integer"
1202            ))
1203        })
1204}
1205
1206// ---------------------------------------------------------------------------
1207// Vector/FTS/Hybrid search procedures
1208// ---------------------------------------------------------------------------
1209
1210/// Auto-embed a text query using the vector index's embedding configuration.
1211///
1212/// Looks up the embedding config from the index on `label.property` and uses
1213/// it to embed the provided text query into a vector.
1214async fn auto_embed_text(
1215    graph_ctx: &GraphExecutionContext,
1216    label: &str,
1217    property: &str,
1218    query_text: &str,
1219) -> DFResult<Vec<f32>> {
1220    let storage = graph_ctx.storage();
1221    let uni_schema = storage.schema_manager().schema();
1222    let index_config = uni_schema.vector_index_for_property(label, property);
1223
1224    let embedding_config = index_config
1225        .and_then(|cfg| cfg.embedding_config.as_ref())
1226        .ok_or_else(|| {
1227            datafusion::error::DataFusionError::Execution(format!(
1228                "Cannot auto-embed: vector index for {label}.{property} has no embedding_config. \
1229                 Either provide a pre-computed vector or create the index with embedding options."
1230            ))
1231        })?;
1232
1233    let runtime = graph_ctx.xervo_runtime().ok_or_else(|| {
1234        datafusion::error::DataFusionError::Execution(
1235            "Cannot auto-embed: Uni-Xervo runtime not configured".to_string(),
1236        )
1237    })?;
1238
1239    let embedder = runtime
1240        .embedding(&embedding_config.alias)
1241        .await
1242        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1243    let embeddings = embedder
1244        .embed(vec![query_text])
1245        .await
1246        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1247    embeddings.into_iter().next().ok_or_else(|| {
1248        datafusion::error::DataFusionError::Execution(
1249            "Embedding service returned no results".to_string(),
1250        )
1251    })
1252}
1253
1254async fn execute_vector_query(
1255    graph_ctx: &GraphExecutionContext,
1256    args: &[Value],
1257    yield_items: &[(String, Option<String>)],
1258    target_properties: &HashMap<String, Vec<String>>,
1259    schema: &SchemaRef,
1260) -> DFResult<Option<RecordBatch>> {
1261    let label = require_string_arg(args, 0, "uni.vector.query: first argument (label)")?;
1262    let property = require_string_arg(args, 1, "uni.vector.query: second argument (property)")?;
1263
1264    let query_val = args.get(2).ok_or_else(|| {
1265        datafusion::error::DataFusionError::Execution(
1266            "uni.vector.query: third argument (query) is required".to_string(),
1267        )
1268    })?;
1269
1270    let storage = graph_ctx.storage();
1271
1272    let query_vector: Vec<f32> = if let Some(query_text) = query_val.as_str() {
1273        auto_embed_text(graph_ctx, &label, &property, query_text).await?
1274    } else {
1275        extract_vector(query_val)?
1276    };
1277
1278    let k = require_int_arg(args, 3, "uni.vector.query: fourth argument (k)")?;
1279    let filter = extract_optional_filter(args, 4);
1280    let threshold = extract_optional_threshold(args, 5);
1281    let query_ctx = graph_ctx.query_context();
1282
1283    let mut results = storage
1284        .vector_search(
1285            &label,
1286            &property,
1287            &query_vector,
1288            k,
1289            filter.as_deref(),
1290            Some(&query_ctx),
1291        )
1292        .await
1293        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1294
1295    // Apply threshold post-filter (on distance)
1296    if let Some(max_dist) = threshold {
1297        results.retain(|(_, dist)| *dist <= max_dist as f32);
1298    }
1299
1300    if results.is_empty() {
1301        return Ok(Some(create_empty_batch(schema.clone())?));
1302    }
1303
1304    // Calculate scores using the same logic as the old executor
1305    let schema_manager = storage.schema_manager();
1306    let uni_schema = schema_manager.schema();
1307    let metric = uni_schema
1308        .vector_index_for_property(&label, &property)
1309        .map(|config| config.metric.clone())
1310        .unwrap_or(DistanceMetric::L2);
1311
1312    build_search_result_batch(
1313        &results,
1314        &label,
1315        &metric,
1316        yield_items,
1317        target_properties,
1318        graph_ctx,
1319        schema,
1320    )
1321    .await
1322}
1323
1324// ---------------------------------------------------------------------------
1325// FTS search procedure
1326// ---------------------------------------------------------------------------
1327
1328async fn execute_fts_query(
1329    graph_ctx: &GraphExecutionContext,
1330    args: &[Value],
1331    yield_items: &[(String, Option<String>)],
1332    target_properties: &HashMap<String, Vec<String>>,
1333    schema: &SchemaRef,
1334) -> DFResult<Option<RecordBatch>> {
1335    let label = require_string_arg(args, 0, "uni.fts.query: first argument (label)")?;
1336    let property = require_string_arg(args, 1, "uni.fts.query: second argument (property)")?;
1337    let search_term = require_string_arg(args, 2, "uni.fts.query: third argument (search_term)")?;
1338    let k = require_int_arg(args, 3, "uni.fts.query: fourth argument (k)")?;
1339    let filter = extract_optional_filter(args, 4);
1340    let threshold = extract_optional_threshold(args, 5);
1341
1342    let storage = graph_ctx.storage();
1343    let query_ctx = graph_ctx.query_context();
1344
1345    let mut results = storage
1346        .fts_search(
1347            &label,
1348            &property,
1349            &search_term,
1350            k,
1351            filter.as_deref(),
1352            Some(&query_ctx),
1353        )
1354        .await
1355        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1356
1357    if let Some(min_score) = threshold {
1358        results.retain(|(_, score)| *score as f64 >= min_score);
1359    }
1360
1361    if results.is_empty() {
1362        return Ok(Some(create_empty_batch(schema.clone())?));
1363    }
1364
1365    // FTS uses a "fake" L2 metric for the batch builder — scores are already BM25
1366    // We use L2 as a placeholder; the actual score column is built differently.
1367    build_search_result_batch(
1368        &results,
1369        &label,
1370        &DistanceMetric::L2,
1371        yield_items,
1372        target_properties,
1373        graph_ctx,
1374        schema,
1375    )
1376    .await
1377}
1378
1379// ---------------------------------------------------------------------------
1380// Hybrid search procedure
1381// ---------------------------------------------------------------------------
1382
1383async fn execute_hybrid_search(
1384    graph_ctx: &GraphExecutionContext,
1385    args: &[Value],
1386    yield_items: &[(String, Option<String>)],
1387    target_properties: &HashMap<String, Vec<String>>,
1388    schema: &SchemaRef,
1389) -> DFResult<Option<RecordBatch>> {
1390    let label = require_string_arg(args, 0, "uni.search: first argument (label)")?;
1391
1392    // Parse properties: {vector: '...', fts: '...'} or just a string
1393    let properties_val = args.get(1).ok_or_else(|| {
1394        datafusion::error::DataFusionError::Execution(
1395            "uni.search: second argument (properties) is required".to_string(),
1396        )
1397    })?;
1398
1399    let (vector_prop, fts_prop) = if let Some(obj) = properties_val.as_object() {
1400        let vec_prop = obj
1401            .get("vector")
1402            .and_then(|v| v.as_str())
1403            .map(|s| s.to_string());
1404        let fts_prop = obj
1405            .get("fts")
1406            .and_then(|v| v.as_str())
1407            .map(|s| s.to_string());
1408        (vec_prop, fts_prop)
1409    } else if let Some(prop) = properties_val.as_str() {
1410        // Shorthand: just property name means both vector and FTS
1411        (Some(prop.to_string()), Some(prop.to_string()))
1412    } else {
1413        return Err(datafusion::error::DataFusionError::Execution(
1414            "Properties must be an object {vector: '...', fts: '...'} or a string".to_string(),
1415        ));
1416    };
1417
1418    let query_text = require_string_arg(args, 2, "uni.search: third argument (query_text)")?;
1419
1420    // Arg 3: query vector (optional, can be null)
1421    let query_vector: Option<Vec<f32>> = args.get(3).and_then(|v| {
1422        if v.is_null() {
1423            return None;
1424        }
1425        v.as_array().map(|arr| {
1426            arr.iter()
1427                .filter_map(|v| v.as_f64().map(|f| f as f32))
1428                .collect()
1429        })
1430    });
1431
1432    let k = require_int_arg(args, 4, "uni.search: fifth argument (k)")?;
1433    let filter = extract_optional_filter(args, 5);
1434
1435    // Arg 6: options (optional)
1436    let options_val = args.get(6);
1437    let options_map = options_val.and_then(|v| v.as_object());
1438    let fusion_method = options_map
1439        .and_then(|m| m.get("method"))
1440        .and_then(|v| v.as_str())
1441        .unwrap_or("rrf")
1442        .to_string();
1443    let alpha = options_map
1444        .and_then(|m| m.get("alpha"))
1445        .and_then(|v| v.as_f64())
1446        .unwrap_or(0.5) as f32;
1447    let over_fetch_factor = options_map
1448        .and_then(|m| m.get("over_fetch"))
1449        .and_then(|v| v.as_f64())
1450        .unwrap_or(2.0) as f32;
1451    let rrf_k = options_map
1452        .and_then(|m| m.get("rrf_k"))
1453        .and_then(|v| v.as_u64())
1454        .unwrap_or(60) as usize;
1455
1456    let over_fetch_k = (k as f32 * over_fetch_factor).ceil() as usize;
1457
1458    let storage = graph_ctx.storage();
1459    let query_ctx = graph_ctx.query_context();
1460
1461    // Execute vector search if configured
1462    let mut vector_results: Vec<(Vid, f32)> = Vec::new();
1463    if let Some(ref vec_prop) = vector_prop {
1464        // Get or generate query vector
1465        let qvec = if let Some(ref v) = query_vector {
1466            v.clone()
1467        } else {
1468            // Auto-embed the query text if embedding config exists
1469            auto_embed_text(graph_ctx, &label, vec_prop, &query_text)
1470                .await
1471                .unwrap_or_default()
1472        };
1473
1474        if !qvec.is_empty() {
1475            vector_results = storage
1476                .vector_search(
1477                    &label,
1478                    vec_prop,
1479                    &qvec,
1480                    over_fetch_k,
1481                    filter.as_deref(),
1482                    Some(&query_ctx),
1483                )
1484                .await
1485                .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1486        }
1487    }
1488
1489    // Execute FTS search if configured
1490    let mut fts_results: Vec<(Vid, f32)> = Vec::new();
1491    if let Some(ref fts_prop) = fts_prop {
1492        fts_results = storage
1493            .fts_search(
1494                &label,
1495                fts_prop,
1496                &query_text,
1497                over_fetch_k,
1498                filter.as_deref(),
1499                Some(&query_ctx),
1500            )
1501            .await
1502            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1503    }
1504
1505    // Fuse results
1506    let fused_results = match fusion_method.as_str() {
1507        "weighted" => fuse_weighted(&vector_results, &fts_results, alpha),
1508        _ => fuse_rrf(&vector_results, &fts_results, rrf_k),
1509    };
1510
1511    // Limit to k results
1512    let final_results: Vec<_> = fused_results.into_iter().take(k).collect();
1513
1514    if final_results.is_empty() {
1515        return Ok(Some(create_empty_batch(schema.clone())?));
1516    }
1517
1518    // Build lookup maps for original scores
1519    let vec_score_map: HashMap<Vid, f32> = vector_results.iter().cloned().collect();
1520    let fts_score_map: HashMap<Vid, f32> = fts_results.iter().cloned().collect();
1521    let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
1522
1523    // Get distance metric for vector score normalization
1524    let uni_schema = storage.schema_manager().schema();
1525    let metric = vector_prop
1526        .as_ref()
1527        .and_then(|vp| {
1528            uni_schema
1529                .vector_index_for_property(&label, vp)
1530                .map(|config| config.metric.clone())
1531        })
1532        .unwrap_or(DistanceMetric::L2);
1533
1534    let score_ctx = HybridScoreContext {
1535        vec_score_map: &vec_score_map,
1536        fts_score_map: &fts_score_map,
1537        fts_max,
1538        metric: &metric,
1539    };
1540
1541    build_hybrid_search_batch(
1542        &final_results,
1543        &score_ctx,
1544        &label,
1545        yield_items,
1546        target_properties,
1547        graph_ctx,
1548        schema,
1549    )
1550    .await
1551}
1552
1553/// Reciprocal Rank Fusion (RRF) for combining search results.
1554/// Delegates to the shared `fusion` module.
1555fn fuse_rrf(vec_results: &[(Vid, f32)], fts_results: &[(Vid, f32)], k: usize) -> Vec<(Vid, f32)> {
1556    crate::query::fusion::fuse_rrf(vec_results, fts_results, k)
1557}
1558
1559/// Weighted fusion: alpha * vec_score + (1 - alpha) * fts_score.
1560/// Delegates to the shared `fusion` module.
1561fn fuse_weighted(
1562    vec_results: &[(Vid, f32)],
1563    fts_results: &[(Vid, f32)],
1564    alpha: f32,
1565) -> Vec<(Vid, f32)> {
1566    crate::query::fusion::fuse_weighted(vec_results, fts_results, alpha)
1567}
1568
1569/// Precomputed score context for hybrid search batch building.
1570struct HybridScoreContext<'a> {
1571    vec_score_map: &'a HashMap<Vid, f32>,
1572    fts_score_map: &'a HashMap<Vid, f32>,
1573    fts_max: f32,
1574    metric: &'a DistanceMetric,
1575}
1576
1577/// Build a RecordBatch for hybrid search results with fused, vector, and FTS scores.
1578async fn build_hybrid_search_batch(
1579    results: &[(Vid, f32)],
1580    scores: &HybridScoreContext<'_>,
1581    label: &str,
1582    yield_items: &[(String, Option<String>)],
1583    target_properties: &HashMap<String, Vec<String>>,
1584    graph_ctx: &GraphExecutionContext,
1585    schema: &SchemaRef,
1586) -> DFResult<Option<RecordBatch>> {
1587    let num_rows = results.len();
1588    let vids: Vec<Vid> = results.iter().map(|(vid, _)| *vid).collect();
1589    let fused_scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
1590
1591    // Pre-load properties for node-like yields
1592    let property_manager = graph_ctx.property_manager();
1593    let query_ctx = graph_ctx.query_context();
1594    let uni_schema = graph_ctx.storage().schema_manager().schema();
1595    let label_props = uni_schema.properties.get(label);
1596
1597    let has_node_yield = yield_items
1598        .iter()
1599        .any(|(name, _)| map_yield_to_canonical(name) == "node");
1600
1601    let props_map = if has_node_yield {
1602        property_manager
1603            .get_batch_vertex_props_for_label(&vids, label, Some(&query_ctx))
1604            .await
1605            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?
1606    } else {
1607        HashMap::new()
1608    };
1609
1610    let mut columns: Vec<ArrayRef> = Vec::new();
1611
1612    for (name, alias) in yield_items {
1613        let output_name = alias.as_ref().unwrap_or(name);
1614        let canonical = map_yield_to_canonical(name);
1615
1616        match canonical.as_str() {
1617            "node" => {
1618                columns.extend(build_node_yield_columns(
1619                    &vids,
1620                    label,
1621                    output_name,
1622                    target_properties,
1623                    &props_map,
1624                    label_props,
1625                )?);
1626            }
1627            "vid" => {
1628                let mut builder = Int64Builder::with_capacity(num_rows);
1629                for vid in &vids {
1630                    builder.append_value(vid.as_u64() as i64);
1631                }
1632                columns.push(Arc::new(builder.finish()));
1633            }
1634            "score" => {
1635                let mut builder = Float32Builder::with_capacity(num_rows);
1636                for score in &fused_scores {
1637                    builder.append_value(*score);
1638                }
1639                columns.push(Arc::new(builder.finish()));
1640            }
1641            "vector_score" => {
1642                let mut builder = Float32Builder::with_capacity(num_rows);
1643                for vid in &vids {
1644                    if let Some(&dist) = scores.vec_score_map.get(vid) {
1645                        let score = calculate_score(dist, scores.metric);
1646                        builder.append_value(score);
1647                    } else {
1648                        builder.append_null();
1649                    }
1650                }
1651                columns.push(Arc::new(builder.finish()));
1652            }
1653            "fts_score" => {
1654                let mut builder = Float32Builder::with_capacity(num_rows);
1655                for vid in &vids {
1656                    if let Some(&raw_score) = scores.fts_score_map.get(vid) {
1657                        let norm = if scores.fts_max > 0.0 {
1658                            raw_score / scores.fts_max
1659                        } else {
1660                            0.0
1661                        };
1662                        builder.append_value(norm);
1663                    } else {
1664                        builder.append_null();
1665                    }
1666                }
1667                columns.push(Arc::new(builder.finish()));
1668            }
1669            "distance" => {
1670                // For hybrid search, distance is the vector distance if available
1671                let mut builder = Float64Builder::with_capacity(num_rows);
1672                for vid in &vids {
1673                    if let Some(&dist) = scores.vec_score_map.get(vid) {
1674                        builder.append_value(dist as f64);
1675                    } else {
1676                        builder.append_null();
1677                    }
1678                }
1679                columns.push(Arc::new(builder.finish()));
1680            }
1681            _ => {
1682                let mut builder = StringBuilder::with_capacity(num_rows, 0);
1683                for _ in 0..num_rows {
1684                    builder.append_null();
1685                }
1686                columns.push(Arc::new(builder.finish()));
1687            }
1688        }
1689    }
1690
1691    let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
1692    Ok(Some(batch))
1693}
1694
1695// ---------------------------------------------------------------------------
1696// Shared search result batch builder
1697// ---------------------------------------------------------------------------
1698
1699/// Build a RecordBatch for search procedures (vector, FTS) that yield
1700/// both node-like and scalar columns.
1701async fn build_search_result_batch(
1702    results: &[(Vid, f32)],
1703    label: &str,
1704    metric: &DistanceMetric,
1705    yield_items: &[(String, Option<String>)],
1706    target_properties: &HashMap<String, Vec<String>>,
1707    graph_ctx: &GraphExecutionContext,
1708    schema: &SchemaRef,
1709) -> DFResult<Option<RecordBatch>> {
1710    let num_rows = results.len();
1711    let vids: Vec<Vid> = results.iter().map(|(vid, _)| *vid).collect();
1712    let distances: Vec<f32> = results.iter().map(|(_, d)| *d).collect();
1713
1714    // Pre-compute scores
1715    let scores: Vec<f32> = distances
1716        .iter()
1717        .map(|dist| calculate_score(*dist, metric))
1718        .collect();
1719
1720    // Pre-load properties for all node-like yields
1721    let property_manager = graph_ctx.property_manager();
1722    let query_ctx = graph_ctx.query_context();
1723    let uni_schema = graph_ctx.storage().schema_manager().schema();
1724    let label_props = uni_schema.properties.get(label);
1725
1726    // Load properties if any node-like yield needs them
1727    let has_node_yield = yield_items
1728        .iter()
1729        .any(|(name, _)| map_yield_to_canonical(name) == "node");
1730
1731    let props_map = if has_node_yield {
1732        property_manager
1733            .get_batch_vertex_props_for_label(&vids, label, Some(&query_ctx))
1734            .await
1735            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?
1736    } else {
1737        HashMap::new()
1738    };
1739
1740    // Build columns in schema order
1741    let mut columns: Vec<ArrayRef> = Vec::new();
1742
1743    for (name, alias) in yield_items {
1744        let output_name = alias.as_ref().unwrap_or(name);
1745        let canonical = map_yield_to_canonical(name);
1746
1747        match canonical.as_str() {
1748            "node" => {
1749                columns.extend(build_node_yield_columns(
1750                    &vids,
1751                    label,
1752                    output_name,
1753                    target_properties,
1754                    &props_map,
1755                    label_props,
1756                )?);
1757            }
1758            "distance" => {
1759                let mut builder = Float64Builder::with_capacity(num_rows);
1760                for dist in &distances {
1761                    builder.append_value(*dist as f64);
1762                }
1763                columns.push(Arc::new(builder.finish()));
1764            }
1765            "score" => {
1766                let mut builder = Float32Builder::with_capacity(num_rows);
1767                for score in &scores {
1768                    builder.append_value(*score);
1769                }
1770                columns.push(Arc::new(builder.finish()));
1771            }
1772            "vid" => {
1773                let mut builder = Int64Builder::with_capacity(num_rows);
1774                for vid in &vids {
1775                    builder.append_value(vid.as_u64() as i64);
1776                }
1777                columns.push(Arc::new(builder.finish()));
1778            }
1779            _ => {
1780                // Unknown yield — emit nulls
1781                let mut builder = StringBuilder::with_capacity(num_rows, 0);
1782                for _ in 0..num_rows {
1783                    builder.append_null();
1784                }
1785                columns.push(Arc::new(builder.finish()));
1786            }
1787        }
1788    }
1789
1790    let batch = RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)?;
1791    Ok(Some(batch))
1792}
1793
1794// ---------------------------------------------------------------------------
1795// Helpers
1796// ---------------------------------------------------------------------------
1797
1798/// Build the node-yield columns (_vid, variable, _labels, property columns) shared by
1799/// search result batch builders. Returns the columns to append.
1800fn build_node_yield_columns(
1801    vids: &[Vid],
1802    label: &str,
1803    output_name: &str,
1804    target_properties: &HashMap<String, Vec<String>>,
1805    props_map: &HashMap<Vid, uni_common::Properties>,
1806    label_props: Option<&std::collections::HashMap<String, uni_common::core::schema::PropertyMeta>>,
1807) -> DFResult<Vec<ArrayRef>> {
1808    let num_rows = vids.len();
1809    let mut columns = Vec::new();
1810
1811    // _vid column
1812    let mut vid_builder = UInt64Builder::with_capacity(num_rows);
1813    for vid in vids {
1814        vid_builder.append_value(vid.as_u64());
1815    }
1816    columns.push(Arc::new(vid_builder.finish()) as ArrayRef);
1817
1818    // variable column (VID as string)
1819    let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
1820    for vid in vids {
1821        var_builder.append_value(vid.to_string());
1822    }
1823    columns.push(Arc::new(var_builder.finish()) as ArrayRef);
1824
1825    // _labels column
1826    let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
1827    for _ in 0..num_rows {
1828        labels_builder.values().append_value(label);
1829        labels_builder.append(true);
1830    }
1831    columns.push(Arc::new(labels_builder.finish()) as ArrayRef);
1832
1833    // Property columns
1834    if let Some(props) = target_properties.get(output_name) {
1835        for prop_name in props {
1836            let data_type = resolve_property_type(prop_name, label_props);
1837            let column = crate::query::df_graph::scan::build_property_column_static(
1838                vids, props_map, prop_name, &data_type,
1839            )?;
1840            columns.push(column);
1841        }
1842    }
1843
1844    Ok(columns)
1845}
1846
1847/// Extract a vector from a Value.
1848fn extract_vector(val: &Value) -> DFResult<Vec<f32>> {
1849    match val {
1850        Value::Vector(vec) => Ok(vec.clone()),
1851        Value::List(arr) => {
1852            let mut vec = Vec::with_capacity(arr.len());
1853            for v in arr {
1854                if let Some(f) = v.as_f64() {
1855                    vec.push(f as f32);
1856                } else {
1857                    return Err(datafusion::error::DataFusionError::Execution(
1858                        "Query vector must contain numbers".to_string(),
1859                    ));
1860                }
1861            }
1862            Ok(vec)
1863        }
1864        _ => Err(datafusion::error::DataFusionError::Execution(
1865            "Query vector must be a list or vector".to_string(),
1866        )),
1867    }
1868}