Skip to main content

uni_query/query/df_graph/
vector_knn.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Vector KNN search execution plan for DataFusion.
5//!
6//! This module provides [`GraphVectorKnnExec`], a DataFusion [`ExecutionPlan`] that
7//! performs vector similarity search using the underlying vector index.
8//!
9//! # Example
10//!
11//! ```text
12//! CALL uni.vector.query('Person', 'embedding', [0.1, 0.2, ...], 10)
13//! YIELD node, score
14//! ```
15
16use arrow_array::builder::{FixedSizeListBuilder, Float32Builder, StringBuilder, UInt64Builder};
17use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, Int64Array, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_common::core::id::Vid;
32use uni_common::core::schema::{DistanceMetric, PropertyMeta};
33use uni_cypher::ast::Expr;
34use uni_plugin::traits::index::{IndexHandle, IndexKind};
35
36use crate::query::df_graph::GraphExecutionContext;
37use crate::query::df_graph::common::{
38    arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
39};
40use crate::query::df_graph::scan::{property_field, resolve_property_type};
41
42/// Vector-retrieval source for a [`GraphVectorKnnExec`].
43///
44/// The exec is kind-agnostic above the retrieval step: threshold filter,
45/// score normalization, label / vid emission, and property hydration all
46/// run identically on the `Vec<(Vid, f32)>` produced here. Only the
47/// retrieval call differs:
48///
49/// - [`VectorSource::Native`] dispatches to
50///   `StorageManager::vector_search`, which routes through the built-in
51///   vector backend (Lance / memory / etc.).
52/// - [`VectorSource::Plugin`] dispatches to
53///   [`IndexHandle::probe`] on a host-registered plugin handle (see
54///   `PluginRegistry::register_index_handle`). The planner picks this
55///   variant when an index-name lookup against the plugin registry
56///   succeeds; this preserves the "no behavior change for built-ins"
57///   invariant — native indexes never register a handle so the
58///   fall-through is `Native`.
59#[derive(Clone)]
60pub(crate) enum VectorSource {
61    /// Native built-in vector backend (default).
62    Native,
63    /// Plugin-supplied live handle.
64    Plugin {
65        /// Kind that produced the handle. Informational; kept so the
66        /// planner-level dispatch log can include it.
67        #[allow(dead_code)]
68        kind: IndexKind,
69        /// The handle to probe.
70        handle: Arc<dyn IndexHandle>,
71    },
72}
73
74impl fmt::Debug for VectorSource {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            Self::Native => f.write_str("Native"),
78            Self::Plugin { kind, .. } => f.debug_struct("Plugin").field("kind", kind).finish(),
79        }
80    }
81}
82
83/// Vector KNN search execution plan.
84///
85/// Queries the vector index for the K nearest neighbors to a query vector,
86/// returning matching vertex IDs and similarity scores.
87pub struct GraphVectorKnnExec {
88    /// Graph execution context for storage access.
89    graph_ctx: Arc<GraphExecutionContext>,
90
91    /// Label ID to search in.
92    label_id: u16,
93
94    /// Label name for display.
95    label_name: String,
96
97    /// Variable name for result vertices.
98    variable: String,
99
100    /// Property name containing vector embeddings.
101    property: String,
102
103    /// Query vector expression.
104    query_expr: Expr,
105
106    /// Number of results to return.
107    k: usize,
108
109    /// Optional similarity threshold.
110    threshold: Option<f32>,
111
112    /// Query parameters for expression evaluation.
113    params: HashMap<String, Value>,
114
115    /// Target vertex properties to materialize.
116    target_properties: Vec<String>,
117
118    /// Output schema.
119    schema: SchemaRef,
120
121    /// Plan properties.
122    properties: Arc<PlanProperties>,
123
124    /// Vector-retrieval source. `Native` for the built-in path;
125    /// `Plugin { handle, .. }` when the planner found a registered
126    /// `IndexHandle` for this index's name in `PluginRegistry`.
127    source: VectorSource,
128
129    /// Execution metrics.
130    metrics: ExecutionPlanMetricsSet,
131}
132
133impl fmt::Debug for GraphVectorKnnExec {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_struct("GraphVectorKnnExec")
136            .field("label_id", &self.label_id)
137            .field("variable", &self.variable)
138            .field("property", &self.property)
139            .field("k", &self.k)
140            .field("threshold", &self.threshold)
141            .finish()
142    }
143}
144
145impl GraphVectorKnnExec {
146    /// Create a new vector KNN search execution plan.
147    ///
148    /// # Arguments
149    ///
150    /// * `graph_ctx` - Graph execution context
151    /// * `label_id` - Label ID to search
152    /// * `label_name` - Label name for display
153    /// * `variable` - Variable name for results
154    /// * `property` - Property containing vectors
155    /// * `query_expr` - Expression evaluating to query vector
156    /// * `k` - Number of results
157    /// * `threshold` - Optional similarity threshold
158    /// * `params` - Query parameters
159    #[expect(clippy::too_many_arguments)]
160    pub fn new(
161        graph_ctx: Arc<GraphExecutionContext>,
162        label_id: u16,
163        label_name: impl Into<String>,
164        variable: impl Into<String>,
165        property: impl Into<String>,
166        query_expr: Expr,
167        k: usize,
168        threshold: Option<f32>,
169        params: HashMap<String, Value>,
170        target_properties: Vec<String>,
171    ) -> Self {
172        let variable = variable.into();
173        let property = property.into();
174        let label_name = label_name.into();
175
176        // Resolve property types from schema
177        let uni_schema = graph_ctx.storage().schema_manager().schema();
178        let label_props = uni_schema.properties.get(label_name.as_str());
179
180        let schema = Self::build_schema(&variable, &target_properties, label_props);
181        let properties = compute_plan_properties(schema.clone());
182
183        Self {
184            graph_ctx,
185            label_id,
186            label_name,
187            variable,
188            property,
189            query_expr,
190            k,
191            threshold,
192            params,
193            target_properties,
194            schema,
195            properties,
196            source: VectorSource::Native,
197            metrics: ExecutionPlanMetricsSet::new(),
198        }
199    }
200
201    /// Create a new vector KNN search execution plan that dispatches
202    /// retrieval through a plugin-registered [`IndexHandle`] instead of
203    /// the native storage path.
204    ///
205    /// All other behavior (threshold, scoring, property hydration) is
206    /// identical to [`Self::new`].
207    #[expect(clippy::too_many_arguments)]
208    pub fn with_plugin_source(
209        graph_ctx: Arc<GraphExecutionContext>,
210        label_id: u16,
211        label_name: impl Into<String>,
212        variable: impl Into<String>,
213        property: impl Into<String>,
214        query_expr: Expr,
215        k: usize,
216        threshold: Option<f32>,
217        params: HashMap<String, Value>,
218        target_properties: Vec<String>,
219        kind: IndexKind,
220        handle: Arc<dyn IndexHandle>,
221    ) -> Self {
222        let mut exec = Self::new(
223            graph_ctx,
224            label_id,
225            label_name,
226            variable,
227            property,
228            query_expr,
229            k,
230            threshold,
231            params,
232            target_properties,
233        );
234        exec.source = VectorSource::Plugin { kind, handle };
235        exec
236    }
237
238    /// Build the output schema.
239    ///
240    /// Schema contains:
241    /// - `{variable}._vid` - Vertex ID
242    /// - `{variable}` - Variable identifier (as string for now)
243    /// - `{variable}._score` - Similarity score
244    /// - `{variable}.{prop}` - Property columns
245    fn build_schema(
246        variable: &str,
247        target_properties: &[String],
248        label_props: Option<&HashMap<String, PropertyMeta>>,
249    ) -> SchemaRef {
250        let mut fields = vec![
251            Field::new(format!("{}._vid", variable), DataType::UInt64, false),
252            Field::new(variable, DataType::Utf8, false),
253            Field::new(format!("{}._labels", variable), labels_data_type(), true),
254            Field::new(format!("{}._score", variable), DataType::Float32, true),
255        ];
256
257        // Add property columns
258        for prop_name in target_properties {
259            let col_name = format!("{}.{}", variable, prop_name);
260            let arrow_type = resolve_property_type(prop_name, label_props);
261            let uni_type = label_props
262                .and_then(|p| p.get(prop_name))
263                .map(|m| &m.r#type);
264            fields.push(property_field(&col_name, arrow_type, uni_type));
265        }
266
267        Arc::new(Schema::new(fields))
268    }
269
270    /// Evaluate the query expression to extract the query vector.
271    fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
272        let value = evaluate_simple_expr(&self.query_expr, &self.params, &HashMap::new())?;
273
274        match value {
275            Value::Vector(vec) => Ok(vec),
276            Value::List(arr) => {
277                let mut vec = Vec::with_capacity(arr.len());
278                for v in arr {
279                    if let Some(f) = v.as_f64() {
280                        vec.push(f as f32);
281                    } else {
282                        return Err(datafusion::error::DataFusionError::Execution(
283                            "Query vector must contain numbers".to_string(),
284                        ));
285                    }
286                }
287                Ok(vec)
288            }
289            _ => Err(datafusion::error::DataFusionError::Execution(
290                "Query vector must be a list or vector".to_string(),
291            )),
292        }
293    }
294
295    /// Evaluate the query expression to a multi-vector (a list of token vectors).
296    fn evaluate_query_multivector(&self) -> DFResult<Vec<Vec<f32>>> {
297        let value = evaluate_simple_expr(&self.query_expr, &self.params, &HashMap::new())?;
298        let Value::List(tokens) = value else {
299            return Err(datafusion::error::DataFusionError::Execution(
300                "Multi-vector query must be a list of vectors".to_string(),
301            ));
302        };
303        tokens
304            .into_iter()
305            .map(|tok| match tok {
306                Value::Vector(v) => Ok(v),
307                Value::List(inner) => inner
308                    .iter()
309                    .map(|x| {
310                        x.as_f64().map(|f| f as f32).ok_or_else(|| {
311                            datafusion::error::DataFusionError::Execution(
312                                "Multi-vector query token must contain numbers".to_string(),
313                            )
314                        })
315                    })
316                    .collect(),
317                _ => Err(datafusion::error::DataFusionError::Execution(
318                    "Multi-vector query must be a list of vectors".to_string(),
319                )),
320            })
321            .collect()
322    }
323
324    /// Whether the queried property is a multi-vector (`List<FixedSizeList>`) column.
325    fn is_multivector_property(&self) -> bool {
326        let uni_schema = self.graph_ctx.storage().schema_manager().schema();
327        let label_props = uni_schema.properties.get(self.label_name.as_str());
328        matches!(
329            resolve_property_type(&self.property, label_props),
330            DataType::List(ref inner)
331                if matches!(inner.data_type(), DataType::FixedSizeList(_, _))
332        )
333    }
334}
335
336impl DisplayAs for GraphVectorKnnExec {
337    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338        write!(
339            f,
340            "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
341            self.label_name, self.property, self.k, self.variable
342        )
343    }
344}
345
346impl ExecutionPlan for GraphVectorKnnExec {
347    fn name(&self) -> &str {
348        "GraphVectorKnnExec"
349    }
350
351    fn as_any(&self) -> &dyn Any {
352        self
353    }
354
355    fn schema(&self) -> SchemaRef {
356        self.schema.clone()
357    }
358
359    fn properties(&self) -> &Arc<PlanProperties> {
360        &self.properties
361    }
362
363    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
364        vec![]
365    }
366
367    fn with_new_children(
368        self: Arc<Self>,
369        children: Vec<Arc<dyn ExecutionPlan>>,
370    ) -> DFResult<Arc<dyn ExecutionPlan>> {
371        if !children.is_empty() {
372            return Err(datafusion::error::DataFusionError::Internal(
373                "GraphVectorKnnExec has no children".to_string(),
374            ));
375        }
376        Ok(self)
377    }
378
379    fn execute(
380        &self,
381        partition: usize,
382        _context: Arc<TaskContext>,
383    ) -> DFResult<SendableRecordBatchStream> {
384        let metrics = BaselineMetrics::new(&self.metrics, partition);
385
386        // Evaluate the query upfront: a multi-vector (ColBERT) property takes a
387        // list of token vectors and routes to MaxSim retrieval; a dense property
388        // takes a single vector.
389        let (query_vector, multivec_query) = if self.is_multivector_property() {
390            (Vec::new(), Some(self.evaluate_query_multivector()?))
391        } else {
392            (self.evaluate_query_vector()?, None)
393        };
394
395        Ok(Box::pin(VectorKnnStream::new(
396            self.graph_ctx.clone(),
397            self.label_name.clone(),
398            self.variable.clone(),
399            self.property.clone(),
400            query_vector,
401            multivec_query,
402            self.k,
403            self.threshold,
404            self.target_properties.clone(),
405            self.schema.clone(),
406            self.source.clone(),
407            metrics,
408        )))
409    }
410
411    fn metrics(&self) -> Option<MetricsSet> {
412        Some(self.metrics.clone_inner())
413    }
414}
415
416/// State machine for vector KNN stream.
417enum VectorKnnState {
418    /// Initial state, ready to start search.
419    Init,
420    /// Executing the async search.
421    Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
422    /// Stream is done.
423    Done,
424}
425
426/// Stream that executes vector KNN search.
427struct VectorKnnStream {
428    /// Graph execution context.
429    graph_ctx: Arc<GraphExecutionContext>,
430
431    /// Label name to search.
432    label_name: String,
433
434    /// Variable name for results.
435    variable: String,
436
437    /// Property name containing vectors.
438    property: String,
439
440    /// Query vector (dense path).
441    query_vector: Vec<f32>,
442
443    /// Query multi-vector (ColBERT / MaxSim path); `Some` when the queried
444    /// property is a `List<Vector>` column.
445    multivec_query: Option<Vec<Vec<f32>>>,
446
447    /// Number of results.
448    k: usize,
449
450    /// Similarity threshold.
451    threshold: Option<f32>,
452
453    /// Target vertex properties to materialize.
454    target_properties: Vec<String>,
455
456    /// Output schema.
457    schema: SchemaRef,
458
459    /// Vector-retrieval source (native or plugin handle).
460    source: VectorSource,
461
462    /// Stream state.
463    state: VectorKnnState,
464
465    /// Metrics.
466    metrics: BaselineMetrics,
467}
468
469impl VectorKnnStream {
470    #[expect(clippy::too_many_arguments)]
471    fn new(
472        graph_ctx: Arc<GraphExecutionContext>,
473        label_name: String,
474        variable: String,
475        property: String,
476        query_vector: Vec<f32>,
477        multivec_query: Option<Vec<Vec<f32>>>,
478        k: usize,
479        threshold: Option<f32>,
480        target_properties: Vec<String>,
481        schema: SchemaRef,
482        source: VectorSource,
483        metrics: BaselineMetrics,
484    ) -> Self {
485        Self {
486            graph_ctx,
487            label_name,
488            variable,
489            property,
490            query_vector,
491            multivec_query,
492            k,
493            threshold,
494            target_properties,
495            schema,
496            source,
497            state: VectorKnnState::Init,
498            metrics,
499        }
500    }
501}
502
503impl Stream for VectorKnnStream {
504    type Item = DFResult<RecordBatch>;
505
506    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
507        let metrics = self.metrics.clone();
508        let _timer = metrics.elapsed_compute().timer();
509        loop {
510            let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
511
512            match state {
513                VectorKnnState::Init => {
514                    // Clone data for async block
515                    let graph_ctx = self.graph_ctx.clone();
516                    let label_name = self.label_name.clone();
517                    let variable = self.variable.clone();
518                    let property = self.property.clone();
519                    let query_vector = self.query_vector.clone();
520                    let multivec_query = self.multivec_query.clone();
521                    let k = self.k;
522                    let threshold = self.threshold;
523                    let target_properties = self.target_properties.clone();
524                    let schema = self.schema.clone();
525                    let source = self.source.clone();
526
527                    let fut = async move {
528                        // Check timeout
529                        graph_ctx.check_timeout().map_err(|e| {
530                            datafusion::error::DataFusionError::Execution(e.to_string())
531                        })?;
532
533                        execute_vector_search(
534                            &graph_ctx,
535                            &label_name,
536                            &variable,
537                            &property,
538                            &query_vector,
539                            multivec_query.as_deref(),
540                            k,
541                            threshold,
542                            &target_properties,
543                            &schema,
544                            &source,
545                        )
546                        .await
547                    };
548
549                    self.state = VectorKnnState::Executing(Box::pin(fut));
550                    // Continue loop to poll the future
551                }
552                VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
553                    Poll::Ready(Ok(batch)) => {
554                        self.state = VectorKnnState::Done;
555                        self.metrics
556                            .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
557                        return Poll::Ready(batch.map(Ok));
558                    }
559                    Poll::Ready(Err(e)) => {
560                        self.state = VectorKnnState::Done;
561                        return Poll::Ready(Some(Err(e)));
562                    }
563                    Poll::Pending => {
564                        self.state = VectorKnnState::Executing(fut);
565                        return Poll::Pending;
566                    }
567                },
568                VectorKnnState::Done => {
569                    return Poll::Ready(None);
570                }
571            }
572        }
573    }
574}
575
576impl RecordBatchStream for VectorKnnStream {
577    fn schema(&self) -> SchemaRef {
578        self.schema.clone()
579    }
580}
581
582/// Execute the vector search and build results.
583#[expect(clippy::too_many_arguments)]
584async fn execute_vector_search(
585    graph_ctx: &GraphExecutionContext,
586    label_name: &str,
587    variable: &str,
588    property: &str,
589    query_vector: &[f32],
590    multivec_query: Option<&[Vec<f32>]>,
591    k: usize,
592    threshold: Option<f32>,
593    target_properties: &[String],
594    schema: &SchemaRef,
595    source: &VectorSource,
596) -> DFResult<Option<RecordBatch>> {
597    let storage = graph_ctx.storage();
598
599    // Retrieve `(vid, distance)` pairs via the configured source.
600    let results = retrieve_vid_scores(
601        graph_ctx,
602        label_name,
603        property,
604        query_vector,
605        multivec_query,
606        k,
607        source,
608    )
609    .await?;
610
611    // Look up the distance metric for this vector property so we can
612    // convert raw distances into normalised similarity scores correctly.
613    // Multi-vector (ColBERT) defaults to Cosine; dense defaults to L2.
614    let default_metric = if multivec_query.is_some() {
615        DistanceMetric::Cosine
616    } else {
617        DistanceMetric::L2
618    };
619    let metric = storage
620        .schema_manager()
621        .schema()
622        .vector_index_for_property(label_name, property)
623        .map(|cfg| cfg.metric.clone())
624        .unwrap_or(default_metric);
625
626    // Filter by threshold and build result
627    let mut vids = Vec::new();
628    let mut scores = Vec::new();
629
630    for (vid, value) in results {
631        // Multi-vector (ColBERT) results are already exact MaxSim similarities
632        // (higher is better); dense distances are converted to a similarity here.
633        let similarity = if multivec_query.is_some() {
634            value
635        } else {
636            calculate_score(value, &metric)
637        };
638
639        if let Some(thresh) = threshold
640            && similarity < thresh
641        {
642            continue;
643        }
644
645        vids.push(vid);
646        scores.push(similarity);
647    }
648
649    if vids.is_empty() {
650        return Ok(Some(RecordBatch::new_empty(schema.clone())));
651    }
652
653    // Build the base record batch (VID, variable, score)
654    let batch = build_result_batch(
655        &vids,
656        &scores,
657        variable,
658        target_properties,
659        label_name,
660        graph_ctx,
661        schema,
662    )
663    .await?;
664    Ok(Some(batch))
665}
666
667/// Retrieve `(Vid, distance)` pairs for the configured [`VectorSource`].
668///
669/// - [`VectorSource::Native`] delegates to `StorageManager::vector_search`,
670///   which routes through the built-in vector backend (Lance / memory).
671/// - [`VectorSource::Plugin`] builds a 1-row probe batch carrying the
672///   query vector as `FixedSizeList<Float32>`, calls
673///   [`IndexHandle::probe`], then extracts the `(vid: Int64, distance:
674///   Float32)` columns from the result. Plugin handles emit vids as
675///   `i64`; we widen via `as u64` because graph vids are stored as
676///   non-negative `u64` and test fixtures (and any sane real index) only
677///   produce non-negative integers.
678async fn retrieve_vid_scores(
679    graph_ctx: &GraphExecutionContext,
680    label_name: &str,
681    property: &str,
682    query_vector: &[f32],
683    multivec_query: Option<&[Vec<f32>]>,
684    k: usize,
685    source: &VectorSource,
686) -> DFResult<Vec<(Vid, f32)>> {
687    match source {
688        VectorSource::Native => {
689            let storage = graph_ctx.storage();
690            let query_ctx = graph_ctx.query_context();
691            // A multi-vector property routes to MaxSim retrieval with L0
692            // visibility: Lance generates candidates over flushed data and the
693            // shared re-ranker merges live L0 rows and re-scores by exact MaxSim.
694            // The inline predicate path uses default ANN tuning (nprobes/refine
695            // are set via the `uni.vector.query` options map, which a predicate
696            // cannot express) and the default over-fetch.
697            if let Some(mv) = multivec_query {
698                let property_manager = graph_ctx.property_manager();
699                let metric = storage
700                    .schema_manager()
701                    .schema()
702                    .vector_index_for_property(label_name, property)
703                    .map(|cfg| cfg.metric.clone())
704                    .unwrap_or(DistanceMetric::Cosine);
705                let retrieval_k = k
706                    .saturating_mul(
707                        crate::query::df_graph::search_procedures::MULTIVECTOR_OVER_FETCH,
708                    )
709                    .max(k);
710                let (ranked, _props) =
711                    crate::query::df_graph::search_procedures::multivector_rerank(
712                        storage,
713                        property_manager,
714                        &query_ctx,
715                        label_name,
716                        property,
717                        mv,
718                        k,
719                        retrieval_k,
720                        None,
721                        uni_store::VectorQueryOpts::default(),
722                        &metric,
723                    )
724                    .await?;
725                return Ok(ranked);
726            }
727            storage
728                .vector_search(
729                    label_name,
730                    property,
731                    query_vector,
732                    k,
733                    None,
734                    uni_store::VectorQueryOpts::default(),
735                    Some(&query_ctx),
736                )
737                .await
738                .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
739        }
740        VectorSource::Plugin { handle, .. } => {
741            // Build a single-row query batch:
742            //     [ vector: FixedSizeList<Float32, dim> ]
743            let dim = i32::try_from(query_vector.len()).map_err(|_| {
744                datafusion::error::DataFusionError::Execution(
745                    "query vector exceeds i32::MAX dimensions".to_string(),
746                )
747            })?;
748            let item_field = Arc::new(Field::new("item", DataType::Float32, true));
749            let mut fsl_builder =
750                FixedSizeListBuilder::new(Float32Builder::with_capacity(query_vector.len()), dim)
751                    .with_field(Arc::clone(&item_field));
752            for &v in query_vector {
753                fsl_builder.values().append_value(v);
754            }
755            fsl_builder.append(true);
756            let fsl: FixedSizeListArray = fsl_builder.finish();
757
758            let query_schema = Arc::new(Schema::new(vec![Field::new(
759                "vector",
760                DataType::FixedSizeList(item_field, dim),
761                false,
762            )]));
763            let query_batch =
764                RecordBatch::try_new(query_schema, vec![Arc::new(fsl)]).map_err(arrow_err)?;
765
766            let result = handle.probe(&query_batch, k).map_err(|e| {
767                datafusion::error::DataFusionError::Execution(format!(
768                    "IndexHandle::probe failed: {e:?}"
769                ))
770            })?;
771
772            // Result schema is `[vid: Int64, distance: Float32]` per the
773            // `IndexHandle` trait contract.
774            let vid_col = result
775                .column_by_name("vid")
776                .ok_or_else(|| {
777                    datafusion::error::DataFusionError::Execution(
778                        "IndexHandle::probe result missing `vid` column".to_string(),
779                    )
780                })?
781                .as_any()
782                .downcast_ref::<Int64Array>()
783                .ok_or_else(|| {
784                    datafusion::error::DataFusionError::Execution(
785                        "IndexHandle::probe result `vid` column is not Int64".to_string(),
786                    )
787                })?;
788            let dist_col = result
789                .column_by_name("distance")
790                .ok_or_else(|| {
791                    datafusion::error::DataFusionError::Execution(
792                        "IndexHandle::probe result missing `distance` column".to_string(),
793                    )
794                })?
795                .as_any()
796                .downcast_ref::<Float32Array>()
797                .ok_or_else(|| {
798                    datafusion::error::DataFusionError::Execution(
799                        "IndexHandle::probe result `distance` column is not Float32".to_string(),
800                    )
801                })?;
802
803            let mut pairs = Vec::with_capacity(result.num_rows());
804            for i in 0..result.num_rows() {
805                if vid_col.is_null(i) {
806                    continue;
807                }
808                let vid_i64 = vid_col.value(i);
809                let dist = if dist_col.is_null(i) {
810                    f32::INFINITY
811                } else {
812                    dist_col.value(i)
813                };
814                pairs.push((Vid::from(vid_i64 as u64), dist));
815            }
816            Ok(pairs)
817        }
818    }
819}
820
821/// Build a result batch from VIDs and scores, including hydrated properties.
822async fn build_result_batch(
823    vids: &[Vid],
824    scores: &[f32],
825    _variable: &str,
826    target_properties: &[String],
827    label_name: &str,
828    graph_ctx: &GraphExecutionContext,
829    schema: &SchemaRef,
830) -> DFResult<RecordBatch> {
831    let num_rows = vids.len();
832
833    // Build _vid column
834    let mut vid_builder = UInt64Builder::with_capacity(num_rows);
835    for vid in vids {
836        vid_builder.append_value(vid.as_u64());
837    }
838
839    // Build variable column (VID as string for now)
840    let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
841    for vid in vids {
842        var_builder.append_value(vid.to_string());
843    }
844
845    // Build _labels column
846    let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
847    for _vid in vids {
848        labels_builder.values().append_value(label_name);
849        labels_builder.append(true);
850    }
851
852    // Build score column
853    let mut score_builder = Float32Builder::with_capacity(num_rows);
854    for &score in scores {
855        score_builder.append_value(score);
856    }
857
858    let mut columns: Vec<ArrayRef> = vec![
859        Arc::new(vid_builder.finish()),
860        Arc::new(var_builder.finish()),
861        Arc::new(labels_builder.finish()),
862        Arc::new(score_builder.finish()),
863    ];
864
865    // Hydrate property columns
866    if !target_properties.is_empty() {
867        let property_manager = graph_ctx.property_manager();
868        let query_ctx = graph_ctx.query_context();
869
870        let props_map = property_manager
871            .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
872            .await
873            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
874
875        let uni_schema = graph_ctx.storage().schema_manager().schema();
876        let label_props = uni_schema.properties.get(label_name);
877
878        for prop_name in target_properties {
879            let data_type = resolve_property_type(prop_name, label_props);
880            let column = crate::query::df_graph::scan::build_property_column_static(
881                vids, &props_map, prop_name, &data_type,
882            )?;
883            columns.push(column);
884        }
885    }
886
887    RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
888}
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893    use uni_cypher::ast::CypherLiteral;
894
895    #[test]
896    fn test_build_schema() {
897        let schema = GraphVectorKnnExec::build_schema("n", &[], None);
898
899        assert_eq!(schema.fields().len(), 4);
900        assert_eq!(schema.field(0).name(), "n._vid");
901        assert_eq!(schema.field(1).name(), "n");
902        assert_eq!(schema.field(2).name(), "n._labels");
903        assert_eq!(schema.field(3).name(), "n._score");
904    }
905
906    #[test]
907    fn test_evaluate_literal_list() {
908        let expr = Expr::List(vec![
909            Expr::Literal(CypherLiteral::Float(0.1)),
910            Expr::Literal(CypherLiteral::Float(0.2)),
911            Expr::Literal(CypherLiteral::Float(0.3)),
912        ]);
913
914        let result = evaluate_simple_expr(&expr, &HashMap::new(), &HashMap::new()).unwrap();
915        match result {
916            Value::List(arr) => {
917                assert_eq!(arr.len(), 3);
918            }
919            _ => panic!("Expected list"),
920        }
921    }
922
923    #[test]
924    fn test_evaluate_parameter() {
925        let expr = Expr::Parameter("query".to_string());
926        let mut params = HashMap::new();
927        params.insert(
928            "query".to_string(),
929            Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
930        );
931
932        let result = evaluate_simple_expr(&expr, &params, &HashMap::new()).unwrap();
933        match result {
934            Value::List(arr) => {
935                assert_eq!(arr.len(), 2);
936            }
937            _ => panic!("Expected list"),
938        }
939    }
940
941    #[test]
942    fn test_build_schema_with_extra_properties() {
943        let extra_props = vec!["name".to_string(), "embedding".to_string()];
944        let schema = GraphVectorKnnExec::build_schema("doc", &extra_props, None);
945
946        // Should have base fields + extra properties
947        assert!(schema.field_with_name("doc._vid").is_ok());
948        assert!(schema.field_with_name("doc").is_ok());
949        assert!(schema.field_with_name("doc._score").is_ok());
950        assert!(
951            schema.field_with_name("doc.name").is_ok(),
952            "Extra property 'name' should be in schema"
953        );
954        assert!(
955            schema.field_with_name("doc.embedding").is_ok(),
956            "Extra property 'embedding' should be in schema"
957        );
958    }
959
960    #[test]
961    fn test_evaluate_variable() {
962        // Test that a variable expression resolves to the variable's value
963        let expr = Expr::Variable("x".to_string());
964        let mut variables = HashMap::new();
965        variables.insert(
966            "x".to_string(),
967            Value::List(vec![Value::Float(0.5), Value::Float(0.6)]),
968        );
969
970        let result = evaluate_simple_expr(&expr, &HashMap::new(), &variables).unwrap();
971        match result {
972            Value::List(arr) => {
973                assert_eq!(arr.len(), 2);
974            }
975            _ => panic!("Expected list, got {:?}", result),
976        }
977    }
978}