rustsat_batsat/
lib.rs

1//! # rustsat-batsat - Interface to the BatSat SAT Solver for RustSAT
2//!
3//! Interface to the [BatSat](https://github.com/c-cube/batsat) incremental SAT-Solver to be used with the [RustSAT](https://github.com/chrjabs/rustsat) library.
4//!
5//! BatSat is fully implemented in Rust which has advantages in restricted compilation scenarios like WebAssembly.
6//!
7//! # BatSat Version
8//!
9//! The version of BatSat in this crate is Version 0.6.0.
10
11#![warn(clippy::pedantic)]
12#![warn(missing_docs)]
13#![warn(missing_debug_implementations)]
14
15use std::time::Duration;
16
17use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
18use cpu_time::ProcessTime;
19use rustsat::{
20    solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats},
21    types::{Cl, Clause, Lit, TernaryVal, Var},
22};
23
24/// RustSAT wrapper for [`batsat::BasicSolver`]
25pub type BasicSolver = Solver<batsat::BasicCallbacks>;
26
27/// RustSAT wrapper for a [`batsat::Solver`] Solver from BatSat
28#[derive(Default)]
29pub struct Solver<Cb: Callbacks> {
30    internal: batsat::Solver<Cb>,
31    n_sat: usize,
32    n_unsat: usize,
33    n_terminated: usize,
34    avg_clause_len: f32,
35    cpu_time: Duration,
36}
37
38impl<Cb: Callbacks> std::fmt::Debug for Solver<Cb> {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("Solver")
41            .field("internal", &"omitted")
42            .field("n_sat", &self.n_sat)
43            .field("n_unsat", &self.n_unsat)
44            .field("n_terminated", &self.n_terminated)
45            .field("avg_clause_len", &self.avg_clause_len)
46            .field("cpu_time", &self.cpu_time)
47            .finish()
48    }
49}
50
51impl<Cb: Callbacks> Solver<Cb> {
52    /// Gets a reference to the internal [`BasicSolver`]
53    #[must_use]
54    pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
55        &self.internal
56    }
57
58    /// Gets a mutable reference to the internal [`BasicSolver`]
59    #[must_use]
60    pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
61        &mut self.internal
62    }
63
64    #[allow(clippy::cast_precision_loss)]
65    #[inline]
66    fn update_avg_clause_len(&mut self, clause: &Cl) {
67        self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
68            + clause.len() as f32)
69            / (self.n_clauses() + 1) as f32;
70    }
71
72    fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
73        let a = assumps
74            .iter()
75            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
76            .collect::<Vec<_>>();
77
78        let start = ProcessTime::now();
79        let ret = match self.internal.solve_limited(&a) {
80            x if x == lbool::TRUE => {
81                self.n_sat += 1;
82                SolverResult::Sat
83            }
84            x if x == lbool::FALSE => {
85                self.n_unsat += 1;
86                SolverResult::Unsat
87            }
88            x if x == lbool::UNDEF => {
89                self.n_terminated += 1;
90                SolverResult::Interrupted
91            }
92            _ => unreachable!(),
93        };
94        self.cpu_time += start.elapsed();
95        ret
96    }
97}
98
99impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
100    fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
101        iter.into_iter()
102            .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
103    }
104}
105
106impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
107where
108    C: AsRef<Cl> + ?Sized,
109    Cb: Callbacks,
110{
111    fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
112        iter.into_iter().for_each(|cl| {
113            self.add_clause_ref(cl)
114                .expect("Error adding clause in extend");
115        });
116    }
117}
118
119impl<Cb: Callbacks> Solve for Solver<Cb> {
120    fn signature(&self) -> &'static str {
121        "BatSat 0.6.0"
122    }
123
124    fn solve(&mut self) -> anyhow::Result<SolverResult> {
125        Ok(self.solve_track_stats(&[]))
126    }
127
128    fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
129        let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos());
130
131        match self.internal.value_lit(l) {
132            x if x == lbool::TRUE => Ok(TernaryVal::True),
133            x if x == lbool::FALSE => Ok(TernaryVal::False),
134            x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
135            _ => unreachable!(),
136        }
137    }
138
139    fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
140    where
141        C: AsRef<Cl> + ?Sized,
142    {
143        let clause = clause.as_ref();
144        self.update_avg_clause_len(clause);
145
146        let mut c: Vec<_> = clause
147            .iter()
148            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
149            .collect();
150
151        self.internal.add_clause_reuse(&mut c);
152
153        Ok(())
154    }
155}
156
157impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
158    fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
159        Ok(self.solve_track_stats(assumps))
160    }
161
162    fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
163        Ok(self
164            .internal
165            .unsat_core()
166            .iter()
167            .map(|l| Lit::new(l.var().idx() - 1, !l.sign()))
168            .collect::<Vec<_>>())
169    }
170}
171
172impl<Cb: Callbacks> SolveStats for Solver<Cb> {
173    fn stats(&self) -> SolverStats {
174        SolverStats {
175            n_sat: self.n_sat,
176            n_unsat: self.n_unsat,
177            n_terminated: self.n_terminated,
178            n_clauses: self.n_clauses(),
179            max_var: self.max_var(),
180            avg_clause_len: self.avg_clause_len,
181            cpu_solve_time: self.cpu_time,
182        }
183    }
184
185    fn n_sat_solves(&self) -> usize {
186        self.n_sat
187    }
188
189    fn n_unsat_solves(&self) -> usize {
190        self.n_unsat
191    }
192
193    fn n_terminated(&self) -> usize {
194        self.n_terminated
195    }
196
197    fn n_clauses(&self) -> usize {
198        usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
199    }
200
201    fn max_var(&self) -> Option<Var> {
202        let num = self.internal.num_vars();
203        if num > 0 {
204            // BatSat returns a value that is off by one
205            Some(Var::new(num - 2))
206        } else {
207            None
208        }
209    }
210
211    fn avg_clause_len(&self) -> f32 {
212        self.avg_clause_len
213    }
214
215    fn cpu_solve_time(&self) -> Duration {
216        self.cpu_time
217    }
218}
219
220#[cfg(test)]
221mod test {
222    rustsat_solvertests::basic_unittests!(super::BasicSolver, false);
223}