1use crate::query::df_graph::GraphExecutionContext;
17use crate::query::df_graph::common::{
18 compute_plan_properties, evaluate_simple_expr, labels_data_type,
19};
20use crate::query::df_graph::scan::resolve_property_type;
21use arrow_array::builder::{Float32Builder, StringBuilder, UInt64Builder};
22use arrow_array::{ArrayRef, RecordBatch};
23use arrow_schema::{DataType, Field, Schema, SchemaRef};
24use datafusion::common::Result as DFResult;
25use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
26use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
27use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
28use futures::Stream;
29use std::any::Any;
30use std::collections::HashMap;
31use std::fmt;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::task::{Context, Poll};
35use uni_common::Value;
36use uni_common::core::id::Vid;
37use uni_cypher::ast::Expr;
38
39pub struct GraphVectorKnnExec {
44 graph_ctx: Arc<GraphExecutionContext>,
46
47 label_id: u16,
49
50 label_name: String,
52
53 variable: String,
55
56 property: String,
58
59 query_expr: Expr,
61
62 k: usize,
64
65 threshold: Option<f32>,
67
68 params: HashMap<String, Value>,
70
71 target_properties: Vec<String>,
73
74 schema: SchemaRef,
76
77 properties: PlanProperties,
79
80 metrics: ExecutionPlanMetricsSet,
82}
83
84impl fmt::Debug for GraphVectorKnnExec {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 f.debug_struct("GraphVectorKnnExec")
87 .field("label_id", &self.label_id)
88 .field("variable", &self.variable)
89 .field("property", &self.property)
90 .field("k", &self.k)
91 .field("threshold", &self.threshold)
92 .finish()
93 }
94}
95
96impl GraphVectorKnnExec {
97 #[expect(clippy::too_many_arguments)]
111 pub fn new(
112 graph_ctx: Arc<GraphExecutionContext>,
113 label_id: u16,
114 label_name: impl Into<String>,
115 variable: impl Into<String>,
116 property: impl Into<String>,
117 query_expr: Expr,
118 k: usize,
119 threshold: Option<f32>,
120 params: HashMap<String, Value>,
121 target_properties: Vec<String>,
122 ) -> Self {
123 let variable = variable.into();
124 let property = property.into();
125 let label_name = label_name.into();
126
127 let uni_schema = graph_ctx.storage().schema_manager().schema();
129 let label_props = uni_schema.properties.get(label_name.as_str());
130
131 let schema = Self::build_schema(&variable, &target_properties, label_props);
132 let properties = compute_plan_properties(schema.clone());
133
134 Self {
135 graph_ctx,
136 label_id,
137 label_name,
138 variable,
139 property,
140 query_expr,
141 k,
142 threshold,
143 params,
144 target_properties,
145 schema,
146 properties,
147 metrics: ExecutionPlanMetricsSet::new(),
148 }
149 }
150
151 fn build_schema(
159 variable: &str,
160 target_properties: &[String],
161 label_props: Option<
162 &std::collections::HashMap<String, uni_common::core::schema::PropertyMeta>,
163 >,
164 ) -> SchemaRef {
165 let mut fields = vec![
166 Field::new(format!("{}._vid", variable), DataType::UInt64, false),
167 Field::new(variable, DataType::Utf8, false),
168 Field::new(format!("{}._labels", variable), labels_data_type(), true),
169 Field::new(format!("{}._score", variable), DataType::Float32, true),
170 ];
171
172 for prop_name in target_properties {
174 let col_name = format!("{}.{}", variable, prop_name);
175 let arrow_type = resolve_property_type(prop_name, label_props);
176 fields.push(Field::new(&col_name, arrow_type, true));
177 }
178
179 Arc::new(Schema::new(fields))
180 }
181
182 fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
184 let value = evaluate_simple_expr(&self.query_expr, &self.params)?;
185
186 match value {
187 Value::Vector(vec) => Ok(vec),
188 Value::List(arr) => {
189 let mut vec = Vec::with_capacity(arr.len());
190 for v in arr {
191 if let Some(f) = v.as_f64() {
192 vec.push(f as f32);
193 } else {
194 return Err(datafusion::error::DataFusionError::Execution(
195 "Query vector must contain numbers".to_string(),
196 ));
197 }
198 }
199 Ok(vec)
200 }
201 _ => Err(datafusion::error::DataFusionError::Execution(
202 "Query vector must be a list or vector".to_string(),
203 )),
204 }
205 }
206}
207
208impl DisplayAs for GraphVectorKnnExec {
209 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 write!(
211 f,
212 "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
213 self.label_name, self.property, self.k, self.variable
214 )
215 }
216}
217
218impl ExecutionPlan for GraphVectorKnnExec {
219 fn name(&self) -> &str {
220 "GraphVectorKnnExec"
221 }
222
223 fn as_any(&self) -> &dyn Any {
224 self
225 }
226
227 fn schema(&self) -> SchemaRef {
228 self.schema.clone()
229 }
230
231 fn properties(&self) -> &PlanProperties {
232 &self.properties
233 }
234
235 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
236 vec![]
237 }
238
239 fn with_new_children(
240 self: Arc<Self>,
241 children: Vec<Arc<dyn ExecutionPlan>>,
242 ) -> DFResult<Arc<dyn ExecutionPlan>> {
243 if !children.is_empty() {
244 return Err(datafusion::error::DataFusionError::Internal(
245 "GraphVectorKnnExec has no children".to_string(),
246 ));
247 }
248 Ok(self)
249 }
250
251 fn execute(
252 &self,
253 partition: usize,
254 _context: Arc<TaskContext>,
255 ) -> DFResult<SendableRecordBatchStream> {
256 let metrics = BaselineMetrics::new(&self.metrics, partition);
257
258 let query_vector = self.evaluate_query_vector()?;
260
261 Ok(Box::pin(VectorKnnStream::new(
262 self.graph_ctx.clone(),
263 self.label_name.clone(),
264 self.variable.clone(),
265 self.property.clone(),
266 query_vector,
267 self.k,
268 self.threshold,
269 self.target_properties.clone(),
270 self.schema.clone(),
271 metrics,
272 )))
273 }
274
275 fn metrics(&self) -> Option<MetricsSet> {
276 Some(self.metrics.clone_inner())
277 }
278}
279
280enum VectorKnnState {
282 Init,
284 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
286 Done,
288}
289
290struct VectorKnnStream {
292 graph_ctx: Arc<GraphExecutionContext>,
294
295 label_name: String,
297
298 variable: String,
300
301 property: String,
303
304 query_vector: Vec<f32>,
306
307 k: usize,
309
310 threshold: Option<f32>,
312
313 target_properties: Vec<String>,
315
316 schema: SchemaRef,
318
319 state: VectorKnnState,
321
322 metrics: BaselineMetrics,
324}
325
326impl VectorKnnStream {
327 #[expect(clippy::too_many_arguments)]
328 fn new(
329 graph_ctx: Arc<GraphExecutionContext>,
330 label_name: String,
331 variable: String,
332 property: String,
333 query_vector: Vec<f32>,
334 k: usize,
335 threshold: Option<f32>,
336 target_properties: Vec<String>,
337 schema: SchemaRef,
338 metrics: BaselineMetrics,
339 ) -> Self {
340 Self {
341 graph_ctx,
342 label_name,
343 variable,
344 property,
345 query_vector,
346 k,
347 threshold,
348 target_properties,
349 schema,
350 state: VectorKnnState::Init,
351 metrics,
352 }
353 }
354}
355
356impl Stream for VectorKnnStream {
357 type Item = DFResult<RecordBatch>;
358
359 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
360 loop {
361 let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
362
363 match state {
364 VectorKnnState::Init => {
365 let graph_ctx = self.graph_ctx.clone();
367 let label_name = self.label_name.clone();
368 let variable = self.variable.clone();
369 let property = self.property.clone();
370 let query_vector = self.query_vector.clone();
371 let k = self.k;
372 let threshold = self.threshold;
373 let target_properties = self.target_properties.clone();
374 let schema = self.schema.clone();
375
376 let fut = async move {
377 graph_ctx.check_timeout().map_err(|e| {
379 datafusion::error::DataFusionError::Execution(e.to_string())
380 })?;
381
382 execute_vector_search(
383 &graph_ctx,
384 &label_name,
385 &variable,
386 &property,
387 &query_vector,
388 k,
389 threshold,
390 &target_properties,
391 &schema,
392 )
393 .await
394 };
395
396 self.state = VectorKnnState::Executing(Box::pin(fut));
397 }
399 VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
400 Poll::Ready(Ok(batch)) => {
401 self.state = VectorKnnState::Done;
402 self.metrics
403 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
404 return Poll::Ready(batch.map(Ok));
405 }
406 Poll::Ready(Err(e)) => {
407 self.state = VectorKnnState::Done;
408 return Poll::Ready(Some(Err(e)));
409 }
410 Poll::Pending => {
411 self.state = VectorKnnState::Executing(fut);
412 return Poll::Pending;
413 }
414 },
415 VectorKnnState::Done => {
416 return Poll::Ready(None);
417 }
418 }
419 }
420 }
421}
422
423impl RecordBatchStream for VectorKnnStream {
424 fn schema(&self) -> SchemaRef {
425 self.schema.clone()
426 }
427}
428
429#[expect(clippy::too_many_arguments)]
431async fn execute_vector_search(
432 graph_ctx: &GraphExecutionContext,
433 label_name: &str,
434 variable: &str,
435 property: &str,
436 query_vector: &[f32],
437 k: usize,
438 threshold: Option<f32>,
439 target_properties: &[String],
440 schema: &SchemaRef,
441) -> DFResult<Option<RecordBatch>> {
442 let storage = graph_ctx.storage();
443 let query_ctx = graph_ctx.query_context();
444
445 let results = storage
447 .vector_search(
448 label_name,
449 property,
450 query_vector,
451 k,
452 None,
453 Some(&query_ctx),
454 )
455 .await
456 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
457
458 let mut vids = Vec::new();
460 let mut scores = Vec::new();
461
462 for (vid, distance) in results {
463 let similarity = 1.0 - distance;
465
466 if let Some(thresh) = threshold
467 && similarity < thresh
468 {
469 continue;
470 }
471
472 vids.push(vid);
473 scores.push(similarity);
474 }
475
476 if vids.is_empty() {
477 return Ok(Some(RecordBatch::new_empty(schema.clone())));
478 }
479
480 let batch = build_result_batch(
482 &vids,
483 &scores,
484 variable,
485 target_properties,
486 label_name,
487 graph_ctx,
488 schema,
489 )
490 .await?;
491 Ok(Some(batch))
492}
493
494async fn build_result_batch(
496 vids: &[Vid],
497 scores: &[f32],
498 _variable: &str,
499 target_properties: &[String],
500 label_name: &str,
501 graph_ctx: &GraphExecutionContext,
502 schema: &SchemaRef,
503) -> DFResult<RecordBatch> {
504 let num_rows = vids.len();
505
506 let mut vid_builder = UInt64Builder::with_capacity(num_rows);
508 for vid in vids {
509 vid_builder.append_value(vid.as_u64());
510 }
511
512 let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
514 for vid in vids {
515 var_builder.append_value(vid.to_string());
516 }
517
518 let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
520 for _vid in vids {
521 labels_builder.values().append_value(label_name);
522 labels_builder.append(true);
523 }
524
525 let mut score_builder = Float32Builder::with_capacity(num_rows);
527 for &score in scores {
528 score_builder.append_value(score);
529 }
530
531 let mut columns: Vec<ArrayRef> = vec![
532 Arc::new(vid_builder.finish()),
533 Arc::new(var_builder.finish()),
534 Arc::new(labels_builder.finish()),
535 Arc::new(score_builder.finish()),
536 ];
537
538 if !target_properties.is_empty() {
540 let property_manager = graph_ctx.property_manager();
541 let query_ctx = graph_ctx.query_context();
542
543 let props_map = property_manager
544 .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
545 .await
546 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
547
548 let uni_schema = graph_ctx.storage().schema_manager().schema();
549 let label_props = uni_schema.properties.get(label_name);
550
551 for prop_name in target_properties {
552 let data_type = resolve_property_type(prop_name, label_props);
553 let column = crate::query::df_graph::scan::build_property_column_static(
554 vids, &props_map, prop_name, &data_type,
555 )?;
556 columns.push(column);
557 }
558 }
559
560 RecordBatch::try_new(schema.clone(), columns)
561 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use uni_cypher::ast::CypherLiteral;
568
569 #[test]
570 fn test_build_schema() {
571 let schema = GraphVectorKnnExec::build_schema("n", &[], None);
572
573 assert_eq!(schema.fields().len(), 4);
574 assert_eq!(schema.field(0).name(), "n._vid");
575 assert_eq!(schema.field(1).name(), "n");
576 assert_eq!(schema.field(2).name(), "n._labels");
577 assert_eq!(schema.field(3).name(), "n._score");
578 }
579
580 #[test]
581 fn test_evaluate_literal_list() {
582 let expr = Expr::List(vec![
583 Expr::Literal(CypherLiteral::Float(0.1)),
584 Expr::Literal(CypherLiteral::Float(0.2)),
585 Expr::Literal(CypherLiteral::Float(0.3)),
586 ]);
587
588 let result = evaluate_simple_expr(&expr, &HashMap::new()).unwrap();
589 match result {
590 Value::List(arr) => {
591 assert_eq!(arr.len(), 3);
592 }
593 _ => panic!("Expected list"),
594 }
595 }
596
597 #[test]
598 fn test_evaluate_parameter() {
599 let expr = Expr::Parameter("query".to_string());
600 let mut params = HashMap::new();
601 params.insert(
602 "query".to_string(),
603 Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
604 );
605
606 let result = evaluate_simple_expr(&expr, ¶ms).unwrap();
607 match result {
608 Value::List(arr) => {
609 assert_eq!(arr.len(), 2);
610 }
611 _ => panic!("Expected list"),
612 }
613 }
614}