sat_solver/sat/
cnf.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
2//! Defines the Conjunctive Normal Form (CNF) representation for SAT formulas.
3//!
4//! A CNF formula is a conjunction (AND) of clauses, where each clause is a
5//! disjunction (OR) of literals. This is the standard input format for most
6//! SAT solvers.
7//!
8//! This module provides:
9//! - The `Cnf` struct to store a list of clauses and related metadata.
10//! - Methods for constructing `Cnf` from various sources (e.g. iterators of DIMACS literals, `Expr`).
11//! - Utilities for interacting with the `Cnf` formula, such as adding clauses, iterating,
12//!   and verifying solutions.
13//! - Conversion functions to and from a more general `Expr` (expression tree) representation.
14#![allow(unsafe_code, clippy::cast_possible_truncation, clippy::cast_sign_loss)]
15
16use super::clause::Clause;
17use super::expr::{Expr, apply_laws};
18use crate::sat::clause_storage::LiteralStorage;
19use crate::sat::literal;
20use crate::sat::literal::{Literal, PackedLiteral, Variable};
21use crate::sat::solver::Solutions;
22use itertools::Itertools;
23use smallvec::SmallVec;
24use std::fmt::Display;
25use std::num::NonZeroI32;
26use std::ops::{Index, IndexMut};
27
28/// Represents a decision level in the SAT solver's search process.
29pub type DecisionLevel = usize;
30
31/// Represents a boolean formula in Conjunctive Normal Form (CNF).
32///
33/// A CNF formula is a collection of clauses. The formula is satisfied if and only if
34/// all its clauses are satisfied.
35///
36/// # Type Parameters
37///
38/// * `L`: The type of `Literal` used in the clauses. Defaults to `PackedLiteral`.
39/// * `S`: The `LiteralStorage` type used within each `Clause`. Defaults to `SmallVec<[L; 8]>`.
40#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
41pub struct Cnf<L: Literal = PackedLiteral, S: LiteralStorage<L> = SmallVec<[L; 8]>> {
42    /// The list of clauses that make up the CNF formula.
43    pub clauses: Vec<Clause<L, S>>,
44    /// The highest variable identifier encountered in the formula, plus one.
45    /// This represents the number of distinct variables if they are numbered contiguously from 0 or 1.
46    /// If variables are 1-indexed `v1, ..., vn`, `num_vars` would be `n+1`.
47    /// If variables are 0-indexed `v0, ..., v(n-1)`, `num_vars` would be `n`.
48    pub num_vars: usize,
49    /// A flat list of all variable identifiers present in the formula. May contain duplicates.
50    pub vars: Vec<Variable>,
51    /// A flat list of all literals present across all clauses. May contain duplicates.
52    pub lits: Vec<L>,
53    /// The index in `clauses` vector that separates original problem clauses from learnt clauses.
54    /// Clauses from `0` to `non_learnt_idx - 1` are original.
55    /// Clauses from `non_learnt_idx` onwards are learnt during solving.
56    /// When a `Cnf` is first created from a problem, `non_learnt_idx` is equal to `clauses.len()`.
57    pub non_learnt_idx: usize,
58}
59
60impl<T: Literal, S: LiteralStorage<T>> Index<usize> for Cnf<T, S> {
61    type Output = Clause<T, S>;
62
63    /// Accesses the clause at the given `index`.
64    ///
65    /// # Panics
66    ///
67    /// Panics if `index` is out of bounds.
68    ///
69    /// # Safety
70    ///
71    /// This implementation uses `get_unchecked` for performance
72    fn index(&self, index: usize) -> &Self::Output {
73        // Safety: Caller must ensure `index` is within bounds `[0, self.clauses.len())`.
74        // This should be fine if used correctly.
75        unsafe { self.clauses.get_unchecked(index) }
76    }
77}
78
79impl<T: Literal, S: LiteralStorage<T>> IndexMut<usize> for Cnf<T, S> {
80    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
81        // Safety: Caller must ensure `index` is within bounds `[0, self.clauses.len())`.
82        // This should be fine if used correctly.
83        unsafe { self.clauses.get_unchecked_mut(index) }
84    }
85}
86
87impl<T: Literal, S: LiteralStorage<T>> Cnf<T, S> {
88    /// Creates a new `Cnf` instance from an iterator of clauses, where each clause
89    /// is itself an iterator of `i32` (DIMACS literals).
90    ///
91    /// Example: `Cnf::new(vec![vec![1, -2], vec![2, 3]])` creates a CNF for
92    /// `(x1 OR !x2) AND (x2 OR x3)`.
93    ///
94    /// During construction, it determines `num_vars` based on the maximum variable ID
95    /// encountered. `vars` and `lits` collect all variables and literals.
96    /// `non_learnt_idx` is set to the total number of initial clauses.
97    ///
98    /// # Type Parameters for Arguments
99    ///
100    /// * `J`: An iterator yielding `i32` (DIMACS literals for a single clause).
101    /// * `I`: An iterator yielding `J` (an iterator of clauses).
102    pub fn new<J: IntoIterator<Item = i32>, I: IntoIterator<Item = J>>(clauses_iter: I) -> Self {
103        let (clauses_vec, max_var_id, vars_vec, lits_vec, num_initial_clauses) = clauses_iter
104            .into_iter()
105            .map(|clause_dimacs| clause_dimacs.into_iter().collect::<Clause<_, _>>())
106            .fold(
107                (Vec::new(), u32::default(), Vec::new(), Vec::new(), 0),
108                |(
109                    mut acc_clauses,
110                    mut current_max_var,
111                    mut acc_vars,
112                    mut acc_lits,
113                    clause_count,
114                ),
115                 clause| {
116                    if clause.is_empty() || clause.is_tautology() {
117                        return (
118                            acc_clauses,
119                            current_max_var,
120                            acc_vars,
121                            acc_lits,
122                            clause_count,
123                        );
124                    }
125
126                    let clause_max_var = clause
127                        .iter()
128                        .map(|l: &T| l.variable())
129                        .max()
130                        .unwrap_or_default();
131
132                    current_max_var = current_max_var.max(clause_max_var);
133
134                    acc_lits.extend(clause.iter().copied());
135                    acc_vars.extend(clause.iter().map(|l| l.variable()));
136
137                    acc_clauses.push(clause);
138
139                    (
140                        acc_clauses,
141                        current_max_var,
142                        acc_vars,
143                        acc_lits,
144                        clause_count + 1,
145                    )
146                },
147            );
148
149        Self {
150            clauses: clauses_vec,
151            num_vars: (max_var_id as usize).wrapping_add(1),
152            vars: vars_vec,
153            lits: lits_vec,
154            non_learnt_idx: num_initial_clauses,
155        }
156    }
157
158    /// Removes a clause at the specified index.
159    ///
160    /// # Arguments
161    ///
162    /// * `idx`: The index of the clause to remove.
163    ///
164    /// # Panics
165    ///
166    /// Panics if `idx` is out of bounds.
167    pub fn remove(&mut self, idx: usize) {
168        self.clauses.remove(idx);
169        if idx < self.non_learnt_idx {
170            self.non_learnt_idx = self.non_learnt_idx.saturating_sub(1);
171        }
172    }
173
174    /// Returns an iterator over the clauses in the CNF.
175    pub fn iter(&self) -> impl Iterator<Item = &Clause<T, S>> {
176        self.clauses.iter()
177    }
178
179    /// Returns a mutable iterator over the clauses in the CNF.
180    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Clause<T, S>> {
181        self.clauses.iter_mut()
182    }
183
184    /// Adds a new clause to the CNF.
185    ///
186    /// The clause is added to the end of the `clauses` list. If it's considered a learnt
187    /// clause, it should be added after `non_learnt_idx`. This function implicitly adds
188    /// it as if it's a new problem clause if `non_learnt_idx` isn't managed externally
189    /// before calling this for learnt clauses.
190    ///
191    /// Updates `num_vars`, `vars`, and `lits` based on the new clause.
192    ///
193    /// # Arguments
194    ///
195    /// * `clause`: The `Clause<T, S>` to add.
196    pub fn add_clause(&mut self, clause: Clause<T, S>) {
197        let clause_max_var_id = clause
198            .iter()
199            .map(|l| l.variable())
200            .max()
201            .unwrap_or_default();
202        let clause_max_var_usize = clause_max_var_id as usize;
203
204        let clause_vars = clause.iter().map(|l| l.variable()).collect_vec();
205
206        self.vars.extend(clause_vars);
207        self.lits.extend(clause.iter());
208        self.clauses.push(clause);
209
210        let required_num_vars = clause_max_var_usize.wrapping_add(1);
211        self.num_vars = self.num_vars.max(required_num_vars);
212    }
213
214    /// Adds a new clause specified as a `Vec<i32>` (DIMACS literals).
215    ///
216    /// Converts `clause_dimacs` to a `Clause<T, S>` and then calls `add_clause`.
217    ///
218    /// # Arguments
219    ///
220    /// * `clause_dimacs`: A vector of `i32` representing the clause.
221    pub fn add_clause_vec(&mut self, clause_dimacs: Vec<i32>) {
222        self.add_clause(Clause::from(clause_dimacs));
223    }
224
225    /// Returns the total number of clauses in the CNF (both original and learnt).
226    #[must_use]
227    pub const fn len(&self) -> usize {
228        self.clauses.len()
229    }
230
231    /// Returns `true` if the CNF contains no clauses.
232    #[must_use]
233    pub const fn is_empty(&self) -> bool {
234        self.clauses.is_empty()
235    }
236
237    /// Verifies if a given set of solutions satisfies the CNF formula.
238    ///
239    /// A CNF is satisfied if every clause in it is satisfied. A clause is satisfied
240    /// if at least one of its literals is true under the given assignment.
241    ///
242    /// # Arguments
243    ///
244    /// * `solutions`: A `Solutions` object providing the truth assignment for variables.
245    ///   `Solutions` handles DIMACS-style variable IDs (1-indexed, signed).
246    ///
247    /// # Returns
248    ///
249    /// `true` if `solutions` satisfies all clauses in the CNF, `false` otherwise.
250    #[must_use]
251    pub fn verify(&self, solutions: &Solutions) -> bool {
252        self.iter().all(|clause| {
253            clause.iter().any(|&lit| {
254                let lit_i32 = lit.to_i32();
255                NonZeroI32::new(lit_i32).is_some_and(|nonzero_var| solutions.check(nonzero_var))
256            })
257        })
258    }
259
260    /// Converts this `Cnf<T, S>` into a `Cnf<L, U>` with different literal or storage types.
261    ///
262    /// Each clause is converted using `Clause::convert`. Metadata like `num_vars`,
263    /// `vars`, `lits`, and `non_learnt_idx` are also transformed or cloned.
264    ///
265    /// # Type Parameters
266    ///
267    /// * `L`: The target `Literal` type for the new CNF.
268    /// * `U`: The target `LiteralStorage<L>` type for clauses in the new CNF.
269    ///
270    /// # Returns
271    ///
272    /// A new `Cnf<L, U>` instance.
273    pub fn convert<TargetL: Literal, TargetS: LiteralStorage<TargetL>>(
274        &self,
275    ) -> Cnf<TargetL, TargetS> {
276        let clauses_converted = self.clauses.iter().map(Clause::convert).collect_vec();
277
278        let vars_converted = self.vars.clone();
279
280        let lits_converted = self.lits.iter().map(|l| literal::convert(l)).collect_vec();
281
282        Cnf {
283            clauses: clauses_converted,
284            num_vars: self.num_vars,
285            vars: vars_converted,
286            lits: lits_converted,
287            non_learnt_idx: self.non_learnt_idx,
288        }
289    }
290}
291
292impl<L: Literal, S: LiteralStorage<L>> Display for Cnf<L, S> {
293    /// Formats the CNF into DIMACS CNF format.
294    ///
295    /// Example output:
296    /// ```
297    /// c Generated by CNF
298    /// p cnf <num_vars> <num_clauses>
299    /// 1 -2 0
300    /// 2 3 0
301    /// ```
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        let dimacs_num_vars = if self.num_vars > 0 {
304            self.num_vars - 1
305        } else {
306            0
307        };
308        let dimacs_num_clauses = self.non_learnt_idx;
309
310        writeln!(f, "c Generated by CNF module")?;
311        writeln!(f, "p cnf {dimacs_num_vars} {dimacs_num_clauses}")?;
312
313        for clause_idx in 0..self.non_learnt_idx {
314            let clause = &self.clauses[clause_idx];
315            for &lit in clause.iter() {
316                write!(f, "{} ", lit.to_i32())?;
317            }
318            writeln!(f, "0")?;
319        }
320        Ok(())
321    }
322}
323
324impl<L: Literal, S: LiteralStorage<L>> FromIterator<Clause<L, S>> for Cnf<L, S> {
325    /// Creates a `Cnf` from an iterator of `Clause<L, S>`.
326    ///
327    /// Each clause from the iterator is added using `self.add_clause`.
328    /// This initialises a default `Cnf` and then populates it.
329    /// `non_learnt_idx` will be implicitly managed by `add_clause` if it updates it,
330    /// or will remain 0 if `add_clause` only appends to `clauses`.
331    /// A more robust way would be to collect clauses, then initialise `Cnf` fields properly.
332    fn from_iter<IterClauses: IntoIterator<Item = Clause<L, S>>>(iter: IterClauses) -> Self {
333        let mut cnf = Self::default();
334        let mut max_var_id = u32::default();
335        let mut clause_count = 0;
336
337        for clause in iter {
338            if let Some(clause_max_v) = clause.iter().map(|l| l.variable()).max() {
339                max_var_id = max_var_id.max(clause_max_v);
340            }
341            cnf.vars.extend(clause.iter().map(|l| l.variable()));
342            cnf.lits.extend(clause.iter().copied());
343            cnf.clauses.push(clause);
344            clause_count += 1;
345        }
346        cnf.num_vars = (max_var_id as usize).wrapping_add(1);
347        cnf.non_learnt_idx = clause_count;
348        cnf
349    }
350}
351
352/// Converts a general boolean expression (`Expr`) into CNF.
353///
354/// The conversion involves:
355/// 1. Applying logical laws (`apply_laws`) to transform the expression into a
356///    structure that is easier to convert to CNF (e.g. NNF, pushing negations inwards).
357/// 2. Recursively converting the transformed expression into a list of clauses (`to_clauses`).
358/// 3. Constructing a `Cnf` object from this list of clauses.
359///
360/// Note: This is a standard, potentially exponential, conversion for arbitrary expressions.
361/// For more efficient CNF conversion (e.g. Tseytin transformation), specialised algorithms are needed.
362#[must_use]
363pub fn to_cnf<T: Literal, S: LiteralStorage<T>>(expr: &Expr) -> Cnf<T, S> {
364    let cnf_expr = apply_laws(expr);
365    let clauses_vec = to_clauses_recursive(&cnf_expr);
366    Cnf::from_iter(clauses_vec)
367}
368
369/// Helper function to recursively convert an `Expr` (assumed to be in a CNF-friendly form)
370/// into a list of `Clause`s.
371fn to_clauses_recursive<T: Literal, S: LiteralStorage<T>>(expr: &Expr) -> Vec<Clause<T, S>> {
372    match expr {
373        Expr::And(e1, e2) => {
374            let mut c1_clauses = to_clauses_recursive(e1);
375            let c2_clauses = to_clauses_recursive(e2);
376            c1_clauses.extend(c2_clauses);
377            c1_clauses
378        }
379        _ => vec![to_clause_recursive(expr)],
380    }
381}
382
383/// Helper function to convert an `Expr` representing a disjunction or a literal
384/// into a single `Clause`.
385fn to_clause_recursive<T: Literal, S: LiteralStorage<T>>(expr: &Expr) -> Clause<T, S> {
386    match expr {
387        Expr::Or(e1, e2) => {
388            let clause1: Clause<T, S> = to_clause_recursive(e1);
389            let clause2: Clause<T, S> = to_clause_recursive(e2);
390            let mut combined_lits: Vec<T> = Vec::from(&clause1);
391            combined_lits.extend(Vec::from(&clause2));
392            Clause::<T, S>::new(&combined_lits)
393        }
394        _ => Clause::<T, S>::new(&[expr_to_literal(expr)]),
395    }
396}
397
398/// Helper function to convert an `Expr` representing a literal into a `Literal` type `T`.
399/// Panics if the expression is not a literal form (Var or Not(Var)).
400fn expr_to_literal<T: Literal>(expr: &Expr) -> T {
401    match expr {
402        Expr::Var(i) => T::new(*i, true),
403        Expr::Not(e) => {
404            if let Expr::Var(i) = **e {
405                T::new(i, false)
406            } else {
407                panic!("Expression Not(Non-Variable) encountered where literal expected.");
408            }
409        }
410        _ => panic!("Expression is not a literal (Var or Not(Var))."),
411    }
412}
413
414/// Converts a `Clause` back into an `Expr` (a disjunction of literal expressions).
415fn clause_to_expr<T: Literal, S: LiteralStorage<T>>(clause: &Clause<T, S>) -> Expr {
416    let mut iter = clause.iter();
417    let first_lit_expr =
418        literal_to_expr_node(*iter.next().expect("Cannot convert empty clause to Expr"));
419    iter.fold(first_lit_expr, |acc_expr, &literal| {
420        Expr::Or(Box::new(acc_expr), Box::new(literal_to_expr_node(literal)))
421    })
422}
423
424/// Converts a `Literal` `T` into an `Expr` node (Var or Not(Var)).
425fn literal_to_expr_node<T: Literal>(literal: T) -> Expr {
426    if literal.polarity() {
427        Expr::Var(literal.variable())
428    } else {
429        Expr::Not(Box::new(Expr::Var(literal.variable())))
430    }
431}
432
433impl<T: Literal, S: LiteralStorage<T>> From<Expr> for Cnf<T, S> {
434    /// Converts an `Expr` into a `Cnf<T, S>`.
435    /// This is a convenience wrapper around `to_cnf`.
436    fn from(expr: Expr) -> Self {
437        to_cnf(&expr)
438    }
439}
440
441impl<L: Literal, S: LiteralStorage<L>> From<Vec<Clause<L, S>>> for Cnf<L, S> {
442    /// Converts a `Vec<Clause<L, S>>` directly into a `Cnf<L, S>`.
443    /// Uses `from_iter` for consistent initialisation.
444    fn from(clauses: Vec<Clause<L, S>>) -> Self {
445        Self::from_iter(clauses)
446    }
447}
448
449impl<L: Literal, S: LiteralStorage<L>> From<Vec<Vec<i32>>> for Cnf<L, S> {
450    /// Converts `Vec<Vec<i32>>` (DIMACS clauses) into a `Cnf<L, S>`.
451    /// Uses `Cnf::new` for construction.
452    fn from(value: Vec<Vec<i32>>) -> Self {
453        Self::new(value)
454    }
455}
456
457impl<T: Literal, S: LiteralStorage<T>> TryFrom<Cnf<T, S>> for Expr {
458    type Error = &'static str;
459
460    /// Attempts to convert a `Cnf<T, S>` back into an `Expr`.
461    /// The resulting `Expr` will be a conjunction of disjunctions.
462    /// Returns an error if the CNF is empty (contains no clauses).
463    fn try_from(cnf: Cnf<T, S>) -> Result<Self, Self::Error> {
464        let mut iter = cnf.iter();
465        let first_clause_expr =
466            clause_to_expr(iter.next().ok_or("Cannot convert empty CNF to Expr")?);
467
468        iter.try_fold(first_clause_expr, |acc_expr, clause| {
469            Ok(Self::And(
470                Box::new(acc_expr),
471                Box::new(clause_to_expr(clause)),
472            ))
473        })
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use crate::sat::literal::PackedLiteral;
481
482    #[test]
483    fn test_cnf_new_from_dimacs() {
484        let dimacs_clauses = vec![vec![1, -2], vec![-1, 2, 3]];
485        let cnf: Cnf<PackedLiteral> = Cnf::new(dimacs_clauses);
486
487        assert_eq!(cnf.clauses.len(), 2);
488        assert_eq!(cnf.num_vars, 3 + 1);
489        assert_eq!(cnf.non_learnt_idx, 2);
490
491        let first_clause = &cnf.clauses[0];
492        assert_eq!(first_clause.len(), 2);
493        assert!(
494            first_clause
495                .iter()
496                .any(|l| l.variable() == 1_u32 && l.polarity())
497        );
498        assert!(
499            first_clause
500                .iter()
501                .any(|l| l.variable() == 2_u32 && !l.polarity())
502        );
503    }
504
505    #[test]
506    fn test_cnf_add_clause() {
507        let mut cnf: Cnf<PackedLiteral> = Cnf::new(Vec::<Vec<i32>>::new());
508        assert_eq!(cnf.num_vars, 1);
509
510        let clause1_dimacs = vec![1, -2];
511        cnf.add_clause_vec(clause1_dimacs);
512        assert_eq!(cnf.clauses.len(), 1);
513        assert_eq!(cnf.num_vars, 2 + 1);
514
515        let clause2 = Clause::from(vec![-2, 3, 4]);
516        cnf.add_clause(clause2);
517        assert_eq!(cnf.clauses.len(), 2);
518        assert_eq!(cnf.num_vars, 4 + 1);
519    }
520
521    #[test]
522    fn test_cnf_display_dimacs() {
523        let cnf: Cnf<PackedLiteral> = Cnf::new(vec![vec![1, -2], vec![2, 3]]);
524        let dimacs_str = format!("{cnf}");
525        let expected_header = "p cnf 3 2";
526        assert!(
527            dimacs_str.contains(expected_header),
528            "DIMACS header mismatch"
529        );
530        assert!(dimacs_str.contains("1 -2 0"), "Clause 1 mismatch");
531        assert!(dimacs_str.contains("2 3 0"), "Clause 2 mismatch");
532    }
533
534    #[test]
535    fn test_cnf_from_expr() {
536        let expr = Expr::And(
537            Box::new(Expr::Or(
538                Box::new(Expr::Var(1_u32)),
539                Box::new(Expr::Not(Box::new(Expr::Var(2_u32)))),
540            )),
541            Box::new(Expr::Or(
542                Box::new(Expr::Var(2_u32)),
543                Box::new(Expr::Var(3_u32)),
544            )),
545        );
546
547        let cnf: Cnf<PackedLiteral> = Cnf::from(expr);
548        assert_eq!(cnf.clauses.len(), 2);
549        assert_eq!(cnf.num_vars, 3 + 1);
550
551        assert!(cnf.clauses.iter().any(|c| {
552            c.len() == 2
553                && c.iter().any(|l| l.variable() == 1_u32 && l.polarity())
554                && c.iter().any(|l| l.variable() == 2_u32 && !l.polarity())
555        }));
556
557        assert!(cnf.clauses.iter().any(|c| {
558            c.len() == 2
559                && c.iter().any(|l| l.variable() == 2_u32 && l.polarity())
560                && c.iter().any(|l| l.variable() == 3_u32 && l.polarity())
561        }));
562    }
563
564    #[test]
565    fn test_cnf_verify_solution() {
566        let cnf: Cnf<PackedLiteral> = Cnf::new(vec![vec![1, -2], vec![-1, 2, 3]]);
567        let mut solutions = Solutions::default();
568        solutions.add(1.try_into().unwrap());
569        solutions.add((-2).try_into().unwrap());
570        solutions.add(3.try_into().unwrap());
571        assert!(cnf.verify(&solutions));
572
573        let mut solutions_fail = Solutions::default();
574        solutions_fail.add((-1).try_into().unwrap());
575        solutions_fail.add(2.try_into().unwrap());
576        solutions_fail.add((-3).try_into().unwrap());
577        assert!(!cnf.verify(&solutions_fail));
578    }
579
580    #[test]
581    fn test_cnf_new_empty_input() {
582        let cnf_empty: Cnf<PackedLiteral> = Cnf::new(Vec::<Vec<i32>>::new());
583        assert!(cnf_empty.is_empty());
584        assert_eq!(cnf_empty.num_vars, 1);
585        assert_eq!(cnf_empty.non_learnt_idx, 0);
586    }
587
588    #[test]
589    fn test_cnf_new_with_empty_clause() {
590        let cnf_with_empty_clause: Cnf<PackedLiteral> = Cnf::new(vec![Vec::<i32>::new()]);
591        assert_eq!(cnf_with_empty_clause.clauses.len(), 0);
592    }
593}