1use std::{
4 io,
5 marker::PhantomData,
6 ops::{Not, Range},
7 sync::{
8 atomic::{AtomicBool, Ordering},
9 Arc,
10 },
11};
12
13#[cfg(feature = "interrupt-oracle")]
14use std::sync::Mutex;
15
16use anyhow::Context;
17use cadical_veripb_tracer::CadicalCertCollector;
18use rustsat::{
19 encodings::{card::Totalizer, pb::GeneralizedTotalizer},
20 instances::{Cnf, ManageVars},
21 solvers::{
22 DefaultInitializer, Initialize, SolveIncremental, SolveStats, SolverResult, SolverStats,
23 },
24 types::{Assignment, Clause, Lit, TernaryVal},
25};
26use scuttle_proc::oracle_bounds;
27
28#[cfg(feature = "sol-tightening")]
29use maxpre::PreproClauses;
30
31use crate::{
32 options::{CoreBoostingOptions, EnumOptions},
33 types::{Instance, NonDomPoint, ObjEncoding, Objective, ParetoFront, VarManager},
34 EncodingStats, KernelOptions, Limits, MaybeTerminated,
35 MaybeTerminatedError::{self, Done, Error, Terminated},
36 Phase, Stats, Termination, WriteSolverLog,
37};
38
39pub mod bioptsat;
40pub mod lowerbounding;
41pub mod pminimal;
42
43mod coreboosting;
44mod coreguided;
45pub(crate) mod proofs;
46pub use proofs::{InitCert, InitCertDefaultBlock};
47
48pub trait Init: Sized {
50 type Oracle: SolveIncremental;
51 type BlockClauseGen: Fn(Assignment) -> Clause;
52
53 fn new<Cls>(
55 clauses: Cls,
56 objs: Vec<Objective>,
57 var_manager: VarManager,
58 opts: KernelOptions,
59 block_clause_gen: Self::BlockClauseGen,
60 ) -> anyhow::Result<Self>
61 where
62 Cls: IntoIterator<Item = Clause>;
63
64 fn from_instance(
66 inst: Instance,
67 opts: KernelOptions,
68 block_clause_gen: Self::BlockClauseGen,
69 ) -> anyhow::Result<Self> {
70 Self::new(
71 inst.clauses.into_iter().map(|(cl, _)| cl),
72 inst.objs,
73 inst.vm,
74 opts,
75 block_clause_gen,
76 )
77 }
78}
79
80pub trait InitDefaultBlock: Init<BlockClauseGen = fn(Assignment) -> Clause> {
81 fn new_default_blocking<Cls>(
83 clauses: Cls,
84 objs: Vec<Objective>,
85 var_manager: VarManager,
86 opts: KernelOptions,
87 ) -> anyhow::Result<Self>
88 where
89 Cls: IntoIterator<Item = Clause>,
90 {
91 Self::new(clauses, objs, var_manager, opts, default_blocking_clause)
92 }
93
94 fn from_instance_default_blocking(inst: Instance, opts: KernelOptions) -> anyhow::Result<Self> {
97 Self::new(
98 inst.clauses.into_iter().map(|(cl, _)| cl),
99 inst.objs,
100 inst.vm,
101 opts,
102 default_blocking_clause,
103 )
104 }
105}
106
107impl<Alg> InitDefaultBlock for Alg where Alg: Init<BlockClauseGen = fn(Assignment) -> Clause> {}
108
109pub trait Solve: KernelFunctions {
111 fn solve(&mut self, limits: Limits) -> MaybeTerminatedError;
114 fn all_stats(&self) -> (Stats, Option<SolverStats>, Option<Vec<EncodingStats>>);
116}
117
118pub trait CoreBoost {
120 fn core_boost(&mut self, opts: CoreBoostingOptions) -> MaybeTerminatedError<bool>;
122}
123
124pub trait KernelFunctions {
126 fn pareto_front(&self) -> ParetoFront;
128 fn stats(&self) -> Stats;
130 fn attach_logger<L: WriteSolverLog + 'static>(&mut self, logger: L);
132 fn detach_logger(&mut self) -> Option<Box<dyn WriteSolverLog>>;
134 fn interrupter(&mut self) -> Interrupter;
136}
137
138pub struct Interrupter {
139 term_flag: Arc<AtomicBool>,
141 #[cfg(feature = "interrupt-oracle")]
143 oracle_interrupter: Arc<Mutex<Box<dyn rustsat::solvers::InterruptSolver + Send>>>,
144}
145
146#[cfg(feature = "interrupt-oracle")]
147impl Interrupter {
148 pub fn interrupt(&mut self) {
150 self.term_flag.store(true, Ordering::Relaxed);
151 self.oracle_interrupter.lock().unwrap().interrupt();
152 }
153}
154
155#[cfg(not(feature = "interrupt-oracle"))]
156impl Interrupter {
157 pub fn interrupt(&mut self) {
159 self.term_flag.store(true, Ordering::Relaxed);
160 }
161}
162
163struct Kernel<O, ProofW, OInit = DefaultInitializer, BCG = fn(Assignment) -> Clause>
172where
173 ProofW: io::Write,
174{
175 oracle: O,
177 var_manager: VarManager,
179 #[cfg(feature = "sol-tightening")]
180 obj_lit_data: rustsat::types::RsHashMap<Lit, crate::types::ObjLitData>,
182 objs: Vec<Objective>,
184 orig_cnf: Option<Cnf>,
186 block_clause_gen: BCG,
188 opts: KernelOptions,
190 stats: Stats,
192 lims: Limits,
194 #[cfg(feature = "maxpre")]
196 inpro: Option<maxpre::MaxPre>,
197 logger: Option<Box<dyn WriteSolverLog>>,
199 term_flag: Arc<AtomicBool>,
201 #[cfg(feature = "interrupt-oracle")]
203 oracle_interrupter: Arc<Mutex<Box<dyn rustsat::solvers::InterruptSolver + Send>>>,
204 proof_stuff: Option<proofs::ProofStuff<ProofW>>,
206 _factory: PhantomData<OInit>,
208}
209
210#[oracle_bounds]
211impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
212where
213 O: SolveIncremental,
214 ProofW: io::Write,
215 OInit: Initialize<O>,
216 BCG: Fn(Assignment) -> Clause,
217{
218 pub fn new<Cls>(
219 clauses: Cls,
220 objs: Vec<Objective>,
221 var_manager: VarManager,
222 bcg: BCG,
223 opts: KernelOptions,
224 ) -> anyhow::Result<Self>
225 where
226 Cls: IntoIterator<Item = Clause>,
227 {
228 let mut stats = Stats {
229 n_objs: 0,
230 n_real_objs: 0,
231 n_orig_clauses: 0,
232 ..Default::default()
233 };
234 let mut oracle = OInit::init();
235 oracle.reserve(var_manager.max_var().unwrap())?;
236 let orig_cnf = if opts.store_cnf {
237 let cnf: Cnf = clauses.into_iter().collect();
238 stats.n_orig_clauses = cnf.len();
239 oracle.add_cnf_ref(&cnf)?;
240 Some(cnf)
241 } else {
242 for cl in clauses.into_iter() {
243 stats.n_orig_clauses += 1;
244 oracle.add_clause(cl)?;
245 }
246 None
247 };
248 stats.n_objs = objs.len();
249 stats.n_real_objs = objs.iter().fold(0, |cnt, o| {
250 if matches!(o, Objective::Constant { .. }) {
251 cnt
252 } else {
253 cnt + 1
254 }
255 });
256 #[cfg(feature = "sol-tightening")]
258 let obj_lit_data = {
259 use crate::types::ObjLitData;
260 use rustsat::types::RsHashMap;
261 let mut obj_lit_data: RsHashMap<_, ObjLitData> = RsHashMap::default();
262 for (idx, obj) in objs.iter().enumerate() {
263 match obj {
264 Objective::Weighted { lits, .. } => {
265 for &olit in lits.keys() {
266 match obj_lit_data.get_mut(&olit) {
267 Some(data) => data.objs.push(idx),
268 None => {
269 obj_lit_data.insert(olit, ObjLitData { objs: vec![idx] });
270 }
271 }
272 }
273 }
274 Objective::Unweighted { lits, .. } => {
275 for &olit in lits {
276 match obj_lit_data.get_mut(&olit) {
277 Some(data) => data.objs.push(idx),
278 None => {
279 obj_lit_data.insert(olit, ObjLitData { objs: vec![idx] });
280 }
281 }
282 }
283 }
284 Objective::Constant { .. } => (),
285 }
286 }
287 obj_lit_data
288 };
289 #[cfg(feature = "sol-tightening")]
290 for o in &objs {
292 for (l, _) in o.iter() {
293 oracle.freeze_var(l.var())?;
294 }
295 }
296 #[cfg(feature = "interrupt-oracle")]
297 let interrupter = oracle.interrupter();
298 Ok(Self {
299 oracle,
300 var_manager,
301 #[cfg(feature = "sol-tightening")]
302 obj_lit_data,
303 objs,
304 orig_cnf,
305 block_clause_gen: bcg,
306 opts,
307 stats,
308 lims: Limits::none(),
309 #[cfg(feature = "maxpre")]
310 inpro: None,
311 logger: None,
312 term_flag: Arc::new(AtomicBool::new(false)),
313 #[cfg(feature = "interrupt-oracle")]
314 oracle_interrupter: Arc::new(Mutex::new(Box::new(interrupter))),
315 proof_stuff: None,
316 _factory: PhantomData,
317 })
318 }
319}
320
321impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
322where
323 ProofW: io::Write,
324{
325 fn start_solving(&mut self, limits: Limits) {
326 self.stats.n_solve_calls += 1;
327 self.lims = limits;
328 }
329
330 fn attach_logger<L: WriteSolverLog + 'static>(&mut self, logger: L) {
331 self.logger = Some(Box::new(logger));
332 }
333
334 fn detach_logger(&mut self) -> Option<Box<dyn WriteSolverLog>> {
335 self.logger.take()
336 }
337
338 fn externalize_internal_costs(&self, costs: &[usize]) -> Vec<isize> {
342 debug_assert_eq!(costs.len(), self.stats.n_objs);
343 costs
344 .iter()
345 .enumerate()
346 .map(|(idx, &cst)| match self.objs[idx] {
347 Objective::Weighted { offset, .. } => {
348 let signed_cst: isize = cst.try_into().expect("cost exceeds `isize`");
349 signed_cst + offset
350 }
351 Objective::Unweighted {
352 offset,
353 unit_weight,
354 ..
355 } => {
356 let signed_mult_cost: isize = (cst * unit_weight)
357 .try_into()
358 .expect("multiplied cost exceeds `isize`");
359 signed_mult_cost + offset
360 }
361 Objective::Constant { offset, .. } => {
362 debug_assert_eq!(cst, 0);
363 offset
364 }
365 })
366 .collect()
367 }
368
369 fn block_pareto_mcs(&self, sol: Assignment) -> Clause {
371 let mut blocking_clause = Clause::new();
372 self.objs.iter().for_each(|oe| {
373 oe.iter().for_each(|(l, _)| {
374 if sol.lit_value(l) == TernaryVal::True {
375 blocking_clause.add(-l)
376 }
377 })
378 });
379 blocking_clause
380 .normalize()
381 .expect("Tautological blocking clause")
382 }
383
384 fn check_termination(&self) -> MaybeTerminated {
386 if self.term_flag.load(Ordering::Relaxed) {
387 MaybeTerminated::Terminated(Termination::Interrupted)
388 } else {
389 MaybeTerminated::Done(())
390 }
391 }
392
393 fn log_candidate(&mut self, costs: &[usize], phase: Phase) -> MaybeTerminatedError {
395 debug_assert_eq!(costs.len(), self.stats.n_objs);
396 self.stats.n_candidates += 1;
397 if let Some(logger) = &mut self.logger {
399 logger
400 .log_candidate(costs, phase)
401 .context("logger failed")?;
402 }
403 if let Some(candidates) = &mut self.lims.candidates {
405 *candidates -= 1;
406 if *candidates == 0 {
407 return Terminated(Termination::CandidatesLimit);
408 }
409 }
410 Done(())
411 }
412
413 fn log_oracle_call(&mut self, result: SolverResult) -> MaybeTerminatedError {
415 self.stats.n_oracle_calls += 1;
416 if let Some(logger) = &mut self.logger {
418 logger.log_oracle_call(result).context("logger failed")?;
419 }
420 if let Some(oracle_calls) = &mut self.lims.oracle_calls {
422 *oracle_calls -= 1;
423 if *oracle_calls == 0 {
424 return Terminated(Termination::OracleCallsLimit);
425 }
426 }
427 Done(())
428 }
429
430 fn log_solution(&mut self) -> MaybeTerminatedError {
432 self.stats.n_solutions += 1;
433 if let Some(logger) = &mut self.logger {
435 logger.log_solution().context("logger failed")?;
436 }
437 if let Some(solutions) = &mut self.lims.sols {
439 *solutions -= 1;
440 if *solutions == 0 {
441 return Terminated(Termination::SolsLimit);
442 }
443 }
444 Done(())
445 }
446
447 fn log_non_dominated(&mut self, non_dominated: &NonDomPoint) -> MaybeTerminatedError {
449 self.stats.n_non_dominated += 1;
450 if let Some(logger) = &mut self.logger {
452 logger
453 .log_non_dominated(non_dominated)
454 .context("logger failed")?;
455 }
456 if let Some(pps) = &mut self.lims.pps {
458 *pps -= 1;
459 if *pps == 0 {
460 return Terminated(Termination::PPLimit);
461 }
462 }
463 Done(())
464 }
465
466 #[cfg(feature = "sol-tightening")]
467 fn log_heuristic_obj_improvement(
469 &mut self,
470 obj_idx: usize,
471 apparent_cost: usize,
472 improved_cost: usize,
473 ) -> anyhow::Result<()> {
474 if let Some(logger) = &mut self.logger {
476 logger
477 .log_heuristic_obj_improvement(obj_idx, apparent_cost, improved_cost)
478 .context("logger failed")?;
479 }
480 Ok(())
481 }
482
483 fn log_routine_start(&mut self, desc: &'static str) -> anyhow::Result<()> {
485 if let Some(logger) = &mut self.logger {
487 logger.log_routine_start(desc).context("logger failed")?;
488 }
489 Ok(())
490 }
491
492 fn log_routine_end(&mut self) -> anyhow::Result<()> {
494 if let Some(logger) = &mut self.logger {
496 logger.log_routine_end().context("logger failed")?;
497 }
498 Ok(())
499 }
500}
501
502#[cfg(feature = "interrupt-oracle")]
503impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
504where
505 O: rustsat::solvers::Interrupt,
506 ProofW: io::Write,
507{
508 fn interrupter(&mut self) -> Interrupter {
509 Interrupter {
510 term_flag: self.term_flag.clone(),
511 oracle_interrupter: self.oracle_interrupter.clone(),
512 }
513 }
514}
515
516#[cfg(not(feature = "interrupt-oracle"))]
517impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
518where
519 ProofW: io::Write,
520{
521 fn interrupter(&mut self) -> Interrupter {
522 Interrupter {
523 term_flag: self.term_flag.clone(),
524 }
525 }
526}
527
528#[cfg(feature = "sol-tightening")]
529impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
530where
531 O: SolveIncremental + rustsat::solvers::FlipLit,
532 ProofW: io::Write,
533{
534 pub fn get_cost_with_heuristic_improvements(
537 &mut self,
538 obj_idx: usize,
539 sol: &mut Assignment,
540 mut tightening: bool,
541 ) -> anyhow::Result<usize> {
542 debug_assert!(obj_idx < self.stats.n_objs);
543 let mut reduction = 0;
544 let mut cost = 0;
546 let mut used_sol = sol;
547 let mut rec_sol;
548 if let Some(inpro) = &mut self.inpro {
549 tightening = false;
552 rec_sol = inpro.reconstruct(used_sol.clone());
553 used_sol = &mut rec_sol;
554 }
555 for (l, w) in self.objs[obj_idx].iter() {
556 let val = used_sol.lit_value(l);
557 if val == TernaryVal::True {
558 if tightening && !self.obj_lit_data.contains_key(&!l) {
559 if self.oracle.flip_lit(!l)? {
562 used_sol.assign_lit(!l);
563 reduction += w;
564 continue;
565 }
566 }
567 cost += w;
568 }
569 }
570 if reduction > 0 {
571 debug_assert!(tightening);
572 *used_sol = self.oracle.solution(used_sol.max_var().unwrap())?;
574 debug_assert_eq!(
575 self.get_cost_with_heuristic_improvements(obj_idx, used_sol, false)?,
576 cost
577 );
578 }
579 if tightening {
580 self.log_heuristic_obj_improvement(obj_idx, cost + reduction, cost)?;
581 }
582 Ok(cost)
583 }
584}
585
586#[cfg(not(feature = "sol-tightening"))]
587impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
588where
589 O: SolveIncremental,
590 ProofW: io::Write,
591{
592 fn get_cost_with_heuristic_improvements(
595 &mut self,
596 obj_idx: usize,
597 sol: &mut Assignment,
598 _tightening: bool,
599 ) -> anyhow::Result<usize> {
600 debug_assert!(obj_idx < self.stats.n_objs);
601 let mut cost = 0;
602 for (l, w) in self.objs[obj_idx].iter() {
603 let val = sol.lit_value(l);
604 if val == TernaryVal::True {
605 cost += w;
606 }
607 }
608 Ok(cost)
609 }
610}
611
612#[cfg(feature = "phasing")]
613impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
614where
615 O: rustsat::solvers::PhaseLit,
616 ProofW: io::Write,
617{
618 fn phase_solution(&mut self, solution: Assignment) -> anyhow::Result<()> {
621 if !self.opts.solution_guided_search {
622 return Ok(());
623 }
624 for lit in solution.into_iter() {
625 self.oracle.phase_lit(lit)?;
626 }
627 Ok(())
628 }
629
630 fn unphase_solution(&mut self) -> anyhow::Result<()> {
633 if !self.opts.solution_guided_search {
634 return Ok(());
635 }
636 for idx in 0..self.var_manager.max_var().unwrap().idx32() + 1 {
637 self.oracle.unphase_var(rustsat::var![idx])?;
638 }
639 Ok(())
640 }
641}
642
643#[cfg(not(feature = "phasing"))]
644impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG> {
645 fn phase_solution(&mut self, _solution: Assignment) -> anyhow::Result<()> {
648 Ok(())
649 }
650
651 fn unphase_solution(&mut self) -> anyhow::Result<()> {
654 Ok(())
655 }
656}
657
658#[oracle_bounds]
659impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
660where
661 O: SolveIncremental,
662 ProofW: io::Write,
663 BCG: Fn(Assignment) -> Clause,
664{
665 fn yield_solutions<Col: Extend<NonDomPoint>>(
669 &mut self,
670 costs: Vec<usize>,
671 assumps: &[Lit],
672 mut solution: Assignment,
673 collector: &mut Col,
674 ) -> MaybeTerminatedError {
675 debug_assert_eq!(costs.len(), self.stats.n_objs);
676 self.log_routine_start("yield solutions")?;
677 self.unphase_solution()?;
678
679 let mut non_dominated = NonDomPoint::new(self.externalize_internal_costs(&costs));
681
682 loop {
683 debug_assert_eq!(
684 (0..self.stats.n_objs)
685 .map(|idx| {
686 self.get_cost_with_heuristic_improvements(idx, &mut solution, false)
687 .unwrap()
688 })
689 .collect::<Vec<_>>(),
690 costs
691 );
692
693 solution = solution.truncate(self.var_manager.max_orig_var());
695
696 non_dominated.add_sol(solution.clone());
697 match self.log_solution() {
698 Done(_) => (),
699 Terminated(term) => {
700 let nd_term = self.log_non_dominated(&non_dominated);
701 collector.extend([non_dominated]);
702 nd_term?;
703 return Terminated(term);
704 }
705 Error(err) => {
706 let nd_term = self.log_non_dominated(&non_dominated);
707 collector.extend([non_dominated]);
708 nd_term?;
709 return Error(err);
710 }
711 }
712 if match self.opts.enumeration {
713 EnumOptions::NoEnum => true,
714 EnumOptions::Solutions(Some(limit)) => non_dominated.n_sols() >= limit,
715 EnumOptions::PMCSs(Some(limit)) => non_dominated.n_sols() >= limit,
716 _unlimited => false,
717 } {
718 let pp_term = self.log_non_dominated(&non_dominated);
719 collector.extend([non_dominated]);
720 self.log_routine_end()?;
721 return pp_term;
722 }
723 self.check_termination()?;
724
725 match self.opts.enumeration {
727 EnumOptions::Solutions(_) => {
728 self.oracle.add_clause((self.block_clause_gen)(solution))?
729 }
730 EnumOptions::PMCSs(_) => self.oracle.add_clause(self.block_pareto_mcs(solution))?,
731 EnumOptions::NoEnum => panic!("Should never reach this"),
732 }
733
734 let res = self.solve_assumps(assumps)?;
736 if res == SolverResult::Unsat {
737 let pp_term = self.log_non_dominated(&non_dominated);
738 collector.extend([non_dominated]);
740 self.log_routine_end()?;
741 return pp_term;
742 }
743 self.check_termination()?;
744 solution = self.oracle.solution(self.var_manager.max_var().unwrap())?;
745 }
746 }
747}
748
749impl<'learn, 'term, ProofW, OInit, BCG>
750 Kernel<rustsat_cadical::CaDiCaL<'learn, 'term>, ProofW, OInit, BCG>
751where
752 ProofW: io::Write + 'static,
753 BCG: Fn(Assignment) -> Clause,
754{
755 fn linsu_yield<Col>(
758 &mut self,
759 obj_idx: usize,
760 encoding: &mut ObjEncoding<GeneralizedTotalizer, Totalizer>,
761 base_assumps: &[Lit],
762 upper_bound: Option<(usize, Option<Assignment>)>,
763 lower_bound: Option<usize>,
764 collector: &mut Col,
765 ) -> MaybeTerminatedError<Option<(usize, Assignment, Option<pigeons::AbsConstraintId>)>>
766 where
767 Col: Extend<NonDomPoint>,
768 {
769 let Some((cost, mut sol, lb_id)) =
770 self.linsu(obj_idx, encoding, base_assumps, upper_bound, lower_bound)?
771 else {
772 return Done(None);
774 };
775 let costs: Vec<_> = (0..self.stats.n_objs)
776 .map(|idx| {
777 self.get_cost_with_heuristic_improvements(idx, &mut sol, false)
778 .unwrap()
779 })
780 .collect();
781 debug_assert_eq!(costs[obj_idx], cost);
782 let mut assumps = Vec::from(base_assumps);
784 self.extend_encoding(encoding, cost..cost + 1)?;
785 assumps.extend(encoding.enforce_ub(cost).unwrap());
786 self.yield_solutions(costs, &assumps, sol.clone(), collector)?;
787 Done(Some((cost, sol, lb_id)))
788 }
789}
790
791#[oracle_bounds]
792impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
793where
794 O: SolveIncremental + SolveStats,
795 ProofW: io::Write,
796{
797 fn get_solution_and_internal_costs(
801 &mut self,
802 tightening: bool,
803 ) -> anyhow::Result<(Vec<usize>, Assignment)> {
804 let mut sol = self.oracle.solution(self.var_manager.max_var().unwrap())?;
805 let costs = (0..self.objs.len())
806 .map(|idx| self.get_cost_with_heuristic_improvements(idx, &mut sol, tightening))
807 .collect::<Result<Vec<usize>, _>>()?;
808 debug_assert_eq!(costs.len(), self.stats.n_objs);
809 Ok((costs, sol))
810 }
811}
812
813impl<'learn, 'term, ProofW, OInit, BCG>
814 Kernel<rustsat_cadical::CaDiCaL<'learn, 'term>, ProofW, OInit, BCG>
815where
816 ProofW: io::Write + 'static,
817{
818 fn linsu(
820 &mut self,
821 obj_idx: usize,
822 encoding: &mut ObjEncoding<GeneralizedTotalizer, Totalizer>,
823 base_assumps: &[Lit],
824 upper_bound: Option<(usize, Option<Assignment>)>,
825 lower_bound: Option<usize>,
826 ) -> MaybeTerminatedError<Option<(usize, Assignment, Option<pigeons::AbsConstraintId>)>> {
827 use rustsat::solvers::Solve;
828
829 self.log_routine_start("linsu")?;
830
831 let mut lb_id = None;
832
833 let lower_bound = lower_bound.unwrap_or(0);
834
835 let (mut cost, mut sol) = if let Some(bound) = upper_bound {
836 bound
837 } else {
838 let res = self.solve_assumps(base_assumps)?;
839 if res == SolverResult::Unsat {
840 self.log_routine_end()?;
841 return Done(None);
842 }
843 let mut sol = self.oracle.solution(self.var_manager.max_var().unwrap())?;
844 let cost = self.get_cost_with_heuristic_improvements(obj_idx, &mut sol, true)?;
845 (cost, Some(sol))
846 };
847 let mut assumps = Vec::from(base_assumps);
848 #[cfg(feature = "coarse-convergence")]
849 let mut coarse = true;
850 while cost > lower_bound {
851 let bound = '_bound: {
852 #[cfg(feature = "coarse-convergence")]
853 if coarse {
854 break '_bound encoding.coarse_ub(cost - 1);
855 }
856 cost - 1
857 };
858 assumps.drain(base_assumps.len()..);
859 self.extend_encoding(encoding, bound..bound + 1)?;
860 assumps.extend(encoding.enforce_ub(bound).unwrap());
861 match self.solve_assumps(&assumps)? {
862 SolverResult::Sat => {
863 let mut thissol = self.oracle.solution(self.var_manager.max_var().unwrap())?;
864 let new_cost =
865 self.get_cost_with_heuristic_improvements(obj_idx, &mut thissol, false)?;
866 debug_assert!(new_cost < cost);
867 let costs: Vec<_> = (0..self.stats.n_objs)
868 .map(|oidx| {
869 self.get_cost_with_heuristic_improvements(oidx, &mut thissol, false)
870 .unwrap()
871 })
872 .collect();
873 self.log_candidate(&costs, Phase::Linsu)?;
874 sol = Some(thissol);
875 cost = new_cost;
876 if cost <= lower_bound {
877 self.log_routine_end()?;
878 return Done(Some((cost, sol.unwrap(), None)));
879 }
880 }
881 SolverResult::Unsat => {
882 #[cfg(feature = "coarse-convergence")]
883 if bound + 1 < cost {
884 coarse = false;
885 continue;
886 }
887
888 if let Some(proof_stuff) = &mut self.proof_stuff {
889 lb_id = Some(proofs::linsu_certify_lower_bound(
890 base_assumps,
891 cost,
892 &(self.oracle.core()?),
893 &self.objs[obj_idx],
894 encoding,
895 proof_stuff,
896 &mut self.oracle,
897 )?);
898 }
899
900 break;
901 }
902 _ => unreachable!(),
903 }
904 }
905
906 if sol.is_none() {
908 self.extend_encoding(encoding, cost..cost + 1)?;
909 assumps.drain(base_assumps.len()..);
910 assumps.extend(encoding.enforce_ub(cost).unwrap());
911 let res = self.solve_assumps(&assumps)?;
912 debug_assert_eq!(res, SolverResult::Sat);
913 sol = Some(self.oracle.solution(self.var_manager.max_var().unwrap())?);
914 }
915 self.log_routine_end()?;
916 Done(Some((cost, sol.unwrap(), lb_id)))
917 }
918
919 fn extend_encoding(
920 &mut self,
921 encoding: &mut ObjEncoding<GeneralizedTotalizer, Totalizer>,
922 range: Range<usize>,
923 ) -> anyhow::Result<()> {
924 if let Some(proofs::ProofStuff { pt_handle, .. }) = &self.proof_stuff {
925 let proof: *mut _ = self.oracle.proof_tracer_mut(pt_handle).proof_mut();
926 #[cfg(feature = "verbose-proofs")]
927 {
928 unsafe { &mut *proof }.comment(&format_args!(
929 "extending encoding to {}..{}",
930 range.start, range.end,
931 ))?;
932 }
933 let mut collector = CadicalCertCollector::new(&mut self.oracle, pt_handle);
934 encoding.encode_ub_change_cert(
935 range,
936 &mut collector,
937 &mut self.var_manager,
938 unsafe { &mut *proof },
939 )?;
940 } else {
941 encoding.encode_ub_change(range, &mut self.oracle, &mut self.var_manager)?;
942 }
943 Ok(())
944 }
945}
946
947impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
948where
949 O: SolveIncremental,
950 ProofW: io::Write,
951{
952 fn solve(&mut self) -> MaybeTerminatedError<SolverResult> {
955 self.log_routine_start("oracle call")?;
956 let res = self.oracle.solve()?;
957 self.log_routine_end()?;
958 self.check_termination()?;
959 self.log_oracle_call(res)?;
960 Done(res)
961 }
962
963 fn solve_assumps(&mut self, assumps: &[Lit]) -> MaybeTerminatedError<SolverResult> {
966 self.log_routine_start("oracle call")?;
967 let res = self.oracle.solve_assumps(assumps)?;
968 self.log_routine_end()?;
969 self.check_termination()?;
970 self.log_oracle_call(res)?;
971 Done(res)
972 }
973}
974
975#[oracle_bounds]
976impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
977where
978 O: SolveIncremental,
979 ProofW: io::Write,
980 OInit: Initialize<O>,
981{
982 fn reset_oracle(&mut self, include_var_manager: bool) -> anyhow::Result<()> {
984 anyhow::ensure!(
985 self.opts.store_cnf,
986 "cannot reset oracle without having stored the CNF"
987 );
988 self.log_routine_start("reset-oracle")?;
989 self.oracle = OInit::init();
990 if include_var_manager {
991 self.oracle.reserve(self.var_manager.max_enc_var())?;
992 } else {
993 self.oracle.reserve(self.var_manager.max_var().unwrap())?;
994 }
995 self.oracle.add_cnf(self.orig_cnf.clone().unwrap())?;
996 #[cfg(feature = "interrupt-oracle")]
997 {
998 *self.oracle_interrupter.lock().unwrap() = Box::new(self.oracle.interrupter());
999 }
1000 #[cfg(feature = "sol-tightening")]
1001 for o in &self.objs {
1003 for (l, _) in o.iter() {
1004 self.oracle.freeze_var(l.var())?;
1005 }
1006 }
1007 if include_var_manager {
1008 self.var_manager
1009 .forget_from(self.var_manager.max_enc_var() + 1);
1010 }
1011 self.log_routine_end()?;
1012 Ok(())
1013 }
1014}
1015
1016pub fn default_blocking_clause(sol: Assignment) -> Clause {
1018 Clause::from_iter(sol.into_iter().map(Lit::not))
1019}