Skip to main content

scirs2_optimize/integer/
milp_branch_and_bound.rs

1//! Mixed-Integer Linear Programming via Branch-and-Bound
2//!
3//! This module provides the top-level [`MilpProblem`] struct that encapsulates
4//! the MILP formulation and the [`branch_and_bound`] function that solves it.
5//!
6//! # Problem formulation
7//!
8//! ```text
9//! minimize    c^T x
10//! subject to  A x <= b               (linear inequalities)
11//!             lb <= x <= ub           (variable bounds)
12//!             x_i ∈ Z  for i in I    (integrality constraints)
13//! ```
14//!
15//! # Algorithm
16//!
17//! - **LP relaxation** at each node via revised simplex (Phase I + Phase II).
18//! - **Variable selection**: most-fractional or strong branching.
19//! - **Node selection**: best-first (lowest LP lower bound).
20//! - **Pruning**: infeasible nodes, integral nodes, bound-based cutoff.
21//!
22//! # References
23//! - Land, A.H. & Doig, A.G. (1960). "An automatic method of solving discrete
24//!   programming problems." Econometrica, 28(3), 497–520.
25//! - Wolsey, L.A. (1998). *Integer Programming*. Wiley.
26
27use crate::error::{OptimizeError, OptimizeResult};
28use scirs2_core::ndarray::{Array1, Array2};
29use std::cmp::Ordering;
30use std::collections::BinaryHeap;
31
32// ─────────────────────────────────────────────────────────────────────────────
33// MILP Problem Definition
34// ─────────────────────────────────────────────────────────────────────────────
35
36/// Mixed-Integer Linear Programming problem.
37///
38/// Represents:
39/// ```text
40/// minimize    c^T x
41/// subject to  A x <= b
42///             lb <= x <= ub
43///             x_i ∈ Z  for all i in integer_vars
44/// ```
45#[derive(Debug, Clone)]
46pub struct MilpProblem {
47    /// Objective coefficients (length n).
48    pub c: Array1<f64>,
49    /// Constraint matrix (m × n); represents `A x <= b`.
50    pub a: Array2<f64>,
51    /// Constraint RHS (length m).
52    pub b: Array1<f64>,
53    /// Lower bounds for each variable (length n; default 0).
54    pub lb: Array1<f64>,
55    /// Upper bounds for each variable (length n; default +∞).
56    pub ub: Array1<f64>,
57    /// Indices of variables that must be integer.
58    pub integer_vars: Vec<usize>,
59}
60
61impl MilpProblem {
62    /// Construct a new MILP problem.
63    ///
64    /// # Panics
65    /// Does not panic; returns [`OptimizeError::InvalidInput`] on invalid dimensions.
66    pub fn new(
67        c: Array1<f64>,
68        a: Array2<f64>,
69        b: Array1<f64>,
70        lb: Array1<f64>,
71        ub: Array1<f64>,
72        integer_vars: Vec<usize>,
73    ) -> OptimizeResult<Self> {
74        let n = c.len();
75        let (m, ncols) = a.dim();
76        if ncols != n {
77            return Err(OptimizeError::InvalidInput(format!(
78                "A has {} columns but c has {} entries",
79                ncols, n
80            )));
81        }
82        if b.len() != m {
83            return Err(OptimizeError::InvalidInput(format!(
84                "b has length {} but A has {} rows",
85                b.len(),
86                m
87            )));
88        }
89        if lb.len() != n || ub.len() != n {
90            return Err(OptimizeError::InvalidInput(
91                "lb and ub must have the same length as c".to_string(),
92            ));
93        }
94        for &idx in &integer_vars {
95            if idx >= n {
96                return Err(OptimizeError::InvalidInput(format!(
97                    "integer_vars contains out-of-range index {}",
98                    idx
99                )));
100            }
101        }
102        Ok(MilpProblem {
103            c,
104            a,
105            b,
106            lb,
107            ub,
108            integer_vars,
109        })
110    }
111
112    /// Number of variables.
113    #[inline]
114    pub fn n_vars(&self) -> usize {
115        self.c.len()
116    }
117
118    /// Number of constraints.
119    #[inline]
120    pub fn n_constraints(&self) -> usize {
121        self.b.len()
122    }
123}
124
125// ─────────────────────────────────────────────────────────────────────────────
126// Configuration
127// ─────────────────────────────────────────────────────────────────────────────
128
129/// Variable selection strategy in B&B
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum BranchingStrategy {
132    /// Branch on the variable whose value is closest to 0.5 (most fractional).
133    MostFractional,
134    /// Branch on the first fractional variable encountered.
135    FirstFractional,
136    /// Evaluate a small number of candidates and choose the one giving the
137    /// best LP bound improvement (strong branching, with a limited trial budget).
138    StrongBranching,
139}
140
141/// Configuration for the branch-and-bound solver.
142#[derive(Debug, Clone)]
143pub struct BnbConfig {
144    /// Maximum number of B&B nodes to explore.
145    pub max_nodes: usize,
146    /// Wall-clock time limit in seconds (0 = no limit).
147    pub time_limit_secs: f64,
148    /// Absolute gap tolerance: stop when `incumbent − lower_bound ≤ abs_gap`.
149    pub abs_gap: f64,
150    /// Relative gap tolerance: stop when `gap / |incumbent| ≤ rel_gap`.
151    pub rel_gap: f64,
152    /// Integrality tolerance.
153    pub int_tol: f64,
154    /// Variable selection strategy.
155    pub branching: BranchingStrategy,
156    /// Number of candidate variables to evaluate in strong branching.
157    pub strong_branching_candidates: usize,
158}
159
160impl Default for BnbConfig {
161    fn default() -> Self {
162        BnbConfig {
163            max_nodes: 50_000,
164            time_limit_secs: 0.0,
165            abs_gap: 1e-6,
166            rel_gap: 1e-6,
167            int_tol: 1e-6,
168            branching: BranchingStrategy::MostFractional,
169            strong_branching_candidates: 5,
170        }
171    }
172}
173
174// ─────────────────────────────────────────────────────────────────────────────
175// Result
176// ─────────────────────────────────────────────────────────────────────────────
177
178/// Result returned by [`branch_and_bound`].
179#[derive(Debug, Clone)]
180pub struct MilpResult {
181    /// Optimal solution vector (length n).
182    pub x: Array1<f64>,
183    /// Optimal objective value (`c^T x`).
184    pub obj: f64,
185    /// Whether an optimal solution was found.
186    pub success: bool,
187    /// Solver status message.
188    pub message: String,
189    /// Number of B&B nodes explored.
190    pub nodes_explored: usize,
191    /// Lower bound on the optimal objective at termination.
192    pub lower_bound: f64,
193}
194
195// ─────────────────────────────────────────────────────────────────────────────
196// LP Solver (Revised Simplex with Phase I / Phase II)
197// ─────────────────────────────────────────────────────────────────────────────
198
199/// Status returned by the LP solver.
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201enum LpStatus {
202    Optimal,
203    Infeasible,
204    Unbounded,
205}
206
207/// LP solution
208struct LpSolution {
209    x: Vec<f64>,
210    obj: f64,
211    status: LpStatus,
212}
213
214/// Revised-simplex LP solver.
215///
216/// Solves `min c^T x` subject to:
217///   `A_ub x <= b_ub`
218///   `lb <= x <= ub`
219///
220/// We convert to standard form by:
221/// - Adding slack variables `s_i >= 0` for each inequality: `A x + s = b_ub`
222/// - Shifting `x_j <- x_j - lb[j]` so all variables are non-negative
223///   (after adding `ub_shift = ub[j] - lb[j]` as an upper bound)
224/// - Adding artificial variables for Phase I (big-M method) to handle RHS < 0
225fn solve_lp(
226    c: &[f64],         // n
227    a_ub: &[Vec<f64>], // m × n
228    b_ub: &[f64],      // m
229    lb: &[f64],        // n
230    ub: &[f64],        // n
231    max_iter: usize,
232) -> LpSolution {
233    let n = c.len();
234    let m = a_ub.len();
235
236    if n == 0 {
237        return LpSolution {
238            x: Vec::new(),
239            obj: 0.0,
240            status: LpStatus::Optimal,
241        };
242    }
243
244    // --- Bound feasibility check ----------------------------------------
245    for i in 0..n {
246        if lb[i] > ub[i] + 1e-10 {
247            return LpSolution {
248                x: vec![0.0; n],
249                obj: f64::INFINITY,
250                status: LpStatus::Infeasible,
251            };
252        }
253    }
254
255    // --- Shift x <- x - lb so that all vars start non-negative ----------
256    // New variable y = x - lb; bounds: 0 <= y <= ub - lb
257    let ub_shifted: Vec<f64> = (0..n).map(|i| ub[i] - lb[i]).collect();
258    // Shift b_ub: A(y + lb) <= b => Ay <= b - A*lb
259    let b_shifted: Vec<f64> = (0..m)
260        .map(|i| {
261            b_ub[i]
262                - a_ub[i]
263                    .iter()
264                    .zip(lb.iter())
265                    .map(|(&a, &l)| a * l)
266                    .sum::<f64>()
267        })
268        .collect();
269
270    // --- Add slacks: A y + s = b_shifted (s >= 0) -------------------------
271    // Handle finite upper bounds on y with surplus variables:
272    //   y_j <= ub_j  =>  y_j + t_j = ub_j  (t_j >= 0)
273    let n_ub_constrained: usize = ub_shifted.iter().filter(|&&u| u.is_finite()).count();
274    let n_total = n + m + n_ub_constrained; // structural + slacks + UB slacks
275
276    // Map structural variable index -> UB slack column index
277    let mut ub_slack_for: Vec<Option<usize>> = vec![None; n];
278    let mut ub_slack_idx = n + m; // start after structural + inequality slacks
279    for j in 0..n {
280        if ub_shifted[j].is_finite() {
281            ub_slack_for[j] = Some(ub_slack_idx);
282            ub_slack_idx += 1;
283        }
284    }
285
286    // Build full constraint matrix [A | I_m | (UB slack columns)]
287    // rows: m (ineq slacks) + n_ub_constrained (UB bound rows)
288    let total_rows = m + n_ub_constrained;
289    let mut full_a: Vec<Vec<f64>> = vec![vec![0.0; n_total]; total_rows];
290    let mut full_b: Vec<f64> = vec![0.0; total_rows];
291
292    // Fill inequality rows
293    for i in 0..m {
294        for j in 0..n {
295            full_a[i][j] = a_ub[i][j];
296        }
297        // Slack variable for row i is at column n + i
298        full_a[i][n + i] = 1.0;
299        full_b[i] = b_shifted[i];
300    }
301
302    // Fill upper bound rows: y_j + t_j = ub_shifted[j]
303    let mut ub_row = m;
304    for j in 0..n {
305        if let Some(sk) = ub_slack_for[j] {
306            full_a[ub_row][j] = 1.0;
307            full_a[ub_row][sk] = 1.0;
308            full_b[ub_row] = ub_shifted[j];
309            ub_row += 1;
310        }
311    }
312
313    // --- Phase I: find BFS using big-M / artificial variables ------------
314    // For rows with b >= 0, the slack/UB-slack is a valid basic variable.
315    // For rows with b < 0, we negate the row (so RHS becomes positive) and
316    // add an artificial variable (because the negated slack now has coeff -1).
317
318    let mut a_work = full_a.clone();
319    let mut b_work = full_b.clone();
320    let mut needs_artif = vec![false; total_rows];
321    for i in 0..total_rows {
322        if b_work[i] < -1e-12 {
323            for v in a_work[i].iter_mut() {
324                *v = -*v;
325            }
326            b_work[i] = -b_work[i];
327            needs_artif[i] = true;
328        }
329    }
330
331    let n_artif: usize = needs_artif.iter().filter(|&&v| v).count();
332    let n_total_ext = n_total + n_artif; // extended with artificial columns
333
334    // Extend rows to include artificial columns
335    let mut artif_col_idx = n_total;
336    let mut artif_map: Vec<Option<usize>> = vec![None; total_rows]; // row -> artificial col
337    for i in 0..total_rows {
338        if needs_artif[i] {
339            artif_map[i] = Some(artif_col_idx);
340            artif_col_idx += 1;
341        }
342    }
343
344    // Extend all rows with zero columns for artificials
345    for row in a_work.iter_mut() {
346        row.resize(n_total_ext, 0.0);
347    }
348    // Set artificial variable coefficients to +1 for their respective rows
349    for i in 0..total_rows {
350        if let Some(acol) = artif_map[i] {
351            a_work[i][acol] = 1.0;
352        }
353    }
354
355    // Objective: original c for structural vars, 0 for slacks, big-M for artificials
356    let big_m = 1e7_f64;
357    let mut big_m_c: Vec<f64> = vec![0.0; n_total_ext];
358    for j in 0..n {
359        big_m_c[j] = c[j];
360    }
361    for i in 0..total_rows {
362        if let Some(acol) = artif_map[i] {
363            big_m_c[acol] = big_m;
364        }
365    }
366
367    // Initial basis: for rows not needing artificials, use slack/UB-slack;
368    // for rows needing artificials, use the artificial variable.
369    let mut basis: Vec<usize> = Vec::with_capacity(total_rows);
370    let mut ub_row_counter = 0usize;
371    for i in 0..m {
372        if needs_artif[i] {
373            basis.push(artif_map[i].unwrap_or(0));
374        } else {
375            basis.push(n + i); // slack for ineq row i
376        }
377    }
378    for j in 0..n {
379        if let Some(sk) = ub_slack_for[j] {
380            let row_idx = m + ub_row_counter;
381            if needs_artif[row_idx] {
382                basis.push(artif_map[row_idx].unwrap_or(0));
383            } else {
384                basis.push(sk);
385            }
386            ub_row_counter += 1;
387        }
388    }
389
390    // Run revised simplex
391    let sol = revised_simplex(&mut a_work, &mut b_work, &big_m_c, &mut basis, max_iter);
392
393    if sol.status == LpStatus::Infeasible || sol.status == LpStatus::Unbounded {
394        return LpSolution {
395            x: vec![0.0; n],
396            obj: f64::INFINITY,
397            status: LpStatus::Infeasible,
398        };
399    }
400
401    // Check that no artificial is in the basis with positive value
402    for (i, &bv) in basis.iter().enumerate() {
403        if bv >= n_total && b_work[i] > 1e-6 {
404            // Artificial variable still in basis -> LP is infeasible
405            return LpSolution {
406                x: vec![0.0; n],
407                obj: f64::INFINITY,
408                status: LpStatus::Infeasible,
409            };
410        }
411    }
412
413    // Extract solution: y (shifted) from sol.x (only first n_total variables)
414    let y = &sol.x;
415    let x: Vec<f64> = (0..n)
416        .map(|j| {
417            let yj = if j < y.len() { y[j] } else { 0.0 };
418            (lb[j] + yj).max(lb[j]).min(ub[j])
419        })
420        .collect();
421    let obj = c
422        .iter()
423        .zip(x.iter())
424        .map(|(&ci, &xi)| ci * xi)
425        .sum::<f64>();
426
427    LpSolution {
428        x,
429        obj,
430        status: LpStatus::Optimal,
431    }
432}
433
434/// Simplified revised simplex method (tableau form).
435///
436/// Operates on the system `A x = b` with `x >= 0`.
437/// `basis` is the initial basic feasible basis (indices of basic variables).
438fn revised_simplex(
439    a: &mut Vec<Vec<f64>>,
440    b: &mut Vec<f64>,
441    c: &[f64],
442    basis: &mut Vec<usize>,
443    max_iter: usize,
444) -> LpSolution {
445    let m = a.len();
446    if m == 0 {
447        let n = c.len();
448        let mut x = vec![0.0_f64; n];
449        // Minimise: set vars to lb (which is 0 after shift)
450        return LpSolution {
451            x,
452            obj: 0.0,
453            status: LpStatus::Optimal,
454        };
455    }
456    let n_total = if m > 0 { a[0].len() } else { 0 };
457
458    // Build basis inverse (B^{-1}) as an m×m identity initially,
459    // then update with pivot operations.  We work with a full tableau for simplicity.
460    // Full tableau: [A | I] -> after entering each column we update.
461
462    // We use the simplex tableau directly:
463    // tableau[i][j] = B^{-1} A_j  for structural column j
464    // tableau[i][m] = B^{-1} b
465    // Reduced costs: c_bar[j] = c[j] - c_B B^{-1} A_j
466
467    let mut tableau: Vec<Vec<f64>> = vec![vec![0.0; n_total + 1]; m];
468    for i in 0..m {
469        for j in 0..n_total {
470            tableau[i][j] = a[i][j];
471        }
472        tableau[i][n_total] = b[i];
473    }
474
475    // Make basis columns identity via row operations (initial BFS)
476    for col in 0..m {
477        let basic = basis[col];
478        // Find the row where this column has 1.0 (it should, by construction)
479        // Actually the initial basis columns may not be unit vectors if rows were flipped.
480        // Pivot to make basis[col] a unit vector in column col.
481        let pivot_row = col; // assume basis variable for row col is in position col
482                             // Find pivot in this column among rows
483        let pivot_val = tableau[pivot_row][basic];
484        if pivot_val.abs() < 1e-12 {
485            // Try to find a different row for this basis element
486            let mut found = false;
487            for i in 0..m {
488                if i != pivot_row && tableau[i][basic].abs() > 1e-10 {
489                    tableau.swap(pivot_row, i);
490                    basis.swap(pivot_row, i);
491                    found = true;
492                    break;
493                }
494            }
495            if !found {
496                continue;
497            }
498        }
499        let pv = tableau[pivot_row][basic];
500        if pv.abs() < 1e-12 {
501            continue;
502        }
503        for j in 0..=n_total {
504            tableau[pivot_row][j] /= pv;
505        }
506        for i in 0..m {
507            if i == pivot_row {
508                continue;
509            }
510            let factor = tableau[i][basic];
511            if factor.abs() < 1e-15 {
512                continue;
513            }
514            for j in 0..=n_total {
515                let delta = factor * tableau[pivot_row][j];
516                tableau[i][j] -= delta;
517            }
518        }
519    }
520
521    // Simplex iterations
522    for _iter in 0..max_iter {
523        // Compute reduced costs
524        let c_b: Vec<f64> = basis
525            .iter()
526            .map(|&b| c.get(b).copied().unwrap_or(0.0))
527            .collect();
528
529        let mut enter = None;
530        let mut min_rc = -1e-8_f64;
531        for j in 0..n_total {
532            // rc = c[j] - c_B^T B^{-1} A_j = c[j] - c_B^T tableau_col
533            let rc = c.get(j).copied().unwrap_or(0.0)
534                - c_b
535                    .iter()
536                    .zip(tableau.iter())
537                    .map(|(&cb, row)| cb * row[j])
538                    .sum::<f64>();
539            if rc < min_rc {
540                min_rc = rc;
541                enter = Some(j);
542            }
543        }
544
545        let enter_col = match enter {
546            None => break, // optimal
547            Some(j) => j,
548        };
549
550        // Minimum ratio test
551        let mut leave_row = None;
552        let mut min_ratio = f64::INFINITY;
553        for i in 0..m {
554            let coef = tableau[i][enter_col];
555            if coef > 1e-10 {
556                let ratio = tableau[i][n_total] / coef;
557                if ratio < min_ratio {
558                    min_ratio = ratio;
559                    leave_row = Some(i);
560                }
561            }
562        }
563
564        let pivot_row = match leave_row {
565            None => {
566                // Unbounded
567                let mut x = vec![0.0; n_total];
568                for (i, &b) in basis.iter().enumerate() {
569                    if b < n_total {
570                        x[b] = tableau[i][n_total].max(0.0);
571                    }
572                }
573                return LpSolution {
574                    x,
575                    obj: f64::NEG_INFINITY,
576                    status: LpStatus::Unbounded,
577                };
578            }
579            Some(r) => r,
580        };
581
582        // Pivot
583        let pv = tableau[pivot_row][enter_col];
584        for j in 0..=n_total {
585            tableau[pivot_row][j] /= pv;
586        }
587        for i in 0..m {
588            if i == pivot_row {
589                continue;
590            }
591            let factor = tableau[i][enter_col];
592            if factor.abs() < 1e-15 {
593                continue;
594            }
595            for j in 0..=n_total {
596                let delta = factor * tableau[pivot_row][j];
597                tableau[i][j] -= delta;
598            }
599        }
600        basis[pivot_row] = enter_col;
601    }
602
603    // Extract solution
604    let mut x = vec![0.0_f64; n_total];
605    for (i, &b) in basis.iter().enumerate() {
606        if b < n_total {
607            x[b] = tableau[i][n_total].max(0.0);
608        }
609    }
610
611    // Update b (for external use)
612    for i in 0..m {
613        b[i] = tableau[i][n_total];
614    }
615
616    let obj: f64 = c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum();
617    LpSolution {
618        x,
619        obj,
620        status: LpStatus::Optimal,
621    }
622}
623
624// ─────────────────────────────────────────────────────────────────────────────
625// B&B Node
626// ─────────────────────────────────────────────────────────────────────────────
627
628/// A node in the branch-and-bound tree.
629#[derive(Debug, Clone)]
630struct BbNode {
631    /// Tightened lower bounds (length n), incorporating branching constraints.
632    lb: Vec<f64>,
633    /// Tightened upper bounds (length n), incorporating branching constraints.
634    ub: Vec<f64>,
635    /// LP lower bound at this node (for priority queue ordering).
636    lp_lb: f64,
637    /// Depth in the tree (for tie-breaking).
638    depth: usize,
639}
640
641/// Wrapper for priority queue (max-heap, negated to get min-heap by lp_lb).
642struct PqEntry {
643    neg_lb: f64,
644    node: BbNode,
645}
646
647impl PartialEq for PqEntry {
648    fn eq(&self, other: &Self) -> bool {
649        self.neg_lb == other.neg_lb
650    }
651}
652
653impl Eq for PqEntry {}
654
655impl PartialOrd for PqEntry {
656    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
657        Some(self.cmp(other))
658    }
659}
660
661impl Ord for PqEntry {
662    fn cmp(&self, other: &Self) -> Ordering {
663        self.neg_lb
664            .partial_cmp(&other.neg_lb)
665            .unwrap_or(Ordering::Equal)
666    }
667}
668
669// ─────────────────────────────────────────────────────────────────────────────
670// Helper: check/evaluate integrality
671// ─────────────────────────────────────────────────────────────────────────────
672
673fn is_integer_valued_local(v: f64, tol: f64) -> bool {
674    (v - v.round()).abs() <= tol
675}
676
677fn eval_obj_vec(c: &[f64], x: &[f64]) -> f64 {
678    c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum()
679}
680
681// ─────────────────────────────────────────────────────────────────────────────
682// Variable selection
683// ─────────────────────────────────────────────────────────────────────────────
684
685fn select_most_fractional(x: &[f64], int_vars: &[usize], tol: f64) -> Option<usize> {
686    let mut best = None;
687    let mut best_dist = -1.0_f64;
688    for &j in int_vars {
689        let xi = x[j];
690        let frac = (xi - xi.floor()).min(xi.ceil() - xi);
691        if frac > tol && frac > best_dist {
692            best_dist = frac;
693            best = Some(j);
694        }
695    }
696    best
697}
698
699fn select_first_fractional(x: &[f64], int_vars: &[usize], tol: f64) -> Option<usize> {
700    for &j in int_vars {
701        if !is_integer_valued_local(x[j], tol) {
702            return Some(j);
703        }
704    }
705    None
706}
707
708/// Strong branching: solve 2 mini-LPs for each candidate and pick the one
709/// with the best min(down_lb, up_lb).
710fn select_strong_branching(
711    x: &[f64],
712    int_vars: &[usize],
713    tol: f64,
714    config: &BnbConfig,
715    problem: &MilpProblem,
716    node_lb: &[f64],
717    node_ub: &[f64],
718) -> Option<usize> {
719    let fractional: Vec<usize> = int_vars
720        .iter()
721        .copied()
722        .filter(|&j| !is_integer_valued_local(x[j], tol))
723        .collect();
724    if fractional.is_empty() {
725        return None;
726    }
727
728    // Sort candidates by most-fractional order to limit probing
729    let mut candidates: Vec<(usize, f64)> = fractional
730        .iter()
731        .copied()
732        .map(|j| {
733            let frac = (x[j] - x[j].floor()).min(x[j].ceil() - x[j]);
734            (j, frac)
735        })
736        .collect();
737    candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
738    let k = candidates.len().min(config.strong_branching_candidates);
739    candidates.truncate(k);
740
741    let n = problem.n_vars();
742    let m = problem.n_constraints();
743    let a_rows: Vec<Vec<f64>> = (0..m)
744        .map(|i| (0..n).map(|j| problem.a[[i, j]]).collect())
745        .collect();
746    let b_vec: Vec<f64> = problem.b.to_vec();
747    let c_vec: Vec<f64> = problem.c.to_vec();
748
749    let mut best_var = None;
750    let mut best_score = f64::NEG_INFINITY;
751
752    for &(var, _) in &candidates {
753        let xi = x[var];
754        let xi_floor = xi.floor();
755        let xi_ceil = xi.ceil();
756
757        // Down branch: ub[var] = floor
758        let mut lb_d = node_lb.to_vec();
759        let mut ub_d = node_ub.to_vec();
760        ub_d[var] = ub_d[var].min(xi_floor);
761        let down = solve_lp(&c_vec, &a_rows, &b_vec, &lb_d, &ub_d, 500);
762
763        // Up branch: lb[var] = ceil
764        let mut lb_u = node_lb.to_vec();
765        let mut ub_u = node_ub.to_vec();
766        lb_u[var] = lb_u[var].max(xi_ceil);
767        let up = solve_lp(&c_vec, &a_rows, &b_vec, &lb_u, &ub_u, 500);
768
769        let down_bound = if down.status == LpStatus::Infeasible {
770            f64::INFINITY
771        } else {
772            down.obj
773        };
774        let up_bound = if up.status == LpStatus::Infeasible {
775            f64::INFINITY
776        } else {
777            up.obj
778        };
779
780        // Score: max of the two bounds (prefer the branch with higher minimum bound)
781        let score = down_bound.min(up_bound);
782        if score > best_score {
783            best_score = score;
784            best_var = Some(var);
785        }
786    }
787
788    best_var
789}
790
791// ─────────────────────────────────────────────────────────────────────────────
792// Main branch-and-bound entry point
793// ─────────────────────────────────────────────────────────────────────────────
794
795/// Solve a Mixed-Integer Linear Program via branch-and-bound.
796///
797/// # Arguments
798/// * `problem` – the [`MilpProblem`] to solve
799/// * `config`  – [`BnbConfig`] controlling solver behaviour
800///
801/// # Returns
802/// A [`MilpResult`] containing the solution, objective value, and statistics.
803///
804/// # Errors
805/// Returns [`OptimizeError::InvalidInput`] if `problem` has inconsistent dimensions.
806///
807/// # Example
808/// ```rust,no_run
809/// use scirs2_optimize::integer::milp_branch_and_bound::{
810///     MilpProblem, BnbConfig, branch_and_bound,
811/// };
812/// use scirs2_core::ndarray::{array, Array2};
813///
814/// // maximize 4x0 + 3x1 s.t. 2x0 + 3x1 <= 6, x in {0,1}^2
815/// // (as minimization: minimize -4x0 - 3x1)
816/// let c  = array![-4.0, -3.0];
817/// let a  = Array2::from_shape_vec((1, 2), vec![2.0, 3.0]).expect("valid input");
818/// let b  = array![6.0];
819/// let lb = array![0.0, 0.0];
820/// let ub = array![1.0, 1.0];
821/// let prob = MilpProblem::new(c, a, b, lb, ub, vec![0, 1]).expect("valid input");
822/// let cfg = BnbConfig::default();
823/// let res = branch_and_bound(&prob, &cfg).expect("valid input");
824/// assert!(res.success);
825/// ```
826pub fn branch_and_bound(problem: &MilpProblem, config: &BnbConfig) -> OptimizeResult<MilpResult> {
827    let n = problem.n_vars();
828    let m = problem.n_constraints();
829
830    if n == 0 {
831        return Err(OptimizeError::InvalidInput(
832            "Problem has no variables".to_string(),
833        ));
834    }
835
836    let start_time = std::time::Instant::now();
837
838    let c_vec: Vec<f64> = problem.c.to_vec();
839    let b_vec: Vec<f64> = problem.b.to_vec();
840    let a_rows: Vec<Vec<f64>> = (0..m)
841        .map(|i| (0..n).map(|j| problem.a[[i, j]]).collect())
842        .collect();
843
844    let base_lb: Vec<f64> = problem.lb.to_vec();
845    let base_ub: Vec<f64> = problem.ub.to_vec();
846
847    // Apply binary variable bounds (treat vars with ub=1 and lb=0 as binary)
848    let mut root_lb = base_lb.clone();
849    let mut root_ub = base_ub.clone();
850    for &j in &problem.integer_vars {
851        // Clamp to integer bounds
852        root_lb[j] = root_lb[j].ceil();
853        root_ub[j] = root_ub[j].floor();
854        if root_lb[j] > root_ub[j] {
855            // Infeasible by bounds alone
856            return Ok(MilpResult {
857                x: Array1::zeros(n),
858                obj: f64::INFINITY,
859                success: false,
860                message: format!("Variable {} has empty integer domain", j),
861                nodes_explored: 0,
862                lower_bound: f64::INFINITY,
863            });
864        }
865    }
866
867    // Solve root LP relaxation
868    let root_lp = solve_lp(&c_vec, &a_rows, &b_vec, &root_lb, &root_ub, 5000);
869
870    if root_lp.status == LpStatus::Infeasible {
871        return Ok(MilpResult {
872            x: Array1::zeros(n),
873            obj: f64::INFINITY,
874            success: false,
875            message: "Root LP relaxation is infeasible".to_string(),
876            nodes_explored: 1,
877            lower_bound: f64::INFINITY,
878        });
879    }
880    if root_lp.status == LpStatus::Unbounded {
881        return Ok(MilpResult {
882            x: Array1::zeros(n),
883            obj: f64::NEG_INFINITY,
884            success: false,
885            message: "Root LP relaxation is unbounded".to_string(),
886            nodes_explored: 1,
887            lower_bound: f64::NEG_INFINITY,
888        });
889    }
890
891    let mut incumbent: Option<Vec<f64>> = None;
892    let mut incumbent_obj = f64::INFINITY;
893    let mut nodes_explored = 1usize;
894    let mut global_lb = root_lp.obj;
895
896    // Check if root LP is already integer feasible
897    let root_x = root_lp.x;
898    let all_int = problem
899        .integer_vars
900        .iter()
901        .all(|&j| is_integer_valued_local(root_x[j], config.int_tol));
902    if all_int {
903        let obj = eval_obj_vec(&c_vec, &root_x);
904        return Ok(MilpResult {
905            x: Array1::from_vec(root_x),
906            obj,
907            success: true,
908            message: "Root LP relaxation is integer feasible".to_string(),
909            nodes_explored,
910            lower_bound: obj,
911        });
912    }
913
914    // Initialize priority queue (best-first by LP lower bound)
915    let root_node = BbNode {
916        lb: root_lb,
917        ub: root_ub,
918        lp_lb: root_lp.obj,
919        depth: 0,
920    };
921
922    let mut pq: BinaryHeap<PqEntry> = BinaryHeap::new();
923    pq.push(PqEntry {
924        neg_lb: -root_node.lp_lb,
925        node: root_node,
926    });
927
928    while let Some(PqEntry { node, .. }) = pq.pop() {
929        nodes_explored += 1;
930
931        // Check termination conditions
932        if nodes_explored > config.max_nodes {
933            break;
934        }
935        if config.time_limit_secs > 0.0 {
936            let elapsed = start_time.elapsed().as_secs_f64();
937            if elapsed >= config.time_limit_secs {
938                break;
939            }
940        }
941
942        // Prune: node lower bound >= incumbent
943        if node.lp_lb >= incumbent_obj - config.abs_gap {
944            continue;
945        }
946        if incumbent_obj.abs() > 1e-10 {
947            let gap = (incumbent_obj - node.lp_lb) / incumbent_obj.abs();
948            if gap <= config.rel_gap {
949                continue;
950            }
951        }
952
953        // Solve LP at this node
954        let lp = solve_lp(&c_vec, &a_rows, &b_vec, &node.lb, &node.ub, 3000);
955
956        if lp.status == LpStatus::Infeasible {
957            continue;
958        }
959        if lp.status == LpStatus::Unbounded {
960            continue;
961        }
962
963        let lp_obj = lp.obj;
964
965        // Update global lower bound
966        if lp_obj > global_lb {
967            global_lb = lp_obj;
968        }
969
970        // Prune: LP objective >= incumbent
971        if lp_obj >= incumbent_obj - config.abs_gap {
972            continue;
973        }
974
975        let lp_x = lp.x;
976
977        // Check integrality
978        let int_feasible = problem
979            .integer_vars
980            .iter()
981            .all(|&j| is_integer_valued_local(lp_x[j], config.int_tol));
982
983        if int_feasible {
984            let obj = eval_obj_vec(&c_vec, &lp_x);
985            if obj < incumbent_obj {
986                incumbent_obj = obj;
987                incumbent = Some(lp_x);
988                global_lb = global_lb.min(incumbent_obj);
989            }
990            continue;
991        }
992
993        // Select branching variable
994        let branch_var = match config.branching {
995            BranchingStrategy::MostFractional => {
996                select_most_fractional(&lp_x, &problem.integer_vars, config.int_tol)
997            }
998            BranchingStrategy::FirstFractional => {
999                select_first_fractional(&lp_x, &problem.integer_vars, config.int_tol)
1000            }
1001            BranchingStrategy::StrongBranching => select_strong_branching(
1002                &lp_x,
1003                &problem.integer_vars,
1004                config.int_tol,
1005                config,
1006                problem,
1007                &node.lb,
1008                &node.ub,
1009            ),
1010        };
1011
1012        let branch_var = match branch_var {
1013            Some(v) => v,
1014            None => {
1015                // All integer variables are integral (shouldn't happen but handle it)
1016                let obj = eval_obj_vec(&c_vec, &lp_x);
1017                if obj < incumbent_obj {
1018                    incumbent_obj = obj;
1019                    incumbent = Some(lp_x);
1020                }
1021                continue;
1022            }
1023        };
1024
1025        let xi = lp_x[branch_var];
1026        let xi_floor = xi.floor();
1027        let xi_ceil = xi.ceil();
1028
1029        // Down branch: x[branch_var] <= floor(xi)
1030        {
1031            let mut lb_d = node.lb.clone();
1032            let mut ub_d = node.ub.clone();
1033            ub_d[branch_var] = ub_d[branch_var].min(xi_floor);
1034            if lb_d[branch_var] <= ub_d[branch_var] + 1e-10 {
1035                pq.push(PqEntry {
1036                    neg_lb: -lp_obj,
1037                    node: BbNode {
1038                        lb: lb_d,
1039                        ub: ub_d,
1040                        lp_lb: lp_obj,
1041                        depth: node.depth + 1,
1042                    },
1043                });
1044            }
1045        }
1046
1047        // Up branch: x[branch_var] >= ceil(xi)
1048        {
1049            let mut lb_u = node.lb.clone();
1050            let mut ub_u = node.ub.clone();
1051            lb_u[branch_var] = lb_u[branch_var].max(xi_ceil);
1052            if lb_u[branch_var] <= ub_u[branch_var] + 1e-10 {
1053                pq.push(PqEntry {
1054                    neg_lb: -lp_obj,
1055                    node: BbNode {
1056                        lb: lb_u,
1057                        ub: ub_u,
1058                        lp_lb: lp_obj,
1059                        depth: node.depth + 1,
1060                    },
1061                });
1062            }
1063        }
1064    }
1065
1066    // Update global lb from remaining queue
1067    if let Some(PqEntry { node, .. }) = pq.peek() {
1068        if node.lp_lb > global_lb {
1069            global_lb = node.lp_lb;
1070        }
1071    }
1072
1073    match incumbent {
1074        Some(x) => Ok(MilpResult {
1075            x: Array1::from_vec(x),
1076            obj: incumbent_obj,
1077            success: true,
1078            message: format!(
1079                "Optimal solution found (nodes={}, gap={:.2e})",
1080                nodes_explored,
1081                (incumbent_obj - global_lb).abs()
1082            ),
1083            nodes_explored,
1084            lower_bound: global_lb,
1085        }),
1086        None => Ok(MilpResult {
1087            x: Array1::zeros(n),
1088            obj: f64::INFINITY,
1089            success: false,
1090            message: format!(
1091                "No integer feasible solution found in {} nodes",
1092                nodes_explored
1093            ),
1094            nodes_explored,
1095            lower_bound: global_lb,
1096        }),
1097    }
1098}
1099
1100// ─────────────────────────────────────────────────────────────────────────────
1101// Tests
1102// ─────────────────────────────────────────────────────────────────────────────
1103
1104#[cfg(test)]
1105mod tests {
1106    use super::*;
1107    use approx::assert_abs_diff_eq;
1108    use scirs2_core::ndarray::{array, Array2};
1109
1110    // Helper: create a simple binary knapsack MILP
1111    // maximize sum(v*x) s.t. sum(w*x) <= cap, x in {0,1}
1112    fn make_knapsack(values: &[f64], weights: &[f64], cap: f64) -> MilpProblem {
1113        let n = values.len();
1114        let c = Array1::from_vec(values.iter().map(|&v| -v).collect());
1115        let a = Array2::from_shape_vec((1, n), weights.to_vec()).expect("shape");
1116        let b = array![cap];
1117        let lb = Array1::zeros(n);
1118        let ub = Array1::ones(n);
1119        let int_vars = (0..n).collect();
1120        MilpProblem::new(c, a, b, lb, ub, int_vars).expect("valid problem")
1121    }
1122
1123    #[test]
1124    fn test_milp_binary_knapsack() {
1125        let values = vec![4.0, 3.0, 5.0, 2.0, 6.0];
1126        let weights = vec![2.0, 3.0, 4.0, 1.0, 5.0];
1127        let prob = make_knapsack(&values, &weights, 8.0);
1128        let cfg = BnbConfig::default();
1129        let res = branch_and_bound(&prob, &cfg).expect("failed to create res");
1130        assert!(res.success, "B&B should find solution");
1131        // Optimal: items 0,3,4 with total value 12
1132        assert!(
1133            res.obj <= -11.9,
1134            "optimal obj should be ~-12, got {}",
1135            res.obj
1136        );
1137    }
1138
1139    #[test]
1140    fn test_milp_pure_integer() {
1141        // min x + y, x+y >= 3.5, x,y >= 0 integer
1142        // optimal: x+y = 4, obj = 4
1143        let c = array![1.0, 1.0];
1144        let a = Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("failed to create a");
1145        let b = array![-3.5];
1146        let lb = array![0.0, 0.0];
1147        let ub = array![10.0, 10.0];
1148        let prob = MilpProblem::new(c, a, b, lb, ub, vec![0, 1]).expect("failed to create prob");
1149        let cfg = BnbConfig::default();
1150        let res = branch_and_bound(&prob, &cfg).expect("failed to create res");
1151        assert!(res.success);
1152        assert_abs_diff_eq!(res.obj, 4.0, epsilon = 1e-4);
1153    }
1154
1155    #[test]
1156    fn test_milp_lp_optimal_already_integer() {
1157        // LP relaxation already gives integer solution
1158        let c = array![1.0, 2.0];
1159        let a = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).expect("failed to create a");
1160        let b = array![10.0];
1161        let lb = array![2.0, 3.0];
1162        let ub = array![5.0, 6.0];
1163        let prob = MilpProblem::new(c, a, b, lb, ub, vec![0, 1]).expect("failed to create prob");
1164        let cfg = BnbConfig::default();
1165        let res = branch_and_bound(&prob, &cfg).expect("failed to create res");
1166        assert!(res.success);
1167        // min at lb = (2,3) -> obj = 2 + 6 = 8
1168        assert_abs_diff_eq!(res.obj, 8.0, epsilon = 1e-3);
1169    }
1170
1171    #[test]
1172    fn test_milp_strong_branching() {
1173        let values = vec![6.0, 5.0, 4.0, 3.0];
1174        let weights = vec![3.0, 3.0, 2.0, 1.0];
1175        let prob = make_knapsack(&values, &weights, 6.0);
1176        let cfg = BnbConfig {
1177            branching: BranchingStrategy::StrongBranching,
1178            ..Default::default()
1179        };
1180        let res = branch_and_bound(&prob, &cfg).expect("failed to create res");
1181        assert!(res.success);
1182        // Best: items 0 (6,3) + 3 (3,1) + 2 (4,2) = 13, weight = 6 or items 0+1+3 = 14 w=7>6
1183        // items 0+2+3 = 13, w=6 or items 1+2+3=12, w=6
1184        assert!(res.obj <= -12.9, "obj={} should be <= -13", res.obj);
1185    }
1186
1187    #[test]
1188    fn test_milp_infeasible() {
1189        // x in {0,1}, x >= 2 -> infeasible
1190        let c = array![1.0];
1191        let a = Array2::from_shape_vec((1, 1), vec![-1.0]).expect("failed to create a");
1192        let b = array![-2.0];
1193        let lb = array![0.0];
1194        let ub = array![1.0];
1195        let prob = MilpProblem::new(c, a, b, lb, ub, vec![0]).expect("failed to create prob");
1196        let cfg = BnbConfig::default();
1197        let res = branch_and_bound(&prob, &cfg).expect("failed to create res");
1198        assert!(!res.success, "should be infeasible");
1199    }
1200
1201    #[test]
1202    fn test_milp_mixed_integer() {
1203        // min 2x + y; x integer, y continuous; x+y >= 2.5; x,y >= 0
1204        // optimal: x=0 (integer), y=2.5 -> obj=2.5
1205        let c = array![2.0, 1.0];
1206        let a = Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("failed to create a");
1207        let b = array![-2.5];
1208        let lb = array![0.0, 0.0];
1209        let ub = array![10.0, 10.0];
1210        let prob = MilpProblem::new(c, a, b, lb, ub, vec![0]).expect("failed to create prob");
1211        let cfg = BnbConfig::default();
1212        let res = branch_and_bound(&prob, &cfg).expect("failed to create res");
1213        assert!(res.success);
1214        assert!(res.obj <= 3.0 + 1e-4, "obj={}", res.obj);
1215    }
1216
1217    #[test]
1218    fn test_bnb_config_default() {
1219        let cfg = BnbConfig::default();
1220        assert_eq!(cfg.max_nodes, 50_000);
1221        assert_eq!(cfg.branching, BranchingStrategy::MostFractional);
1222    }
1223
1224    #[test]
1225    fn test_milp_problem_new_error_dim() {
1226        let c = array![1.0, 2.0];
1227        let a = Array2::from_shape_vec((1, 3), vec![1.0, 1.0, 1.0]).expect("failed to create a"); // 3 cols != 2
1228        let b = array![5.0];
1229        let lb = array![0.0, 0.0];
1230        let ub = array![1.0, 1.0];
1231        let res = MilpProblem::new(c, a, b, lb, ub, vec![]);
1232        assert!(res.is_err());
1233    }
1234}