splr/cdb/
sls.rs

1/// Implementation of Stochastic Local Search
2use {
3    crate::{assign::AssignIF, types::*},
4    std::collections::HashMap,
5};
6
7pub trait StochasticLocalSearchIF {
8    /// returns the decision level of the given assignment and the one of the final assignment.
9    /// Note: the lower level a set of clauses make a conflict at,
10    /// the higher learning rate a solver can keep and the better learnt clauses we will have.
11    /// This would be a better criteria that can be used in CDCL solvers.
12    fn stochastic_local_search(
13        &mut self,
14        asg: &impl AssignIF,
15        start: &mut HashMap<VarId, bool>,
16        limit: usize,
17    ) -> (usize, usize);
18}
19
20impl StochasticLocalSearchIF for ClauseDB {
21    fn stochastic_local_search(
22        &mut self,
23        _asg: &impl AssignIF,
24        assignment: &mut HashMap<VarId, bool>,
25        limit: usize,
26    ) -> (usize, usize) {
27        let mut returns: (usize, usize) = (0, 0);
28        let mut last_flip = self.num_clause;
29        let mut seed = 721_109;
30        for step in 1..=limit {
31            let mut unsat_clauses = 0;
32            // let mut level: DecisionLevel = 0;
33            // CONSIDER: counting only given (permanent) clauses.
34            let mut flip_target: HashMap<VarId, usize> = HashMap::new();
35            let mut target_clause: Option<&Clause> = None;
36            for c in self.clause.iter().skip(1).filter(|c| !c.is_dead()) {
37                // let mut cls_lvl: DecisionLevel = 0;
38                if c.is_falsified(assignment, &mut flip_target) {
39                    unsat_clauses += 1;
40                    // for l in c.lits.iter() {
41                    //     cls_lvl = cls_lvl.max(asg.level(l.vi()));
42                    // }
43                    // level = level.max(cls_lvl);
44                    if target_clause.is_none() || unsat_clauses == step {
45                        target_clause = Some(c);
46                        for l in c.lits.iter() {
47                            flip_target.entry(l.vi()).or_insert(0);
48                        }
49                    }
50                }
51            }
52            if step == 1 {
53                returns.0 = unsat_clauses;
54                // returns.0 = level as usize;
55            }
56            returns.1 = unsat_clauses;
57            // returns.1 = level as usize;
58            if unsat_clauses == 0 || step == limit {
59                break;
60            }
61            seed = ((((!seed & 0x0000_0000_ffff_ffff) * 1_304_003) % 2_003_819)
62                ^ ((!last_flip & 0x0000_0000_ffff_ffff) * seed))
63                % 3_754_873;
64            if let Some(c) = target_clause {
65                let beta: f64 = 3.2 - 2.1 / (1.0 + unsat_clauses as f64).log(2.0);
66                // let beta: f64 = if unsat_clauses <= 3 { 1.0 } else { 3.0 };
67                let factor = |vi| beta.powf(-(*flip_target.get(vi).unwrap() as f64));
68                let vars = c.lits.iter().map(|l| l.vi()).collect::<Vec<_>>();
69                let index = ((seed % 100) as f64 / 100.0) * vars.iter().map(factor).sum::<f64>();
70                let mut sum: f64 = 0.0;
71                for vi in vars.iter() {
72                    sum += factor(vi);
73                    if index <= sum {
74                        assignment.entry(*vi).and_modify(|e| *e = !*e);
75                        last_flip = *vi;
76                        break;
77                    }
78                }
79            } else {
80                break;
81            }
82        }
83        returns
84    }
85}
86
87impl Clause {
88    fn is_falsified(
89        &self,
90        assignment: &HashMap<VarId, bool>,
91        flip_target: &mut HashMap<VarId, usize>,
92    ) -> bool {
93        let mut num_sat = 0;
94        let mut sat_vi = 0;
95        for l in self.iter() {
96            let vi = l.vi();
97            match assignment.get(&vi) {
98                Some(b) if *b == l.as_bool() => {
99                    if num_sat == 1 {
100                        return false;
101                    }
102                    num_sat += 1;
103                    sat_vi = vi;
104                }
105                None => unreachable!(),
106                _ => (),
107            }
108        }
109        if num_sat == 0 {
110            true
111        } else {
112            *flip_target.entry(sat_vi).or_insert(0) += 1;
113            false
114        }
115    }
116}