1use num_rational::Ratio;
2use ordered_float::OrderedFloat;
3use priority_queue::PriorityQueue;
4
5use std::cell::Cell;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::collections::HashSet;
9use std::rc::Rc;
10
11const BUMP_FACTOR: f64 = 1. / 0.95;
12
13pub struct Solver {
14 literals: HashMap<i32, Rc<Lit>>,
15 watchers: HashMap<i32, HashSet<usize>>,
16 var_order: PriorityQueue<i32, OrderedFloat<f64>>,
17 clauses: Vec<Vec<Rc<Lit>>>,
18 decisions: HashMap<usize, HashSet<i32>>,
19 i_graph: HashMap<i32, (usize, Vec<i32>)>,
20 var_inc: f64,
21 level: usize,
22 cur_watchers: HashMap<i32, HashSet<usize>>,
23 cur_var_order: PriorityQueue<i32, OrderedFloat<f64>>,
24}
25
26impl Solver {
27 pub fn new(input: &str) -> Solver {
28 let lines = input.split('\n');
29
30 let mut clauses = Vec::new();
31 let mut literals: HashMap<i32, Rc<Lit>> = HashMap::new();
32
33 for line in lines {
34 match line.chars().nth(0) {
35 Some('p') => {
36 let mut p = line.split_whitespace();
37 p.next();
38 p.next();
39 let num_vars: i32 = p.next().unwrap().parse().unwrap();
40 literals = (-num_vars..=num_vars)
41 .filter(|&i| i != 0)
42 .map(|i| (i, Rc::new(Lit::new(i))))
43 .collect();
44 },
45 Some('c') => continue,
46 Some(_) => {
47 let mut clause = HashSet::new();
48 let mut lits = line.split_whitespace();
49
50 loop {
51 match lits.next().unwrap().parse().unwrap() {
52 0 => break,
53 i => {clause.insert(i); continue;},
54 }
55 }
56
57 clauses.push(clause
58 .iter()
59 .map(|l| Rc::clone(&literals[l]))
60 .collect::<Vec<_>>());
61 },
62 None => continue,
63 }
64 }
65
66 let mut watchers: HashMap<i32, HashSet<usize>> = literals
67 .iter()
68 .map(|(&i, _)| (i, HashSet::new()))
69 .collect();
70
71 let var_priorities: HashMap<i32, Cell<f64>> = literals
72 .iter()
73 .map(|(&i, _)| (i, Cell::new(1.)))
74 .collect();
75
76 for (i, clause) in clauses.iter().enumerate() {
77 for lit in clause {
78 var_priorities[&lit.to_int()].replace(
79 var_priorities[&lit.to_int()].get() + 1.
80 );
81 }
82
83 for j in 0..std::cmp::min(2, clause.len()) {
84 let lit = clause[j].to_int();
85 watchers.get_mut(&-lit).unwrap().insert(i);
86 }
87 }
88
89 let mut var_order = PriorityQueue::new();
90 for (lit, p) in var_priorities {
91 var_order.push(lit, OrderedFloat::from(p.take()));
92 }
93
94 Solver {
95 literals,
96 watchers,
97 var_order,
98 clauses,
99 cur_watchers: HashMap::new(),
100 cur_var_order: PriorityQueue::new(),
101 decisions: HashMap::new(),
102 i_graph: HashMap::new(),
103 var_inc: 1.,
104 level: 0,
105 }
106 }
107
108 pub fn solve(mut self) -> String {
109 self.restart();
110
111 loop {
112 let conflict = self.propagate();
113 if let Some(lit) = conflict {
114 if self.level == 0 {
115 return "UNSAT".to_owned();
116 } else {
117 self.analyze(lit);
118 self.restart();
119 }
120 } else {
121 if self.satisfied() {
122 return self.model();
123 } else {
124 self.decide();
125 }
126 }
127 }
128 }
129
130 fn propagate(&mut self) -> Option<i32> {
131 let mut unit_literals = HashSet::new();
133 if self.level == 0 {
136 for clause in &self.clauses {
137 if clause.is_unit() {
138 let lit = clause[0].to_int();
139 unit_literals.insert(lit);
140 self.decisions.get_mut(&self.level).unwrap().insert(lit);
141 self.i_graph.insert(lit, (self.level, Vec::new()));
142 }
143 }
144 } else {
145 let lit = self.decisions[&self.level].iter().next().cloned().unwrap();
146 unit_literals.insert(lit);
147 }
148
149 while !unit_literals.is_empty() {
150 let lit = unit_literals.iter().next().cloned().unwrap();
151 unit_literals.remove(&lit);
152 self.literals.get_mut(&lit).unwrap().set_true();
153 self.literals.get_mut(&-lit).unwrap().set_false();
154
155 let indexes = self.cur_watchers[&lit].clone();
156 for i in indexes {
157 let clause = &self.clauses[i];
158 if clause.is_satisfied() {
159 continue;
161 } else if clause.is_unit() {
162 let unit_lit = clause.get_unset().unwrap().to_int();
164 if self.i_graph.contains_key(&unit_lit) {
165 continue
166 }
167 unit_literals.insert(unit_lit);
168
169 self.decisions.get_mut(&self.level).unwrap().insert(unit_lit);
170 let reason = clause
171 .iter()
172 .map(|l| l.to_int())
173 .filter(|&l| l != unit_lit)
174 .map(|l| -l)
175 .collect::<Vec<_>>();
176 self.i_graph.insert(unit_lit, (self.level, reason));
177 } else if clause.is_conflict() {
178 self.decisions.get_mut(&self.level).unwrap().insert(-lit);
180 let reason = clause
181 .iter()
182 .map(|l| l.to_int())
183 .filter(|&l| l != -lit)
184 .map(|l| -l)
185 .collect::<Vec<_>>();
186 self.i_graph.insert(-lit, (self.level, reason));
187 return Some(lit);
188 } else {
189 let mut clause_iter = clause.iter();
191 let l = loop {
192 let l = clause_iter.next().unwrap();
193 if !l.is_unset() {
194 continue;
195 }
196 let l = l.to_int();
197 if self.cur_watchers[&-l].contains(&i) {
198 continue;
199 }
200 break l;
201 };
202
203 self.cur_watchers.get_mut(&lit).unwrap().remove(&i);
204 self.cur_watchers.get_mut(&-l).unwrap().insert(i);
205 }
206 }
207 }
208
209 None
210 }
211
212 fn satisfied(&self) -> bool {
213 self.clauses.iter().all(Clause::is_satisfied)
214 }
215
216 fn restart(&mut self) {
217 for lit in self.literals.values() {
218 lit.unset();
219 }
220 self.cur_watchers = self.watchers.clone();
221 self.cur_var_order = self.var_order.clone();
222 self.decisions.insert(0, HashSet::new());
223 self.i_graph.clear();
224 self.level = 0;
225 }
226
227 fn model(&self) -> String {
228 (1..=self.literals.len() as i32 / 2)
229 .map(|l| if self.i_graph.contains_key(&-l) { -l } else { l })
230 .map(|l| l.to_string())
231 .collect::<Vec<_>>()
232 .join(" ")
233 }
234
235 fn decide(&mut self) {
236 let lit = loop {
237 let next_lit = self.cur_var_order.pop().unwrap().0;
238 if self.literals[&next_lit].is_unset() {
239 break next_lit;
240 }
241 };
242
243 self.level += 1;
244 self.decisions.insert(self.level, [lit].iter().cloned().collect());
245 self.i_graph.insert(lit, (self.level, Vec::new()));
246 }
247
248 fn analyze(&mut self, lit: i32) {
249 let mut uips = HashSet::new();
251 let weights = self.decisions[&self.level]
252 .iter()
253 .map(|&l| (l, Ratio::<i128>::new(0, 1)))
254 .collect::<HashMap<_, _>>();
255 let weights_ref = Rc::new(RefCell::new(weights));
256
257 fn explore(lit: i32,
258 weight: Ratio<i128>,
259 weights: Rc<RefCell<HashMap<i32, Ratio<i128>>>>,
260 i_graph: &HashMap<i32, (usize, Vec<i32>)>,
261 level: usize) {
262 *weights.borrow_mut().get_mut(&lit).unwrap() += weight;
263 let next_lits = i_graph[&lit].1
264 .iter()
265 .filter(|&l| i_graph[l].0 == level)
266 .collect::<Vec<_>>();
267 for &l in &next_lits {
268 explore(*l,
269 weight / next_lits.len() as i128,
270 Rc::clone(&weights),
271 i_graph,
272 level,
273 );
274 }
275 }
276
277 explore(lit,
278 Ratio::new(1, 1),
279 Rc::clone(&weights_ref),
280 &self.i_graph,
281 self.level);
282
283 for l in weights_ref.borrow().keys() {
284 if weights_ref.borrow()[l] == Ratio::new(1, 1) {
285 uips.insert(*l);
286 }
287 }
288 uips.remove(&lit);
289
290 let weights = self.decisions[&self.level]
291 .iter()
292 .map(|&l| (l, Ratio::new(0, 1)))
293 .collect::<HashMap<_, _>>();
294 let weights_ref = Rc::new(RefCell::new(weights));
295
296 explore(-lit,
297 Ratio::new(1, 1),
298 Rc::clone(&weights_ref),
299 &self.i_graph,
300 self.level);
301
302 for l in weights_ref.borrow().keys() {
303 if weights_ref.borrow()[l] == Ratio::new(1, 1) && uips.contains(l) {
304 continue;
305 }
306
307 uips.remove(l);
308 }
309
310 let mut l = lit;
311 let fuip = loop {
312 for &next_l in &self.i_graph[&l].1 {
313 if self.i_graph[&next_l].0 == self.level {
314 l = next_l;
315 break;
316 }
317 }
318 if uips.contains(&l) {
319 break l;
320 }
321 };
322
323 let new_clause = [-fuip].iter().cloned().collect();
325 let new_clause_ref = Rc::new(RefCell::new(new_clause));
326
327 fn find_cut(lit: i32,
328 new_clause: Rc<RefCell<HashSet<i32>>>,
329 i_graph: &HashMap<i32, (usize, Vec<i32>)>,
330 level: usize,
331 fuip: i32) {
332 if i_graph[&lit].0 != level {
333 new_clause.borrow_mut().insert(-lit);
334 return;
335 }
336 if lit == fuip {
337 return;
338 }
339
340 for &l in &i_graph[&lit].1 {
341 find_cut(
342 l,
343 Rc::clone(&new_clause),
344 i_graph,
345 level,
346 fuip
347 );
348 }
349 }
350
351 find_cut(lit, Rc::clone(&new_clause_ref), &self.i_graph, self.level, fuip);
352 find_cut(-lit, Rc::clone(&new_clause_ref), &self.i_graph, self.level, fuip);
353
354 let clause = new_clause_ref.borrow().clone();
356 self.clauses.push(
357 clause
358 .iter()
359 .map(|&l| Rc::clone(&self.literals[&l]))
360 .collect::<Vec<_>>()
361 );
362 let clause_idx = self.clauses.len() - 1;
363 let mut clause_iter = clause.iter();
364 for _ in 0..std::cmp::min(2, clause.len()) {
365 let lit = -clause_iter.next().unwrap();
366 self.watchers
367 .get_mut(&lit)
368 .unwrap()
369 .insert(clause_idx);
370 }
371
372 self.var_inc *= BUMP_FACTOR;
373 for lit in clause {
374 if let None = self.cur_var_order.get_priority(&lit) {
375 continue;
376 }
377
378 let new_p = OrderedFloat::from(self.cur_var_order
379 .get_priority(&lit)
380 .unwrap()
381 .into_inner()
382 + self.var_inc);
383
384 self.cur_var_order.change_priority(
385 &lit,
386 new_p
387 );
388
389 self.var_inc *= BUMP_FACTOR;
390 if new_p.into_inner() * self.var_inc > 1e100 {
391 self.var_inc *= 1e-100;
392
393 for (_, p) in &mut self.cur_var_order {
394 *p = OrderedFloat::from(p.into_inner() * 1e-100);
395 }
396 }
397 }
398 }
399}
400
401trait Clause {
402 fn is_satisfied(&self) -> bool;
403 fn is_conflict(&self) -> bool;
404 fn is_unit(&self) -> bool;
405 fn get_unset(&self) -> Option<Rc<Lit>>;
406}
407
408impl Clause for Vec<Rc<Lit>> {
409 fn is_satisfied(&self) -> bool {
410 self.iter().any(|lit| lit.is_true())
411 }
412
413 fn is_conflict(&self) -> bool {
414 self.iter().all(|lit| lit.is_false())
415 }
416
417 fn is_unit(&self) -> bool {
418 self.iter().filter(|&lit| lit.is_unset()).count() == 1
419 }
420
421 fn get_unset(&self) -> Option<Rc<Lit>> {
422 for lit in self {
423 if lit.is_unset() {
424 return Some(Rc::clone(lit));
425 }
426 }
427
428 None
429 }
430}
431
432struct Lit {
433 lit: i32,
434 value: Cell<i8>,
435}
436
437impl Lit {
438 fn new(lit: i32) -> Lit {
439 Lit { lit, value: Cell::new(0) }
440 }
441
442 fn set_true(&self) {
443 self.value.set(1)
444 }
445
446 fn set_false(&self) {
447 self.value.set(-1)
448 }
449
450 fn unset(&self) {
451 self.value.set(0)
452 }
453
454 fn is_true(&self) -> bool {
455 self.value.get() == 1
456 }
457
458 fn is_false(&self) -> bool {
459 self.value.get() == -1
460 }
461
462 fn is_unset(&self) -> bool {
463 self.value.get() == 0
464 }
465
466 fn to_int(&self) -> i32 {
467 self.lit
468 }
469}
470
471impl std::fmt::Debug for Lit {
472 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473 self.lit.fmt(f)
474 }
475}