1use arrow_array::builder::{FixedSizeListBuilder, Float32Builder, StringBuilder, UInt64Builder};
17use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, Int64Array, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_common::core::id::Vid;
32use uni_common::core::schema::{DistanceMetric, PropertyMeta};
33use uni_cypher::ast::Expr;
34use uni_plugin::traits::index::{IndexHandle, IndexKind};
35
36use crate::query::df_graph::GraphExecutionContext;
37use crate::query::df_graph::common::{
38 arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
39};
40use crate::query::df_graph::scan::resolve_property_type;
41
42#[derive(Clone)]
60pub(crate) enum VectorSource {
61 Native,
63 Plugin {
65 #[allow(dead_code)]
68 kind: IndexKind,
69 handle: Arc<dyn IndexHandle>,
71 },
72}
73
74impl fmt::Debug for VectorSource {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::Native => f.write_str("Native"),
78 Self::Plugin { kind, .. } => f.debug_struct("Plugin").field("kind", kind).finish(),
79 }
80 }
81}
82
83pub struct GraphVectorKnnExec {
88 graph_ctx: Arc<GraphExecutionContext>,
90
91 label_id: u16,
93
94 label_name: String,
96
97 variable: String,
99
100 property: String,
102
103 query_expr: Expr,
105
106 k: usize,
108
109 threshold: Option<f32>,
111
112 params: HashMap<String, Value>,
114
115 target_properties: Vec<String>,
117
118 schema: SchemaRef,
120
121 properties: Arc<PlanProperties>,
123
124 source: VectorSource,
128
129 metrics: ExecutionPlanMetricsSet,
131}
132
133impl fmt::Debug for GraphVectorKnnExec {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 f.debug_struct("GraphVectorKnnExec")
136 .field("label_id", &self.label_id)
137 .field("variable", &self.variable)
138 .field("property", &self.property)
139 .field("k", &self.k)
140 .field("threshold", &self.threshold)
141 .finish()
142 }
143}
144
145impl GraphVectorKnnExec {
146 #[expect(clippy::too_many_arguments)]
160 pub fn new(
161 graph_ctx: Arc<GraphExecutionContext>,
162 label_id: u16,
163 label_name: impl Into<String>,
164 variable: impl Into<String>,
165 property: impl Into<String>,
166 query_expr: Expr,
167 k: usize,
168 threshold: Option<f32>,
169 params: HashMap<String, Value>,
170 target_properties: Vec<String>,
171 ) -> Self {
172 let variable = variable.into();
173 let property = property.into();
174 let label_name = label_name.into();
175
176 let uni_schema = graph_ctx.storage().schema_manager().schema();
178 let label_props = uni_schema.properties.get(label_name.as_str());
179
180 let schema = Self::build_schema(&variable, &target_properties, label_props);
181 let properties = compute_plan_properties(schema.clone());
182
183 Self {
184 graph_ctx,
185 label_id,
186 label_name,
187 variable,
188 property,
189 query_expr,
190 k,
191 threshold,
192 params,
193 target_properties,
194 schema,
195 properties,
196 source: VectorSource::Native,
197 metrics: ExecutionPlanMetricsSet::new(),
198 }
199 }
200
201 #[expect(clippy::too_many_arguments)]
208 pub fn with_plugin_source(
209 graph_ctx: Arc<GraphExecutionContext>,
210 label_id: u16,
211 label_name: impl Into<String>,
212 variable: impl Into<String>,
213 property: impl Into<String>,
214 query_expr: Expr,
215 k: usize,
216 threshold: Option<f32>,
217 params: HashMap<String, Value>,
218 target_properties: Vec<String>,
219 kind: IndexKind,
220 handle: Arc<dyn IndexHandle>,
221 ) -> Self {
222 let mut exec = Self::new(
223 graph_ctx,
224 label_id,
225 label_name,
226 variable,
227 property,
228 query_expr,
229 k,
230 threshold,
231 params,
232 target_properties,
233 );
234 exec.source = VectorSource::Plugin { kind, handle };
235 exec
236 }
237
238 fn build_schema(
246 variable: &str,
247 target_properties: &[String],
248 label_props: Option<&HashMap<String, PropertyMeta>>,
249 ) -> SchemaRef {
250 let mut fields = vec![
251 Field::new(format!("{}._vid", variable), DataType::UInt64, false),
252 Field::new(variable, DataType::Utf8, false),
253 Field::new(format!("{}._labels", variable), labels_data_type(), true),
254 Field::new(format!("{}._score", variable), DataType::Float32, true),
255 ];
256
257 for prop_name in target_properties {
259 let col_name = format!("{}.{}", variable, prop_name);
260 let arrow_type = resolve_property_type(prop_name, label_props);
261 fields.push(Field::new(&col_name, arrow_type, true));
262 }
263
264 Arc::new(Schema::new(fields))
265 }
266
267 fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
269 let value = evaluate_simple_expr(&self.query_expr, &self.params, &HashMap::new())?;
270
271 match value {
272 Value::Vector(vec) => Ok(vec),
273 Value::List(arr) => {
274 let mut vec = Vec::with_capacity(arr.len());
275 for v in arr {
276 if let Some(f) = v.as_f64() {
277 vec.push(f as f32);
278 } else {
279 return Err(datafusion::error::DataFusionError::Execution(
280 "Query vector must contain numbers".to_string(),
281 ));
282 }
283 }
284 Ok(vec)
285 }
286 _ => Err(datafusion::error::DataFusionError::Execution(
287 "Query vector must be a list or vector".to_string(),
288 )),
289 }
290 }
291}
292
293impl DisplayAs for GraphVectorKnnExec {
294 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295 write!(
296 f,
297 "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
298 self.label_name, self.property, self.k, self.variable
299 )
300 }
301}
302
303impl ExecutionPlan for GraphVectorKnnExec {
304 fn name(&self) -> &str {
305 "GraphVectorKnnExec"
306 }
307
308 fn as_any(&self) -> &dyn Any {
309 self
310 }
311
312 fn schema(&self) -> SchemaRef {
313 self.schema.clone()
314 }
315
316 fn properties(&self) -> &Arc<PlanProperties> {
317 &self.properties
318 }
319
320 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
321 vec![]
322 }
323
324 fn with_new_children(
325 self: Arc<Self>,
326 children: Vec<Arc<dyn ExecutionPlan>>,
327 ) -> DFResult<Arc<dyn ExecutionPlan>> {
328 if !children.is_empty() {
329 return Err(datafusion::error::DataFusionError::Internal(
330 "GraphVectorKnnExec has no children".to_string(),
331 ));
332 }
333 Ok(self)
334 }
335
336 fn execute(
337 &self,
338 partition: usize,
339 _context: Arc<TaskContext>,
340 ) -> DFResult<SendableRecordBatchStream> {
341 let metrics = BaselineMetrics::new(&self.metrics, partition);
342
343 let query_vector = self.evaluate_query_vector()?;
345
346 Ok(Box::pin(VectorKnnStream::new(
347 self.graph_ctx.clone(),
348 self.label_name.clone(),
349 self.variable.clone(),
350 self.property.clone(),
351 query_vector,
352 self.k,
353 self.threshold,
354 self.target_properties.clone(),
355 self.schema.clone(),
356 self.source.clone(),
357 metrics,
358 )))
359 }
360
361 fn metrics(&self) -> Option<MetricsSet> {
362 Some(self.metrics.clone_inner())
363 }
364}
365
366enum VectorKnnState {
368 Init,
370 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
372 Done,
374}
375
376struct VectorKnnStream {
378 graph_ctx: Arc<GraphExecutionContext>,
380
381 label_name: String,
383
384 variable: String,
386
387 property: String,
389
390 query_vector: Vec<f32>,
392
393 k: usize,
395
396 threshold: Option<f32>,
398
399 target_properties: Vec<String>,
401
402 schema: SchemaRef,
404
405 source: VectorSource,
407
408 state: VectorKnnState,
410
411 metrics: BaselineMetrics,
413}
414
415impl VectorKnnStream {
416 #[expect(clippy::too_many_arguments)]
417 fn new(
418 graph_ctx: Arc<GraphExecutionContext>,
419 label_name: String,
420 variable: String,
421 property: String,
422 query_vector: Vec<f32>,
423 k: usize,
424 threshold: Option<f32>,
425 target_properties: Vec<String>,
426 schema: SchemaRef,
427 source: VectorSource,
428 metrics: BaselineMetrics,
429 ) -> Self {
430 Self {
431 graph_ctx,
432 label_name,
433 variable,
434 property,
435 query_vector,
436 k,
437 threshold,
438 target_properties,
439 schema,
440 source,
441 state: VectorKnnState::Init,
442 metrics,
443 }
444 }
445}
446
447impl Stream for VectorKnnStream {
448 type Item = DFResult<RecordBatch>;
449
450 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
451 let metrics = self.metrics.clone();
452 let _timer = metrics.elapsed_compute().timer();
453 loop {
454 let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
455
456 match state {
457 VectorKnnState::Init => {
458 let graph_ctx = self.graph_ctx.clone();
460 let label_name = self.label_name.clone();
461 let variable = self.variable.clone();
462 let property = self.property.clone();
463 let query_vector = self.query_vector.clone();
464 let k = self.k;
465 let threshold = self.threshold;
466 let target_properties = self.target_properties.clone();
467 let schema = self.schema.clone();
468 let source = self.source.clone();
469
470 let fut = async move {
471 graph_ctx.check_timeout().map_err(|e| {
473 datafusion::error::DataFusionError::Execution(e.to_string())
474 })?;
475
476 execute_vector_search(
477 &graph_ctx,
478 &label_name,
479 &variable,
480 &property,
481 &query_vector,
482 k,
483 threshold,
484 &target_properties,
485 &schema,
486 &source,
487 )
488 .await
489 };
490
491 self.state = VectorKnnState::Executing(Box::pin(fut));
492 }
494 VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
495 Poll::Ready(Ok(batch)) => {
496 self.state = VectorKnnState::Done;
497 self.metrics
498 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
499 return Poll::Ready(batch.map(Ok));
500 }
501 Poll::Ready(Err(e)) => {
502 self.state = VectorKnnState::Done;
503 return Poll::Ready(Some(Err(e)));
504 }
505 Poll::Pending => {
506 self.state = VectorKnnState::Executing(fut);
507 return Poll::Pending;
508 }
509 },
510 VectorKnnState::Done => {
511 return Poll::Ready(None);
512 }
513 }
514 }
515 }
516}
517
518impl RecordBatchStream for VectorKnnStream {
519 fn schema(&self) -> SchemaRef {
520 self.schema.clone()
521 }
522}
523
524#[expect(clippy::too_many_arguments)]
526async fn execute_vector_search(
527 graph_ctx: &GraphExecutionContext,
528 label_name: &str,
529 variable: &str,
530 property: &str,
531 query_vector: &[f32],
532 k: usize,
533 threshold: Option<f32>,
534 target_properties: &[String],
535 schema: &SchemaRef,
536 source: &VectorSource,
537) -> DFResult<Option<RecordBatch>> {
538 let storage = graph_ctx.storage();
539
540 let results =
542 retrieve_vid_scores(graph_ctx, label_name, property, query_vector, k, source).await?;
543
544 let metric = storage
547 .schema_manager()
548 .schema()
549 .vector_index_for_property(label_name, property)
550 .map(|cfg| cfg.metric.clone())
551 .unwrap_or(DistanceMetric::L2);
552
553 let mut vids = Vec::new();
555 let mut scores = Vec::new();
556
557 for (vid, distance) in results {
558 let similarity = calculate_score(distance, &metric);
559
560 if let Some(thresh) = threshold
561 && similarity < thresh
562 {
563 continue;
564 }
565
566 vids.push(vid);
567 scores.push(similarity);
568 }
569
570 if vids.is_empty() {
571 return Ok(Some(RecordBatch::new_empty(schema.clone())));
572 }
573
574 let batch = build_result_batch(
576 &vids,
577 &scores,
578 variable,
579 target_properties,
580 label_name,
581 graph_ctx,
582 schema,
583 )
584 .await?;
585 Ok(Some(batch))
586}
587
588async fn retrieve_vid_scores(
600 graph_ctx: &GraphExecutionContext,
601 label_name: &str,
602 property: &str,
603 query_vector: &[f32],
604 k: usize,
605 source: &VectorSource,
606) -> DFResult<Vec<(Vid, f32)>> {
607 match source {
608 VectorSource::Native => {
609 let storage = graph_ctx.storage();
610 let query_ctx = graph_ctx.query_context();
611 storage
612 .vector_search(
613 label_name,
614 property,
615 query_vector,
616 k,
617 None,
618 Some(&query_ctx),
619 )
620 .await
621 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
622 }
623 VectorSource::Plugin { handle, .. } => {
624 let dim = i32::try_from(query_vector.len()).map_err(|_| {
627 datafusion::error::DataFusionError::Execution(
628 "query vector exceeds i32::MAX dimensions".to_string(),
629 )
630 })?;
631 let item_field = Arc::new(Field::new("item", DataType::Float32, true));
632 let mut fsl_builder =
633 FixedSizeListBuilder::new(Float32Builder::with_capacity(query_vector.len()), dim)
634 .with_field(Arc::clone(&item_field));
635 for &v in query_vector {
636 fsl_builder.values().append_value(v);
637 }
638 fsl_builder.append(true);
639 let fsl: FixedSizeListArray = fsl_builder.finish();
640
641 let query_schema = Arc::new(Schema::new(vec![Field::new(
642 "vector",
643 DataType::FixedSizeList(item_field, dim),
644 false,
645 )]));
646 let query_batch =
647 RecordBatch::try_new(query_schema, vec![Arc::new(fsl)]).map_err(arrow_err)?;
648
649 let result = handle.probe(&query_batch, k).map_err(|e| {
650 datafusion::error::DataFusionError::Execution(format!(
651 "IndexHandle::probe failed: {e:?}"
652 ))
653 })?;
654
655 let vid_col = result
658 .column_by_name("vid")
659 .ok_or_else(|| {
660 datafusion::error::DataFusionError::Execution(
661 "IndexHandle::probe result missing `vid` column".to_string(),
662 )
663 })?
664 .as_any()
665 .downcast_ref::<Int64Array>()
666 .ok_or_else(|| {
667 datafusion::error::DataFusionError::Execution(
668 "IndexHandle::probe result `vid` column is not Int64".to_string(),
669 )
670 })?;
671 let dist_col = result
672 .column_by_name("distance")
673 .ok_or_else(|| {
674 datafusion::error::DataFusionError::Execution(
675 "IndexHandle::probe result missing `distance` column".to_string(),
676 )
677 })?
678 .as_any()
679 .downcast_ref::<Float32Array>()
680 .ok_or_else(|| {
681 datafusion::error::DataFusionError::Execution(
682 "IndexHandle::probe result `distance` column is not Float32".to_string(),
683 )
684 })?;
685
686 let mut pairs = Vec::with_capacity(result.num_rows());
687 for i in 0..result.num_rows() {
688 if vid_col.is_null(i) {
689 continue;
690 }
691 let vid_i64 = vid_col.value(i);
692 let dist = if dist_col.is_null(i) {
693 f32::INFINITY
694 } else {
695 dist_col.value(i)
696 };
697 pairs.push((Vid::from(vid_i64 as u64), dist));
698 }
699 Ok(pairs)
700 }
701 }
702}
703
704async fn build_result_batch(
706 vids: &[Vid],
707 scores: &[f32],
708 _variable: &str,
709 target_properties: &[String],
710 label_name: &str,
711 graph_ctx: &GraphExecutionContext,
712 schema: &SchemaRef,
713) -> DFResult<RecordBatch> {
714 let num_rows = vids.len();
715
716 let mut vid_builder = UInt64Builder::with_capacity(num_rows);
718 for vid in vids {
719 vid_builder.append_value(vid.as_u64());
720 }
721
722 let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
724 for vid in vids {
725 var_builder.append_value(vid.to_string());
726 }
727
728 let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
730 for _vid in vids {
731 labels_builder.values().append_value(label_name);
732 labels_builder.append(true);
733 }
734
735 let mut score_builder = Float32Builder::with_capacity(num_rows);
737 for &score in scores {
738 score_builder.append_value(score);
739 }
740
741 let mut columns: Vec<ArrayRef> = vec![
742 Arc::new(vid_builder.finish()),
743 Arc::new(var_builder.finish()),
744 Arc::new(labels_builder.finish()),
745 Arc::new(score_builder.finish()),
746 ];
747
748 if !target_properties.is_empty() {
750 let property_manager = graph_ctx.property_manager();
751 let query_ctx = graph_ctx.query_context();
752
753 let props_map = property_manager
754 .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
755 .await
756 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
757
758 let uni_schema = graph_ctx.storage().schema_manager().schema();
759 let label_props = uni_schema.properties.get(label_name);
760
761 for prop_name in target_properties {
762 let data_type = resolve_property_type(prop_name, label_props);
763 let column = crate::query::df_graph::scan::build_property_column_static(
764 vids, &props_map, prop_name, &data_type,
765 )?;
766 columns.push(column);
767 }
768 }
769
770 RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use uni_cypher::ast::CypherLiteral;
777
778 #[test]
779 fn test_build_schema() {
780 let schema = GraphVectorKnnExec::build_schema("n", &[], None);
781
782 assert_eq!(schema.fields().len(), 4);
783 assert_eq!(schema.field(0).name(), "n._vid");
784 assert_eq!(schema.field(1).name(), "n");
785 assert_eq!(schema.field(2).name(), "n._labels");
786 assert_eq!(schema.field(3).name(), "n._score");
787 }
788
789 #[test]
790 fn test_evaluate_literal_list() {
791 let expr = Expr::List(vec![
792 Expr::Literal(CypherLiteral::Float(0.1)),
793 Expr::Literal(CypherLiteral::Float(0.2)),
794 Expr::Literal(CypherLiteral::Float(0.3)),
795 ]);
796
797 let result = evaluate_simple_expr(&expr, &HashMap::new(), &HashMap::new()).unwrap();
798 match result {
799 Value::List(arr) => {
800 assert_eq!(arr.len(), 3);
801 }
802 _ => panic!("Expected list"),
803 }
804 }
805
806 #[test]
807 fn test_evaluate_parameter() {
808 let expr = Expr::Parameter("query".to_string());
809 let mut params = HashMap::new();
810 params.insert(
811 "query".to_string(),
812 Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
813 );
814
815 let result = evaluate_simple_expr(&expr, ¶ms, &HashMap::new()).unwrap();
816 match result {
817 Value::List(arr) => {
818 assert_eq!(arr.len(), 2);
819 }
820 _ => panic!("Expected list"),
821 }
822 }
823
824 #[test]
825 fn test_build_schema_with_extra_properties() {
826 let extra_props = vec!["name".to_string(), "embedding".to_string()];
827 let schema = GraphVectorKnnExec::build_schema("doc", &extra_props, None);
828
829 assert!(schema.field_with_name("doc._vid").is_ok());
831 assert!(schema.field_with_name("doc").is_ok());
832 assert!(schema.field_with_name("doc._score").is_ok());
833 assert!(
834 schema.field_with_name("doc.name").is_ok(),
835 "Extra property 'name' should be in schema"
836 );
837 assert!(
838 schema.field_with_name("doc.embedding").is_ok(),
839 "Extra property 'embedding' should be in schema"
840 );
841 }
842
843 #[test]
844 fn test_evaluate_variable() {
845 let expr = Expr::Variable("x".to_string());
847 let mut variables = HashMap::new();
848 variables.insert(
849 "x".to_string(),
850 Value::List(vec![Value::Float(0.5), Value::Float(0.6)]),
851 );
852
853 let result = evaluate_simple_expr(&expr, &HashMap::new(), &variables).unwrap();
854 match result {
855 Value::List(arr) => {
856 assert_eq!(arr.len(), 2);
857 }
858 _ => panic!("Expected list, got {:?}", result),
859 }
860 }
861}