scirs2_optimize/combinatorial/
tsp.rs1use scirs2_core::ndarray::Array2;
7use std::cmp::Ordering;
8use std::collections::BinaryHeap;
9
10use crate::error::OptimizeError;
11
12pub type TspResult<T> = Result<T, OptimizeError>;
14
15#[derive(Clone, PartialEq)]
18struct PrimEntry {
19 cost: f64,
20 vertex: usize,
21}
22
23impl Eq for PrimEntry {}
24
25impl PartialOrd for PrimEntry {
26 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
27 Some(self.cmp(other))
28 }
29}
30
31impl Ord for PrimEntry {
32 fn cmp(&self, other: &Self) -> Ordering {
33 other
35 .cost
36 .partial_cmp(&self.cost)
37 .unwrap_or(Ordering::Equal)
38 .then(self.vertex.cmp(&other.vertex))
39 }
40}
41
42pub fn tour_length(tour: &[usize], dist: &Array2<f64>) -> f64 {
49 let n = tour.len();
50 if n == 0 {
51 return 0.0;
52 }
53 let mut total = 0.0;
54 for i in 0..n {
55 let from = tour[i];
56 let to = tour[(i + 1) % n];
57 total += dist[[from, to]];
58 }
59 total
60}
61
62pub fn nearest_neighbor_heuristic(
67 dist: &Array2<f64>,
68 start: usize,
69) -> TspResult<(Vec<usize>, f64)> {
70 let n = dist.nrows();
71 if n == 0 {
72 return Ok((vec![], 0.0));
73 }
74 if start >= n {
75 return Err(OptimizeError::InvalidInput(format!(
76 "start index {start} out of range for {n} cities"
77 )));
78 }
79
80 let mut visited = vec![false; n];
81 let mut tour = Vec::with_capacity(n);
82 let mut current = start;
83 visited[current] = true;
84 tour.push(current);
85
86 for _ in 1..n {
87 let mut best_next = None;
88 let mut best_dist = f64::INFINITY;
89 for j in 0..n {
90 if !visited[j] {
91 let d = dist[[current, j]];
92 if d < best_dist {
93 best_dist = d;
94 best_next = Some(j);
95 }
96 }
97 }
98 match best_next {
99 Some(next) => {
100 visited[next] = true;
101 tour.push(next);
102 current = next;
103 }
104 None => break,
105 }
106 }
107
108 let length = tour_length(&tour, dist);
109 Ok((tour, length))
110}
111
112pub fn two_opt(tour: &mut Vec<usize>, dist: &Array2<f64>) -> f64 {
117 let n = tour.len();
118 if n < 4 {
119 return tour_length(tour, dist);
120 }
121
122 let mut improved = true;
123 while improved {
124 improved = false;
125 for i in 0..n - 1 {
126 for j in i + 2..n {
127 if i == 0 && j == n - 1 {
129 continue;
130 }
131 let a = tour[i];
132 let b = tour[i + 1];
133 let c = tour[j];
134 let d = tour[(j + 1) % n];
135 let current_cost = dist[[a, b]] + dist[[c, d]];
136 let new_cost = dist[[a, c]] + dist[[b, d]];
137 if new_cost < current_cost - 1e-10 {
138 tour[i + 1..=j].reverse();
140 improved = true;
141 }
142 }
143 }
144 }
145
146 tour_length(tour, dist)
147}
148
149pub fn three_opt_move(
158 dist: &Array2<f64>,
159 i: usize,
160 j: usize,
161 k: usize,
162 tour: &[usize],
163) -> Option<Vec<usize>> {
164 let n = tour.len();
165 if n < 6 {
166 return None;
167 }
168 if !(i < j && j < k && k < n) {
170 return None;
171 }
172
173 let a = tour[i];
174 let b = tour[i + 1];
175 let c = tour[j];
176 let d = tour[j + 1];
177 let e = tour[k];
178 let f = tour[(k + 1) % n];
179
180 let d0 = dist[[a, b]] + dist[[c, d]] + dist[[e, f]];
181
182 let candidates: [(f64, u8); 7] = [
190 (dist[[a, c]] + dist[[b, d]] + dist[[e, f]], 1),
192 (dist[[a, b]] + dist[[c, e]] + dist[[d, f]], 2),
194 (dist[[a, c]] + dist[[b, e]] + dist[[d, f]], 3),
196 (dist[[a, d]] + dist[[e, b]] + dist[[c, f]], 4),
198 (dist[[a, d]] + dist[[e, c]] + dist[[b, f]], 5),
200 (dist[[a, e]] + dist[[d, b]] + dist[[c, f]], 6),
202 (dist[[a, e]] + dist[[d, c]] + dist[[b, f]], 7),
204 ];
205
206 let best = candidates
207 .iter()
208 .min_by(|x, y| x.0.partial_cmp(&y.0).unwrap_or(Ordering::Equal));
209
210 let (best_cost, reconnect_type) = match best {
211 Some(&(c, t)) => (c, t),
212 None => return None,
213 };
214
215 if best_cost >= d0 - 1e-10 {
216 return None;
217 }
218
219 let seg1: Vec<usize> = tour[..=i].to_vec();
221 let seg2: Vec<usize> = tour[i + 1..=j].to_vec();
222 let seg3: Vec<usize> = tour[j + 1..=k].to_vec();
223 let seg4: Vec<usize> = if k + 1 < n {
224 tour[k + 1..].to_vec()
225 } else {
226 vec![]
227 };
228
229 let mut new_tour = seg1;
230 match reconnect_type {
231 1 => {
232 new_tour.extend(seg2.iter().rev());
233 new_tour.extend_from_slice(&seg3);
234 }
235 2 => {
236 new_tour.extend_from_slice(&seg2);
237 new_tour.extend(seg3.iter().rev());
238 }
239 3 => {
240 new_tour.extend(seg2.iter().rev());
241 new_tour.extend(seg3.iter().rev());
242 }
243 4 => {
244 new_tour.extend_from_slice(&seg3);
245 new_tour.extend_from_slice(&seg2);
246 }
247 5 => {
248 new_tour.extend_from_slice(&seg3);
249 new_tour.extend(seg2.iter().rev());
250 }
251 6 => {
252 new_tour.extend(seg3.iter().rev());
253 new_tour.extend_from_slice(&seg2);
254 }
255 7 => {
256 new_tour.extend(seg3.iter().rev());
257 new_tour.extend(seg2.iter().rev());
258 }
259 _ => unreachable!(),
260 }
261 new_tour.extend_from_slice(&seg4);
262 Some(new_tour)
263}
264
265pub fn or_opt(tour: &mut Vec<usize>, dist: &Array2<f64>) -> f64 {
271 let n = tour.len();
272 if n < 4 {
273 return tour_length(tour, dist);
274 }
275
276 let mut improved = true;
277 while improved {
278 improved = false;
279 for seg_len in 1..=3_usize {
280 if n < seg_len + 2 {
281 continue;
282 }
283 'outer: for seg_start in 0..n {
284 let seg_end = (seg_start + seg_len - 1) % n;
285 let prev = if seg_start == 0 { n - 1 } else { seg_start - 1 };
287 let after = (seg_end + 1) % n;
288 if prev == seg_end || after == seg_start {
290 continue;
291 }
292
293 let first_city = tour[seg_start];
295 let last_city = tour[seg_end];
296 let prev_city = tour[prev];
297 let after_city = tour[after];
298
299 let remove_cost = dist[[prev_city, first_city]] + dist[[last_city, after_city]]
300 - dist[[prev_city, after_city]];
301
302 let mut best_gain = 1e-10; let mut best_ins = None;
305 let mut best_reverse = false;
306
307 for ins in 0..n {
308 let in_seg = if seg_start <= seg_end {
310 ins >= seg_start && ins <= seg_end
311 } else {
312 ins >= seg_start || ins <= seg_end
313 };
314 if in_seg || ins == prev {
315 continue;
316 }
317 let ins_next = (ins + 1) % n;
318 let ins_city = tour[ins];
319 let ins_next_city = tour[ins_next];
320
321 let fwd = dist[[ins_city, first_city]] + dist[[last_city, ins_next_city]]
323 - dist[[ins_city, ins_next_city]];
324 let gain_fwd = remove_cost - fwd;
325 if gain_fwd > best_gain {
326 best_gain = gain_fwd;
327 best_ins = Some(ins);
328 best_reverse = false;
329 }
330
331 if seg_len > 1 {
333 let rev = dist[[ins_city, last_city]] + dist[[first_city, ins_next_city]]
334 - dist[[ins_city, ins_next_city]];
335 let gain_rev = remove_cost - rev;
336 if gain_rev > best_gain {
337 best_gain = gain_rev;
338 best_ins = Some(ins);
339 best_reverse = true;
340 }
341 }
342 }
343
344 if let Some(ins) = best_ins {
345 let segment: Vec<usize> =
347 (0..seg_len).map(|k| tour[(seg_start + k) % n]).collect();
348 let seg_set: std::collections::HashSet<usize> =
349 segment.iter().cloned().collect();
350
351 let remaining: Vec<usize> = tour
353 .iter()
354 .cloned()
355 .filter(|v| !seg_set.contains(v))
356 .collect();
357
358 let ins_city = tour[ins];
360 let ins_pos = remaining.iter().position(|&v| v == ins_city).unwrap_or(0);
361
362 let mut new_tour: Vec<usize> = Vec::with_capacity(n);
363 new_tour.extend_from_slice(&remaining[..=ins_pos]);
364 if best_reverse {
365 new_tour.extend(segment.iter().rev());
366 } else {
367 new_tour.extend_from_slice(&segment);
368 }
369 if ins_pos + 1 < remaining.len() {
370 new_tour.extend_from_slice(&remaining[ins_pos + 1..]);
371 }
372
373 if new_tour.len() == n {
374 *tour = new_tour;
375 improved = true;
376 break 'outer;
377 }
378 }
379 }
380 }
381 }
382
383 tour_length(tour, dist)
384}
385
386pub fn mst_lower_bound(dist: &Array2<f64>) -> f64 {
390 let n = dist.nrows();
391 if n == 0 {
392 return 0.0;
393 }
394 if n == 1 {
395 return 0.0;
396 }
397
398 let mut in_mst = vec![false; n];
399 let mut min_edge = vec![f64::INFINITY; n];
400 min_edge[0] = 0.0;
401
402 let mut heap: BinaryHeap<PrimEntry> = BinaryHeap::new();
403 heap.push(PrimEntry {
404 cost: 0.0,
405 vertex: 0,
406 });
407
408 let mut mst_weight = 0.0;
409
410 while let Some(PrimEntry { cost, vertex }) = heap.pop() {
411 if in_mst[vertex] {
412 continue;
413 }
414 in_mst[vertex] = true;
415 mst_weight += cost;
416
417 for j in 0..n {
418 if !in_mst[j] {
419 let d = dist[[vertex, j]];
420 if d < min_edge[j] {
421 min_edge[j] = d;
422 heap.push(PrimEntry { cost: d, vertex: j });
423 }
424 }
425 }
426 }
427
428 mst_weight
429}
430
431pub struct TspSolver {
433 dist: Array2<f64>,
434}
435
436impl TspSolver {
437 pub fn new(dist: Array2<f64>) -> TspResult<Self> {
442 if dist.nrows() != dist.ncols() {
443 return Err(OptimizeError::InvalidInput(
444 "Distance matrix must be square".to_string(),
445 ));
446 }
447 Ok(Self { dist })
448 }
449
450 pub fn solve(&self) -> TspResult<(Vec<usize>, f64)> {
454 let n = self.dist.nrows();
455 if n == 0 {
456 return Ok((vec![], 0.0));
457 }
458
459 let mut best_tour = vec![];
460 let mut best_len = f64::INFINITY;
461
462 for start in 0..n {
463 let (mut tour, _) = nearest_neighbor_heuristic(&self.dist, start)?;
464 two_opt(&mut tour, &self.dist);
465 or_opt(&mut tour, &self.dist);
466 let len = tour_length(&tour, &self.dist);
467 if len < best_len {
468 best_len = len;
469 best_tour = tour;
470 }
471 }
472
473 Ok((best_tour, best_len))
474 }
475
476 pub fn lower_bound(&self) -> f64 {
478 mst_lower_bound(&self.dist)
479 }
480}
481
482#[cfg(test)]
486mod tests {
487 use super::*;
488 use scirs2_core::ndarray::array;
489
490 fn square_dist() -> Array2<f64> {
491 array![
493 [0.0, 1.0, 1.414, 1.0],
494 [1.0, 0.0, 1.0, 1.414],
495 [1.414, 1.0, 0.0, 1.0],
496 [1.0, 1.414, 1.0, 0.0]
497 ]
498 }
499
500 #[test]
501 fn test_tour_length() {
502 let dist = square_dist();
503 let tour = vec![0, 1, 2, 3];
504 let len = tour_length(&tour, &dist);
505 assert!((len - 4.0).abs() < 1e-6);
507 }
508
509 #[test]
510 fn test_nearest_neighbor() {
511 let dist = square_dist();
512 let (tour, len) = nearest_neighbor_heuristic(&dist, 0).expect("unexpected None or Err");
513 assert_eq!(tour.len(), 4);
514 assert!(len > 0.0);
515 }
516
517 #[test]
518 fn test_two_opt_improves() {
519 let dist = square_dist();
520 let mut tour = vec![0, 2, 1, 3];
522 let original_len = tour_length(&tour, &dist);
523 let new_len = two_opt(&mut tour, &dist);
524 assert!(new_len <= original_len + 1e-9);
525 }
526
527 #[test]
528 fn test_or_opt() {
529 let dist = square_dist();
530 let mut tour = vec![0, 1, 2, 3];
531 let len = or_opt(&mut tour, &dist);
532 assert!(len > 0.0);
533 assert_eq!(tour.len(), 4);
534 }
535
536 #[test]
537 fn test_mst_lower_bound() {
538 let dist = square_dist();
539 let lb = mst_lower_bound(&dist);
540 assert!(lb > 0.0);
542 assert!(lb <= 4.0 + 1e-6); }
544
545 #[test]
546 fn test_solver_small() {
547 let dist = square_dist();
548 let solver = TspSolver::new(dist).expect("failed to create solver");
549 let (tour, len) = solver.solve().expect("unexpected None or Err");
550 assert_eq!(tour.len(), 4);
551 assert!(len <= 4.5);
553 }
554
555 #[test]
556 fn test_three_opt_move() {
557 let dist = square_dist();
558 let n = 6;
561 let mut big_dist = Array2::<f64>::zeros((n, n));
562 for r in 0..n {
563 for c in 0..n {
564 if r != c {
565 let dx = (r as f64) - (c as f64);
566 big_dist[[r, c]] = dx.abs();
567 }
568 }
569 }
570 let tour: Vec<usize> = vec![0, 1, 2, 3, 4, 5];
571 let _ = three_opt_move(&big_dist, 0, 2, 4, &tour);
573 }
574
575 #[test]
576 fn test_invalid_start() {
577 let dist = square_dist();
578 assert!(nearest_neighbor_heuristic(&dist, 10).is_err());
579 }
580
581 #[test]
582 fn test_empty_tour() {
583 let dist: Array2<f64> = Array2::zeros((0, 0));
584 let (tour, len) = nearest_neighbor_heuristic(&dist, 0).expect("unexpected None or Err");
585 assert!(tour.is_empty());
586 assert_eq!(len, 0.0);
587 }
588}