Skip to main content

rustsat_ipasir/
lib.rs

1//! # rustsat-ipasir - IPASIR Bindings for RustSAT
2//!
3//! [IPASIR](https://github.com/biotomas/ipasir) is a general incremental interface for SAT
4//! solvers. This crate provides bindings for this API to be used with the
5//! [RustSAT](https://github.com/chrjabs/rustsat) library.
6//!
7//! **Note**: This crate only provides bindings to the API, linking to a IPASIR library needs to be
8//! set up by the user. This is intentional to allow easy integration of solvers that we do not
9//! provide a specialized crate for. For a plug-and-play experience see the other RustSAT solver
10//! crates.
11//!
12//! ## Linking
13//!
14//! Linking to a IPASIR library can be done by adding something like the following to your projects
15//! build script (`build.rs`).
16//!
17//! ```
18//! // Link to custom IPASIR solver
19//! // Modify this for linking to your static library
20//! // The name of the library should be _without_ the prefix 'lib' and the suffix '.a'
21//! println!("cargo:rustc-link-lib=static=<path-to-your-static-lib>");
22//! println!("cargo:rustc-link-search=<name-of-your-static-lib>");
23//! // If your IPASIR solver links to the C++ stdlib, the next four lines are required
24//! #[cfg(target_os = "macos")]
25//! println!("cargo:rustc-flags=-l dylib=c++");
26//! #[cfg(not(target_os = "macos"))]
27//! println!("cargo:rustc-flags=-l dylib=stdc++");
28//! ```
29//!
30//! ## Minimum Supported Rust Version (MSRV)
31//!
32//! Currently, the MSRV is 1.76.0, the plan is to always support an MSRV that is at least a year
33//! old.
34//!
35//! Bumps in the MSRV will _not_ be considered breaking changes. If you need a specific MSRV, make
36//! sure to pin a precise version of RustSAT.
37
38#![warn(clippy::pedantic)]
39#![warn(missing_docs)]
40#![warn(missing_debug_implementations)]
41
42use core::ffi::{c_int, c_void, CStr};
43
44use ffi::IpasirHandle;
45use rustsat::{
46    solvers::{
47        ControlSignal, Learn, Solve, SolveIncremental, SolveStats, SolverResult, SolverState,
48        SolverStats, StateError, Terminate,
49    },
50    types::{Cl, Clause, Lit, TernaryVal},
51    utils::Timer,
52};
53use thiserror::Error;
54
55/// Fatal error returned if the IPASIR API returns an invalid value
56#[derive(Error, Clone, Copy, PartialEq, Eq, Debug)]
57#[error("ipasir c-api returned an invalid value: {api_call} -> {value}")]
58pub struct InvalidApiReturn {
59    api_call: &'static str,
60    value: c_int,
61}
62
63#[derive(Debug, PartialEq, Eq, Default)]
64#[allow(dead_code)] // Not all solvers use all states
65enum InternalSolverState {
66    #[default]
67    Configuring,
68    Input,
69    Sat,
70    Unsat(Vec<Lit>),
71}
72
73impl InternalSolverState {
74    fn to_external(&self) -> SolverState {
75        match self {
76            InternalSolverState::Configuring => SolverState::Configuring,
77            InternalSolverState::Input => SolverState::Input,
78            InternalSolverState::Sat => SolverState::Sat,
79            InternalSolverState::Unsat(_) => SolverState::Unsat,
80        }
81    }
82}
83
84type TermCallbackPtr<'a> = Box<dyn FnMut() -> ControlSignal + 'a>;
85type LearnCallbackPtr<'a> = Box<dyn FnMut(Clause) + 'a>;
86/// Double boxing is necessary to get thin pointers for casting
87type OptTermCallbackStore<'a> = Option<Box<TermCallbackPtr<'a>>>;
88/// Double boxing is necessary to get thin pointers for casting
89type OptLearnCallbackStore<'a> = Option<Box<LearnCallbackPtr<'a>>>;
90
91/// Type for an IPASIR solver.
92pub struct IpasirSolver<'term, 'learn> {
93    handle: *mut IpasirHandle,
94    state: InternalSolverState,
95    terminate_cb: OptTermCallbackStore<'term>,
96    learner_cb: OptLearnCallbackStore<'learn>,
97    stats: SolverStats,
98}
99
100impl std::fmt::Debug for IpasirSolver<'_, '_> {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("IpasirSolver")
103            .field("handle", &self.handle)
104            .field("state", &self.state)
105            .field(
106                "terminate_cb",
107                if self.terminate_cb.is_some() {
108                    &"some callback"
109                } else {
110                    &"no callback"
111                },
112            )
113            .field(
114                "learner_cb",
115                if self.learner_cb.is_some() {
116                    &"some callback"
117                } else {
118                    &"no callback"
119                },
120            )
121            .field("stats", &self.stats)
122            .finish()
123    }
124}
125
126unsafe impl Send for IpasirSolver<'_, '_> {}
127
128impl Default for IpasirSolver<'_, '_> {
129    fn default() -> Self {
130        Self {
131            handle: unsafe { ffi::ipasir_init() },
132            state: InternalSolverState::default(),
133            terminate_cb: None,
134            learner_cb: None,
135            stats: SolverStats::default(),
136        }
137    }
138}
139
140impl IpasirSolver<'_, '_> {
141    fn get_core_assumps(&self, assumps: &[Lit]) -> Result<Vec<Lit>, InvalidApiReturn> {
142        let mut core = Vec::with_capacity(assumps.len());
143        core.reserve(assumps.len());
144        for a in assumps {
145            match unsafe { ffi::ipasir_failed(self.handle, a.to_ipasir()) } {
146                0 => (),
147                1 => core.push(!*a),
148                value => {
149                    return Err(InvalidApiReturn {
150                        api_call: "ipasir_failed",
151                        value,
152                    })
153                }
154            }
155        }
156        Ok(core)
157    }
158
159    #[allow(clippy::cast_precision_loss)]
160    #[inline]
161    fn update_avg_clause_len(&mut self, clause: &Cl) {
162        self.stats.avg_clause_len =
163            (self.stats.avg_clause_len * ((self.stats.n_clauses - 1) as f32) + clause.len() as f32)
164                / self.stats.n_clauses as f32;
165    }
166}
167
168impl Solve for IpasirSolver<'_, '_> {
169    fn signature(&self) -> &'static str {
170        let c_chars = unsafe { ffi::ipasir_signature() };
171        let c_str = unsafe { CStr::from_ptr(c_chars) };
172        c_str
173            .to_str()
174            .expect("IPASIR signature returned invalid UTF-8.")
175    }
176
177    fn solve(&mut self) -> anyhow::Result<SolverResult> {
178        // If already solved, return state
179        if let InternalSolverState::Sat = self.state {
180            return Ok(SolverResult::Sat);
181        }
182        if let InternalSolverState::Unsat(core) = &self.state {
183            if core.is_empty() {
184                return Ok(SolverResult::Unsat);
185            }
186        }
187        let start = Timer::now();
188        // Solve with IPASIR backend
189        let res = unsafe { ffi::ipasir_solve(self.handle) };
190        self.stats.cpu_solve_time += start.elapsed();
191        match res {
192            0 => {
193                self.stats.n_terminated += 1;
194                self.state = InternalSolverState::Input;
195                Ok(SolverResult::Interrupted)
196            }
197            10 => {
198                self.stats.n_sat += 1;
199                self.state = InternalSolverState::Sat;
200                Ok(SolverResult::Sat)
201            }
202            20 => {
203                self.stats.n_unsat += 1;
204                self.state = InternalSolverState::Unsat(vec![]);
205                Ok(SolverResult::Unsat)
206            }
207            value => Err(InvalidApiReturn {
208                api_call: "ipasir_solve",
209                value,
210            }
211            .into()),
212        }
213    }
214
215    fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
216        if self.state != InternalSolverState::Sat {
217            return Err(StateError {
218                required_state: SolverState::Sat,
219                actual_state: self.state.to_external(),
220            }
221            .into());
222        }
223        let lit = lit.to_ipasir();
224        match unsafe { ffi::ipasir_val(self.handle, lit) } {
225            0 => Ok(TernaryVal::DontCare),
226            p if p == lit => Ok(TernaryVal::True),
227            n if n == -lit => Ok(TernaryVal::False),
228            value => Err(InvalidApiReturn {
229                api_call: "ipasir_val",
230                value,
231            }
232            .into()),
233        }
234    }
235
236    fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
237    where
238        C: AsRef<Cl> + ?Sized,
239    {
240        let clause = clause.as_ref();
241        // Update wrapper-internal state
242        self.stats.n_clauses += 1;
243        clause.iter().for_each(|l| match self.stats.max_var {
244            None => self.stats.max_var = Some(l.var()),
245            Some(var) => {
246                if l.var() > var {
247                    self.stats.max_var = Some(l.var());
248                }
249            }
250        });
251        self.update_avg_clause_len(clause);
252        self.state = InternalSolverState::Input;
253        // Call IPASIR backend
254        for lit in clause {
255            unsafe { ffi::ipasir_add(self.handle, lit.to_ipasir()) }
256        }
257        unsafe { ffi::ipasir_add(self.handle, 0) };
258        Ok(())
259    }
260}
261
262impl SolveIncremental for IpasirSolver<'_, '_> {
263    fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
264        // If in error state, remain there
265        // If not, need to resolve because assumptions might have changed
266        let start = Timer::now();
267        // Solve with IPASIR backend
268        for a in assumps {
269            unsafe { ffi::ipasir_assume(self.handle, a.to_ipasir()) }
270        }
271        let res = unsafe { ffi::ipasir_solve(self.handle) };
272        self.stats.cpu_solve_time += start.elapsed();
273        match res {
274            0 => {
275                self.stats.n_terminated += 1;
276                self.state = InternalSolverState::Input;
277                Ok(SolverResult::Interrupted)
278            }
279            10 => {
280                self.stats.n_sat += 1;
281                self.state = InternalSolverState::Sat;
282                Ok(SolverResult::Sat)
283            }
284            20 => {
285                self.stats.n_unsat += 1;
286                self.state = InternalSolverState::Unsat(self.get_core_assumps(assumps)?);
287                Ok(SolverResult::Unsat)
288            }
289            value => Err(InvalidApiReturn {
290                api_call: "ipasir_solve",
291                value,
292            }
293            .into()),
294        }
295    }
296
297    fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
298        match &self.state {
299            InternalSolverState::Unsat(core) => Ok(core.clone()),
300            other => Err(StateError {
301                required_state: SolverState::Unsat,
302                actual_state: other.to_external(),
303            }
304            .into()),
305        }
306    }
307}
308
309impl<'term> Terminate<'term> for IpasirSolver<'term, '_> {
310    /// Sets a terminator callback that is regularly called during solving.
311    ///
312    /// # Examples
313    ///
314    /// Terminate solver after 10 callback calls.
315    ///
316    /// ```no_run
317    /// use rustsat::solvers::{ControlSignal, Solve, SolverResult, Terminate};
318    /// use rustsat_ipasir::IpasirSolver;
319    ///
320    /// let mut solver = IpasirSolver::default();
321    ///
322    /// // Load instance
323    ///
324    /// let mut cnt = 1;
325    /// solver.attach_terminator(move || {
326    ///     if cnt > 10 {
327    ///         ControlSignal::Terminate
328    ///     } else {
329    ///         cnt += 1;
330    ///         ControlSignal::Continue
331    ///     }
332    /// });
333    ///
334    /// let ret = solver.solve().unwrap();
335    ///
336    /// // Assuming an instance is actually loaded and runs long enough
337    /// // assert_eq!(ret, SolverResult::Interrupted);
338    /// ```
339    fn attach_terminator<CB>(&mut self, cb: CB)
340    where
341        CB: FnMut() -> ControlSignal + 'term,
342    {
343        self.terminate_cb = Some(Box::new(Box::new(cb)));
344        let cb_ptr =
345            std::ptr::from_ref(self.terminate_cb.as_mut().unwrap().as_mut()).cast::<c_void>();
346        unsafe { ffi::ipasir_set_terminate(self.handle, cb_ptr, Some(ffi::ipasir_terminate_cb)) }
347    }
348
349    fn detach_terminator(&mut self) {
350        self.terminate_cb = None;
351        unsafe { ffi::ipasir_set_terminate(self.handle, std::ptr::null(), None) }
352    }
353}
354
355impl<'learn> Learn<'learn> for IpasirSolver<'_, 'learn> {
356    /// Sets a learner callback that gets passed clauses up to a certain length learned by the solver.
357    ///
358    /// The callback goes out of scope with the solver, afterwards captured variables become accessible.
359    ///
360    /// # Examples
361    ///
362    /// Count number of learned clauses up to length 10.
363    ///
364    /// ```no_run
365    /// use rustsat::solvers::{Solve, SolverResult, Learn};
366    /// use rustsat_ipasir::IpasirSolver;
367    ///
368    /// let mut cnt = 0;
369    ///
370    /// {
371    ///     let mut solver = IpasirSolver::default();
372    ///     // Load instance
373    ///
374    ///     solver.attach_learner(|_| cnt += 1, 10);
375    ///
376    ///     solver.solve().unwrap();
377    /// }
378    ///
379    /// // cnt variable can be accessed from here on
380    /// ```
381    fn attach_learner<CB>(&mut self, cb: CB, max_len: usize)
382    where
383        CB: FnMut(Clause) + 'learn,
384    {
385        self.learner_cb = Some(Box::new(Box::new(cb)));
386        let cb_ptr =
387            std::ptr::from_ref(self.learner_cb.as_mut().unwrap().as_mut()).cast::<c_void>();
388        unsafe {
389            ffi::ipasir_set_learn(
390                self.handle,
391                cb_ptr,
392                max_len.try_into().unwrap(),
393                Some(ffi::ipasir_learn_cb),
394            );
395        }
396    }
397
398    fn detach_learner(&mut self) {
399        self.terminate_cb = None;
400        unsafe { ffi::ipasir_set_learn(self.handle, std::ptr::null(), 0, None) }
401    }
402}
403
404impl SolveStats for IpasirSolver<'_, '_> {
405    fn stats(&self) -> SolverStats {
406        self.stats.clone()
407    }
408}
409
410impl Drop for IpasirSolver<'_, '_> {
411    fn drop(&mut self) {
412        unsafe { ffi::ipasir_release(self.handle) }
413    }
414}
415
416impl Extend<Clause> for IpasirSolver<'_, '_> {
417    fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
418        iter.into_iter()
419            .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
420    }
421}
422
423impl<'a, C> Extend<&'a C> for IpasirSolver<'_, '_>
424where
425    C: AsRef<Cl> + ?Sized,
426{
427    fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
428        iter.into_iter().for_each(|cl| {
429            self.add_clause_ref(cl)
430                .expect("Error adding clause in extend");
431        });
432    }
433}
434
435mod ffi {
436    use super::{LearnCallbackPtr, TermCallbackPtr};
437    use core::ffi::{c_char, c_int, c_void};
438    use rustsat::{solvers::ControlSignal, types::Lit, utils::from_raw_parts_maybe_null};
439
440    #[repr(C)]
441    pub struct IpasirHandle {
442        _private: [u8; 0],
443    }
444
445    extern "C" {
446        // Redefinitions of IPASIR functions
447        pub fn ipasir_signature() -> *const c_char;
448        pub fn ipasir_init() -> *mut IpasirHandle;
449        pub fn ipasir_release(solver: *mut IpasirHandle);
450        pub fn ipasir_add(solver: *mut IpasirHandle, lit_or_zero: c_int);
451        pub fn ipasir_assume(solver: *mut IpasirHandle, lit: c_int);
452        pub fn ipasir_solve(solver: *mut IpasirHandle) -> c_int;
453        pub fn ipasir_val(solver: *mut IpasirHandle, lit: c_int) -> c_int;
454        pub fn ipasir_failed(solver: *mut IpasirHandle, lit: c_int) -> c_int;
455        pub fn ipasir_set_terminate(
456            solver: *mut IpasirHandle,
457            state: *const c_void,
458            terminate: Option<unsafe extern "C" fn(state: *const c_void) -> c_int>,
459        );
460        pub fn ipasir_set_learn(
461            solver: *mut IpasirHandle,
462            state: *const c_void,
463            max_length: c_int,
464            learn: Option<unsafe extern "C" fn(state: *const c_void, clause: *const c_int)>,
465        );
466    }
467
468    // Raw callbacks forwarding to user callbacks
469    pub unsafe extern "C" fn ipasir_terminate_cb(ptr: *const c_void) -> c_int {
470        let cb = &mut *(ptr as *mut TermCallbackPtr<'_>);
471        match cb() {
472            ControlSignal::Continue => 0,
473            ControlSignal::Terminate => 1,
474        }
475    }
476
477    pub unsafe extern "C" fn ipasir_learn_cb(ptr: *const c_void, clause: *const c_int) {
478        let cb = unsafe { &mut *(ptr as *mut LearnCallbackPtr<'_>) };
479
480        let mut cnt: usize = 0;
481        while *clause
482            .offset(isize::try_from(cnt).expect("learned clauses is longer than `isize::MAX`"))
483            != 0
484        {
485            cnt += 1;
486        }
487        let int_slice = from_raw_parts_maybe_null(clause, cnt);
488        let clause = int_slice
489            .iter()
490            .map(|il| {
491                Lit::from_ipasir(*il).expect("Invalid literal in learned clause from IPASIR solver")
492            })
493            .collect();
494        cb(clause);
495    }
496}