1use crate::query::df_graph::GraphExecutionContext;
19use crate::query::df_graph::common::{
20 column_as_vid_array, compute_plan_properties, edge_struct_fields, new_node_list_builder,
21};
22use arrow::compute::take;
23use arrow_array::builder::{ListBuilder, StructBuilder, UInt64Builder};
24use arrow_array::{Array, ArrayRef, RecordBatch, UInt32Array, UInt64Array};
25use arrow_schema::{DataType, Field, Schema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::{Stream, StreamExt};
31use fxhash::FxHashMap;
32use std::any::Any;
33use std::collections::{HashSet, VecDeque};
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::task::{Context, Poll};
38use uni_common::core::id::Vid;
39use uni_store::runtime::l0_visibility;
40use uni_store::storage::direction::Direction;
41
42pub struct GraphShortestPathExec {
64 input: Arc<dyn ExecutionPlan>,
66
67 source_column: String,
69
70 target_column: String,
72
73 edge_type_ids: Vec<u32>,
75
76 direction: Direction,
78
79 path_variable: String,
81
82 all_shortest: bool,
84
85 graph_ctx: Arc<GraphExecutionContext>,
87
88 schema: SchemaRef,
90
91 properties: PlanProperties,
93
94 metrics: ExecutionPlanMetricsSet,
96}
97
98impl fmt::Debug for GraphShortestPathExec {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.debug_struct("GraphShortestPathExec")
101 .field("source_column", &self.source_column)
102 .field("target_column", &self.target_column)
103 .field("edge_type_ids", &self.edge_type_ids)
104 .field("direction", &self.direction)
105 .field("path_variable", &self.path_variable)
106 .field("all_shortest", &self.all_shortest)
107 .finish()
108 }
109}
110
111impl GraphShortestPathExec {
112 #[expect(
124 clippy::too_many_arguments,
125 reason = "Shortest path requires many parameters"
126 )]
127 pub fn new(
128 input: Arc<dyn ExecutionPlan>,
129 source_column: impl Into<String>,
130 target_column: impl Into<String>,
131 edge_type_ids: Vec<u32>,
132 direction: Direction,
133 path_variable: impl Into<String>,
134 graph_ctx: Arc<GraphExecutionContext>,
135 all_shortest: bool,
136 ) -> Self {
137 let source_column = source_column.into();
138 let target_column = target_column.into();
139 let path_variable = path_variable.into();
140
141 let schema = Self::build_schema(input.schema(), &path_variable);
142 let properties = compute_plan_properties(schema.clone());
143
144 Self {
145 input,
146 source_column,
147 target_column,
148 edge_type_ids,
149 direction,
150 path_variable,
151 all_shortest,
152 graph_ctx,
153 schema,
154 properties,
155 metrics: ExecutionPlanMetricsSet::new(),
156 }
157 }
158
159 fn build_schema(input_schema: SchemaRef, path_variable: &str) -> SchemaRef {
161 let mut fields: Vec<Field> = input_schema
162 .fields()
163 .iter()
164 .map(|f| f.as_ref().clone())
165 .collect();
166
167 fields.push(crate::query::df_graph::common::build_path_struct_field(
169 path_variable,
170 ));
171
172 let path_col_name = format!("{}._path", path_variable);
174 fields.push(Field::new(
175 &path_col_name,
176 DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
177 true, ));
179
180 let len_col_name = format!("{}._length", path_variable);
182 fields.push(Field::new(&len_col_name, DataType::UInt64, true));
183
184 Arc::new(Schema::new(fields))
185 }
186}
187
188impl DisplayAs for GraphShortestPathExec {
189 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 let mode = if self.all_shortest { "all" } else { "any" };
191 write!(
192 f,
193 "GraphShortestPathExec: {} -> {} via {:?} ({})",
194 self.source_column, self.target_column, self.edge_type_ids, mode
195 )
196 }
197}
198
199impl ExecutionPlan for GraphShortestPathExec {
200 fn name(&self) -> &str {
201 "GraphShortestPathExec"
202 }
203
204 fn as_any(&self) -> &dyn Any {
205 self
206 }
207
208 fn schema(&self) -> SchemaRef {
209 Arc::clone(&self.schema)
210 }
211
212 fn properties(&self) -> &PlanProperties {
213 &self.properties
214 }
215
216 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
217 vec![&self.input]
218 }
219
220 fn with_new_children(
221 self: Arc<Self>,
222 children: Vec<Arc<dyn ExecutionPlan>>,
223 ) -> DFResult<Arc<dyn ExecutionPlan>> {
224 if children.len() != 1 {
225 return Err(datafusion::error::DataFusionError::Plan(
226 "GraphShortestPathExec requires exactly one child".to_string(),
227 ));
228 }
229
230 Ok(Arc::new(Self::new(
231 Arc::clone(&children[0]),
232 self.source_column.clone(),
233 self.target_column.clone(),
234 self.edge_type_ids.clone(),
235 self.direction,
236 self.path_variable.clone(),
237 Arc::clone(&self.graph_ctx),
238 self.all_shortest,
239 )))
240 }
241
242 fn execute(
243 &self,
244 partition: usize,
245 context: Arc<TaskContext>,
246 ) -> DFResult<SendableRecordBatchStream> {
247 let input_stream = self.input.execute(partition, context)?;
248
249 let metrics = BaselineMetrics::new(&self.metrics, partition);
250
251 let warm_fut = self
252 .graph_ctx
253 .warming_future(self.edge_type_ids.clone(), self.direction);
254
255 Ok(Box::pin(GraphShortestPathStream {
256 input: input_stream,
257 source_column: self.source_column.clone(),
258 target_column: self.target_column.clone(),
259 edge_type_ids: self.edge_type_ids.clone(),
260 direction: self.direction,
261 all_shortest: self.all_shortest,
262 graph_ctx: Arc::clone(&self.graph_ctx),
263 schema: Arc::clone(&self.schema),
264 state: ShortestPathStreamState::Warming(warm_fut),
265 metrics,
266 }))
267 }
268
269 fn metrics(&self) -> Option<MetricsSet> {
270 Some(self.metrics.clone_inner())
271 }
272}
273
274enum ShortestPathStreamState {
276 Warming(Pin<Box<dyn std::future::Future<Output = DFResult<()>> + Send>>),
278 Reading,
280 Done,
282}
283
284struct GraphShortestPathStream {
286 input: SendableRecordBatchStream,
288
289 source_column: String,
291
292 target_column: String,
294
295 edge_type_ids: Vec<u32>,
297
298 direction: Direction,
300
301 all_shortest: bool,
303
304 graph_ctx: Arc<GraphExecutionContext>,
306
307 schema: SchemaRef,
309
310 state: ShortestPathStreamState,
312
313 metrics: BaselineMetrics,
315}
316
317impl GraphShortestPathStream {
318 fn compute_shortest_path(&self, source: Vid, target: Vid) -> Option<Vec<Vid>> {
320 if source == target {
321 return Some(vec![source]);
322 }
323
324 let mut visited: HashSet<Vid> = HashSet::new();
325 let mut queue: VecDeque<(Vid, Vec<Vid>)> = VecDeque::new();
326
327 visited.insert(source);
328 queue.push_back((source, vec![source]));
329
330 while let Some((current, path)) = queue.pop_front() {
331 for &edge_type in &self.edge_type_ids {
333 let neighbors = self
334 .graph_ctx
335 .get_neighbors(current, edge_type, self.direction);
336
337 for (neighbor, _eid) in neighbors {
338 if neighbor == target {
339 let mut result = path.clone();
341 result.push(target);
342 return Some(result);
343 }
344
345 if !visited.contains(&neighbor) {
346 visited.insert(neighbor);
347 let mut new_path = path.clone();
348 new_path.push(neighbor);
349 queue.push_back((neighbor, new_path));
350 }
351 }
352 }
353 }
354
355 None }
357
358 fn compute_all_shortest_paths(&self, source: Vid, target: Vid) -> Vec<Vec<Vid>> {
363 if source == target {
364 return vec![vec![source]];
365 }
366
367 let mut depth: FxHashMap<Vid, u32> = FxHashMap::default();
369 let mut predecessors: FxHashMap<Vid, Vec<Vid>> = FxHashMap::default();
370 depth.insert(source, 0);
371
372 let mut current_layer: Vec<Vid> = vec![source];
373 let mut current_depth = 0u32;
374 let mut target_found = false;
375
376 while !current_layer.is_empty() && !target_found {
377 current_depth += 1;
378 let mut next_layer_set: HashSet<Vid> = HashSet::new();
379
380 for ¤t in ¤t_layer {
381 for &edge_type in &self.edge_type_ids {
382 let neighbors =
383 self.graph_ctx
384 .get_neighbors(current, edge_type, self.direction);
385
386 for (neighbor, _eid) in neighbors {
387 if let Some(&d) = depth.get(&neighbor) {
388 if d == current_depth {
390 predecessors.entry(neighbor).or_default().push(current);
391 }
392 continue;
393 }
394
395 depth.insert(neighbor, current_depth);
397 predecessors.entry(neighbor).or_default().push(current);
398
399 if neighbor == target {
400 target_found = true;
401 } else {
402 next_layer_set.insert(neighbor);
403 }
404 }
405 }
406 }
407
408 current_layer = next_layer_set.into_iter().collect();
409 }
410
411 if !target_found {
412 return vec![];
413 }
414
415 let mut result: Vec<Vec<Vid>> = Vec::new();
417 let mut stack: Vec<(Vid, Vec<Vid>)> = vec![(target, vec![target])];
418
419 while let Some((node, path)) = stack.pop() {
420 if node == source {
421 let mut full_path = path;
422 full_path.reverse();
423 result.push(full_path);
424 continue;
425 }
426 if let Some(preds) = predecessors.get(&node) {
427 for &pred in preds {
428 let mut new_path = path.clone();
429 new_path.push(pred);
430 stack.push((pred, new_path));
431 }
432 }
433 }
434
435 result
436 }
437
438 fn process_batch(&self, batch: RecordBatch) -> DFResult<RecordBatch> {
440 let source_col = batch.column_by_name(&self.source_column).ok_or_else(|| {
442 datafusion::error::DataFusionError::Execution(format!(
443 "Source column '{}' not found",
444 self.source_column
445 ))
446 })?;
447
448 let target_col = batch.column_by_name(&self.target_column).ok_or_else(|| {
449 datafusion::error::DataFusionError::Execution(format!(
450 "Target column '{}' not found",
451 self.target_column
452 ))
453 })?;
454
455 let source_vid_cow = column_as_vid_array(source_col.as_ref())?;
456 let source_vids: &UInt64Array = &source_vid_cow;
457
458 let target_vid_cow = column_as_vid_array(target_col.as_ref())?;
459 let target_vids: &UInt64Array = &target_vid_cow;
460
461 if self.all_shortest {
462 let mut row_indices: Vec<u32> = Vec::new();
464 let mut all_paths: Vec<Option<Vec<Vid>>> = Vec::new();
465
466 for i in 0..batch.num_rows() {
467 if source_vids.is_null(i) || target_vids.is_null(i) {
468 row_indices.push(i as u32);
469 all_paths.push(None);
470 } else {
471 let source = Vid::from(source_vids.value(i));
472 let target = Vid::from(target_vids.value(i));
473 let paths = self.compute_all_shortest_paths(source, target);
474 if paths.is_empty() {
475 row_indices.push(i as u32);
476 all_paths.push(None);
477 } else {
478 for path in paths {
479 row_indices.push(i as u32);
480 all_paths.push(Some(path));
481 }
482 }
483 }
484 }
485
486 let indices = UInt32Array::from(row_indices);
488 let expanded_columns: Vec<ArrayRef> = batch
489 .columns()
490 .iter()
491 .map(|col| {
492 take(col.as_ref(), &indices, None).map_err(|e| {
493 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
494 })
495 })
496 .collect::<DFResult<Vec<_>>>()?;
497 let expanded_batch = RecordBatch::try_new(batch.schema(), expanded_columns)
498 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
499
500 self.build_output_batch(&expanded_batch, &all_paths)
501 } else {
502 let mut paths: Vec<Option<Vec<Vid>>> = Vec::with_capacity(batch.num_rows());
504
505 for i in 0..batch.num_rows() {
506 let path = if source_vids.is_null(i) || target_vids.is_null(i) {
507 None
508 } else {
509 let source = Vid::from(source_vids.value(i));
510 let target = Vid::from(target_vids.value(i));
511 self.compute_shortest_path(source, target)
512 };
513 paths.push(path);
514 }
515
516 self.build_output_batch(&batch, &paths)
517 }
518 }
519
520 fn build_output_batch(
522 &self,
523 input: &RecordBatch,
524 paths: &[Option<Vec<Vid>>],
525 ) -> DFResult<RecordBatch> {
526 let num_rows = paths.len();
527 let query_ctx = self.graph_ctx.query_context();
528
529 let mut columns: Vec<ArrayRef> = input.columns().to_vec();
531
532 let mut nodes_builder = new_node_list_builder();
534 let mut rels_builder =
535 ListBuilder::new(StructBuilder::from_fields(edge_struct_fields(), num_rows));
536 let mut path_validity = Vec::with_capacity(num_rows);
537
538 for path in paths {
539 match path {
540 Some(vids) => {
541 for &vid in vids {
543 super::common::append_node_to_struct(
544 nodes_builder.values(),
545 vid,
546 &query_ctx,
547 );
548 }
549 nodes_builder.append(true);
550
551 for window in vids.windows(2) {
554 let src = window[0];
555 let dst = window[1];
556 let (eid, type_name) = self.find_edge(src, dst);
557 super::common::append_edge_to_struct(
558 rels_builder.values(),
559 eid,
560 &type_name,
561 src.as_u64(),
562 dst.as_u64(),
563 &query_ctx,
564 );
565 }
566 rels_builder.append(true);
567 path_validity.push(true);
568 }
569 None => {
570 nodes_builder.append(false);
572 rels_builder.append(false);
573 path_validity.push(false);
574 }
575 }
576 }
577
578 let nodes_array = Arc::new(nodes_builder.finish()) as ArrayRef;
579 let rels_array = Arc::new(rels_builder.finish()) as ArrayRef;
580
581 let path_struct =
582 super::common::build_path_struct_array(nodes_array, rels_array, path_validity)?;
583 columns.push(Arc::new(path_struct));
584
585 let mut list_builder = ListBuilder::new(UInt64Builder::new());
587 for path in paths {
588 match path {
589 Some(p) => {
590 let values: Vec<u64> = p.iter().map(|v| v.as_u64()).collect();
591 list_builder.values().append_slice(&values);
592 list_builder.append(true);
593 }
594 None => {
595 list_builder.append(false); }
597 }
598 }
599 columns.push(Arc::new(list_builder.finish()));
600
601 let lengths: Vec<Option<u64>> = paths
603 .iter()
604 .map(|p| p.as_ref().map(|path| (path.len() - 1) as u64))
605 .collect();
606 columns.push(Arc::new(UInt64Array::from(lengths)));
607
608 self.metrics.record_output(num_rows);
609
610 RecordBatch::try_new(Arc::clone(&self.schema), columns)
611 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
612 }
613
614 fn find_edge(&self, src: Vid, dst: Vid) -> (uni_common::core::id::Eid, String) {
617 let query_ctx = self.graph_ctx.query_context();
618 for &edge_type in &self.edge_type_ids {
619 let neighbors = self.graph_ctx.get_neighbors(src, edge_type, self.direction);
620 for (neighbor, eid) in neighbors {
621 if neighbor == dst {
622 let type_name =
623 l0_visibility::get_edge_type(eid, &query_ctx).unwrap_or_default();
624 return (eid, type_name);
625 }
626 }
627 }
628 (uni_common::core::id::Eid::from(0u64), String::new())
629 }
630}
631
632impl Stream for GraphShortestPathStream {
633 type Item = DFResult<RecordBatch>;
634
635 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
636 loop {
637 let state = std::mem::replace(&mut self.state, ShortestPathStreamState::Done);
638
639 match state {
640 ShortestPathStreamState::Warming(mut fut) => match fut.as_mut().poll(cx) {
641 Poll::Ready(Ok(())) => {
642 self.state = ShortestPathStreamState::Reading;
643 }
645 Poll::Ready(Err(e)) => {
646 self.state = ShortestPathStreamState::Done;
647 return Poll::Ready(Some(Err(e)));
648 }
649 Poll::Pending => {
650 self.state = ShortestPathStreamState::Warming(fut);
651 return Poll::Pending;
652 }
653 },
654 ShortestPathStreamState::Reading => {
655 if let Err(e) = self.graph_ctx.check_timeout() {
657 return Poll::Ready(Some(Err(
658 datafusion::error::DataFusionError::Execution(e.to_string()),
659 )));
660 }
661
662 match self.input.poll_next_unpin(cx) {
663 Poll::Ready(Some(Ok(batch))) => {
664 let result = self.process_batch(batch);
665 self.state = ShortestPathStreamState::Reading;
666 return Poll::Ready(Some(result));
667 }
668 Poll::Ready(Some(Err(e))) => {
669 self.state = ShortestPathStreamState::Done;
670 return Poll::Ready(Some(Err(e)));
671 }
672 Poll::Ready(None) => {
673 self.state = ShortestPathStreamState::Done;
674 return Poll::Ready(None);
675 }
676 Poll::Pending => {
677 self.state = ShortestPathStreamState::Reading;
678 return Poll::Pending;
679 }
680 }
681 }
682 ShortestPathStreamState::Done => {
683 return Poll::Ready(None);
684 }
685 }
686 }
687 }
688}
689
690impl RecordBatchStream for GraphShortestPathStream {
691 fn schema(&self) -> SchemaRef {
692 Arc::clone(&self.schema)
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699
700 #[test]
701 fn test_shortest_path_schema() {
702 let input_schema = Arc::new(Schema::new(vec![
703 Field::new("_source_vid", DataType::UInt64, false),
704 Field::new("_target_vid", DataType::UInt64, false),
705 ]));
706
707 let output_schema = GraphShortestPathExec::build_schema(input_schema, "p");
708
709 assert_eq!(output_schema.fields().len(), 5);
710 assert_eq!(output_schema.field(0).name(), "_source_vid");
711 assert_eq!(output_schema.field(1).name(), "_target_vid");
712 assert_eq!(output_schema.field(2).name(), "p");
713 assert_eq!(output_schema.field(3).name(), "p._path");
714 assert_eq!(output_schema.field(4).name(), "p._length");
715 }
716}