1use std::cmp::Ordering;
35use std::collections::{BinaryHeap, HashMap, HashSet};
36
37#[derive(Debug, Clone)]
39pub struct YenPath<N> {
40 pub nodes: Vec<N>,
42 pub cost: f64,
44}
45
46#[derive(Clone)]
47struct DijkEntry<N> {
48 node: N,
49 cost: f64,
50}
51
52impl<N: PartialEq> PartialEq for DijkEntry<N> {
53 fn eq(&self, other: &Self) -> bool {
54 self.node == other.node
55 }
56}
57
58impl<N: Eq> Eq for DijkEntry<N> {}
59
60impl<N: Eq> PartialOrd for DijkEntry<N> {
61 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
62 Some(self.cmp(other))
63 }
64}
65
66impl<N: Eq> Ord for DijkEntry<N> {
67 fn cmp(&self, other: &Self) -> Ordering {
68 other
69 .cost
70 .partial_cmp(&self.cost)
71 .unwrap_or(Ordering::Equal)
72 }
73}
74
75fn dijkstra_path<N, FN, I>(
78 src: N,
79 dest: N,
80 neighbors: &mut FN,
81 excluded_edges: &HashSet<(N, N)>,
82 excluded_nodes: &HashSet<N>,
83) -> Option<YenPath<N>>
84where
85 N: Clone + Eq + std::hash::Hash,
86 FN: FnMut(&N) -> I,
87 I: IntoIterator<Item = (N, f64)>,
88{
89 let mut g_scores: HashMap<N, f64> = HashMap::new();
90 let mut came_from: HashMap<N, N> = HashMap::new();
91 let mut closed: HashSet<N> = HashSet::new();
92 let mut heap: BinaryHeap<DijkEntry<N>> = BinaryHeap::new();
93
94 g_scores.insert(src.clone(), 0.0);
95 heap.push(DijkEntry {
96 node: src.clone(),
97 cost: 0.0,
98 });
99
100 while let Some(current) = heap.pop() {
101 if current.node == dest {
102 let mut path = Vec::new();
104 let mut cur = dest.clone();
105 loop {
106 path.push(cur.clone());
107 match came_from.get(&cur) {
108 Some(prev) => cur = prev.clone(),
109 None => break,
110 }
111 }
112 path.reverse();
113 return Some(YenPath {
114 nodes: path,
115 cost: current.cost,
116 });
117 }
118
119 if !closed.insert(current.node.clone()) {
120 continue;
121 }
122
123 let g = current.cost;
124 for (nbr, edge_cost) in neighbors(¤t.node) {
125 if closed.contains(&nbr) {
126 continue;
127 }
128 if excluded_nodes.contains(&nbr) {
129 continue;
130 }
131 if excluded_edges.contains(&(current.node.clone(), nbr.clone())) {
132 continue;
133 }
134 let tentative = g + edge_cost;
135 let prev = g_scores.get(&nbr).copied().unwrap_or(f64::INFINITY);
136 if tentative < prev {
137 g_scores.insert(nbr.clone(), tentative);
138 came_from.insert(nbr.clone(), current.node.clone());
139 heap.push(DijkEntry {
140 node: nbr,
141 cost: tentative,
142 });
143 }
144 }
145 }
146
147 None
148}
149
150pub fn yen_k_shortest<N, FN, I>(src: N, dest: N, k: usize, mut neighbors: FN) -> Vec<YenPath<N>>
164where
165 N: Clone + Eq + std::hash::Hash,
166 FN: FnMut(&N) -> I,
167 I: IntoIterator<Item = (N, f64)>,
168{
169 if k == 0 {
170 return Vec::new();
171 }
172
173 let first = dijkstra_path(
175 src.clone(),
176 dest.clone(),
177 &mut neighbors,
178 &HashSet::new(),
179 &HashSet::new(),
180 );
181 let Some(first) = first else {
182 return Vec::new();
183 };
184
185 let mut accepted: Vec<YenPath<N>> = vec![first];
186
187 let mut candidates: BinaryHeap<CandidateEntry<N>> = BinaryHeap::new();
190 let mut candidate_set: HashSet<Vec<N>> = HashSet::new();
191
192 for ki in 1..k {
193 let prev_path = &accepted[ki - 1].nodes;
194
195 for spur_idx in 0..prev_path.len().saturating_sub(1) {
197 let spur_node = prev_path[spur_idx].clone();
198 let root_path: Vec<N> = prev_path[..=spur_idx].to_vec();
199
200 let mut excluded_edges: HashSet<(N, N)> = HashSet::new();
202 for accepted_path in &accepted {
203 if accepted_path.nodes.len() > spur_idx
204 && accepted_path.nodes[..=spur_idx] == root_path[..]
205 {
206 excluded_edges.insert((
207 accepted_path.nodes[spur_idx].clone(),
208 accepted_path.nodes[spur_idx + 1].clone(),
209 ));
210 }
211 }
212
213 let mut excluded_nodes: HashSet<N> = HashSet::new();
215 for node in &root_path[..spur_idx] {
216 excluded_nodes.insert(node.clone());
217 }
218
219 if let Some(spur_path) = dijkstra_path(
220 spur_node,
221 dest.clone(),
222 &mut neighbors,
223 &excluded_edges,
224 &excluded_nodes,
225 ) {
226 let mut full_nodes = root_path.clone();
228 full_nodes.extend_from_slice(&spur_path.nodes[1..]);
229
230 let mut seen = HashSet::new();
232 if full_nodes.iter().any(|n| !seen.insert(n.clone())) {
233 continue;
234 }
235
236 if !candidate_set.contains(&full_nodes) {
237 let cost = path_cost(&full_nodes, &mut neighbors);
239 candidate_set.insert(full_nodes.clone());
240 candidates.push(CandidateEntry {
241 cost,
242 nodes: full_nodes,
243 });
244 }
245 }
246 }
247
248 if let Some(best) = candidates.pop() {
250 accepted.push(YenPath {
251 nodes: best.nodes,
252 cost: best.cost,
253 });
254 } else {
255 break;
256 }
257 }
258
259 accepted
260}
261
262fn path_cost<N, FN, I>(nodes: &[N], neighbors: &mut FN) -> f64
265where
266 N: Clone + Eq + std::hash::Hash,
267 FN: FnMut(&N) -> I,
268 I: IntoIterator<Item = (N, f64)>,
269{
270 let mut total = 0.0;
271 for pair in nodes.windows(2) {
272 let from = &pair[0];
273 let to = &pair[1];
274 let edge_cost = neighbors(from)
276 .into_iter()
277 .find(|(n, _)| n == to)
278 .map(|(_, c)| c)
279 .unwrap_or(0.0);
280 total += edge_cost;
281 }
282 total
283}
284
285struct CandidateEntry<N> {
286 cost: f64,
287 nodes: Vec<N>,
288}
289
290impl<N: PartialEq> PartialEq for CandidateEntry<N> {
291 fn eq(&self, other: &Self) -> bool {
292 self.cost == other.cost && self.nodes == other.nodes
293 }
294}
295
296impl<N: Eq> Eq for CandidateEntry<N> {}
297
298impl<N: Eq> PartialOrd for CandidateEntry<N> {
299 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
300 Some(self.cmp(other))
301 }
302}
303
304impl<N: Eq> Ord for CandidateEntry<N> {
305 fn cmp(&self, other: &Self) -> Ordering {
306 other
308 .cost
309 .partial_cmp(&self.cost)
310 .unwrap_or(Ordering::Equal)
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 fn diamond_neighbors(node: &usize) -> Vec<(usize, f64)> {
319 match *node {
320 0 => vec![(1, 1.0), (2, 2.0)],
321 1 => vec![(3, 2.0)],
322 2 => vec![(3, 1.0)],
323 _ => vec![],
324 }
325 }
326
327 #[test]
328 fn finds_two_paths_on_diamond() {
329 let paths = yen_k_shortest(0, 3, 3, diamond_neighbors);
330 assert_eq!(paths.len(), 2);
331 assert_eq!(paths[0].nodes, vec![0, 1, 3]);
332 assert!((paths[0].cost - 3.0).abs() < 1e-6);
333 assert_eq!(paths[1].nodes, vec![0, 2, 3]);
334 assert!((paths[1].cost - 3.0).abs() < 1e-6);
335 }
336
337 #[test]
338 fn single_path_on_line() {
339 let neighbors = |node: &usize| -> Vec<(usize, f64)> {
340 match *node {
341 0 => vec![(1, 1.0)],
342 1 => vec![(2, 1.0)],
343 _ => vec![],
344 }
345 };
346 let paths = yen_k_shortest(0, 2, 5, neighbors);
347 assert_eq!(paths.len(), 1);
348 assert_eq!(paths[0].nodes, vec![0, 1, 2]);
349 assert!((paths[0].cost - 2.0).abs() < 1e-6);
350 }
351
352 #[test]
353 fn no_path_returns_empty() {
354 let neighbors = |_: &usize| -> Vec<(usize, f64)> { vec![] };
355 let paths = yen_k_shortest(0, 5, 3, neighbors);
356 assert!(paths.is_empty());
357 }
358
359 #[test]
360 fn k_zero_returns_empty() {
361 let paths = yen_k_shortest(0, 3, 0, diamond_neighbors);
362 assert!(paths.is_empty());
363 }
364
365 #[test]
366 fn paths_are_loopless() {
367 let neighbors = |node: &usize| -> Vec<(usize, f64)> {
369 match *node {
370 0 => vec![(1, 1.0), (2, 3.0)],
371 1 => vec![(0, 1.0), (2, 1.0)],
372 2 => vec![(3, 1.0)],
373 _ => vec![],
374 }
375 };
376 let paths = yen_k_shortest(0, 3, 5, neighbors);
377 for path in &paths {
378 let mut seen = HashSet::new();
379 assert!(
380 path.nodes.iter().all(|n| seen.insert(n)),
381 "Path contains loop: {:?}",
382 path.nodes
383 );
384 }
385 }
386
387 #[test]
388 fn paths_sorted_by_cost() {
389 let neighbors = |node: &usize| -> Vec<(usize, f64)> {
391 match *node {
392 0 => vec![(1, 1.0), (2, 2.0), (3, 5.0)],
393 1 => vec![(4, 1.0)],
394 2 => vec![(4, 1.0)],
395 3 => vec![(4, 1.0)],
396 _ => vec![],
397 }
398 };
399 let paths = yen_k_shortest(0, 4, 5, neighbors);
400 for i in 1..paths.len() {
401 assert!(
402 paths[i].cost >= paths[i - 1].cost - 1e-12,
403 "Paths not sorted: cost[{}]={} < cost[{}]={}",
404 i,
405 paths[i].cost,
406 i - 1,
407 paths[i - 1].cost
408 );
409 }
410 }
411}