Skip to main content

rustsat/solvers/
simulators.rs

1//! # Solver Simulators
2//!
3//! This module contains generic code to simulate features that solvers might not support.
4
5use crate::{
6    instances::Cnf,
7    types::{Cl, Clause, Lit},
8    utils::Timer,
9};
10
11use super::Solve;
12
13#[derive(Debug, PartialEq, Eq, Default)]
14enum InternalSolverState {
15    #[default]
16    Init,
17    Sat,
18    Unsat(Vec<Lit>),
19    Unknown,
20}
21
22impl InternalSolverState {
23    fn to_external(&self) -> super::SolverState {
24        match self {
25            InternalSolverState::Init | InternalSolverState::Unknown => super::SolverState::Input,
26            InternalSolverState::Sat => super::SolverState::Sat,
27            InternalSolverState::Unsat(_) => super::SolverState::Unsat,
28        }
29    }
30}
31
32/// Simulates an incremental solver based on a non-incremental one
33///
34/// This simply stores the added clauses internally, and re-initializes a new solver whenever solve
35/// has been called. Assumptions are also added as unit clauses and the returned core will always
36/// be the entire set of assumptions.
37///
38/// # Generics
39///
40/// - `S`: the wrapped [`Solve`] type
41/// - `Init`: an initializer for `S`, implementing [`super::Initialize`]
42#[derive(Debug)]
43pub struct Incremental<S, Init = super::DefaultInitializer> {
44    solver: S,
45    init: std::marker::PhantomData<Init>,
46    state: InternalSolverState,
47    clauses: Cnf,
48    stats: super::SolverStats,
49}
50
51impl<S, Init> Default for Incremental<S, Init>
52where
53    Init: super::Initialize<S>,
54{
55    fn default() -> Self {
56        Self {
57            solver: Init::init(),
58            init: std::marker::PhantomData,
59            state: InternalSolverState::default(),
60            clauses: Cnf::default(),
61            stats: super::SolverStats::default(),
62        }
63    }
64}
65
66impl<S, Init> Incremental<S, Init> {
67    #[allow(clippy::cast_precision_loss)]
68    #[inline]
69    fn update_avg_clause_len(&mut self, clause: &Cl) {
70        self.stats.avg_clause_len =
71            (self.stats.avg_clause_len * ((self.stats.n_clauses - 1) as f32) + clause.len() as f32)
72                / self.stats.n_clauses as f32;
73    }
74
75    #[inline]
76    fn update_max_var(&mut self, clause: &Cl) {
77        if self.stats.max_var.is_none() {
78            self.stats.max_var = Some(crate::types::Var::new(0));
79        }
80        let max_var = self.stats.max_var.as_mut().unwrap();
81        for lit in clause {
82            *max_var = std::cmp::max(*max_var, lit.var());
83        }
84    }
85}
86
87impl<S, Init> super::Solve for Incremental<S, Init>
88where
89    S: super::Solve,
90    Init: super::Initialize<S>,
91{
92    fn signature(&self) -> &'static str {
93        self.solver.signature()
94    }
95
96    fn solve(&mut self) -> anyhow::Result<super::SolverResult> {
97        match &self.state {
98            InternalSolverState::Sat => return Ok(super::SolverResult::Sat),
99            InternalSolverState::Unsat(lits) if lits.is_empty() => {
100                return Ok(super::SolverResult::Unsat)
101            }
102            InternalSolverState::Unknown | InternalSolverState::Unsat(_) => {
103                self.solver = Init::init();
104                self.solver.add_cnf_ref(&self.clauses)?;
105            }
106            InternalSolverState::Init => (),
107        }
108        let start = Timer::now();
109        let res = self.solver.solve()?;
110        self.stats.cpu_solve_time += start.elapsed();
111        match res {
112            super::SolverResult::Sat => {
113                self.stats.n_sat += 1;
114                self.state = InternalSolverState::Sat;
115            }
116            super::SolverResult::Unsat => {
117                self.stats.n_unsat += 1;
118                self.state = InternalSolverState::Unsat(vec![]);
119            }
120            super::SolverResult::Interrupted => {
121                self.stats.n_terminated += 1;
122                self.state = InternalSolverState::Unknown;
123            }
124        }
125        Ok(res)
126    }
127
128    fn lit_val(&self, lit: Lit) -> anyhow::Result<crate::types::TernaryVal> {
129        self.solver.lit_val(lit)
130    }
131
132    fn var_val(&self, var: crate::types::Var) -> anyhow::Result<crate::types::TernaryVal> {
133        self.solver.var_val(var)
134    }
135
136    fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
137    where
138        C: AsRef<crate::types::Cl> + ?Sized,
139    {
140        self.stats.n_clauses += 1;
141        self.update_avg_clause_len(clause.as_ref());
142        self.update_max_var(clause.as_ref());
143        if matches!(self.state, InternalSolverState::Init) {
144            self.solver.add_clause_ref(clause)?;
145        } else {
146            self.state = InternalSolverState::Init;
147            self.solver = Init::init();
148            self.solver.add_cnf_ref(&self.clauses)?;
149            self.solver.add_clause_ref(&clause)?;
150        }
151        self.clauses
152            .add_clause(clause.as_ref().iter().copied().collect());
153        Ok(())
154    }
155
156    fn add_clause(&mut self, clause: Clause) -> anyhow::Result<()> {
157        self.stats.n_clauses += 1;
158        self.update_avg_clause_len(&clause);
159        self.update_max_var(&clause);
160        if matches!(self.state, InternalSolverState::Init) {
161            self.solver.add_clause_ref(&clause)?;
162        } else {
163            self.state = InternalSolverState::Init;
164            self.solver = Init::init();
165            self.solver.add_cnf_ref(&self.clauses)?;
166            self.solver.add_clause_ref(&clause)?;
167        }
168        self.clauses.add_clause(clause);
169        Ok(())
170    }
171
172    fn solution(&self, high_var: crate::types::Var) -> anyhow::Result<crate::types::Assignment> {
173        self.solver.solution(high_var)
174    }
175}
176
177impl<S, Init> super::SolveStats for Incremental<S, Init> {
178    fn stats(&self) -> super::SolverStats {
179        self.stats.clone()
180    }
181}
182
183impl<S, Init> super::SolveIncremental for Incremental<S, Init>
184where
185    S: super::Solve,
186    Init: super::Initialize<S>,
187{
188    fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<super::SolverResult> {
189        let start = Timer::now();
190        if !matches!(self.state, InternalSolverState::Init) {
191            self.solver = Init::init();
192            self.solver.add_cnf_ref(&self.clauses)?;
193        }
194        for lit in assumps {
195            self.solver.add_unit(*lit)?;
196        }
197        let res = self.solver.solve()?;
198        self.stats.cpu_solve_time += start.elapsed();
199        match res {
200            super::SolverResult::Sat => {
201                self.stats.n_sat += 1;
202                self.state = InternalSolverState::Sat;
203            }
204            super::SolverResult::Unsat => {
205                self.stats.n_unsat += 1;
206                self.state = InternalSolverState::Unsat(assumps.iter().map(|l| !*l).collect());
207            }
208            super::SolverResult::Interrupted => {
209                self.stats.n_terminated += 1;
210                self.state = InternalSolverState::Unknown;
211            }
212        }
213        Ok(res)
214    }
215
216    fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
217        match &self.state {
218            InternalSolverState::Unsat(core) => Ok(core.clone()),
219            other => Err(super::StateError {
220                required_state: super::SolverState::Unsat,
221                actual_state: other.to_external(),
222            }
223            .into()),
224        }
225    }
226}
227
228impl<S, Init> Extend<Clause> for Incremental<S, Init>
229where
230    S: super::Solve,
231    Init: super::Initialize<S>,
232{
233    fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
234        iter.into_iter()
235            .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
236    }
237}
238
239impl<'a, S, Init, C> Extend<&'a C> for Incremental<S, Init>
240where
241    S: super::Solve,
242    Init: super::Initialize<S>,
243    C: AsRef<Cl> + ?Sized,
244{
245    fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
246        iter.into_iter().for_each(|cl| {
247            self.add_clause_ref(cl)
248                .expect("Error adding clause in extend");
249        });
250    }
251}