Skip to main content

scirs2_optimize/integer/
lp_relaxation.rs

1//! LP relaxation solver for use in branch-and-bound
2//!
3//! Solves linear programs using a simplex method with big-M phase I.
4//! This is an internal solver used by the branch-and-bound and cutting plane methods.
5
6use super::LinearProgram;
7use crate::error::{OptimizeError, OptimizeResult};
8use scirs2_core::ndarray::Array1;
9
10/// Result from LP solve
11#[derive(Debug, Clone)]
12pub struct LpResult {
13    /// Optimal solution
14    pub x: Array1<f64>,
15    /// Optimal objective value
16    pub fun: f64,
17    /// Whether LP is feasible and bounded
18    pub success: bool,
19    /// Status: 0=optimal, 1=infeasible, 2=unbounded
20    pub status: i32,
21}
22
23/// Simple LP relaxation solver
24pub struct LpRelaxationSolver;
25
26impl LpRelaxationSolver {
27    /// Solve a linear program using a simplex-based method.
28    ///
29    /// Handles:
30    ///   min c^T x
31    ///   s.t. A_ub x <= b_ub  (inequality constraints)
32    ///        A_eq x  = b_eq  (equality constraints)
33    ///        lb <= x <= ub
34    pub fn solve(
35        lp: &LinearProgram,
36        extra_lb: &[f64],
37        extra_ub: &[f64],
38    ) -> OptimizeResult<LpResult> {
39        let n = lp.n_vars();
40        if n == 0 {
41            return Err(OptimizeError::InvalidInput("Empty LP".to_string()));
42        }
43
44        // Combine bounds
45        let lb: Vec<f64> = (0..n)
46            .map(|i| {
47                let base = lp.lower.as_ref().map_or(0.0, |l| l[i]);
48                if i < extra_lb.len() {
49                    base.max(extra_lb[i])
50                } else {
51                    base
52                }
53            })
54            .collect();
55
56        let ub: Vec<f64> = (0..n)
57            .map(|i| {
58                let base = lp.upper.as_ref().map_or(f64::INFINITY, |u| u[i]);
59                if i < extra_ub.len() {
60                    base.min(extra_ub[i])
61                } else {
62                    base
63                }
64            })
65            .collect();
66
67        // Check bound feasibility
68        for i in 0..n {
69            if lb[i] > ub[i] + 1e-10 {
70                return Ok(LpResult {
71                    x: Array1::zeros(n),
72                    fun: f64::INFINITY,
73                    success: false,
74                    status: 1,
75                });
76            }
77        }
78
79        // Gather inequality constraints
80        let (a_ub_rows, b_ub_vec) = match (&lp.a_ub, &lp.b_ub) {
81            (Some(a), Some(b)) => {
82                let m = a.nrows();
83                let rows: Vec<Vec<f64>> = (0..m)
84                    .map(|i| (0..n).map(|j| a[[i, j]]).collect())
85                    .collect();
86                let bv: Vec<f64> = b.to_vec();
87                (rows, bv)
88            }
89            _ => (vec![], vec![]),
90        };
91
92        // Gather equality constraints
93        let (a_eq_rows, b_eq_vec) = match (&lp.a_eq, &lp.b_eq) {
94            (Some(a), Some(b)) => {
95                let m = a.nrows();
96                let rows: Vec<Vec<f64>> = (0..m)
97                    .map(|i| (0..n).map(|j| a[[i, j]]).collect())
98                    .collect();
99                let bv: Vec<f64> = b.to_vec();
100                (rows, bv)
101            }
102            _ => (vec![], vec![]),
103        };
104
105        let c_vec: Vec<f64> = lp.c.to_vec();
106
107        let sol = solve_lp_simplex(
108            &c_vec, &a_ub_rows, &b_ub_vec, &a_eq_rows, &b_eq_vec, &lb, &ub,
109        );
110
111        Ok(sol)
112    }
113}
114
115/// Internal LP solver using simplex with big-M method.
116///
117/// min c^T x  s.t.  A_ub x <= b_ub, A_eq x = b_eq, lb <= x <= ub
118fn solve_lp_simplex(
119    c: &[f64],
120    a_ub: &[Vec<f64>],
121    b_ub: &[f64],
122    a_eq: &[Vec<f64>],
123    b_eq: &[f64],
124    lb: &[f64],
125    ub: &[f64],
126) -> LpResult {
127    let n = c.len();
128    let m_ub = a_ub.len();
129    let m_eq = a_eq.len();
130
131    if n == 0 {
132        return LpResult {
133            x: Array1::zeros(0),
134            fun: 0.0,
135            success: true,
136            status: 0,
137        };
138    }
139
140    // --- Variable shift: y = x - lb, so y >= 0, y <= ub - lb -----------
141    let ub_shifted: Vec<f64> = (0..n)
142        .map(|i| {
143            if ub[i].is_finite() {
144                ub[i] - lb[i]
145            } else {
146                f64::INFINITY
147            }
148        })
149        .collect();
150
151    // Shift RHS of inequality constraints: A(y+lb) <= b => Ay <= b - A*lb
152    let b_ub_shifted: Vec<f64> = (0..m_ub)
153        .map(|i| {
154            b_ub[i]
155                - a_ub[i]
156                    .iter()
157                    .zip(lb.iter())
158                    .map(|(&a, &l)| a * l)
159                    .sum::<f64>()
160        })
161        .collect();
162
163    // Shift RHS of equality constraints: A_eq(y+lb) = b_eq => A_eq y = b_eq - A_eq*lb
164    let b_eq_shifted: Vec<f64> = (0..m_eq)
165        .map(|i| {
166            b_eq[i]
167                - a_eq[i]
168                    .iter()
169                    .zip(lb.iter())
170                    .map(|(&a, &l)| a * l)
171                    .sum::<f64>()
172        })
173        .collect();
174
175    // --- Count finite upper bounds for UB slack variables ----------------
176    let n_ub_constrained: usize = ub_shifted.iter().filter(|&&u| u.is_finite()).count();
177
178    // Variables: y_0..y_{n-1}, ineq_slack s_0..s_{m_ub-1}, ub_slack t_0..t_{n_ub-1}
179    let n_struct = n;
180    let n_ineq_slack = m_ub;
181    let n_ub_slack = n_ub_constrained;
182    let n_total = n_struct + n_ineq_slack + n_ub_slack;
183
184    // Map: structural variable j -> UB slack column index
185    let mut ub_slack_col: Vec<Option<usize>> = vec![None; n];
186    let mut ub_col_idx = n_struct + n_ineq_slack;
187    for j in 0..n {
188        if ub_shifted[j].is_finite() {
189            ub_slack_col[j] = Some(ub_col_idx);
190            ub_col_idx += 1;
191        }
192    }
193
194    // Total constraint rows: m_ub (inequality) + n_ub (UB) + m_eq (equality)
195    let total_rows = m_ub + n_ub_constrained + m_eq;
196
197    // Build full constraint matrix and RHS
198    let mut full_a: Vec<Vec<f64>> = vec![vec![0.0; n_total]; total_rows];
199    let mut full_b: Vec<f64> = vec![0.0; total_rows];
200
201    // Fill inequality rows: A y + s = b_shifted  (s >= 0)
202    for i in 0..m_ub {
203        for j in 0..n {
204            full_a[i][j] = a_ub[i][j];
205        }
206        full_a[i][n_struct + i] = 1.0; // slack
207        full_b[i] = b_ub_shifted[i];
208    }
209
210    // Fill UB bound rows: y_j + t_j = ub_shifted[j]  (t_j >= 0)
211    let mut ub_row = m_ub;
212    for j in 0..n {
213        if let Some(sk_col) = ub_slack_col[j] {
214            full_a[ub_row][j] = 1.0;
215            full_a[ub_row][sk_col] = 1.0;
216            full_b[ub_row] = ub_shifted[j];
217            ub_row += 1;
218        }
219    }
220
221    // Fill equality rows: A_eq y = b_eq_shifted (no slacks, need artificials)
222    for i in 0..m_eq {
223        let row_idx = m_ub + n_ub_constrained + i;
224        for j in 0..n {
225            full_a[row_idx][j] = a_eq[i][j];
226        }
227        full_b[row_idx] = b_eq_shifted[i];
228    }
229
230    // --- Handle negative RHS by flipping rows + adding artificials -------
231    let mut needs_artif = vec![false; total_rows];
232    for i in 0..total_rows {
233        if full_b[i] < -1e-12 {
234            // Negate the row
235            for v in full_a[i].iter_mut() {
236                *v = -*v;
237            }
238            full_b[i] = -full_b[i];
239            needs_artif[i] = true;
240        }
241    }
242    // Equality constraint rows always need artificials (no slack in basis)
243    for i in 0..m_eq {
244        let row_idx = m_ub + n_ub_constrained + i;
245        needs_artif[row_idx] = true;
246    }
247
248    let n_artif: usize = needs_artif.iter().filter(|&&v| v).count();
249    let n_total_ext = n_total + n_artif;
250
251    // Build artificial column map: row -> artificial column index
252    let mut artif_col_map: Vec<Option<usize>> = vec![None; total_rows];
253    let mut acol = n_total;
254    for i in 0..total_rows {
255        if needs_artif[i] {
256            artif_col_map[i] = Some(acol);
257            acol += 1;
258        }
259    }
260
261    // Extend rows with artificial columns
262    for row in full_a.iter_mut() {
263        row.resize(n_total_ext, 0.0);
264    }
265    for i in 0..total_rows {
266        if let Some(ac) = artif_col_map[i] {
267            full_a[i][ac] = 1.0;
268        }
269    }
270
271    // Objective: original c for structural vars, 0 for slacks, big-M for artificials
272    let big_m = 1e7_f64;
273    let mut obj_c: Vec<f64> = vec![0.0; n_total_ext];
274    for j in 0..n {
275        obj_c[j] = c[j];
276    }
277    for i in 0..total_rows {
278        if let Some(ac) = artif_col_map[i] {
279            obj_c[ac] = big_m;
280        }
281    }
282
283    // Initial basis
284    let mut basis: Vec<usize> = Vec::with_capacity(total_rows);
285    // Inequality rows: use slack (or artificial if flipped)
286    for i in 0..m_ub {
287        if needs_artif[i] {
288            basis.push(artif_col_map[i].unwrap_or(0));
289        } else {
290            basis.push(n_struct + i); // ineq slack
291        }
292    }
293    // UB rows: use UB slack (or artificial if flipped -- rare for UB rows)
294    let mut ub_row_counter = 0usize;
295    for j in 0..n {
296        if let Some(sk_col) = ub_slack_col[j] {
297            let row_idx = m_ub + ub_row_counter;
298            if needs_artif[row_idx] {
299                basis.push(artif_col_map[row_idx].unwrap_or(0));
300            } else {
301                basis.push(sk_col);
302            }
303            ub_row_counter += 1;
304        }
305    }
306    // Equality rows: always use artificial
307    for i in 0..m_eq {
308        let row_idx = m_ub + n_ub_constrained + i;
309        basis.push(artif_col_map[row_idx].unwrap_or(0));
310    }
311
312    // Run simplex
313    let simplex_result = run_simplex(&mut full_a, &mut full_b, &obj_c, &mut basis, 20_000);
314
315    if simplex_result == SimplexStatus::Unbounded {
316        return LpResult {
317            x: Array1::zeros(n),
318            fun: f64::NEG_INFINITY,
319            success: false,
320            status: 2,
321        };
322    }
323
324    // Check that no artificial is in the basis with positive value
325    for (i, &bv) in basis.iter().enumerate() {
326        if bv >= n_total && full_b[i] > 1e-6 {
327            return LpResult {
328                x: Array1::zeros(n),
329                fun: f64::INFINITY,
330                success: false,
331                status: 1,
332            };
333        }
334    }
335
336    // Extract y values (shifted structural variables)
337    let mut y = vec![0.0_f64; n];
338    for (i, &bv) in basis.iter().enumerate() {
339        if bv < n {
340            y[bv] = full_b[i].max(0.0);
341        }
342    }
343
344    // Shift back: x = y + lb, clamped to [lb, ub]
345    let x: Vec<f64> = (0..n)
346        .map(|j| (lb[j] + y[j]).max(lb[j]).min(ub[j]))
347        .collect();
348
349    let fun: f64 = c.iter().zip(x.iter()).map(|(&ci, &xi)| ci * xi).sum();
350
351    LpResult {
352        x: Array1::from_vec(x),
353        fun,
354        success: true,
355        status: 0,
356    }
357}
358
359#[derive(Debug, PartialEq)]
360enum SimplexStatus {
361    Optimal,
362    Unbounded,
363    MaxIter,
364}
365
366/// Run the simplex method on the tableau.
367///
368/// The system is `A x = b`, `x >= 0`, minimize `c^T x`.
369/// `basis` holds the indices of the initial basic variables.
370/// On return, `a` and `b` are updated in place (tableau form), and `basis` holds the final basis.
371fn run_simplex(
372    a: &mut Vec<Vec<f64>>,
373    b: &mut Vec<f64>,
374    c: &[f64],
375    basis: &mut Vec<usize>,
376    max_iter: usize,
377) -> SimplexStatus {
378    let m = a.len();
379    if m == 0 {
380        return SimplexStatus::Optimal;
381    }
382    let n_total = a[0].len();
383
384    // Build full tableau: [A | b]
385    let mut tableau: Vec<Vec<f64>> = vec![vec![0.0; n_total + 1]; m];
386    for i in 0..m {
387        for j in 0..n_total {
388            tableau[i][j] = a[i][j];
389        }
390        tableau[i][n_total] = b[i];
391    }
392
393    // Make basis columns identity via row operations
394    for col in 0..m {
395        let basic = basis[col];
396        let pivot_val = tableau[col][basic];
397        if pivot_val.abs() < 1e-12 {
398            // Try to find a different row for this basis element
399            let mut found = false;
400            for i in 0..m {
401                if i != col && tableau[i][basic].abs() > 1e-10 {
402                    tableau.swap(col, i);
403                    basis.swap(col, i);
404                    found = true;
405                    break;
406                }
407            }
408            if !found {
409                continue;
410            }
411        }
412        let pv = tableau[col][basic];
413        if pv.abs() < 1e-12 {
414            continue;
415        }
416        for j in 0..=n_total {
417            tableau[col][j] /= pv;
418        }
419        for i in 0..m {
420            if i == col {
421                continue;
422            }
423            let factor = tableau[i][basic];
424            if factor.abs() < 1e-15 {
425                continue;
426            }
427            for j in 0..=n_total {
428                let delta = factor * tableau[col][j];
429                tableau[i][j] -= delta;
430            }
431        }
432    }
433
434    // Simplex iterations
435    let mut status = SimplexStatus::MaxIter;
436    for _iter in 0..max_iter {
437        // Compute reduced costs using current basis
438        let c_b: Vec<f64> = basis
439            .iter()
440            .map(|&bv| c.get(bv).copied().unwrap_or(0.0))
441            .collect();
442
443        // Find entering variable (most negative reduced cost)
444        let mut enter = None;
445        let mut min_rc = -1e-8_f64;
446        for j in 0..n_total {
447            let rc = c.get(j).copied().unwrap_or(0.0)
448                - c_b
449                    .iter()
450                    .zip(tableau.iter())
451                    .map(|(&cb, row)| cb * row[j])
452                    .sum::<f64>();
453            if rc < min_rc {
454                min_rc = rc;
455                enter = Some(j);
456            }
457        }
458
459        let enter_col = match enter {
460            None => {
461                status = SimplexStatus::Optimal;
462                break;
463            }
464            Some(j) => j,
465        };
466
467        // Minimum ratio test
468        let mut leave_row = None;
469        let mut min_ratio = f64::INFINITY;
470        for i in 0..m {
471            let coef = tableau[i][enter_col];
472            if coef > 1e-10 {
473                let ratio = tableau[i][n_total] / coef;
474                if ratio < min_ratio - 1e-12 {
475                    min_ratio = ratio;
476                    leave_row = Some(i);
477                } else if (ratio - min_ratio).abs() < 1e-12 {
478                    // Bland's rule: prefer smaller index
479                    if let Some(prev) = leave_row {
480                        if basis[i] < basis[prev] {
481                            leave_row = Some(i);
482                        }
483                    }
484                }
485            }
486        }
487
488        let pivot_row = match leave_row {
489            None => {
490                status = SimplexStatus::Unbounded;
491                break;
492            }
493            Some(r) => r,
494        };
495
496        // Pivot
497        let pv = tableau[pivot_row][enter_col];
498        for j in 0..=n_total {
499            tableau[pivot_row][j] /= pv;
500        }
501        for i in 0..m {
502            if i == pivot_row {
503                continue;
504            }
505            let factor = tableau[i][enter_col];
506            if factor.abs() < 1e-15 {
507                continue;
508            }
509            for j in 0..=n_total {
510                let delta = factor * tableau[pivot_row][j];
511                tableau[i][j] -= delta;
512            }
513        }
514        basis[pivot_row] = enter_col;
515    }
516
517    // Write back modified b and a
518    for i in 0..m {
519        b[i] = tableau[i][n_total];
520        for j in 0..n_total {
521            a[i][j] = tableau[i][j];
522        }
523    }
524
525    status
526}