Skip to main content

uni_query/query/df_graph/
procedure_call.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Procedure call execution plan for DataFusion.
5//!
6//! This module provides [`GraphProcedureCallExec`], a DataFusion [`ExecutionPlan`] that
7//! executes Cypher `CALL` procedures natively within the DataFusion engine.
8//!
9//! Used for composite queries where a `CALL` is followed by `MATCH`, e.g.:
10//! ```text
11//! CALL uni.schema.labels() YIELD label
12//! MATCH (n:Person) WHERE label = 'Person'
13//! RETURN n.name, label
14//! ```
15
16use crate::query::df_graph::GraphExecutionContext;
17use crate::query::df_graph::common::{
18    compute_plan_properties, evaluate_simple_expr, labels_data_type,
19};
20use crate::query::df_graph::scan::resolve_property_type;
21use arrow_array::builder::{
22    BooleanBuilder, Float32Builder, Float64Builder, Int64Builder, StringBuilder, UInt64Builder,
23};
24use arrow_array::{ArrayRef, RecordBatch};
25use arrow_schema::{DataType, Field, Schema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use std::any::Any;
32use std::collections::HashMap;
33use std::fmt;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::task::{Context, Poll};
37use uni_common::Value;
38use uni_common::core::id::Vid;
39use uni_cypher::ast::Expr;
40
41/// Maps a user-provided yield name to a canonical name.
42///
43/// - "vid", "_vid" → "vid"
44/// - "distance", "dist", "_distance" → "distance"
45/// - "score", "_score" → "score"
46/// - anything else → "node" (treated as node variable)
47pub(crate) fn map_yield_to_canonical(yield_name: &str) -> String {
48    match yield_name.to_lowercase().as_str() {
49        "vid" | "_vid" => "vid",
50        "distance" | "dist" | "_distance" => "distance",
51        "score" | "_score" => "score",
52        "vector_score" => "vector_score",
53        "fts_score" => "fts_score",
54        "raw_score" => "raw_score",
55        _ => "node",
56    }
57    .to_string()
58}
59
60/// Procedure call execution plan for DataFusion.
61///
62/// Executes Cypher CALL procedures (schema introspection, vector search, FTS, etc.)
63/// and emits results as Arrow RecordBatches.
64pub struct GraphProcedureCallExec {
65    /// Graph execution context for storage access.
66    graph_ctx: Arc<GraphExecutionContext>,
67
68    /// Fully qualified procedure name (e.g. "uni.schema.labels").
69    procedure_name: String,
70
71    /// Argument expressions from the CALL clause.
72    arguments: Vec<Expr>,
73
74    /// Yield items: (original_name, optional_alias).
75    yield_items: Vec<(String, Option<String>)>,
76
77    /// Query parameters for expression evaluation.
78    params: HashMap<String, Value>,
79
80    /// Target properties per variable (for node-like yields).
81    target_properties: HashMap<String, Vec<String>>,
82
83    /// Output schema.
84    schema: SchemaRef,
85
86    /// Plan properties.
87    properties: PlanProperties,
88
89    /// Execution metrics.
90    metrics: ExecutionPlanMetricsSet,
91}
92
93impl fmt::Debug for GraphProcedureCallExec {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        f.debug_struct("GraphProcedureCallExec")
96            .field("procedure_name", &self.procedure_name)
97            .field("yield_items", &self.yield_items)
98            .finish()
99    }
100}
101
102impl GraphProcedureCallExec {
103    /// Create a new procedure call execution plan.
104    pub fn new(
105        graph_ctx: Arc<GraphExecutionContext>,
106        procedure_name: String,
107        arguments: Vec<Expr>,
108        yield_items: Vec<(String, Option<String>)>,
109        params: HashMap<String, Value>,
110        target_properties: HashMap<String, Vec<String>>,
111    ) -> Self {
112        let schema = Self::build_schema(
113            &procedure_name,
114            &yield_items,
115            &target_properties,
116            &graph_ctx,
117        );
118        let properties = compute_plan_properties(schema.clone());
119
120        Self {
121            graph_ctx,
122            procedure_name,
123            arguments,
124            yield_items,
125            params,
126            target_properties,
127            schema,
128            properties,
129            metrics: ExecutionPlanMetricsSet::new(),
130        }
131    }
132
133    /// Build the output schema based on the procedure name and yield items.
134    fn build_schema(
135        procedure_name: &str,
136        yield_items: &[(String, Option<String>)],
137        target_properties: &HashMap<String, Vec<String>>,
138        graph_ctx: &GraphExecutionContext,
139    ) -> SchemaRef {
140        let mut fields = Vec::new();
141
142        match procedure_name {
143            "uni.schema.labels" => {
144                // Schema procedure yields scalar columns
145                for (name, alias) in yield_items {
146                    let col_name = alias.as_ref().unwrap_or(name);
147                    let data_type = match name.as_str() {
148                        "label" => DataType::Utf8,
149                        "propertyCount" | "nodeCount" | "indexCount" => DataType::Int64,
150                        _ => DataType::Utf8,
151                    };
152                    fields.push(Field::new(col_name, data_type, true));
153                }
154            }
155            "uni.schema.edgeTypes" | "uni.schema.relationshipTypes" => {
156                for (name, alias) in yield_items {
157                    let col_name = alias.as_ref().unwrap_or(name);
158                    let data_type = match name.as_str() {
159                        "type" | "relationshipType" => DataType::Utf8,
160                        "propertyCount" => DataType::Int64,
161                        "sourceLabels" | "targetLabels" => DataType::Utf8, // JSON string
162                        _ => DataType::Utf8,
163                    };
164                    fields.push(Field::new(col_name, data_type, true));
165                }
166            }
167            "uni.schema.indexes" => {
168                for (name, alias) in yield_items {
169                    let col_name = alias.as_ref().unwrap_or(name);
170                    let data_type = match name.as_str() {
171                        "name" | "type" | "label" | "state" | "properties" => DataType::Utf8,
172                        _ => DataType::Utf8,
173                    };
174                    fields.push(Field::new(col_name, data_type, true));
175                }
176            }
177            "uni.schema.constraints" => {
178                for (name, alias) in yield_items {
179                    let col_name = alias.as_ref().unwrap_or(name);
180                    let data_type = match name.as_str() {
181                        "enabled" => DataType::Boolean,
182                        _ => DataType::Utf8,
183                    };
184                    fields.push(Field::new(col_name, data_type, true));
185                }
186            }
187            "uni.schema.labelInfo" => {
188                for (name, alias) in yield_items {
189                    let col_name = alias.as_ref().unwrap_or(name);
190                    let data_type = match name.as_str() {
191                        "property" | "dataType" => DataType::Utf8,
192                        "nullable" | "indexed" | "unique" => DataType::Boolean,
193                        _ => DataType::Utf8,
194                    };
195                    fields.push(Field::new(col_name, data_type, true));
196                }
197            }
198            "uni.vector.query" | "uni.fts.query" | "uni.search" => {
199                // Search procedures yield node-like and scalar columns
200                for (name, alias) in yield_items {
201                    let output_name = alias.as_ref().unwrap_or(name);
202                    let canonical = map_yield_to_canonical(name);
203
204                    match canonical.as_str() {
205                        "node" => {
206                            // Node-like yield: emit _vid, variable, _label, and properties
207                            fields.push(Field::new(
208                                format!("{}._vid", output_name),
209                                DataType::UInt64,
210                                false,
211                            ));
212                            fields.push(Field::new(output_name, DataType::Utf8, false));
213                            fields.push(Field::new(
214                                format!("{}._labels", output_name),
215                                labels_data_type(),
216                                true,
217                            ));
218
219                            // Add property columns
220                            if let Some(props) = target_properties.get(output_name.as_str()) {
221                                let uni_schema = graph_ctx.storage().schema_manager().schema();
222                                // We don't know the exact label yet at planning time,
223                                // but we can try to resolve property types from any label
224                                for prop_name in props {
225                                    let col_name = format!("{}.{}", output_name, prop_name);
226                                    let arrow_type = resolve_property_type(prop_name, None);
227                                    // Try to resolve from all labels in the schema
228                                    let resolved_type = uni_schema
229                                        .properties
230                                        .values()
231                                        .find_map(|label_props| {
232                                            label_props.get(prop_name.as_str()).map(|_| {
233                                                resolve_property_type(prop_name, Some(label_props))
234                                            })
235                                        })
236                                        .unwrap_or(arrow_type);
237                                    fields.push(Field::new(&col_name, resolved_type, true));
238                                }
239                            }
240                        }
241                        "distance" => {
242                            fields.push(Field::new(output_name, DataType::Float64, true));
243                        }
244                        "score" | "vector_score" | "fts_score" | "raw_score" => {
245                            fields.push(Field::new(output_name, DataType::Float32, true));
246                        }
247                        "vid" => {
248                            fields.push(Field::new(output_name, DataType::Int64, true));
249                        }
250                        _ => {
251                            fields.push(Field::new(output_name, DataType::Utf8, true));
252                        }
253                    }
254                }
255            }
256            name if name.starts_with("uni.algo.") => {
257                if let Some(registry) = graph_ctx.algo_registry()
258                    && let Some(procedure) = registry.get(name)
259                {
260                    let sig = procedure.signature();
261                    for (yield_name, alias) in yield_items {
262                        let col_name = alias.as_ref().unwrap_or(yield_name);
263                        let yield_vt = sig.yields.iter().find(|(n, _)| *n == yield_name.as_str());
264                        let data_type = yield_vt
265                            .map(|(_, vt)| value_type_to_arrow(vt))
266                            .unwrap_or(DataType::Utf8);
267                        let mut field = Field::new(col_name, data_type, true);
268                        // Tag complex types (List, Map, etc.) so record_batches_to_rows
269                        // can parse the JSON string back to the original type.
270                        if yield_vt.is_some_and(|(_, vt)| is_complex_value_type(vt)) {
271                            let mut metadata = std::collections::HashMap::new();
272                            metadata.insert("cv_encoded".to_string(), "true".to_string());
273                            field = field.with_metadata(metadata);
274                        }
275                        fields.push(field);
276                    }
277                } else {
278                    // Unknown algo or no registry: fallback to Utf8
279                    for (name, alias) in yield_items {
280                        let col_name = alias.as_ref().unwrap_or(name);
281                        fields.push(Field::new(col_name, DataType::Utf8, true));
282                    }
283                }
284            }
285            _ => {
286                // Check external procedure registry for type information
287                if let Some(registry) = graph_ctx.procedure_registry()
288                    && let Some(proc_def) = registry.get(procedure_name)
289                {
290                    for (name, alias) in yield_items {
291                        let col_name = alias.as_ref().unwrap_or(name);
292                        // Find the output type from the procedure definition
293                        let data_type = proc_def
294                            .outputs
295                            .iter()
296                            .find(|o| o.name == *name)
297                            .map(|o| procedure_value_type_to_arrow(&o.output_type))
298                            .unwrap_or(DataType::Utf8);
299                        fields.push(Field::new(col_name, data_type, true));
300                    }
301                } else if yield_items.is_empty() {
302                    // Void procedure (no YIELD) — no output columns
303                } else {
304                    // Unknown procedure without registry: fallback to Utf8
305                    for (name, alias) in yield_items {
306                        let col_name = alias.as_ref().unwrap_or(name);
307                        fields.push(Field::new(col_name, DataType::Utf8, true));
308                    }
309                }
310            }
311        }
312
313        Arc::new(Schema::new(fields))
314    }
315}
316
317/// Convert an algorithm `ValueType` to an Arrow `DataType`.
318fn value_type_to_arrow(vt: &uni_algo::algo::procedures::ValueType) -> DataType {
319    use uni_algo::algo::procedures::ValueType;
320    match vt {
321        ValueType::Int => DataType::Int64,
322        ValueType::Float => DataType::Float64,
323        ValueType::String => DataType::Utf8,
324        ValueType::Bool => DataType::Boolean,
325        ValueType::List
326        | ValueType::Map
327        | ValueType::Node
328        | ValueType::Relationship
329        | ValueType::Path
330        | ValueType::Any => DataType::Utf8,
331    }
332}
333
334/// Returns true if the ValueType is a complex type that should be JSON-encoded as Utf8
335/// and tagged with `cv_encoded=true` metadata for downstream parsing.
336fn is_complex_value_type(vt: &uni_algo::algo::procedures::ValueType) -> bool {
337    use uni_algo::algo::procedures::ValueType;
338    matches!(
339        vt,
340        ValueType::List
341            | ValueType::Map
342            | ValueType::Node
343            | ValueType::Relationship
344            | ValueType::Path
345    )
346}
347
348/// Convert a `ProcedureValueType` to an Arrow `DataType`.
349fn procedure_value_type_to_arrow(
350    vt: &crate::query::executor::procedure::ProcedureValueType,
351) -> DataType {
352    use crate::query::executor::procedure::ProcedureValueType;
353    match vt {
354        ProcedureValueType::Integer => DataType::Int64,
355        ProcedureValueType::Float | ProcedureValueType::Number => DataType::Float64,
356        ProcedureValueType::Boolean => DataType::Boolean,
357        ProcedureValueType::String | ProcedureValueType::Any => DataType::Utf8,
358    }
359}
360
361impl DisplayAs for GraphProcedureCallExec {
362    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363        write!(
364            f,
365            "GraphProcedureCallExec: procedure={}",
366            self.procedure_name
367        )
368    }
369}
370
371impl ExecutionPlan for GraphProcedureCallExec {
372    fn name(&self) -> &str {
373        "GraphProcedureCallExec"
374    }
375
376    fn as_any(&self) -> &dyn Any {
377        self
378    }
379
380    fn schema(&self) -> SchemaRef {
381        self.schema.clone()
382    }
383
384    fn properties(&self) -> &PlanProperties {
385        &self.properties
386    }
387
388    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
389        vec![]
390    }
391
392    fn with_new_children(
393        self: Arc<Self>,
394        children: Vec<Arc<dyn ExecutionPlan>>,
395    ) -> DFResult<Arc<dyn ExecutionPlan>> {
396        if !children.is_empty() {
397            return Err(datafusion::error::DataFusionError::Internal(
398                "GraphProcedureCallExec has no children".to_string(),
399            ));
400        }
401        Ok(self)
402    }
403
404    fn execute(
405        &self,
406        partition: usize,
407        _context: Arc<TaskContext>,
408    ) -> DFResult<SendableRecordBatchStream> {
409        let metrics = BaselineMetrics::new(&self.metrics, partition);
410
411        // Evaluate arguments upfront
412        let mut evaluated_args = Vec::with_capacity(self.arguments.len());
413        for arg in &self.arguments {
414            evaluated_args.push(evaluate_simple_expr(arg, &self.params)?);
415        }
416
417        Ok(Box::pin(ProcedureCallStream::new(
418            self.graph_ctx.clone(),
419            self.procedure_name.clone(),
420            evaluated_args,
421            self.yield_items.clone(),
422            self.target_properties.clone(),
423            self.schema.clone(),
424            metrics,
425        )))
426    }
427
428    fn metrics(&self) -> Option<MetricsSet> {
429        Some(self.metrics.clone_inner())
430    }
431}
432
433// ---------------------------------------------------------------------------
434// Stream implementation
435// ---------------------------------------------------------------------------
436
437/// State machine for procedure call stream.
438enum ProcedureCallState {
439    /// Initial state, ready to start execution.
440    Init,
441    /// Executing the async procedure.
442    Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
443    /// Stream is done.
444    Done,
445}
446
447/// Stream that executes a procedure call.
448struct ProcedureCallStream {
449    graph_ctx: Arc<GraphExecutionContext>,
450    procedure_name: String,
451    evaluated_args: Vec<Value>,
452    yield_items: Vec<(String, Option<String>)>,
453    target_properties: HashMap<String, Vec<String>>,
454    schema: SchemaRef,
455    state: ProcedureCallState,
456    metrics: BaselineMetrics,
457}
458
459impl ProcedureCallStream {
460    fn new(
461        graph_ctx: Arc<GraphExecutionContext>,
462        procedure_name: String,
463        evaluated_args: Vec<Value>,
464        yield_items: Vec<(String, Option<String>)>,
465        target_properties: HashMap<String, Vec<String>>,
466        schema: SchemaRef,
467        metrics: BaselineMetrics,
468    ) -> Self {
469        Self {
470            graph_ctx,
471            procedure_name,
472            evaluated_args,
473            yield_items,
474            target_properties,
475            schema,
476            state: ProcedureCallState::Init,
477            metrics,
478        }
479    }
480}
481
482impl Stream for ProcedureCallStream {
483    type Item = DFResult<RecordBatch>;
484
485    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
486        loop {
487            let state = std::mem::replace(&mut self.state, ProcedureCallState::Done);
488
489            match state {
490                ProcedureCallState::Init => {
491                    let graph_ctx = self.graph_ctx.clone();
492                    let procedure_name = self.procedure_name.clone();
493                    let evaluated_args = self.evaluated_args.clone();
494                    let yield_items = self.yield_items.clone();
495                    let target_properties = self.target_properties.clone();
496                    let schema = self.schema.clone();
497
498                    let fut = async move {
499                        graph_ctx.check_timeout().map_err(|e| {
500                            datafusion::error::DataFusionError::Execution(e.to_string())
501                        })?;
502
503                        execute_procedure(
504                            &graph_ctx,
505                            &procedure_name,
506                            &evaluated_args,
507                            &yield_items,
508                            &target_properties,
509                            &schema,
510                        )
511                        .await
512                    };
513
514                    self.state = ProcedureCallState::Executing(Box::pin(fut));
515                }
516                ProcedureCallState::Executing(mut fut) => match fut.as_mut().poll(cx) {
517                    Poll::Ready(Ok(batch)) => {
518                        self.state = ProcedureCallState::Done;
519                        self.metrics
520                            .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
521                        return Poll::Ready(batch.map(Ok));
522                    }
523                    Poll::Ready(Err(e)) => {
524                        self.state = ProcedureCallState::Done;
525                        return Poll::Ready(Some(Err(e)));
526                    }
527                    Poll::Pending => {
528                        self.state = ProcedureCallState::Executing(fut);
529                        return Poll::Pending;
530                    }
531                },
532                ProcedureCallState::Done => {
533                    return Poll::Ready(None);
534                }
535            }
536        }
537    }
538}
539
540impl RecordBatchStream for ProcedureCallStream {
541    fn schema(&self) -> SchemaRef {
542        self.schema.clone()
543    }
544}
545
546// ---------------------------------------------------------------------------
547// Procedure execution dispatch
548// ---------------------------------------------------------------------------
549
550/// Execute a procedure and build a RecordBatch result.
551async fn execute_procedure(
552    graph_ctx: &GraphExecutionContext,
553    procedure_name: &str,
554    args: &[Value],
555    yield_items: &[(String, Option<String>)],
556    target_properties: &HashMap<String, Vec<String>>,
557    schema: &SchemaRef,
558) -> DFResult<Option<RecordBatch>> {
559    match procedure_name {
560        "uni.schema.labels" => execute_schema_labels(graph_ctx, yield_items, schema).await,
561        "uni.schema.edgeTypes" | "uni.schema.relationshipTypes" => {
562            execute_schema_edge_types(graph_ctx, yield_items, schema).await
563        }
564        "uni.schema.indexes" => execute_schema_indexes(graph_ctx, yield_items, schema).await,
565        "uni.schema.constraints" => {
566            execute_schema_constraints(graph_ctx, yield_items, schema).await
567        }
568        "uni.schema.labelInfo" => {
569            execute_schema_label_info(graph_ctx, args, yield_items, schema).await
570        }
571        "uni.vector.query" => {
572            execute_vector_query(graph_ctx, args, yield_items, target_properties, schema).await
573        }
574        "uni.fts.query" => {
575            execute_fts_query(graph_ctx, args, yield_items, target_properties, schema).await
576        }
577        "uni.search" => {
578            execute_hybrid_search(graph_ctx, args, yield_items, target_properties, schema).await
579        }
580        name if name.starts_with("uni.algo.") => {
581            execute_algo_procedure(graph_ctx, name, args, yield_items, schema).await
582        }
583        _ => {
584            execute_registered_procedure(graph_ctx, procedure_name, args, yield_items, schema).await
585        }
586    }
587}
588
589// ---------------------------------------------------------------------------
590// Schema procedures
591// ---------------------------------------------------------------------------
592
593async fn execute_schema_labels(
594    graph_ctx: &GraphExecutionContext,
595    yield_items: &[(String, Option<String>)],
596    schema: &SchemaRef,
597) -> DFResult<Option<RecordBatch>> {
598    let uni_schema = graph_ctx.storage().schema_manager().schema();
599    let storage = graph_ctx.storage();
600
601    // Collect rows: one per label
602    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
603    for label_name in uni_schema.labels.keys() {
604        let mut row = HashMap::new();
605        row.insert("label".to_string(), Value::String(label_name.clone()));
606
607        let prop_count = uni_schema
608            .properties
609            .get(label_name)
610            .map(|p| p.len())
611            .unwrap_or(0);
612        row.insert("propertyCount".to_string(), Value::Int(prop_count as i64));
613
614        let node_count = if let Ok(ds) = storage.vertex_dataset(label_name) {
615            if let Ok(raw) = ds.open_raw().await {
616                raw.count_rows(None).await.unwrap_or(0)
617            } else {
618                0
619            }
620        } else {
621            0
622        };
623        row.insert("nodeCount".to_string(), Value::Int(node_count as i64));
624
625        let idx_count = uni_schema
626            .indexes
627            .iter()
628            .filter(|i| i.label() == label_name)
629            .count();
630        row.insert("indexCount".to_string(), Value::Int(idx_count as i64));
631
632        rows.push(row);
633    }
634
635    build_scalar_batch(&rows, yield_items, schema)
636}
637
638async fn execute_schema_edge_types(
639    graph_ctx: &GraphExecutionContext,
640    yield_items: &[(String, Option<String>)],
641    schema: &SchemaRef,
642) -> DFResult<Option<RecordBatch>> {
643    let uni_schema = graph_ctx.storage().schema_manager().schema();
644
645    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
646    for (type_name, meta) in &uni_schema.edge_types {
647        let mut row = HashMap::new();
648        row.insert("type".to_string(), Value::String(type_name.clone()));
649        row.insert(
650            "relationshipType".to_string(),
651            Value::String(type_name.clone()),
652        );
653        row.insert(
654            "sourceLabels".to_string(),
655            Value::String(format!("{:?}", meta.src_labels)),
656        );
657        row.insert(
658            "targetLabels".to_string(),
659            Value::String(format!("{:?}", meta.dst_labels)),
660        );
661
662        let prop_count = uni_schema
663            .properties
664            .get(type_name)
665            .map(|p| p.len())
666            .unwrap_or(0);
667        row.insert("propertyCount".to_string(), Value::Int(prop_count as i64));
668
669        rows.push(row);
670    }
671
672    build_scalar_batch(&rows, yield_items, schema)
673}
674
675async fn execute_schema_indexes(
676    graph_ctx: &GraphExecutionContext,
677    yield_items: &[(String, Option<String>)],
678    schema: &SchemaRef,
679) -> DFResult<Option<RecordBatch>> {
680    let uni_schema = graph_ctx.storage().schema_manager().schema();
681
682    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
683    for idx in &uni_schema.indexes {
684        use uni_common::core::schema::IndexDefinition;
685
686        // Extract type name and properties JSON per variant
687        let (type_name, properties_json) = match &idx {
688            IndexDefinition::Vector(v) => (
689                "VECTOR",
690                serde_json::to_string(&[&v.property]).unwrap_or_default(),
691            ),
692            IndexDefinition::FullText(f) => (
693                "FULLTEXT",
694                serde_json::to_string(&f.properties).unwrap_or_default(),
695            ),
696            IndexDefinition::Scalar(s) => (
697                "SCALAR",
698                serde_json::to_string(&s.properties).unwrap_or_default(),
699            ),
700            IndexDefinition::JsonFullText(j) => (
701                "JSON_FTS",
702                serde_json::to_string(&[&j.column]).unwrap_or_default(),
703            ),
704            IndexDefinition::Inverted(inv) => (
705                "INVERTED",
706                serde_json::to_string(&[&inv.property]).unwrap_or_default(),
707            ),
708            _ => ("UNKNOWN", String::new()),
709        };
710
711        let row = HashMap::from([
712            ("state".to_string(), Value::String("ONLINE".to_string())),
713            ("name".to_string(), Value::String(idx.name().to_string())),
714            ("type".to_string(), Value::String(type_name.to_string())),
715            ("label".to_string(), Value::String(idx.label().to_string())),
716            ("properties".to_string(), Value::String(properties_json)),
717        ]);
718        rows.push(row);
719    }
720
721    build_scalar_batch(&rows, yield_items, schema)
722}
723
724async fn execute_schema_constraints(
725    graph_ctx: &GraphExecutionContext,
726    yield_items: &[(String, Option<String>)],
727    schema: &SchemaRef,
728) -> DFResult<Option<RecordBatch>> {
729    let uni_schema = graph_ctx.storage().schema_manager().schema();
730
731    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
732    for c in &uni_schema.constraints {
733        let mut row = HashMap::new();
734        row.insert("name".to_string(), Value::String(c.name.clone()));
735        row.insert("enabled".to_string(), Value::Bool(c.enabled));
736
737        match &c.constraint_type {
738            uni_common::core::schema::ConstraintType::Unique { properties } => {
739                row.insert("type".to_string(), Value::String("UNIQUE".to_string()));
740                row.insert(
741                    "properties".to_string(),
742                    Value::String(serde_json::to_string(&properties).unwrap_or_default()),
743                );
744            }
745            uni_common::core::schema::ConstraintType::Exists { property } => {
746                row.insert("type".to_string(), Value::String("EXISTS".to_string()));
747                row.insert(
748                    "properties".to_string(),
749                    Value::String(serde_json::to_string(&[&property]).unwrap_or_default()),
750                );
751            }
752            uni_common::core::schema::ConstraintType::Check { expression } => {
753                row.insert("type".to_string(), Value::String("CHECK".to_string()));
754                row.insert("expression".to_string(), Value::String(expression.clone()));
755            }
756            _ => {
757                row.insert("type".to_string(), Value::String("UNKNOWN".to_string()));
758            }
759        }
760
761        match &c.target {
762            uni_common::core::schema::ConstraintTarget::Label(l) => {
763                row.insert("label".to_string(), Value::String(l.clone()));
764            }
765            uni_common::core::schema::ConstraintTarget::EdgeType(t) => {
766                row.insert("relationshipType".to_string(), Value::String(t.clone()));
767            }
768            _ => {
769                row.insert("target".to_string(), Value::String("UNKNOWN".to_string()));
770            }
771        }
772
773        rows.push(row);
774    }
775
776    build_scalar_batch(&rows, yield_items, schema)
777}
778
779async fn execute_schema_label_info(
780    graph_ctx: &GraphExecutionContext,
781    args: &[Value],
782    yield_items: &[(String, Option<String>)],
783    schema: &SchemaRef,
784) -> DFResult<Option<RecordBatch>> {
785    let label_name = require_string_arg(args, 0, "uni.schema.labelInfo: first argument (label)")?;
786
787    let uni_schema = graph_ctx.storage().schema_manager().schema();
788
789    let mut rows: Vec<HashMap<String, Value>> = Vec::new();
790    if let Some(props) = uni_schema.properties.get(&label_name) {
791        for (prop_name, prop_meta) in props {
792            let mut row = HashMap::new();
793            row.insert("property".to_string(), Value::String(prop_name.clone()));
794            row.insert(
795                "dataType".to_string(),
796                Value::String(format!("{:?}", prop_meta.r#type)),
797            );
798            row.insert("nullable".to_string(), Value::Bool(prop_meta.nullable));
799
800            let is_indexed = uni_schema.indexes.iter().any(|idx| match idx {
801                uni_common::core::schema::IndexDefinition::Vector(v) => {
802                    v.label == label_name && v.property == *prop_name
803                }
804                uni_common::core::schema::IndexDefinition::Scalar(s) => {
805                    s.label == label_name && s.properties.contains(prop_name)
806                }
807                uni_common::core::schema::IndexDefinition::FullText(f) => {
808                    f.label == label_name && f.properties.contains(prop_name)
809                }
810                uni_common::core::schema::IndexDefinition::Inverted(inv) => {
811                    inv.label == label_name && inv.property == *prop_name
812                }
813                uni_common::core::schema::IndexDefinition::JsonFullText(j) => j.label == label_name,
814                _ => false,
815            });
816            row.insert("indexed".to_string(), Value::Bool(is_indexed));
817
818            let unique = uni_schema.constraints.iter().any(|c| {
819                if let uni_common::core::schema::ConstraintTarget::Label(l) = &c.target
820                    && l == &label_name
821                    && c.enabled
822                    && let uni_common::core::schema::ConstraintType::Unique { properties } =
823                        &c.constraint_type
824                {
825                    return properties.contains(prop_name);
826                }
827                false
828            });
829            row.insert("unique".to_string(), Value::Bool(unique));
830
831            rows.push(row);
832        }
833    }
834
835    build_scalar_batch(&rows, yield_items, schema)
836}
837
838/// Build a typed Arrow column from an iterator of optional `Value`s.
839///
840/// Dispatches on `data_type` to build the appropriate Arrow array. For types
841/// not explicitly handled (Utf8 fallback), values are stringified.
842fn build_typed_column<'a>(
843    values: impl Iterator<Item = Option<&'a Value>>,
844    num_rows: usize,
845    data_type: &DataType,
846) -> ArrayRef {
847    match data_type {
848        DataType::Int64 => {
849            let mut builder = Int64Builder::with_capacity(num_rows);
850            for val in values {
851                match val.and_then(|v| v.as_i64()) {
852                    Some(i) => builder.append_value(i),
853                    None => builder.append_null(),
854                }
855            }
856            Arc::new(builder.finish())
857        }
858        DataType::Float64 => {
859            let mut builder = Float64Builder::with_capacity(num_rows);
860            for val in values {
861                match val.and_then(|v| v.as_f64()) {
862                    Some(f) => builder.append_value(f),
863                    None => builder.append_null(),
864                }
865            }
866            Arc::new(builder.finish())
867        }
868        DataType::Boolean => {
869            let mut builder = BooleanBuilder::with_capacity(num_rows);
870            for val in values {
871                match val.and_then(|v| v.as_bool()) {
872                    Some(b) => builder.append_value(b),
873                    None => builder.append_null(),
874                }
875            }
876            Arc::new(builder.finish())
877        }
878        _ => {
879            // Utf8 fallback: stringify values
880            let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
881            for val in values {
882                match val {
883                    Some(Value::String(s)) => builder.append_value(s),
884                    Some(v) => builder.append_value(format!("{v}")),
885                    None => builder.append_null(),
886                }
887            }
888            Arc::new(builder.finish())
889        }
890    }
891}
892
893/// Create an empty RecordBatch for the given schema.
894///
895/// When a schema has zero fields, `RecordBatch::new_empty()` panics because it
896/// cannot determine the row count from an empty array. This helper handles that
897/// edge case by using `RecordBatchOptions::with_row_count(0)`.
898fn create_empty_batch(schema: SchemaRef) -> DFResult<RecordBatch> {
899    if schema.fields().is_empty() {
900        let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
901        RecordBatch::try_new_with_options(schema, vec![], &options)
902            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
903    } else {
904        Ok(RecordBatch::new_empty(schema))
905    }
906}
907
908/// Build a RecordBatch from scalar-valued rows for schema procedures.
909fn build_scalar_batch(
910    rows: &[HashMap<String, Value>],
911    yield_items: &[(String, Option<String>)],
912    schema: &SchemaRef,
913) -> DFResult<Option<RecordBatch>> {
914    if rows.is_empty() {
915        return Ok(Some(create_empty_batch(schema.clone())?));
916    }
917
918    let num_rows = rows.len();
919    let mut columns: Vec<ArrayRef> = Vec::new();
920
921    for (idx, (name, _alias)) in yield_items.iter().enumerate() {
922        let field = schema.field(idx);
923        let values = rows.iter().map(|row| row.get(name));
924        columns.push(build_typed_column(values, num_rows, field.data_type()));
925    }
926
927    let batch = RecordBatch::try_new(schema.clone(), columns)
928        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
929    Ok(Some(batch))
930}
931
932// ---------------------------------------------------------------------------
933// External/registered procedures
934// ---------------------------------------------------------------------------
935
936/// Execute an externally registered procedure (e.g., TCK test procedures).
937///
938/// Looks up the procedure in the `ProcedureRegistry`, evaluates arguments,
939/// filters data rows by matching input columns, and projects output columns.
940async fn execute_registered_procedure(
941    graph_ctx: &GraphExecutionContext,
942    procedure_name: &str,
943    args: &[Value],
944    yield_items: &[(String, Option<String>)],
945    schema: &SchemaRef,
946) -> DFResult<Option<RecordBatch>> {
947    let registry = graph_ctx.procedure_registry().ok_or_else(|| {
948        datafusion::error::DataFusionError::Execution(format!(
949            "Procedure '{}' not supported in DataFusion engine (no procedure registry)",
950            procedure_name
951        ))
952    })?;
953
954    let proc_def = registry.get(procedure_name).ok_or_else(|| {
955        datafusion::error::DataFusionError::Execution(format!(
956            "ProcedureNotFound: Unknown procedure '{}'",
957            procedure_name
958        ))
959    })?;
960
961    // Validate argument count
962    if args.len() != proc_def.params.len() {
963        return Err(datafusion::error::DataFusionError::Execution(format!(
964            "InvalidNumberOfArguments: Procedure '{}' expects {} argument(s), got {}",
965            proc_def.name,
966            proc_def.params.len(),
967            args.len()
968        )));
969    }
970
971    // Validate argument types
972    for (i, (arg_val, param)) in args.iter().zip(&proc_def.params).enumerate() {
973        if !arg_val.is_null() && !check_proc_type_compatible(arg_val, &param.param_type) {
974            return Err(datafusion::error::DataFusionError::Execution(format!(
975                "InvalidArgumentType: Argument {} ('{}') of procedure '{}' has incompatible type",
976                i, param.name, proc_def.name
977            )));
978        }
979    }
980
981    // Filter data rows: keep rows where input columns match the provided args
982    let filtered: Vec<&HashMap<String, Value>> = proc_def
983        .data
984        .iter()
985        .filter(|row| {
986            for (param, arg_val) in proc_def.params.iter().zip(args) {
987                if let Some(row_val) = row.get(&param.name)
988                    && !proc_values_match(row_val, arg_val)
989                {
990                    return false;
991                }
992            }
993            true
994        })
995        .collect();
996
997    // If the procedure has no yield items (void procedure), return empty batch
998    if yield_items.is_empty() {
999        return Ok(Some(create_empty_batch(schema.clone())?));
1000    }
1001
1002    if filtered.is_empty() {
1003        return Ok(Some(create_empty_batch(schema.clone())?));
1004    }
1005
1006    // Project output columns based on yield items
1007    // We need to map yield names back to output column names in the procedure definition
1008    let num_rows = filtered.len();
1009    let mut columns: Vec<ArrayRef> = Vec::new();
1010
1011    for (idx, (name, _alias)) in yield_items.iter().enumerate() {
1012        let field = schema.field(idx);
1013        let values = filtered.iter().map(|row| row.get(name.as_str()));
1014        columns.push(build_typed_column(values, num_rows, field.data_type()));
1015    }
1016
1017    let batch = RecordBatch::try_new(schema.clone(), columns)
1018        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1019    Ok(Some(batch))
1020}
1021
1022/// Checks whether a value is compatible with a procedure type (DF engine version).
1023fn check_proc_type_compatible(
1024    val: &Value,
1025    expected: &crate::query::executor::procedure::ProcedureValueType,
1026) -> bool {
1027    use crate::query::executor::procedure::ProcedureValueType;
1028    match expected {
1029        ProcedureValueType::Any => true,
1030        ProcedureValueType::String => val.is_string(),
1031        ProcedureValueType::Boolean => val.is_bool(),
1032        ProcedureValueType::Integer => val.is_i64(),
1033        ProcedureValueType::Float => val.is_f64() || val.is_i64(),
1034        ProcedureValueType::Number => val.is_number(),
1035    }
1036}
1037
1038/// Checks whether two values match for input-column filtering (DF engine version).
1039fn proc_values_match(row_val: &Value, arg_val: &Value) -> bool {
1040    if arg_val.is_null() || row_val.is_null() {
1041        return arg_val.is_null() && row_val.is_null();
1042    }
1043    // Compare numbers by f64 to handle int/float cross-comparison
1044    if let (Some(a), Some(b)) = (row_val.as_f64(), arg_val.as_f64()) {
1045        return (a - b).abs() < f64::EPSILON;
1046    }
1047    row_val == arg_val
1048}
1049
1050// ---------------------------------------------------------------------------
1051// Algorithm procedures
1052// ---------------------------------------------------------------------------
1053
1054async fn execute_algo_procedure(
1055    graph_ctx: &GraphExecutionContext,
1056    procedure_name: &str,
1057    args: &[Value],
1058    yield_items: &[(String, Option<String>)],
1059    schema: &SchemaRef,
1060) -> DFResult<Option<RecordBatch>> {
1061    use futures::StreamExt;
1062    use uni_algo::algo::procedures::AlgoContext;
1063
1064    let registry = graph_ctx.algo_registry().ok_or_else(|| {
1065        datafusion::error::DataFusionError::Execution(
1066            "Algorithm registry not available".to_string(),
1067        )
1068    })?;
1069
1070    let procedure = registry.get(procedure_name).ok_or_else(|| {
1071        datafusion::error::DataFusionError::Execution(format!(
1072            "Unknown algorithm: {}",
1073            procedure_name
1074        ))
1075    })?;
1076
1077    let signature = procedure.signature();
1078
1079    // Convert uni_common::Value args to serde_json::Value for algo crate.
1080    // Note: do NOT call validate_args here — the procedure's own execute()
1081    // already validates and fills defaults internally.
1082    let serde_args: Vec<serde_json::Value> = args.iter().cloned().map(|v| v.into()).collect();
1083
1084    // Build AlgoContext — no L0Manager in the DF path (read-only snapshot)
1085    let algo_ctx = AlgoContext::new(graph_ctx.storage().clone(), None);
1086
1087    // Execute and collect stream
1088    let mut stream = procedure.execute(algo_ctx, serde_args);
1089    let mut rows = Vec::new();
1090    while let Some(row_res) = stream.next().await {
1091        // Check timeout periodically
1092        if rows.len() % 1000 == 0 {
1093            graph_ctx
1094                .check_timeout()
1095                .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1096        }
1097        let row =
1098            row_res.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1099        rows.push(row);
1100    }
1101
1102    build_algo_batch(&rows, &signature, yield_items, schema)
1103}
1104
1105/// Convert a `serde_json::Value` to a `uni_common::Value` for column building.
1106fn json_to_value(jv: &serde_json::Value) -> Value {
1107    match jv {
1108        serde_json::Value::Null => Value::Null,
1109        serde_json::Value::Bool(b) => Value::Bool(*b),
1110        serde_json::Value::Number(n) => {
1111            if let Some(i) = n.as_i64() {
1112                Value::Int(i)
1113            } else if let Some(f) = n.as_f64() {
1114                Value::Float(f)
1115            } else {
1116                Value::Null
1117            }
1118        }
1119        serde_json::Value::String(s) => Value::String(s.clone()),
1120        other => Value::String(other.to_string()),
1121    }
1122}
1123
1124/// Build a RecordBatch from algorithm result rows.
1125fn build_algo_batch(
1126    rows: &[uni_algo::algo::procedures::AlgoResultRow],
1127    signature: &uni_algo::algo::procedures::ProcedureSignature,
1128    yield_items: &[(String, Option<String>)],
1129    schema: &SchemaRef,
1130) -> DFResult<Option<RecordBatch>> {
1131    if rows.is_empty() {
1132        return Ok(Some(create_empty_batch(schema.clone())?));
1133    }
1134
1135    let num_rows = rows.len();
1136    let mut columns: Vec<ArrayRef> = Vec::new();
1137
1138    for (idx, (yield_name, _alias)) in yield_items.iter().enumerate() {
1139        let sig_idx = signature
1140            .yields
1141            .iter()
1142            .position(|(n, _)| *n == yield_name.as_str());
1143
1144        // Convert serde_json values to uni_common::Value for the shared column builder
1145        let uni_values: Vec<Value> = rows
1146            .iter()
1147            .map(|row| match sig_idx {
1148                Some(si) => json_to_value(&row.values[si]),
1149                None => Value::Null,
1150            })
1151            .collect();
1152
1153        let field = schema.field(idx);
1154        let values = uni_values.iter().map(Some);
1155        columns.push(build_typed_column(values, num_rows, field.data_type()));
1156    }
1157
1158    let batch = RecordBatch::try_new(schema.clone(), columns)
1159        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1160    Ok(Some(batch))
1161}
1162
1163// ---------------------------------------------------------------------------
1164// Shared search argument helpers
1165// ---------------------------------------------------------------------------
1166
1167/// Extract a required string argument from the argument list at a given position.
1168fn require_string_arg(args: &[Value], index: usize, description: &str) -> DFResult<String> {
1169    args.get(index)
1170        .and_then(|v| v.as_str())
1171        .map(|s| s.to_string())
1172        .ok_or_else(|| {
1173            datafusion::error::DataFusionError::Execution(format!("{description} must be a string"))
1174        })
1175}
1176
1177/// Extract an optional filter string from the argument list.
1178/// Returns `None` if the argument is missing, null, or not a string.
1179fn extract_optional_filter(args: &[Value], index: usize) -> Option<String> {
1180    args.get(index).and_then(|v| {
1181        if v.is_null() {
1182            None
1183        } else {
1184            v.as_str().map(|s| s.to_string())
1185        }
1186    })
1187}
1188
1189/// Extract an optional float threshold from the argument list.
1190/// Returns `None` if the argument is missing or null.
1191fn extract_optional_threshold(args: &[Value], index: usize) -> Option<f64> {
1192    args.get(index)
1193        .and_then(|v| if v.is_null() { None } else { v.as_f64() })
1194}
1195
1196/// Extract a required integer argument from the argument list at a given position.
1197fn require_int_arg(args: &[Value], index: usize, description: &str) -> DFResult<usize> {
1198    args.get(index)
1199        .and_then(|v| v.as_u64())
1200        .map(|v| v as usize)
1201        .ok_or_else(|| {
1202            datafusion::error::DataFusionError::Execution(format!(
1203                "{description} must be an integer"
1204            ))
1205        })
1206}
1207
1208// ---------------------------------------------------------------------------
1209// Vector/FTS/Hybrid search procedures
1210// ---------------------------------------------------------------------------
1211
1212/// Auto-embed a text query using the vector index's embedding configuration.
1213///
1214/// Looks up the embedding config from the index on `label.property` and uses
1215/// it to embed the provided text query into a vector.
1216async fn auto_embed_text(
1217    graph_ctx: &GraphExecutionContext,
1218    label: &str,
1219    property: &str,
1220    query_text: &str,
1221) -> DFResult<Vec<f32>> {
1222    let storage = graph_ctx.storage();
1223    let uni_schema = storage.schema_manager().schema();
1224    let index_config = uni_schema.vector_index_for_property(label, property);
1225
1226    let embedding_config = index_config
1227        .and_then(|cfg| cfg.embedding_config.as_ref())
1228        .ok_or_else(|| {
1229            datafusion::error::DataFusionError::Execution(format!(
1230                "Cannot auto-embed: vector index for {label}.{property} has no embedding_config. \
1231                 Either provide a pre-computed vector or create the index with embedding options."
1232            ))
1233        })?;
1234
1235    let runtime = graph_ctx.xervo_runtime().ok_or_else(|| {
1236        datafusion::error::DataFusionError::Execution(
1237            "Cannot auto-embed: Uni-Xervo runtime not configured".to_string(),
1238        )
1239    })?;
1240
1241    let embedder = runtime
1242        .embedding(&embedding_config.alias)
1243        .await
1244        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1245    let embeddings = embedder
1246        .embed(vec![query_text])
1247        .await
1248        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1249    embeddings.into_iter().next().ok_or_else(|| {
1250        datafusion::error::DataFusionError::Execution(
1251            "Embedding service returned no results".to_string(),
1252        )
1253    })
1254}
1255
1256async fn execute_vector_query(
1257    graph_ctx: &GraphExecutionContext,
1258    args: &[Value],
1259    yield_items: &[(String, Option<String>)],
1260    target_properties: &HashMap<String, Vec<String>>,
1261    schema: &SchemaRef,
1262) -> DFResult<Option<RecordBatch>> {
1263    let label = require_string_arg(args, 0, "uni.vector.query: first argument (label)")?;
1264    let property = require_string_arg(args, 1, "uni.vector.query: second argument (property)")?;
1265
1266    let query_val = args.get(2).ok_or_else(|| {
1267        datafusion::error::DataFusionError::Execution(
1268            "uni.vector.query: third argument (query) is required".to_string(),
1269        )
1270    })?;
1271
1272    let storage = graph_ctx.storage();
1273
1274    let query_vector: Vec<f32> = if let Some(query_text) = query_val.as_str() {
1275        auto_embed_text(graph_ctx, &label, &property, query_text).await?
1276    } else {
1277        extract_vector(query_val)?
1278    };
1279
1280    let k = require_int_arg(args, 3, "uni.vector.query: fourth argument (k)")?;
1281    let filter = extract_optional_filter(args, 4);
1282    let threshold = extract_optional_threshold(args, 5);
1283    let query_ctx = graph_ctx.query_context();
1284
1285    let mut results = storage
1286        .vector_search(
1287            &label,
1288            &property,
1289            &query_vector,
1290            k,
1291            filter.as_deref(),
1292            Some(&query_ctx),
1293        )
1294        .await
1295        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1296
1297    // Apply threshold post-filter (on distance)
1298    if let Some(max_dist) = threshold {
1299        results.retain(|(_, dist)| *dist <= max_dist as f32);
1300    }
1301
1302    if results.is_empty() {
1303        return Ok(Some(create_empty_batch(schema.clone())?));
1304    }
1305
1306    // Calculate scores using the same logic as the old executor
1307    let schema_manager = storage.schema_manager();
1308    let uni_schema = schema_manager.schema();
1309    let metric = uni_schema
1310        .vector_index_for_property(&label, &property)
1311        .map(|config| config.metric.clone())
1312        .unwrap_or(uni_common::core::schema::DistanceMetric::L2);
1313
1314    build_search_result_batch(
1315        &results,
1316        &label,
1317        &metric,
1318        yield_items,
1319        target_properties,
1320        graph_ctx,
1321        schema,
1322    )
1323    .await
1324}
1325
1326// ---------------------------------------------------------------------------
1327// FTS search procedure
1328// ---------------------------------------------------------------------------
1329
1330async fn execute_fts_query(
1331    graph_ctx: &GraphExecutionContext,
1332    args: &[Value],
1333    yield_items: &[(String, Option<String>)],
1334    target_properties: &HashMap<String, Vec<String>>,
1335    schema: &SchemaRef,
1336) -> DFResult<Option<RecordBatch>> {
1337    let label = require_string_arg(args, 0, "uni.fts.query: first argument (label)")?;
1338    let property = require_string_arg(args, 1, "uni.fts.query: second argument (property)")?;
1339    let search_term = require_string_arg(args, 2, "uni.fts.query: third argument (search_term)")?;
1340    let k = require_int_arg(args, 3, "uni.fts.query: fourth argument (k)")?;
1341    let filter = extract_optional_filter(args, 4);
1342    let threshold = extract_optional_threshold(args, 5);
1343
1344    let storage = graph_ctx.storage();
1345    let query_ctx = graph_ctx.query_context();
1346
1347    let mut results = storage
1348        .fts_search(
1349            &label,
1350            &property,
1351            &search_term,
1352            k,
1353            filter.as_deref(),
1354            Some(&query_ctx),
1355        )
1356        .await
1357        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1358
1359    if let Some(min_score) = threshold {
1360        results.retain(|(_, score)| *score as f64 >= min_score);
1361    }
1362
1363    if results.is_empty() {
1364        return Ok(Some(create_empty_batch(schema.clone())?));
1365    }
1366
1367    // FTS uses a "fake" L2 metric for the batch builder — scores are already BM25
1368    // We use L2 as a placeholder; the actual score column is built differently.
1369    build_search_result_batch(
1370        &results,
1371        &label,
1372        &uni_common::core::schema::DistanceMetric::L2,
1373        yield_items,
1374        target_properties,
1375        graph_ctx,
1376        schema,
1377    )
1378    .await
1379}
1380
1381// ---------------------------------------------------------------------------
1382// Hybrid search procedure
1383// ---------------------------------------------------------------------------
1384
1385async fn execute_hybrid_search(
1386    graph_ctx: &GraphExecutionContext,
1387    args: &[Value],
1388    yield_items: &[(String, Option<String>)],
1389    target_properties: &HashMap<String, Vec<String>>,
1390    schema: &SchemaRef,
1391) -> DFResult<Option<RecordBatch>> {
1392    let label = require_string_arg(args, 0, "uni.search: first argument (label)")?;
1393
1394    // Parse properties: {vector: '...', fts: '...'} or just a string
1395    let properties_val = args.get(1).ok_or_else(|| {
1396        datafusion::error::DataFusionError::Execution(
1397            "uni.search: second argument (properties) is required".to_string(),
1398        )
1399    })?;
1400
1401    let (vector_prop, fts_prop) = if let Some(obj) = properties_val.as_object() {
1402        let vec_prop = obj
1403            .get("vector")
1404            .and_then(|v| v.as_str())
1405            .map(|s| s.to_string());
1406        let fts_prop = obj
1407            .get("fts")
1408            .and_then(|v| v.as_str())
1409            .map(|s| s.to_string());
1410        (vec_prop, fts_prop)
1411    } else if let Some(prop) = properties_val.as_str() {
1412        // Shorthand: just property name means both vector and FTS
1413        (Some(prop.to_string()), Some(prop.to_string()))
1414    } else {
1415        return Err(datafusion::error::DataFusionError::Execution(
1416            "Properties must be an object {vector: '...', fts: '...'} or a string".to_string(),
1417        ));
1418    };
1419
1420    let query_text = require_string_arg(args, 2, "uni.search: third argument (query_text)")?;
1421
1422    // Arg 3: query vector (optional, can be null)
1423    let query_vector: Option<Vec<f32>> = args.get(3).and_then(|v| {
1424        if v.is_null() {
1425            return None;
1426        }
1427        v.as_array().map(|arr| {
1428            arr.iter()
1429                .filter_map(|v| v.as_f64().map(|f| f as f32))
1430                .collect()
1431        })
1432    });
1433
1434    let k = require_int_arg(args, 4, "uni.search: fifth argument (k)")?;
1435    let filter = extract_optional_filter(args, 5);
1436
1437    // Arg 6: options (optional)
1438    let options_val = args.get(6);
1439    let options_map = options_val.and_then(|v| v.as_object());
1440    let fusion_method = options_map
1441        .and_then(|m| m.get("method"))
1442        .and_then(|v| v.as_str())
1443        .unwrap_or("rrf")
1444        .to_string();
1445    let alpha = options_map
1446        .and_then(|m| m.get("alpha"))
1447        .and_then(|v| v.as_f64())
1448        .unwrap_or(0.5) as f32;
1449    let over_fetch_factor = options_map
1450        .and_then(|m| m.get("over_fetch"))
1451        .and_then(|v| v.as_f64())
1452        .unwrap_or(2.0) as f32;
1453    let rrf_k = options_map
1454        .and_then(|m| m.get("rrf_k"))
1455        .and_then(|v| v.as_u64())
1456        .unwrap_or(60) as usize;
1457
1458    let over_fetch_k = (k as f32 * over_fetch_factor).ceil() as usize;
1459
1460    let storage = graph_ctx.storage();
1461    let query_ctx = graph_ctx.query_context();
1462
1463    // Execute vector search if configured
1464    let mut vector_results: Vec<(Vid, f32)> = Vec::new();
1465    if let Some(ref vec_prop) = vector_prop {
1466        // Get or generate query vector
1467        let qvec = if let Some(ref v) = query_vector {
1468            v.clone()
1469        } else {
1470            // Auto-embed the query text if embedding config exists
1471            auto_embed_text(graph_ctx, &label, vec_prop, &query_text)
1472                .await
1473                .unwrap_or_default()
1474        };
1475
1476        if !qvec.is_empty() {
1477            vector_results = storage
1478                .vector_search(
1479                    &label,
1480                    vec_prop,
1481                    &qvec,
1482                    over_fetch_k,
1483                    filter.as_deref(),
1484                    Some(&query_ctx),
1485                )
1486                .await
1487                .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1488        }
1489    }
1490
1491    // Execute FTS search if configured
1492    let mut fts_results: Vec<(Vid, f32)> = Vec::new();
1493    if let Some(ref fts_prop) = fts_prop {
1494        fts_results = storage
1495            .fts_search(
1496                &label,
1497                fts_prop,
1498                &query_text,
1499                over_fetch_k,
1500                filter.as_deref(),
1501                Some(&query_ctx),
1502            )
1503            .await
1504            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1505    }
1506
1507    // Fuse results
1508    let fused_results = match fusion_method.as_str() {
1509        "weighted" => fuse_weighted(&vector_results, &fts_results, alpha),
1510        _ => fuse_rrf(&vector_results, &fts_results, rrf_k),
1511    };
1512
1513    // Limit to k results
1514    let final_results: Vec<_> = fused_results.into_iter().take(k).collect();
1515
1516    if final_results.is_empty() {
1517        return Ok(Some(create_empty_batch(schema.clone())?));
1518    }
1519
1520    // Build lookup maps for original scores
1521    let vec_score_map: HashMap<Vid, f32> = vector_results.iter().cloned().collect();
1522    let fts_score_map: HashMap<Vid, f32> = fts_results.iter().cloned().collect();
1523    let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
1524
1525    // Get distance metric for vector score normalization
1526    let uni_schema = storage.schema_manager().schema();
1527    let metric = vector_prop
1528        .as_ref()
1529        .and_then(|vp| {
1530            uni_schema
1531                .vector_index_for_property(&label, vp)
1532                .map(|config| config.metric.clone())
1533        })
1534        .unwrap_or(uni_common::core::schema::DistanceMetric::L2);
1535
1536    let score_ctx = HybridScoreContext {
1537        vec_score_map: &vec_score_map,
1538        fts_score_map: &fts_score_map,
1539        fts_max,
1540        metric: &metric,
1541    };
1542
1543    build_hybrid_search_batch(
1544        &final_results,
1545        &score_ctx,
1546        &label,
1547        yield_items,
1548        target_properties,
1549        graph_ctx,
1550        schema,
1551    )
1552    .await
1553}
1554
1555/// Reciprocal Rank Fusion (RRF) for combining search results.
1556/// RRF score = sum(1 / (k + rank)) for each result list.
1557fn fuse_rrf(vec_results: &[(Vid, f32)], fts_results: &[(Vid, f32)], k: usize) -> Vec<(Vid, f32)> {
1558    let mut scores: HashMap<Vid, f32> = HashMap::new();
1559
1560    for (rank, (vid, _)) in vec_results.iter().enumerate() {
1561        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
1562        *scores.entry(*vid).or_default() += rrf_score;
1563    }
1564
1565    for (rank, (vid, _)) in fts_results.iter().enumerate() {
1566        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
1567        *scores.entry(*vid).or_default() += rrf_score;
1568    }
1569
1570    let mut results: Vec<_> = scores.into_iter().collect();
1571    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1572    results
1573}
1574
1575/// Weighted fusion: alpha * vec_score + (1 - alpha) * fts_score.
1576/// Both score sets are normalized to [0, 1] range.
1577fn fuse_weighted(
1578    vec_results: &[(Vid, f32)],
1579    fts_results: &[(Vid, f32)],
1580    alpha: f32,
1581) -> Vec<(Vid, f32)> {
1582    // Normalize vector scores (distance -> similarity)
1583    let vec_max = vec_results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
1584    let vec_min = vec_results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
1585    let vec_range = if vec_max > vec_min {
1586        vec_max - vec_min
1587    } else {
1588        1.0
1589    };
1590
1591    let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
1592
1593    let vec_scores: HashMap<Vid, f32> = vec_results
1594        .iter()
1595        .map(|(vid, dist)| {
1596            let norm = 1.0 - (dist - vec_min) / vec_range;
1597            (*vid, norm)
1598        })
1599        .collect();
1600
1601    let fts_scores: HashMap<Vid, f32> = fts_results
1602        .iter()
1603        .map(|(vid, score)| {
1604            let norm = if fts_max > 0.0 { score / fts_max } else { 0.0 };
1605            (*vid, norm)
1606        })
1607        .collect();
1608
1609    let all_vids: std::collections::HashSet<Vid> = vec_scores
1610        .keys()
1611        .chain(fts_scores.keys())
1612        .cloned()
1613        .collect();
1614
1615    let mut results: Vec<(Vid, f32)> = all_vids
1616        .into_iter()
1617        .map(|vid| {
1618            let vec_score = *vec_scores.get(&vid).unwrap_or(&0.0);
1619            let fts_score = *fts_scores.get(&vid).unwrap_or(&0.0);
1620            let fused = alpha * vec_score + (1.0 - alpha) * fts_score;
1621            (vid, fused)
1622        })
1623        .collect();
1624
1625    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1626    results
1627}
1628
1629/// Precomputed score context for hybrid search batch building.
1630struct HybridScoreContext<'a> {
1631    vec_score_map: &'a HashMap<Vid, f32>,
1632    fts_score_map: &'a HashMap<Vid, f32>,
1633    fts_max: f32,
1634    metric: &'a uni_common::core::schema::DistanceMetric,
1635}
1636
1637/// Build a RecordBatch for hybrid search results with fused, vector, and FTS scores.
1638async fn build_hybrid_search_batch(
1639    results: &[(Vid, f32)],
1640    scores: &HybridScoreContext<'_>,
1641    label: &str,
1642    yield_items: &[(String, Option<String>)],
1643    target_properties: &HashMap<String, Vec<String>>,
1644    graph_ctx: &GraphExecutionContext,
1645    schema: &SchemaRef,
1646) -> DFResult<Option<RecordBatch>> {
1647    let num_rows = results.len();
1648    let vids: Vec<Vid> = results.iter().map(|(vid, _)| *vid).collect();
1649    let fused_scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
1650
1651    // Pre-load properties for node-like yields
1652    let property_manager = graph_ctx.property_manager();
1653    let query_ctx = graph_ctx.query_context();
1654    let uni_schema = graph_ctx.storage().schema_manager().schema();
1655    let label_props = uni_schema.properties.get(label);
1656
1657    let has_node_yield = yield_items
1658        .iter()
1659        .any(|(name, _)| map_yield_to_canonical(name) == "node");
1660
1661    let props_map = if has_node_yield {
1662        property_manager
1663            .get_batch_vertex_props_for_label(&vids, label, Some(&query_ctx))
1664            .await
1665            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?
1666    } else {
1667        HashMap::new()
1668    };
1669
1670    let mut columns: Vec<ArrayRef> = Vec::new();
1671
1672    for (name, alias) in yield_items {
1673        let output_name = alias.as_ref().unwrap_or(name);
1674        let canonical = map_yield_to_canonical(name);
1675
1676        match canonical.as_str() {
1677            "node" => {
1678                columns.extend(build_node_yield_columns(
1679                    &vids,
1680                    label,
1681                    output_name,
1682                    target_properties,
1683                    &props_map,
1684                    label_props,
1685                )?);
1686            }
1687            "vid" => {
1688                let mut builder = Int64Builder::with_capacity(num_rows);
1689                for vid in &vids {
1690                    builder.append_value(vid.as_u64() as i64);
1691                }
1692                columns.push(Arc::new(builder.finish()));
1693            }
1694            "score" => {
1695                let mut builder = Float32Builder::with_capacity(num_rows);
1696                for score in &fused_scores {
1697                    builder.append_value(*score);
1698                }
1699                columns.push(Arc::new(builder.finish()));
1700            }
1701            "vector_score" => {
1702                let mut builder = Float32Builder::with_capacity(num_rows);
1703                for vid in &vids {
1704                    if let Some(&dist) = scores.vec_score_map.get(vid) {
1705                        let score = calculate_score(dist, scores.metric);
1706                        builder.append_value(score);
1707                    } else {
1708                        builder.append_null();
1709                    }
1710                }
1711                columns.push(Arc::new(builder.finish()));
1712            }
1713            "fts_score" => {
1714                let mut builder = Float32Builder::with_capacity(num_rows);
1715                for vid in &vids {
1716                    if let Some(&raw_score) = scores.fts_score_map.get(vid) {
1717                        let norm = if scores.fts_max > 0.0 {
1718                            raw_score / scores.fts_max
1719                        } else {
1720                            0.0
1721                        };
1722                        builder.append_value(norm);
1723                    } else {
1724                        builder.append_null();
1725                    }
1726                }
1727                columns.push(Arc::new(builder.finish()));
1728            }
1729            "distance" => {
1730                // For hybrid search, distance is the vector distance if available
1731                let mut builder = Float64Builder::with_capacity(num_rows);
1732                for vid in &vids {
1733                    if let Some(&dist) = scores.vec_score_map.get(vid) {
1734                        builder.append_value(dist as f64);
1735                    } else {
1736                        builder.append_null();
1737                    }
1738                }
1739                columns.push(Arc::new(builder.finish()));
1740            }
1741            _ => {
1742                let mut builder = StringBuilder::with_capacity(num_rows, 0);
1743                for _ in 0..num_rows {
1744                    builder.append_null();
1745                }
1746                columns.push(Arc::new(builder.finish()));
1747            }
1748        }
1749    }
1750
1751    let batch = RecordBatch::try_new(schema.clone(), columns)
1752        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1753    Ok(Some(batch))
1754}
1755
1756// ---------------------------------------------------------------------------
1757// Shared search result batch builder
1758// ---------------------------------------------------------------------------
1759
1760/// Build a RecordBatch for search procedures (vector, FTS) that yield
1761/// both node-like and scalar columns.
1762async fn build_search_result_batch(
1763    results: &[(Vid, f32)],
1764    label: &str,
1765    metric: &uni_common::core::schema::DistanceMetric,
1766    yield_items: &[(String, Option<String>)],
1767    target_properties: &HashMap<String, Vec<String>>,
1768    graph_ctx: &GraphExecutionContext,
1769    schema: &SchemaRef,
1770) -> DFResult<Option<RecordBatch>> {
1771    let num_rows = results.len();
1772    let vids: Vec<Vid> = results.iter().map(|(vid, _)| *vid).collect();
1773    let distances: Vec<f32> = results.iter().map(|(_, d)| *d).collect();
1774
1775    // Pre-compute scores
1776    let scores: Vec<f32> = distances
1777        .iter()
1778        .map(|dist| calculate_score(*dist, metric))
1779        .collect();
1780
1781    // Pre-load properties for all node-like yields
1782    let property_manager = graph_ctx.property_manager();
1783    let query_ctx = graph_ctx.query_context();
1784    let uni_schema = graph_ctx.storage().schema_manager().schema();
1785    let label_props = uni_schema.properties.get(label);
1786
1787    // Load properties if any node-like yield needs them
1788    let has_node_yield = yield_items
1789        .iter()
1790        .any(|(name, _)| map_yield_to_canonical(name) == "node");
1791
1792    let props_map = if has_node_yield {
1793        property_manager
1794            .get_batch_vertex_props_for_label(&vids, label, Some(&query_ctx))
1795            .await
1796            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?
1797    } else {
1798        HashMap::new()
1799    };
1800
1801    // Build columns in schema order
1802    let mut columns: Vec<ArrayRef> = Vec::new();
1803
1804    for (name, alias) in yield_items {
1805        let output_name = alias.as_ref().unwrap_or(name);
1806        let canonical = map_yield_to_canonical(name);
1807
1808        match canonical.as_str() {
1809            "node" => {
1810                columns.extend(build_node_yield_columns(
1811                    &vids,
1812                    label,
1813                    output_name,
1814                    target_properties,
1815                    &props_map,
1816                    label_props,
1817                )?);
1818            }
1819            "distance" => {
1820                let mut builder = Float64Builder::with_capacity(num_rows);
1821                for dist in &distances {
1822                    builder.append_value(*dist as f64);
1823                }
1824                columns.push(Arc::new(builder.finish()));
1825            }
1826            "score" => {
1827                let mut builder = Float32Builder::with_capacity(num_rows);
1828                for score in &scores {
1829                    builder.append_value(*score);
1830                }
1831                columns.push(Arc::new(builder.finish()));
1832            }
1833            "vid" => {
1834                let mut builder = Int64Builder::with_capacity(num_rows);
1835                for vid in &vids {
1836                    builder.append_value(vid.as_u64() as i64);
1837                }
1838                columns.push(Arc::new(builder.finish()));
1839            }
1840            _ => {
1841                // Unknown yield — emit nulls
1842                let mut builder = StringBuilder::with_capacity(num_rows, 0);
1843                for _ in 0..num_rows {
1844                    builder.append_null();
1845                }
1846                columns.push(Arc::new(builder.finish()));
1847            }
1848        }
1849    }
1850
1851    let batch = RecordBatch::try_new(schema.clone(), columns)
1852        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1853    Ok(Some(batch))
1854}
1855
1856// ---------------------------------------------------------------------------
1857// Helpers
1858// ---------------------------------------------------------------------------
1859
1860/// Build the node-yield columns (_vid, variable, _labels, property columns) shared by
1861/// search result batch builders. Returns the columns to append.
1862fn build_node_yield_columns(
1863    vids: &[Vid],
1864    label: &str,
1865    output_name: &str,
1866    target_properties: &HashMap<String, Vec<String>>,
1867    props_map: &HashMap<Vid, uni_common::Properties>,
1868    label_props: Option<&std::collections::HashMap<String, uni_common::core::schema::PropertyMeta>>,
1869) -> DFResult<Vec<ArrayRef>> {
1870    let num_rows = vids.len();
1871    let mut columns = Vec::new();
1872
1873    // _vid column
1874    let mut vid_builder = UInt64Builder::with_capacity(num_rows);
1875    for vid in vids {
1876        vid_builder.append_value(vid.as_u64());
1877    }
1878    columns.push(Arc::new(vid_builder.finish()) as ArrayRef);
1879
1880    // variable column (VID as string)
1881    let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
1882    for vid in vids {
1883        var_builder.append_value(vid.to_string());
1884    }
1885    columns.push(Arc::new(var_builder.finish()) as ArrayRef);
1886
1887    // _labels column
1888    let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
1889    for _ in 0..num_rows {
1890        labels_builder.values().append_value(label);
1891        labels_builder.append(true);
1892    }
1893    columns.push(Arc::new(labels_builder.finish()) as ArrayRef);
1894
1895    // Property columns
1896    if let Some(props) = target_properties.get(output_name) {
1897        for prop_name in props {
1898            let data_type = resolve_property_type(prop_name, label_props);
1899            let column = crate::query::df_graph::scan::build_property_column_static(
1900                vids, props_map, prop_name, &data_type,
1901            )?;
1902            columns.push(column);
1903        }
1904    }
1905
1906    Ok(columns)
1907}
1908
1909/// Extract a vector from a Value.
1910fn extract_vector(val: &Value) -> DFResult<Vec<f32>> {
1911    match val {
1912        Value::Vector(vec) => Ok(vec.clone()),
1913        Value::List(arr) => {
1914            let mut vec = Vec::with_capacity(arr.len());
1915            for v in arr {
1916                if let Some(f) = v.as_f64() {
1917                    vec.push(f as f32);
1918                } else {
1919                    return Err(datafusion::error::DataFusionError::Execution(
1920                        "Query vector must contain numbers".to_string(),
1921                    ));
1922                }
1923            }
1924            Ok(vec)
1925        }
1926        _ => Err(datafusion::error::DataFusionError::Execution(
1927            "Query vector must be a list or vector".to_string(),
1928        )),
1929    }
1930}
1931
1932/// Calculate normalized score from distance based on distance metric.
1933fn calculate_score(distance: f32, metric: &uni_common::core::schema::DistanceMetric) -> f32 {
1934    match metric {
1935        uni_common::core::schema::DistanceMetric::Cosine => {
1936            // Cosine distance → similarity: (2 - d) / 2
1937            (2.0 - distance) / 2.0
1938        }
1939        uni_common::core::schema::DistanceMetric::Dot => distance,
1940        _ => {
1941            // L2 and others: 1 / (1 + d)
1942            1.0 / (1.0 + distance)
1943        }
1944    }
1945}