1use arrow_array::builder::{Float32Builder, StringBuilder, UInt64Builder};
17use arrow_array::{ArrayRef, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_common::core::id::Vid;
32use uni_common::core::schema::{DistanceMetric, PropertyMeta};
33use uni_cypher::ast::Expr;
34
35use crate::query::df_graph::GraphExecutionContext;
36use crate::query::df_graph::common::{
37 arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
38};
39use crate::query::df_graph::scan::resolve_property_type;
40
41pub struct GraphVectorKnnExec {
46 graph_ctx: Arc<GraphExecutionContext>,
48
49 label_id: u16,
51
52 label_name: String,
54
55 variable: String,
57
58 property: String,
60
61 query_expr: Expr,
63
64 k: usize,
66
67 threshold: Option<f32>,
69
70 params: HashMap<String, Value>,
72
73 target_properties: Vec<String>,
75
76 schema: SchemaRef,
78
79 properties: PlanProperties,
81
82 metrics: ExecutionPlanMetricsSet,
84}
85
86impl fmt::Debug for GraphVectorKnnExec {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("GraphVectorKnnExec")
89 .field("label_id", &self.label_id)
90 .field("variable", &self.variable)
91 .field("property", &self.property)
92 .field("k", &self.k)
93 .field("threshold", &self.threshold)
94 .finish()
95 }
96}
97
98impl GraphVectorKnnExec {
99 #[expect(clippy::too_many_arguments)]
113 pub fn new(
114 graph_ctx: Arc<GraphExecutionContext>,
115 label_id: u16,
116 label_name: impl Into<String>,
117 variable: impl Into<String>,
118 property: impl Into<String>,
119 query_expr: Expr,
120 k: usize,
121 threshold: Option<f32>,
122 params: HashMap<String, Value>,
123 target_properties: Vec<String>,
124 ) -> Self {
125 let variable = variable.into();
126 let property = property.into();
127 let label_name = label_name.into();
128
129 let uni_schema = graph_ctx.storage().schema_manager().schema();
131 let label_props = uni_schema.properties.get(label_name.as_str());
132
133 let schema = Self::build_schema(&variable, &target_properties, label_props);
134 let properties = compute_plan_properties(schema.clone());
135
136 Self {
137 graph_ctx,
138 label_id,
139 label_name,
140 variable,
141 property,
142 query_expr,
143 k,
144 threshold,
145 params,
146 target_properties,
147 schema,
148 properties,
149 metrics: ExecutionPlanMetricsSet::new(),
150 }
151 }
152
153 fn build_schema(
161 variable: &str,
162 target_properties: &[String],
163 label_props: Option<&HashMap<String, PropertyMeta>>,
164 ) -> SchemaRef {
165 let mut fields = vec![
166 Field::new(format!("{}._vid", variable), DataType::UInt64, false),
167 Field::new(variable, DataType::Utf8, false),
168 Field::new(format!("{}._labels", variable), labels_data_type(), true),
169 Field::new(format!("{}._score", variable), DataType::Float32, true),
170 ];
171
172 for prop_name in target_properties {
174 let col_name = format!("{}.{}", variable, prop_name);
175 let arrow_type = resolve_property_type(prop_name, label_props);
176 fields.push(Field::new(&col_name, arrow_type, true));
177 }
178
179 Arc::new(Schema::new(fields))
180 }
181
182 fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
184 let value = evaluate_simple_expr(&self.query_expr, &self.params)?;
185
186 match value {
187 Value::Vector(vec) => Ok(vec),
188 Value::List(arr) => {
189 let mut vec = Vec::with_capacity(arr.len());
190 for v in arr {
191 if let Some(f) = v.as_f64() {
192 vec.push(f as f32);
193 } else {
194 return Err(datafusion::error::DataFusionError::Execution(
195 "Query vector must contain numbers".to_string(),
196 ));
197 }
198 }
199 Ok(vec)
200 }
201 _ => Err(datafusion::error::DataFusionError::Execution(
202 "Query vector must be a list or vector".to_string(),
203 )),
204 }
205 }
206}
207
208impl DisplayAs for GraphVectorKnnExec {
209 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 write!(
211 f,
212 "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
213 self.label_name, self.property, self.k, self.variable
214 )
215 }
216}
217
218impl ExecutionPlan for GraphVectorKnnExec {
219 fn name(&self) -> &str {
220 "GraphVectorKnnExec"
221 }
222
223 fn as_any(&self) -> &dyn Any {
224 self
225 }
226
227 fn schema(&self) -> SchemaRef {
228 self.schema.clone()
229 }
230
231 fn properties(&self) -> &PlanProperties {
232 &self.properties
233 }
234
235 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
236 vec![]
237 }
238
239 fn with_new_children(
240 self: Arc<Self>,
241 children: Vec<Arc<dyn ExecutionPlan>>,
242 ) -> DFResult<Arc<dyn ExecutionPlan>> {
243 if !children.is_empty() {
244 return Err(datafusion::error::DataFusionError::Internal(
245 "GraphVectorKnnExec has no children".to_string(),
246 ));
247 }
248 Ok(self)
249 }
250
251 fn execute(
252 &self,
253 partition: usize,
254 _context: Arc<TaskContext>,
255 ) -> DFResult<SendableRecordBatchStream> {
256 let metrics = BaselineMetrics::new(&self.metrics, partition);
257
258 let query_vector = self.evaluate_query_vector()?;
260
261 Ok(Box::pin(VectorKnnStream::new(
262 self.graph_ctx.clone(),
263 self.label_name.clone(),
264 self.variable.clone(),
265 self.property.clone(),
266 query_vector,
267 self.k,
268 self.threshold,
269 self.target_properties.clone(),
270 self.schema.clone(),
271 metrics,
272 )))
273 }
274
275 fn metrics(&self) -> Option<MetricsSet> {
276 Some(self.metrics.clone_inner())
277 }
278}
279
280enum VectorKnnState {
282 Init,
284 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
286 Done,
288}
289
290struct VectorKnnStream {
292 graph_ctx: Arc<GraphExecutionContext>,
294
295 label_name: String,
297
298 variable: String,
300
301 property: String,
303
304 query_vector: Vec<f32>,
306
307 k: usize,
309
310 threshold: Option<f32>,
312
313 target_properties: Vec<String>,
315
316 schema: SchemaRef,
318
319 state: VectorKnnState,
321
322 metrics: BaselineMetrics,
324}
325
326impl VectorKnnStream {
327 #[expect(clippy::too_many_arguments)]
328 fn new(
329 graph_ctx: Arc<GraphExecutionContext>,
330 label_name: String,
331 variable: String,
332 property: String,
333 query_vector: Vec<f32>,
334 k: usize,
335 threshold: Option<f32>,
336 target_properties: Vec<String>,
337 schema: SchemaRef,
338 metrics: BaselineMetrics,
339 ) -> Self {
340 Self {
341 graph_ctx,
342 label_name,
343 variable,
344 property,
345 query_vector,
346 k,
347 threshold,
348 target_properties,
349 schema,
350 state: VectorKnnState::Init,
351 metrics,
352 }
353 }
354}
355
356impl Stream for VectorKnnStream {
357 type Item = DFResult<RecordBatch>;
358
359 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
360 loop {
361 let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
362
363 match state {
364 VectorKnnState::Init => {
365 let graph_ctx = self.graph_ctx.clone();
367 let label_name = self.label_name.clone();
368 let variable = self.variable.clone();
369 let property = self.property.clone();
370 let query_vector = self.query_vector.clone();
371 let k = self.k;
372 let threshold = self.threshold;
373 let target_properties = self.target_properties.clone();
374 let schema = self.schema.clone();
375
376 let fut = async move {
377 graph_ctx.check_timeout().map_err(|e| {
379 datafusion::error::DataFusionError::Execution(e.to_string())
380 })?;
381
382 execute_vector_search(
383 &graph_ctx,
384 &label_name,
385 &variable,
386 &property,
387 &query_vector,
388 k,
389 threshold,
390 &target_properties,
391 &schema,
392 )
393 .await
394 };
395
396 self.state = VectorKnnState::Executing(Box::pin(fut));
397 }
399 VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
400 Poll::Ready(Ok(batch)) => {
401 self.state = VectorKnnState::Done;
402 self.metrics
403 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
404 return Poll::Ready(batch.map(Ok));
405 }
406 Poll::Ready(Err(e)) => {
407 self.state = VectorKnnState::Done;
408 return Poll::Ready(Some(Err(e)));
409 }
410 Poll::Pending => {
411 self.state = VectorKnnState::Executing(fut);
412 return Poll::Pending;
413 }
414 },
415 VectorKnnState::Done => {
416 return Poll::Ready(None);
417 }
418 }
419 }
420 }
421}
422
423impl RecordBatchStream for VectorKnnStream {
424 fn schema(&self) -> SchemaRef {
425 self.schema.clone()
426 }
427}
428
429#[expect(clippy::too_many_arguments)]
431async fn execute_vector_search(
432 graph_ctx: &GraphExecutionContext,
433 label_name: &str,
434 variable: &str,
435 property: &str,
436 query_vector: &[f32],
437 k: usize,
438 threshold: Option<f32>,
439 target_properties: &[String],
440 schema: &SchemaRef,
441) -> DFResult<Option<RecordBatch>> {
442 let storage = graph_ctx.storage();
443 let query_ctx = graph_ctx.query_context();
444
445 let results = storage
447 .vector_search(
448 label_name,
449 property,
450 query_vector,
451 k,
452 None,
453 Some(&query_ctx),
454 )
455 .await
456 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
457
458 let metric = storage
461 .schema_manager()
462 .schema()
463 .vector_index_for_property(label_name, property)
464 .map(|cfg| cfg.metric.clone())
465 .unwrap_or(DistanceMetric::L2);
466
467 let mut vids = Vec::new();
469 let mut scores = Vec::new();
470
471 for (vid, distance) in results {
472 let similarity = calculate_score(distance, &metric);
473
474 if let Some(thresh) = threshold
475 && similarity < thresh
476 {
477 continue;
478 }
479
480 vids.push(vid);
481 scores.push(similarity);
482 }
483
484 if vids.is_empty() {
485 return Ok(Some(RecordBatch::new_empty(schema.clone())));
486 }
487
488 let batch = build_result_batch(
490 &vids,
491 &scores,
492 variable,
493 target_properties,
494 label_name,
495 graph_ctx,
496 schema,
497 )
498 .await?;
499 Ok(Some(batch))
500}
501
502async fn build_result_batch(
504 vids: &[Vid],
505 scores: &[f32],
506 _variable: &str,
507 target_properties: &[String],
508 label_name: &str,
509 graph_ctx: &GraphExecutionContext,
510 schema: &SchemaRef,
511) -> DFResult<RecordBatch> {
512 let num_rows = vids.len();
513
514 let mut vid_builder = UInt64Builder::with_capacity(num_rows);
516 for vid in vids {
517 vid_builder.append_value(vid.as_u64());
518 }
519
520 let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
522 for vid in vids {
523 var_builder.append_value(vid.to_string());
524 }
525
526 let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
528 for _vid in vids {
529 labels_builder.values().append_value(label_name);
530 labels_builder.append(true);
531 }
532
533 let mut score_builder = Float32Builder::with_capacity(num_rows);
535 for &score in scores {
536 score_builder.append_value(score);
537 }
538
539 let mut columns: Vec<ArrayRef> = vec![
540 Arc::new(vid_builder.finish()),
541 Arc::new(var_builder.finish()),
542 Arc::new(labels_builder.finish()),
543 Arc::new(score_builder.finish()),
544 ];
545
546 if !target_properties.is_empty() {
548 let property_manager = graph_ctx.property_manager();
549 let query_ctx = graph_ctx.query_context();
550
551 let props_map = property_manager
552 .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
553 .await
554 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
555
556 let uni_schema = graph_ctx.storage().schema_manager().schema();
557 let label_props = uni_schema.properties.get(label_name);
558
559 for prop_name in target_properties {
560 let data_type = resolve_property_type(prop_name, label_props);
561 let column = crate::query::df_graph::scan::build_property_column_static(
562 vids, &props_map, prop_name, &data_type,
563 )?;
564 columns.push(column);
565 }
566 }
567
568 RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use uni_cypher::ast::CypherLiteral;
575
576 #[test]
577 fn test_build_schema() {
578 let schema = GraphVectorKnnExec::build_schema("n", &[], None);
579
580 assert_eq!(schema.fields().len(), 4);
581 assert_eq!(schema.field(0).name(), "n._vid");
582 assert_eq!(schema.field(1).name(), "n");
583 assert_eq!(schema.field(2).name(), "n._labels");
584 assert_eq!(schema.field(3).name(), "n._score");
585 }
586
587 #[test]
588 fn test_evaluate_literal_list() {
589 let expr = Expr::List(vec![
590 Expr::Literal(CypherLiteral::Float(0.1)),
591 Expr::Literal(CypherLiteral::Float(0.2)),
592 Expr::Literal(CypherLiteral::Float(0.3)),
593 ]);
594
595 let result = evaluate_simple_expr(&expr, &HashMap::new()).unwrap();
596 match result {
597 Value::List(arr) => {
598 assert_eq!(arr.len(), 3);
599 }
600 _ => panic!("Expected list"),
601 }
602 }
603
604 #[test]
605 fn test_evaluate_parameter() {
606 let expr = Expr::Parameter("query".to_string());
607 let mut params = HashMap::new();
608 params.insert(
609 "query".to_string(),
610 Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
611 );
612
613 let result = evaluate_simple_expr(&expr, ¶ms).unwrap();
614 match result {
615 Value::List(arr) => {
616 assert_eq!(arr.len(), 2);
617 }
618 _ => panic!("Expected list"),
619 }
620 }
621}