Skip to main content

scirs2_optimize/integer/
branch_and_bound.rs

1//! Branch-and-Bound algorithm for Mixed-Integer Programming
2//!
3//! The branch-and-bound algorithm systematically explores the space of integer
4//! solutions by:
5//! 1. Solving LP relaxations at each node
6//! 2. Branching on a fractional integer variable
7//! 3. Using bounds to prune subproblems that cannot improve on the incumbent
8//!
9//! # Algorithm
10//! - Start with the LP relaxation of the full problem
11//! - Maintain an incumbent (best integer solution found)
12//! - Branch by adding floor/ceil constraints on a fractional variable
13//! - Prune nodes where LP lower bound >= incumbent objective
14//!
15//! # References
16//! - Land, A.H. & Doig, A.G. (1960). "An automatic method of solving discrete
17//!   programming problems." Econometrica, 28(3), 497-520.
18
19use super::{
20    is_integer_valued,
21    lp_relaxation::{LpRelaxationSolver, LpResult},
22    most_fractional_variable, IntegerKind, IntegerVariableSet, LinearProgram, MipResult,
23};
24use crate::error::{OptimizeError, OptimizeResult};
25use scirs2_core::ndarray::Array1;
26use std::collections::VecDeque;
27
28/// Options for branch-and-bound
29#[derive(Debug, Clone)]
30pub struct BranchAndBoundOptions {
31    /// Maximum number of nodes to explore
32    pub max_nodes: usize,
33    /// Absolute tolerance for integrality
34    pub int_tol: f64,
35    /// Absolute tolerance for optimality gap
36    pub opt_tol: f64,
37    /// Maximum LP iterations per node
38    pub max_lp_iter: usize,
39    /// Node selection strategy
40    pub node_selection: NodeSelection,
41    /// Variable selection strategy
42    pub variable_selection: VariableSelection,
43}
44
45impl Default for BranchAndBoundOptions {
46    fn default() -> Self {
47        BranchAndBoundOptions {
48            max_nodes: 10000,
49            int_tol: 1e-6,
50            opt_tol: 1e-6,
51            max_lp_iter: 1000,
52            node_selection: NodeSelection::BestFirst,
53            variable_selection: VariableSelection::MostFractional,
54        }
55    }
56}
57
58/// Node selection strategy for branch-and-bound
59#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum NodeSelection {
61    /// Best-first: select node with lowest lower bound
62    BestFirst,
63    /// Depth-first: explore deepest node first (DFS)
64    DepthFirst,
65    /// Best-of-depth: DFS with best-bound pruning
66    BestOfDepth,
67}
68
69/// Variable selection strategy for branching
70#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum VariableSelection {
72    /// Branch on most fractional variable
73    MostFractional,
74    /// Branch on first fractional integer variable
75    FirstFractional,
76    /// Branch on variable with largest LP value range
77    MaxRange,
78}
79
80/// A node in the B&B tree
81#[derive(Debug, Clone)]
82struct BbNode {
83    /// Additional lower bounds for this node (branching constraints)
84    extra_lb: Vec<f64>,
85    /// Additional upper bounds for this node (branching constraints)
86    extra_ub: Vec<f64>,
87    /// LP lower bound at this node (for ordering)
88    lb: f64,
89    /// Depth in the tree
90    depth: usize,
91}
92
93/// Branch-and-bound solver for mixed-integer programming
94pub struct BranchAndBoundSolver {
95    pub options: BranchAndBoundOptions,
96}
97
98impl BranchAndBoundSolver {
99    /// Create with default options
100    pub fn new() -> Self {
101        BranchAndBoundSolver {
102            options: BranchAndBoundOptions::default(),
103        }
104    }
105
106    /// Create with custom options
107    pub fn with_options(options: BranchAndBoundOptions) -> Self {
108        BranchAndBoundSolver { options }
109    }
110
111    /// Select branching variable
112    fn select_variable(&self, x: &[f64], ivs: &IntegerVariableSet) -> Option<usize> {
113        match self.options.variable_selection {
114            VariableSelection::MostFractional => most_fractional_variable(x, ivs),
115            VariableSelection::FirstFractional => {
116                for (i, &xi) in x.iter().enumerate() {
117                    if ivs.is_integer(i) && !is_integer_valued(xi, self.options.int_tol) {
118                        return Some(i);
119                    }
120                }
121                None
122            }
123            VariableSelection::MaxRange => {
124                // Select the integer variable with largest fractional part
125                most_fractional_variable(x, ivs)
126            }
127        }
128    }
129
130    /// Check if an LP solution satisfies integrality
131    fn is_integer_feasible(&self, x: &[f64], ivs: &IntegerVariableSet) -> bool {
132        for (i, &xi) in x.iter().enumerate() {
133            if ivs.is_integer(i) && !is_integer_valued(xi, self.options.int_tol) {
134                return false;
135            }
136        }
137        true
138    }
139
140    /// Round integer variables in LP solution and verify feasibility
141    fn round_integer_solution(&self, x: &[f64], ivs: &IntegerVariableSet) -> Vec<f64> {
142        x.iter()
143            .enumerate()
144            .map(|(i, &xi)| if ivs.is_integer(i) { xi.round() } else { xi })
145            .collect()
146    }
147
148    /// Evaluate objective value
149    fn eval_obj(lp: &LinearProgram, x: &[f64]) -> f64 {
150        lp.c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum()
151    }
152
153    /// Solve the MIP problem
154    pub fn solve(&self, lp: &LinearProgram, ivs: &IntegerVariableSet) -> OptimizeResult<MipResult> {
155        let n = lp.n_vars();
156        if n == 0 {
157            return Err(OptimizeError::InvalidInput("Empty problem".to_string()));
158        }
159        if ivs.len() != n {
160            return Err(OptimizeError::InvalidInput(format!(
161                "IntegerVariableSet length {} != LP dimension {}",
162                ivs.len(),
163                n
164            )));
165        }
166
167        let base_lb: Vec<f64> = lp.lower.as_ref().map_or(vec![0.0; n], |l| l.to_vec());
168        let base_ub: Vec<f64> = lp
169            .upper
170            .as_ref()
171            .map_or(vec![f64::INFINITY; n], |u| u.to_vec());
172
173        // Apply binary variable bounds
174        let mut base_lb = base_lb;
175        let mut base_ub = base_ub;
176        for i in 0..n {
177            if ivs.kinds[i] == IntegerKind::Binary {
178                base_lb[i] = base_lb[i].max(0.0);
179                base_ub[i] = base_ub[i].min(1.0);
180            }
181        }
182
183        // Solve root LP relaxation
184        let root_result = LpRelaxationSolver::solve(lp, &base_lb, &base_ub)?;
185
186        if !root_result.success {
187            return Ok(MipResult {
188                x: Array1::zeros(n),
189                fun: f64::INFINITY,
190                success: false,
191                message: "LP relaxation infeasible".to_string(),
192                nodes_explored: 1,
193                lp_solves: 1,
194                lower_bound: f64::INFINITY,
195            });
196        }
197
198        let mut incumbent: Option<Vec<f64>> = None;
199        let mut incumbent_obj = f64::INFINITY;
200        let mut nodes_explored = 1usize;
201        let mut lp_solves = 1usize;
202        let mut global_lb = root_result.fun;
203
204        // Check if root LP solution is already integer feasible
205        let root_x: Vec<f64> = root_result.x.to_vec();
206        if self.is_integer_feasible(&root_x, ivs) {
207            let obj = Self::eval_obj(lp, &root_x);
208            incumbent = Some(root_x.clone());
209            incumbent_obj = obj;
210            return Ok(MipResult {
211                x: Array1::from_vec(root_x),
212                fun: obj,
213                success: true,
214                message: "LP relaxation gives integer solution".to_string(),
215                nodes_explored,
216                lp_solves,
217                lower_bound: global_lb,
218            });
219        }
220
221        // Initialize node queue
222        let root_node = BbNode {
223            extra_lb: base_lb.clone(),
224            extra_ub: base_ub.clone(),
225            lb: root_result.fun,
226            depth: 0,
227        };
228
229        let mut queue: VecDeque<BbNode> = VecDeque::new();
230        queue.push_back(root_node);
231
232        while !queue.is_empty() && nodes_explored < self.options.max_nodes {
233            // Select node
234            let node = match self.options.node_selection {
235                NodeSelection::BestFirst => {
236                    // Find node with lowest lb
237                    let best_pos = queue
238                        .iter()
239                        .enumerate()
240                        .min_by(|(_, a), (_, b)| {
241                            a.lb.partial_cmp(&b.lb).unwrap_or(std::cmp::Ordering::Equal)
242                        })
243                        .map(|(i, _)| i)
244                        .unwrap_or(0);
245                    match queue.remove(best_pos) {
246                        Some(n) => n,
247                        None => match queue.pop_back() {
248                            Some(n) => n,
249                            None => break,
250                        },
251                    }
252                }
253                NodeSelection::DepthFirst | NodeSelection::BestOfDepth => {
254                    // DFS: pop from back
255                    match queue.pop_back() {
256                        Some(n) => n,
257                        None => break,
258                    }
259                }
260            };
261
262            nodes_explored += 1;
263
264            // Prune: node lower bound >= incumbent
265            if node.lb >= incumbent_obj - self.options.opt_tol {
266                continue;
267            }
268
269            // Solve LP at this node
270            let lp_result = match LpRelaxationSolver::solve(lp, &node.extra_lb, &node.extra_ub) {
271                Ok(r) => r,
272                Err(_) => continue,
273            };
274            lp_solves += 1;
275
276            if !lp_result.success {
277                // Node is infeasible
278                continue;
279            }
280
281            let node_obj = lp_result.fun;
282
283            // Update global lower bound
284            if node_obj > global_lb {
285                global_lb = node_obj;
286            }
287
288            // Prune: LP objective >= incumbent
289            if node_obj >= incumbent_obj - self.options.opt_tol {
290                continue;
291            }
292
293            let node_x: Vec<f64> = lp_result.x.to_vec();
294
295            // Check integrality
296            if self.is_integer_feasible(&node_x, ivs) {
297                let obj = Self::eval_obj(lp, &node_x);
298                if obj < incumbent_obj {
299                    incumbent_obj = obj;
300                    incumbent = Some(node_x);
301                    // Update global lb
302                    if queue.is_empty() {
303                        global_lb = incumbent_obj;
304                    }
305                }
306                continue;
307            }
308
309            // Select branching variable
310            let branch_var = match self.select_variable(&node_x, ivs) {
311                Some(v) => v,
312                None => {
313                    // All integer vars are integer (shouldn't happen, but handle gracefully)
314                    let rounded = self.round_integer_solution(&node_x, ivs);
315                    let obj = Self::eval_obj(lp, &rounded);
316                    if obj < incumbent_obj {
317                        incumbent_obj = obj;
318                        incumbent = Some(rounded);
319                    }
320                    continue;
321                }
322            };
323
324            let xi = node_x[branch_var];
325            let xi_floor = xi.floor();
326            let xi_ceil = xi.ceil();
327
328            // Branch down: x[branch_var] <= floor(xi)
329            let mut lb_down = node.extra_lb.clone();
330            let mut ub_down = node.extra_ub.clone();
331            ub_down[branch_var] = ub_down[branch_var].min(xi_floor);
332            if lb_down[branch_var] <= ub_down[branch_var] {
333                queue.push_back(BbNode {
334                    extra_lb: lb_down,
335                    extra_ub: ub_down,
336                    lb: node_obj,
337                    depth: node.depth + 1,
338                });
339            }
340
341            // Branch up: x[branch_var] >= ceil(xi)
342            let mut lb_up = node.extra_lb.clone();
343            let mut ub_up = node.extra_ub.clone();
344            lb_up[branch_var] = lb_up[branch_var].max(xi_ceil);
345            if lb_up[branch_var] <= ub_up[branch_var] {
346                queue.push_back(BbNode {
347                    extra_lb: lb_up,
348                    extra_ub: ub_up,
349                    lb: node_obj,
350                    depth: node.depth + 1,
351                });
352            }
353        }
354
355        match incumbent {
356            Some(x) => Ok(MipResult {
357                x: Array1::from_vec(x),
358                fun: incumbent_obj,
359                success: true,
360                message: format!(
361                    "Optimal solution found (nodes={}, lp_solves={})",
362                    nodes_explored, lp_solves
363                ),
364                nodes_explored,
365                lp_solves,
366                lower_bound: global_lb,
367            }),
368            None => Ok(MipResult {
369                x: Array1::zeros(n),
370                fun: f64::INFINITY,
371                success: false,
372                message: "No integer feasible solution found".to_string(),
373                nodes_explored,
374                lp_solves,
375                lower_bound: global_lb,
376            }),
377        }
378    }
379}
380
381impl Default for BranchAndBoundSolver {
382    fn default() -> Self {
383        BranchAndBoundSolver::new()
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use approx::assert_abs_diff_eq;
391    use scirs2_core::ndarray::{array, Array2};
392
393    /// Simple binary knapsack: maximize value subject to weight constraint
394    /// Formulated as: minimize -v^T x s.t. w^T x <= capacity, x in {0,1}^n
395    #[test]
396    fn test_branch_and_bound_binary_knapsack() {
397        // Items: (value, weight)
398        // (4, 2), (3, 3), (5, 4), (2, 1), (6, 5)
399        // Capacity: 8
400        // Optimal: take items 0, 3, 4 -> value = 4+2+6 = 12 or items 0,2,3 -> 4+5+2=11
401        // Actually: items 0,1,3,4: weight=2+3+1+5=11 > 8; items 0,2,3: w=2+4+1=7, v=4+5+2=11; items 0,3,4: w=2+1+5=8, v=12
402        // So optimal value = 12 (items 0,3,4)
403        let values = vec![4.0, 3.0, 5.0, 2.0, 6.0];
404        let weights = vec![2.0, 3.0, 4.0, 1.0, 5.0];
405        let capacity = 8.0;
406        let n = 5;
407
408        // minimize -values^T x
409        let c = array![-4.0, -3.0, -5.0, -2.0, -6.0];
410        let mut lp = LinearProgram::new(c);
411        lp.a_ub = Some(Array2::from_shape_vec((1, n), weights.clone()).expect("shape"));
412        lp.b_ub = Some(array![capacity]);
413        lp.lower = Some(array![0.0, 0.0, 0.0, 0.0, 0.0]);
414        lp.upper = Some(array![1.0, 1.0, 1.0, 1.0, 1.0]);
415
416        let ivs = IntegerVariableSet::all_binary(n);
417        let solver = BranchAndBoundSolver::new();
418        let result = solver.solve(&lp, &ivs).expect("solve failed");
419
420        assert!(result.success, "B&B should find solution");
421        // Optimal: -12
422        assert!(
423            result.fun <= -11.9,
424            "optimal value should be -12, got {}",
425            result.fun
426        );
427    }
428
429    #[test]
430    fn test_branch_and_bound_pure_integer() {
431        // minimize x[0] + x[1]
432        // subject to x[0] + x[1] >= 3.5 (so x_int: x[0]+x[1] >= 4)
433        // x >= 0, integer
434        let c = array![1.0, 1.0];
435        let mut lp = LinearProgram::new(c);
436        // -x[0] - x[1] <= -3.5  (i.e. x[0]+x[1] >= 3.5)
437        lp.a_ub = Some(Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("shape"));
438        lp.b_ub = Some(array![-3.5]);
439        lp.lower = Some(array![0.0, 0.0]);
440        lp.upper = Some(array![10.0, 10.0]);
441
442        let ivs = IntegerVariableSet::all_integer(2);
443        let solver = BranchAndBoundSolver::new();
444        let result = solver.solve(&lp, &ivs).expect("solve failed");
445
446        assert!(result.success);
447        // Integer optimal: x[0]+x[1] = 4, minimized -> 4
448        assert_abs_diff_eq!(result.fun, 4.0, epsilon = 1e-4);
449    }
450
451    #[test]
452    fn test_branch_and_bound_mixed_integer() {
453        // minimize 2x[0] + x[1]
454        // x[0] integer, x[1] continuous
455        // x[0] + x[1] >= 2.5
456        // x >= 0
457        let c = array![2.0, 1.0];
458        let mut lp = LinearProgram::new(c);
459        lp.a_ub = Some(Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("shape"));
460        lp.b_ub = Some(array![-2.5]);
461        lp.lower = Some(array![0.0, 0.0]);
462        lp.upper = Some(array![10.0, 10.0]);
463
464        let mut ivs = IntegerVariableSet::new(2);
465        ivs.set_kind(0, IntegerKind::Integer);
466
467        let solver = BranchAndBoundSolver::new();
468        let result = solver.solve(&lp, &ivs).expect("solve failed");
469
470        assert!(result.success);
471        // x[0]=0 int, x[1]=2.5 -> obj=2.5; or x[0]=1,x[1]=1.5->obj=3.5; opt is x[0]=0,x[1]=2.5
472        // Actually: x[0]=0 is integer, x[1]=2.5 is continuous: obj=2.5
473        assert!(result.fun <= 3.0, "fun={}", result.fun);
474    }
475
476    #[test]
477    fn test_branch_and_bound_depth_first() {
478        let opts = BranchAndBoundOptions {
479            node_selection: NodeSelection::DepthFirst,
480            ..Default::default()
481        };
482        let solver = BranchAndBoundSolver::with_options(opts);
483
484        let c = array![1.0, 1.0];
485        let mut lp = LinearProgram::new(c);
486        lp.a_ub = Some(Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("shape"));
487        lp.b_ub = Some(array![-2.7]);
488        lp.lower = Some(array![0.0, 0.0]);
489        lp.upper = Some(array![10.0, 10.0]);
490
491        let ivs = IntegerVariableSet::all_integer(2);
492        let result = solver.solve(&lp, &ivs).expect("solve failed");
493
494        assert!(result.success);
495        assert_abs_diff_eq!(result.fun, 3.0, epsilon = 1e-4);
496    }
497}