Skip to main content

uni_algo/algo/algorithms/
k_shortest_paths.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! K-Shortest Paths Algorithm (Yen's Algorithm).
5//!
6//! Finds the K shortest loop-less paths from source to target.
7
8use crate::algo::GraphProjection;
9use crate::algo::algorithms::Algorithm;
10use std::cmp::Reverse;
11use std::collections::{BinaryHeap, HashSet};
12use uni_common::core::id::Vid;
13
14pub struct KShortestPaths;
15
16#[derive(Debug, Clone)]
17pub struct KShortestPathsConfig {
18    pub source: Vid,
19    pub target: Vid,
20    pub k: usize,
21}
22
23impl Default for KShortestPathsConfig {
24    fn default() -> Self {
25        Self {
26            source: Vid::from(0),
27            target: Vid::from(0),
28            k: 1,
29        }
30    }
31}
32
33pub struct KShortestPathsResult {
34    pub paths: Vec<(Vec<Vid>, f64)>, // (path, cost)
35}
36
37impl Algorithm for KShortestPaths {
38    type Config = KShortestPathsConfig;
39    type Result = KShortestPathsResult;
40
41    fn name() -> &'static str {
42        "k_shortest_paths"
43    }
44
45    fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
46        let source_slot = match graph.to_slot(config.source) {
47            Some(s) => s,
48            None => return KShortestPathsResult { paths: Vec::new() },
49        };
50        let target_slot = match graph.to_slot(config.target) {
51            Some(s) => s,
52            None => return KShortestPathsResult { paths: Vec::new() },
53        };
54
55        if config.k == 0 {
56            return KShortestPathsResult { paths: Vec::new() };
57        }
58
59        let mut a: Vec<(Vec<u32>, f64)> = Vec::new();
60
61        // 1. First shortest path
62        let (path0, cost0) = match run_dijkstra(graph, source_slot, target_slot, &HashSet::new()) {
63            Some(res) => res,
64            None => return KShortestPathsResult { paths: Vec::new() },
65        };
66        a.push((path0, cost0));
67
68        let mut b: BinaryHeap<Reverse<(u64, Vec<u32>)>> = BinaryHeap::new();
69
70        // 2. Iterate for k
71        for k in 1..config.k {
72            let prev_path = &a[k - 1].0;
73
74            // The spur node ranges from the first node to the next-to-last node in the previous k-shortest path.
75            for i in 0..prev_path.len() - 1 {
76                let spur_node = prev_path[i];
77                let root_path = &prev_path[0..=i];
78                let root_path_cost = calculate_path_cost(graph, root_path);
79
80                let mut forbidden_edges = HashSet::new();
81
82                for (p_path, _) in &a {
83                    if p_path.len() > i && &p_path[0..=i] == root_path {
84                        forbidden_edges.insert((p_path[i], p_path[i + 1]));
85                    }
86                }
87
88                // Remove root path nodes from graph (except spur node) to ensure loopless
89                // We simulate this by checking if neighbor is in root_path (excluding spur)
90                // Actually Yen's usually ensures loopless.
91                // Standard implementation: disable nodes in root path.
92
93                // Run Dijkstra from spur node to target
94                if let Some((spur_path, spur_cost)) = run_dijkstra_with_constraints(
95                    graph,
96                    spur_node,
97                    target_slot,
98                    &forbidden_edges,
99                    root_path, // forbidden nodes are root_path[0..i] (excluding spur which is root_path[i])
100                ) {
101                    let mut total_path = root_path[0..i].to_vec();
102                    total_path.extend(spur_path);
103                    let total_cost = root_path_cost + spur_cost; // root_path excludes last edge to spur? No, root_path includes spur.
104                    // Wait, root_path includes spur. path cost calculation logic needs to be precise.
105
106                    // Logic check:
107                    // root_path = [s, ..., spur]
108                    // spur_path = [spur, ..., t]
109                    // total = [s, ..., spur, ..., t]
110                    // cost = cost(s..spur) + cost(spur..t)
111
112                    // Using bits for f64 ordering in heap
113                    let entry = Reverse((total_cost.to_bits(), total_path));
114
115                    // Ideally verify uniqueness before pushing to heap, but B handles sorting.
116                    // Duplicate paths might be generated.
117                    // We should check if path is already in B?
118                    // BinaryHeap doesn't support contains.
119                    // Typically B is a set or we push and then dedup when popping.
120
121                    // Optimization: check if already in B? Too slow.
122                    // Just push.
123                    b.push(entry);
124                }
125            }
126
127            if b.is_empty() {
128                break;
129            }
130
131            // Extract best from B
132            // Need to handle duplicates
133            let mut best_path = None;
134
135            while let Some(Reverse((cost_bits, path))) = b.pop() {
136                let cost = f64::from_bits(cost_bits);
137                // Check if path is already in A
138                let exists = a.iter().any(|(p, _)| p == &path);
139                if !exists {
140                    best_path = Some((path, cost));
141                    break;
142                }
143            }
144
145            if let Some(bp) = best_path {
146                a.push(bp);
147            } else {
148                break;
149            }
150        }
151
152        let mapped_paths = a
153            .into_iter()
154            .map(|(path, cost)| {
155                let vids = path.iter().map(|&slot| graph.to_vid(slot)).collect();
156                (vids, cost)
157            })
158            .collect();
159
160        KShortestPathsResult {
161            paths: mapped_paths,
162        }
163    }
164}
165
166fn calculate_path_cost(graph: &GraphProjection, path: &[u32]) -> f64 {
167    let mut cost = 0.0;
168    for i in 0..path.len() - 1 {
169        let u = path[i];
170        let v = path[i + 1];
171        // Find edge weight
172        // Linear scan of neighbors
173        let neighbors = graph.out_neighbors(u);
174        let mut weight = 1.0;
175        if graph.has_weights() {
176            for (idx, &n) in neighbors.iter().enumerate() {
177                if n == v {
178                    weight = graph.out_weight(u, idx);
179                    break;
180                }
181            }
182        }
183        cost += weight;
184    }
185    cost
186}
187
188fn run_dijkstra(
189    graph: &GraphProjection,
190    source: u32,
191    target: u32,
192    forbidden_edges: &HashSet<(u32, u32)>,
193) -> Option<(Vec<u32>, f64)> {
194    run_dijkstra_with_constraints(graph, source, target, forbidden_edges, &[])
195}
196
197fn run_dijkstra_with_constraints(
198    graph: &GraphProjection,
199    source: u32,
200    target: u32,
201    forbidden_edges: &HashSet<(u32, u32)>,
202    forbidden_nodes: &[u32],
203) -> Option<(Vec<u32>, f64)> {
204    let n = graph.vertex_count();
205    let mut dist = vec![f64::INFINITY; n];
206    let mut prev = vec![None; n];
207    let mut heap = BinaryHeap::new();
208
209    dist[source as usize] = 0.0;
210    heap.push(Reverse((0.0f64.to_bits(), source)));
211
212    let forbidden_nodes_set: HashSet<u32> = forbidden_nodes.iter().cloned().collect();
213
214    while let Some(Reverse((d_bits, u))) = heap.pop() {
215        let d = f64::from_bits(d_bits);
216        if d > dist[u as usize] {
217            continue;
218        }
219        if u == target {
220            break;
221        }
222
223        for (i, &v) in graph.out_neighbors(u).iter().enumerate() {
224            if forbidden_nodes_set.contains(&v) {
225                continue;
226            }
227            if forbidden_edges.contains(&(u, v)) {
228                continue;
229            }
230
231            let weight = if graph.has_weights() {
232                graph.out_weight(u, i)
233            } else {
234                1.0
235            };
236            let new_dist = d + weight;
237
238            if new_dist < dist[v as usize] {
239                dist[v as usize] = new_dist;
240                prev[v as usize] = Some(u);
241                heap.push(Reverse((new_dist.to_bits(), v)));
242            }
243        }
244    }
245
246    if dist[target as usize] == f64::INFINITY {
247        return None;
248    }
249
250    let mut path = Vec::new();
251    let mut curr = Some(target);
252    while let Some(slot) = curr {
253        path.push(slot);
254        if slot == source {
255            break;
256        }
257        curr = prev[slot as usize];
258    }
259    path.reverse();
260    Some((path, dist[target as usize]))
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::algo::test_utils::build_test_graph;
267
268    #[test]
269    fn test_ksp_simple() {
270        // 0 -> 1 -> 3 (cost 2)
271        // 0 -> 2 -> 3 (cost 2)
272        // 0 -> 3 (cost 10 - not possible in unit weight without custom builder)
273        // Let's rely on hop count as cost (1.0).
274        // 0->1->3 (2.0)
275        // 0->2->3 (2.0)
276
277        let vids = vec![Vid::from(0), Vid::from(1), Vid::from(2), Vid::from(3)];
278        let edges = vec![
279            (Vid::from(0), Vid::from(1)),
280            (Vid::from(1), Vid::from(3)),
281            (Vid::from(0), Vid::from(2)),
282            (Vid::from(2), Vid::from(3)),
283        ];
284        let graph = build_test_graph(vids, edges);
285
286        let config = KShortestPathsConfig {
287            source: Vid::from(0),
288            target: Vid::from(3),
289            k: 2,
290        };
291
292        let result = KShortestPaths::run(&graph, config);
293        assert_eq!(result.paths.len(), 2);
294        assert_eq!(result.paths[0].1, 2.0);
295        assert_eq!(result.paths[1].1, 2.0);
296    }
297}