Skip to main content

uni_query/query/df_graph/
vector_knn.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Vector KNN search execution plan for DataFusion.
5//!
6//! This module provides [`GraphVectorKnnExec`], a DataFusion [`ExecutionPlan`] that
7//! performs vector similarity search using the underlying vector index.
8//!
9//! # Example
10//!
11//! ```text
12//! CALL uni.vector.query('Person', 'embedding', [0.1, 0.2, ...], 10)
13//! YIELD node, score
14//! ```
15
16use arrow_array::builder::{FixedSizeListBuilder, Float32Builder, StringBuilder, UInt64Builder};
17use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, Int64Array, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_common::core::id::Vid;
32use uni_common::core::schema::{DistanceMetric, PropertyMeta};
33use uni_cypher::ast::Expr;
34use uni_plugin::traits::index::{IndexHandle, IndexKind};
35
36use crate::query::df_graph::GraphExecutionContext;
37use crate::query::df_graph::common::{
38    arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
39};
40use crate::query::df_graph::scan::{property_field, resolve_property_type};
41
42/// Vector-retrieval source for a [`GraphVectorKnnExec`].
43///
44/// The exec is kind-agnostic above the retrieval step: threshold filter,
45/// score normalization, label / vid emission, and property hydration all
46/// run identically on the `Vec<(Vid, f32)>` produced here. Only the
47/// retrieval call differs:
48///
49/// - [`VectorSource::Native`] dispatches to
50///   `StorageManager::vector_search`, which routes through the built-in
51///   vector backend (Lance / memory / etc.).
52/// - [`VectorSource::Plugin`] dispatches to
53///   [`IndexHandle::probe`] on a host-registered plugin handle (see
54///   `PluginRegistry::register_index_handle`). The planner picks this
55///   variant when an index-name lookup against the plugin registry
56///   succeeds; this preserves the "no behavior change for built-ins"
57///   invariant — native indexes never register a handle so the
58///   fall-through is `Native`.
59#[derive(Clone)]
60pub(crate) enum VectorSource {
61    /// Native built-in vector backend (default).
62    Native,
63    /// Plugin-supplied live handle.
64    Plugin {
65        /// Kind that produced the handle. Informational; kept so the
66        /// planner-level dispatch log can include it.
67        #[allow(dead_code)]
68        kind: IndexKind,
69        /// The handle to probe.
70        handle: Arc<dyn IndexHandle>,
71    },
72}
73
74impl fmt::Debug for VectorSource {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            Self::Native => f.write_str("Native"),
78            Self::Plugin { kind, .. } => f.debug_struct("Plugin").field("kind", kind).finish(),
79        }
80    }
81}
82
83/// Vector KNN search execution plan.
84///
85/// Queries the vector index for the K nearest neighbors to a query vector,
86/// returning matching vertex IDs and similarity scores.
87pub struct GraphVectorKnnExec {
88    /// Graph execution context for storage access.
89    graph_ctx: Arc<GraphExecutionContext>,
90
91    /// Label ID to search in.
92    label_id: u16,
93
94    /// Label name for display.
95    label_name: String,
96
97    /// Variable name for result vertices.
98    variable: String,
99
100    /// Property name containing vector embeddings.
101    property: String,
102
103    /// Query vector expression.
104    query_expr: Expr,
105
106    /// Number of results to return.
107    k: usize,
108
109    /// Optional similarity threshold.
110    threshold: Option<f32>,
111
112    /// Query parameters for expression evaluation.
113    params: HashMap<String, Value>,
114
115    /// Target vertex properties to materialize.
116    target_properties: Vec<String>,
117
118    /// Output schema.
119    schema: SchemaRef,
120
121    /// Plan properties.
122    properties: Arc<PlanProperties>,
123
124    /// Vector-retrieval source. `Native` for the built-in path;
125    /// `Plugin { handle, .. }` when the planner found a registered
126    /// `IndexHandle` for this index's name in `PluginRegistry`.
127    source: VectorSource,
128
129    /// Execution metrics.
130    metrics: ExecutionPlanMetricsSet,
131}
132
133impl fmt::Debug for GraphVectorKnnExec {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_struct("GraphVectorKnnExec")
136            .field("label_id", &self.label_id)
137            .field("variable", &self.variable)
138            .field("property", &self.property)
139            .field("k", &self.k)
140            .field("threshold", &self.threshold)
141            .finish()
142    }
143}
144
145impl GraphVectorKnnExec {
146    /// Create a new vector KNN search execution plan.
147    ///
148    /// # Arguments
149    ///
150    /// * `graph_ctx` - Graph execution context
151    /// * `label_id` - Label ID to search
152    /// * `label_name` - Label name for display
153    /// * `variable` - Variable name for results
154    /// * `property` - Property containing vectors
155    /// * `query_expr` - Expression evaluating to query vector
156    /// * `k` - Number of results
157    /// * `threshold` - Optional similarity threshold
158    /// * `params` - Query parameters
159    #[expect(clippy::too_many_arguments)]
160    pub fn new(
161        graph_ctx: Arc<GraphExecutionContext>,
162        label_id: u16,
163        label_name: impl Into<String>,
164        variable: impl Into<String>,
165        property: impl Into<String>,
166        query_expr: Expr,
167        k: usize,
168        threshold: Option<f32>,
169        params: HashMap<String, Value>,
170        target_properties: Vec<String>,
171    ) -> Self {
172        let variable = variable.into();
173        let property = property.into();
174        let label_name = label_name.into();
175
176        // Resolve property types from schema
177        let uni_schema = graph_ctx.storage().schema_manager().schema();
178        let label_props = uni_schema.properties.get(label_name.as_str());
179
180        let schema = Self::build_schema(&variable, &target_properties, label_props);
181        let properties = compute_plan_properties(schema.clone());
182
183        Self {
184            graph_ctx,
185            label_id,
186            label_name,
187            variable,
188            property,
189            query_expr,
190            k,
191            threshold,
192            params,
193            target_properties,
194            schema,
195            properties,
196            source: VectorSource::Native,
197            metrics: ExecutionPlanMetricsSet::new(),
198        }
199    }
200
201    /// Create a new vector KNN search execution plan that dispatches
202    /// retrieval through a plugin-registered [`IndexHandle`] instead of
203    /// the native storage path.
204    ///
205    /// All other behavior (threshold, scoring, property hydration) is
206    /// identical to [`Self::new`].
207    #[expect(clippy::too_many_arguments)]
208    pub fn with_plugin_source(
209        graph_ctx: Arc<GraphExecutionContext>,
210        label_id: u16,
211        label_name: impl Into<String>,
212        variable: impl Into<String>,
213        property: impl Into<String>,
214        query_expr: Expr,
215        k: usize,
216        threshold: Option<f32>,
217        params: HashMap<String, Value>,
218        target_properties: Vec<String>,
219        kind: IndexKind,
220        handle: Arc<dyn IndexHandle>,
221    ) -> Self {
222        let mut exec = Self::new(
223            graph_ctx,
224            label_id,
225            label_name,
226            variable,
227            property,
228            query_expr,
229            k,
230            threshold,
231            params,
232            target_properties,
233        );
234        exec.source = VectorSource::Plugin { kind, handle };
235        exec
236    }
237
238    /// Build the output schema.
239    ///
240    /// Schema contains:
241    /// - `{variable}._vid` - Vertex ID
242    /// - `{variable}` - Variable identifier (as string for now)
243    /// - `{variable}._score` - Similarity score
244    /// - `{variable}.{prop}` - Property columns
245    fn build_schema(
246        variable: &str,
247        target_properties: &[String],
248        label_props: Option<&HashMap<String, PropertyMeta>>,
249    ) -> SchemaRef {
250        let mut fields = vec![
251            Field::new(format!("{}._vid", variable), DataType::UInt64, false),
252            Field::new(variable, DataType::Utf8, false),
253            Field::new(format!("{}._labels", variable), labels_data_type(), true),
254            Field::new(format!("{}._score", variable), DataType::Float32, true),
255        ];
256
257        // Add property columns
258        for prop_name in target_properties {
259            let col_name = format!("{}.{}", variable, prop_name);
260            let arrow_type = resolve_property_type(prop_name, label_props);
261            let uni_type = label_props
262                .and_then(|p| p.get(prop_name))
263                .map(|m| &m.r#type);
264            fields.push(property_field(&col_name, arrow_type, uni_type));
265        }
266
267        Arc::new(Schema::new(fields))
268    }
269
270    /// Evaluate the query expression to extract the query vector.
271    fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
272        let value = evaluate_simple_expr(&self.query_expr, &self.params, &HashMap::new())?;
273
274        match value {
275            Value::Vector(vec) => Ok(vec),
276            Value::List(arr) => {
277                let mut vec = Vec::with_capacity(arr.len());
278                for v in arr {
279                    if let Some(f) = v.as_f64() {
280                        vec.push(f as f32);
281                    } else {
282                        return Err(datafusion::error::DataFusionError::Execution(
283                            "Query vector must contain numbers".to_string(),
284                        ));
285                    }
286                }
287                Ok(vec)
288            }
289            _ => Err(datafusion::error::DataFusionError::Execution(
290                "Query vector must be a list or vector".to_string(),
291            )),
292        }
293    }
294}
295
296impl DisplayAs for GraphVectorKnnExec {
297    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        write!(
299            f,
300            "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
301            self.label_name, self.property, self.k, self.variable
302        )
303    }
304}
305
306impl ExecutionPlan for GraphVectorKnnExec {
307    fn name(&self) -> &str {
308        "GraphVectorKnnExec"
309    }
310
311    fn as_any(&self) -> &dyn Any {
312        self
313    }
314
315    fn schema(&self) -> SchemaRef {
316        self.schema.clone()
317    }
318
319    fn properties(&self) -> &Arc<PlanProperties> {
320        &self.properties
321    }
322
323    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
324        vec![]
325    }
326
327    fn with_new_children(
328        self: Arc<Self>,
329        children: Vec<Arc<dyn ExecutionPlan>>,
330    ) -> DFResult<Arc<dyn ExecutionPlan>> {
331        if !children.is_empty() {
332            return Err(datafusion::error::DataFusionError::Internal(
333                "GraphVectorKnnExec has no children".to_string(),
334            ));
335        }
336        Ok(self)
337    }
338
339    fn execute(
340        &self,
341        partition: usize,
342        _context: Arc<TaskContext>,
343    ) -> DFResult<SendableRecordBatchStream> {
344        let metrics = BaselineMetrics::new(&self.metrics, partition);
345
346        // Evaluate query vector upfront
347        let query_vector = self.evaluate_query_vector()?;
348
349        Ok(Box::pin(VectorKnnStream::new(
350            self.graph_ctx.clone(),
351            self.label_name.clone(),
352            self.variable.clone(),
353            self.property.clone(),
354            query_vector,
355            self.k,
356            self.threshold,
357            self.target_properties.clone(),
358            self.schema.clone(),
359            self.source.clone(),
360            metrics,
361        )))
362    }
363
364    fn metrics(&self) -> Option<MetricsSet> {
365        Some(self.metrics.clone_inner())
366    }
367}
368
369/// State machine for vector KNN stream.
370enum VectorKnnState {
371    /// Initial state, ready to start search.
372    Init,
373    /// Executing the async search.
374    Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
375    /// Stream is done.
376    Done,
377}
378
379/// Stream that executes vector KNN search.
380struct VectorKnnStream {
381    /// Graph execution context.
382    graph_ctx: Arc<GraphExecutionContext>,
383
384    /// Label name to search.
385    label_name: String,
386
387    /// Variable name for results.
388    variable: String,
389
390    /// Property name containing vectors.
391    property: String,
392
393    /// Query vector.
394    query_vector: Vec<f32>,
395
396    /// Number of results.
397    k: usize,
398
399    /// Similarity threshold.
400    threshold: Option<f32>,
401
402    /// Target vertex properties to materialize.
403    target_properties: Vec<String>,
404
405    /// Output schema.
406    schema: SchemaRef,
407
408    /// Vector-retrieval source (native or plugin handle).
409    source: VectorSource,
410
411    /// Stream state.
412    state: VectorKnnState,
413
414    /// Metrics.
415    metrics: BaselineMetrics,
416}
417
418impl VectorKnnStream {
419    #[expect(clippy::too_many_arguments)]
420    fn new(
421        graph_ctx: Arc<GraphExecutionContext>,
422        label_name: String,
423        variable: String,
424        property: String,
425        query_vector: Vec<f32>,
426        k: usize,
427        threshold: Option<f32>,
428        target_properties: Vec<String>,
429        schema: SchemaRef,
430        source: VectorSource,
431        metrics: BaselineMetrics,
432    ) -> Self {
433        Self {
434            graph_ctx,
435            label_name,
436            variable,
437            property,
438            query_vector,
439            k,
440            threshold,
441            target_properties,
442            schema,
443            source,
444            state: VectorKnnState::Init,
445            metrics,
446        }
447    }
448}
449
450impl Stream for VectorKnnStream {
451    type Item = DFResult<RecordBatch>;
452
453    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
454        let metrics = self.metrics.clone();
455        let _timer = metrics.elapsed_compute().timer();
456        loop {
457            let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
458
459            match state {
460                VectorKnnState::Init => {
461                    // Clone data for async block
462                    let graph_ctx = self.graph_ctx.clone();
463                    let label_name = self.label_name.clone();
464                    let variable = self.variable.clone();
465                    let property = self.property.clone();
466                    let query_vector = self.query_vector.clone();
467                    let k = self.k;
468                    let threshold = self.threshold;
469                    let target_properties = self.target_properties.clone();
470                    let schema = self.schema.clone();
471                    let source = self.source.clone();
472
473                    let fut = async move {
474                        // Check timeout
475                        graph_ctx.check_timeout().map_err(|e| {
476                            datafusion::error::DataFusionError::Execution(e.to_string())
477                        })?;
478
479                        execute_vector_search(
480                            &graph_ctx,
481                            &label_name,
482                            &variable,
483                            &property,
484                            &query_vector,
485                            k,
486                            threshold,
487                            &target_properties,
488                            &schema,
489                            &source,
490                        )
491                        .await
492                    };
493
494                    self.state = VectorKnnState::Executing(Box::pin(fut));
495                    // Continue loop to poll the future
496                }
497                VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
498                    Poll::Ready(Ok(batch)) => {
499                        self.state = VectorKnnState::Done;
500                        self.metrics
501                            .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
502                        return Poll::Ready(batch.map(Ok));
503                    }
504                    Poll::Ready(Err(e)) => {
505                        self.state = VectorKnnState::Done;
506                        return Poll::Ready(Some(Err(e)));
507                    }
508                    Poll::Pending => {
509                        self.state = VectorKnnState::Executing(fut);
510                        return Poll::Pending;
511                    }
512                },
513                VectorKnnState::Done => {
514                    return Poll::Ready(None);
515                }
516            }
517        }
518    }
519}
520
521impl RecordBatchStream for VectorKnnStream {
522    fn schema(&self) -> SchemaRef {
523        self.schema.clone()
524    }
525}
526
527/// Execute the vector search and build results.
528#[expect(clippy::too_many_arguments)]
529async fn execute_vector_search(
530    graph_ctx: &GraphExecutionContext,
531    label_name: &str,
532    variable: &str,
533    property: &str,
534    query_vector: &[f32],
535    k: usize,
536    threshold: Option<f32>,
537    target_properties: &[String],
538    schema: &SchemaRef,
539    source: &VectorSource,
540) -> DFResult<Option<RecordBatch>> {
541    let storage = graph_ctx.storage();
542
543    // Retrieve `(vid, distance)` pairs via the configured source.
544    let results =
545        retrieve_vid_scores(graph_ctx, label_name, property, query_vector, k, source).await?;
546
547    // Look up the distance metric for this vector property so we can
548    // convert raw distances into normalised similarity scores correctly.
549    let metric = storage
550        .schema_manager()
551        .schema()
552        .vector_index_for_property(label_name, property)
553        .map(|cfg| cfg.metric.clone())
554        .unwrap_or(DistanceMetric::L2);
555
556    // Filter by threshold and build result
557    let mut vids = Vec::new();
558    let mut scores = Vec::new();
559
560    for (vid, distance) in results {
561        let similarity = calculate_score(distance, &metric);
562
563        if let Some(thresh) = threshold
564            && similarity < thresh
565        {
566            continue;
567        }
568
569        vids.push(vid);
570        scores.push(similarity);
571    }
572
573    if vids.is_empty() {
574        return Ok(Some(RecordBatch::new_empty(schema.clone())));
575    }
576
577    // Build the base record batch (VID, variable, score)
578    let batch = build_result_batch(
579        &vids,
580        &scores,
581        variable,
582        target_properties,
583        label_name,
584        graph_ctx,
585        schema,
586    )
587    .await?;
588    Ok(Some(batch))
589}
590
591/// Retrieve `(Vid, distance)` pairs for the configured [`VectorSource`].
592///
593/// - [`VectorSource::Native`] delegates to `StorageManager::vector_search`,
594///   which routes through the built-in vector backend (Lance / memory).
595/// - [`VectorSource::Plugin`] builds a 1-row probe batch carrying the
596///   query vector as `FixedSizeList<Float32>`, calls
597///   [`IndexHandle::probe`], then extracts the `(vid: Int64, distance:
598///   Float32)` columns from the result. Plugin handles emit vids as
599///   `i64`; we widen via `as u64` because graph vids are stored as
600///   non-negative `u64` and test fixtures (and any sane real index) only
601///   produce non-negative integers.
602async fn retrieve_vid_scores(
603    graph_ctx: &GraphExecutionContext,
604    label_name: &str,
605    property: &str,
606    query_vector: &[f32],
607    k: usize,
608    source: &VectorSource,
609) -> DFResult<Vec<(Vid, f32)>> {
610    match source {
611        VectorSource::Native => {
612            let storage = graph_ctx.storage();
613            let query_ctx = graph_ctx.query_context();
614            storage
615                .vector_search(
616                    label_name,
617                    property,
618                    query_vector,
619                    k,
620                    None,
621                    Some(&query_ctx),
622                )
623                .await
624                .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
625        }
626        VectorSource::Plugin { handle, .. } => {
627            // Build a single-row query batch:
628            //     [ vector: FixedSizeList<Float32, dim> ]
629            let dim = i32::try_from(query_vector.len()).map_err(|_| {
630                datafusion::error::DataFusionError::Execution(
631                    "query vector exceeds i32::MAX dimensions".to_string(),
632                )
633            })?;
634            let item_field = Arc::new(Field::new("item", DataType::Float32, true));
635            let mut fsl_builder =
636                FixedSizeListBuilder::new(Float32Builder::with_capacity(query_vector.len()), dim)
637                    .with_field(Arc::clone(&item_field));
638            for &v in query_vector {
639                fsl_builder.values().append_value(v);
640            }
641            fsl_builder.append(true);
642            let fsl: FixedSizeListArray = fsl_builder.finish();
643
644            let query_schema = Arc::new(Schema::new(vec![Field::new(
645                "vector",
646                DataType::FixedSizeList(item_field, dim),
647                false,
648            )]));
649            let query_batch =
650                RecordBatch::try_new(query_schema, vec![Arc::new(fsl)]).map_err(arrow_err)?;
651
652            let result = handle.probe(&query_batch, k).map_err(|e| {
653                datafusion::error::DataFusionError::Execution(format!(
654                    "IndexHandle::probe failed: {e:?}"
655                ))
656            })?;
657
658            // Result schema is `[vid: Int64, distance: Float32]` per the
659            // `IndexHandle` trait contract.
660            let vid_col = result
661                .column_by_name("vid")
662                .ok_or_else(|| {
663                    datafusion::error::DataFusionError::Execution(
664                        "IndexHandle::probe result missing `vid` column".to_string(),
665                    )
666                })?
667                .as_any()
668                .downcast_ref::<Int64Array>()
669                .ok_or_else(|| {
670                    datafusion::error::DataFusionError::Execution(
671                        "IndexHandle::probe result `vid` column is not Int64".to_string(),
672                    )
673                })?;
674            let dist_col = result
675                .column_by_name("distance")
676                .ok_or_else(|| {
677                    datafusion::error::DataFusionError::Execution(
678                        "IndexHandle::probe result missing `distance` column".to_string(),
679                    )
680                })?
681                .as_any()
682                .downcast_ref::<Float32Array>()
683                .ok_or_else(|| {
684                    datafusion::error::DataFusionError::Execution(
685                        "IndexHandle::probe result `distance` column is not Float32".to_string(),
686                    )
687                })?;
688
689            let mut pairs = Vec::with_capacity(result.num_rows());
690            for i in 0..result.num_rows() {
691                if vid_col.is_null(i) {
692                    continue;
693                }
694                let vid_i64 = vid_col.value(i);
695                let dist = if dist_col.is_null(i) {
696                    f32::INFINITY
697                } else {
698                    dist_col.value(i)
699                };
700                pairs.push((Vid::from(vid_i64 as u64), dist));
701            }
702            Ok(pairs)
703        }
704    }
705}
706
707/// Build a result batch from VIDs and scores, including hydrated properties.
708async fn build_result_batch(
709    vids: &[Vid],
710    scores: &[f32],
711    _variable: &str,
712    target_properties: &[String],
713    label_name: &str,
714    graph_ctx: &GraphExecutionContext,
715    schema: &SchemaRef,
716) -> DFResult<RecordBatch> {
717    let num_rows = vids.len();
718
719    // Build _vid column
720    let mut vid_builder = UInt64Builder::with_capacity(num_rows);
721    for vid in vids {
722        vid_builder.append_value(vid.as_u64());
723    }
724
725    // Build variable column (VID as string for now)
726    let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
727    for vid in vids {
728        var_builder.append_value(vid.to_string());
729    }
730
731    // Build _labels column
732    let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
733    for _vid in vids {
734        labels_builder.values().append_value(label_name);
735        labels_builder.append(true);
736    }
737
738    // Build score column
739    let mut score_builder = Float32Builder::with_capacity(num_rows);
740    for &score in scores {
741        score_builder.append_value(score);
742    }
743
744    let mut columns: Vec<ArrayRef> = vec![
745        Arc::new(vid_builder.finish()),
746        Arc::new(var_builder.finish()),
747        Arc::new(labels_builder.finish()),
748        Arc::new(score_builder.finish()),
749    ];
750
751    // Hydrate property columns
752    if !target_properties.is_empty() {
753        let property_manager = graph_ctx.property_manager();
754        let query_ctx = graph_ctx.query_context();
755
756        let props_map = property_manager
757            .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
758            .await
759            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
760
761        let uni_schema = graph_ctx.storage().schema_manager().schema();
762        let label_props = uni_schema.properties.get(label_name);
763
764        for prop_name in target_properties {
765            let data_type = resolve_property_type(prop_name, label_props);
766            let column = crate::query::df_graph::scan::build_property_column_static(
767                vids, &props_map, prop_name, &data_type,
768            )?;
769            columns.push(column);
770        }
771    }
772
773    RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
774}
775
776#[cfg(test)]
777mod tests {
778    use super::*;
779    use uni_cypher::ast::CypherLiteral;
780
781    #[test]
782    fn test_build_schema() {
783        let schema = GraphVectorKnnExec::build_schema("n", &[], None);
784
785        assert_eq!(schema.fields().len(), 4);
786        assert_eq!(schema.field(0).name(), "n._vid");
787        assert_eq!(schema.field(1).name(), "n");
788        assert_eq!(schema.field(2).name(), "n._labels");
789        assert_eq!(schema.field(3).name(), "n._score");
790    }
791
792    #[test]
793    fn test_evaluate_literal_list() {
794        let expr = Expr::List(vec![
795            Expr::Literal(CypherLiteral::Float(0.1)),
796            Expr::Literal(CypherLiteral::Float(0.2)),
797            Expr::Literal(CypherLiteral::Float(0.3)),
798        ]);
799
800        let result = evaluate_simple_expr(&expr, &HashMap::new(), &HashMap::new()).unwrap();
801        match result {
802            Value::List(arr) => {
803                assert_eq!(arr.len(), 3);
804            }
805            _ => panic!("Expected list"),
806        }
807    }
808
809    #[test]
810    fn test_evaluate_parameter() {
811        let expr = Expr::Parameter("query".to_string());
812        let mut params = HashMap::new();
813        params.insert(
814            "query".to_string(),
815            Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
816        );
817
818        let result = evaluate_simple_expr(&expr, &params, &HashMap::new()).unwrap();
819        match result {
820            Value::List(arr) => {
821                assert_eq!(arr.len(), 2);
822            }
823            _ => panic!("Expected list"),
824        }
825    }
826
827    #[test]
828    fn test_build_schema_with_extra_properties() {
829        let extra_props = vec!["name".to_string(), "embedding".to_string()];
830        let schema = GraphVectorKnnExec::build_schema("doc", &extra_props, None);
831
832        // Should have base fields + extra properties
833        assert!(schema.field_with_name("doc._vid").is_ok());
834        assert!(schema.field_with_name("doc").is_ok());
835        assert!(schema.field_with_name("doc._score").is_ok());
836        assert!(
837            schema.field_with_name("doc.name").is_ok(),
838            "Extra property 'name' should be in schema"
839        );
840        assert!(
841            schema.field_with_name("doc.embedding").is_ok(),
842            "Extra property 'embedding' should be in schema"
843        );
844    }
845
846    #[test]
847    fn test_evaluate_variable() {
848        // Test that a variable expression resolves to the variable's value
849        let expr = Expr::Variable("x".to_string());
850        let mut variables = HashMap::new();
851        variables.insert(
852            "x".to_string(),
853            Value::List(vec![Value::Float(0.5), Value::Float(0.6)]),
854        );
855
856        let result = evaluate_simple_expr(&expr, &HashMap::new(), &variables).unwrap();
857        match result {
858            Value::List(arr) => {
859                assert_eq!(arr.len(), 2);
860            }
861            _ => panic!("Expected list, got {:?}", result),
862        }
863    }
864}