Skip to main content

uni_query/query/df_graph/
vector_knn.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Vector KNN search execution plan for DataFusion.
5//!
6//! This module provides [`GraphVectorKnnExec`], a DataFusion [`ExecutionPlan`] that
7//! performs vector similarity search using the underlying vector index.
8//!
9//! # Example
10//!
11//! ```text
12//! CALL uni.vector.query('Person', 'embedding', [0.1, 0.2, ...], 10)
13//! YIELD node, score
14//! ```
15
16use crate::query::df_graph::GraphExecutionContext;
17use crate::query::df_graph::common::{
18    compute_plan_properties, evaluate_simple_expr, labels_data_type,
19};
20use crate::query::df_graph::scan::resolve_property_type;
21use arrow_array::builder::{Float32Builder, StringBuilder, UInt64Builder};
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_schema::{DataType, Field, Schema, SchemaRef};
24use datafusion::common::Result as DFResult;
25use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
26use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
27use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
28use futures::Stream;
29use std::any::Any;
30use std::collections::HashMap;
31use std::fmt;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::task::{Context, Poll};
35use uni_common::Value;
36use uni_common::core::id::Vid;
37use uni_cypher::ast::Expr;
38
39/// Vector KNN search execution plan.
40///
41/// Queries the vector index for the K nearest neighbors to a query vector,
42/// returning matching vertex IDs and similarity scores.
43pub struct GraphVectorKnnExec {
44    /// Graph execution context for storage access.
45    graph_ctx: Arc<GraphExecutionContext>,
46
47    /// Label ID to search in.
48    label_id: u16,
49
50    /// Label name for display.
51    label_name: String,
52
53    /// Variable name for result vertices.
54    variable: String,
55
56    /// Property name containing vector embeddings.
57    property: String,
58
59    /// Query vector expression.
60    query_expr: Expr,
61
62    /// Number of results to return.
63    k: usize,
64
65    /// Optional similarity threshold.
66    threshold: Option<f32>,
67
68    /// Query parameters for expression evaluation.
69    params: HashMap<String, Value>,
70
71    /// Target vertex properties to materialize.
72    target_properties: Vec<String>,
73
74    /// Output schema.
75    schema: SchemaRef,
76
77    /// Plan properties.
78    properties: PlanProperties,
79
80    /// Execution metrics.
81    metrics: ExecutionPlanMetricsSet,
82}
83
84impl fmt::Debug for GraphVectorKnnExec {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        f.debug_struct("GraphVectorKnnExec")
87            .field("label_id", &self.label_id)
88            .field("variable", &self.variable)
89            .field("property", &self.property)
90            .field("k", &self.k)
91            .field("threshold", &self.threshold)
92            .finish()
93    }
94}
95
96impl GraphVectorKnnExec {
97    /// Create a new vector KNN search execution plan.
98    ///
99    /// # Arguments
100    ///
101    /// * `graph_ctx` - Graph execution context
102    /// * `label_id` - Label ID to search
103    /// * `label_name` - Label name for display
104    /// * `variable` - Variable name for results
105    /// * `property` - Property containing vectors
106    /// * `query_expr` - Expression evaluating to query vector
107    /// * `k` - Number of results
108    /// * `threshold` - Optional similarity threshold
109    /// * `params` - Query parameters
110    #[expect(clippy::too_many_arguments)]
111    pub fn new(
112        graph_ctx: Arc<GraphExecutionContext>,
113        label_id: u16,
114        label_name: impl Into<String>,
115        variable: impl Into<String>,
116        property: impl Into<String>,
117        query_expr: Expr,
118        k: usize,
119        threshold: Option<f32>,
120        params: HashMap<String, Value>,
121        target_properties: Vec<String>,
122    ) -> Self {
123        let variable = variable.into();
124        let property = property.into();
125        let label_name = label_name.into();
126
127        // Resolve property types from schema
128        let uni_schema = graph_ctx.storage().schema_manager().schema();
129        let label_props = uni_schema.properties.get(label_name.as_str());
130
131        let schema = Self::build_schema(&variable, &target_properties, label_props);
132        let properties = compute_plan_properties(schema.clone());
133
134        Self {
135            graph_ctx,
136            label_id,
137            label_name,
138            variable,
139            property,
140            query_expr,
141            k,
142            threshold,
143            params,
144            target_properties,
145            schema,
146            properties,
147            metrics: ExecutionPlanMetricsSet::new(),
148        }
149    }
150
151    /// Build the output schema.
152    ///
153    /// Schema contains:
154    /// - `{variable}._vid` - Vertex ID
155    /// - `{variable}` - Variable identifier (as string for now)
156    /// - `{variable}._score` - Similarity score
157    /// - `{variable}.{prop}` - Property columns
158    fn build_schema(
159        variable: &str,
160        target_properties: &[String],
161        label_props: Option<
162            &std::collections::HashMap<String, uni_common::core::schema::PropertyMeta>,
163        >,
164    ) -> SchemaRef {
165        let mut fields = vec![
166            Field::new(format!("{}._vid", variable), DataType::UInt64, false),
167            Field::new(variable, DataType::Utf8, false),
168            Field::new(format!("{}._labels", variable), labels_data_type(), true),
169            Field::new(format!("{}._score", variable), DataType::Float32, true),
170        ];
171
172        // Add property columns
173        for prop_name in target_properties {
174            let col_name = format!("{}.{}", variable, prop_name);
175            let arrow_type = resolve_property_type(prop_name, label_props);
176            fields.push(Field::new(&col_name, arrow_type, true));
177        }
178
179        Arc::new(Schema::new(fields))
180    }
181
182    /// Evaluate the query expression to extract the query vector.
183    fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
184        let value = evaluate_simple_expr(&self.query_expr, &self.params)?;
185
186        match value {
187            Value::Vector(vec) => Ok(vec),
188            Value::List(arr) => {
189                let mut vec = Vec::with_capacity(arr.len());
190                for v in arr {
191                    if let Some(f) = v.as_f64() {
192                        vec.push(f as f32);
193                    } else {
194                        return Err(datafusion::error::DataFusionError::Execution(
195                            "Query vector must contain numbers".to_string(),
196                        ));
197                    }
198                }
199                Ok(vec)
200            }
201            _ => Err(datafusion::error::DataFusionError::Execution(
202                "Query vector must be a list or vector".to_string(),
203            )),
204        }
205    }
206}
207
208impl DisplayAs for GraphVectorKnnExec {
209    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        write!(
211            f,
212            "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
213            self.label_name, self.property, self.k, self.variable
214        )
215    }
216}
217
218impl ExecutionPlan for GraphVectorKnnExec {
219    fn name(&self) -> &str {
220        "GraphVectorKnnExec"
221    }
222
223    fn as_any(&self) -> &dyn Any {
224        self
225    }
226
227    fn schema(&self) -> SchemaRef {
228        self.schema.clone()
229    }
230
231    fn properties(&self) -> &PlanProperties {
232        &self.properties
233    }
234
235    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
236        vec![]
237    }
238
239    fn with_new_children(
240        self: Arc<Self>,
241        children: Vec<Arc<dyn ExecutionPlan>>,
242    ) -> DFResult<Arc<dyn ExecutionPlan>> {
243        if !children.is_empty() {
244            return Err(datafusion::error::DataFusionError::Internal(
245                "GraphVectorKnnExec has no children".to_string(),
246            ));
247        }
248        Ok(self)
249    }
250
251    fn execute(
252        &self,
253        partition: usize,
254        _context: Arc<TaskContext>,
255    ) -> DFResult<SendableRecordBatchStream> {
256        let metrics = BaselineMetrics::new(&self.metrics, partition);
257
258        // Evaluate query vector upfront
259        let query_vector = self.evaluate_query_vector()?;
260
261        Ok(Box::pin(VectorKnnStream::new(
262            self.graph_ctx.clone(),
263            self.label_name.clone(),
264            self.variable.clone(),
265            self.property.clone(),
266            query_vector,
267            self.k,
268            self.threshold,
269            self.target_properties.clone(),
270            self.schema.clone(),
271            metrics,
272        )))
273    }
274
275    fn metrics(&self) -> Option<MetricsSet> {
276        Some(self.metrics.clone_inner())
277    }
278}
279
280/// State machine for vector KNN stream.
281enum VectorKnnState {
282    /// Initial state, ready to start search.
283    Init,
284    /// Executing the async search.
285    Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
286    /// Stream is done.
287    Done,
288}
289
290/// Stream that executes vector KNN search.
291struct VectorKnnStream {
292    /// Graph execution context.
293    graph_ctx: Arc<GraphExecutionContext>,
294
295    /// Label name to search.
296    label_name: String,
297
298    /// Variable name for results.
299    variable: String,
300
301    /// Property name containing vectors.
302    property: String,
303
304    /// Query vector.
305    query_vector: Vec<f32>,
306
307    /// Number of results.
308    k: usize,
309
310    /// Similarity threshold.
311    threshold: Option<f32>,
312
313    /// Target vertex properties to materialize.
314    target_properties: Vec<String>,
315
316    /// Output schema.
317    schema: SchemaRef,
318
319    /// Stream state.
320    state: VectorKnnState,
321
322    /// Metrics.
323    metrics: BaselineMetrics,
324}
325
326impl VectorKnnStream {
327    #[expect(clippy::too_many_arguments)]
328    fn new(
329        graph_ctx: Arc<GraphExecutionContext>,
330        label_name: String,
331        variable: String,
332        property: String,
333        query_vector: Vec<f32>,
334        k: usize,
335        threshold: Option<f32>,
336        target_properties: Vec<String>,
337        schema: SchemaRef,
338        metrics: BaselineMetrics,
339    ) -> Self {
340        Self {
341            graph_ctx,
342            label_name,
343            variable,
344            property,
345            query_vector,
346            k,
347            threshold,
348            target_properties,
349            schema,
350            state: VectorKnnState::Init,
351            metrics,
352        }
353    }
354}
355
356impl Stream for VectorKnnStream {
357    type Item = DFResult<RecordBatch>;
358
359    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
360        loop {
361            let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
362
363            match state {
364                VectorKnnState::Init => {
365                    // Clone data for async block
366                    let graph_ctx = self.graph_ctx.clone();
367                    let label_name = self.label_name.clone();
368                    let variable = self.variable.clone();
369                    let property = self.property.clone();
370                    let query_vector = self.query_vector.clone();
371                    let k = self.k;
372                    let threshold = self.threshold;
373                    let target_properties = self.target_properties.clone();
374                    let schema = self.schema.clone();
375
376                    let fut = async move {
377                        // Check timeout
378                        graph_ctx.check_timeout().map_err(|e| {
379                            datafusion::error::DataFusionError::Execution(e.to_string())
380                        })?;
381
382                        execute_vector_search(
383                            &graph_ctx,
384                            &label_name,
385                            &variable,
386                            &property,
387                            &query_vector,
388                            k,
389                            threshold,
390                            &target_properties,
391                            &schema,
392                        )
393                        .await
394                    };
395
396                    self.state = VectorKnnState::Executing(Box::pin(fut));
397                    // Continue loop to poll the future
398                }
399                VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
400                    Poll::Ready(Ok(batch)) => {
401                        self.state = VectorKnnState::Done;
402                        self.metrics
403                            .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
404                        return Poll::Ready(batch.map(Ok));
405                    }
406                    Poll::Ready(Err(e)) => {
407                        self.state = VectorKnnState::Done;
408                        return Poll::Ready(Some(Err(e)));
409                    }
410                    Poll::Pending => {
411                        self.state = VectorKnnState::Executing(fut);
412                        return Poll::Pending;
413                    }
414                },
415                VectorKnnState::Done => {
416                    return Poll::Ready(None);
417                }
418            }
419        }
420    }
421}
422
423impl RecordBatchStream for VectorKnnStream {
424    fn schema(&self) -> SchemaRef {
425        self.schema.clone()
426    }
427}
428
429/// Execute the vector search and build results.
430#[expect(clippy::too_many_arguments)]
431async fn execute_vector_search(
432    graph_ctx: &GraphExecutionContext,
433    label_name: &str,
434    variable: &str,
435    property: &str,
436    query_vector: &[f32],
437    k: usize,
438    threshold: Option<f32>,
439    target_properties: &[String],
440    schema: &SchemaRef,
441) -> DFResult<Option<RecordBatch>> {
442    let storage = graph_ctx.storage();
443    let query_ctx = graph_ctx.query_context();
444
445    // Execute vector search
446    let results = storage
447        .vector_search(
448            label_name,
449            property,
450            query_vector,
451            k,
452            None,
453            Some(&query_ctx),
454        )
455        .await
456        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
457
458    // Filter by threshold and build result
459    let mut vids = Vec::new();
460    let mut scores = Vec::new();
461
462    for (vid, distance) in results {
463        // Convert distance to similarity (assuming Cosine/Dot product)
464        let similarity = 1.0 - distance;
465
466        if let Some(thresh) = threshold
467            && similarity < thresh
468        {
469            continue;
470        }
471
472        vids.push(vid);
473        scores.push(similarity);
474    }
475
476    if vids.is_empty() {
477        return Ok(Some(RecordBatch::new_empty(schema.clone())));
478    }
479
480    // Build the base record batch (VID, variable, score)
481    let batch = build_result_batch(
482        &vids,
483        &scores,
484        variable,
485        target_properties,
486        label_name,
487        graph_ctx,
488        schema,
489    )
490    .await?;
491    Ok(Some(batch))
492}
493
494/// Build a result batch from VIDs and scores, including hydrated properties.
495async fn build_result_batch(
496    vids: &[Vid],
497    scores: &[f32],
498    _variable: &str,
499    target_properties: &[String],
500    label_name: &str,
501    graph_ctx: &GraphExecutionContext,
502    schema: &SchemaRef,
503) -> DFResult<RecordBatch> {
504    let num_rows = vids.len();
505
506    // Build _vid column
507    let mut vid_builder = UInt64Builder::with_capacity(num_rows);
508    for vid in vids {
509        vid_builder.append_value(vid.as_u64());
510    }
511
512    // Build variable column (VID as string for now)
513    let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
514    for vid in vids {
515        var_builder.append_value(vid.to_string());
516    }
517
518    // Build _labels column
519    let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
520    for _vid in vids {
521        labels_builder.values().append_value(label_name);
522        labels_builder.append(true);
523    }
524
525    // Build score column
526    let mut score_builder = Float32Builder::with_capacity(num_rows);
527    for &score in scores {
528        score_builder.append_value(score);
529    }
530
531    let mut columns: Vec<ArrayRef> = vec![
532        Arc::new(vid_builder.finish()),
533        Arc::new(var_builder.finish()),
534        Arc::new(labels_builder.finish()),
535        Arc::new(score_builder.finish()),
536    ];
537
538    // Hydrate property columns
539    if !target_properties.is_empty() {
540        let property_manager = graph_ctx.property_manager();
541        let query_ctx = graph_ctx.query_context();
542
543        let props_map = property_manager
544            .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
545            .await
546            .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
547
548        let uni_schema = graph_ctx.storage().schema_manager().schema();
549        let label_props = uni_schema.properties.get(label_name);
550
551        for prop_name in target_properties {
552            let data_type = resolve_property_type(prop_name, label_props);
553            let column = crate::query::df_graph::scan::build_property_column_static(
554                vids, &props_map, prop_name, &data_type,
555            )?;
556            columns.push(column);
557        }
558    }
559
560    RecordBatch::try_new(schema.clone(), columns)
561        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use uni_cypher::ast::CypherLiteral;
568
569    #[test]
570    fn test_build_schema() {
571        let schema = GraphVectorKnnExec::build_schema("n", &[], None);
572
573        assert_eq!(schema.fields().len(), 4);
574        assert_eq!(schema.field(0).name(), "n._vid");
575        assert_eq!(schema.field(1).name(), "n");
576        assert_eq!(schema.field(2).name(), "n._labels");
577        assert_eq!(schema.field(3).name(), "n._score");
578    }
579
580    #[test]
581    fn test_evaluate_literal_list() {
582        let expr = Expr::List(vec![
583            Expr::Literal(CypherLiteral::Float(0.1)),
584            Expr::Literal(CypherLiteral::Float(0.2)),
585            Expr::Literal(CypherLiteral::Float(0.3)),
586        ]);
587
588        let result = evaluate_simple_expr(&expr, &HashMap::new()).unwrap();
589        match result {
590            Value::List(arr) => {
591                assert_eq!(arr.len(), 3);
592            }
593            _ => panic!("Expected list"),
594        }
595    }
596
597    #[test]
598    fn test_evaluate_parameter() {
599        let expr = Expr::Parameter("query".to_string());
600        let mut params = HashMap::new();
601        params.insert(
602            "query".to_string(),
603            Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
604        );
605
606        let result = evaluate_simple_expr(&expr, &params).unwrap();
607        match result {
608            Value::List(arr) => {
609                assert_eq!(arr.len(), 2);
610            }
611            _ => panic!("Expected list"),
612        }
613    }
614}