Skip to main content

uni_query/query/df_graph/
vector_knn.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Vector KNN search execution plan for DataFusion.
5//!
6//! This module provides [`GraphVectorKnnExec`], a DataFusion [`ExecutionPlan`] that
7//! performs vector similarity search using the underlying vector index.
8//!
9//! # Example
10//!
11//! ```text
12//! CALL uni.vector.query('Person', 'embedding', [0.1, 0.2, ...], 10)
13//! YIELD node, score
14//! ```
15
16use arrow_array::builder::{Float32Builder, StringBuilder, UInt64Builder};
17use arrow_array::{ArrayRef, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_common::core::id::Vid;
32use uni_common::core::schema::{DistanceMetric, PropertyMeta};
33use uni_cypher::ast::Expr;
34
35use crate::query::df_graph::GraphExecutionContext;
36use crate::query::df_graph::common::{
37    arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
38};
39use crate::query::df_graph::scan::resolve_property_type;
40
41/// Vector KNN search execution plan.
42///
43/// Queries the vector index for the K nearest neighbors to a query vector,
44/// returning matching vertex IDs and similarity scores.
45pub struct GraphVectorKnnExec {
46    /// Graph execution context for storage access.
47    graph_ctx: Arc<GraphExecutionContext>,
48
49    /// Label ID to search in.
50    label_id: u16,
51
52    /// Label name for display.
53    label_name: String,
54
55    /// Variable name for result vertices.
56    variable: String,
57
58    /// Property name containing vector embeddings.
59    property: String,
60
61    /// Query vector expression.
62    query_expr: Expr,
63
64    /// Number of results to return.
65    k: usize,
66
67    /// Optional similarity threshold.
68    threshold: Option<f32>,
69
70    /// Query parameters for expression evaluation.
71    params: HashMap<String, Value>,
72
73    /// Target vertex properties to materialize.
74    target_properties: Vec<String>,
75
76    /// Output schema.
77    schema: SchemaRef,
78
79    /// Plan properties.
80    properties: PlanProperties,
81
82    /// Execution metrics.
83    metrics: ExecutionPlanMetricsSet,
84}
85
86impl fmt::Debug for GraphVectorKnnExec {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("GraphVectorKnnExec")
89            .field("label_id", &self.label_id)
90            .field("variable", &self.variable)
91            .field("property", &self.property)
92            .field("k", &self.k)
93            .field("threshold", &self.threshold)
94            .finish()
95    }
96}
97
98impl GraphVectorKnnExec {
99    /// Create a new vector KNN search execution plan.
100    ///
101    /// # Arguments
102    ///
103    /// * `graph_ctx` - Graph execution context
104    /// * `label_id` - Label ID to search
105    /// * `label_name` - Label name for display
106    /// * `variable` - Variable name for results
107    /// * `property` - Property containing vectors
108    /// * `query_expr` - Expression evaluating to query vector
109    /// * `k` - Number of results
110    /// * `threshold` - Optional similarity threshold
111    /// * `params` - Query parameters
112    #[expect(clippy::too_many_arguments)]
113    pub fn new(
114        graph_ctx: Arc<GraphExecutionContext>,
115        label_id: u16,
116        label_name: impl Into<String>,
117        variable: impl Into<String>,
118        property: impl Into<String>,
119        query_expr: Expr,
120        k: usize,
121        threshold: Option<f32>,
122        params: HashMap<String, Value>,
123        target_properties: Vec<String>,
124    ) -> Self {
125        let variable = variable.into();
126        let property = property.into();
127        let label_name = label_name.into();
128
129        // Resolve property types from schema
130        let uni_schema = graph_ctx.storage().schema_manager().schema();
131        let label_props = uni_schema.properties.get(label_name.as_str());
132
133        let schema = Self::build_schema(&variable, &target_properties, label_props);
134        let properties = compute_plan_properties(schema.clone());
135
136        Self {
137            graph_ctx,
138            label_id,
139            label_name,
140            variable,
141            property,
142            query_expr,
143            k,
144            threshold,
145            params,
146            target_properties,
147            schema,
148            properties,
149            metrics: ExecutionPlanMetricsSet::new(),
150        }
151    }
152
153    /// Build the output schema.
154    ///
155    /// Schema contains:
156    /// - `{variable}._vid` - Vertex ID
157    /// - `{variable}` - Variable identifier (as string for now)
158    /// - `{variable}._score` - Similarity score
159    /// - `{variable}.{prop}` - Property columns
160    fn build_schema(
161        variable: &str,
162        target_properties: &[String],
163        label_props: Option<&HashMap<String, PropertyMeta>>,
164    ) -> SchemaRef {
165        let mut fields = vec![
166            Field::new(format!("{}._vid", variable), DataType::UInt64, false),
167            Field::new(variable, DataType::Utf8, false),
168            Field::new(format!("{}._labels", variable), labels_data_type(), true),
169            Field::new(format!("{}._score", variable), DataType::Float32, true),
170        ];
171
172        // Add property columns
173        for prop_name in target_properties {
174            let col_name = format!("{}.{}", variable, prop_name);
175            let arrow_type = resolve_property_type(prop_name, label_props);
176            fields.push(Field::new(&col_name, arrow_type, true));
177        }
178
179        Arc::new(Schema::new(fields))
180    }
181
182    /// Evaluate the query expression to extract the query vector.
183    fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
184        let value = evaluate_simple_expr(&self.query_expr, &self.params)?;
185
186        match value {
187            Value::Vector(vec) => Ok(vec),
188            Value::List(arr) => {
189                let mut vec = Vec::with_capacity(arr.len());
190                for v in arr {
191                    if let Some(f) = v.as_f64() {
192                        vec.push(f as f32);
193                    } else {
194                        return Err(datafusion::error::DataFusionError::Execution(
195                            "Query vector must contain numbers".to_string(),
196                        ));
197                    }
198                }
199                Ok(vec)
200            }
201            _ => Err(datafusion::error::DataFusionError::Execution(
202                "Query vector must be a list or vector".to_string(),
203            )),
204        }
205    }
206}
207
208impl DisplayAs for GraphVectorKnnExec {
209    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        write!(
211            f,
212            "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
213            self.label_name, self.property, self.k, self.variable
214        )
215    }
216}
217
218impl ExecutionPlan for GraphVectorKnnExec {
219    fn name(&self) -> &str {
220        "GraphVectorKnnExec"
221    }
222
223    fn as_any(&self) -> &dyn Any {
224        self
225    }
226
227    fn schema(&self) -> SchemaRef {
228        self.schema.clone()
229    }
230
231    fn properties(&self) -> &PlanProperties {
232        &self.properties
233    }
234
235    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
236        vec![]
237    }
238
239    fn with_new_children(
240        self: Arc<Self>,
241        children: Vec<Arc<dyn ExecutionPlan>>,
242    ) -> DFResult<Arc<dyn ExecutionPlan>> {
243        if !children.is_empty() {
244            return Err(datafusion::error::DataFusionError::Internal(
245                "GraphVectorKnnExec has no children".to_string(),
246            ));
247        }
248        Ok(self)
249    }
250
251    fn execute(
252        &self,
253        partition: usize,
254        _context: Arc<TaskContext>,
255    ) -> DFResult<SendableRecordBatchStream> {
256        let metrics = BaselineMetrics::new(&self.metrics, partition);
257
258        // Evaluate query vector upfront
259        let query_vector = self.evaluate_query_vector()?;
260
261        Ok(Box::pin(VectorKnnStream::new(
262            self.graph_ctx.clone(),
263            self.label_name.clone(),
264            self.variable.clone(),
265            self.property.clone(),
266            query_vector,
267            self.k,
268            self.threshold,
269            self.target_properties.clone(),
270            self.schema.clone(),
271            metrics,
272        )))
273    }
274
275    fn metrics(&self) -> Option<MetricsSet> {
276        Some(self.metrics.clone_inner())
277    }
278}
279
280/// State machine for vector KNN stream.
281enum VectorKnnState {
282    /// Initial state, ready to start search.
283    Init,
284    /// Executing the async search.
285    Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
286    /// Stream is done.
287    Done,
288}
289
290/// Stream that executes vector KNN search.
291struct VectorKnnStream {
292    /// Graph execution context.
293    graph_ctx: Arc<GraphExecutionContext>,
294
295    /// Label name to search.
296    label_name: String,
297
298    /// Variable name for results.
299    variable: String,
300
301    /// Property name containing vectors.
302    property: String,
303
304    /// Query vector.
305    query_vector: Vec<f32>,
306
307    /// Number of results.
308    k: usize,
309
310    /// Similarity threshold.
311    threshold: Option<f32>,
312
313    /// Target vertex properties to materialize.
314    target_properties: Vec<String>,
315
316    /// Output schema.
317    schema: SchemaRef,
318
319    /// Stream state.
320    state: VectorKnnState,
321
322    /// Metrics.
323    metrics: BaselineMetrics,
324}
325
326impl VectorKnnStream {
327    #[expect(clippy::too_many_arguments)]
328    fn new(
329        graph_ctx: Arc<GraphExecutionContext>,
330        label_name: String,
331        variable: String,
332        property: String,
333        query_vector: Vec<f32>,
334        k: usize,
335        threshold: Option<f32>,
336        target_properties: Vec<String>,
337        schema: SchemaRef,
338        metrics: BaselineMetrics,
339    ) -> Self {
340        Self {
341            graph_ctx,
342            label_name,
343            variable,
344            property,
345            query_vector,
346            k,
347            threshold,
348            target_properties,
349            schema,
350            state: VectorKnnState::Init,
351            metrics,
352        }
353    }
354}
355
356impl Stream for VectorKnnStream {
357    type Item = DFResult<RecordBatch>;
358
359    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
360        loop {
361            let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
362
363            match state {
364                VectorKnnState::Init => {
365                    // Clone data for async block
366                    let graph_ctx = self.graph_ctx.clone();
367                    let label_name = self.label_name.clone();
368                    let variable = self.variable.clone();
369                    let property = self.property.clone();
370                    let query_vector = self.query_vector.clone();
371                    let k = self.k;
372                    let threshold = self.threshold;
373                    let target_properties = self.target_properties.clone();
374                    let schema = self.schema.clone();
375
376                    let fut = async move {
377                        // Check timeout
378                        graph_ctx.check_timeout().map_err(|e| {
379                            datafusion::error::DataFusionError::Execution(e.to_string())
380                        })?;
381
382                        execute_vector_search(
383                            &graph_ctx,
384                            &label_name,
385                            &variable,
386                            &property,
387                            &query_vector,
388                            k,
389                            threshold,
390                            &target_properties,
391                            &schema,
392                        )
393                        .await
394                    };
395
396                    self.state = VectorKnnState::Executing(Box::pin(fut));
397                    // Continue loop to poll the future
398                }
399                VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
400                    Poll::Ready(Ok(batch)) => {
401                        self.state = VectorKnnState::Done;
402                        self.metrics
403                            .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
404                        return Poll::Ready(batch.map(Ok));
405                    }
406                    Poll::Ready(Err(e)) => {
407                        self.state = VectorKnnState::Done;
408                        return Poll::Ready(Some(Err(e)));
409                    }
410                    Poll::Pending => {
411                        self.state = VectorKnnState::Executing(fut);
412                        return Poll::Pending;
413                    }
414                },
415                VectorKnnState::Done => {
416                    return Poll::Ready(None);
417                }
418            }
419        }
420    }
421}
422
423impl RecordBatchStream for VectorKnnStream {
424    fn schema(&self) -> SchemaRef {
425        self.schema.clone()
426    }
427}
428
429/// Execute the vector search and build results.
430#[expect(clippy::too_many_arguments)]
431async fn execute_vector_search(
432    graph_ctx: &GraphExecutionContext,
433    label_name: &str,
434    variable: &str,
435    property: &str,
436    query_vector: &[f32],
437    k: usize,
438    threshold: Option<f32>,
439    target_properties: &[String],
440    schema: &SchemaRef,
441) -> DFResult<Option<RecordBatch>> {
442    let storage = graph_ctx.storage();
443    let query_ctx = graph_ctx.query_context();
444
445    // Execute vector search
446    let results = storage
447        .vector_search(
448            label_name,
449            property,
450            query_vector,
451            k,
452            None,
453            Some(&query_ctx),
454        )
455        .await
456        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
457
458    // Look up the distance metric for this vector property so we can
459    // convert raw distances into normalised similarity scores correctly.
460    let metric = storage
461        .schema_manager()
462        .schema()
463        .vector_index_for_property(label_name, property)
464        .map(|cfg| cfg.metric.clone())
465        .unwrap_or(DistanceMetric::L2);
466
467    // Filter by threshold and build result
468    let mut vids = Vec::new();
469    let mut scores = Vec::new();
470
471    for (vid, distance) in results {
472        let similarity = calculate_score(distance, &metric);
473
474        if let Some(thresh) = threshold
475            && similarity < thresh
476        {
477            continue;
478        }
479
480        vids.push(vid);
481        scores.push(similarity);
482    }
483
484    if vids.is_empty() {
485        return Ok(Some(RecordBatch::new_empty(schema.clone())));
486    }
487
488    // Build the base record batch (VID, variable, score)
489    let batch = build_result_batch(
490        &vids,
491        &scores,
492        variable,
493        target_properties,
494        label_name,
495        graph_ctx,
496        schema,
497    )
498    .await?;
499    Ok(Some(batch))
500}
501
502/// Build a result batch from VIDs and scores, including hydrated properties.
503async fn build_result_batch(
504    vids: &[Vid],
505    scores: &[f32],
506    _variable: &str,
507    target_properties: &[String],
508    label_name: &str,
509    graph_ctx: &GraphExecutionContext,
510    schema: &SchemaRef,
511) -> DFResult<RecordBatch> {
512    let num_rows = vids.len();
513
514    // Build _vid column
515    let mut vid_builder = UInt64Builder::with_capacity(num_rows);
516    for vid in vids {
517        vid_builder.append_value(vid.as_u64());
518    }
519
520    // Build variable column (VID as string for now)
521    let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
522    for vid in vids {
523        var_builder.append_value(vid.to_string());
524    }
525
526    // Build _labels column
527    let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
528    for _vid in vids {
529        labels_builder.values().append_value(label_name);
530        labels_builder.append(true);
531    }
532
533    // Build score column
534    let mut score_builder = Float32Builder::with_capacity(num_rows);
535    for &score in scores {
536        score_builder.append_value(score);
537    }
538
539    let mut columns: Vec<ArrayRef> = vec![
540        Arc::new(vid_builder.finish()),
541        Arc::new(var_builder.finish()),
542        Arc::new(labels_builder.finish()),
543        Arc::new(score_builder.finish()),
544    ];
545
546    // Hydrate property columns
547    if !target_properties.is_empty() {
548        let property_manager = graph_ctx.property_manager();
549        let query_ctx = graph_ctx.query_context();
550
551        let props_map = property_manager
552            .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
553            .await
554            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
555
556        let uni_schema = graph_ctx.storage().schema_manager().schema();
557        let label_props = uni_schema.properties.get(label_name);
558
559        for prop_name in target_properties {
560            let data_type = resolve_property_type(prop_name, label_props);
561            let column = crate::query::df_graph::scan::build_property_column_static(
562                vids, &props_map, prop_name, &data_type,
563            )?;
564            columns.push(column);
565        }
566    }
567
568    RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574    use uni_cypher::ast::CypherLiteral;
575
576    #[test]
577    fn test_build_schema() {
578        let schema = GraphVectorKnnExec::build_schema("n", &[], None);
579
580        assert_eq!(schema.fields().len(), 4);
581        assert_eq!(schema.field(0).name(), "n._vid");
582        assert_eq!(schema.field(1).name(), "n");
583        assert_eq!(schema.field(2).name(), "n._labels");
584        assert_eq!(schema.field(3).name(), "n._score");
585    }
586
587    #[test]
588    fn test_evaluate_literal_list() {
589        let expr = Expr::List(vec![
590            Expr::Literal(CypherLiteral::Float(0.1)),
591            Expr::Literal(CypherLiteral::Float(0.2)),
592            Expr::Literal(CypherLiteral::Float(0.3)),
593        ]);
594
595        let result = evaluate_simple_expr(&expr, &HashMap::new()).unwrap();
596        match result {
597            Value::List(arr) => {
598                assert_eq!(arr.len(), 3);
599            }
600            _ => panic!("Expected list"),
601        }
602    }
603
604    #[test]
605    fn test_evaluate_parameter() {
606        let expr = Expr::Parameter("query".to_string());
607        let mut params = HashMap::new();
608        params.insert(
609            "query".to_string(),
610            Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
611        );
612
613        let result = evaluate_simple_expr(&expr, &params).unwrap();
614        match result {
615            Value::List(arr) => {
616                assert_eq!(arr.len(), 2);
617            }
618            _ => panic!("Expected list"),
619        }
620    }
621}