1use super::{
20 is_integer_valued,
21 lp_relaxation::{LpRelaxationSolver, LpResult},
22 most_fractional_variable, IntegerKind, IntegerVariableSet, LinearProgram, MipResult,
23};
24use crate::error::{OptimizeError, OptimizeResult};
25use scirs2_core::ndarray::Array1;
26use std::collections::VecDeque;
27
28#[derive(Debug, Clone)]
30pub struct BranchAndBoundOptions {
31 pub max_nodes: usize,
33 pub int_tol: f64,
35 pub opt_tol: f64,
37 pub max_lp_iter: usize,
39 pub node_selection: NodeSelection,
41 pub variable_selection: VariableSelection,
43}
44
45impl Default for BranchAndBoundOptions {
46 fn default() -> Self {
47 BranchAndBoundOptions {
48 max_nodes: 10000,
49 int_tol: 1e-6,
50 opt_tol: 1e-6,
51 max_lp_iter: 1000,
52 node_selection: NodeSelection::BestFirst,
53 variable_selection: VariableSelection::MostFractional,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum NodeSelection {
61 BestFirst,
63 DepthFirst,
65 BestOfDepth,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq)]
71pub enum VariableSelection {
72 MostFractional,
74 FirstFractional,
76 MaxRange,
78}
79
80#[derive(Debug, Clone)]
82struct BbNode {
83 extra_lb: Vec<f64>,
85 extra_ub: Vec<f64>,
87 lb: f64,
89 depth: usize,
91}
92
93pub struct BranchAndBoundSolver {
95 pub options: BranchAndBoundOptions,
96}
97
98impl BranchAndBoundSolver {
99 pub fn new() -> Self {
101 BranchAndBoundSolver {
102 options: BranchAndBoundOptions::default(),
103 }
104 }
105
106 pub fn with_options(options: BranchAndBoundOptions) -> Self {
108 BranchAndBoundSolver { options }
109 }
110
111 fn select_variable(&self, x: &[f64], ivs: &IntegerVariableSet) -> Option<usize> {
113 match self.options.variable_selection {
114 VariableSelection::MostFractional => most_fractional_variable(x, ivs),
115 VariableSelection::FirstFractional => {
116 for (i, &xi) in x.iter().enumerate() {
117 if ivs.is_integer(i) && !is_integer_valued(xi, self.options.int_tol) {
118 return Some(i);
119 }
120 }
121 None
122 }
123 VariableSelection::MaxRange => {
124 most_fractional_variable(x, ivs)
126 }
127 }
128 }
129
130 fn is_integer_feasible(&self, x: &[f64], ivs: &IntegerVariableSet) -> bool {
132 for (i, &xi) in x.iter().enumerate() {
133 if ivs.is_integer(i) && !is_integer_valued(xi, self.options.int_tol) {
134 return false;
135 }
136 }
137 true
138 }
139
140 fn round_integer_solution(&self, x: &[f64], ivs: &IntegerVariableSet) -> Vec<f64> {
142 x.iter()
143 .enumerate()
144 .map(|(i, &xi)| if ivs.is_integer(i) { xi.round() } else { xi })
145 .collect()
146 }
147
148 fn eval_obj(lp: &LinearProgram, x: &[f64]) -> f64 {
150 lp.c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum()
151 }
152
153 pub fn solve(&self, lp: &LinearProgram, ivs: &IntegerVariableSet) -> OptimizeResult<MipResult> {
155 let n = lp.n_vars();
156 if n == 0 {
157 return Err(OptimizeError::InvalidInput("Empty problem".to_string()));
158 }
159 if ivs.len() != n {
160 return Err(OptimizeError::InvalidInput(format!(
161 "IntegerVariableSet length {} != LP dimension {}",
162 ivs.len(),
163 n
164 )));
165 }
166
167 let base_lb: Vec<f64> = lp.lower.as_ref().map_or(vec![0.0; n], |l| l.to_vec());
168 let base_ub: Vec<f64> = lp
169 .upper
170 .as_ref()
171 .map_or(vec![f64::INFINITY; n], |u| u.to_vec());
172
173 let mut base_lb = base_lb;
175 let mut base_ub = base_ub;
176 for i in 0..n {
177 if ivs.kinds[i] == IntegerKind::Binary {
178 base_lb[i] = base_lb[i].max(0.0);
179 base_ub[i] = base_ub[i].min(1.0);
180 }
181 }
182
183 let root_result = LpRelaxationSolver::solve(lp, &base_lb, &base_ub)?;
185
186 if !root_result.success {
187 return Ok(MipResult {
188 x: Array1::zeros(n),
189 fun: f64::INFINITY,
190 success: false,
191 message: "LP relaxation infeasible".to_string(),
192 nodes_explored: 1,
193 lp_solves: 1,
194 lower_bound: f64::INFINITY,
195 });
196 }
197
198 let mut incumbent: Option<Vec<f64>> = None;
199 let mut incumbent_obj = f64::INFINITY;
200 let mut nodes_explored = 1usize;
201 let mut lp_solves = 1usize;
202 let mut global_lb = root_result.fun;
203
204 let root_x: Vec<f64> = root_result.x.to_vec();
206 if self.is_integer_feasible(&root_x, ivs) {
207 let obj = Self::eval_obj(lp, &root_x);
208 incumbent = Some(root_x.clone());
209 incumbent_obj = obj;
210 return Ok(MipResult {
211 x: Array1::from_vec(root_x),
212 fun: obj,
213 success: true,
214 message: "LP relaxation gives integer solution".to_string(),
215 nodes_explored,
216 lp_solves,
217 lower_bound: global_lb,
218 });
219 }
220
221 let root_node = BbNode {
223 extra_lb: base_lb.clone(),
224 extra_ub: base_ub.clone(),
225 lb: root_result.fun,
226 depth: 0,
227 };
228
229 let mut queue: VecDeque<BbNode> = VecDeque::new();
230 queue.push_back(root_node);
231
232 while !queue.is_empty() && nodes_explored < self.options.max_nodes {
233 let node = match self.options.node_selection {
235 NodeSelection::BestFirst => {
236 let best_pos = queue
238 .iter()
239 .enumerate()
240 .min_by(|(_, a), (_, b)| {
241 a.lb.partial_cmp(&b.lb).unwrap_or(std::cmp::Ordering::Equal)
242 })
243 .map(|(i, _)| i)
244 .unwrap_or(0);
245 match queue.remove(best_pos) {
246 Some(n) => n,
247 None => match queue.pop_back() {
248 Some(n) => n,
249 None => break,
250 },
251 }
252 }
253 NodeSelection::DepthFirst | NodeSelection::BestOfDepth => {
254 match queue.pop_back() {
256 Some(n) => n,
257 None => break,
258 }
259 }
260 };
261
262 nodes_explored += 1;
263
264 if node.lb >= incumbent_obj - self.options.opt_tol {
266 continue;
267 }
268
269 let lp_result = match LpRelaxationSolver::solve(lp, &node.extra_lb, &node.extra_ub) {
271 Ok(r) => r,
272 Err(_) => continue,
273 };
274 lp_solves += 1;
275
276 if !lp_result.success {
277 continue;
279 }
280
281 let node_obj = lp_result.fun;
282
283 if node_obj > global_lb {
285 global_lb = node_obj;
286 }
287
288 if node_obj >= incumbent_obj - self.options.opt_tol {
290 continue;
291 }
292
293 let node_x: Vec<f64> = lp_result.x.to_vec();
294
295 if self.is_integer_feasible(&node_x, ivs) {
297 let obj = Self::eval_obj(lp, &node_x);
298 if obj < incumbent_obj {
299 incumbent_obj = obj;
300 incumbent = Some(node_x);
301 if queue.is_empty() {
303 global_lb = incumbent_obj;
304 }
305 }
306 continue;
307 }
308
309 let branch_var = match self.select_variable(&node_x, ivs) {
311 Some(v) => v,
312 None => {
313 let rounded = self.round_integer_solution(&node_x, ivs);
315 let obj = Self::eval_obj(lp, &rounded);
316 if obj < incumbent_obj {
317 incumbent_obj = obj;
318 incumbent = Some(rounded);
319 }
320 continue;
321 }
322 };
323
324 let xi = node_x[branch_var];
325 let xi_floor = xi.floor();
326 let xi_ceil = xi.ceil();
327
328 let mut lb_down = node.extra_lb.clone();
330 let mut ub_down = node.extra_ub.clone();
331 ub_down[branch_var] = ub_down[branch_var].min(xi_floor);
332 if lb_down[branch_var] <= ub_down[branch_var] {
333 queue.push_back(BbNode {
334 extra_lb: lb_down,
335 extra_ub: ub_down,
336 lb: node_obj,
337 depth: node.depth + 1,
338 });
339 }
340
341 let mut lb_up = node.extra_lb.clone();
343 let mut ub_up = node.extra_ub.clone();
344 lb_up[branch_var] = lb_up[branch_var].max(xi_ceil);
345 if lb_up[branch_var] <= ub_up[branch_var] {
346 queue.push_back(BbNode {
347 extra_lb: lb_up,
348 extra_ub: ub_up,
349 lb: node_obj,
350 depth: node.depth + 1,
351 });
352 }
353 }
354
355 match incumbent {
356 Some(x) => Ok(MipResult {
357 x: Array1::from_vec(x),
358 fun: incumbent_obj,
359 success: true,
360 message: format!(
361 "Optimal solution found (nodes={}, lp_solves={})",
362 nodes_explored, lp_solves
363 ),
364 nodes_explored,
365 lp_solves,
366 lower_bound: global_lb,
367 }),
368 None => Ok(MipResult {
369 x: Array1::zeros(n),
370 fun: f64::INFINITY,
371 success: false,
372 message: "No integer feasible solution found".to_string(),
373 nodes_explored,
374 lp_solves,
375 lower_bound: global_lb,
376 }),
377 }
378 }
379}
380
381impl Default for BranchAndBoundSolver {
382 fn default() -> Self {
383 BranchAndBoundSolver::new()
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use approx::assert_abs_diff_eq;
391 use scirs2_core::ndarray::{array, Array2};
392
393 #[test]
396 fn test_branch_and_bound_binary_knapsack() {
397 let values = vec![4.0, 3.0, 5.0, 2.0, 6.0];
404 let weights = vec![2.0, 3.0, 4.0, 1.0, 5.0];
405 let capacity = 8.0;
406 let n = 5;
407
408 let c = array![-4.0, -3.0, -5.0, -2.0, -6.0];
410 let mut lp = LinearProgram::new(c);
411 lp.a_ub = Some(Array2::from_shape_vec((1, n), weights.clone()).expect("shape"));
412 lp.b_ub = Some(array![capacity]);
413 lp.lower = Some(array![0.0, 0.0, 0.0, 0.0, 0.0]);
414 lp.upper = Some(array![1.0, 1.0, 1.0, 1.0, 1.0]);
415
416 let ivs = IntegerVariableSet::all_binary(n);
417 let solver = BranchAndBoundSolver::new();
418 let result = solver.solve(&lp, &ivs).expect("solve failed");
419
420 assert!(result.success, "B&B should find solution");
421 assert!(
423 result.fun <= -11.9,
424 "optimal value should be -12, got {}",
425 result.fun
426 );
427 }
428
429 #[test]
430 fn test_branch_and_bound_pure_integer() {
431 let c = array![1.0, 1.0];
435 let mut lp = LinearProgram::new(c);
436 lp.a_ub = Some(Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("shape"));
438 lp.b_ub = Some(array![-3.5]);
439 lp.lower = Some(array![0.0, 0.0]);
440 lp.upper = Some(array![10.0, 10.0]);
441
442 let ivs = IntegerVariableSet::all_integer(2);
443 let solver = BranchAndBoundSolver::new();
444 let result = solver.solve(&lp, &ivs).expect("solve failed");
445
446 assert!(result.success);
447 assert_abs_diff_eq!(result.fun, 4.0, epsilon = 1e-4);
449 }
450
451 #[test]
452 fn test_branch_and_bound_mixed_integer() {
453 let c = array![2.0, 1.0];
458 let mut lp = LinearProgram::new(c);
459 lp.a_ub = Some(Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("shape"));
460 lp.b_ub = Some(array![-2.5]);
461 lp.lower = Some(array![0.0, 0.0]);
462 lp.upper = Some(array![10.0, 10.0]);
463
464 let mut ivs = IntegerVariableSet::new(2);
465 ivs.set_kind(0, IntegerKind::Integer);
466
467 let solver = BranchAndBoundSolver::new();
468 let result = solver.solve(&lp, &ivs).expect("solve failed");
469
470 assert!(result.success);
471 assert!(result.fun <= 3.0, "fun={}", result.fun);
474 }
475
476 #[test]
477 fn test_branch_and_bound_depth_first() {
478 let opts = BranchAndBoundOptions {
479 node_selection: NodeSelection::DepthFirst,
480 ..Default::default()
481 };
482 let solver = BranchAndBoundSolver::with_options(opts);
483
484 let c = array![1.0, 1.0];
485 let mut lp = LinearProgram::new(c);
486 lp.a_ub = Some(Array2::from_shape_vec((1, 2), vec![-1.0, -1.0]).expect("shape"));
487 lp.b_ub = Some(array![-2.7]);
488 lp.lower = Some(array![0.0, 0.0]);
489 lp.upper = Some(array![10.0, 10.0]);
490
491 let ivs = IntegerVariableSet::all_integer(2);
492 let result = solver.solve(&lp, &ivs).expect("solve failed");
493
494 assert!(result.success);
495 assert_abs_diff_eq!(result.fun, 3.0, epsilon = 1e-4);
496 }
497}