Skip to main content

uni_query/query/df_graph/
shortest_path.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Shortest path execution plan for DataFusion.
5//!
6//! This module provides [`GraphShortestPathExec`], a DataFusion [`ExecutionPlan`] that
7//! computes shortest paths between source and target vertices using BFS.
8//!
9//! # Algorithm
10//!
11//! Uses bidirectional BFS for efficiency:
12//! 1. Expand from source (forward direction)
13//! 2. Expand from target (backward direction)
14//! 3. Return path when frontiers meet
15//!
16//! Falls back to single-direction BFS when bidirectional is not applicable.
17
18use crate::query::df_graph::GraphExecutionContext;
19use crate::query::df_graph::common::{
20    column_as_vid_array, compute_plan_properties, edge_struct_fields, new_node_list_builder,
21};
22use arrow::compute::take;
23use arrow_array::builder::{ListBuilder, StructBuilder, UInt64Builder};
24use arrow_array::{Array, ArrayRef, RecordBatch, UInt32Array, UInt64Array};
25use arrow_schema::{DataType, Field, Schema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::{Stream, StreamExt};
31use fxhash::FxHashMap;
32use std::any::Any;
33use std::collections::{HashSet, VecDeque};
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::task::{Context, Poll};
38use uni_common::core::id::Vid;
39use uni_store::runtime::l0_visibility;
40use uni_store::storage::direction::Direction;
41
42/// Shortest path execution plan.
43///
44/// Computes shortest paths between source and target vertices using BFS.
45/// Returns the path as a list of VIDs.
46///
47/// # Example
48///
49/// ```ignore
50/// // Find shortest path from source to target via KNOWS edges
51/// let shortest_path = GraphShortestPathExec::new(
52///     input_plan,
53///     "_source_vid",
54///     "_target_vid",
55///     vec![knows_type_id],
56///     Direction::Both,
57///     "p",
58///     graph_ctx,
59/// );
60///
61/// // Output: input columns + p._path (List<UInt64>)
62/// ```
63pub struct GraphShortestPathExec {
64    /// Input execution plan.
65    input: Arc<dyn ExecutionPlan>,
66
67    /// Column name containing source VIDs.
68    source_column: String,
69
70    /// Column name containing target VIDs.
71    target_column: String,
72
73    /// Edge type IDs to traverse.
74    edge_type_ids: Vec<u32>,
75
76    /// Traversal direction.
77    direction: Direction,
78
79    /// Variable name for the path.
80    path_variable: String,
81
82    /// Whether this is allShortestPaths (true) or shortestPath (false).
83    all_shortest: bool,
84
85    /// Graph execution context.
86    graph_ctx: Arc<GraphExecutionContext>,
87
88    /// Output schema.
89    schema: SchemaRef,
90
91    /// Cached plan properties.
92    properties: PlanProperties,
93
94    /// Execution metrics.
95    metrics: ExecutionPlanMetricsSet,
96}
97
98impl fmt::Debug for GraphShortestPathExec {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        f.debug_struct("GraphShortestPathExec")
101            .field("source_column", &self.source_column)
102            .field("target_column", &self.target_column)
103            .field("edge_type_ids", &self.edge_type_ids)
104            .field("direction", &self.direction)
105            .field("path_variable", &self.path_variable)
106            .field("all_shortest", &self.all_shortest)
107            .finish()
108    }
109}
110
111impl GraphShortestPathExec {
112    /// Create a new shortest path execution plan.
113    ///
114    /// # Arguments
115    ///
116    /// * `input` - Input plan providing source and target vertices
117    /// * `source_column` - Column name containing source VIDs
118    /// * `target_column` - Column name containing target VIDs
119    /// * `edge_type_ids` - Edge types to traverse
120    /// * `direction` - Traversal direction
121    /// * `path_variable` - Variable name for the path
122    /// * `graph_ctx` - Graph execution context
123    #[expect(
124        clippy::too_many_arguments,
125        reason = "Shortest path requires many parameters"
126    )]
127    pub fn new(
128        input: Arc<dyn ExecutionPlan>,
129        source_column: impl Into<String>,
130        target_column: impl Into<String>,
131        edge_type_ids: Vec<u32>,
132        direction: Direction,
133        path_variable: impl Into<String>,
134        graph_ctx: Arc<GraphExecutionContext>,
135        all_shortest: bool,
136    ) -> Self {
137        let source_column = source_column.into();
138        let target_column = target_column.into();
139        let path_variable = path_variable.into();
140
141        let schema = Self::build_schema(input.schema(), &path_variable);
142        let properties = compute_plan_properties(schema.clone());
143
144        Self {
145            input,
146            source_column,
147            target_column,
148            edge_type_ids,
149            direction,
150            path_variable,
151            all_shortest,
152            graph_ctx,
153            schema,
154            properties,
155            metrics: ExecutionPlanMetricsSet::new(),
156        }
157    }
158
159    /// Build output schema.
160    fn build_schema(input_schema: SchemaRef, path_variable: &str) -> SchemaRef {
161        let mut fields: Vec<Field> = input_schema
162            .fields()
163            .iter()
164            .map(|f| f.as_ref().clone())
165            .collect();
166
167        // Add the proper path struct column (nodes + relationships)
168        fields.push(crate::query::df_graph::common::build_path_struct_field(
169            path_variable,
170        ));
171
172        // Add path column (raw VID list for internal use)
173        let path_col_name = format!("{}._path", path_variable);
174        fields.push(Field::new(
175            &path_col_name,
176            DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
177            true, // Nullable - null when no path exists
178        ));
179
180        // Add path length column
181        let len_col_name = format!("{}._length", path_variable);
182        fields.push(Field::new(&len_col_name, DataType::UInt64, true));
183
184        Arc::new(Schema::new(fields))
185    }
186}
187
188impl DisplayAs for GraphShortestPathExec {
189    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        let mode = if self.all_shortest { "all" } else { "any" };
191        write!(
192            f,
193            "GraphShortestPathExec: {} -> {} via {:?} ({})",
194            self.source_column, self.target_column, self.edge_type_ids, mode
195        )
196    }
197}
198
199impl ExecutionPlan for GraphShortestPathExec {
200    fn name(&self) -> &str {
201        "GraphShortestPathExec"
202    }
203
204    fn as_any(&self) -> &dyn Any {
205        self
206    }
207
208    fn schema(&self) -> SchemaRef {
209        Arc::clone(&self.schema)
210    }
211
212    fn properties(&self) -> &PlanProperties {
213        &self.properties
214    }
215
216    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
217        vec![&self.input]
218    }
219
220    fn with_new_children(
221        self: Arc<Self>,
222        children: Vec<Arc<dyn ExecutionPlan>>,
223    ) -> DFResult<Arc<dyn ExecutionPlan>> {
224        if children.len() != 1 {
225            return Err(datafusion::error::DataFusionError::Plan(
226                "GraphShortestPathExec requires exactly one child".to_string(),
227            ));
228        }
229
230        Ok(Arc::new(Self::new(
231            Arc::clone(&children[0]),
232            self.source_column.clone(),
233            self.target_column.clone(),
234            self.edge_type_ids.clone(),
235            self.direction,
236            self.path_variable.clone(),
237            Arc::clone(&self.graph_ctx),
238            self.all_shortest,
239        )))
240    }
241
242    fn execute(
243        &self,
244        partition: usize,
245        context: Arc<TaskContext>,
246    ) -> DFResult<SendableRecordBatchStream> {
247        let input_stream = self.input.execute(partition, context)?;
248
249        let metrics = BaselineMetrics::new(&self.metrics, partition);
250
251        let warm_fut = self
252            .graph_ctx
253            .warming_future(self.edge_type_ids.clone(), self.direction);
254
255        Ok(Box::pin(GraphShortestPathStream {
256            input: input_stream,
257            source_column: self.source_column.clone(),
258            target_column: self.target_column.clone(),
259            edge_type_ids: self.edge_type_ids.clone(),
260            direction: self.direction,
261            all_shortest: self.all_shortest,
262            graph_ctx: Arc::clone(&self.graph_ctx),
263            schema: Arc::clone(&self.schema),
264            state: ShortestPathStreamState::Warming(warm_fut),
265            metrics,
266        }))
267    }
268
269    fn metrics(&self) -> Option<MetricsSet> {
270        Some(self.metrics.clone_inner())
271    }
272}
273
274/// State machine for shortest path stream execution.
275enum ShortestPathStreamState {
276    /// Warming adjacency CSRs before first batch.
277    Warming(Pin<Box<dyn std::future::Future<Output = DFResult<()>> + Send>>),
278    /// Processing input batches.
279    Reading,
280    /// Stream is done.
281    Done,
282}
283
284/// Stream that computes shortest paths.
285struct GraphShortestPathStream {
286    /// Input stream.
287    input: SendableRecordBatchStream,
288
289    /// Column name containing source VIDs.
290    source_column: String,
291
292    /// Column name containing target VIDs.
293    target_column: String,
294
295    /// Edge type IDs to traverse.
296    edge_type_ids: Vec<u32>,
297
298    /// Traversal direction.
299    direction: Direction,
300
301    /// Whether this is allShortestPaths mode.
302    all_shortest: bool,
303
304    /// Graph execution context.
305    graph_ctx: Arc<GraphExecutionContext>,
306
307    /// Output schema.
308    schema: SchemaRef,
309
310    /// Stream state.
311    state: ShortestPathStreamState,
312
313    /// Metrics.
314    metrics: BaselineMetrics,
315}
316
317impl GraphShortestPathStream {
318    /// Compute shortest path between two vertices using BFS.
319    fn compute_shortest_path(&self, source: Vid, target: Vid) -> Option<Vec<Vid>> {
320        if source == target {
321            return Some(vec![source]);
322        }
323
324        let mut visited: HashSet<Vid> = HashSet::new();
325        let mut queue: VecDeque<(Vid, Vec<Vid>)> = VecDeque::new();
326
327        visited.insert(source);
328        queue.push_back((source, vec![source]));
329
330        while let Some((current, path)) = queue.pop_front() {
331            // Get neighbors for all edge types
332            for &edge_type in &self.edge_type_ids {
333                let neighbors = self
334                    .graph_ctx
335                    .get_neighbors(current, edge_type, self.direction);
336
337                for (neighbor, _eid) in neighbors {
338                    if neighbor == target {
339                        // Found the target
340                        let mut result = path.clone();
341                        result.push(target);
342                        return Some(result);
343                    }
344
345                    if !visited.contains(&neighbor) {
346                        visited.insert(neighbor);
347                        let mut new_path = path.clone();
348                        new_path.push(neighbor);
349                        queue.push_back((neighbor, new_path));
350                    }
351                }
352            }
353        }
354
355        None // No path found
356    }
357
358    /// Compute all shortest paths between two vertices using layer-by-layer BFS
359    /// with predecessor tracking.
360    ///
361    /// Returns all paths of minimum length from source to target.
362    fn compute_all_shortest_paths(&self, source: Vid, target: Vid) -> Vec<Vec<Vid>> {
363        if source == target {
364            return vec![vec![source]];
365        }
366
367        // Layer-by-layer BFS recording ALL predecessors at shortest depth
368        let mut depth: FxHashMap<Vid, u32> = FxHashMap::default();
369        let mut predecessors: FxHashMap<Vid, Vec<Vid>> = FxHashMap::default();
370        depth.insert(source, 0);
371
372        let mut current_layer: Vec<Vid> = vec![source];
373        let mut current_depth = 0u32;
374        let mut target_found = false;
375
376        while !current_layer.is_empty() && !target_found {
377            current_depth += 1;
378            let mut next_layer_set: HashSet<Vid> = HashSet::new();
379
380            for &current in &current_layer {
381                for &edge_type in &self.edge_type_ids {
382                    let neighbors =
383                        self.graph_ctx
384                            .get_neighbors(current, edge_type, self.direction);
385
386                    for (neighbor, _eid) in neighbors {
387                        if let Some(&d) = depth.get(&neighbor) {
388                            // Already discovered: only add predecessor if same depth
389                            if d == current_depth {
390                                predecessors.entry(neighbor).or_default().push(current);
391                            }
392                            continue;
393                        }
394
395                        // First time seeing this vertex at current_depth
396                        depth.insert(neighbor, current_depth);
397                        predecessors.entry(neighbor).or_default().push(current);
398
399                        if neighbor == target {
400                            target_found = true;
401                        } else {
402                            next_layer_set.insert(neighbor);
403                        }
404                    }
405                }
406            }
407
408            current_layer = next_layer_set.into_iter().collect();
409        }
410
411        if !target_found {
412            return vec![];
413        }
414
415        // Enumerate all shortest paths via backward DFS from target to source
416        let mut result: Vec<Vec<Vid>> = Vec::new();
417        let mut stack: Vec<(Vid, Vec<Vid>)> = vec![(target, vec![target])];
418
419        while let Some((node, path)) = stack.pop() {
420            if node == source {
421                let mut full_path = path;
422                full_path.reverse();
423                result.push(full_path);
424                continue;
425            }
426            if let Some(preds) = predecessors.get(&node) {
427                for &pred in preds {
428                    let mut new_path = path.clone();
429                    new_path.push(pred);
430                    stack.push((pred, new_path));
431                }
432            }
433        }
434
435        result
436    }
437
438    /// Process a single input batch.
439    fn process_batch(&self, batch: RecordBatch) -> DFResult<RecordBatch> {
440        // Extract source and target VIDs
441        let source_col = batch.column_by_name(&self.source_column).ok_or_else(|| {
442            datafusion::error::DataFusionError::Execution(format!(
443                "Source column '{}' not found",
444                self.source_column
445            ))
446        })?;
447
448        let target_col = batch.column_by_name(&self.target_column).ok_or_else(|| {
449            datafusion::error::DataFusionError::Execution(format!(
450                "Target column '{}' not found",
451                self.target_column
452            ))
453        })?;
454
455        let source_vid_cow = column_as_vid_array(source_col.as_ref())?;
456        let source_vids: &UInt64Array = &source_vid_cow;
457
458        let target_vid_cow = column_as_vid_array(target_col.as_ref())?;
459        let target_vids: &UInt64Array = &target_vid_cow;
460
461        if self.all_shortest {
462            // allShortestPaths: each input row can produce multiple output rows
463            let mut row_indices: Vec<u32> = Vec::new();
464            let mut all_paths: Vec<Option<Vec<Vid>>> = Vec::new();
465
466            for i in 0..batch.num_rows() {
467                if source_vids.is_null(i) || target_vids.is_null(i) {
468                    row_indices.push(i as u32);
469                    all_paths.push(None);
470                } else {
471                    let source = Vid::from(source_vids.value(i));
472                    let target = Vid::from(target_vids.value(i));
473                    let paths = self.compute_all_shortest_paths(source, target);
474                    if paths.is_empty() {
475                        row_indices.push(i as u32);
476                        all_paths.push(None);
477                    } else {
478                        for path in paths {
479                            row_indices.push(i as u32);
480                            all_paths.push(Some(path));
481                        }
482                    }
483                }
484            }
485
486            // Expand input batch rows according to row_indices
487            let indices = UInt32Array::from(row_indices);
488            let expanded_columns: Vec<ArrayRef> = batch
489                .columns()
490                .iter()
491                .map(|col| {
492                    take(col.as_ref(), &indices, None).map_err(|e| {
493                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
494                    })
495                })
496                .collect::<DFResult<Vec<_>>>()?;
497            let expanded_batch = RecordBatch::try_new(batch.schema(), expanded_columns)
498                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
499
500            self.build_output_batch(&expanded_batch, &all_paths)
501        } else {
502            // shortestPath: one path per input row
503            let mut paths: Vec<Option<Vec<Vid>>> = Vec::with_capacity(batch.num_rows());
504
505            for i in 0..batch.num_rows() {
506                let path = if source_vids.is_null(i) || target_vids.is_null(i) {
507                    None
508                } else {
509                    let source = Vid::from(source_vids.value(i));
510                    let target = Vid::from(target_vids.value(i));
511                    self.compute_shortest_path(source, target)
512                };
513                paths.push(path);
514            }
515
516            self.build_output_batch(&batch, &paths)
517        }
518    }
519
520    /// Build output batch with path columns.
521    fn build_output_batch(
522        &self,
523        input: &RecordBatch,
524        paths: &[Option<Vec<Vid>>],
525    ) -> DFResult<RecordBatch> {
526        let num_rows = paths.len();
527        let query_ctx = self.graph_ctx.query_context();
528
529        // Copy input columns
530        let mut columns: Vec<ArrayRef> = input.columns().to_vec();
531
532        // Build the path struct column (nodes + relationships)
533        let mut nodes_builder = new_node_list_builder();
534        let mut rels_builder =
535            ListBuilder::new(StructBuilder::from_fields(edge_struct_fields(), num_rows));
536        let mut path_validity = Vec::with_capacity(num_rows);
537
538        for path in paths {
539            match path {
540                Some(vids) => {
541                    // Add all nodes
542                    for &vid in vids {
543                        super::common::append_node_to_struct(
544                            nodes_builder.values(),
545                            vid,
546                            &query_ctx,
547                        );
548                    }
549                    nodes_builder.append(true);
550
551                    // Add edges between consecutive nodes
552                    // BFS returns node VIDs; edges are between consecutive pairs
553                    for window in vids.windows(2) {
554                        let src = window[0];
555                        let dst = window[1];
556                        let (eid, type_name) = self.find_edge(src, dst);
557                        super::common::append_edge_to_struct(
558                            rels_builder.values(),
559                            eid,
560                            &type_name,
561                            src.as_u64(),
562                            dst.as_u64(),
563                            &query_ctx,
564                        );
565                    }
566                    rels_builder.append(true);
567                    path_validity.push(true);
568                }
569                None => {
570                    // Null path
571                    nodes_builder.append(false);
572                    rels_builder.append(false);
573                    path_validity.push(false);
574                }
575            }
576        }
577
578        let nodes_array = Arc::new(nodes_builder.finish()) as ArrayRef;
579        let rels_array = Arc::new(rels_builder.finish()) as ArrayRef;
580
581        let path_struct =
582            super::common::build_path_struct_array(nodes_array, rels_array, path_validity)?;
583        columns.push(Arc::new(path_struct));
584
585        // Build raw path list column (VID list for internal use)
586        let mut list_builder = ListBuilder::new(UInt64Builder::new());
587        for path in paths {
588            match path {
589                Some(p) => {
590                    let values: Vec<u64> = p.iter().map(|v| v.as_u64()).collect();
591                    list_builder.values().append_slice(&values);
592                    list_builder.append(true);
593                }
594                None => {
595                    list_builder.append(false); // Null for no path
596                }
597            }
598        }
599        columns.push(Arc::new(list_builder.finish()));
600
601        // Build path length column
602        let lengths: Vec<Option<u64>> = paths
603            .iter()
604            .map(|p| p.as_ref().map(|path| (path.len() - 1) as u64))
605            .collect();
606        columns.push(Arc::new(UInt64Array::from(lengths)));
607
608        self.metrics.record_output(num_rows);
609
610        RecordBatch::try_new(Arc::clone(&self.schema), columns)
611            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
612    }
613
614    /// Find an edge connecting src to dst.
615    /// Returns (eid, type_name). Property lookup is handled by `append_edge_to_struct`.
616    fn find_edge(&self, src: Vid, dst: Vid) -> (uni_common::core::id::Eid, String) {
617        let query_ctx = self.graph_ctx.query_context();
618        for &edge_type in &self.edge_type_ids {
619            let neighbors = self.graph_ctx.get_neighbors(src, edge_type, self.direction);
620            for (neighbor, eid) in neighbors {
621                if neighbor == dst {
622                    let type_name =
623                        l0_visibility::get_edge_type(eid, &query_ctx).unwrap_or_default();
624                    return (eid, type_name);
625                }
626            }
627        }
628        (uni_common::core::id::Eid::from(0u64), String::new())
629    }
630}
631
632impl Stream for GraphShortestPathStream {
633    type Item = DFResult<RecordBatch>;
634
635    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
636        loop {
637            let state = std::mem::replace(&mut self.state, ShortestPathStreamState::Done);
638
639            match state {
640                ShortestPathStreamState::Warming(mut fut) => match fut.as_mut().poll(cx) {
641                    Poll::Ready(Ok(())) => {
642                        self.state = ShortestPathStreamState::Reading;
643                        // Continue loop to start reading
644                    }
645                    Poll::Ready(Err(e)) => {
646                        self.state = ShortestPathStreamState::Done;
647                        return Poll::Ready(Some(Err(e)));
648                    }
649                    Poll::Pending => {
650                        self.state = ShortestPathStreamState::Warming(fut);
651                        return Poll::Pending;
652                    }
653                },
654                ShortestPathStreamState::Reading => {
655                    // Check timeout
656                    if let Err(e) = self.graph_ctx.check_timeout() {
657                        return Poll::Ready(Some(Err(
658                            datafusion::error::DataFusionError::Execution(e.to_string()),
659                        )));
660                    }
661
662                    match self.input.poll_next_unpin(cx) {
663                        Poll::Ready(Some(Ok(batch))) => {
664                            let result = self.process_batch(batch);
665                            self.state = ShortestPathStreamState::Reading;
666                            return Poll::Ready(Some(result));
667                        }
668                        Poll::Ready(Some(Err(e))) => {
669                            self.state = ShortestPathStreamState::Done;
670                            return Poll::Ready(Some(Err(e)));
671                        }
672                        Poll::Ready(None) => {
673                            self.state = ShortestPathStreamState::Done;
674                            return Poll::Ready(None);
675                        }
676                        Poll::Pending => {
677                            self.state = ShortestPathStreamState::Reading;
678                            return Poll::Pending;
679                        }
680                    }
681                }
682                ShortestPathStreamState::Done => {
683                    return Poll::Ready(None);
684                }
685            }
686        }
687    }
688}
689
690impl RecordBatchStream for GraphShortestPathStream {
691    fn schema(&self) -> SchemaRef {
692        Arc::clone(&self.schema)
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    #[test]
701    fn test_shortest_path_schema() {
702        let input_schema = Arc::new(Schema::new(vec![
703            Field::new("_source_vid", DataType::UInt64, false),
704            Field::new("_target_vid", DataType::UInt64, false),
705        ]));
706
707        let output_schema = GraphShortestPathExec::build_schema(input_schema, "p");
708
709        assert_eq!(output_schema.fields().len(), 5);
710        assert_eq!(output_schema.field(0).name(), "_source_vid");
711        assert_eq!(output_schema.field(1).name(), "_target_vid");
712        assert_eq!(output_schema.field(2).name(), "p");
713        assert_eq!(output_schema.field(3).name(), "p._path");
714        assert_eq!(output_schema.field(4).name(), "p._length");
715    }
716}