sat_solver/
lib.rs

1use num_rational::Ratio;
2use ordered_float::OrderedFloat;
3use priority_queue::PriorityQueue;
4
5use std::cell::Cell;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::collections::HashSet;
9use std::rc::Rc;
10
11const BUMP_FACTOR: f64 = 1. / 0.95;
12
13pub struct Solver {
14    literals: HashMap<i32, Rc<Lit>>,
15    watchers: HashMap<i32, HashSet<usize>>,
16    var_order: PriorityQueue<i32, OrderedFloat<f64>>,
17    clauses: Vec<Vec<Rc<Lit>>>,
18    decisions: HashMap<usize, HashSet<i32>>,
19    i_graph: HashMap<i32, (usize, Vec<i32>)>,
20    var_inc: f64,
21    level: usize,
22    cur_watchers: HashMap<i32, HashSet<usize>>,
23    cur_var_order: PriorityQueue<i32, OrderedFloat<f64>>,
24}
25
26impl Solver {
27    pub fn new(input: &str) -> Solver {
28        let lines = input.split('\n');
29
30        let mut clauses = Vec::new();
31        let mut literals: HashMap<i32, Rc<Lit>> = HashMap::new();
32
33        for line in lines {
34            match line.chars().nth(0) {
35                Some('p') => {
36                    let mut p = line.split_whitespace();
37                    p.next();
38                    p.next();
39                    let num_vars: i32 = p.next().unwrap().parse().unwrap();
40                    literals = (-num_vars..=num_vars)
41                        .filter(|&i| i != 0)
42                        .map(|i| (i, Rc::new(Lit::new(i))))
43                        .collect();
44                },
45                Some('c') => continue,
46                Some(_) => {
47                    let mut clause = HashSet::new();
48                    let mut lits = line.split_whitespace();
49
50                    loop {
51                        match lits.next().unwrap().parse().unwrap() {
52                            0 => break,
53                            i => {clause.insert(i); continue;},
54                        }
55                    }
56
57                    clauses.push(clause
58                                 .iter()
59                                 .map(|l| Rc::clone(&literals[l]))
60                                 .collect::<Vec<_>>());
61                },
62                None => continue,
63            }
64        }
65
66        let mut watchers: HashMap<i32, HashSet<usize>> = literals
67            .iter()
68            .map(|(&i, _)| (i, HashSet::new()))
69            .collect();
70
71        let var_priorities: HashMap<i32, Cell<f64>> = literals
72            .iter()
73            .map(|(&i, _)| (i, Cell::new(1.)))
74            .collect();
75
76        for (i, clause) in clauses.iter().enumerate() {
77            for lit in clause {
78                var_priorities[&lit.to_int()].replace(
79                    var_priorities[&lit.to_int()].get() + 1.
80                );
81            }
82
83            for j in 0..std::cmp::min(2, clause.len()) {
84                let lit = clause[j].to_int();
85                watchers.get_mut(&-lit).unwrap().insert(i);
86            }
87        }
88
89        let mut var_order = PriorityQueue::new();
90        for (lit, p) in var_priorities {
91            var_order.push(lit, OrderedFloat::from(p.take()));
92        }
93
94        Solver {
95            literals,
96            watchers,
97            var_order,
98            clauses,
99            cur_watchers: HashMap::new(),
100            cur_var_order: PriorityQueue::new(),
101            decisions: HashMap::new(),
102            i_graph: HashMap::new(),
103            var_inc: 1.,
104            level: 0,
105        }
106    }
107
108    pub fn solve(mut self) -> String {
109        self.restart();
110
111        loop {
112            let conflict = self.propagate();
113            if let Some(lit) = conflict {
114                if self.level == 0 {
115                    return "UNSAT".to_owned();
116                } else {
117                    self.analyze(lit);
118                    self.restart();
119                }
120            } else {
121                if self.satisfied() {
122                    return self.model();
123                } else {
124                    self.decide();
125                }
126            }
127        }
128    }
129
130    fn propagate(&mut self) -> Option<i32> {
131        // TODO: change to Vec
132        let mut unit_literals = HashSet::new();
133        // at the start, the only possibly unit clause is the newly appended
134        // clause, unless we are at the first iteration (0th level)
135        if self.level == 0 {
136            for clause in &self.clauses {
137                if clause.is_unit() {
138                    let lit = clause[0].to_int();
139                    unit_literals.insert(lit);
140                    self.decisions.get_mut(&self.level).unwrap().insert(lit);
141                    self.i_graph.insert(lit, (self.level, Vec::new()));
142                }
143            }
144        } else {
145            let lit = self.decisions[&self.level].iter().next().cloned().unwrap();
146            unit_literals.insert(lit);
147        }
148
149        while !unit_literals.is_empty() {
150            let lit = unit_literals.iter().next().cloned().unwrap();
151            unit_literals.remove(&lit);
152            self.literals.get_mut(&lit).unwrap().set_true();
153            self.literals.get_mut(&-lit).unwrap().set_false();
154
155            let indexes = self.cur_watchers[&lit].clone();
156            for i in indexes {
157                let clause = &self.clauses[i];
158                if clause.is_satisfied() {
159                    // nothing to do
160                    continue;
161                } else if clause.is_unit() {
162                    // new unit clause found
163                    let unit_lit = clause.get_unset().unwrap().to_int();
164                    if self.i_graph.contains_key(&unit_lit) {
165                        continue
166                    }
167                    unit_literals.insert(unit_lit);
168
169                    self.decisions.get_mut(&self.level).unwrap().insert(unit_lit);
170                    let reason = clause
171                        .iter()
172                        .map(|l| l.to_int())
173                        .filter(|&l| l != unit_lit)
174                        .map(|l| -l)
175                        .collect::<Vec<_>>();
176                    self.i_graph.insert(unit_lit, (self.level, reason));
177                } else if clause.is_conflict() {
178                    // conflict found
179                    self.decisions.get_mut(&self.level).unwrap().insert(-lit);
180                    let reason = clause
181                        .iter()
182                        .map(|l| l.to_int())
183                        .filter(|&l| l != -lit)
184                        .map(|l| -l)
185                        .collect::<Vec<_>>();
186                    self.i_graph.insert(-lit, (self.level, reason));
187                    return Some(lit);
188                } else {
189                    // clause not satisfied, modify watchers
190                    let mut clause_iter = clause.iter();
191                    let l = loop {
192                        let l = clause_iter.next().unwrap();
193                        if !l.is_unset() {
194                            continue;
195                        }
196                        let l = l.to_int();
197                        if self.cur_watchers[&-l].contains(&i) {
198                            continue;
199                        }
200                        break l;
201                    };
202
203                    self.cur_watchers.get_mut(&lit).unwrap().remove(&i);
204                    self.cur_watchers.get_mut(&-l).unwrap().insert(i);
205                }
206            }
207        }
208
209        None
210    }
211
212    fn satisfied(&self) -> bool {
213        self.clauses.iter().all(Clause::is_satisfied)
214    }
215
216    fn restart(&mut self) {
217        for lit in self.literals.values() {
218            lit.unset();
219        }
220        self.cur_watchers = self.watchers.clone();
221        self.cur_var_order = self.var_order.clone();
222        self.decisions.insert(0, HashSet::new());
223        self.i_graph.clear();
224        self.level = 0;
225    }
226
227    fn model(&self) -> String {
228        (1..=self.literals.len() as i32 / 2)
229            .map(|l| if self.i_graph.contains_key(&-l) { -l } else { l })
230            .map(|l| l.to_string())
231            .collect::<Vec<_>>()
232            .join(" ")
233    }
234
235    fn decide(&mut self) {
236        let lit = loop {
237            let next_lit = self.cur_var_order.pop().unwrap().0;
238            if self.literals[&next_lit].is_unset() {
239                break next_lit;
240            }
241        };
242
243        self.level += 1;
244        self.decisions.insert(self.level, [lit].iter().cloned().collect());
245        self.i_graph.insert(lit, (self.level, Vec::new()));
246    }
247
248    fn analyze(&mut self, lit: i32) {
249        // find first unique implication point (1-UIP)
250        let mut uips = HashSet::new();
251        let weights = self.decisions[&self.level]
252            .iter()
253            .map(|&l| (l, Ratio::<i128>::new(0, 1)))
254            .collect::<HashMap<_, _>>();
255        let weights_ref = Rc::new(RefCell::new(weights));
256
257        fn explore(lit: i32,
258                   weight: Ratio<i128>,
259                   weights: Rc<RefCell<HashMap<i32, Ratio<i128>>>>,
260                   i_graph: &HashMap<i32, (usize, Vec<i32>)>,
261                   level: usize) {
262            *weights.borrow_mut().get_mut(&lit).unwrap() += weight;
263            let next_lits = i_graph[&lit].1
264                                         .iter()
265                                         .filter(|&l| i_graph[l].0 == level)
266                                         .collect::<Vec<_>>();
267            for &l in &next_lits {
268                explore(*l,
269                        weight / next_lits.len() as i128,
270                        Rc::clone(&weights),
271                        i_graph,
272                        level,
273                );
274            }
275        }
276
277        explore(lit,
278                Ratio::new(1, 1),
279                Rc::clone(&weights_ref),
280                &self.i_graph,
281                self.level);
282
283        for l in weights_ref.borrow().keys() {
284            if weights_ref.borrow()[l] == Ratio::new(1, 1) {
285                uips.insert(*l);
286            }
287        }
288        uips.remove(&lit);
289
290        let weights = self.decisions[&self.level]
291            .iter()
292            .map(|&l| (l, Ratio::new(0, 1)))
293            .collect::<HashMap<_, _>>();
294        let weights_ref = Rc::new(RefCell::new(weights));
295
296        explore(-lit,
297                Ratio::new(1, 1),
298                Rc::clone(&weights_ref),
299                &self.i_graph,
300                self.level);
301
302        for l in weights_ref.borrow().keys() {
303            if weights_ref.borrow()[l] == Ratio::new(1, 1) && uips.contains(l) {
304                continue;
305            }
306
307            uips.remove(l);
308        }
309
310        let mut l = lit;
311        let fuip = loop {
312            for &next_l in &self.i_graph[&l].1 {
313                if self.i_graph[&next_l].0 == self.level {
314                    l = next_l;
315                    break;
316                }
317            }
318            if uips.contains(&l) {
319                break l;
320            }
321        };
322
323        // find cut
324        let new_clause = [-fuip].iter().cloned().collect();
325        let new_clause_ref = Rc::new(RefCell::new(new_clause));
326
327        fn find_cut(lit: i32,
328                    new_clause: Rc<RefCell<HashSet<i32>>>,
329                    i_graph: &HashMap<i32, (usize, Vec<i32>)>,
330                    level: usize,
331                    fuip: i32) {
332            if i_graph[&lit].0 != level {
333                new_clause.borrow_mut().insert(-lit);
334                return;
335            }
336            if lit == fuip {
337                return;
338            }
339
340            for &l in &i_graph[&lit].1 {
341                find_cut(
342                    l,
343                    Rc::clone(&new_clause),
344                    i_graph,
345                    level,
346                    fuip
347                );
348            }
349        }
350
351        find_cut(lit, Rc::clone(&new_clause_ref), &self.i_graph, self.level, fuip);
352        find_cut(-lit, Rc::clone(&new_clause_ref), &self.i_graph, self.level, fuip);
353
354        // add clause
355        let clause = new_clause_ref.borrow().clone();
356        self.clauses.push(
357            clause
358                .iter()
359                .map(|&l| Rc::clone(&self.literals[&l]))
360                .collect::<Vec<_>>()
361        );
362        let clause_idx = self.clauses.len() - 1;
363        let mut clause_iter = clause.iter();
364        for _ in 0..std::cmp::min(2, clause.len()) {
365            let lit = -clause_iter.next().unwrap();
366            self.watchers
367                .get_mut(&lit)
368                .unwrap()
369                .insert(clause_idx);
370        }
371
372        self.var_inc *= BUMP_FACTOR;
373        for lit in clause {
374            if let None = self.cur_var_order.get_priority(&lit) {
375                continue;
376            }
377
378            let new_p = OrderedFloat::from(self.cur_var_order
379                                           .get_priority(&lit)
380                                           .unwrap()
381                                           .into_inner()
382                                           + self.var_inc);
383
384            self.cur_var_order.change_priority(
385                &lit,
386                new_p
387            );
388
389            self.var_inc *= BUMP_FACTOR;
390            if new_p.into_inner() * self.var_inc > 1e100 {
391                self.var_inc *= 1e-100;
392
393                for (_, p) in &mut self.cur_var_order {
394                    *p = OrderedFloat::from(p.into_inner() * 1e-100);
395                }
396            }
397        }
398    }
399}
400
401trait Clause {
402    fn is_satisfied(&self) -> bool;
403    fn is_conflict(&self) -> bool;
404    fn is_unit(&self) -> bool;
405    fn get_unset(&self) -> Option<Rc<Lit>>;
406}
407
408impl Clause for Vec<Rc<Lit>> {
409    fn is_satisfied(&self) -> bool {
410        self.iter().any(|lit| lit.is_true())
411    }
412
413    fn is_conflict(&self) -> bool {
414        self.iter().all(|lit| lit.is_false())
415    }
416
417    fn is_unit(&self) -> bool {
418        self.iter().filter(|&lit| lit.is_unset()).count() == 1
419    }
420
421    fn get_unset(&self) -> Option<Rc<Lit>> {
422        for lit in self {
423            if lit.is_unset() {
424                return Some(Rc::clone(lit));
425            }
426        }
427
428        None
429    }
430}
431
432struct Lit {
433    lit: i32,
434    value: Cell<i8>,
435}
436
437impl Lit {
438    fn new(lit: i32) -> Lit {
439        Lit { lit, value: Cell::new(0) }
440    }
441
442    fn set_true(&self) {
443        self.value.set(1)
444    }
445
446    fn set_false(&self) {
447        self.value.set(-1)
448    }
449
450    fn unset(&self) {
451        self.value.set(0)
452    }
453
454    fn is_true(&self) -> bool {
455        self.value.get() == 1
456    }
457
458    fn is_false(&self) -> bool {
459        self.value.get() == -1
460    }
461
462    fn is_unset(&self) -> bool {
463        self.value.get() == 0
464    }
465
466    fn to_int(&self) -> i32 {
467        self.lit
468    }
469}
470
471impl std::fmt::Debug for Lit {
472    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473        self.lit.fmt(f)
474    }
475}