1use super::lift_project::{LiftProjectConfig, LiftProjectCut, LiftProjectGenerator};
25use crate::error::{OptimizeError, OptimizeResult};
26
27pub struct LiftProjectMipSolver {
34 config: LiftProjectConfig,
35 generator: LiftProjectGenerator,
36 cut_pool: Vec<LiftProjectCut>,
37 iterations: usize,
39 total_cuts_generated: usize,
41}
42
43impl LiftProjectMipSolver {
44 pub fn new(config: LiftProjectConfig) -> Self {
46 let generator = LiftProjectGenerator::new(config.clone());
47 LiftProjectMipSolver {
48 config,
49 generator,
50 cut_pool: Vec::new(),
51 iterations: 0,
52 total_cuts_generated: 0,
53 }
54 }
55
56 pub fn default_solver() -> Self {
58 LiftProjectMipSolver::new(LiftProjectConfig::default())
59 }
60
61 pub fn add_cuts_to_lp(
79 &mut self,
80 a: &[Vec<f64>],
81 b: &[f64],
82 x_bar: &[f64],
83 integer_vars: &[usize],
84 ) -> OptimizeResult<Vec<LiftProjectCut>> {
85 self.iterations += 1;
86
87 let new_cuts = self.generator.generate_cuts(a, b, x_bar, integer_vars)?;
88
89 if new_cuts.is_empty() {
90 return Ok(Vec::new());
91 }
92
93 let violated: Vec<LiftProjectCut> = new_cuts
96 .into_iter()
97 .filter(|c| {
98 let v = self.generator.cut_violation(c, x_bar);
99 v > self.config.cut_violation_tol
100 })
101 .collect();
102
103 self.total_cuts_generated += violated.len();
104 self.cut_pool.extend(violated.clone());
105
106 Ok(violated)
107 }
108
109 pub fn cut_pool_size(&self) -> usize {
113 self.cut_pool.len()
114 }
115
116 pub fn cut_pool(&self) -> &[LiftProjectCut] {
118 &self.cut_pool
119 }
120
121 pub fn clear_cut_pool(&mut self) {
123 self.cut_pool.clear();
124 }
125
126 pub fn purge_non_violated_cuts(&mut self, x_new: &[f64]) {
132 self.cut_pool.retain(|c| {
133 let v = self.generator.cut_violation(c, x_new);
134 v > self.config.cut_violation_tol
135 });
136 }
137
138 pub fn iterations(&self) -> usize {
142 self.iterations
143 }
144
145 pub fn total_cuts_generated(&self) -> usize {
147 self.total_cuts_generated
148 }
149
150 pub fn config(&self) -> &LiftProjectConfig {
152 &self.config
153 }
154
155 pub fn cut_violation(&self, cut: &LiftProjectCut, x_bar: &[f64]) -> f64 {
159 self.generator.cut_violation(cut, x_bar)
160 }
161
162 pub fn build_augmented_system(
174 &self,
175 a: &[Vec<f64>],
176 b: &[f64],
177 ) -> OptimizeResult<(Vec<Vec<f64>>, Vec<f64>)> {
178 if a.len() != b.len() {
179 return Err(OptimizeError::InvalidInput(format!(
180 "Constraint matrix has {} rows but b has {} entries",
181 a.len(),
182 b.len()
183 )));
184 }
185
186 let n = if a.is_empty() {
187 self.cut_pool.first().map_or(0, |c| c.pi.len())
189 } else {
190 a[0].len()
191 };
192
193 let mut a_aug: Vec<Vec<f64>> = a.to_vec();
194 let mut b_aug: Vec<f64> = b.to_vec();
195
196 for cut in &self.cut_pool {
197 if cut.pi.len() != n {
198 return Err(OptimizeError::InvalidInput(format!(
199 "Cut has {} coefficients but constraint matrix has {} columns",
200 cut.pi.len(),
201 n
202 )));
203 }
204 let neg_pi: Vec<f64> = cut.pi.iter().map(|&p| -p).collect();
206 a_aug.push(neg_pi);
207 b_aug.push(-cut.pi0);
208 }
209
210 Ok((a_aug, b_aug))
211 }
212}
213
214#[cfg(test)]
219mod tests {
220 use super::*;
221
222 fn make_fractional_lp() -> (Vec<Vec<f64>>, Vec<f64>, Vec<f64>, Vec<usize>) {
223 let a = vec![vec![1.0, 1.0]];
225 let b = vec![1.0];
226 let x_bar = vec![0.4, 0.6];
227 let ivars = vec![0, 1];
228 (a, b, x_bar, ivars)
229 }
230
231 #[test]
232 fn test_add_cuts_increases_pool_size() {
233 let mut solver = LiftProjectMipSolver::default_solver();
234 let (a, b, x_bar, ivars) = make_fractional_lp();
235 assert_eq!(solver.cut_pool_size(), 0);
236 let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
237 assert_eq!(solver.cut_pool_size(), cuts.len());
238 assert!(solver.cut_pool_size() > 0, "Expected cuts to be generated");
239 }
240
241 #[test]
242 fn test_add_cuts_returns_violated_cuts() {
243 let mut solver = LiftProjectMipSolver::default_solver();
244 let (a, b, x_bar, ivars) = make_fractional_lp();
245 let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
246 for cut in &cuts {
247 let v = solver.cut_violation(cut, &x_bar);
248 assert!(
249 v > solver.config().cut_violation_tol,
250 "Returned cut should be violated at x_bar, got v={}",
251 v
252 );
253 }
254 }
255
256 #[test]
257 fn test_add_cuts_empty_for_integer_solution() {
258 let mut solver = LiftProjectMipSolver::default_solver();
259 let (a, b, _, ivars) = make_fractional_lp();
260 let x_int = vec![1.0, 0.0]; let cuts = solver.add_cuts_to_lp(&a, &b, &x_int, &ivars).unwrap();
262 assert!(cuts.is_empty());
263 assert_eq!(solver.cut_pool_size(), 0);
264 }
265
266 #[test]
267 fn test_clear_cut_pool_resets_size() {
268 let mut solver = LiftProjectMipSolver::default_solver();
269 let (a, b, x_bar, ivars) = make_fractional_lp();
270 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
271 assert!(solver.cut_pool_size() > 0);
272 solver.clear_cut_pool();
273 assert_eq!(solver.cut_pool_size(), 0);
274 }
275
276 #[test]
277 fn test_iterations_counter_increments() {
278 let mut solver = LiftProjectMipSolver::default_solver();
279 let (a, b, x_bar, ivars) = make_fractional_lp();
280 assert_eq!(solver.iterations(), 0);
281 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
282 assert_eq!(solver.iterations(), 1);
283 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
284 assert_eq!(solver.iterations(), 2);
285 }
286
287 #[test]
288 fn test_total_cuts_generated_accumulates() {
289 let mut solver = LiftProjectMipSolver::default_solver();
290 let (a, b, x_bar, ivars) = make_fractional_lp();
291 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
292 let after_first = solver.total_cuts_generated();
293 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
294 let after_second = solver.total_cuts_generated();
295 assert!(after_second >= after_first);
296 }
297
298 #[test]
299 fn test_pool_accumulates_across_calls() {
300 let mut solver = LiftProjectMipSolver::default_solver();
301 let (a, b, x_bar, ivars) = make_fractional_lp();
302 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
303 let size_after_first = solver.cut_pool_size();
304 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
305 let size_after_second = solver.cut_pool_size();
306 assert!(size_after_second >= size_after_first);
307 }
308
309 #[test]
310 fn test_purge_non_violated_cuts() {
311 let mut solver = LiftProjectMipSolver::default_solver();
312 let (a, b, x_bar, ivars) = make_fractional_lp();
313 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
314 let size_before = solver.cut_pool_size();
315 let x_int = vec![1.0, 0.0];
318 solver.purge_non_violated_cuts(&x_int);
319 let size_after = solver.cut_pool_size();
320 assert!(
321 size_after <= size_before,
322 "Pool should not grow after purge"
323 );
324 }
325
326 #[test]
327 fn test_build_augmented_system_appends_cuts() {
328 let mut solver = LiftProjectMipSolver::default_solver();
329 let (a, b, x_bar, ivars) = make_fractional_lp();
330 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
331 let n_original = a.len();
332 let n_cuts = solver.cut_pool_size();
333 let (a_aug, b_aug) = solver.build_augmented_system(&a, &b).unwrap();
334 assert_eq!(a_aug.len(), n_original + n_cuts);
335 assert_eq!(b_aug.len(), n_original + n_cuts);
336 }
337
338 #[test]
339 fn test_build_augmented_system_negates_cuts() {
340 let mut solver = LiftProjectMipSolver::default_solver();
341 let (a, b, x_bar, ivars) = make_fractional_lp();
342 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
343 let (a_aug, b_aug) = solver.build_augmented_system(&a, &b).unwrap();
344 let n_orig = a.len();
345 for (k, cut) in solver.cut_pool().iter().enumerate() {
347 let row = &a_aug[n_orig + k];
348 let rhs = b_aug[n_orig + k];
349 for (j, (&aug_coeff, &pi_k)) in row.iter().zip(cut.pi.iter()).enumerate() {
350 assert!(
351 (aug_coeff - (-pi_k)).abs() < 1e-12,
352 "Augmented row coeff [{}][{}] = {} but expected {}",
353 k,
354 j,
355 aug_coeff,
356 -pi_k
357 );
358 }
359 assert!(
360 (rhs - (-cut.pi0)).abs() < 1e-12,
361 "Augmented RHS = {} but expected {}",
362 rhs,
363 -cut.pi0
364 );
365 }
366 }
367
368 #[test]
369 fn test_build_augmented_system_error_on_mismatched_a_b() {
370 let solver = LiftProjectMipSolver::default_solver();
371 let a = vec![vec![1.0, 1.0], vec![0.0, 1.0]];
372 let b = vec![1.0]; let result = solver.build_augmented_system(&a, &b);
374 assert!(result.is_err());
375 }
376
377 #[test]
378 fn test_cut_pool_accessor_matches_pool_size() {
379 let mut solver = LiftProjectMipSolver::default_solver();
380 let (a, b, x_bar, ivars) = make_fractional_lp();
381 solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
382 assert_eq!(solver.cut_pool().len(), solver.cut_pool_size());
383 }
384
385 #[test]
386 fn test_config_accessor() {
387 let config = LiftProjectConfig {
388 max_cuts: 7,
389 cut_violation_tol: 1e-5,
390 ..Default::default()
391 };
392 let solver = LiftProjectMipSolver::new(config.clone());
393 assert_eq!(solver.config().max_cuts, 7);
394 assert!((solver.config().cut_violation_tol - 1e-5).abs() < 1e-12);
395 }
396
397 #[test]
398 fn test_multiple_constraint_rows_generate_more_cuts() {
399 let mut solver = LiftProjectMipSolver::default_solver();
400 let a = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
401 let b = vec![0.8, 0.8, 1.2];
402 let x_bar = vec![0.4, 0.5];
403 let ivars = vec![0, 1];
404 let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &ivars).unwrap();
405 assert!(!cuts.is_empty());
407 }
408
409 #[test]
410 fn test_solver_handles_no_integer_vars_gracefully() {
411 let mut solver = LiftProjectMipSolver::default_solver();
412 let a = vec![vec![1.0, 1.0]];
413 let b = vec![1.0];
414 let x_bar = vec![0.4, 0.6];
415 let cuts = solver.add_cuts_to_lp(&a, &b, &x_bar, &[]).unwrap();
417 assert!(cuts.is_empty());
418 }
419}