1use std::io;
12
13use nom::{
14 branch::alt,
15 character::complete::{alphanumeric1, char, multispace1},
16 combinator::{map, recognize},
17 multi::many1,
18 number::complete::double,
19 sequence::{terminated, tuple},
20};
21use rustsat::{
22 clause,
23 instances::{fio::dimacs, ManageVars},
24 types::{RsHashMap, Var},
25 utils,
26};
27
28#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)]
29enum VarId {
30 Binary(u32, u32),
32 Eq(u32, u32, u32),
34 Same(u32, u32),
36}
37
38#[derive(Debug, PartialEq, Eq, Clone)]
39struct VarManager {
40 next_var: Var,
41 vars: RsHashMap<VarId, Var>,
42}
43
44impl VarManager {
45 fn id(&mut self, id: VarId) -> Var {
46 if let Some(var) = self.vars.get(&id) {
47 return *var;
48 }
49 let var = self.new_var();
50 self.vars.insert(id, var);
51 var
52 }
53}
54
55impl ManageVars for VarManager {
56 fn new_var(&mut self) -> Var {
57 let v = self.next_var;
58 self.next_var += 1;
59 v
60 }
61
62 fn max_var(&self) -> Option<Var> {
63 if self.next_var == Var::new(0) {
64 None
65 } else {
66 Some(self.next_var - 1)
67 }
68 }
69
70 fn increase_next_free(&mut self, v: Var) -> bool {
71 if v > self.next_var {
72 self.next_var = v;
73 return true;
74 };
75 false
76 }
77
78 fn combine(&mut self, other: Self) {
79 if other.next_var > self.next_var {
80 self.next_var = other.next_var;
81 };
82 }
83
84 fn n_used(&self) -> u32 {
85 self.next_var.idx32()
86 }
87
88 fn forget_from(&mut self, min_var: Var) {
89 if min_var < self.next_var {
90 self.vars.retain(|_, var| *var < min_var);
91 self.next_var = min_var;
92 }
93 }
94}
95
96impl Default for VarManager {
97 fn default() -> Self {
98 Self {
99 next_var: Var::new(0),
100 vars: Default::default(),
101 }
102 }
103}
104
105#[derive(Clone, Copy, PartialEq, Eq, Debug)]
106pub enum Similarity {
107 Similar(usize),
108 DisSimilar(usize),
109 DontCare,
110 MustLink,
111 CannotLink,
112}
113
114enum Clause {
115 Eq(u32, u32, u32, u8),
117 Same(u32, u32, u32),
119 MustLink(u32, u32, u32),
121 CannotLink(u32, u32),
123 Similar(u32, u32),
125 DisSimilar(u32, u32),
127}
128
129pub struct Encoding {
130 similarities: Vec<Similarity>,
131 n: u32,
132 a: u32,
133 var_manager: VarManager,
134 next_clause: Clause,
135}
136
137impl Encoding {
138 pub fn new<R: io::BufRead, Map: Fn(f64) -> Similarity>(
139 in_reader: R,
140 sim_map: Map,
141 ) -> anyhow::Result<Self> {
142 let mut ident_map = RsHashMap::default();
143 let mut next_idx: u32 = 0;
144 let process_line =
145 |line: Result<String, io::Error>| -> Option<anyhow::Result<(String, String, f64)>> {
146 let line = line.ok()?;
147 let line = line.trim_start();
148 if line.starts_with('%') {
149 return None;
150 }
151 let (_, tup) = tuple((
152 terminated(ident, multispace1),
153 terminated(ident, multispace1),
154 double,
155 ))(line)
156 .ok()?;
157 if !ident_map.contains_key(&tup.0) {
158 ident_map.insert(tup.0.clone(), next_idx);
159 next_idx += 1;
160 }
161 if !ident_map.contains_key(&tup.1) {
162 ident_map.insert(tup.1.clone(), next_idx);
163 next_idx += 1;
164 }
165 Some(Ok(tup))
166 };
167 let pairwise = in_reader
168 .lines()
169 .filter_map(process_line)
170 .collect::<Result<Vec<_>, _>>()?;
171 let n = ident_map.len() as u32;
172
173 let mut similarities = Vec::new();
174 similarities.resize(Self::sim_idx(n - 2, n - 1, n) + 1, Similarity::DontCare);
175 for (ident1, ident2, sim) in pairwise {
176 let mut idx1 = *ident_map.get(&ident1).unwrap();
177 let mut idx2 = *ident_map.get(&ident2).unwrap();
178 if idx1 == idx2 {
179 if sim_map(sim) != Similarity::MustLink {
180 eprintln!(
181 "warning: self-similarity for {} is {} (mapped to {:?})",
182 ident1,
183 sim,
184 sim_map(sim)
185 )
186 }
187 continue;
188 }
189 if idx2 < idx1 {
190 std::mem::swap(&mut idx1, &mut idx2);
191 }
192 similarities[Self::sim_idx(idx1, idx2, n)] = sim_map(sim);
193 }
194
195 let a = utils::digits(ident_map.len(), 2);
196
197 Ok(Self {
198 similarities,
199 n,
200 a,
201 var_manager: Default::default(),
202 next_clause: Clause::Eq(0, 1, 0, 0),
203 })
204 }
205
206 #[inline]
208 fn sim_idx(idx1: u32, idx2: u32, n_idents: u32) -> usize {
209 let idx1 = idx1 as usize;
212 let idx2 = idx2 as usize;
213 let n_idents = n_idents as usize;
214 n_idents * (n_idents - 1) / 2 - (n_idents - idx1) * ((n_idents - idx1) - 1) / 2 + idx2
215 - idx1
216 - 1
217 }
218}
219
220impl Iterator for Encoding {
221 type Item = dimacs::McnfLine;
222
223 fn next(&mut self) -> Option<Self::Item> {
224 loop {
225 match self.next_clause {
226 Clause::Eq(idx1, idx2, k, cidx) => {
227 if idx1 >= self.n - 1 {
228 self.next_clause = Clause::Same(0, 1, 0);
229 continue;
230 }
231 if idx2 >= self.n {
232 self.next_clause = Clause::Eq(idx1 + 1, idx1 + 2, 0, 0);
233 continue;
234 }
235 if k >= self.a {
236 self.next_clause = Clause::Eq(idx1, idx2 + 1, 0, 0);
237 continue;
238 }
239 if cidx >= 4 {
240 self.next_clause = Clause::Eq(idx1, idx2, k + 1, 0);
241 continue;
242 }
243 if matches!(
244 self.similarities[Self::sim_idx(idx1, idx2, self.n)],
245 Similarity::DontCare | Similarity::MustLink
246 ) {
247 self.next_clause = Clause::Eq(idx1, idx2 + 1, 0, 0);
249 continue;
250 }
251 self.next_clause = Clause::Eq(idx1, idx2, k, cidx + 1);
252 let b1 = self.var_manager.id(VarId::Binary(idx1, k)).pos_lit();
253 let b2 = self.var_manager.id(VarId::Binary(idx2, k)).pos_lit();
254 let eql = self.var_manager.id(VarId::Eq(idx1, idx2, k)).pos_lit();
255 return Some(dimacs::McnfLine::Hard(match cidx {
256 0 => clause![eql, b1, b2],
257 1 => clause![eql, !b1, !b2],
258 2 => clause![!eql, !b1, b2],
259 3 => clause![!eql, b1, !b2],
260 _ => panic!(),
261 }));
262 }
263 Clause::Same(idx1, idx2, cidx) => {
264 if idx1 >= self.n - 1 {
265 self.next_clause = Clause::MustLink(0, 1, 0);
266 continue;
267 }
268 if idx2 >= self.n {
269 self.next_clause = Clause::Same(idx1 + 1, idx1 + 2, 0);
270 continue;
271 }
272 if cidx > self.a {
273 self.next_clause = Clause::Same(idx1, idx2 + 1, 0);
274 continue;
275 }
276 if matches!(
277 self.similarities[Self::sim_idx(idx1, idx2, self.n)],
278 Similarity::DontCare | Similarity::MustLink | Similarity::CannotLink
279 ) {
280 self.next_clause = Clause::Same(idx1, idx2 + 1, 0);
282 continue;
283 }
284 self.next_clause = Clause::Same(idx1, idx2, cidx + 1);
285 let sl = self.var_manager.id(VarId::Same(idx1, idx2)).pos_lit();
286 return Some(dimacs::McnfLine::Hard(match cidx {
287 cidx if cidx < self.a => {
288 let eql = self.var_manager.id(VarId::Eq(idx1, idx2, cidx)).pos_lit();
289 clause![!sl, eql]
290 }
291 _ => {
292 let mut cl: rustsat::types::constraints::Clause = (0..self.a)
293 .map(|k| self.var_manager.id(VarId::Eq(idx1, idx2, k)).neg_lit())
294 .collect();
295 cl.add(sl);
296 cl
297 }
298 }));
299 }
300 Clause::MustLink(idx1, idx2, cidx) => {
301 if idx1 >= self.n - 1 {
302 self.next_clause = Clause::CannotLink(0, 1);
303 continue;
304 }
305 if idx2 >= self.n {
306 self.next_clause = Clause::MustLink(idx1 + 1, idx1 + 2, 0);
307 continue;
308 }
309 if cidx >= 2 * self.a {
310 self.next_clause = Clause::MustLink(idx1, idx2 + 1, 0);
311 continue;
312 }
313 if !matches!(
314 self.similarities[Self::sim_idx(idx1, idx2, self.n)],
315 Similarity::MustLink
316 ) {
317 self.next_clause = Clause::MustLink(idx1, idx2 + 1, 0);
319 continue;
320 }
321 self.next_clause = Clause::MustLink(idx1, idx2, cidx + 1);
322 let b1 = self.var_manager.id(VarId::Binary(idx1, cidx / 2)).pos_lit();
323 let b2 = self.var_manager.id(VarId::Binary(idx2, cidx / 2)).pos_lit();
324 return Some(dimacs::McnfLine::Hard(if cidx % 2 == 0 {
325 clause![!b1, b2]
326 } else {
327 clause![b1, !b2]
328 }));
329 }
330 Clause::CannotLink(idx1, idx2) => {
331 if idx1 >= self.n - 1 {
332 self.next_clause = Clause::Similar(0, 1);
333 continue;
334 }
335 if idx2 >= self.n {
336 self.next_clause = Clause::CannotLink(idx1 + 1, idx1 + 2);
337 continue;
338 }
339 if !matches!(
340 self.similarities[Self::sim_idx(idx1, idx2, self.n)],
341 Similarity::CannotLink
342 ) {
343 self.next_clause = Clause::CannotLink(idx1, idx2 + 1);
345 continue;
346 }
347 self.next_clause = Clause::CannotLink(idx1, idx2 + 1);
348 return Some(dimacs::McnfLine::Hard(
349 (0..self.a)
350 .map(|k| self.var_manager.id(VarId::Eq(idx1, idx2, k)).neg_lit())
351 .collect(),
352 ));
353 }
354 Clause::Similar(idx1, idx2) => {
355 if idx1 >= self.n - 1 {
356 self.next_clause = Clause::DisSimilar(0, 1);
357 continue;
358 }
359 if idx2 >= self.n {
360 self.next_clause = Clause::Similar(idx1 + 1, idx1 + 2);
361 continue;
362 }
363 self.next_clause = Clause::Similar(idx1, idx2 + 1);
364 if let Similarity::Similar(weight) =
365 self.similarities[Self::sim_idx(idx1, idx2, self.n)]
366 {
367 return Some(dimacs::McnfLine::Soft(
368 clause![self.var_manager.id(VarId::Same(idx1, idx2)).pos_lit()],
369 weight,
370 0,
371 ));
372 }
373 }
374 Clause::DisSimilar(idx1, idx2) => {
375 if idx1 >= self.n - 1 {
376 return None;
377 }
378 if idx2 >= self.n {
379 self.next_clause = Clause::DisSimilar(idx1 + 1, idx1 + 2);
380 continue;
381 }
382 self.next_clause = Clause::DisSimilar(idx1, idx2 + 1);
383 if let Similarity::DisSimilar(weight) =
384 self.similarities[Self::sim_idx(idx1, idx2, self.n)]
385 {
386 return Some(dimacs::McnfLine::Soft(
387 clause![!self.var_manager.id(VarId::Same(idx1, idx2)).pos_lit()],
388 weight,
389 1,
390 ));
391 }
392 }
393 }
394 }
395 }
396}
397
398fn ident(input: &str) -> nom::IResult<&str, String> {
399 map(
400 recognize(many1(alt((alphanumeric1, recognize(char('_')))))),
401 String::from,
402 )(input)
403}
404
405pub fn scaling_map(sim: f64, multiplier: u32) -> isize {
406 (sim * (multiplier as f64)).trunc() as isize
407}
408
409pub fn saturating_map(sim: isize, dont_care: usize, hard_threshold: usize) -> Similarity {
410 match sim.unsigned_abs() {
411 asim if asim < dont_care => Similarity::DontCare,
412 asim if asim > hard_threshold => {
413 if sim > 0 {
414 Similarity::MustLink
415 } else {
416 Similarity::CannotLink
417 }
418 }
419 _ => {
420 if sim > 0 {
421 Similarity::Similar(sim as usize)
422 } else {
423 Similarity::DisSimilar(-sim as usize)
424 }
425 }
426 }
427}