rustsat_batsat/
lib.rs

1//! # rustsat-batsat - Interface to the BatSat SAT Solver for RustSAT
2//!
3//! Interface to the [BatSat](https://github.com/c-cube/batsat) incremental SAT-Solver to be used with the [RustSAT](https://github.com/chrjabs/rustsat) library.
4//!
5//! BatSat is fully implemented in Rust which has advantages in restricted compilation scenarios like WebAssembly.
6//!
7//! # BatSat Version
8//!
9//! The version of BatSat in this crate is Version 0.6.0.
10
11#![warn(clippy::pedantic)]
12#![warn(missing_docs)]
13
14use std::time::Duration;
15
16use batsat::{intmap::AsIndex, lbool, Callbacks, SolverInterface};
17use cpu_time::ProcessTime;
18use rustsat::{
19    solvers::{Solve, SolveIncremental, SolveStats, SolverResult, SolverStats},
20    types::{Cl, Clause, Lit, TernaryVal, Var},
21};
22
23/// RustSAT wrapper for [`batsat::BasicSolver`]
24pub type BasicSolver = Solver<batsat::BasicCallbacks>;
25
26/// RustSAT wrapper for a [`batsat::Solver`] Solver from BatSat
27#[derive(Default)]
28pub struct Solver<Cb: Callbacks> {
29    internal: batsat::Solver<Cb>,
30    n_sat: usize,
31    n_unsat: usize,
32    n_terminated: usize,
33    avg_clause_len: f32,
34    cpu_time: Duration,
35}
36
37impl<Cb: Callbacks> Solver<Cb> {
38    /// Gets a reference to the internal [`BasicSolver`]
39    #[must_use]
40    pub fn batsat_ref(&self) -> &batsat::Solver<Cb> {
41        &self.internal
42    }
43
44    /// Gets a mutable reference to the internal [`BasicSolver`]
45    #[must_use]
46    pub fn batsat_mut(&mut self) -> &mut batsat::Solver<Cb> {
47        &mut self.internal
48    }
49
50    #[allow(clippy::cast_precision_loss)]
51    #[inline]
52    fn update_avg_clause_len(&mut self, clause: &Cl) {
53        self.avg_clause_len = (self.avg_clause_len * ((self.n_clauses()) as f32)
54            + clause.len() as f32)
55            / (self.n_clauses() + 1) as f32;
56    }
57
58    fn solve_track_stats(&mut self, assumps: &[Lit]) -> SolverResult {
59        let a = assumps
60            .iter()
61            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
62            .collect::<Vec<_>>();
63
64        let start = ProcessTime::now();
65        let ret = match self.internal.solve_limited(&a) {
66            x if x == lbool::TRUE => {
67                self.n_sat += 1;
68                SolverResult::Sat
69            }
70            x if x == lbool::FALSE => {
71                self.n_unsat += 1;
72                SolverResult::Unsat
73            }
74            x if x == lbool::UNDEF => {
75                self.n_terminated += 1;
76                SolverResult::Interrupted
77            }
78            _ => unreachable!(),
79        };
80        self.cpu_time += start.elapsed();
81        ret
82    }
83}
84
85impl<Cb: Callbacks> Extend<Clause> for Solver<Cb> {
86    fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
87        iter.into_iter()
88            .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
89    }
90}
91
92impl<'a, C, Cb> Extend<&'a C> for Solver<Cb>
93where
94    C: AsRef<Cl> + ?Sized,
95    Cb: Callbacks,
96{
97    fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
98        iter.into_iter().for_each(|cl| {
99            self.add_clause_ref(cl)
100                .expect("Error adding clause in extend");
101        });
102    }
103}
104
105impl<Cb: Callbacks> Solve for Solver<Cb> {
106    fn signature(&self) -> &'static str {
107        "BatSat 0.6.0"
108    }
109
110    fn solve(&mut self) -> anyhow::Result<SolverResult> {
111        Ok(self.solve_track_stats(&[]))
112    }
113
114    fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
115        let l = batsat::Lit::new(batsat::Var::from_index(lit.vidx() + 1), lit.is_pos());
116
117        match self.internal.value_lit(l) {
118            x if x == lbool::TRUE => Ok(TernaryVal::True),
119            x if x == lbool::FALSE => Ok(TernaryVal::False),
120            x if x == lbool::UNDEF => Ok(TernaryVal::DontCare),
121            _ => unreachable!(),
122        }
123    }
124
125    fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
126    where
127        C: AsRef<Cl> + ?Sized,
128    {
129        let clause = clause.as_ref();
130        self.update_avg_clause_len(clause);
131
132        let mut c: Vec<_> = clause
133            .iter()
134            .map(|l| batsat::Lit::new(self.internal.var_of_int(l.vidx32() + 1), l.is_pos()))
135            .collect();
136
137        self.internal.add_clause_reuse(&mut c);
138
139        Ok(())
140    }
141}
142
143impl<Cb: Callbacks> SolveIncremental for Solver<Cb> {
144    fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
145        Ok(self.solve_track_stats(assumps))
146    }
147
148    fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
149        Ok(self
150            .internal
151            .unsat_core()
152            .iter()
153            .map(|l| Lit::new(l.var().idx() - 1, !l.sign()))
154            .collect::<Vec<_>>())
155    }
156}
157
158impl<Cb: Callbacks> SolveStats for Solver<Cb> {
159    fn stats(&self) -> SolverStats {
160        SolverStats {
161            n_sat: self.n_sat,
162            n_unsat: self.n_unsat,
163            n_terminated: self.n_terminated,
164            n_clauses: self.n_clauses(),
165            max_var: self.max_var(),
166            avg_clause_len: self.avg_clause_len,
167            cpu_solve_time: self.cpu_time,
168        }
169    }
170
171    fn n_sat_solves(&self) -> usize {
172        self.n_sat
173    }
174
175    fn n_unsat_solves(&self) -> usize {
176        self.n_unsat
177    }
178
179    fn n_terminated(&self) -> usize {
180        self.n_terminated
181    }
182
183    fn n_clauses(&self) -> usize {
184        usize::try_from(self.internal.num_clauses()).expect("more than `usize::MAX` clauses")
185    }
186
187    fn max_var(&self) -> Option<Var> {
188        let num = self.internal.num_vars();
189        if num > 0 {
190            // BatSat returns a value that is off by one
191            Some(Var::new(num - 2))
192        } else {
193            None
194        }
195    }
196
197    fn avg_clause_len(&self) -> f32 {
198        self.avg_clause_len
199    }
200
201    fn cpu_solve_time(&self) -> Duration {
202        self.cpu_time
203    }
204}
205
206#[cfg(test)]
207mod test {
208    rustsat_solvertests::basic_unittests!(super::BasicSolver, false);
209}