rustkernel_graph/
paths.rs

1//! Shortest path kernels.
2//!
3//! This module provides shortest path algorithms:
4//! - Single-source shortest path (SSSP) via BFS/Delta-Stepping
5//! - All-pairs shortest path (APSP)
6//! - K-shortest paths (Yen's algorithm)
7
8use crate::types::CsrGraph;
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
10use std::cmp::Ordering;
11use std::collections::{BinaryHeap, VecDeque};
12
13// ============================================================================
14// Shortest Path Results
15// ============================================================================
16
17/// Result of single-source shortest path calculation.
18#[derive(Debug, Clone)]
19pub struct ShortestPathResult {
20    /// Node index.
21    pub node_index: usize,
22    /// Shortest distance from source (f64::INFINITY if unreachable).
23    pub distance: f64,
24    /// Predecessor node index on shortest path (-1 if no path).
25    pub predecessor: i64,
26    /// Whether node is reachable from source.
27    pub is_reachable: bool,
28    /// Number of hops (for unweighted graphs).
29    pub hop_count: u32,
30}
31
32/// A single path result.
33#[derive(Debug, Clone)]
34pub struct PathResult {
35    /// Source node index.
36    pub source: usize,
37    /// Target node index.
38    pub target: usize,
39    /// Total path length (sum of edge weights).
40    pub path_length: f64,
41    /// Number of hops (edges) in path.
42    pub hop_count: usize,
43    /// Ordered list of node indices along the path.
44    pub node_path: Vec<usize>,
45}
46
47/// All-pairs shortest path result.
48#[derive(Debug, Clone)]
49pub struct AllPairsResult {
50    /// Number of nodes.
51    pub node_count: usize,
52    /// Distance matrix in row-major order.
53    /// distances[i * node_count + j] = shortest distance from node i to node j.
54    pub distances: Vec<f64>,
55    /// Predecessor matrix for path reconstruction.
56    pub predecessors: Vec<i64>,
57}
58
59impl AllPairsResult {
60    /// Get distance from source to target.
61    pub fn distance(&self, source: usize, target: usize) -> f64 {
62        self.distances[source * self.node_count + target]
63    }
64
65    /// Reconstruct path from source to target.
66    pub fn reconstruct_path(&self, source: usize, target: usize) -> Option<Vec<usize>> {
67        if !self.distance(source, target).is_finite() {
68            return None;
69        }
70
71        let mut path = Vec::new();
72        let mut current = target;
73
74        while current != source {
75            path.push(current);
76            let pred = self.predecessors[source * self.node_count + current];
77            if pred < 0 {
78                return None;
79            }
80            current = pred as usize;
81        }
82
83        path.push(source);
84        path.reverse();
85        Some(path)
86    }
87}
88
89// ============================================================================
90// Shortest Path Kernel
91// ============================================================================
92
93/// Shortest path kernel using BFS (unweighted) or Delta-Stepping (weighted).
94#[derive(Debug, Clone)]
95pub struct ShortestPath {
96    metadata: KernelMetadata,
97}
98
99impl Default for ShortestPath {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl ShortestPath {
106    /// Create a new shortest path kernel.
107    #[must_use]
108    pub fn new() -> Self {
109        Self {
110            metadata: KernelMetadata::batch("graph/shortest-path", Domain::GraphAnalytics)
111                .with_description("Shortest path via BFS/Delta-Stepping")
112                .with_throughput(50_000)
113                .with_latency_us(80.0),
114        }
115    }
116
117    /// Compute single-source shortest paths using BFS (for unweighted graphs).
118    ///
119    /// # Arguments
120    /// * `graph` - Input graph (CSR format)
121    /// * `source` - Source node index
122    pub fn compute_sssp_bfs(graph: &CsrGraph, source: usize) -> Vec<ShortestPathResult> {
123        let n = graph.num_nodes;
124        let mut distances = vec![f64::INFINITY; n];
125        let mut predecessors = vec![-1i64; n];
126        let mut hop_counts = vec![0u32; n];
127
128        distances[source] = 0.0;
129
130        let mut queue = VecDeque::new();
131        queue.push_back(source);
132
133        while let Some(v) = queue.pop_front() {
134            let current_dist = distances[v];
135
136            for &w in graph.neighbors(v as u64) {
137                let w = w as usize;
138                if distances[w].is_infinite() {
139                    distances[w] = current_dist + 1.0;
140                    predecessors[w] = v as i64;
141                    hop_counts[w] = hop_counts[v] + 1;
142                    queue.push_back(w);
143                }
144            }
145        }
146
147        (0..n)
148            .map(|i| ShortestPathResult {
149                node_index: i,
150                distance: distances[i],
151                predecessor: predecessors[i],
152                is_reachable: distances[i].is_finite(),
153                hop_count: hop_counts[i],
154            })
155            .collect()
156    }
157
158    /// Compute single-source shortest paths using Dijkstra (for weighted graphs).
159    ///
160    /// # Arguments
161    /// * `graph` - Input graph (CSR format)
162    /// * `source` - Source node index
163    /// * `weights` - Edge weights (parallel to graph edges)
164    pub fn compute_sssp_dijkstra(
165        graph: &CsrGraph,
166        source: usize,
167        weights: &[f64],
168    ) -> Vec<ShortestPathResult> {
169        let n = graph.num_nodes;
170        let mut distances = vec![f64::INFINITY; n];
171        let mut predecessors = vec![-1i64; n];
172        let mut hop_counts = vec![0u32; n];
173
174        distances[source] = 0.0;
175
176        // Priority queue: (negative distance, node) - negated for min-heap behavior
177        let mut heap = BinaryHeap::new();
178        heap.push(HeapNode {
179            dist: 0.0,
180            node: source,
181        });
182
183        while let Some(HeapNode { dist, node: v }) = heap.pop() {
184            if dist > distances[v] {
185                continue; // Already processed with shorter distance
186            }
187
188            let neighbors = graph.neighbors(v as u64);
189            let edge_start = if v == 0 {
190                0
191            } else {
192                graph.row_offsets[v] as usize
193            };
194
195            for (i, &w) in neighbors.iter().enumerate() {
196                let w = w as usize;
197                let weight = weights.get(edge_start + i).copied().unwrap_or(1.0);
198                let new_dist = distances[v] + weight;
199
200                if new_dist < distances[w] {
201                    distances[w] = new_dist;
202                    predecessors[w] = v as i64;
203                    hop_counts[w] = hop_counts[v] + 1;
204                    heap.push(HeapNode {
205                        dist: new_dist,
206                        node: w,
207                    });
208                }
209            }
210        }
211
212        (0..n)
213            .map(|i| ShortestPathResult {
214                node_index: i,
215                distance: distances[i],
216                predecessor: predecessors[i],
217                is_reachable: distances[i].is_finite(),
218                hop_count: hop_counts[i],
219            })
220            .collect()
221    }
222
223    /// Compute all-pairs shortest paths.
224    pub fn compute_apsp(graph: &CsrGraph) -> AllPairsResult {
225        let n = graph.num_nodes;
226        let mut distances = vec![f64::INFINITY; n * n];
227        let mut predecessors = vec![-1i64; n * n];
228
229        // Run SSSP from each node
230        for source in 0..n {
231            let sssp = Self::compute_sssp_bfs(graph, source);
232
233            for result in sssp {
234                let idx = source * n + result.node_index;
235                distances[idx] = result.distance;
236                predecessors[idx] = result.predecessor;
237            }
238        }
239
240        AllPairsResult {
241            node_count: n,
242            distances,
243            predecessors,
244        }
245    }
246
247    /// Reconstruct path from source to target.
248    pub fn reconstruct_path(
249        sssp: &[ShortestPathResult],
250        source: usize,
251        target: usize,
252    ) -> Option<Vec<usize>> {
253        if !sssp[target].is_reachable {
254            return None;
255        }
256
257        let mut path = Vec::new();
258        let mut current = target;
259
260        while current != source {
261            path.push(current);
262            let pred = sssp[current].predecessor;
263            if pred < 0 {
264                return None;
265            }
266            current = pred as usize;
267        }
268
269        path.push(source);
270        path.reverse();
271        Some(path)
272    }
273
274    /// Compute shortest path between two nodes.
275    pub fn compute_path(graph: &CsrGraph, source: usize, target: usize) -> Option<PathResult> {
276        let sssp = Self::compute_sssp_bfs(graph, source);
277
278        if !sssp[target].is_reachable {
279            return None;
280        }
281
282        let node_path = Self::reconstruct_path(&sssp, source, target)?;
283
284        Some(PathResult {
285            source,
286            target,
287            path_length: sssp[target].distance,
288            hop_count: node_path.len() - 1,
289            node_path,
290        })
291    }
292
293    /// Find k shortest paths using Yen's algorithm.
294    pub fn compute_k_shortest(
295        graph: &CsrGraph,
296        source: usize,
297        target: usize,
298        k: usize,
299    ) -> Vec<PathResult> {
300        let mut result_paths = Vec::new();
301
302        // First, find the shortest path
303        if let Some(first_path) = Self::compute_path(graph, source, target) {
304            result_paths.push(first_path);
305        } else {
306            return result_paths;
307        }
308
309        // Candidate paths
310        let mut candidates: Vec<PathResult> = Vec::new();
311
312        for _i in 1..k {
313            let prev_path = &result_paths[result_paths.len() - 1];
314
315            // For each deviation point on the previous path
316            for j in 0..(prev_path.node_path.len() - 1) {
317                let spur_node = prev_path.node_path[j];
318                let root_path: Vec<usize> = prev_path.node_path[..=j].to_vec();
319
320                // Create modified graph (remove edges used by previous paths at this deviation)
321                // For simplicity, we'll use a less efficient but correct approach
322                let edges_to_avoid = Self::collect_edges_to_avoid(&result_paths, &root_path);
323
324                // Find path in modified graph
325                if let Some(spur_path) =
326                    Self::compute_path_avoiding(graph, spur_node, target, &edges_to_avoid)
327                {
328                    let mut total_path = root_path.clone();
329                    total_path.extend(spur_path.node_path.into_iter().skip(1));
330
331                    let path_length = (total_path.len() - 1) as f64;
332                    let candidate = PathResult {
333                        source,
334                        target,
335                        path_length,
336                        hop_count: total_path.len() - 1,
337                        node_path: total_path,
338                    };
339
340                    // Add if not already in candidates or results
341                    if !Self::path_exists(&candidates, &candidate.node_path)
342                        && !Self::path_exists_in_results(&result_paths, &candidate.node_path)
343                    {
344                        candidates.push(candidate);
345                    }
346                }
347            }
348
349            if candidates.is_empty() {
350                break;
351            }
352
353            // Sort candidates by path length and take the best one
354            candidates.sort_by(|a, b| {
355                a.path_length
356                    .partial_cmp(&b.path_length)
357                    .unwrap_or(Ordering::Equal)
358            });
359
360            result_paths.push(candidates.remove(0));
361        }
362
363        result_paths
364    }
365
366    /// Compute path avoiding certain edges.
367    fn compute_path_avoiding(
368        graph: &CsrGraph,
369        source: usize,
370        target: usize,
371        avoid_edges: &[(usize, usize)],
372    ) -> Option<PathResult> {
373        let n = graph.num_nodes;
374        let mut distances = vec![f64::INFINITY; n];
375        let mut predecessors = vec![-1i64; n];
376
377        distances[source] = 0.0;
378
379        let mut queue = VecDeque::new();
380        queue.push_back(source);
381
382        while let Some(v) = queue.pop_front() {
383            if v == target {
384                break;
385            }
386
387            let current_dist = distances[v];
388
389            for &w in graph.neighbors(v as u64) {
390                let w = w as usize;
391
392                // Skip avoided edges
393                if avoid_edges.contains(&(v, w)) {
394                    continue;
395                }
396
397                if distances[w].is_infinite() {
398                    distances[w] = current_dist + 1.0;
399                    predecessors[w] = v as i64;
400                    queue.push_back(w);
401                }
402            }
403        }
404
405        if distances[target].is_infinite() {
406            return None;
407        }
408
409        // Reconstruct path
410        let mut path = Vec::new();
411        let mut current = target;
412
413        while current != source {
414            path.push(current);
415            let pred = predecessors[current];
416            if pred < 0 {
417                return None;
418            }
419            current = pred as usize;
420        }
421
422        path.push(source);
423        path.reverse();
424
425        Some(PathResult {
426            source,
427            target,
428            path_length: distances[target],
429            hop_count: path.len() - 1,
430            node_path: path,
431        })
432    }
433
434    fn collect_edges_to_avoid(
435        result_paths: &[PathResult],
436        root_path: &[usize],
437    ) -> Vec<(usize, usize)> {
438        let mut edges = Vec::new();
439
440        for path in result_paths {
441            // Check if this path shares the root
442            if path.node_path.len() >= root_path.len()
443                && path.node_path[..root_path.len()] == *root_path
444            {
445                // Add the edge right after root_path
446                if path.node_path.len() > root_path.len() {
447                    let from = root_path[root_path.len() - 1];
448                    let to = path.node_path[root_path.len()];
449                    edges.push((from, to));
450                }
451            }
452        }
453
454        edges
455    }
456
457    fn path_exists(candidates: &[PathResult], path: &[usize]) -> bool {
458        candidates.iter().any(|c| c.node_path == path)
459    }
460
461    fn path_exists_in_results(results: &[PathResult], path: &[usize]) -> bool {
462        results.iter().any(|r| r.node_path == path)
463    }
464
465    /// Compute eccentricity for each node (max distance to any other node).
466    #[allow(clippy::needless_range_loop)]
467    pub fn compute_eccentricity(graph: &CsrGraph) -> Vec<f64> {
468        let n = graph.num_nodes;
469        let mut eccentricities = vec![0.0; n];
470
471        for source in 0..n {
472            let sssp = Self::compute_sssp_bfs(graph, source);
473            let max_dist = sssp
474                .iter()
475                .filter(|r| r.is_reachable)
476                .map(|r| r.distance)
477                .fold(0.0, f64::max);
478            eccentricities[source] = max_dist;
479        }
480
481        eccentricities
482    }
483
484    /// Compute graph diameter (max eccentricity).
485    pub fn compute_diameter(graph: &CsrGraph) -> f64 {
486        Self::compute_eccentricity(graph)
487            .into_iter()
488            .fold(0.0, f64::max)
489    }
490
491    /// Compute graph radius (min eccentricity).
492    pub fn compute_radius(graph: &CsrGraph) -> f64 {
493        Self::compute_eccentricity(graph)
494            .into_iter()
495            .filter(|&e| e > 0.0)
496            .fold(f64::INFINITY, f64::min)
497    }
498}
499
500impl GpuKernel for ShortestPath {
501    fn metadata(&self) -> &KernelMetadata {
502        &self.metadata
503    }
504}
505
506/// Helper struct for Dijkstra's priority queue.
507#[derive(Clone, PartialEq)]
508struct HeapNode {
509    dist: f64,
510    node: usize,
511}
512
513impl Eq for HeapNode {}
514
515impl Ord for HeapNode {
516    fn cmp(&self, other: &Self) -> Ordering {
517        // Reverse ordering for min-heap
518        other
519            .dist
520            .partial_cmp(&self.dist)
521            .unwrap_or(Ordering::Equal)
522    }
523}
524
525impl PartialOrd for HeapNode {
526    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
527        Some(self.cmp(other))
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    fn create_line_graph() -> CsrGraph {
536        // Line: 0 - 1 - 2 - 3
537        CsrGraph::from_edges(4, &[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)])
538    }
539
540    fn create_complete_graph() -> CsrGraph {
541        // Complete graph K4
542        CsrGraph::from_edges(
543            4,
544            &[
545                (0, 1),
546                (0, 2),
547                (0, 3),
548                (1, 0),
549                (1, 2),
550                (1, 3),
551                (2, 0),
552                (2, 1),
553                (2, 3),
554                (3, 0),
555                (3, 1),
556                (3, 2),
557            ],
558        )
559    }
560
561    fn create_disconnected_graph() -> CsrGraph {
562        // Two disconnected pairs: 0-1 and 2-3
563        CsrGraph::from_edges(4, &[(0, 1), (1, 0), (2, 3), (3, 2)])
564    }
565
566    #[test]
567    fn test_shortest_path_metadata() {
568        let kernel = ShortestPath::new();
569        assert_eq!(kernel.metadata().id, "graph/shortest-path");
570        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
571    }
572
573    #[test]
574    fn test_sssp_bfs_line() {
575        let graph = create_line_graph();
576        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
577
578        assert_eq!(sssp[0].distance, 0.0);
579        assert_eq!(sssp[1].distance, 1.0);
580        assert_eq!(sssp[2].distance, 2.0);
581        assert_eq!(sssp[3].distance, 3.0);
582    }
583
584    #[test]
585    fn test_sssp_bfs_complete() {
586        let graph = create_complete_graph();
587        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
588
589        // In complete graph, all nodes are distance 1 from any other
590        assert_eq!(sssp[0].distance, 0.0);
591        assert_eq!(sssp[1].distance, 1.0);
592        assert_eq!(sssp[2].distance, 1.0);
593        assert_eq!(sssp[3].distance, 1.0);
594    }
595
596    #[test]
597    fn test_sssp_disconnected() {
598        let graph = create_disconnected_graph();
599        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
600
601        assert!(sssp[0].is_reachable);
602        assert!(sssp[1].is_reachable);
603        assert!(!sssp[2].is_reachable);
604        assert!(!sssp[3].is_reachable);
605    }
606
607    #[test]
608    fn test_reconstruct_path() {
609        let graph = create_line_graph();
610        let sssp = ShortestPath::compute_sssp_bfs(&graph, 0);
611
612        let path = ShortestPath::reconstruct_path(&sssp, 0, 3);
613        assert!(path.is_some());
614        let path = path.unwrap();
615        assert_eq!(path, vec![0, 1, 2, 3]);
616    }
617
618    #[test]
619    fn test_compute_path() {
620        let graph = create_line_graph();
621        let path = ShortestPath::compute_path(&graph, 0, 3);
622
623        assert!(path.is_some());
624        let path = path.unwrap();
625        assert_eq!(path.hop_count, 3);
626        assert_eq!(path.node_path, vec![0, 1, 2, 3]);
627    }
628
629    #[test]
630    fn test_apsp() {
631        let graph = create_line_graph();
632        let apsp = ShortestPath::compute_apsp(&graph);
633
634        assert_eq!(apsp.distance(0, 3), 3.0);
635        assert_eq!(apsp.distance(1, 2), 1.0);
636        assert_eq!(apsp.distance(0, 0), 0.0);
637    }
638
639    #[test]
640    fn test_diameter() {
641        let graph = create_line_graph();
642        let diameter = ShortestPath::compute_diameter(&graph);
643
644        assert_eq!(diameter, 3.0);
645    }
646
647    #[test]
648    fn test_k_shortest() {
649        let graph = create_complete_graph();
650        let paths = ShortestPath::compute_k_shortest(&graph, 0, 3, 3);
651
652        assert!(!paths.is_empty());
653        assert_eq!(paths[0].hop_count, 1); // Direct edge
654    }
655}