1use crate::query::df_graph::GraphExecutionContext;
19use crate::query::df_graph::common::{
20 arrow_err, column_as_vid_array, compute_plan_properties, edge_struct_fields,
21 new_node_list_builder,
22};
23use arrow::compute::take;
24use arrow_array::builder::{ListBuilder, StructBuilder, UInt64Builder};
25use arrow_array::{Array, ArrayRef, RecordBatch, UInt32Array, UInt64Array};
26use arrow_schema::{DataType, Field, Schema, SchemaRef};
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
30use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
31use futures::{Stream, StreamExt};
32use fxhash::FxHashMap;
33use std::any::Any;
34use std::collections::{HashSet, VecDeque};
35use std::fmt;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::task::{Context, Poll};
39use uni_common::core::id::Vid;
40use uni_store::runtime::l0_visibility;
41use uni_store::storage::direction::Direction;
42
43pub struct GraphShortestPathExec {
65 input: Arc<dyn ExecutionPlan>,
67
68 source_column: String,
70
71 target_column: String,
73
74 edge_type_ids: Vec<u32>,
76
77 direction: Direction,
79
80 path_variable: String,
82
83 all_shortest: bool,
85
86 graph_ctx: Arc<GraphExecutionContext>,
88
89 schema: SchemaRef,
91
92 properties: PlanProperties,
94
95 metrics: ExecutionPlanMetricsSet,
97}
98
99impl fmt::Debug for GraphShortestPathExec {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 f.debug_struct("GraphShortestPathExec")
102 .field("source_column", &self.source_column)
103 .field("target_column", &self.target_column)
104 .field("edge_type_ids", &self.edge_type_ids)
105 .field("direction", &self.direction)
106 .field("path_variable", &self.path_variable)
107 .field("all_shortest", &self.all_shortest)
108 .finish()
109 }
110}
111
112impl GraphShortestPathExec {
113 #[expect(
125 clippy::too_many_arguments,
126 reason = "Shortest path requires many parameters"
127 )]
128 pub fn new(
129 input: Arc<dyn ExecutionPlan>,
130 source_column: impl Into<String>,
131 target_column: impl Into<String>,
132 edge_type_ids: Vec<u32>,
133 direction: Direction,
134 path_variable: impl Into<String>,
135 graph_ctx: Arc<GraphExecutionContext>,
136 all_shortest: bool,
137 ) -> Self {
138 let source_column = source_column.into();
139 let target_column = target_column.into();
140 let path_variable = path_variable.into();
141
142 let schema = Self::build_schema(input.schema(), &path_variable);
143 let properties = compute_plan_properties(schema.clone());
144
145 Self {
146 input,
147 source_column,
148 target_column,
149 edge_type_ids,
150 direction,
151 path_variable,
152 all_shortest,
153 graph_ctx,
154 schema,
155 properties,
156 metrics: ExecutionPlanMetricsSet::new(),
157 }
158 }
159
160 fn build_schema(input_schema: SchemaRef, path_variable: &str) -> SchemaRef {
162 let mut fields: Vec<Field> = input_schema
163 .fields()
164 .iter()
165 .map(|f| f.as_ref().clone())
166 .collect();
167
168 fields.push(crate::query::df_graph::common::build_path_struct_field(
170 path_variable,
171 ));
172
173 let path_col_name = format!("{}._path", path_variable);
175 fields.push(Field::new(
176 &path_col_name,
177 DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
178 true, ));
180
181 let len_col_name = format!("{}._length", path_variable);
183 fields.push(Field::new(&len_col_name, DataType::UInt64, true));
184
185 Arc::new(Schema::new(fields))
186 }
187}
188
189impl DisplayAs for GraphShortestPathExec {
190 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 let mode = if self.all_shortest { "all" } else { "any" };
192 write!(
193 f,
194 "GraphShortestPathExec: {} -> {} via {:?} ({})",
195 self.source_column, self.target_column, self.edge_type_ids, mode
196 )
197 }
198}
199
200impl ExecutionPlan for GraphShortestPathExec {
201 fn name(&self) -> &str {
202 "GraphShortestPathExec"
203 }
204
205 fn as_any(&self) -> &dyn Any {
206 self
207 }
208
209 fn schema(&self) -> SchemaRef {
210 Arc::clone(&self.schema)
211 }
212
213 fn properties(&self) -> &PlanProperties {
214 &self.properties
215 }
216
217 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
218 vec![&self.input]
219 }
220
221 fn with_new_children(
222 self: Arc<Self>,
223 children: Vec<Arc<dyn ExecutionPlan>>,
224 ) -> DFResult<Arc<dyn ExecutionPlan>> {
225 if children.len() != 1 {
226 return Err(datafusion::error::DataFusionError::Plan(
227 "GraphShortestPathExec requires exactly one child".to_string(),
228 ));
229 }
230
231 Ok(Arc::new(Self::new(
232 Arc::clone(&children[0]),
233 self.source_column.clone(),
234 self.target_column.clone(),
235 self.edge_type_ids.clone(),
236 self.direction,
237 self.path_variable.clone(),
238 Arc::clone(&self.graph_ctx),
239 self.all_shortest,
240 )))
241 }
242
243 fn execute(
244 &self,
245 partition: usize,
246 context: Arc<TaskContext>,
247 ) -> DFResult<SendableRecordBatchStream> {
248 let input_stream = self.input.execute(partition, context)?;
249
250 let metrics = BaselineMetrics::new(&self.metrics, partition);
251
252 let warm_fut = self
253 .graph_ctx
254 .warming_future(self.edge_type_ids.clone(), self.direction);
255
256 Ok(Box::pin(GraphShortestPathStream {
257 input: input_stream,
258 source_column: self.source_column.clone(),
259 target_column: self.target_column.clone(),
260 edge_type_ids: self.edge_type_ids.clone(),
261 direction: self.direction,
262 all_shortest: self.all_shortest,
263 graph_ctx: Arc::clone(&self.graph_ctx),
264 schema: Arc::clone(&self.schema),
265 state: ShortestPathStreamState::Warming(warm_fut),
266 metrics,
267 }))
268 }
269
270 fn metrics(&self) -> Option<MetricsSet> {
271 Some(self.metrics.clone_inner())
272 }
273}
274
275enum ShortestPathStreamState {
277 Warming(Pin<Box<dyn std::future::Future<Output = DFResult<()>> + Send>>),
279 Reading,
281 Done,
283}
284
285struct GraphShortestPathStream {
287 input: SendableRecordBatchStream,
289
290 source_column: String,
292
293 target_column: String,
295
296 edge_type_ids: Vec<u32>,
298
299 direction: Direction,
301
302 all_shortest: bool,
304
305 graph_ctx: Arc<GraphExecutionContext>,
307
308 schema: SchemaRef,
310
311 state: ShortestPathStreamState,
313
314 metrics: BaselineMetrics,
316}
317
318impl GraphShortestPathStream {
319 fn compute_shortest_path(&self, source: Vid, target: Vid) -> Option<Vec<Vid>> {
321 if source == target {
322 return Some(vec![source]);
323 }
324
325 let mut visited: HashSet<Vid> = HashSet::new();
326 let mut queue: VecDeque<(Vid, Vec<Vid>)> = VecDeque::new();
327
328 visited.insert(source);
329 queue.push_back((source, vec![source]));
330
331 while let Some((current, path)) = queue.pop_front() {
332 for &edge_type in &self.edge_type_ids {
334 let neighbors = self
335 .graph_ctx
336 .get_neighbors(current, edge_type, self.direction);
337
338 for (neighbor, _eid) in neighbors {
339 if neighbor == target {
340 let mut result = path.clone();
342 result.push(target);
343 return Some(result);
344 }
345
346 if !visited.contains(&neighbor) {
347 visited.insert(neighbor);
348 let mut new_path = path.clone();
349 new_path.push(neighbor);
350 queue.push_back((neighbor, new_path));
351 }
352 }
353 }
354 }
355
356 None }
358
359 fn compute_all_shortest_paths(&self, source: Vid, target: Vid) -> Vec<Vec<Vid>> {
364 if source == target {
365 return vec![vec![source]];
366 }
367
368 let mut depth: FxHashMap<Vid, u32> = FxHashMap::default();
370 let mut predecessors: FxHashMap<Vid, Vec<Vid>> = FxHashMap::default();
371 depth.insert(source, 0);
372
373 let mut current_layer: Vec<Vid> = vec![source];
374 let mut current_depth = 0u32;
375 let mut target_found = false;
376
377 while !current_layer.is_empty() && !target_found {
378 current_depth += 1;
379 let mut next_layer_set: HashSet<Vid> = HashSet::new();
380
381 for ¤t in ¤t_layer {
382 for &edge_type in &self.edge_type_ids {
383 let neighbors =
384 self.graph_ctx
385 .get_neighbors(current, edge_type, self.direction);
386
387 for (neighbor, _eid) in neighbors {
388 if let Some(&d) = depth.get(&neighbor) {
389 if d == current_depth {
391 predecessors.entry(neighbor).or_default().push(current);
392 }
393 continue;
394 }
395
396 depth.insert(neighbor, current_depth);
398 predecessors.entry(neighbor).or_default().push(current);
399
400 if neighbor == target {
401 target_found = true;
402 } else {
403 next_layer_set.insert(neighbor);
404 }
405 }
406 }
407 }
408
409 current_layer = next_layer_set.into_iter().collect();
410 }
411
412 if !target_found {
413 return vec![];
414 }
415
416 let mut result: Vec<Vec<Vid>> = Vec::new();
418 let mut stack: Vec<(Vid, Vec<Vid>)> = vec![(target, vec![target])];
419
420 while let Some((node, path)) = stack.pop() {
421 if node == source {
422 let mut full_path = path;
423 full_path.reverse();
424 result.push(full_path);
425 continue;
426 }
427 if let Some(preds) = predecessors.get(&node) {
428 for &pred in preds {
429 let mut new_path = path.clone();
430 new_path.push(pred);
431 stack.push((pred, new_path));
432 }
433 }
434 }
435
436 result
437 }
438
439 fn process_batch(&self, batch: RecordBatch) -> DFResult<RecordBatch> {
441 let source_col = batch.column_by_name(&self.source_column).ok_or_else(|| {
443 datafusion::error::DataFusionError::Execution(format!(
444 "Source column '{}' not found",
445 self.source_column
446 ))
447 })?;
448
449 let target_col = batch.column_by_name(&self.target_column).ok_or_else(|| {
450 datafusion::error::DataFusionError::Execution(format!(
451 "Target column '{}' not found",
452 self.target_column
453 ))
454 })?;
455
456 let source_vid_cow = column_as_vid_array(source_col.as_ref())?;
457 let source_vids: &UInt64Array = &source_vid_cow;
458
459 let target_vid_cow = column_as_vid_array(target_col.as_ref())?;
460 let target_vids: &UInt64Array = &target_vid_cow;
461
462 if self.all_shortest {
463 let mut row_indices: Vec<u32> = Vec::new();
465 let mut all_paths: Vec<Option<Vec<Vid>>> = Vec::new();
466
467 for i in 0..batch.num_rows() {
468 if source_vids.is_null(i) || target_vids.is_null(i) {
469 row_indices.push(i as u32);
470 all_paths.push(None);
471 } else {
472 let source = Vid::from(source_vids.value(i));
473 let target = Vid::from(target_vids.value(i));
474 let paths = self.compute_all_shortest_paths(source, target);
475 if paths.is_empty() {
476 row_indices.push(i as u32);
477 all_paths.push(None);
478 } else {
479 for path in paths {
480 row_indices.push(i as u32);
481 all_paths.push(Some(path));
482 }
483 }
484 }
485 }
486
487 let indices = UInt32Array::from(row_indices);
489 let expanded_columns: Vec<ArrayRef> = batch
490 .columns()
491 .iter()
492 .map(|col| {
493 take(col.as_ref(), &indices, None).map_err(|e| {
494 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
495 })
496 })
497 .collect::<DFResult<Vec<_>>>()?;
498 let expanded_batch =
499 RecordBatch::try_new(batch.schema(), expanded_columns).map_err(arrow_err)?;
500
501 self.build_output_batch(&expanded_batch, &all_paths)
502 } else {
503 let mut paths: Vec<Option<Vec<Vid>>> = Vec::with_capacity(batch.num_rows());
505
506 for i in 0..batch.num_rows() {
507 let path = if source_vids.is_null(i) || target_vids.is_null(i) {
508 None
509 } else {
510 let source = Vid::from(source_vids.value(i));
511 let target = Vid::from(target_vids.value(i));
512 self.compute_shortest_path(source, target)
513 };
514 paths.push(path);
515 }
516
517 self.build_output_batch(&batch, &paths)
518 }
519 }
520
521 fn build_output_batch(
523 &self,
524 input: &RecordBatch,
525 paths: &[Option<Vec<Vid>>],
526 ) -> DFResult<RecordBatch> {
527 let num_rows = paths.len();
528 let query_ctx = self.graph_ctx.query_context();
529
530 let mut columns: Vec<ArrayRef> = input.columns().to_vec();
532
533 let mut nodes_builder = new_node_list_builder();
535 let mut rels_builder =
536 ListBuilder::new(StructBuilder::from_fields(edge_struct_fields(), num_rows));
537 let mut path_validity = Vec::with_capacity(num_rows);
538
539 for path in paths {
540 match path {
541 Some(vids) => {
542 for &vid in vids {
544 super::common::append_node_to_struct(
545 nodes_builder.values(),
546 vid,
547 &query_ctx,
548 );
549 }
550 nodes_builder.append(true);
551
552 for window in vids.windows(2) {
555 let src = window[0];
556 let dst = window[1];
557 let (eid, type_name) = self.find_edge(src, dst);
558 super::common::append_edge_to_struct(
559 rels_builder.values(),
560 eid,
561 &type_name,
562 src.as_u64(),
563 dst.as_u64(),
564 &query_ctx,
565 );
566 }
567 rels_builder.append(true);
568 path_validity.push(true);
569 }
570 None => {
571 nodes_builder.append(false);
573 rels_builder.append(false);
574 path_validity.push(false);
575 }
576 }
577 }
578
579 let nodes_array = Arc::new(nodes_builder.finish()) as ArrayRef;
580 let rels_array = Arc::new(rels_builder.finish()) as ArrayRef;
581
582 let path_struct =
583 super::common::build_path_struct_array(nodes_array, rels_array, path_validity)?;
584 columns.push(Arc::new(path_struct));
585
586 let mut list_builder = ListBuilder::new(UInt64Builder::new());
588 for path in paths {
589 match path {
590 Some(p) => {
591 let values: Vec<u64> = p.iter().map(|v| v.as_u64()).collect();
592 list_builder.values().append_slice(&values);
593 list_builder.append(true);
594 }
595 None => {
596 list_builder.append(false); }
598 }
599 }
600 columns.push(Arc::new(list_builder.finish()));
601
602 let lengths: Vec<Option<u64>> = paths
604 .iter()
605 .map(|p| p.as_ref().map(|path| (path.len() - 1) as u64))
606 .collect();
607 columns.push(Arc::new(UInt64Array::from(lengths)));
608
609 self.metrics.record_output(num_rows);
610
611 RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(arrow_err)
612 }
613
614 fn find_edge(&self, src: Vid, dst: Vid) -> (uni_common::core::id::Eid, String) {
617 let query_ctx = self.graph_ctx.query_context();
618 for &edge_type in &self.edge_type_ids {
619 let neighbors = self.graph_ctx.get_neighbors(src, edge_type, self.direction);
620 for (neighbor, eid) in neighbors {
621 if neighbor == dst {
622 let type_name =
623 l0_visibility::get_edge_type(eid, &query_ctx).unwrap_or_default();
624 return (eid, type_name);
625 }
626 }
627 }
628 (uni_common::core::id::Eid::from(0u64), String::new())
629 }
630}
631
632impl Stream for GraphShortestPathStream {
633 type Item = DFResult<RecordBatch>;
634
635 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
636 loop {
637 let state = std::mem::replace(&mut self.state, ShortestPathStreamState::Done);
638
639 match state {
640 ShortestPathStreamState::Warming(mut fut) => match fut.as_mut().poll(cx) {
641 Poll::Ready(Ok(())) => {
642 self.state = ShortestPathStreamState::Reading;
643 }
645 Poll::Ready(Err(e)) => {
646 self.state = ShortestPathStreamState::Done;
647 return Poll::Ready(Some(Err(e)));
648 }
649 Poll::Pending => {
650 self.state = ShortestPathStreamState::Warming(fut);
651 return Poll::Pending;
652 }
653 },
654 ShortestPathStreamState::Reading => {
655 if let Err(e) = self.graph_ctx.check_timeout() {
657 return Poll::Ready(Some(Err(
658 datafusion::error::DataFusionError::Execution(e.to_string()),
659 )));
660 }
661
662 match self.input.poll_next_unpin(cx) {
663 Poll::Ready(Some(Ok(batch))) => {
664 let result = self.process_batch(batch);
665 self.state = ShortestPathStreamState::Reading;
666 return Poll::Ready(Some(result));
667 }
668 Poll::Ready(Some(Err(e))) => {
669 self.state = ShortestPathStreamState::Done;
670 return Poll::Ready(Some(Err(e)));
671 }
672 Poll::Ready(None) => {
673 self.state = ShortestPathStreamState::Done;
674 return Poll::Ready(None);
675 }
676 Poll::Pending => {
677 self.state = ShortestPathStreamState::Reading;
678 return Poll::Pending;
679 }
680 }
681 }
682 ShortestPathStreamState::Done => {
683 return Poll::Ready(None);
684 }
685 }
686 }
687 }
688}
689
690impl RecordBatchStream for GraphShortestPathStream {
691 fn schema(&self) -> SchemaRef {
692 Arc::clone(&self.schema)
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699
700 #[test]
701 fn test_shortest_path_schema() {
702 let input_schema = Arc::new(Schema::new(vec![
703 Field::new("_source_vid", DataType::UInt64, false),
704 Field::new("_target_vid", DataType::UInt64, false),
705 ]));
706
707 let output_schema = GraphShortestPathExec::build_schema(input_schema, "p");
708
709 assert_eq!(output_schema.fields().len(), 5);
710 assert_eq!(output_schema.field(0).name(), "_source_vid");
711 assert_eq!(output_schema.field(1).name(), "_target_vid");
712 assert_eq!(output_schema.field(2).name(), "p");
713 assert_eq!(output_schema.field(3).name(), "p._path");
714 assert_eq!(output_schema.field(4).name(), "p._length");
715 }
716}