1use arrow_array::builder::{FixedSizeListBuilder, Float32Builder, StringBuilder, UInt64Builder};
17use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, Int64Array, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_common::core::id::Vid;
32use uni_common::core::schema::{DistanceMetric, PropertyMeta};
33use uni_cypher::ast::Expr;
34use uni_plugin::traits::index::{IndexHandle, IndexKind};
35
36use crate::query::df_graph::GraphExecutionContext;
37use crate::query::df_graph::common::{
38 arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
39};
40use crate::query::df_graph::scan::{property_field, resolve_property_type};
41
42#[derive(Clone)]
60pub(crate) enum VectorSource {
61 Native,
63 Plugin {
65 #[allow(dead_code)]
68 kind: IndexKind,
69 handle: Arc<dyn IndexHandle>,
71 },
72}
73
74impl fmt::Debug for VectorSource {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::Native => f.write_str("Native"),
78 Self::Plugin { kind, .. } => f.debug_struct("Plugin").field("kind", kind).finish(),
79 }
80 }
81}
82
83pub struct GraphVectorKnnExec {
88 graph_ctx: Arc<GraphExecutionContext>,
90
91 label_id: u16,
93
94 label_name: String,
96
97 variable: String,
99
100 property: String,
102
103 query_expr: Expr,
105
106 k: usize,
108
109 threshold: Option<f32>,
111
112 params: HashMap<String, Value>,
114
115 target_properties: Vec<String>,
117
118 schema: SchemaRef,
120
121 properties: Arc<PlanProperties>,
123
124 source: VectorSource,
128
129 metrics: ExecutionPlanMetricsSet,
131}
132
133impl fmt::Debug for GraphVectorKnnExec {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 f.debug_struct("GraphVectorKnnExec")
136 .field("label_id", &self.label_id)
137 .field("variable", &self.variable)
138 .field("property", &self.property)
139 .field("k", &self.k)
140 .field("threshold", &self.threshold)
141 .finish()
142 }
143}
144
145impl GraphVectorKnnExec {
146 #[expect(clippy::too_many_arguments)]
160 pub fn new(
161 graph_ctx: Arc<GraphExecutionContext>,
162 label_id: u16,
163 label_name: impl Into<String>,
164 variable: impl Into<String>,
165 property: impl Into<String>,
166 query_expr: Expr,
167 k: usize,
168 threshold: Option<f32>,
169 params: HashMap<String, Value>,
170 target_properties: Vec<String>,
171 ) -> Self {
172 let variable = variable.into();
173 let property = property.into();
174 let label_name = label_name.into();
175
176 let uni_schema = graph_ctx.storage().schema_manager().schema();
178 let label_props = uni_schema.properties.get(label_name.as_str());
179
180 let schema = Self::build_schema(&variable, &target_properties, label_props);
181 let properties = compute_plan_properties(schema.clone());
182
183 Self {
184 graph_ctx,
185 label_id,
186 label_name,
187 variable,
188 property,
189 query_expr,
190 k,
191 threshold,
192 params,
193 target_properties,
194 schema,
195 properties,
196 source: VectorSource::Native,
197 metrics: ExecutionPlanMetricsSet::new(),
198 }
199 }
200
201 #[expect(clippy::too_many_arguments)]
208 pub fn with_plugin_source(
209 graph_ctx: Arc<GraphExecutionContext>,
210 label_id: u16,
211 label_name: impl Into<String>,
212 variable: impl Into<String>,
213 property: impl Into<String>,
214 query_expr: Expr,
215 k: usize,
216 threshold: Option<f32>,
217 params: HashMap<String, Value>,
218 target_properties: Vec<String>,
219 kind: IndexKind,
220 handle: Arc<dyn IndexHandle>,
221 ) -> Self {
222 let mut exec = Self::new(
223 graph_ctx,
224 label_id,
225 label_name,
226 variable,
227 property,
228 query_expr,
229 k,
230 threshold,
231 params,
232 target_properties,
233 );
234 exec.source = VectorSource::Plugin { kind, handle };
235 exec
236 }
237
238 fn build_schema(
246 variable: &str,
247 target_properties: &[String],
248 label_props: Option<&HashMap<String, PropertyMeta>>,
249 ) -> SchemaRef {
250 let mut fields = vec![
251 Field::new(format!("{}._vid", variable), DataType::UInt64, false),
252 Field::new(variable, DataType::Utf8, false),
253 Field::new(format!("{}._labels", variable), labels_data_type(), true),
254 Field::new(format!("{}._score", variable), DataType::Float32, true),
255 ];
256
257 for prop_name in target_properties {
259 let col_name = format!("{}.{}", variable, prop_name);
260 let arrow_type = resolve_property_type(prop_name, label_props);
261 let uni_type = label_props
262 .and_then(|p| p.get(prop_name))
263 .map(|m| &m.r#type);
264 fields.push(property_field(&col_name, arrow_type, uni_type));
265 }
266
267 Arc::new(Schema::new(fields))
268 }
269
270 fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
272 let value = evaluate_simple_expr(&self.query_expr, &self.params, &HashMap::new())?;
273
274 match value {
275 Value::Vector(vec) => Ok(vec),
276 Value::List(arr) => {
277 let mut vec = Vec::with_capacity(arr.len());
278 for v in arr {
279 if let Some(f) = v.as_f64() {
280 vec.push(f as f32);
281 } else {
282 return Err(datafusion::error::DataFusionError::Execution(
283 "Query vector must contain numbers".to_string(),
284 ));
285 }
286 }
287 Ok(vec)
288 }
289 _ => Err(datafusion::error::DataFusionError::Execution(
290 "Query vector must be a list or vector".to_string(),
291 )),
292 }
293 }
294}
295
296impl DisplayAs for GraphVectorKnnExec {
297 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 write!(
299 f,
300 "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
301 self.label_name, self.property, self.k, self.variable
302 )
303 }
304}
305
306impl ExecutionPlan for GraphVectorKnnExec {
307 fn name(&self) -> &str {
308 "GraphVectorKnnExec"
309 }
310
311 fn as_any(&self) -> &dyn Any {
312 self
313 }
314
315 fn schema(&self) -> SchemaRef {
316 self.schema.clone()
317 }
318
319 fn properties(&self) -> &Arc<PlanProperties> {
320 &self.properties
321 }
322
323 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
324 vec![]
325 }
326
327 fn with_new_children(
328 self: Arc<Self>,
329 children: Vec<Arc<dyn ExecutionPlan>>,
330 ) -> DFResult<Arc<dyn ExecutionPlan>> {
331 if !children.is_empty() {
332 return Err(datafusion::error::DataFusionError::Internal(
333 "GraphVectorKnnExec has no children".to_string(),
334 ));
335 }
336 Ok(self)
337 }
338
339 fn execute(
340 &self,
341 partition: usize,
342 _context: Arc<TaskContext>,
343 ) -> DFResult<SendableRecordBatchStream> {
344 let metrics = BaselineMetrics::new(&self.metrics, partition);
345
346 let query_vector = self.evaluate_query_vector()?;
348
349 Ok(Box::pin(VectorKnnStream::new(
350 self.graph_ctx.clone(),
351 self.label_name.clone(),
352 self.variable.clone(),
353 self.property.clone(),
354 query_vector,
355 self.k,
356 self.threshold,
357 self.target_properties.clone(),
358 self.schema.clone(),
359 self.source.clone(),
360 metrics,
361 )))
362 }
363
364 fn metrics(&self) -> Option<MetricsSet> {
365 Some(self.metrics.clone_inner())
366 }
367}
368
369enum VectorKnnState {
371 Init,
373 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
375 Done,
377}
378
379struct VectorKnnStream {
381 graph_ctx: Arc<GraphExecutionContext>,
383
384 label_name: String,
386
387 variable: String,
389
390 property: String,
392
393 query_vector: Vec<f32>,
395
396 k: usize,
398
399 threshold: Option<f32>,
401
402 target_properties: Vec<String>,
404
405 schema: SchemaRef,
407
408 source: VectorSource,
410
411 state: VectorKnnState,
413
414 metrics: BaselineMetrics,
416}
417
418impl VectorKnnStream {
419 #[expect(clippy::too_many_arguments)]
420 fn new(
421 graph_ctx: Arc<GraphExecutionContext>,
422 label_name: String,
423 variable: String,
424 property: String,
425 query_vector: Vec<f32>,
426 k: usize,
427 threshold: Option<f32>,
428 target_properties: Vec<String>,
429 schema: SchemaRef,
430 source: VectorSource,
431 metrics: BaselineMetrics,
432 ) -> Self {
433 Self {
434 graph_ctx,
435 label_name,
436 variable,
437 property,
438 query_vector,
439 k,
440 threshold,
441 target_properties,
442 schema,
443 source,
444 state: VectorKnnState::Init,
445 metrics,
446 }
447 }
448}
449
450impl Stream for VectorKnnStream {
451 type Item = DFResult<RecordBatch>;
452
453 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
454 let metrics = self.metrics.clone();
455 let _timer = metrics.elapsed_compute().timer();
456 loop {
457 let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
458
459 match state {
460 VectorKnnState::Init => {
461 let graph_ctx = self.graph_ctx.clone();
463 let label_name = self.label_name.clone();
464 let variable = self.variable.clone();
465 let property = self.property.clone();
466 let query_vector = self.query_vector.clone();
467 let k = self.k;
468 let threshold = self.threshold;
469 let target_properties = self.target_properties.clone();
470 let schema = self.schema.clone();
471 let source = self.source.clone();
472
473 let fut = async move {
474 graph_ctx.check_timeout().map_err(|e| {
476 datafusion::error::DataFusionError::Execution(e.to_string())
477 })?;
478
479 execute_vector_search(
480 &graph_ctx,
481 &label_name,
482 &variable,
483 &property,
484 &query_vector,
485 k,
486 threshold,
487 &target_properties,
488 &schema,
489 &source,
490 )
491 .await
492 };
493
494 self.state = VectorKnnState::Executing(Box::pin(fut));
495 }
497 VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
498 Poll::Ready(Ok(batch)) => {
499 self.state = VectorKnnState::Done;
500 self.metrics
501 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
502 return Poll::Ready(batch.map(Ok));
503 }
504 Poll::Ready(Err(e)) => {
505 self.state = VectorKnnState::Done;
506 return Poll::Ready(Some(Err(e)));
507 }
508 Poll::Pending => {
509 self.state = VectorKnnState::Executing(fut);
510 return Poll::Pending;
511 }
512 },
513 VectorKnnState::Done => {
514 return Poll::Ready(None);
515 }
516 }
517 }
518 }
519}
520
521impl RecordBatchStream for VectorKnnStream {
522 fn schema(&self) -> SchemaRef {
523 self.schema.clone()
524 }
525}
526
527#[expect(clippy::too_many_arguments)]
529async fn execute_vector_search(
530 graph_ctx: &GraphExecutionContext,
531 label_name: &str,
532 variable: &str,
533 property: &str,
534 query_vector: &[f32],
535 k: usize,
536 threshold: Option<f32>,
537 target_properties: &[String],
538 schema: &SchemaRef,
539 source: &VectorSource,
540) -> DFResult<Option<RecordBatch>> {
541 let storage = graph_ctx.storage();
542
543 let results =
545 retrieve_vid_scores(graph_ctx, label_name, property, query_vector, k, source).await?;
546
547 let metric = storage
550 .schema_manager()
551 .schema()
552 .vector_index_for_property(label_name, property)
553 .map(|cfg| cfg.metric.clone())
554 .unwrap_or(DistanceMetric::L2);
555
556 let mut vids = Vec::new();
558 let mut scores = Vec::new();
559
560 for (vid, distance) in results {
561 let similarity = calculate_score(distance, &metric);
562
563 if let Some(thresh) = threshold
564 && similarity < thresh
565 {
566 continue;
567 }
568
569 vids.push(vid);
570 scores.push(similarity);
571 }
572
573 if vids.is_empty() {
574 return Ok(Some(RecordBatch::new_empty(schema.clone())));
575 }
576
577 let batch = build_result_batch(
579 &vids,
580 &scores,
581 variable,
582 target_properties,
583 label_name,
584 graph_ctx,
585 schema,
586 )
587 .await?;
588 Ok(Some(batch))
589}
590
591async fn retrieve_vid_scores(
603 graph_ctx: &GraphExecutionContext,
604 label_name: &str,
605 property: &str,
606 query_vector: &[f32],
607 k: usize,
608 source: &VectorSource,
609) -> DFResult<Vec<(Vid, f32)>> {
610 match source {
611 VectorSource::Native => {
612 let storage = graph_ctx.storage();
613 let query_ctx = graph_ctx.query_context();
614 storage
615 .vector_search(
616 label_name,
617 property,
618 query_vector,
619 k,
620 None,
621 Some(&query_ctx),
622 )
623 .await
624 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
625 }
626 VectorSource::Plugin { handle, .. } => {
627 let dim = i32::try_from(query_vector.len()).map_err(|_| {
630 datafusion::error::DataFusionError::Execution(
631 "query vector exceeds i32::MAX dimensions".to_string(),
632 )
633 })?;
634 let item_field = Arc::new(Field::new("item", DataType::Float32, true));
635 let mut fsl_builder =
636 FixedSizeListBuilder::new(Float32Builder::with_capacity(query_vector.len()), dim)
637 .with_field(Arc::clone(&item_field));
638 for &v in query_vector {
639 fsl_builder.values().append_value(v);
640 }
641 fsl_builder.append(true);
642 let fsl: FixedSizeListArray = fsl_builder.finish();
643
644 let query_schema = Arc::new(Schema::new(vec![Field::new(
645 "vector",
646 DataType::FixedSizeList(item_field, dim),
647 false,
648 )]));
649 let query_batch =
650 RecordBatch::try_new(query_schema, vec![Arc::new(fsl)]).map_err(arrow_err)?;
651
652 let result = handle.probe(&query_batch, k).map_err(|e| {
653 datafusion::error::DataFusionError::Execution(format!(
654 "IndexHandle::probe failed: {e:?}"
655 ))
656 })?;
657
658 let vid_col = result
661 .column_by_name("vid")
662 .ok_or_else(|| {
663 datafusion::error::DataFusionError::Execution(
664 "IndexHandle::probe result missing `vid` column".to_string(),
665 )
666 })?
667 .as_any()
668 .downcast_ref::<Int64Array>()
669 .ok_or_else(|| {
670 datafusion::error::DataFusionError::Execution(
671 "IndexHandle::probe result `vid` column is not Int64".to_string(),
672 )
673 })?;
674 let dist_col = result
675 .column_by_name("distance")
676 .ok_or_else(|| {
677 datafusion::error::DataFusionError::Execution(
678 "IndexHandle::probe result missing `distance` column".to_string(),
679 )
680 })?
681 .as_any()
682 .downcast_ref::<Float32Array>()
683 .ok_or_else(|| {
684 datafusion::error::DataFusionError::Execution(
685 "IndexHandle::probe result `distance` column is not Float32".to_string(),
686 )
687 })?;
688
689 let mut pairs = Vec::with_capacity(result.num_rows());
690 for i in 0..result.num_rows() {
691 if vid_col.is_null(i) {
692 continue;
693 }
694 let vid_i64 = vid_col.value(i);
695 let dist = if dist_col.is_null(i) {
696 f32::INFINITY
697 } else {
698 dist_col.value(i)
699 };
700 pairs.push((Vid::from(vid_i64 as u64), dist));
701 }
702 Ok(pairs)
703 }
704 }
705}
706
707async fn build_result_batch(
709 vids: &[Vid],
710 scores: &[f32],
711 _variable: &str,
712 target_properties: &[String],
713 label_name: &str,
714 graph_ctx: &GraphExecutionContext,
715 schema: &SchemaRef,
716) -> DFResult<RecordBatch> {
717 let num_rows = vids.len();
718
719 let mut vid_builder = UInt64Builder::with_capacity(num_rows);
721 for vid in vids {
722 vid_builder.append_value(vid.as_u64());
723 }
724
725 let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
727 for vid in vids {
728 var_builder.append_value(vid.to_string());
729 }
730
731 let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
733 for _vid in vids {
734 labels_builder.values().append_value(label_name);
735 labels_builder.append(true);
736 }
737
738 let mut score_builder = Float32Builder::with_capacity(num_rows);
740 for &score in scores {
741 score_builder.append_value(score);
742 }
743
744 let mut columns: Vec<ArrayRef> = vec![
745 Arc::new(vid_builder.finish()),
746 Arc::new(var_builder.finish()),
747 Arc::new(labels_builder.finish()),
748 Arc::new(score_builder.finish()),
749 ];
750
751 if !target_properties.is_empty() {
753 let property_manager = graph_ctx.property_manager();
754 let query_ctx = graph_ctx.query_context();
755
756 let props_map = property_manager
757 .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
758 .await
759 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
760
761 let uni_schema = graph_ctx.storage().schema_manager().schema();
762 let label_props = uni_schema.properties.get(label_name);
763
764 for prop_name in target_properties {
765 let data_type = resolve_property_type(prop_name, label_props);
766 let column = crate::query::df_graph::scan::build_property_column_static(
767 vids, &props_map, prop_name, &data_type,
768 )?;
769 columns.push(column);
770 }
771 }
772
773 RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779 use uni_cypher::ast::CypherLiteral;
780
781 #[test]
782 fn test_build_schema() {
783 let schema = GraphVectorKnnExec::build_schema("n", &[], None);
784
785 assert_eq!(schema.fields().len(), 4);
786 assert_eq!(schema.field(0).name(), "n._vid");
787 assert_eq!(schema.field(1).name(), "n");
788 assert_eq!(schema.field(2).name(), "n._labels");
789 assert_eq!(schema.field(3).name(), "n._score");
790 }
791
792 #[test]
793 fn test_evaluate_literal_list() {
794 let expr = Expr::List(vec![
795 Expr::Literal(CypherLiteral::Float(0.1)),
796 Expr::Literal(CypherLiteral::Float(0.2)),
797 Expr::Literal(CypherLiteral::Float(0.3)),
798 ]);
799
800 let result = evaluate_simple_expr(&expr, &HashMap::new(), &HashMap::new()).unwrap();
801 match result {
802 Value::List(arr) => {
803 assert_eq!(arr.len(), 3);
804 }
805 _ => panic!("Expected list"),
806 }
807 }
808
809 #[test]
810 fn test_evaluate_parameter() {
811 let expr = Expr::Parameter("query".to_string());
812 let mut params = HashMap::new();
813 params.insert(
814 "query".to_string(),
815 Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
816 );
817
818 let result = evaluate_simple_expr(&expr, ¶ms, &HashMap::new()).unwrap();
819 match result {
820 Value::List(arr) => {
821 assert_eq!(arr.len(), 2);
822 }
823 _ => panic!("Expected list"),
824 }
825 }
826
827 #[test]
828 fn test_build_schema_with_extra_properties() {
829 let extra_props = vec!["name".to_string(), "embedding".to_string()];
830 let schema = GraphVectorKnnExec::build_schema("doc", &extra_props, None);
831
832 assert!(schema.field_with_name("doc._vid").is_ok());
834 assert!(schema.field_with_name("doc").is_ok());
835 assert!(schema.field_with_name("doc._score").is_ok());
836 assert!(
837 schema.field_with_name("doc.name").is_ok(),
838 "Extra property 'name' should be in schema"
839 );
840 assert!(
841 schema.field_with_name("doc.embedding").is_ok(),
842 "Extra property 'embedding' should be in schema"
843 );
844 }
845
846 #[test]
847 fn test_evaluate_variable() {
848 let expr = Expr::Variable("x".to_string());
850 let mut variables = HashMap::new();
851 variables.insert(
852 "x".to_string(),
853 Value::List(vec![Value::Float(0.5), Value::Float(0.6)]),
854 );
855
856 let result = evaluate_simple_expr(&expr, &HashMap::new(), &variables).unwrap();
857 match result {
858 Value::List(arr) => {
859 assert_eq!(arr.len(), 2);
860 }
861 _ => panic!("Expected list, got {:?}", result),
862 }
863 }
864}