1#![warn(clippy::pedantic)]
12#![warn(missing_docs)]
13
14use std::time::Duration;
15
16use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
17use cpu_time::ProcessTime;
18use rustsat::{
19 solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats},
20 types::{Cl, Clause, Lit, TernaryVal, Var},
21};
22
23pub type BasicSolver = Solver<batsat::BasicCallbacks>;
25
26#[derive(Default)]
28pub struct Solver<Cb: Callbacks> {
29 internal: batsat::Solver<Cb>,
30 n_sat: usize,
31 n_unsat: usize,
32 n_terminated: usize,
33 avg_clause_len: f32,
34 cpu_time: Duration,
35}
36
37impl<Cb: Callbacks> Solver<Cb> {
38 #[must_use]
40 pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
41 &self.internal
42 }
43
44 #[must_use]
46 pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
47 &mut self.internal
48 }
49
50 #[allow(clippy::cast_precision_loss)]
51 #[inline]
52 fn update_avg_clause_len(&mut self, clause: &Cl) {
53 self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
54 + clause.len() as f32)
55 / (self.n_clauses() + 1) as f32;
56 }
57
58 fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
59 let a = assumps
60 .iter()
61 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
62 .collect::<Vec<_>>();
63
64 let start = ProcessTime::now();
65 let ret = match self.internal.solve_limited(&a) {
66 x if x == lbool::TRUE => {
67 self.n_sat += 1;
68 SolverResult::Sat
69 }
70 x if x == lbool::FALSE => {
71 self.n_unsat += 1;
72 SolverResult::Unsat
73 }
74 x if x == lbool::UNDEF => {
75 self.n_terminated += 1;
76 SolverResult::Interrupted
77 }
78 _ => unreachable!(),
79 };
80 self.cpu_time += start.elapsed();
81 ret
82 }
83}
84
85impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
86 fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
87 iter.into_iter()
88 .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
89 }
90}
91
92impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
93where
94 C: AsRef<Cl> + ?Sized,
95 Cb: Callbacks,
96{
97 fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
98 iter.into_iter().for_each(|cl| {
99 self.add_clause_ref(cl)
100 .expect("Error adding clause in extend");
101 });
102 }
103}
104
105impl<Cb: Callbacks> Solve for Solver<Cb> {
106 fn signature(&self) -> &'static str {
107 "BatSat 0.6.0"
108 }
109
110 fn solve(&mut self) -> anyhow::Result<SolverResult> {
111 Ok(self.solve_track_stats(&[]))
112 }
113
114 fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
115 let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos());
116
117 match self.internal.value_lit(l) {
118 x if x == lbool::TRUE => Ok(TernaryVal::True),
119 x if x == lbool::FALSE => Ok(TernaryVal::False),
120 x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
121 _ => unreachable!(),
122 }
123 }
124
125 fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
126 where
127 C: AsRef<Cl> + ?Sized,
128 {
129 let clause = clause.as_ref();
130 self.update_avg_clause_len(clause);
131
132 let mut c: Vec<_> = clause
133 .iter()
134 .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
135 .collect();
136
137 self.internal.add_clause_reuse(&mut c);
138
139 Ok(())
140 }
141}
142
143impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
144 fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
145 Ok(self.solve_track_stats(assumps))
146 }
147
148 fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
149 Ok(self
150 .internal
151 .unsat_core()
152 .iter()
153 .map(|l| Lit::new(l.var().idx() - 1, !l.sign()))
154 .collect::<Vec<_>>())
155 }
156}
157
158impl<Cb: Callbacks> SolveStats for Solver<Cb> {
159 fn stats(&self) -> SolverStats {
160 SolverStats {
161 n_sat: self.n_sat,
162 n_unsat: self.n_unsat,
163 n_terminated: self.n_terminated,
164 n_clauses: self.n_clauses(),
165 max_var: self.max_var(),
166 avg_clause_len: self.avg_clause_len,
167 cpu_solve_time: self.cpu_time,
168 }
169 }
170
171 fn n_sat_solves(&self) -> usize {
172 self.n_sat
173 }
174
175 fn n_unsat_solves(&self) -> usize {
176 self.n_unsat
177 }
178
179 fn n_terminated(&self) -> usize {
180 self.n_terminated
181 }
182
183 fn n_clauses(&self) -> usize {
184 usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
185 }
186
187 fn max_var(&self) -> Option<Var> {
188 let num = self.internal.num_vars();
189 if num > 0 {
190 Some(Var::new(num - 2))
192 } else {
193 None
194 }
195 }
196
197 fn avg_clause_len(&self) -> f32 {
198 self.avg_clause_len
199 }
200
201 fn cpu_solve_time(&self) -> Duration {
202 self.cpu_time
203 }
204}
205
206#[cfg(test)]
207mod test {
208 rustsat_solvertests::basic_unittests!(super::BasicSolver, false);
209}