scuttle_core/
algs.rs

1//! Core solver functionality shared between different algorithms
2
3use std::{
4    io,
5    marker::PhantomData,
6    ops::{Not, Range},
7    sync::{
8        atomic::{AtomicBool, Ordering},
9        Arc,
10    },
11};
12
13#[cfg(feature = "interrupt-oracle")]
14use std::sync::Mutex;
15
16use anyhow::Context;
17use cadical_veripb_tracer::CadicalCertCollector;
18use rustsat::{
19    encodings::{card::Totalizer, pb::GeneralizedTotalizer},
20    instances::{Cnf, ManageVars},
21    solvers::{
22        DefaultInitializer, Initialize, SolveIncremental, SolveStats, SolverResult, SolverStats,
23    },
24    types::{Assignment, Clause, Lit, TernaryVal},
25};
26use scuttle_proc::oracle_bounds;
27
28#[cfg(feature = "sol-tightening")]
29use maxpre::PreproClauses;
30
31use crate::{
32    options::{CoreBoostingOptions, EnumOptions},
33    types::{Instance, NonDomPoint, ObjEncoding, Objective, ParetoFront, VarManager},
34    EncodingStats, KernelOptions, Limits, MaybeTerminated,
35    MaybeTerminatedError::{self, Done, Error, Terminated},
36    Phase, Stats, Termination, WriteSolverLog,
37};
38
39pub mod bioptsat;
40pub mod lowerbounding;
41pub mod pminimal;
42
43mod coreboosting;
44mod coreguided;
45pub(crate) mod proofs;
46pub use proofs::{InitCert, InitCertDefaultBlock};
47
48/// Trait for initializing algorithms
49pub trait Init: Sized {
50    type Oracle: SolveIncremental;
51    type BlockClauseGen: Fn(Assignment) -> Clause;
52
53    /// Initialization of the algorithm providing all optional input
54    fn new<Cls>(
55        clauses: Cls,
56        objs: Vec<Objective>,
57        var_manager: VarManager,
58        opts: KernelOptions,
59        block_clause_gen: Self::BlockClauseGen,
60    ) -> anyhow::Result<Self>
61    where
62        Cls: IntoIterator<Item = Clause>;
63
64    /// Initialization of the algorithm using an [`Instance`] rather than iterators
65    fn from_instance(
66        inst: Instance,
67        opts: KernelOptions,
68        block_clause_gen: Self::BlockClauseGen,
69    ) -> anyhow::Result<Self> {
70        Self::new(
71            inst.clauses.into_iter().map(|(cl, _)| cl),
72            inst.objs,
73            inst.vm,
74            opts,
75            block_clause_gen,
76        )
77    }
78}
79
80pub trait InitDefaultBlock: Init<BlockClauseGen = fn(Assignment) -> Clause> {
81    /// Initializes the algorithm with the default blocking clause generator
82    fn new_default_blocking<Cls>(
83        clauses: Cls,
84        objs: Vec<Objective>,
85        var_manager: VarManager,
86        opts: KernelOptions,
87    ) -> anyhow::Result<Self>
88    where
89        Cls: IntoIterator<Item = Clause>,
90    {
91        Self::new(clauses, objs, var_manager, opts, default_blocking_clause)
92    }
93
94    /// Initializes the algorithm using an [`Instance`] rather than iterators with the default
95    /// blocking clause generator
96    fn from_instance_default_blocking(inst: Instance, opts: KernelOptions) -> anyhow::Result<Self> {
97        Self::new(
98            inst.clauses.into_iter().map(|(cl, _)| cl),
99            inst.objs,
100            inst.vm,
101            opts,
102            default_blocking_clause,
103        )
104    }
105}
106
107impl<Alg> InitDefaultBlock for Alg where Alg: Init<BlockClauseGen = fn(Assignment) -> Clause> {}
108
109/// Solving interface for each algorithm
110pub trait Solve: KernelFunctions {
111    /// Solves the instance under given limits. If not fully solved, returns an
112    /// early termination reason.
113    fn solve(&mut self, limits: Limits) -> MaybeTerminatedError;
114    /// Gets all statistics from the solver
115    fn all_stats(&self) -> (Stats, Option<SolverStats>, Option<Vec<EncodingStats>>);
116}
117
118/// Core boosting interface
119pub trait CoreBoost {
120    /// Performs core boosting. Returns false if instance is unsat.    
121    fn core_boost(&mut self, opts: CoreBoostingOptions) -> MaybeTerminatedError<bool>;
122}
123
124/// Shared functionality provided by the [`Kernel`]
125pub trait KernelFunctions {
126    /// Gets the Pareto front discovered so far
127    fn pareto_front(&self) -> ParetoFront;
128    /// Gets tracked statistics from the solver
129    fn stats(&self) -> Stats;
130    /// Attaches a logger to the solver
131    fn attach_logger<L: WriteSolverLog + 'static>(&mut self, logger: L);
132    /// Detaches a logger from the solver
133    fn detach_logger(&mut self) -> Option<Box<dyn WriteSolverLog>>;
134    /// Gets an iterrupter to the solver
135    fn interrupter(&mut self) -> Interrupter;
136}
137
138pub struct Interrupter {
139    /// Termination flag of the solver
140    term_flag: Arc<AtomicBool>,
141    /// The terminator of the underlying SAT oracle
142    #[cfg(feature = "interrupt-oracle")]
143    oracle_interrupter: Arc<Mutex<Box<dyn rustsat::solvers::InterruptSolver + Send>>>,
144}
145
146#[cfg(feature = "interrupt-oracle")]
147impl Interrupter {
148    /// Interrupts the solver asynchronously
149    pub fn interrupt(&mut self) {
150        self.term_flag.store(true, Ordering::Relaxed);
151        self.oracle_interrupter.lock().unwrap().interrupt();
152    }
153}
154
155#[cfg(not(feature = "interrupt-oracle"))]
156impl Interrupter {
157    /// Interrupts the solver asynchronously
158    pub fn interrupt(&mut self) {
159        self.term_flag.store(true, Ordering::Relaxed);
160    }
161}
162
163/// Kernel struct shared between all solvers
164///
165/// # Generics
166///
167/// - `O`: the SAT solver oracle
168/// - `ProofW`: the proof writer
169/// - `OInit`: the oracle initializer
170/// - `BCG`: the blocking clause generator
171struct Kernel<O, ProofW, OInit = DefaultInitializer, BCG = fn(Assignment) -> Clause>
172where
173    ProofW: io::Write,
174{
175    /// The SAT solver backend
176    oracle: O,
177    /// The variable manager keeping track of variables
178    var_manager: VarManager,
179    #[cfg(feature = "sol-tightening")]
180    /// Objective literal data
181    obj_lit_data: rustsat::types::RsHashMap<Lit, crate::types::ObjLitData>,
182    /// The objectives
183    objs: Vec<Objective>,
184    /// The stored original clauses, if needed
185    orig_cnf: Option<Cnf>,
186    /// Generator of blocking clauses
187    block_clause_gen: BCG,
188    /// Configuration options
189    opts: KernelOptions,
190    /// Running statistics
191    stats: Stats,
192    /// Limits for the current solving run
193    lims: Limits,
194    /// An optional inprocessor that has been run at some stage
195    #[cfg(feature = "maxpre")]
196    inpro: Option<maxpre::MaxPre>,
197    /// Logger to log with
198    logger: Option<Box<dyn WriteSolverLog>>,
199    /// Termination flag
200    term_flag: Arc<AtomicBool>,
201    /// The oracle interrupter
202    #[cfg(feature = "interrupt-oracle")]
203    oracle_interrupter: Arc<Mutex<Box<dyn rustsat::solvers::InterruptSolver + Send>>>,
204    /// The handle of the proof tracer, when proof logging
205    proof_stuff: Option<proofs::ProofStuff<ProofW>>,
206    /// Phantom marker for oracle factory
207    _factory: PhantomData<OInit>,
208}
209
210#[oracle_bounds]
211impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
212where
213    O: SolveIncremental,
214    ProofW: io::Write,
215    OInit: Initialize<O>,
216    BCG: Fn(Assignment) -> Clause,
217{
218    pub fn new<Cls>(
219        clauses: Cls,
220        objs: Vec<Objective>,
221        var_manager: VarManager,
222        bcg: BCG,
223        opts: KernelOptions,
224    ) -> anyhow::Result<Self>
225    where
226        Cls: IntoIterator<Item = Clause>,
227    {
228        let mut stats = Stats {
229            n_objs: 0,
230            n_real_objs: 0,
231            n_orig_clauses: 0,
232            ..Default::default()
233        };
234        let mut oracle = OInit::init();
235        oracle.reserve(var_manager.max_var().unwrap())?;
236        let orig_cnf = if opts.store_cnf {
237            let cnf: Cnf = clauses.into_iter().collect();
238            stats.n_orig_clauses = cnf.len();
239            oracle.add_cnf_ref(&cnf)?;
240            Some(cnf)
241        } else {
242            for cl in clauses.into_iter() {
243                stats.n_orig_clauses += 1;
244                oracle.add_clause(cl)?;
245            }
246            None
247        };
248        stats.n_objs = objs.len();
249        stats.n_real_objs = objs.iter().fold(0, |cnt, o| {
250            if matches!(o, Objective::Constant { .. }) {
251                cnt
252            } else {
253                cnt + 1
254            }
255        });
256        // Record objective literal occurrences
257        #[cfg(feature = "sol-tightening")]
258        let obj_lit_data = {
259            use crate::types::ObjLitData;
260            use rustsat::types::RsHashMap;
261            let mut obj_lit_data: RsHashMap<_, ObjLitData> = RsHashMap::default();
262            for (idx, obj) in objs.iter().enumerate() {
263                match obj {
264                    Objective::Weighted { lits, .. } => {
265                        for &olit in lits.keys() {
266                            match obj_lit_data.get_mut(&olit) {
267                                Some(data) => data.objs.push(idx),
268                                None => {
269                                    obj_lit_data.insert(olit, ObjLitData { objs: vec![idx] });
270                                }
271                            }
272                        }
273                    }
274                    Objective::Unweighted { lits, .. } => {
275                        for &olit in lits {
276                            match obj_lit_data.get_mut(&olit) {
277                                Some(data) => data.objs.push(idx),
278                                None => {
279                                    obj_lit_data.insert(olit, ObjLitData { objs: vec![idx] });
280                                }
281                            }
282                        }
283                    }
284                    Objective::Constant { .. } => (),
285                }
286            }
287            obj_lit_data
288        };
289        #[cfg(feature = "sol-tightening")]
290        // Freeze objective variables so that they are not removed
291        for o in &objs {
292            for (l, _) in o.iter() {
293                oracle.freeze_var(l.var())?;
294            }
295        }
296        #[cfg(feature = "interrupt-oracle")]
297        let interrupter = oracle.interrupter();
298        Ok(Self {
299            oracle,
300            var_manager,
301            #[cfg(feature = "sol-tightening")]
302            obj_lit_data,
303            objs,
304            orig_cnf,
305            block_clause_gen: bcg,
306            opts,
307            stats,
308            lims: Limits::none(),
309            #[cfg(feature = "maxpre")]
310            inpro: None,
311            logger: None,
312            term_flag: Arc::new(AtomicBool::new(false)),
313            #[cfg(feature = "interrupt-oracle")]
314            oracle_interrupter: Arc::new(Mutex::new(Box::new(interrupter))),
315            proof_stuff: None,
316            _factory: PhantomData,
317        })
318    }
319}
320
321impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
322where
323    ProofW: io::Write,
324{
325    fn start_solving(&mut self, limits: Limits) {
326        self.stats.n_solve_calls += 1;
327        self.lims = limits;
328    }
329
330    fn attach_logger<L: WriteSolverLog + 'static>(&mut self, logger: L) {
331        self.logger = Some(Box::new(logger));
332    }
333
334    fn detach_logger(&mut self) -> Option<Box<dyn WriteSolverLog>> {
335        self.logger.take()
336    }
337
338    /// Converts an internal cost vector to an external one. Internal cost is
339    /// purely the encoding output while external cost takes an offset and
340    /// multiplier into account.
341    fn externalize_internal_costs(&self, costs: &[usize]) -> Vec<isize> {
342        debug_assert_eq!(costs.len(), self.stats.n_objs);
343        costs
344            .iter()
345            .enumerate()
346            .map(|(idx, &cst)| match self.objs[idx] {
347                Objective::Weighted { offset, .. } => {
348                    let signed_cst: isize = cst.try_into().expect("cost exceeds `isize`");
349                    signed_cst + offset
350                }
351                Objective::Unweighted {
352                    offset,
353                    unit_weight,
354                    ..
355                } => {
356                    let signed_mult_cost: isize = (cst * unit_weight)
357                        .try_into()
358                        .expect("multiplied cost exceeds `isize`");
359                    signed_mult_cost + offset
360                }
361                Objective::Constant { offset, .. } => {
362                    debug_assert_eq!(cst, 0);
363                    offset
364                }
365            })
366            .collect()
367    }
368
369    /// Blocks the current Pareto-MCS by blocking all blocking variables that are set
370    fn block_pareto_mcs(&self, sol: Assignment) -> Clause {
371        let mut blocking_clause = Clause::new();
372        self.objs.iter().for_each(|oe| {
373            oe.iter().for_each(|(l, _)| {
374                if sol.lit_value(l) == TernaryVal::True {
375                    blocking_clause.add(-l)
376                }
377            })
378        });
379        blocking_clause
380            .normalize()
381            .expect("Tautological blocking clause")
382    }
383
384    /// Checks the termination flag and terminates if appropriate
385    fn check_termination(&self) -> MaybeTerminated {
386        if self.term_flag.load(Ordering::Relaxed) {
387            MaybeTerminated::Terminated(Termination::Interrupted)
388        } else {
389            MaybeTerminated::Done(())
390        }
391    }
392
393    /// Logs a cost point candidate. Can error a termination if the candidates limit is reached.
394    fn log_candidate(&mut self, costs: &[usize], phase: Phase) -> MaybeTerminatedError {
395        debug_assert_eq!(costs.len(), self.stats.n_objs);
396        self.stats.n_candidates += 1;
397        // Dispatch to logger
398        if let Some(logger) = &mut self.logger {
399            logger
400                .log_candidate(costs, phase)
401                .context("logger failed")?;
402        }
403        // Update limit and check termination
404        if let Some(candidates) = &mut self.lims.candidates {
405            *candidates -= 1;
406            if *candidates == 0 {
407                return Terminated(Termination::CandidatesLimit);
408            }
409        }
410        Done(())
411    }
412
413    /// Logs an oracle call. Can return a termination if the oracle call limit is reached.
414    fn log_oracle_call(&mut self, result: SolverResult) -> MaybeTerminatedError {
415        self.stats.n_oracle_calls += 1;
416        // Dispatch to logger
417        if let Some(logger) = &mut self.logger {
418            logger.log_oracle_call(result).context("logger failed")?;
419        }
420        // Update limit and check termination
421        if let Some(oracle_calls) = &mut self.lims.oracle_calls {
422            *oracle_calls -= 1;
423            if *oracle_calls == 0 {
424                return Terminated(Termination::OracleCallsLimit);
425            }
426        }
427        Done(())
428    }
429
430    /// Logs a solution. Can return a termination if the solution limit is reached.
431    fn log_solution(&mut self) -> MaybeTerminatedError {
432        self.stats.n_solutions += 1;
433        // Dispatch to logger
434        if let Some(logger) = &mut self.logger {
435            logger.log_solution().context("logger failed")?;
436        }
437        // Update limit and check termination
438        if let Some(solutions) = &mut self.lims.sols {
439            *solutions -= 1;
440            if *solutions == 0 {
441                return Terminated(Termination::SolsLimit);
442            }
443        }
444        Done(())
445    }
446
447    /// Logs a non-dominated point. Can return a termination if the non-dominated point limit is reached.
448    fn log_non_dominated(&mut self, non_dominated: &NonDomPoint) -> MaybeTerminatedError {
449        self.stats.n_non_dominated += 1;
450        // Dispatch to logger
451        if let Some(logger) = &mut self.logger {
452            logger
453                .log_non_dominated(non_dominated)
454                .context("logger failed")?;
455        }
456        // Update limit and check termination
457        if let Some(pps) = &mut self.lims.pps {
458            *pps -= 1;
459            if *pps == 0 {
460                return Terminated(Termination::PPLimit);
461            }
462        }
463        Done(())
464    }
465
466    #[cfg(feature = "sol-tightening")]
467    /// Logs a heuristic objective improvement. Can return a logger error.
468    fn log_heuristic_obj_improvement(
469        &mut self,
470        obj_idx: usize,
471        apparent_cost: usize,
472        improved_cost: usize,
473    ) -> anyhow::Result<()> {
474        // Dispatch to logger
475        if let Some(logger) = &mut self.logger {
476            logger
477                .log_heuristic_obj_improvement(obj_idx, apparent_cost, improved_cost)
478                .context("logger failed")?;
479        }
480        Ok(())
481    }
482
483    /// Logs a routine start
484    fn log_routine_start(&mut self, desc: &'static str) -> anyhow::Result<()> {
485        // Dispatch to logger
486        if let Some(logger) = &mut self.logger {
487            logger.log_routine_start(desc).context("logger failed")?;
488        }
489        Ok(())
490    }
491
492    /// Logs a routine end
493    fn log_routine_end(&mut self) -> anyhow::Result<()> {
494        // Dispatch to logger
495        if let Some(logger) = &mut self.logger {
496            logger.log_routine_end().context("logger failed")?;
497        }
498        Ok(())
499    }
500}
501
502#[cfg(feature = "interrupt-oracle")]
503impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
504where
505    O: rustsat::solvers::Interrupt,
506    ProofW: io::Write,
507{
508    fn interrupter(&mut self) -> Interrupter {
509        Interrupter {
510            term_flag: self.term_flag.clone(),
511            oracle_interrupter: self.oracle_interrupter.clone(),
512        }
513    }
514}
515
516#[cfg(not(feature = "interrupt-oracle"))]
517impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
518where
519    ProofW: io::Write,
520{
521    fn interrupter(&mut self) -> Interrupter {
522        Interrupter {
523            term_flag: self.term_flag.clone(),
524        }
525    }
526}
527
528#[cfg(feature = "sol-tightening")]
529impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
530where
531    O: SolveIncremental + rustsat::solvers::FlipLit,
532    ProofW: io::Write,
533{
534    /// Performs heuristic solution improvement and computes the improved
535    /// (internal) cost for one objective
536    pub fn get_cost_with_heuristic_improvements(
537        &mut self,
538        obj_idx: usize,
539        sol: &mut Assignment,
540        mut tightening: bool,
541    ) -> anyhow::Result<usize> {
542        debug_assert!(obj_idx < self.stats.n_objs);
543        let mut reduction = 0;
544        // TODO: iterate over objective literals by weight
545        let mut cost = 0;
546        let mut used_sol = sol;
547        let mut rec_sol;
548        if let Some(inpro) = &mut self.inpro {
549            // TODO: don't reconstruct every time
550            // since tightening is done in the solver, cannot do this with inprocessing
551            tightening = false;
552            rec_sol = inpro.reconstruct(used_sol.clone());
553            used_sol = &mut rec_sol;
554        }
555        for (l, w) in self.objs[obj_idx].iter() {
556            let val = used_sol.lit_value(l);
557            if val == TernaryVal::True {
558                if tightening && !self.obj_lit_data.contains_key(&!l) {
559                    // If tightening and the negated literal does not appear in
560                    // any objective
561                    if self.oracle.flip_lit(!l)? {
562                        used_sol.assign_lit(!l);
563                        reduction += w;
564                        continue;
565                    }
566                }
567                cost += w;
568            }
569        }
570        if reduction > 0 {
571            debug_assert!(tightening);
572            // get assignment from the solver again to trigger reconstruction stack
573            *used_sol = self.oracle.solution(used_sol.max_var().unwrap())?;
574            debug_assert_eq!(
575                self.get_cost_with_heuristic_improvements(obj_idx, used_sol, false)?,
576                cost
577            );
578        }
579        if tightening {
580            self.log_heuristic_obj_improvement(obj_idx, cost + reduction, cost)?;
581        }
582        Ok(cost)
583    }
584}
585
586#[cfg(not(feature = "sol-tightening"))]
587impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
588where
589    O: SolveIncremental,
590    ProofW: io::Write,
591{
592    /// Performs heuristic solution improvement and computes the improved
593    /// (internal) cost for one objective
594    fn get_cost_with_heuristic_improvements(
595        &mut self,
596        obj_idx: usize,
597        sol: &mut Assignment,
598        _tightening: bool,
599    ) -> anyhow::Result<usize> {
600        debug_assert!(obj_idx < self.stats.n_objs);
601        let mut cost = 0;
602        for (l, w) in self.objs[obj_idx].iter() {
603            let val = sol.lit_value(l);
604            if val == TernaryVal::True {
605                cost += w;
606            }
607        }
608        Ok(cost)
609    }
610}
611
612#[cfg(feature = "phasing")]
613impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
614where
615    O: rustsat::solvers::PhaseLit,
616    ProofW: io::Write,
617{
618    /// If solution-guided search is turned on, phases the entire solution in
619    /// the oracle
620    fn phase_solution(&mut self, solution: Assignment) -> anyhow::Result<()> {
621        if !self.opts.solution_guided_search {
622            return Ok(());
623        }
624        for lit in solution.into_iter() {
625            self.oracle.phase_lit(lit)?;
626        }
627        Ok(())
628    }
629
630    /// If solution-guided search is turned on, unphases every variable in the
631    /// solver
632    fn unphase_solution(&mut self) -> anyhow::Result<()> {
633        if !self.opts.solution_guided_search {
634            return Ok(());
635        }
636        for idx in 0..self.var_manager.max_var().unwrap().idx32() + 1 {
637            self.oracle.unphase_var(rustsat::var![idx])?;
638        }
639        Ok(())
640    }
641}
642
643#[cfg(not(feature = "phasing"))]
644impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG> {
645    /// If solution-guided search is turned on, phases the entire solution in
646    /// the oracle
647    fn phase_solution(&mut self, _solution: Assignment) -> anyhow::Result<()> {
648        Ok(())
649    }
650
651    /// If solution-guided search is turned on, unphases every variable in the
652    /// solver
653    fn unphase_solution(&mut self) -> anyhow::Result<()> {
654        Ok(())
655    }
656}
657
658#[oracle_bounds]
659impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
660where
661    O: SolveIncremental,
662    ProofW: io::Write,
663    BCG: Fn(Assignment) -> Clause,
664{
665    /// Yields Pareto-optimal solutions. The given assumptions must only allow
666    /// for solutions at the non-dominated point with given cost. If the options
667    /// ask for enumeration, will enumerate all solutions at this point.
668    fn yield_solutions<Col: Extend<NonDomPoint>>(
669        &mut self,
670        costs: Vec<usize>,
671        assumps: &[Lit],
672        mut solution: Assignment,
673        collector: &mut Col,
674    ) -> MaybeTerminatedError {
675        debug_assert_eq!(costs.len(), self.stats.n_objs);
676        self.log_routine_start("yield solutions")?;
677        self.unphase_solution()?;
678
679        // Create Pareto point
680        let mut non_dominated = NonDomPoint::new(self.externalize_internal_costs(&costs));
681
682        loop {
683            debug_assert_eq!(
684                (0..self.stats.n_objs)
685                    .map(|idx| {
686                        self.get_cost_with_heuristic_improvements(idx, &mut solution, false)
687                            .unwrap()
688                    })
689                    .collect::<Vec<_>>(),
690                costs
691            );
692
693            // Truncate internal solution to only include instance variables
694            solution = solution.truncate(self.var_manager.max_orig_var());
695
696            non_dominated.add_sol(solution.clone());
697            match self.log_solution() {
698                Done(_) => (),
699                Terminated(term) => {
700                    let nd_term = self.log_non_dominated(&non_dominated);
701                    collector.extend([non_dominated]);
702                    nd_term?;
703                    return Terminated(term);
704                }
705                Error(err) => {
706                    let nd_term = self.log_non_dominated(&non_dominated);
707                    collector.extend([non_dominated]);
708                    nd_term?;
709                    return Error(err);
710                }
711            }
712            if match self.opts.enumeration {
713                EnumOptions::NoEnum => true,
714                EnumOptions::Solutions(Some(limit)) => non_dominated.n_sols() >= limit,
715                EnumOptions::PMCSs(Some(limit)) => non_dominated.n_sols() >= limit,
716                _unlimited => false,
717            } {
718                let pp_term = self.log_non_dominated(&non_dominated);
719                collector.extend([non_dominated]);
720                self.log_routine_end()?;
721                return pp_term;
722            }
723            self.check_termination()?;
724
725            // Block last solution
726            match self.opts.enumeration {
727                EnumOptions::Solutions(_) => {
728                    self.oracle.add_clause((self.block_clause_gen)(solution))?
729                }
730                EnumOptions::PMCSs(_) => self.oracle.add_clause(self.block_pareto_mcs(solution))?,
731                EnumOptions::NoEnum => panic!("Should never reach this"),
732            }
733
734            // Find next solution
735            let res = self.solve_assumps(assumps)?;
736            if res == SolverResult::Unsat {
737                let pp_term = self.log_non_dominated(&non_dominated);
738                // All solutions enumerated
739                collector.extend([non_dominated]);
740                self.log_routine_end()?;
741                return pp_term;
742            }
743            self.check_termination()?;
744            solution = self.oracle.solution(self.var_manager.max_var().unwrap())?;
745        }
746    }
747}
748
749impl<'learn, 'term, ProofW, OInit, BCG>
750    Kernel<rustsat_cadical::CaDiCaL<'learn, 'term>, ProofW, OInit, BCG>
751where
752    ProofW: io::Write + 'static,
753    BCG: Fn(Assignment) -> Clause,
754{
755    /// Performs linear sat-unsat search on a given objective and yields
756    /// solutions found at the optimum.
757    fn linsu_yield<Col>(
758        &mut self,
759        obj_idx: usize,
760        encoding: &mut ObjEncoding<GeneralizedTotalizer, Totalizer>,
761        base_assumps: &[Lit],
762        upper_bound: Option<(usize, Option<Assignment>)>,
763        lower_bound: Option<usize>,
764        collector: &mut Col,
765    ) -> MaybeTerminatedError<Option<(usize, Assignment, Option<pigeons::AbsConstraintId>)>>
766    where
767        Col: Extend<NonDomPoint>,
768    {
769        let Some((cost, mut sol, lb_id)) =
770            self.linsu(obj_idx, encoding, base_assumps, upper_bound, lower_bound)?
771        else {
772            // nothing to yield
773            return Done(None);
774        };
775        let costs: Vec<_> = (0..self.stats.n_objs)
776            .map(|idx| {
777                self.get_cost_with_heuristic_improvements(idx, &mut sol, false)
778                    .unwrap()
779            })
780            .collect();
781        debug_assert_eq!(costs[obj_idx], cost);
782        // bound obj
783        let mut assumps = Vec::from(base_assumps);
784        self.extend_encoding(encoding, cost..cost + 1)?;
785        assumps.extend(encoding.enforce_ub(cost).unwrap());
786        self.yield_solutions(costs, &assumps, sol.clone(), collector)?;
787        Done(Some((cost, sol, lb_id)))
788    }
789}
790
791#[oracle_bounds]
792impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
793where
794    O: SolveIncremental + SolveStats,
795    ProofW: io::Write,
796{
797    /// Gets the current objective costs without offset or multiplier. The phase
798    /// parameter is needed to determine if the solution should be heuristically
799    /// improved.
800    fn get_solution_and_internal_costs(
801        &mut self,
802        tightening: bool,
803    ) -> anyhow::Result<(Vec<usize>, Assignment)> {
804        let mut sol = self.oracle.solution(self.var_manager.max_var().unwrap())?;
805        let costs = (0..self.objs.len())
806            .map(|idx| self.get_cost_with_heuristic_improvements(idx, &mut sol, tightening))
807            .collect::<Result<Vec<usize>, _>>()?;
808        debug_assert_eq!(costs.len(), self.stats.n_objs);
809        Ok((costs, sol))
810    }
811}
812
813impl<'learn, 'term, ProofW, OInit, BCG>
814    Kernel<rustsat_cadical::CaDiCaL<'learn, 'term>, ProofW, OInit, BCG>
815where
816    ProofW: io::Write + 'static,
817{
818    /// Performs linear sat-unsat search on a given objective.
819    fn linsu(
820        &mut self,
821        obj_idx: usize,
822        encoding: &mut ObjEncoding<GeneralizedTotalizer, Totalizer>,
823        base_assumps: &[Lit],
824        upper_bound: Option<(usize, Option<Assignment>)>,
825        lower_bound: Option<usize>,
826    ) -> MaybeTerminatedError<Option<(usize, Assignment, Option<pigeons::AbsConstraintId>)>> {
827        use rustsat::solvers::Solve;
828
829        self.log_routine_start("linsu")?;
830
831        let mut lb_id = None;
832
833        let lower_bound = lower_bound.unwrap_or(0);
834
835        let (mut cost, mut sol) = if let Some(bound) = upper_bound {
836            bound
837        } else {
838            let res = self.solve_assumps(base_assumps)?;
839            if res == SolverResult::Unsat {
840                self.log_routine_end()?;
841                return Done(None);
842            }
843            let mut sol = self.oracle.solution(self.var_manager.max_var().unwrap())?;
844            let cost = self.get_cost_with_heuristic_improvements(obj_idx, &mut sol, true)?;
845            (cost, Some(sol))
846        };
847        let mut assumps = Vec::from(base_assumps);
848        #[cfg(feature = "coarse-convergence")]
849        let mut coarse = true;
850        while cost > lower_bound {
851            let bound = '_bound: {
852                #[cfg(feature = "coarse-convergence")]
853                if coarse {
854                    break '_bound encoding.coarse_ub(cost - 1);
855                }
856                cost - 1
857            };
858            assumps.drain(base_assumps.len()..);
859            self.extend_encoding(encoding, bound..bound + 1)?;
860            assumps.extend(encoding.enforce_ub(bound).unwrap());
861            match self.solve_assumps(&assumps)? {
862                SolverResult::Sat => {
863                    let mut thissol = self.oracle.solution(self.var_manager.max_var().unwrap())?;
864                    let new_cost =
865                        self.get_cost_with_heuristic_improvements(obj_idx, &mut thissol, false)?;
866                    debug_assert!(new_cost < cost);
867                    let costs: Vec<_> = (0..self.stats.n_objs)
868                        .map(|oidx| {
869                            self.get_cost_with_heuristic_improvements(oidx, &mut thissol, false)
870                                .unwrap()
871                        })
872                        .collect();
873                    self.log_candidate(&costs, Phase::Linsu)?;
874                    sol = Some(thissol);
875                    cost = new_cost;
876                    if cost <= lower_bound {
877                        self.log_routine_end()?;
878                        return Done(Some((cost, sol.unwrap(), None)));
879                    }
880                }
881                SolverResult::Unsat => {
882                    #[cfg(feature = "coarse-convergence")]
883                    if bound + 1 < cost {
884                        coarse = false;
885                        continue;
886                    }
887
888                    if let Some(proof_stuff) = &mut self.proof_stuff {
889                        lb_id = Some(proofs::linsu_certify_lower_bound(
890                            base_assumps,
891                            cost,
892                            &(self.oracle.core()?),
893                            &self.objs[obj_idx],
894                            encoding,
895                            proof_stuff,
896                            &mut self.oracle,
897                        )?);
898                    }
899
900                    break;
901                }
902                _ => unreachable!(),
903            }
904        }
905
906        // make sure to have a solution
907        if sol.is_none() {
908            self.extend_encoding(encoding, cost..cost + 1)?;
909            assumps.drain(base_assumps.len()..);
910            assumps.extend(encoding.enforce_ub(cost).unwrap());
911            let res = self.solve_assumps(&assumps)?;
912            debug_assert_eq!(res, SolverResult::Sat);
913            sol = Some(self.oracle.solution(self.var_manager.max_var().unwrap())?);
914        }
915        self.log_routine_end()?;
916        Done(Some((cost, sol.unwrap(), lb_id)))
917    }
918
919    fn extend_encoding(
920        &mut self,
921        encoding: &mut ObjEncoding<GeneralizedTotalizer, Totalizer>,
922        range: Range<usize>,
923    ) -> anyhow::Result<()> {
924        if let Some(proofs::ProofStuff { pt_handle, .. }) = &self.proof_stuff {
925            let proof: *mut _ = self.oracle.proof_tracer_mut(pt_handle).proof_mut();
926            #[cfg(feature = "verbose-proofs")]
927            {
928                unsafe { &mut *proof }.comment(&format_args!(
929                    "extending encoding to {}..{}",
930                    range.start, range.end,
931                ))?;
932            }
933            let mut collector = CadicalCertCollector::new(&mut self.oracle, pt_handle);
934            encoding.encode_ub_change_cert(
935                range,
936                &mut collector,
937                &mut self.var_manager,
938                unsafe { &mut *proof },
939            )?;
940        } else {
941            encoding.encode_ub_change(range, &mut self.oracle, &mut self.var_manager)?;
942        }
943        Ok(())
944    }
945}
946
947impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
948where
949    O: SolveIncremental,
950    ProofW: io::Write,
951{
952    /// Wrapper around the oracle with call logging and interrupt detection.
953    /// Assumes that the oracle is unlimited.
954    fn solve(&mut self) -> MaybeTerminatedError<SolverResult> {
955        self.log_routine_start("oracle call")?;
956        let res = self.oracle.solve()?;
957        self.log_routine_end()?;
958        self.check_termination()?;
959        self.log_oracle_call(res)?;
960        Done(res)
961    }
962
963    /// Wrapper around the oracle with call logging and interrupt detection.
964    /// Assumes that the oracle is unlimited.
965    fn solve_assumps(&mut self, assumps: &[Lit]) -> MaybeTerminatedError<SolverResult> {
966        self.log_routine_start("oracle call")?;
967        let res = self.oracle.solve_assumps(assumps)?;
968        self.log_routine_end()?;
969        self.check_termination()?;
970        self.log_oracle_call(res)?;
971        Done(res)
972    }
973}
974
975#[oracle_bounds]
976impl<O, ProofW, OInit, BCG> Kernel<O, ProofW, OInit, BCG>
977where
978    O: SolveIncremental,
979    ProofW: io::Write,
980    OInit: Initialize<O>,
981{
982    /// Resets the oracle and returns an error when the original [`Cnf`] was not stored.
983    fn reset_oracle(&mut self, include_var_manager: bool) -> anyhow::Result<()> {
984        anyhow::ensure!(
985            self.opts.store_cnf,
986            "cannot reset oracle without having stored the CNF"
987        );
988        self.log_routine_start("reset-oracle")?;
989        self.oracle = OInit::init();
990        if include_var_manager {
991            self.oracle.reserve(self.var_manager.max_enc_var())?;
992        } else {
993            self.oracle.reserve(self.var_manager.max_var().unwrap())?;
994        }
995        self.oracle.add_cnf(self.orig_cnf.clone().unwrap())?;
996        #[cfg(feature = "interrupt-oracle")]
997        {
998            *self.oracle_interrupter.lock().unwrap() = Box::new(self.oracle.interrupter());
999        }
1000        #[cfg(feature = "sol-tightening")]
1001        // Freeze objective variables so that they are not removed
1002        for o in &self.objs {
1003            for (l, _) in o.iter() {
1004                self.oracle.freeze_var(l.var())?;
1005            }
1006        }
1007        if include_var_manager {
1008            self.var_manager
1009                .forget_from(self.var_manager.max_enc_var() + 1);
1010        }
1011        self.log_routine_end()?;
1012        Ok(())
1013    }
1014}
1015
1016/// The default blocking clause generator
1017pub fn default_blocking_clause(sol: Assignment) -> Clause {
1018    Clause::from_iter(sol.into_iter().map(Lit::not))
1019}