Skip to main content

uni_query/query/df_graph/
shortest_path.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Shortest path execution plan for DataFusion.
5//!
6//! This module provides [`GraphShortestPathExec`], a DataFusion [`ExecutionPlan`] that
7//! computes shortest paths between source and target vertices using BFS.
8//!
9//! # Algorithm
10//!
11//! Uses bidirectional BFS for efficiency:
12//! 1. Expand from source (forward direction)
13//! 2. Expand from target (backward direction)
14//! 3. Return path when frontiers meet
15//!
16//! Falls back to single-direction BFS when bidirectional is not applicable.
17
18use crate::query::df_graph::GraphExecutionContext;
19use crate::query::df_graph::common::{
20    arrow_err, column_as_vid_array, compute_plan_properties, edge_struct_fields,
21    new_node_list_builder,
22};
23use arrow::compute::take;
24use arrow_array::builder::{ListBuilder, StructBuilder, UInt64Builder};
25use arrow_array::{Array, ArrayRef, RecordBatch, UInt32Array, UInt64Array};
26use arrow_schema::{DataType, Field, Schema, SchemaRef};
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
30use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
31use futures::{Stream, StreamExt};
32use fxhash::FxHashMap;
33use std::any::Any;
34use std::collections::{HashSet, VecDeque};
35use std::fmt;
36use std::pin::Pin;
37use std::sync::Arc;
38use std::task::{Context, Poll};
39use uni_common::core::id::Vid;
40use uni_store::runtime::l0_visibility;
41use uni_store::storage::direction::Direction;
42
43/// Shortest path execution plan.
44///
45/// Computes shortest paths between source and target vertices using BFS.
46/// Returns the path as a list of VIDs.
47///
48/// # Example
49///
50/// ```ignore
51/// // Find shortest path from source to target via KNOWS edges
52/// let shortest_path = GraphShortestPathExec::new(
53///     input_plan,
54///     "_source_vid",
55///     "_target_vid",
56///     vec![knows_type_id],
57///     Direction::Both,
58///     "p",
59///     graph_ctx,
60/// );
61///
62/// // Output: input columns + p._path (List<UInt64>)
63/// ```
64pub struct GraphShortestPathExec {
65    /// Input execution plan.
66    input: Arc<dyn ExecutionPlan>,
67
68    /// Column name containing source VIDs.
69    source_column: String,
70
71    /// Column name containing target VIDs.
72    target_column: String,
73
74    /// Edge type IDs to traverse.
75    edge_type_ids: Vec<u32>,
76
77    /// Traversal direction.
78    direction: Direction,
79
80    /// Variable name for the path.
81    path_variable: String,
82
83    /// Whether this is allShortestPaths (true) or shortestPath (false).
84    all_shortest: bool,
85
86    /// Graph execution context.
87    graph_ctx: Arc<GraphExecutionContext>,
88
89    /// Output schema.
90    schema: SchemaRef,
91
92    /// Cached plan properties.
93    properties: PlanProperties,
94
95    /// Execution metrics.
96    metrics: ExecutionPlanMetricsSet,
97}
98
99impl fmt::Debug for GraphShortestPathExec {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        f.debug_struct("GraphShortestPathExec")
102            .field("source_column", &self.source_column)
103            .field("target_column", &self.target_column)
104            .field("edge_type_ids", &self.edge_type_ids)
105            .field("direction", &self.direction)
106            .field("path_variable", &self.path_variable)
107            .field("all_shortest", &self.all_shortest)
108            .finish()
109    }
110}
111
112impl GraphShortestPathExec {
113    /// Create a new shortest path execution plan.
114    ///
115    /// # Arguments
116    ///
117    /// * `input` - Input plan providing source and target vertices
118    /// * `source_column` - Column name containing source VIDs
119    /// * `target_column` - Column name containing target VIDs
120    /// * `edge_type_ids` - Edge types to traverse
121    /// * `direction` - Traversal direction
122    /// * `path_variable` - Variable name for the path
123    /// * `graph_ctx` - Graph execution context
124    #[expect(
125        clippy::too_many_arguments,
126        reason = "Shortest path requires many parameters"
127    )]
128    pub fn new(
129        input: Arc<dyn ExecutionPlan>,
130        source_column: impl Into<String>,
131        target_column: impl Into<String>,
132        edge_type_ids: Vec<u32>,
133        direction: Direction,
134        path_variable: impl Into<String>,
135        graph_ctx: Arc<GraphExecutionContext>,
136        all_shortest: bool,
137    ) -> Self {
138        let source_column = source_column.into();
139        let target_column = target_column.into();
140        let path_variable = path_variable.into();
141
142        let schema = Self::build_schema(input.schema(), &path_variable);
143        let properties = compute_plan_properties(schema.clone());
144
145        Self {
146            input,
147            source_column,
148            target_column,
149            edge_type_ids,
150            direction,
151            path_variable,
152            all_shortest,
153            graph_ctx,
154            schema,
155            properties,
156            metrics: ExecutionPlanMetricsSet::new(),
157        }
158    }
159
160    /// Build output schema.
161    fn build_schema(input_schema: SchemaRef, path_variable: &str) -> SchemaRef {
162        let mut fields: Vec<Field> = input_schema
163            .fields()
164            .iter()
165            .map(|f| f.as_ref().clone())
166            .collect();
167
168        // Add the proper path struct column (nodes + relationships)
169        fields.push(crate::query::df_graph::common::build_path_struct_field(
170            path_variable,
171        ));
172
173        // Add path column (raw VID list for internal use)
174        let path_col_name = format!("{}._path", path_variable);
175        fields.push(Field::new(
176            &path_col_name,
177            DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
178            true, // Nullable - null when no path exists
179        ));
180
181        // Add path length column
182        let len_col_name = format!("{}._length", path_variable);
183        fields.push(Field::new(&len_col_name, DataType::UInt64, true));
184
185        Arc::new(Schema::new(fields))
186    }
187}
188
189impl DisplayAs for GraphShortestPathExec {
190    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        let mode = if self.all_shortest { "all" } else { "any" };
192        write!(
193            f,
194            "GraphShortestPathExec: {} -> {} via {:?} ({})",
195            self.source_column, self.target_column, self.edge_type_ids, mode
196        )
197    }
198}
199
200impl ExecutionPlan for GraphShortestPathExec {
201    fn name(&self) -> &str {
202        "GraphShortestPathExec"
203    }
204
205    fn as_any(&self) -> &dyn Any {
206        self
207    }
208
209    fn schema(&self) -> SchemaRef {
210        Arc::clone(&self.schema)
211    }
212
213    fn properties(&self) -> &PlanProperties {
214        &self.properties
215    }
216
217    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
218        vec![&self.input]
219    }
220
221    fn with_new_children(
222        self: Arc<Self>,
223        children: Vec<Arc<dyn ExecutionPlan>>,
224    ) -> DFResult<Arc<dyn ExecutionPlan>> {
225        if children.len() != 1 {
226            return Err(datafusion::error::DataFusionError::Plan(
227                "GraphShortestPathExec requires exactly one child".to_string(),
228            ));
229        }
230
231        Ok(Arc::new(Self::new(
232            Arc::clone(&children[0]),
233            self.source_column.clone(),
234            self.target_column.clone(),
235            self.edge_type_ids.clone(),
236            self.direction,
237            self.path_variable.clone(),
238            Arc::clone(&self.graph_ctx),
239            self.all_shortest,
240        )))
241    }
242
243    fn execute(
244        &self,
245        partition: usize,
246        context: Arc<TaskContext>,
247    ) -> DFResult<SendableRecordBatchStream> {
248        let input_stream = self.input.execute(partition, context)?;
249
250        let metrics = BaselineMetrics::new(&self.metrics, partition);
251
252        let warm_fut = self
253            .graph_ctx
254            .warming_future(self.edge_type_ids.clone(), self.direction);
255
256        Ok(Box::pin(GraphShortestPathStream {
257            input: input_stream,
258            source_column: self.source_column.clone(),
259            target_column: self.target_column.clone(),
260            edge_type_ids: self.edge_type_ids.clone(),
261            direction: self.direction,
262            all_shortest: self.all_shortest,
263            graph_ctx: Arc::clone(&self.graph_ctx),
264            schema: Arc::clone(&self.schema),
265            state: ShortestPathStreamState::Warming(warm_fut),
266            metrics,
267        }))
268    }
269
270    fn metrics(&self) -> Option<MetricsSet> {
271        Some(self.metrics.clone_inner())
272    }
273}
274
275/// State machine for shortest path stream execution.
276enum ShortestPathStreamState {
277    /// Warming adjacency CSRs before first batch.
278    Warming(Pin<Box<dyn std::future::Future<Output = DFResult<()>> + Send>>),
279    /// Processing input batches.
280    Reading,
281    /// Stream is done.
282    Done,
283}
284
285/// Stream that computes shortest paths.
286struct GraphShortestPathStream {
287    /// Input stream.
288    input: SendableRecordBatchStream,
289
290    /// Column name containing source VIDs.
291    source_column: String,
292
293    /// Column name containing target VIDs.
294    target_column: String,
295
296    /// Edge type IDs to traverse.
297    edge_type_ids: Vec<u32>,
298
299    /// Traversal direction.
300    direction: Direction,
301
302    /// Whether this is allShortestPaths mode.
303    all_shortest: bool,
304
305    /// Graph execution context.
306    graph_ctx: Arc<GraphExecutionContext>,
307
308    /// Output schema.
309    schema: SchemaRef,
310
311    /// Stream state.
312    state: ShortestPathStreamState,
313
314    /// Metrics.
315    metrics: BaselineMetrics,
316}
317
318impl GraphShortestPathStream {
319    /// Compute shortest path between two vertices using BFS.
320    fn compute_shortest_path(&self, source: Vid, target: Vid) -> Option<Vec<Vid>> {
321        if source == target {
322            return Some(vec![source]);
323        }
324
325        let mut visited: HashSet<Vid> = HashSet::new();
326        let mut queue: VecDeque<(Vid, Vec<Vid>)> = VecDeque::new();
327
328        visited.insert(source);
329        queue.push_back((source, vec![source]));
330
331        while let Some((current, path)) = queue.pop_front() {
332            // Get neighbors for all edge types
333            for &edge_type in &self.edge_type_ids {
334                let neighbors = self
335                    .graph_ctx
336                    .get_neighbors(current, edge_type, self.direction);
337
338                for (neighbor, _eid) in neighbors {
339                    if neighbor == target {
340                        // Found the target
341                        let mut result = path.clone();
342                        result.push(target);
343                        return Some(result);
344                    }
345
346                    if !visited.contains(&neighbor) {
347                        visited.insert(neighbor);
348                        let mut new_path = path.clone();
349                        new_path.push(neighbor);
350                        queue.push_back((neighbor, new_path));
351                    }
352                }
353            }
354        }
355
356        None // No path found
357    }
358
359    /// Compute all shortest paths between two vertices using layer-by-layer BFS
360    /// with predecessor tracking.
361    ///
362    /// Returns all paths of minimum length from source to target.
363    fn compute_all_shortest_paths(&self, source: Vid, target: Vid) -> Vec<Vec<Vid>> {
364        if source == target {
365            return vec![vec![source]];
366        }
367
368        // Layer-by-layer BFS recording ALL predecessors at shortest depth
369        let mut depth: FxHashMap<Vid, u32> = FxHashMap::default();
370        let mut predecessors: FxHashMap<Vid, Vec<Vid>> = FxHashMap::default();
371        depth.insert(source, 0);
372
373        let mut current_layer: Vec<Vid> = vec![source];
374        let mut current_depth = 0u32;
375        let mut target_found = false;
376
377        while !current_layer.is_empty() && !target_found {
378            current_depth += 1;
379            let mut next_layer_set: HashSet<Vid> = HashSet::new();
380
381            for &current in &current_layer {
382                for &edge_type in &self.edge_type_ids {
383                    let neighbors =
384                        self.graph_ctx
385                            .get_neighbors(current, edge_type, self.direction);
386
387                    for (neighbor, _eid) in neighbors {
388                        if let Some(&d) = depth.get(&neighbor) {
389                            // Already discovered: only add predecessor if same depth
390                            if d == current_depth {
391                                predecessors.entry(neighbor).or_default().push(current);
392                            }
393                            continue;
394                        }
395
396                        // First time seeing this vertex at current_depth
397                        depth.insert(neighbor, current_depth);
398                        predecessors.entry(neighbor).or_default().push(current);
399
400                        if neighbor == target {
401                            target_found = true;
402                        } else {
403                            next_layer_set.insert(neighbor);
404                        }
405                    }
406                }
407            }
408
409            current_layer = next_layer_set.into_iter().collect();
410        }
411
412        if !target_found {
413            return vec![];
414        }
415
416        // Enumerate all shortest paths via backward DFS from target to source
417        let mut result: Vec<Vec<Vid>> = Vec::new();
418        let mut stack: Vec<(Vid, Vec<Vid>)> = vec![(target, vec![target])];
419
420        while let Some((node, path)) = stack.pop() {
421            if node == source {
422                let mut full_path = path;
423                full_path.reverse();
424                result.push(full_path);
425                continue;
426            }
427            if let Some(preds) = predecessors.get(&node) {
428                for &pred in preds {
429                    let mut new_path = path.clone();
430                    new_path.push(pred);
431                    stack.push((pred, new_path));
432                }
433            }
434        }
435
436        result
437    }
438
439    /// Process a single input batch.
440    fn process_batch(&self, batch: RecordBatch) -> DFResult<RecordBatch> {
441        // Extract source and target VIDs
442        let source_col = batch.column_by_name(&self.source_column).ok_or_else(|| {
443            datafusion::error::DataFusionError::Execution(format!(
444                "Source column '{}' not found",
445                self.source_column
446            ))
447        })?;
448
449        let target_col = batch.column_by_name(&self.target_column).ok_or_else(|| {
450            datafusion::error::DataFusionError::Execution(format!(
451                "Target column '{}' not found",
452                self.target_column
453            ))
454        })?;
455
456        let source_vid_cow = column_as_vid_array(source_col.as_ref())?;
457        let source_vids: &UInt64Array = &source_vid_cow;
458
459        let target_vid_cow = column_as_vid_array(target_col.as_ref())?;
460        let target_vids: &UInt64Array = &target_vid_cow;
461
462        if self.all_shortest {
463            // allShortestPaths: each input row can produce multiple output rows
464            let mut row_indices: Vec<u32> = Vec::new();
465            let mut all_paths: Vec<Option<Vec<Vid>>> = Vec::new();
466
467            for i in 0..batch.num_rows() {
468                if source_vids.is_null(i) || target_vids.is_null(i) {
469                    row_indices.push(i as u32);
470                    all_paths.push(None);
471                } else {
472                    let source = Vid::from(source_vids.value(i));
473                    let target = Vid::from(target_vids.value(i));
474                    let paths = self.compute_all_shortest_paths(source, target);
475                    if paths.is_empty() {
476                        row_indices.push(i as u32);
477                        all_paths.push(None);
478                    } else {
479                        for path in paths {
480                            row_indices.push(i as u32);
481                            all_paths.push(Some(path));
482                        }
483                    }
484                }
485            }
486
487            // Expand input batch rows according to row_indices
488            let indices = UInt32Array::from(row_indices);
489            let expanded_columns: Vec<ArrayRef> = batch
490                .columns()
491                .iter()
492                .map(|col| {
493                    take(col.as_ref(), &indices, None).map_err(|e| {
494                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
495                    })
496                })
497                .collect::<DFResult<Vec<_>>>()?;
498            let expanded_batch =
499                RecordBatch::try_new(batch.schema(), expanded_columns).map_err(arrow_err)?;
500
501            self.build_output_batch(&expanded_batch, &all_paths)
502        } else {
503            // shortestPath: one path per input row
504            let mut paths: Vec<Option<Vec<Vid>>> = Vec::with_capacity(batch.num_rows());
505
506            for i in 0..batch.num_rows() {
507                let path = if source_vids.is_null(i) || target_vids.is_null(i) {
508                    None
509                } else {
510                    let source = Vid::from(source_vids.value(i));
511                    let target = Vid::from(target_vids.value(i));
512                    self.compute_shortest_path(source, target)
513                };
514                paths.push(path);
515            }
516
517            self.build_output_batch(&batch, &paths)
518        }
519    }
520
521    /// Build output batch with path columns.
522    fn build_output_batch(
523        &self,
524        input: &RecordBatch,
525        paths: &[Option<Vec<Vid>>],
526    ) -> DFResult<RecordBatch> {
527        let num_rows = paths.len();
528        let query_ctx = self.graph_ctx.query_context();
529
530        // Copy input columns
531        let mut columns: Vec<ArrayRef> = input.columns().to_vec();
532
533        // Build the path struct column (nodes + relationships)
534        let mut nodes_builder = new_node_list_builder();
535        let mut rels_builder =
536            ListBuilder::new(StructBuilder::from_fields(edge_struct_fields(), num_rows));
537        let mut path_validity = Vec::with_capacity(num_rows);
538
539        for path in paths {
540            match path {
541                Some(vids) => {
542                    // Add all nodes
543                    for &vid in vids {
544                        super::common::append_node_to_struct(
545                            nodes_builder.values(),
546                            vid,
547                            &query_ctx,
548                        );
549                    }
550                    nodes_builder.append(true);
551
552                    // Add edges between consecutive nodes
553                    // BFS returns node VIDs; edges are between consecutive pairs
554                    for window in vids.windows(2) {
555                        let src = window[0];
556                        let dst = window[1];
557                        let (eid, type_name) = self.find_edge(src, dst);
558                        super::common::append_edge_to_struct(
559                            rels_builder.values(),
560                            eid,
561                            &type_name,
562                            src.as_u64(),
563                            dst.as_u64(),
564                            &query_ctx,
565                        );
566                    }
567                    rels_builder.append(true);
568                    path_validity.push(true);
569                }
570                None => {
571                    // Null path
572                    nodes_builder.append(false);
573                    rels_builder.append(false);
574                    path_validity.push(false);
575                }
576            }
577        }
578
579        let nodes_array = Arc::new(nodes_builder.finish()) as ArrayRef;
580        let rels_array = Arc::new(rels_builder.finish()) as ArrayRef;
581
582        let path_struct =
583            super::common::build_path_struct_array(nodes_array, rels_array, path_validity)?;
584        columns.push(Arc::new(path_struct));
585
586        // Build raw path list column (VID list for internal use)
587        let mut list_builder = ListBuilder::new(UInt64Builder::new());
588        for path in paths {
589            match path {
590                Some(p) => {
591                    let values: Vec<u64> = p.iter().map(|v| v.as_u64()).collect();
592                    list_builder.values().append_slice(&values);
593                    list_builder.append(true);
594                }
595                None => {
596                    list_builder.append(false); // Null for no path
597                }
598            }
599        }
600        columns.push(Arc::new(list_builder.finish()));
601
602        // Build path length column
603        let lengths: Vec<Option<u64>> = paths
604            .iter()
605            .map(|p| p.as_ref().map(|path| (path.len() - 1) as u64))
606            .collect();
607        columns.push(Arc::new(UInt64Array::from(lengths)));
608
609        self.metrics.record_output(num_rows);
610
611        RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(arrow_err)
612    }
613
614    /// Find an edge connecting src to dst.
615    /// Returns (eid, type_name). Property lookup is handled by `append_edge_to_struct`.
616    fn find_edge(&self, src: Vid, dst: Vid) -> (uni_common::core::id::Eid, String) {
617        let query_ctx = self.graph_ctx.query_context();
618        for &edge_type in &self.edge_type_ids {
619            let neighbors = self.graph_ctx.get_neighbors(src, edge_type, self.direction);
620            for (neighbor, eid) in neighbors {
621                if neighbor == dst {
622                    let type_name =
623                        l0_visibility::get_edge_type(eid, &query_ctx).unwrap_or_default();
624                    return (eid, type_name);
625                }
626            }
627        }
628        (uni_common::core::id::Eid::from(0u64), String::new())
629    }
630}
631
632impl Stream for GraphShortestPathStream {
633    type Item = DFResult<RecordBatch>;
634
635    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
636        loop {
637            let state = std::mem::replace(&mut self.state, ShortestPathStreamState::Done);
638
639            match state {
640                ShortestPathStreamState::Warming(mut fut) => match fut.as_mut().poll(cx) {
641                    Poll::Ready(Ok(())) => {
642                        self.state = ShortestPathStreamState::Reading;
643                        // Continue loop to start reading
644                    }
645                    Poll::Ready(Err(e)) => {
646                        self.state = ShortestPathStreamState::Done;
647                        return Poll::Ready(Some(Err(e)));
648                    }
649                    Poll::Pending => {
650                        self.state = ShortestPathStreamState::Warming(fut);
651                        return Poll::Pending;
652                    }
653                },
654                ShortestPathStreamState::Reading => {
655                    // Check timeout
656                    if let Err(e) = self.graph_ctx.check_timeout() {
657                        return Poll::Ready(Some(Err(
658                            datafusion::error::DataFusionError::Execution(e.to_string()),
659                        )));
660                    }
661
662                    match self.input.poll_next_unpin(cx) {
663                        Poll::Ready(Some(Ok(batch))) => {
664                            let result = self.process_batch(batch);
665                            self.state = ShortestPathStreamState::Reading;
666                            return Poll::Ready(Some(result));
667                        }
668                        Poll::Ready(Some(Err(e))) => {
669                            self.state = ShortestPathStreamState::Done;
670                            return Poll::Ready(Some(Err(e)));
671                        }
672                        Poll::Ready(None) => {
673                            self.state = ShortestPathStreamState::Done;
674                            return Poll::Ready(None);
675                        }
676                        Poll::Pending => {
677                            self.state = ShortestPathStreamState::Reading;
678                            return Poll::Pending;
679                        }
680                    }
681                }
682                ShortestPathStreamState::Done => {
683                    return Poll::Ready(None);
684                }
685            }
686        }
687    }
688}
689
690impl RecordBatchStream for GraphShortestPathStream {
691    fn schema(&self) -> SchemaRef {
692        Arc::clone(&self.schema)
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    #[test]
701    fn test_shortest_path_schema() {
702        let input_schema = Arc::new(Schema::new(vec![
703            Field::new("_source_vid", DataType::UInt64, false),
704            Field::new("_target_vid", DataType::UInt64, false),
705        ]));
706
707        let output_schema = GraphShortestPathExec::build_schema(input_schema, "p");
708
709        assert_eq!(output_schema.fields().len(), 5);
710        assert_eq!(output_schema.field(0).name(), "_source_vid");
711        assert_eq!(output_schema.field(1).name(), "_target_vid");
712        assert_eq!(output_schema.field(2).name(), "p");
713        assert_eq!(output_schema.field(3).name(), "p._path");
714        assert_eq!(output_schema.field(4).name(), "p._length");
715    }
716}