1#![warn(clippy::pedantic)]
39#![warn(missing_docs)]
40#![warn(missing_debug_implementations)]
41
42use core::ffi::{c_int, c_void, CStr};
43
44use ffi::IpasirHandle;
45use rustsat::{
46 solvers::{
47 ControlSignal, Learn, Solve, SolveIncremental, SolveStats, SolverResult, SolverState,
48 SolverStats, StateError, Terminate,
49 },
50 types::{Cl, Clause, Lit, TernaryVal},
51 utils::Timer,
52};
53use thiserror::Error;
54
55#[derive(Error, Clone, Copy, PartialEq, Eq, Debug)]
57#[error("ipasir c-api returned an invalid value: {api_call} -> {value}")]
58pub struct InvalidApiReturn {
59 api_call: &'static str,
60 value: c_int,
61}
62
63#[derive(Debug, PartialEq, Eq, Default)]
64#[allow(dead_code)] enum InternalSolverState {
66 #[default]
67 Configuring,
68 Input,
69 Sat,
70 Unsat(Vec<Lit>),
71}
72
73impl InternalSolverState {
74 fn to_external(&self) -> SolverState {
75 match self {
76 InternalSolverState::Configuring => SolverState::Configuring,
77 InternalSolverState::Input => SolverState::Input,
78 InternalSolverState::Sat => SolverState::Sat,
79 InternalSolverState::Unsat(_) => SolverState::Unsat,
80 }
81 }
82}
83
84type TermCallbackPtr<'a> = Box<dyn FnMut() -> ControlSignal + 'a>;
85type LearnCallbackPtr<'a> = Box<dyn FnMut(Clause) + 'a>;
86type OptTermCallbackStore<'a> = Option<Box<TermCallbackPtr<'a>>>;
88type OptLearnCallbackStore<'a> = Option<Box<LearnCallbackPtr<'a>>>;
90
91pub struct IpasirSolver<'term, 'learn> {
93 handle: *mut IpasirHandle,
94 state: InternalSolverState,
95 terminate_cb: OptTermCallbackStore<'term>,
96 learner_cb: OptLearnCallbackStore<'learn>,
97 stats: SolverStats,
98}
99
100impl std::fmt::Debug for IpasirSolver<'_, '_> {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("IpasirSolver")
103 .field("handle", &self.handle)
104 .field("state", &self.state)
105 .field(
106 "terminate_cb",
107 if self.terminate_cb.is_some() {
108 &"some callback"
109 } else {
110 &"no callback"
111 },
112 )
113 .field(
114 "learner_cb",
115 if self.learner_cb.is_some() {
116 &"some callback"
117 } else {
118 &"no callback"
119 },
120 )
121 .field("stats", &self.stats)
122 .finish()
123 }
124}
125
126unsafe impl Send for IpasirSolver<'_, '_> {}
127
128impl Default for IpasirSolver<'_, '_> {
129 fn default() -> Self {
130 Self {
131 handle: unsafe { ffi::ipasir_init() },
132 state: InternalSolverState::default(),
133 terminate_cb: None,
134 learner_cb: None,
135 stats: SolverStats::default(),
136 }
137 }
138}
139
140impl IpasirSolver<'_, '_> {
141 fn get_core_assumps(&self, assumps: &[Lit]) -> Result<Vec<Lit>, InvalidApiReturn> {
142 let mut core = Vec::with_capacity(assumps.len());
143 core.reserve(assumps.len());
144 for a in assumps {
145 match unsafe { ffi::ipasir_failed(self.handle, a.to_ipasir()) } {
146 0 => (),
147 1 => core.push(!*a),
148 value => {
149 return Err(InvalidApiReturn {
150 api_call: "ipasir_failed",
151 value,
152 })
153 }
154 }
155 }
156 Ok(core)
157 }
158
159 #[allow(clippy::cast_precision_loss)]
160 #[inline]
161 fn update_avg_clause_len(&mut self, clause: &Cl) {
162 self.stats.avg_clause_len =
163 (self.stats.avg_clause_len * ((self.stats.n_clauses - 1) as f32) + clause.len() as f32)
164 / self.stats.n_clauses as f32;
165 }
166}
167
168impl Solve for IpasirSolver<'_, '_> {
169 fn signature(&self) -> &'static str {
170 let c_chars = unsafe { ffi::ipasir_signature() };
171 let c_str = unsafe { CStr::from_ptr(c_chars) };
172 c_str
173 .to_str()
174 .expect("IPASIR signature returned invalid UTF-8.")
175 }
176
177 fn solve(&mut self) -> anyhow::Result<SolverResult> {
178 if let InternalSolverState::Sat = self.state {
180 return Ok(SolverResult::Sat);
181 }
182 if let InternalSolverState::Unsat(core) = &self.state {
183 if core.is_empty() {
184 return Ok(SolverResult::Unsat);
185 }
186 }
187 let start = Timer::now();
188 let res = unsafe { ffi::ipasir_solve(self.handle) };
190 self.stats.cpu_solve_time += start.elapsed();
191 match res {
192 0 => {
193 self.stats.n_terminated += 1;
194 self.state = InternalSolverState::Input;
195 Ok(SolverResult::Interrupted)
196 }
197 10 => {
198 self.stats.n_sat += 1;
199 self.state = InternalSolverState::Sat;
200 Ok(SolverResult::Sat)
201 }
202 20 => {
203 self.stats.n_unsat += 1;
204 self.state = InternalSolverState::Unsat(vec![]);
205 Ok(SolverResult::Unsat)
206 }
207 value => Err(InvalidApiReturn {
208 api_call: "ipasir_solve",
209 value,
210 }
211 .into()),
212 }
213 }
214
215 fn lit_val(&self, lit: Lit) -> anyhow::Result<TernaryVal> {
216 if self.state != InternalSolverState::Sat {
217 return Err(StateError {
218 required_state: SolverState::Sat,
219 actual_state: self.state.to_external(),
220 }
221 .into());
222 }
223 let lit = lit.to_ipasir();
224 match unsafe { ffi::ipasir_val(self.handle, lit) } {
225 0 => Ok(TernaryVal::DontCare),
226 p if p == lit => Ok(TernaryVal::True),
227 n if n == -lit => Ok(TernaryVal::False),
228 value => Err(InvalidApiReturn {
229 api_call: "ipasir_val",
230 value,
231 }
232 .into()),
233 }
234 }
235
236 fn add_clause_ref<C>(&mut self, clause: &C) -> anyhow::Result<()>
237 where
238 C: AsRef<Cl> + ?Sized,
239 {
240 let clause = clause.as_ref();
241 self.stats.n_clauses += 1;
243 clause.iter().for_each(|l| match self.stats.max_var {
244 None => self.stats.max_var = Some(l.var()),
245 Some(var) => {
246 if l.var() > var {
247 self.stats.max_var = Some(l.var());
248 }
249 }
250 });
251 self.update_avg_clause_len(clause);
252 self.state = InternalSolverState::Input;
253 for lit in clause {
255 unsafe { ffi::ipasir_add(self.handle, lit.to_ipasir()) }
256 }
257 unsafe { ffi::ipasir_add(self.handle, 0) };
258 Ok(())
259 }
260}
261
262impl SolveIncremental for IpasirSolver<'_, '_> {
263 fn solve_assumps(&mut self, assumps: &[Lit]) -> anyhow::Result<SolverResult> {
264 let start = Timer::now();
267 for a in assumps {
269 unsafe { ffi::ipasir_assume(self.handle, a.to_ipasir()) }
270 }
271 let res = unsafe { ffi::ipasir_solve(self.handle) };
272 self.stats.cpu_solve_time += start.elapsed();
273 match res {
274 0 => {
275 self.stats.n_terminated += 1;
276 self.state = InternalSolverState::Input;
277 Ok(SolverResult::Interrupted)
278 }
279 10 => {
280 self.stats.n_sat += 1;
281 self.state = InternalSolverState::Sat;
282 Ok(SolverResult::Sat)
283 }
284 20 => {
285 self.stats.n_unsat += 1;
286 self.state = InternalSolverState::Unsat(self.get_core_assumps(assumps)?);
287 Ok(SolverResult::Unsat)
288 }
289 value => Err(InvalidApiReturn {
290 api_call: "ipasir_solve",
291 value,
292 }
293 .into()),
294 }
295 }
296
297 fn core(&mut self) -> anyhow::Result<Vec<Lit>> {
298 match &self.state {
299 InternalSolverState::Unsat(core) => Ok(core.clone()),
300 other => Err(StateError {
301 required_state: SolverState::Unsat,
302 actual_state: other.to_external(),
303 }
304 .into()),
305 }
306 }
307}
308
309impl<'term> Terminate<'term> for IpasirSolver<'term, '_> {
310 fn attach_terminator<CB>(&mut self, cb: CB)
340 where
341 CB: FnMut() -> ControlSignal + 'term,
342 {
343 self.terminate_cb = Some(Box::new(Box::new(cb)));
344 let cb_ptr =
345 std::ptr::from_ref(self.terminate_cb.as_mut().unwrap().as_mut()).cast::<c_void>();
346 unsafe { ffi::ipasir_set_terminate(self.handle, cb_ptr, Some(ffi::ipasir_terminate_cb)) }
347 }
348
349 fn detach_terminator(&mut self) {
350 self.terminate_cb = None;
351 unsafe { ffi::ipasir_set_terminate(self.handle, std::ptr::null(), None) }
352 }
353}
354
355impl<'learn> Learn<'learn> for IpasirSolver<'_, 'learn> {
356 fn attach_learner<CB>(&mut self, cb: CB, max_len: usize)
382 where
383 CB: FnMut(Clause) + 'learn,
384 {
385 self.learner_cb = Some(Box::new(Box::new(cb)));
386 let cb_ptr =
387 std::ptr::from_ref(self.learner_cb.as_mut().unwrap().as_mut()).cast::<c_void>();
388 unsafe {
389 ffi::ipasir_set_learn(
390 self.handle,
391 cb_ptr,
392 max_len.try_into().unwrap(),
393 Some(ffi::ipasir_learn_cb),
394 );
395 }
396 }
397
398 fn detach_learner(&mut self) {
399 self.terminate_cb = None;
400 unsafe { ffi::ipasir_set_learn(self.handle, std::ptr::null(), 0, None) }
401 }
402}
403
404impl SolveStats for IpasirSolver<'_, '_> {
405 fn stats(&self) -> SolverStats {
406 self.stats.clone()
407 }
408}
409
410impl Drop for IpasirSolver<'_, '_> {
411 fn drop(&mut self) {
412 unsafe { ffi::ipasir_release(self.handle) }
413 }
414}
415
416impl Extend<Clause> for IpasirSolver<'_, '_> {
417 fn extend<T: IntoIterator<Item = Clause>>(&mut self, iter: T) {
418 iter.into_iter()
419 .for_each(|cl| self.add_clause(cl).expect("Error adding clause in extend"));
420 }
421}
422
423impl<'a, C> Extend<&'a C> for IpasirSolver<'_, '_>
424where
425 C: AsRef<Cl> + ?Sized,
426{
427 fn extend<T: IntoIterator<Item = &'a C>>(&mut self, iter: T) {
428 iter.into_iter().for_each(|cl| {
429 self.add_clause_ref(cl)
430 .expect("Error adding clause in extend");
431 });
432 }
433}
434
435mod ffi {
436 use super::{LearnCallbackPtr, TermCallbackPtr};
437 use core::ffi::{c_char, c_int, c_void};
438 use rustsat::{solvers::ControlSignal, types::Lit, utils::from_raw_parts_maybe_null};
439
440 #[repr(C)]
441 pub struct IpasirHandle {
442 _private: [u8; 0],
443 }
444
445 extern "C" {
446 pub fn ipasir_signature() -> *const c_char;
448 pub fn ipasir_init() -> *mut IpasirHandle;
449 pub fn ipasir_release(solver: *mut IpasirHandle);
450 pub fn ipasir_add(solver: *mut IpasirHandle, lit_or_zero: c_int);
451 pub fn ipasir_assume(solver: *mut IpasirHandle, lit: c_int);
452 pub fn ipasir_solve(solver: *mut IpasirHandle) -> c_int;
453 pub fn ipasir_val(solver: *mut IpasirHandle, lit: c_int) -> c_int;
454 pub fn ipasir_failed(solver: *mut IpasirHandle, lit: c_int) -> c_int;
455 pub fn ipasir_set_terminate(
456 solver: *mut IpasirHandle,
457 state: *const c_void,
458 terminate: Option<unsafe extern "C" fn(state: *const c_void) -> c_int>,
459 );
460 pub fn ipasir_set_learn(
461 solver: *mut IpasirHandle,
462 state: *const c_void,
463 max_length: c_int,
464 learn: Option<unsafe extern "C" fn(state: *const c_void, clause: *const c_int)>,
465 );
466 }
467
468 pub unsafe extern "C" fn ipasir_terminate_cb(ptr: *const c_void) -> c_int {
470 let cb = &mut *(ptr as *mut TermCallbackPtr<'_>);
471 match cb() {
472 ControlSignal::Continue => 0,
473 ControlSignal::Terminate => 1,
474 }
475 }
476
477 pub unsafe extern "C" fn ipasir_learn_cb(ptr: *const c_void, clause: *const c_int) {
478 let cb = unsafe { &mut *(ptr as *mut LearnCallbackPtr<'_>) };
479
480 let mut cnt: usize = 0;
481 while *clause
482 .offset(isize::try_from(cnt).expect("learned clauses is longer than `isize::MAX`"))
483 != 0
484 {
485 cnt += 1;
486 }
487 let int_slice = from_raw_parts_maybe_null(clause, cnt);
488 let clause = int_slice
489 .iter()
490 .map(|il| {
491 Lit::from_ipasir(*il).expect("Invalid literal in learned clause from IPASIR solver")
492 })
493 .collect();
494 cb(clause);
495 }
496}