1use clap::{Parser, ValueEnum};
14use rustsat::{
15 instances::{fio, ManageVars, MultiOptInstance, Objective, OptInstance},
16 types::Clause,
17};
18use std::{
19 collections::BTreeSet,
20 fmt,
21 io::{self, IsTerminal, Write},
22 path::PathBuf,
23};
24use termcolor::{Buffer, BufferWriter, Color, ColorSpec, WriteColor};
25
26struct Cli {
27 in_path: Option<PathBuf>,
28 out_path: Option<PathBuf>,
29 input_format: InputFormat,
30 output_format: OutputFormat,
31 split_alg: SplitAlg,
32 max_combs: usize,
33 always_dump: bool,
34 opb_opts: fio::opb::Options,
35 stdout: BufferWriter,
36 stderr: BufferWriter,
37}
38
39impl Cli {
40 fn init() -> Self {
41 let args = Args::parse();
42 Self {
43 in_path: args.in_path,
44 out_path: args.out_path,
45 input_format: args.input_format,
46 output_format: args.output_format,
47 split_alg: args.split_alg,
48 max_combs: args.max_combs,
49 always_dump: args.always_dump,
50 opb_opts: fio::opb::Options {
51 first_var_idx: args.opb_first_var_idx,
52 ..fio::opb::Options::default()
53 },
54 stdout: BufferWriter::stdout(match args.color.color {
55 concolor_clap::ColorChoice::Always => termcolor::ColorChoice::Always,
56 concolor_clap::ColorChoice::Never => termcolor::ColorChoice::Never,
57 concolor_clap::ColorChoice::Auto => {
58 if io::stdout().is_terminal() {
59 termcolor::ColorChoice::Auto
60 } else {
61 termcolor::ColorChoice::Never
62 }
63 }
64 }),
65 stderr: BufferWriter::stderr(match args.color.color {
66 concolor_clap::ColorChoice::Always => termcolor::ColorChoice::Always,
67 concolor_clap::ColorChoice::Never => termcolor::ColorChoice::Never,
68 concolor_clap::ColorChoice::Auto => {
69 if io::stderr().is_terminal() {
70 termcolor::ColorChoice::Auto
71 } else {
72 termcolor::ColorChoice::Never
73 }
74 }
75 }),
76 }
77 }
78
79 fn warning(&self, msg: &str) {
80 let mut buffer = self.stderr.buffer();
81 buffer
82 .set_color(ColorSpec::new().set_bold(true).set_fg(Some(Color::Yellow)))
83 .unwrap();
84 write!(&mut buffer, "warning").unwrap();
85 buffer.reset().unwrap();
86 buffer.set_color(ColorSpec::new().set_bold(true)).unwrap();
87 write!(&mut buffer, ": ").unwrap();
88 buffer.reset().unwrap();
89 writeln!(&mut buffer, "{msg}").unwrap();
90 self.stdout.print(&buffer).unwrap();
91 }
92
93 fn error(&self, err: &anyhow::Error) {
94 let mut buffer = self.stderr.buffer();
95 buffer
96 .set_color(ColorSpec::new().set_bold(true).set_fg(Some(Color::Red)))
97 .unwrap();
98 write!(&mut buffer, "error").unwrap();
99 buffer.reset().unwrap();
100 buffer.set_color(ColorSpec::new().set_bold(true)).unwrap();
101 write!(&mut buffer, ": ").unwrap();
102 buffer.reset().unwrap();
103 writeln!(&mut buffer, "{err}").unwrap();
104 self.stdout.print(&buffer).unwrap();
105 }
106
107 fn info(&self, msg: &str) {
108 let mut buffer = self.stdout.buffer();
109 buffer
110 .set_color(ColorSpec::new().set_bold(true).set_fg(Some(Color::Blue)))
111 .unwrap();
112 write!(&mut buffer, "info").unwrap();
113 buffer.reset().unwrap();
114 buffer.set_color(ColorSpec::new().set_bold(true)).unwrap();
115 write!(&mut buffer, ": ").unwrap();
116 buffer.reset().unwrap();
117 writeln!(&mut buffer, "{msg}").unwrap();
118 self.stdout.print(&buffer).unwrap();
119 }
120
121 fn print_split_stats(&self, split_stats: SplitStats) {
122 let mut buffer = self.stdout.buffer();
123 Self::start_block(&mut buffer);
124 buffer
125 .set_color(ColorSpec::new().set_bold(true).set_fg(Some(Color::Blue)))
126 .unwrap();
127 write!(&mut buffer, "Split Stats").unwrap();
128 buffer.reset().unwrap();
129 buffer.set_color(ColorSpec::new().set_bold(true)).unwrap();
130 write!(&mut buffer, ": ").unwrap();
131 buffer.reset().unwrap();
132 writeln!(
133 &mut buffer,
134 "split objective into {} separate objectives",
135 split_stats.obj_stats.len()
136 )
137 .unwrap();
138 split_stats
139 .obj_stats
140 .into_iter()
141 .enumerate()
142 .for_each(|(idx, os)| Self::print_obj_stats(&mut buffer, idx + 1, os));
143 Self::end_block(&mut buffer);
144 self.stdout.print(&buffer).unwrap();
145 }
146
147 fn print_obj_stats(buffer: &mut Buffer, idx: usize, stats: ObjStats) {
148 Self::start_block(buffer);
149 buffer
150 .set_color(ColorSpec::new().set_fg(Some(Color::Cyan)))
151 .unwrap();
152 write!(buffer, "Objective").unwrap();
153 buffer.reset().unwrap();
154 writeln!(buffer, " #{idx}").unwrap();
155 Self::print_parameter(buffer, "n-softs", stats.n_softs);
156 Self::print_parameter(buffer, "weight-sum", stats.weight_sum);
157 Self::print_parameter(buffer, "min-weight", stats.min_weight);
158 Self::print_parameter(buffer, "max-weight", stats.max_weight);
159 Self::print_parameter(buffer, "multiplier", stats.multiplier);
160 Self::end_block(buffer);
161 }
162
163 fn print_parameter<V: fmt::Display>(buffer: &mut Buffer, name: &str, val: V) {
164 buffer
165 .set_color(ColorSpec::new().set_fg(Some(Color::Cyan)))
166 .unwrap();
167 write!(buffer, "{name}").unwrap();
168 buffer.reset().unwrap();
169 writeln!(buffer, ": {val}").unwrap();
170 }
171
172 fn start_block(buffer: &mut Buffer) {
173 buffer.set_color(ColorSpec::new().set_dimmed(true)).unwrap();
174 write!(buffer, ">>>>>").unwrap();
175 buffer.reset().unwrap();
176 writeln!(buffer).unwrap();
177 }
178
179 fn end_block(buffer: &mut Buffer) {
180 buffer.set_color(ColorSpec::new().set_dimmed(true)).unwrap();
181 write!(buffer, "<<<<<").unwrap();
182 buffer.reset().unwrap();
183 writeln!(buffer).unwrap();
184 }
185}
186
187#[derive(Parser)]
188#[command(author, version, about, long_about = None)]
189struct Args {
190 in_path: Option<PathBuf>,
192 out_path: Option<PathBuf>,
194 #[arg(long, default_value_t = SplitAlg::default())]
196 split_alg: SplitAlg,
197 #[arg(long, default_value_t = 100000)]
199 max_combs: usize,
200 #[arg(long, default_value_t = InputFormat::default())]
202 input_format: InputFormat,
203 #[arg(long, default_value_t = OutputFormat::default())]
205 output_format: OutputFormat,
206 #[arg(long, short = 'd')]
208 always_dump: bool,
209 #[arg(long, default_value_t = 1)]
211 opb_first_var_idx: u32,
212 #[command(flatten)]
213 color: concolor_clap::Color,
214}
215
216#[derive(Copy, Clone, PartialEq, Eq, ValueEnum, Default)]
217enum InputFormat {
218 #[default]
224 Infer,
225 Wcnf,
227 Opb,
229}
230
231impl fmt::Display for InputFormat {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 match self {
234 InputFormat::Infer => write!(f, "infer"),
235 InputFormat::Wcnf => write!(f, "wcnf"),
236 InputFormat::Opb => write!(f, "opb"),
237 }
238 }
239}
240
241#[derive(Copy, Clone, PartialEq, Eq, ValueEnum, Default)]
242enum OutputFormat {
243 #[default]
245 AsInput,
246 Mcnf,
248 Opb,
250}
251
252impl fmt::Display for OutputFormat {
253 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254 match self {
255 OutputFormat::AsInput => write!(f, "as-input"),
256 OutputFormat::Mcnf => write!(f, "mcnf"),
257 OutputFormat::Opb => write!(f, "opb"),
258 }
259 }
260}
261
262impl OutputFormat {
263 fn infer(self, write_format: WriteFormat) -> WriteFormat {
264 match self {
265 OutputFormat::AsInput => write_format,
266 OutputFormat::Mcnf => WriteFormat::Mcnf,
267 OutputFormat::Opb => WriteFormat::Opb,
268 }
269 }
270}
271
272#[derive(Copy, Clone)]
273enum WriteFormat {
274 Mcnf,
275 Opb,
276}
277
278#[derive(ValueEnum, Default, Clone, Copy, PartialEq, Eq)]
279enum SplitAlg {
280 Bmo,
283 Gcd,
287 #[default]
291 Gbmo,
292}
293
294impl fmt::Display for SplitAlg {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 match self {
297 SplitAlg::Bmo => write!(f, "bmo"),
298 SplitAlg::Gcd => write!(f, "gcd"),
299 SplitAlg::Gbmo => write!(f, "gbmo"),
300 }
301 }
302}
303
304struct SplitStats {
305 obj_stats: Vec<ObjStats>,
306}
307
308struct ObjStats {
309 n_softs: usize,
310 weight_sum: usize,
311 min_weight: usize,
312 max_weight: usize,
313 multiplier: usize,
314}
315
316fn split<VM: ManageVars>(
317 so_inst: OptInstance<VM>,
318 cli: &Cli,
319) -> (MultiOptInstance<VM>, SplitStats) {
320 let (constr, obj) = so_inst.decompose();
321
322 if !obj.weighted() {
323 cli.warning("objective is unweighted, can't split");
324 let obj_stats = ObjStats {
325 n_softs: obj.n_softs(),
326 weight_sum: obj.weight_sum(),
327 min_weight: obj.min_weight(),
328 max_weight: obj.max_weight(),
329 multiplier: 1,
330 };
331 return (
332 MultiOptInstance::compose(constr, vec![obj]),
333 SplitStats {
334 obj_stats: vec![obj_stats],
335 },
336 );
337 }
338
339 let (softs, offset) = obj.into_soft_cls();
340
341 if offset != 0 {
342 cli.warning(&format!(
343 "objective offset is not zero ({offset}), will be added to the lowest ranking objective"
344 ));
345 }
346
347 let mut sorted_clauses: Vec<_> = softs.into_iter().collect();
348 sorted_clauses.sort_by(|wc1, wc2| wc1.1.cmp(&wc2.1));
349
350 let (mut objs, split_stats) = match cli.split_alg {
351 SplitAlg::Bmo => split_bmo(sorted_clauses),
352 SplitAlg::Gcd => split_gbmo(sorted_clauses, cli),
353 SplitAlg::Gbmo => split_gbmo(sorted_clauses, cli),
354 };
355
356 objs.last_mut().unwrap().set_offset(offset);
358
359 (MultiOptInstance::compose(constr, objs), split_stats)
360}
361
362fn perform_split(
363 sorted_clauses: Vec<(Clause, usize)>,
364 split_ends: Vec<usize>,
365) -> (Vec<Objective>, SplitStats) {
366 let mut objs = vec![];
368 let mut split_start = 0;
369 let mut split_stats = SplitStats { obj_stats: vec![] };
370 for split_end in split_ends {
371 let softs = &sorted_clauses[split_start..split_end + 1];
372 let w_gcd = softs
373 .iter()
374 .fold(softs[0].1, |w_gcd, (_, w)| gcd(w_gcd, *w));
375 let obj = Objective::from_iter(softs.iter().cloned().map(|(c, w)| (c, w / w_gcd)));
376 split_stats.obj_stats.push(ObjStats {
377 n_softs: obj.n_softs(),
378 weight_sum: obj.weight_sum(),
379 min_weight: obj.min_weight(),
380 max_weight: obj.max_weight(),
381 multiplier: w_gcd,
382 });
383 objs.push(obj);
384 split_start = split_end + 1;
385 }
386 let softs = &sorted_clauses[split_start..];
387 let w_gcd = softs
388 .iter()
389 .fold(softs[0].1, |w_gcd, (_, w)| gcd(w_gcd, *w));
390 let obj = Objective::from_iter(softs.iter().cloned().map(|(c, w)| (c, w / w_gcd)));
391 split_stats.obj_stats.push(ObjStats {
392 n_softs: obj.n_softs(),
393 weight_sum: obj.weight_sum(),
394 min_weight: obj.min_weight(),
395 max_weight: obj.max_weight(),
396 multiplier: w_gcd,
397 });
398 objs.push(obj);
399 (objs, split_stats)
400}
401
402fn split_bmo(sorted_clauses: Vec<(Clause, usize)>) -> (Vec<Objective>, SplitStats) {
403 let mut multipliers = vec![sorted_clauses.first().unwrap().1];
404 let mut split_ends = vec![];
405 let mut sum = 0;
406 for (idx, (_, w)) in sorted_clauses.iter().enumerate() {
407 if w > multipliers.last().unwrap() {
408 if *w <= sum {
409 split_ends.clear();
411 break;
412 } else {
413 multipliers.push(*w);
414 split_ends.push(idx - 1);
415 }
416 }
417 sum += *w;
418 }
419 perform_split(sorted_clauses, split_ends)
420}
421
422fn gcd(mut a: usize, mut b: usize) -> usize {
423 while b != 0 {
425 a %= b;
426 std::mem::swap(&mut a, &mut b);
427 }
428 a
429}
430
431fn get_sums_pot_splits_gcds(
432 sorted_clauses: &[(Clause, usize)],
433) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
434 let mut sums = vec![];
435 let mut pot_split_ends = vec![];
436 let mut sum = 0;
437 for idx in 0..sorted_clauses.len() {
439 sum += sorted_clauses[idx].1;
440 sums.push(sum);
441 if idx < sorted_clauses.len() - 1 && sorted_clauses[idx + 1].1 > sum {
442 pot_split_ends.push(idx);
443 }
444 }
445 let mut gcds = vec![sorted_clauses.last().unwrap().1];
447 for (_, w) in sorted_clauses.iter().rev().skip(1) {
448 gcds.push(gcd(*gcds.last().unwrap(), *w));
449 }
450 (sums, pot_split_ends, gcds.into_iter().rev().collect())
451}
452
453fn check_split_thorough_gbmo(
454 right_partition: &[(Clause, usize)],
455 left_sum: usize,
456 cli: &Cli,
457) -> bool {
458 for idx in 0..right_partition.len() - 1 {
460 let dist = right_partition[idx + 1].1 - right_partition[idx].1;
461 if dist != 0 && dist < left_sum {
462 return false;
463 }
464 }
465 let right_sum = right_partition.iter().fold(0, |s, (_, w)| s + w);
467 let mut all_weight_combs: BTreeSet<usize> = BTreeSet::new();
468 for (_, w) in right_partition {
469 let w = *w;
470 for weight_comb in all_weight_combs.iter().copied().collect::<Vec<_>>() {
473 let new_comb = weight_comb + w;
474 if !all_weight_combs.insert(new_comb) {
475 continue;
477 }
478 let next_lower = all_weight_combs.range(0..new_comb).last().unwrap();
479 if new_comb - *next_lower <= left_sum {
480 return false;
482 }
483 if let Some(next_higher) = all_weight_combs.range(new_comb + 1..right_sum + 1).next() {
484 if next_higher - new_comb <= left_sum {
485 return false;
487 }
488 }
489 if all_weight_combs.len() > cli.max_combs {
490 cli.warning(&format!(
491 "thorough GBMO check terminated after {} checked weight combinations",
492 all_weight_combs.len()
493 ));
494 return false;
495 }
496 }
497 if !all_weight_combs.insert(w) {
498 continue;
500 }
501 if let Some(next_lower) = all_weight_combs.range(0..w).last() {
502 if w - *next_lower <= left_sum {
503 return false;
505 }
506 }
507 if let Some(next_higher) = all_weight_combs.range(w + 1..right_sum).next() {
508 if next_higher - w <= left_sum {
509 return false;
511 }
512 }
513 if all_weight_combs.len() > cli.max_combs {
514 cli.warning(&format!(
515 "thorough GBMO check terminated after {} checked weight combinations",
516 all_weight_combs.len()
517 ));
518 return false;
519 }
520 }
521 true
522}
523
524fn split_gbmo(sorted_clauses: Vec<(Clause, usize)>, cli: &Cli) -> (Vec<Objective>, SplitStats) {
525 let (sums, pot_split_ends, gcds) = get_sums_pot_splits_gcds(&sorted_clauses);
526 let mut split_ends = vec![];
527 for split_end in pot_split_ends {
528 if sums[split_end] < gcds[split_end + 1]
530 || (cli.split_alg == SplitAlg::Gbmo
531 && check_split_thorough_gbmo(
532 &sorted_clauses[split_end + 1..],
533 sums[split_end],
534 cli,
535 ))
536 {
537 split_ends.push(split_end);
538 }
539 }
540 perform_split(sorted_clauses, split_ends)
541}
542
543macro_rules! is_one_of {
544 ($a:expr, $($b:expr),*) => {
545 $( $a == $b || )* false
546 }
547}
548
549fn parse_instance(
550 path: &Option<PathBuf>,
551 file_format: InputFormat,
552 opb_opts: fio::opb::Options,
553) -> anyhow::Result<(OptInstance, WriteFormat)> {
554 match file_format {
555 InputFormat::Infer => {
556 if let Some(path) = path {
557 if let Some(ext) = path.extension() {
558 let path_without_compr = path.with_extension("");
559 let ext = if is_one_of!(ext, "gz", "bz2", "xz") {
560 match path_without_compr.extension() {
562 Some(ext) => ext,
563 None => anyhow::bail!("no file extension after compression extension"),
564 }
565 } else {
566 ext
567 };
568 if is_one_of!(ext, "wcnf") {
569 OptInstance::from_dimacs_path(path).map(|inst| (inst, WriteFormat::Mcnf))
570 } else if is_one_of!(ext, "opb") {
571 OptInstance::from_opb_path(path, opb_opts)
572 .map(|inst| (inst, WriteFormat::Opb))
573 } else {
574 anyhow::bail!("unknown file extension")
575 }
576 } else {
577 anyhow::bail!("no file extension")
578 }
579 } else {
580 anyhow::bail!("cannot infer file format from stdin")
581 }
582 }
583 InputFormat::Wcnf => {
584 if let Some(path) = path {
585 OptInstance::from_dimacs_path(path).map(|inst| (inst, WriteFormat::Mcnf))
586 } else {
587 OptInstance::from_dimacs(&mut io::BufReader::new(io::stdin()))
588 .map(|inst| (inst, WriteFormat::Mcnf))
589 }
590 }
591 InputFormat::Opb => {
592 if let Some(path) = path {
593 OptInstance::from_opb_path(path, opb_opts).map(|inst| (inst, WriteFormat::Opb))
594 } else {
595 OptInstance::from_opb(&mut io::BufReader::new(io::stdin()), opb_opts)
596 .map(|inst| (inst, WriteFormat::Opb))
597 }
598 }
599 }
600}
601
602macro_rules! handle_error {
603 ($res:expr, $cli:expr) => {{
604 match $res {
605 Ok(val) => val,
606 Err(err) => {
607 $cli.error(&err);
608 anyhow::bail!(err)
609 }
610 }
611 }};
612}
613
614fn main() -> anyhow::Result<()> {
615 let cli = Cli::init();
616
617 if let Some(path) = &cli.in_path {
618 cli.info(&format!("finding splits in {}", path.display()));
619 }
620
621 let (so_inst, write_format) = handle_error!(
622 parse_instance(&cli.in_path, cli.input_format, cli.opb_opts),
623 cli
624 );
625
626 let (mut mo_inst, split_stats) = split(so_inst, &cli);
627
628 if cli.out_path.is_some() {
629 cli.print_split_stats(split_stats);
630 }
631
632 let found_split = mo_inst.n_objectives() > 1;
633
634 let write_format = cli.output_format.infer(write_format);
635
636 if found_split || cli.always_dump {
637 match write_format {
638 WriteFormat::Mcnf => {
639 mo_inst.constraints_mut().convert_to_cnf();
640 if let Some(path) = &cli.out_path {
641 cli.info(&format!("writing mcnf to {}", path.display()));
642 handle_error!(mo_inst.write_dimacs_path(path), cli);
643 } else {
644 handle_error!(
645 mo_inst.write_dimacs(&mut io::BufWriter::new(io::stdout())),
646 cli
647 );
648 }
649 }
650 WriteFormat::Opb => {
651 let (mut constrs, mut objs) = mo_inst.decompose();
652 for obj in &mut objs {
653 obj.convert_to_soft_lits(constrs.var_manager_mut());
654 }
655 let mo_inst = MultiOptInstance::compose(constrs, objs);
656 if let Some(path) = &cli.out_path {
657 cli.info(&format!("writing opb to {}", path.display()));
658 handle_error!(mo_inst.write_opb_path(path, cli.opb_opts), cli);
659 } else {
660 handle_error!(
661 mo_inst.write_opb(&mut io::BufWriter::new(io::stdout()), cli.opb_opts),
662 cli
663 );
664 }
665 }
666 }
667 }
668
669 if found_split {
670 std::process::exit(0);
671 }
672 std::process::exit(1);
673}
674
675#[cfg(test)]
676mod tests {
677 use rustsat::{clause, lit};
678
679 #[test]
680 fn split_bmo() {
681 let sorted_clauses = vec![
682 (clause![lit![0]], 1),
683 (clause![lit![1]], 1),
684 (clause![lit![2]], 3),
685 (clause![lit![3]], 3),
686 ];
687 let (objs, split_stats) = super::split_bmo(sorted_clauses);
688 assert_eq!(objs.len(), 2);
689 assert_eq!(split_stats.obj_stats.len(), 2);
690 assert_eq!(split_stats.obj_stats[0].n_softs, 2);
691 assert_eq!(split_stats.obj_stats[1].n_softs, 2);
692 assert_eq!(split_stats.obj_stats[0].weight_sum, 2);
693 assert_eq!(split_stats.obj_stats[1].weight_sum, 2);
694 assert_eq!(split_stats.obj_stats[0].min_weight, 1);
695 assert_eq!(split_stats.obj_stats[1].min_weight, 1);
696 assert_eq!(split_stats.obj_stats[0].max_weight, 1);
697 assert_eq!(split_stats.obj_stats[1].max_weight, 1);
698 assert_eq!(split_stats.obj_stats[0].multiplier, 1);
699 assert_eq!(split_stats.obj_stats[1].multiplier, 3);
700 }
701
702 #[test]
703 fn gbmo_pre_calc() {
704 let sorted_clauses = vec![
705 (clause![lit![0]], 1),
706 (clause![lit![1]], 1),
707 (clause![lit![2]], 3),
708 (clause![lit![3]], 3),
709 ];
710 let (sums, pot_split_ends, gcds) = super::get_sums_pot_splits_gcds(&sorted_clauses);
711 assert_eq!(sums, vec![1, 2, 5, 8]);
712 assert_eq!(pot_split_ends, vec![1]);
713 assert_eq!(gcds, vec![1, 1, 3, 3]);
714 }
715}