rustsat_batsat/
lib.rs

1//! # rustsat-batsat - Interface to the BatSat SAT Solver for RustSAT
2//!
3//! Interface to the [BatSat](https://github.com/c-cube/batsat) incremental SAT-Solver to be used with the [RustSAT](https://github.com/chrjabs/rustsat) library.
4//!
5//! BatSat is fully implemented in Rust which has advantages in restricted compilation scenarios like WebAssembly.
6//!
7//! # BatSat Version
8//!
9//! The version of BatSat in this crate is Version 0.6.0.
10//!
11//! ## Minimum Supported Rust Version (MSRV)
12//!
13//! Currently, the MSRV is 1.76.0, the plan is to always support an MSRV that is at least a year
14//! old.
15//!
16//! Bumps in the MSRV will _not_ be considered breaking changes. If you need a specific MSRV, make
17//! sure to pin a precise version of RustSAT.
18
19#![warn(clippy::pedantic)]
20#![warn(missing_docs)]
21#![warn(missing_debug_implementations)]
22
23use std::time::Duration;
24
25use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
26use cpu_time::ProcessTime;
27use rustsat::{
28    solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats},
29    types::{Cl, Clause, Lit, TernaryVal, Var},
30};
31
32/// RustSAT wrapper for [`batsat::BasicSolver`]
33pub type BasicSolver = Solver<batsat::BasicCallbacks>;
34
35/// RustSAT wrapper for a [`batsat::Solver`] Solver from BatSat
36#[derive(Default)]
37pub struct Solver<Cb: Callbacks> {
38    internal: batsat::Solver<Cb>,
39    n_sat: usize,
40    n_unsat: usize,
41    n_terminated: usize,
42    avg_clause_len: f32,
43    cpu_time: Duration,
44}
45
46impl<Cb: Callbacks> std::fmt::Debug for Solver<Cb> {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("Solver")
49            .field("internal", &"omitted")
50            .field("n_sat", &self.n_sat)
51            .field("n_unsat", &self.n_unsat)
52            .field("n_terminated", &self.n_terminated)
53            .field("avg_clause_len", &self.avg_clause_len)
54            .field("cpu_time", &self.cpu_time)
55            .finish()
56    }
57}
58
59impl<Cb: Callbacks> Solver<Cb> {
60    /// Gets a reference to the internal [`BasicSolver`]
61    #[must_use]
62    pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
63        &self.internal
64    }
65
66    /// Gets a mutable reference to the internal [`BasicSolver`]
67    #[must_use]
68    pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
69        &mut self.internal
70    }
71
72    #[allow(clippy::cast_precision_loss)]
73    #[inline]
74    fn update_avg_clause_len(&mut self, clause: &Cl) {
75        self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
76            + clause.len() as f32)
77            / (self.n_clauses() + 1) as f32;
78    }
79
80    fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
81        let a = assumps
82            .iter()
83            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
84            .collect::<Vec<_>>();
85
86        let start = ProcessTime::now();
87        let ret = match self.internal.solve_limited(&a) {
88            x if x == lbool::TRUE => {
89                self.n_sat += 1;
90                SolverResult::Sat
91            }
92            x if x == lbool::FALSE => {
93                self.n_unsat += 1;
94                SolverResult::Unsat
95            }
96            x if x == lbool::UNDEF => {
97                self.n_terminated += 1;
98                SolverResult::Interrupted
99            }
100            _ => unreachable!(),
101        };
102        self.cpu_time += start.elapsed();
103        ret
104    }
105}
106
107impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
108    fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
109        iter.into_iter()
110            .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
111    }
112}
113
114impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
115where
116    C: AsRef<Cl> + ?Sized,
117    Cb: Callbacks,
118{
119    fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
120        iter.into_iter().for_each(|cl| {
121            self.add_clause_ref(cl)
122                .expect("Error adding clause in extend");
123        });
124    }
125}
126
127impl<Cb: Callbacks> Solve for Solver<Cb> {
128    fn signature(&self) -> &'static str {
129        "BatSat 0.6.0"
130    }
131
132    fn solve(&mut self) -> anyhow::Result<SolverResult> {
133        Ok(self.solve_track_stats(&[]))
134    }
135
136    fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
137        let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos());
138
139        match self.internal.value_lit(l) {
140            x if x == lbool::TRUE => Ok(TernaryVal::True),
141            x if x == lbool::FALSE => Ok(TernaryVal::False),
142            x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
143            _ => unreachable!(),
144        }
145    }
146
147    fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
148    where
149        C: AsRef<Cl> + ?Sized,
150    {
151        let clause = clause.as_ref();
152        self.update_avg_clause_len(clause);
153
154        let mut c: Vec<_> = clause
155            .iter()
156            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
157            .collect();
158
159        self.internal.add_clause_reuse(&mut c);
160
161        Ok(())
162    }
163}
164
165impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
166    fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
167        Ok(self.solve_track_stats(assumps))
168    }
169
170    fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
171        Ok(self
172            .internal
173            .unsat_core()
174            .iter()
175            .map(|l| Lit::new(l.var().idx() - 1, !l.sign()))
176            .collect::<Vec<_>>())
177    }
178}
179
180impl<Cb: Callbacks> SolveStats for Solver<Cb> {
181    fn stats(&self) -> SolverStats {
182        SolverStats {
183            n_sat: self.n_sat,
184            n_unsat: self.n_unsat,
185            n_terminated: self.n_terminated,
186            n_clauses: self.n_clauses(),
187            max_var: self.max_var(),
188            avg_clause_len: self.avg_clause_len,
189            cpu_solve_time: self.cpu_time,
190        }
191    }
192
193    fn n_sat_solves(&self) -> usize {
194        self.n_sat
195    }
196
197    fn n_unsat_solves(&self) -> usize {
198        self.n_unsat
199    }
200
201    fn n_terminated(&self) -> usize {
202        self.n_terminated
203    }
204
205    fn n_clauses(&self) -> usize {
206        usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
207    }
208
209    fn max_var(&self) -> Option<Var> {
210        let num = self.internal.num_vars();
211        if num > 0 {
212            // BatSat returns a value that is off by one
213            Some(Var::new(num - 2))
214        } else {
215            None
216        }
217    }
218
219    fn avg_clause_len(&self) -> f32 {
220        self.avg_clause_len
221    }
222
223    fn cpu_solve_time(&self) -> Duration {
224        self.cpu_time
225    }
226}
227
228#[cfg(test)]
229mod test {
230    rustsat_solvertests::basic_unittests!(super::BasicSolver, false);
231}