Skip to main content

scirs2_optimize/combinatorial/
tsp.rs

1//! Traveling Salesman Problem (TSP) solvers and heuristics.
2//!
3//! Provides nearest-neighbor construction, 2-opt and 3-opt local search,
4//! Or-opt segment relocation, and a Christofides-style MST lower bound.
5
6use scirs2_core::ndarray::Array2;
7use std::cmp::Ordering;
8use std::collections::BinaryHeap;
9
10use crate::error::OptimizeError;
11
12/// Result type for TSP operations.
13pub type TspResult<T> = Result<T, OptimizeError>;
14
15// ── Internal priority queue entry for Prim's MST ─────────────────────────────
16
17#[derive(Clone, PartialEq)]
18struct PrimEntry {
19    cost: f64,
20    vertex: usize,
21}
22
23impl Eq for PrimEntry {}
24
25impl PartialOrd for PrimEntry {
26    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
27        Some(self.cmp(other))
28    }
29}
30
31impl Ord for PrimEntry {
32    fn cmp(&self, other: &Self) -> Ordering {
33        // min-heap: reverse comparison
34        other
35            .cost
36            .partial_cmp(&self.cost)
37            .unwrap_or(Ordering::Equal)
38            .then(self.vertex.cmp(&other.vertex))
39    }
40}
41
42// ── Public helpers ────────────────────────────────────────────────────────────
43
44/// Compute the total length of a tour given a distance matrix.
45///
46/// The tour is assumed to be a permutation of vertices; the last vertex
47/// connects back to the first.
48pub fn tour_length(tour: &[usize], dist: &Array2<f64>) -> f64 {
49    let n = tour.len();
50    if n == 0 {
51        return 0.0;
52    }
53    let mut total = 0.0;
54    for i in 0..n {
55        let from = tour[i];
56        let to = tour[(i + 1) % n];
57        total += dist[[from, to]];
58    }
59    total
60}
61
62/// Greedy nearest-neighbour construction heuristic.
63///
64/// Starting from `start`, repeatedly visit the closest unvisited city.
65/// Returns `(tour, length)`.
66pub fn nearest_neighbor_heuristic(
67    dist: &Array2<f64>,
68    start: usize,
69) -> TspResult<(Vec<usize>, f64)> {
70    let n = dist.nrows();
71    if n == 0 {
72        return Ok((vec![], 0.0));
73    }
74    if start >= n {
75        return Err(OptimizeError::InvalidInput(format!(
76            "start index {start} out of range for {n} cities"
77        )));
78    }
79
80    let mut visited = vec![false; n];
81    let mut tour = Vec::with_capacity(n);
82    let mut current = start;
83    visited[current] = true;
84    tour.push(current);
85
86    for _ in 1..n {
87        let mut best_next = None;
88        let mut best_dist = f64::INFINITY;
89        for j in 0..n {
90            if !visited[j] {
91                let d = dist[[current, j]];
92                if d < best_dist {
93                    best_dist = d;
94                    best_next = Some(j);
95                }
96            }
97        }
98        match best_next {
99            Some(next) => {
100                visited[next] = true;
101                tour.push(next);
102                current = next;
103            }
104            None => break,
105        }
106    }
107
108    let length = tour_length(&tour, dist);
109    Ok((tour, length))
110}
111
112/// 2-opt local search.
113///
114/// Iteratively reverses sub-sequences of the tour whenever doing so reduces
115/// the total length.  Returns the improved tour length.
116pub fn two_opt(tour: &mut Vec<usize>, dist: &Array2<f64>) -> f64 {
117    let n = tour.len();
118    if n < 4 {
119        return tour_length(tour, dist);
120    }
121
122    let mut improved = true;
123    while improved {
124        improved = false;
125        for i in 0..n - 1 {
126            for j in i + 2..n {
127                // Skip the wrap-around edge (n-1, 0) when j == n-1 and i == 0
128                if i == 0 && j == n - 1 {
129                    continue;
130                }
131                let a = tour[i];
132                let b = tour[i + 1];
133                let c = tour[j];
134                let d = tour[(j + 1) % n];
135                let current_cost = dist[[a, b]] + dist[[c, d]];
136                let new_cost = dist[[a, c]] + dist[[b, d]];
137                if new_cost < current_cost - 1e-10 {
138                    // Reverse the segment [i+1 .. j]
139                    tour[i + 1..=j].reverse();
140                    improved = true;
141                }
142            }
143        }
144    }
145
146    tour_length(tour, dist)
147}
148
149/// Evaluate all 3-opt reconnection types for the three edges defined by
150/// positions `i`, `j`, `k` in the tour.
151///
152/// Returns `Some(new_tour)` if a strictly improving reconnection exists,
153/// `None` otherwise.
154///
155/// The three edges removed are:
156///   (tour\[i\], tour\[i+1\]),  (tour\[j\], tour\[j+1\]),  (tour\[k\], tour\[(k+1)%n\])
157pub fn three_opt_move(
158    dist: &Array2<f64>,
159    i: usize,
160    j: usize,
161    k: usize,
162    tour: &[usize],
163) -> Option<Vec<usize>> {
164    let n = tour.len();
165    if n < 6 {
166        return None;
167    }
168    // Validate ordering
169    if !(i < j && j < k && k < n) {
170        return None;
171    }
172
173    let a = tour[i];
174    let b = tour[i + 1];
175    let c = tour[j];
176    let d = tour[j + 1];
177    let e = tour[k];
178    let f = tour[(k + 1) % n];
179
180    let d0 = dist[[a, b]] + dist[[c, d]] + dist[[e, f]];
181
182    // Segment definitions:
183    //   seg1 = tour[0..=i]
184    //   seg2 = tour[i+1..=j]
185    //   seg3 = tour[j+1..=k]
186    //   seg4 = tour[k+1..]
187
188    // We test all 7 non-trivial reconnections (the 8th is the original).
189    let candidates: [(f64, u8); 7] = [
190        // 1: reverse seg2
191        (dist[[a, c]] + dist[[b, d]] + dist[[e, f]], 1),
192        // 2: reverse seg3
193        (dist[[a, b]] + dist[[c, e]] + dist[[d, f]], 2),
194        // 3: reverse seg2 and seg3
195        (dist[[a, c]] + dist[[b, e]] + dist[[d, f]], 3),
196        // 4: move seg3 between seg1 and seg2
197        (dist[[a, d]] + dist[[e, b]] + dist[[c, f]], 4),
198        // 5: move seg2 between seg3 and seg4 (reverse of case 4)
199        (dist[[a, d]] + dist[[e, c]] + dist[[b, f]], 5),
200        // 6: seg1-seg3-seg2-seg4
201        (dist[[a, e]] + dist[[d, b]] + dist[[c, f]], 6),
202        // 7: reverse all three combined
203        (dist[[a, e]] + dist[[d, c]] + dist[[b, f]], 7),
204    ];
205
206    let best = candidates
207        .iter()
208        .min_by(|x, y| x.0.partial_cmp(&y.0).unwrap_or(Ordering::Equal));
209
210    let (best_cost, reconnect_type) = match best {
211        Some(&(c, t)) => (c, t),
212        None => return None,
213    };
214
215    if best_cost >= d0 - 1e-10 {
216        return None;
217    }
218
219    // Build the new tour based on reconnect_type
220    let seg1: Vec<usize> = tour[..=i].to_vec();
221    let seg2: Vec<usize> = tour[i + 1..=j].to_vec();
222    let seg3: Vec<usize> = tour[j + 1..=k].to_vec();
223    let seg4: Vec<usize> = if k + 1 < n {
224        tour[k + 1..].to_vec()
225    } else {
226        vec![]
227    };
228
229    let mut new_tour = seg1;
230    match reconnect_type {
231        1 => {
232            new_tour.extend(seg2.iter().rev());
233            new_tour.extend_from_slice(&seg3);
234        }
235        2 => {
236            new_tour.extend_from_slice(&seg2);
237            new_tour.extend(seg3.iter().rev());
238        }
239        3 => {
240            new_tour.extend(seg2.iter().rev());
241            new_tour.extend(seg3.iter().rev());
242        }
243        4 => {
244            new_tour.extend_from_slice(&seg3);
245            new_tour.extend_from_slice(&seg2);
246        }
247        5 => {
248            new_tour.extend_from_slice(&seg3);
249            new_tour.extend(seg2.iter().rev());
250        }
251        6 => {
252            new_tour.extend(seg3.iter().rev());
253            new_tour.extend_from_slice(&seg2);
254        }
255        7 => {
256            new_tour.extend(seg3.iter().rev());
257            new_tour.extend(seg2.iter().rev());
258        }
259        _ => unreachable!(),
260    }
261    new_tour.extend_from_slice(&seg4);
262    Some(new_tour)
263}
264
265/// Or-opt local search: relocate segments of length 1, 2, or 3.
266///
267/// For each segment of the given length, try inserting it at every other
268/// position in the tour.  Accepts the best improving move found.
269/// Returns the improved tour length after convergence.
270pub fn or_opt(tour: &mut Vec<usize>, dist: &Array2<f64>) -> f64 {
271    let n = tour.len();
272    if n < 4 {
273        return tour_length(tour, dist);
274    }
275
276    let mut improved = true;
277    while improved {
278        improved = false;
279        for seg_len in 1..=3_usize {
280            if n < seg_len + 2 {
281                continue;
282            }
283            'outer: for seg_start in 0..n {
284                let seg_end = (seg_start + seg_len - 1) % n;
285                // Compute cost of removing the segment from its current position
286                let prev = if seg_start == 0 { n - 1 } else { seg_start - 1 };
287                let after = (seg_end + 1) % n;
288                // Skip if wrap-around overlap
289                if prev == seg_end || after == seg_start {
290                    continue;
291                }
292
293                // Removal gain
294                let first_city = tour[seg_start];
295                let last_city = tour[seg_end];
296                let prev_city = tour[prev];
297                let after_city = tour[after];
298
299                let remove_cost = dist[[prev_city, first_city]] + dist[[last_city, after_city]]
300                    - dist[[prev_city, after_city]];
301
302                // Try inserting after position `ins` (not inside the segment)
303                let mut best_gain = 1e-10; // must improve by at least this
304                let mut best_ins = None;
305                let mut best_reverse = false;
306
307                for ins in 0..n {
308                    // Skip positions within or adjacent to segment
309                    let in_seg = if seg_start <= seg_end {
310                        ins >= seg_start && ins <= seg_end
311                    } else {
312                        ins >= seg_start || ins <= seg_end
313                    };
314                    if in_seg || ins == prev {
315                        continue;
316                    }
317                    let ins_next = (ins + 1) % n;
318                    let ins_city = tour[ins];
319                    let ins_next_city = tour[ins_next];
320
321                    // Forward insertion cost delta
322                    let fwd = dist[[ins_city, first_city]] + dist[[last_city, ins_next_city]]
323                        - dist[[ins_city, ins_next_city]];
324                    let gain_fwd = remove_cost - fwd;
325                    if gain_fwd > best_gain {
326                        best_gain = gain_fwd;
327                        best_ins = Some(ins);
328                        best_reverse = false;
329                    }
330
331                    // Reversed insertion
332                    if seg_len > 1 {
333                        let rev = dist[[ins_city, last_city]] + dist[[first_city, ins_next_city]]
334                            - dist[[ins_city, ins_next_city]];
335                        let gain_rev = remove_cost - rev;
336                        if gain_rev > best_gain {
337                            best_gain = gain_rev;
338                            best_ins = Some(ins);
339                            best_reverse = true;
340                        }
341                    }
342                }
343
344                if let Some(ins) = best_ins {
345                    // Extract the segment
346                    let segment: Vec<usize> =
347                        (0..seg_len).map(|k| tour[(seg_start + k) % n]).collect();
348                    let seg_set: std::collections::HashSet<usize> =
349                        segment.iter().cloned().collect();
350
351                    // Build new tour without the segment, then insert
352                    let remaining: Vec<usize> = tour
353                        .iter()
354                        .cloned()
355                        .filter(|v| !seg_set.contains(v))
356                        .collect();
357
358                    // Find insertion position in remaining
359                    let ins_city = tour[ins];
360                    let ins_pos = remaining.iter().position(|&v| v == ins_city).unwrap_or(0);
361
362                    let mut new_tour: Vec<usize> = Vec::with_capacity(n);
363                    new_tour.extend_from_slice(&remaining[..=ins_pos]);
364                    if best_reverse {
365                        new_tour.extend(segment.iter().rev());
366                    } else {
367                        new_tour.extend_from_slice(&segment);
368                    }
369                    if ins_pos + 1 < remaining.len() {
370                        new_tour.extend_from_slice(&remaining[ins_pos + 1..]);
371                    }
372
373                    if new_tour.len() == n {
374                        *tour = new_tour;
375                        improved = true;
376                        break 'outer;
377                    }
378                }
379            }
380        }
381    }
382
383    tour_length(tour, dist)
384}
385
386/// Compute a minimum spanning tree lower bound using Prim's algorithm.
387///
388/// The MST weight is a classical lower bound for TSP on metric instances.
389pub fn mst_lower_bound(dist: &Array2<f64>) -> f64 {
390    let n = dist.nrows();
391    if n == 0 {
392        return 0.0;
393    }
394    if n == 1 {
395        return 0.0;
396    }
397
398    let mut in_mst = vec![false; n];
399    let mut min_edge = vec![f64::INFINITY; n];
400    min_edge[0] = 0.0;
401
402    let mut heap: BinaryHeap<PrimEntry> = BinaryHeap::new();
403    heap.push(PrimEntry {
404        cost: 0.0,
405        vertex: 0,
406    });
407
408    let mut mst_weight = 0.0;
409
410    while let Some(PrimEntry { cost, vertex }) = heap.pop() {
411        if in_mst[vertex] {
412            continue;
413        }
414        in_mst[vertex] = true;
415        mst_weight += cost;
416
417        for j in 0..n {
418            if !in_mst[j] {
419                let d = dist[[vertex, j]];
420                if d < min_edge[j] {
421                    min_edge[j] = d;
422                    heap.push(PrimEntry { cost: d, vertex: j });
423                }
424            }
425        }
426    }
427
428    mst_weight
429}
430
431/// High-level TSP solver that chains NN heuristic → 2-opt → Or-opt.
432pub struct TspSolver {
433    dist: Array2<f64>,
434}
435
436impl TspSolver {
437    /// Create a new solver with the given distance matrix.
438    ///
439    /// # Errors
440    /// Returns an error if the matrix is not square.
441    pub fn new(dist: Array2<f64>) -> TspResult<Self> {
442        if dist.nrows() != dist.ncols() {
443            return Err(OptimizeError::InvalidInput(
444                "Distance matrix must be square".to_string(),
445            ));
446        }
447        Ok(Self { dist })
448    }
449
450    /// Solve using nearest-neighbour construction followed by 2-opt and Or-opt.
451    ///
452    /// Tries every city as a starting vertex for NN and keeps the best result.
453    pub fn solve(&self) -> TspResult<(Vec<usize>, f64)> {
454        let n = self.dist.nrows();
455        if n == 0 {
456            return Ok((vec![], 0.0));
457        }
458
459        let mut best_tour = vec![];
460        let mut best_len = f64::INFINITY;
461
462        for start in 0..n {
463            let (mut tour, _) = nearest_neighbor_heuristic(&self.dist, start)?;
464            two_opt(&mut tour, &self.dist);
465            or_opt(&mut tour, &self.dist);
466            let len = tour_length(&tour, &self.dist);
467            if len < best_len {
468                best_len = len;
469                best_tour = tour;
470            }
471        }
472
473        Ok((best_tour, best_len))
474    }
475
476    /// Return the MST-based lower bound on the optimal tour length.
477    pub fn lower_bound(&self) -> f64 {
478        mst_lower_bound(&self.dist)
479    }
480}
481
482// ─────────────────────────────────────────────────────────────────────────────
483// Tests
484// ─────────────────────────────────────────────────────────────────────────────
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use scirs2_core::ndarray::array;
489
490    fn square_dist() -> Array2<f64> {
491        // 4-city square: optimal tour length = 4.0
492        array![
493            [0.0, 1.0, 1.414, 1.0],
494            [1.0, 0.0, 1.0, 1.414],
495            [1.414, 1.0, 0.0, 1.0],
496            [1.0, 1.414, 1.0, 0.0]
497        ]
498    }
499
500    #[test]
501    fn test_tour_length() {
502        let dist = square_dist();
503        let tour = vec![0, 1, 2, 3];
504        let len = tour_length(&tour, &dist);
505        // 0→1 + 1→2 + 2→3 + 3→0 = 1+1+1+1 = 4
506        assert!((len - 4.0).abs() < 1e-6);
507    }
508
509    #[test]
510    fn test_nearest_neighbor() {
511        let dist = square_dist();
512        let (tour, len) = nearest_neighbor_heuristic(&dist, 0).expect("unexpected None or Err");
513        assert_eq!(tour.len(), 4);
514        assert!(len > 0.0);
515    }
516
517    #[test]
518    fn test_two_opt_improves() {
519        let dist = square_dist();
520        // A suboptimal tour: 0→2→1→3 has length 1.414+1+1.414+1 = 4.828
521        let mut tour = vec![0, 2, 1, 3];
522        let original_len = tour_length(&tour, &dist);
523        let new_len = two_opt(&mut tour, &dist);
524        assert!(new_len <= original_len + 1e-9);
525    }
526
527    #[test]
528    fn test_or_opt() {
529        let dist = square_dist();
530        let mut tour = vec![0, 1, 2, 3];
531        let len = or_opt(&mut tour, &dist);
532        assert!(len > 0.0);
533        assert_eq!(tour.len(), 4);
534    }
535
536    #[test]
537    fn test_mst_lower_bound() {
538        let dist = square_dist();
539        let lb = mst_lower_bound(&dist);
540        // MST of the square has weight 3.0 (three unit edges)
541        assert!(lb > 0.0);
542        assert!(lb <= 4.0 + 1e-6); // must be ≤ optimal tour length
543    }
544
545    #[test]
546    fn test_solver_small() {
547        let dist = square_dist();
548        let solver = TspSolver::new(dist).expect("failed to create solver");
549        let (tour, len) = solver.solve().expect("unexpected None or Err");
550        assert_eq!(tour.len(), 4);
551        // Optimal is 4.0
552        assert!(len <= 4.5);
553    }
554
555    #[test]
556    fn test_three_opt_move() {
557        let dist = square_dist();
558        // With only 4 nodes any 3-opt call would require i<j<k<4
559        // Use a larger tour to exercise the logic
560        let n = 6;
561        let mut big_dist = Array2::<f64>::zeros((n, n));
562        for r in 0..n {
563            for c in 0..n {
564                if r != c {
565                    let dx = (r as f64) - (c as f64);
566                    big_dist[[r, c]] = dx.abs();
567                }
568            }
569        }
570        let tour: Vec<usize> = vec![0, 1, 2, 3, 4, 5];
571        // Just test it runs without panic
572        let _ = three_opt_move(&big_dist, 0, 2, 4, &tour);
573    }
574
575    #[test]
576    fn test_invalid_start() {
577        let dist = square_dist();
578        assert!(nearest_neighbor_heuristic(&dist, 10).is_err());
579    }
580
581    #[test]
582    fn test_empty_tour() {
583        let dist: Array2<f64> = Array2::zeros((0, 0));
584        let (tour, len) = nearest_neighbor_heuristic(&dist, 0).expect("unexpected None or Err");
585        assert!(tour.is_empty());
586        assert_eq!(len, 0.0);
587    }
588}