Skip to main content

rustsat_tools/encodings/cnf/
clustering.rs

1//! # Constrained Correlation Clustering
2//!
3//! Constrained correlation clustering encodings following \[1\].
4//!
5//! ## References
6//!
7//! - \[1\] Jeremias Berg and Matti Järvisalo: _Cost-optimal constrained
8//!   correlation clustering via weighted partial Maximum Satisfiability_, AIJ
9//!   2017.
10
11use std::io;
12
13use nom::{
14    branch::alt,
15    character::complete::{alphanumeric1, char, multispace1},
16    combinator::{map, recognize},
17    multi::many1,
18    number::complete::double,
19    sequence::{terminated, tuple},
20};
21use rustsat::{
22    clause,
23    instances::{fio::dimacs, ManageVars},
24    types::{RsHashMap, Var},
25    utils,
26};
27
28#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)]
29enum VarId {
30    /// `b i k`
31    Binary(u32, u32),
32    /// `EQ i j k`
33    Eq(u32, u32, u32),
34    /// `S i j`
35    Same(u32, u32),
36}
37
38#[derive(Debug, PartialEq, Eq, Clone)]
39struct VarManager {
40    next_var: Var,
41    vars: RsHashMap<VarId, Var>,
42}
43
44impl VarManager {
45    fn id(&mut self, id: VarId) -> Var {
46        if let Some(var) = self.vars.get(&id) {
47            return *var;
48        }
49        let var = self.new_var();
50        self.vars.insert(id, var);
51        var
52    }
53}
54
55impl ManageVars for VarManager {
56    fn new_var(&mut self) -> Var {
57        let v = self.next_var;
58        self.next_var += 1;
59        v
60    }
61
62    fn max_var(&self) -> Option<Var> {
63        if self.next_var == Var::new(0) {
64            None
65        } else {
66            Some(self.next_var - 1)
67        }
68    }
69
70    fn increase_next_free(&mut self, v: Var) -> bool {
71        if v > self.next_var {
72            self.next_var = v;
73            return true;
74        };
75        false
76    }
77
78    fn combine(&mut self, other: Self) {
79        if other.next_var > self.next_var {
80            self.next_var = other.next_var;
81        };
82    }
83
84    fn n_used(&self) -> u32 {
85        self.next_var.idx32()
86    }
87
88    fn forget_from(&mut self, min_var: Var) {
89        if min_var < self.next_var {
90            self.vars.retain(|_, var| *var < min_var);
91            self.next_var = min_var;
92        }
93    }
94}
95
96impl Default for VarManager {
97    fn default() -> Self {
98        Self {
99            next_var: Var::new(0),
100            vars: Default::default(),
101        }
102    }
103}
104
105#[derive(Clone, Copy, PartialEq, Eq, Debug)]
106pub enum Similarity {
107    Similar(usize),
108    DisSimilar(usize),
109    DontCare,
110    MustLink,
111    CannotLink,
112}
113
114enum Clause {
115    /// `EQUALITY i j k 0..3`
116    Eq(u32, u32, u32, u8),
117    /// `SameCluster i j 0..a`
118    Same(u32, u32, u32),
119    /// `ML i j 0..2a`
120    MustLink(u32, u32, u32),
121    /// `CL i j`
122    CannotLink(u32, u32),
123    /// `S i j`
124    Similar(u32, u32),
125    /// `-S i j`
126    DisSimilar(u32, u32),
127}
128
129pub struct Encoding {
130    similarities: Vec<Similarity>,
131    n: u32,
132    a: u32,
133    var_manager: VarManager,
134    next_clause: Clause,
135}
136
137impl Encoding {
138    pub fn new<R: io::BufRead, Map: Fn(f64) -> Similarity>(
139        in_reader: R,
140        sim_map: Map,
141    ) -> anyhow::Result<Self> {
142        let mut ident_map = RsHashMap::default();
143        let mut next_idx: u32 = 0;
144        let process_line =
145            |line: Result<String, io::Error>| -> Option<anyhow::Result<(String, String, f64)>> {
146                let line = line.ok()?;
147                let line = line.trim_start();
148                if line.starts_with('%') {
149                    return None;
150                }
151                let (_, tup) = tuple((
152                    terminated(ident, multispace1),
153                    terminated(ident, multispace1),
154                    double,
155                ))(line)
156                .ok()?;
157                if !ident_map.contains_key(&tup.0) {
158                    ident_map.insert(tup.0.clone(), next_idx);
159                    next_idx += 1;
160                }
161                if !ident_map.contains_key(&tup.1) {
162                    ident_map.insert(tup.1.clone(), next_idx);
163                    next_idx += 1;
164                }
165                Some(Ok(tup))
166            };
167        let pairwise = in_reader
168            .lines()
169            .filter_map(process_line)
170            .collect::<Result<Vec<_>, _>>()?;
171        let n = ident_map.len() as u32;
172
173        let mut similarities = Vec::new();
174        similarities.resize(Self::sim_idx(n - 2, n - 1, n) + 1, Similarity::DontCare);
175        for (ident1, ident2, sim) in pairwise {
176            let mut idx1 = *ident_map.get(&ident1).unwrap();
177            let mut idx2 = *ident_map.get(&ident2).unwrap();
178            if idx1 == idx2 {
179                if sim_map(sim) != Similarity::MustLink {
180                    eprintln!(
181                        "warning: self-similarity for {} is {} (mapped to {:?})",
182                        ident1,
183                        sim,
184                        sim_map(sim)
185                    )
186                }
187                continue;
188            }
189            if idx2 < idx1 {
190                std::mem::swap(&mut idx1, &mut idx2);
191            }
192            similarities[Self::sim_idx(idx1, idx2, n)] = sim_map(sim);
193        }
194
195        let a = utils::digits(ident_map.len(), 2);
196
197        Ok(Self {
198            similarities,
199            n,
200            a,
201            var_manager: Default::default(),
202            next_clause: Clause::Eq(0, 1, 0, 0),
203        })
204    }
205
206    /// Note: the the first index must be smaller than the second
207    #[inline]
208    fn sim_idx(idx1: u32, idx2: u32, n_idents: u32) -> usize {
209        // indexing an upper triangular matrix compactly in a vec
210        // https://stackoverflow.com/questions/27086195/linear-index-upper-triangular-matrix
211        let idx1 = idx1 as usize;
212        let idx2 = idx2 as usize;
213        let n_idents = n_idents as usize;
214        n_idents * (n_idents - 1) / 2 - (n_idents - idx1) * ((n_idents - idx1) - 1) / 2 + idx2
215            - idx1
216            - 1
217    }
218}
219
220impl Iterator for Encoding {
221    type Item = dimacs::McnfLine;
222
223    fn next(&mut self) -> Option<Self::Item> {
224        loop {
225            match self.next_clause {
226                Clause::Eq(idx1, idx2, k, cidx) => {
227                    if idx1 >= self.n - 1 {
228                        self.next_clause = Clause::Same(0, 1, 0);
229                        continue;
230                    }
231                    if idx2 >= self.n {
232                        self.next_clause = Clause::Eq(idx1 + 1, idx1 + 2, 0, 0);
233                        continue;
234                    }
235                    if k >= self.a {
236                        self.next_clause = Clause::Eq(idx1, idx2 + 1, 0, 0);
237                        continue;
238                    }
239                    if cidx >= 4 {
240                        self.next_clause = Clause::Eq(idx1, idx2, k + 1, 0);
241                        continue;
242                    }
243                    if matches!(
244                        self.similarities[Self::sim_idx(idx1, idx2, self.n)],
245                        Similarity::DontCare | Similarity::MustLink
246                    ) {
247                        // Don't need equality vars for don't cares and must links
248                        self.next_clause = Clause::Eq(idx1, idx2 + 1, 0, 0);
249                        continue;
250                    }
251                    self.next_clause = Clause::Eq(idx1, idx2, k, cidx + 1);
252                    let b1 = self.var_manager.id(VarId::Binary(idx1, k)).pos_lit();
253                    let b2 = self.var_manager.id(VarId::Binary(idx2, k)).pos_lit();
254                    let eql = self.var_manager.id(VarId::Eq(idx1, idx2, k)).pos_lit();
255                    return Some(dimacs::McnfLine::Hard(match cidx {
256                        0 => clause![eql, b1, b2],
257                        1 => clause![eql, !b1, !b2],
258                        2 => clause![!eql, !b1, b2],
259                        3 => clause![!eql, b1, !b2],
260                        _ => panic!(),
261                    }));
262                }
263                Clause::Same(idx1, idx2, cidx) => {
264                    if idx1 >= self.n - 1 {
265                        self.next_clause = Clause::MustLink(0, 1, 0);
266                        continue;
267                    }
268                    if idx2 >= self.n {
269                        self.next_clause = Clause::Same(idx1 + 1, idx1 + 2, 0);
270                        continue;
271                    }
272                    if cidx > self.a {
273                        self.next_clause = Clause::Same(idx1, idx2 + 1, 0);
274                        continue;
275                    }
276                    if matches!(
277                        self.similarities[Self::sim_idx(idx1, idx2, self.n)],
278                        Similarity::DontCare | Similarity::MustLink | Similarity::CannotLink
279                    ) {
280                        // Don't need same cluster vars for don't cares, must links, and cannot links
281                        self.next_clause = Clause::Same(idx1, idx2 + 1, 0);
282                        continue;
283                    }
284                    self.next_clause = Clause::Same(idx1, idx2, cidx + 1);
285                    let sl = self.var_manager.id(VarId::Same(idx1, idx2)).pos_lit();
286                    return Some(dimacs::McnfLine::Hard(match cidx {
287                        cidx if cidx < self.a => {
288                            let eql = self.var_manager.id(VarId::Eq(idx1, idx2, cidx)).pos_lit();
289                            clause![!sl, eql]
290                        }
291                        _ => {
292                            let mut cl: rustsat::types::constraints::Clause = (0..self.a)
293                                .map(|k| self.var_manager.id(VarId::Eq(idx1, idx2, k)).neg_lit())
294                                .collect();
295                            cl.add(sl);
296                            cl
297                        }
298                    }));
299                }
300                Clause::MustLink(idx1, idx2, cidx) => {
301                    if idx1 >= self.n - 1 {
302                        self.next_clause = Clause::CannotLink(0, 1);
303                        continue;
304                    }
305                    if idx2 >= self.n {
306                        self.next_clause = Clause::MustLink(idx1 + 1, idx1 + 2, 0);
307                        continue;
308                    }
309                    if cidx >= 2 * self.a {
310                        self.next_clause = Clause::MustLink(idx1, idx2 + 1, 0);
311                        continue;
312                    }
313                    if !matches!(
314                        self.similarities[Self::sim_idx(idx1, idx2, self.n)],
315                        Similarity::MustLink
316                    ) {
317                        // Only need must link clauses for must links
318                        self.next_clause = Clause::MustLink(idx1, idx2 + 1, 0);
319                        continue;
320                    }
321                    self.next_clause = Clause::MustLink(idx1, idx2, cidx + 1);
322                    let b1 = self.var_manager.id(VarId::Binary(idx1, cidx / 2)).pos_lit();
323                    let b2 = self.var_manager.id(VarId::Binary(idx2, cidx / 2)).pos_lit();
324                    return Some(dimacs::McnfLine::Hard(if cidx % 2 == 0 {
325                        clause![!b1, b2]
326                    } else {
327                        clause![b1, !b2]
328                    }));
329                }
330                Clause::CannotLink(idx1, idx2) => {
331                    if idx1 >= self.n - 1 {
332                        self.next_clause = Clause::Similar(0, 1);
333                        continue;
334                    }
335                    if idx2 >= self.n {
336                        self.next_clause = Clause::CannotLink(idx1 + 1, idx1 + 2);
337                        continue;
338                    }
339                    if !matches!(
340                        self.similarities[Self::sim_idx(idx1, idx2, self.n)],
341                        Similarity::CannotLink
342                    ) {
343                        // Only need cannot link clauses for cannot links
344                        self.next_clause = Clause::CannotLink(idx1, idx2 + 1);
345                        continue;
346                    }
347                    self.next_clause = Clause::CannotLink(idx1, idx2 + 1);
348                    return Some(dimacs::McnfLine::Hard(
349                        (0..self.a)
350                            .map(|k| self.var_manager.id(VarId::Eq(idx1, idx2, k)).neg_lit())
351                            .collect(),
352                    ));
353                }
354                Clause::Similar(idx1, idx2) => {
355                    if idx1 >= self.n - 1 {
356                        self.next_clause = Clause::DisSimilar(0, 1);
357                        continue;
358                    }
359                    if idx2 >= self.n {
360                        self.next_clause = Clause::Similar(idx1 + 1, idx1 + 2);
361                        continue;
362                    }
363                    self.next_clause = Clause::Similar(idx1, idx2 + 1);
364                    if let Similarity::Similar(weight) =
365                        self.similarities[Self::sim_idx(idx1, idx2, self.n)]
366                    {
367                        return Some(dimacs::McnfLine::Soft(
368                            clause![self.var_manager.id(VarId::Same(idx1, idx2)).pos_lit()],
369                            weight,
370                            0,
371                        ));
372                    }
373                }
374                Clause::DisSimilar(idx1, idx2) => {
375                    if idx1 >= self.n - 1 {
376                        return None;
377                    }
378                    if idx2 >= self.n {
379                        self.next_clause = Clause::DisSimilar(idx1 + 1, idx1 + 2);
380                        continue;
381                    }
382                    self.next_clause = Clause::DisSimilar(idx1, idx2 + 1);
383                    if let Similarity::DisSimilar(weight) =
384                        self.similarities[Self::sim_idx(idx1, idx2, self.n)]
385                    {
386                        return Some(dimacs::McnfLine::Soft(
387                            clause![!self.var_manager.id(VarId::Same(idx1, idx2)).pos_lit()],
388                            weight,
389                            1,
390                        ));
391                    }
392                }
393            }
394        }
395    }
396}
397
398fn ident(input: &str) -> nom::IResult<&str, String> {
399    map(
400        recognize(many1(alt((alphanumeric1, recognize(char('_')))))),
401        String::from,
402    )(input)
403}
404
405pub fn scaling_map(sim: f64, multiplier: u32) -> isize {
406    (sim * (multiplier as f64)).trunc() as isize
407}
408
409pub fn saturating_map(sim: isize, dont_care: usize, hard_threshold: usize) -> Similarity {
410    match sim.unsigned_abs() {
411        asim if asim < dont_care => Similarity::DontCare,
412        asim if asim > hard_threshold => {
413            if sim > 0 {
414                Similarity::MustLink
415            } else {
416                Similarity::CannotLink
417            }
418        }
419        _ => {
420            if sim > 0 {
421                Similarity::Similar(sim as usize)
422            } else {
423                Similarity::DisSimilar(-sim as usize)
424            }
425        }
426    }
427}