smt2patch/
lib.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! The `stm2patch` library provides the SMT2 "patching" functionalities and
5//! configurations used by the binary tool `smt2patch`.
6
7use smt2parser::{
8    concrete::*,
9    visitors::{CommandVisitor, TermVisitor},
10};
11use std::{
12    collections::{HashMap, HashSet},
13    io::Write,
14    path::Path,
15    str::FromStr,
16};
17use structopt::StructOpt;
18
19/// Configuration for the SMT2 rewriting operations.
20#[derive(Debug, Clone, StructOpt)]
21pub struct RewriterConfig {
22    #[structopt(long, parse(from_str = parse_clauses))]
23    keep_only_clauses: Option<HashSet<String>>,
24
25    #[structopt(long)]
26    get_unsat_core: bool,
27
28    #[structopt(long)]
29    tag_quantifiers: bool,
30
31    #[structopt(long, parse(try_from_str = try_parse_weights))]
32    set_weights: Option<HashMap<String, usize>>,
33}
34
35fn parse_clauses(src: &str) -> HashSet<String> {
36    let src = src.trim();
37    let src = if src.starts_with('(') && src.ends_with(')') {
38        &src[1..src.len() - 1].trim()
39    } else {
40        src
41    };
42    src.split(' ').map(String::from).collect()
43}
44
45fn try_parse_weights(src: &str) -> Result<HashMap<String, usize>, std::num::ParseIntError> {
46    let src = src.trim();
47    src.split(' ')
48        .map(|s| {
49            let mut items = s.splitn(2, '=');
50            let key = items.next().unwrap();
51            let value = items.next().unwrap_or("0").parse()?;
52            Ok((key.to_string(), value))
53        })
54        .collect()
55}
56
57/// State of the SMT2 rewriter.
58#[derive(Debug)]
59pub struct Rewriter {
60    config: RewriterConfig,
61    discarded_options: HashSet<String>,
62    builder: SyntaxBuilder,
63    clause_count: usize,
64    quantifier_count: usize,
65}
66
67const PRODUCE_UNSAT_CORES: &str = "produce-unsat-cores";
68const CLAUSE: &str = "clause!";
69const QUANT: &str = "quant!";
70
71impl Rewriter {
72    pub fn new(config: RewriterConfig, discarded_options: HashSet<String>) -> Self {
73        Self {
74            config,
75            discarded_options,
76            builder: SyntaxBuilder::default(),
77            clause_count: 0,
78            quantifier_count: 0,
79        }
80    }
81
82    fn make_clause_name(&mut self, term: &Term) -> Symbol {
83        let mut qid = String::new();
84        if let Term::Forall { term, .. } = term {
85            if let Some(s) = Self::get_quantifier_name(term) {
86                qid = format!("!{}", s.0);
87            }
88        }
89        let s = format!("{}{}{}", CLAUSE, self.clause_count, qid);
90        self.clause_count += 1;
91        Symbol(s)
92    }
93
94    fn make_quantifier_name(&mut self) -> Symbol {
95        let s = format!("{}{}", QUANT, self.quantifier_count);
96        self.quantifier_count += 1;
97        Symbol(s)
98    }
99
100    // Hack: This value is returned when we mean to discard a command.
101    #[inline]
102    fn assert_true() -> Command {
103        Command::Assert {
104            term: Term::QualIdentifier(QualIdentifier::Simple {
105                identifier: Identifier::Simple {
106                    symbol: Symbol("true".to_string()),
107                },
108            }),
109        }
110    }
111
112    fn get_clause_name(term: &Term) -> Option<&Symbol> {
113        if let Some(AttributeValue::Symbol(s)) = Self::get_attribute(term, "named") {
114            return Some(s);
115        }
116        None
117    }
118
119    fn get_quantifier_name(term: &Term) -> Option<&Symbol> {
120        match Self::get_attribute(term, "qid") {
121            Some(AttributeValue::Symbol(s)) => Some(s),
122            _ => None,
123        }
124    }
125
126    fn get_attribute<'a>(term: &'a Term, key: &str) -> Option<&'a AttributeValue> {
127        match term {
128            Term::Attributes { attributes, .. } => {
129                for (k, v) in attributes {
130                    if k.0 == key {
131                        return Some(v);
132                    }
133                }
134                None
135            }
136            _ => None,
137        }
138    }
139
140    fn set_attribute(mut term: Term, key: String, value: AttributeValue) -> Term {
141        match &mut term {
142            Term::Attributes { attributes, .. } => {
143                for (k, v) in attributes.iter_mut() {
144                    if k.0 == key {
145                        *v = value;
146                        return term;
147                    }
148                }
149                attributes.push((Keyword(key), value));
150                term
151            }
152            _ => Term::Attributes {
153                term: Box::new(term),
154                attributes: vec![(Keyword(key), value)],
155            },
156        }
157    }
158}
159
160impl smt2parser::rewriter::Rewriter for Rewriter {
161    type V = SyntaxBuilder;
162
163    fn visitor(&mut self) -> &mut SyntaxBuilder {
164        &mut self.builder
165    }
166
167    fn process_symbol(&mut self, symbol: Symbol) -> Symbol {
168        // Bump clause_count and quantifier_count when needed to avoid
169        // clashes with user-provided symbols.
170        if symbol.0.starts_with(CLAUSE) {
171            if let Ok(i) = usize::from_str(&symbol.0[CLAUSE.len()..]) {
172                self.clause_count = std::cmp::max(self.clause_count, i + 1);
173            }
174        } else if symbol.0.starts_with(QUANT) {
175            if let Ok(i) = usize::from_str(&symbol.0[QUANT.len()..]) {
176                self.clause_count = std::cmp::max(self.clause_count, i + 1);
177            }
178        }
179        symbol
180    }
181
182    fn visit_forall(&mut self, vars: Vec<(Symbol, Sort)>, term: Term) -> Term {
183        let name = Self::get_quantifier_name(&term)
184            .cloned()
185            .unwrap_or_else(|| self.make_quantifier_name());
186        // Add name if needed.
187        let term = if !self.config.tag_quantifiers {
188            term
189        } else {
190            Self::set_attribute(
191                term,
192                "qid".to_string(),
193                AttributeValue::Symbol(name.clone()),
194            )
195        };
196        // Add weight if needed.
197        let term = match &self.config.set_weights {
198            None => term,
199            Some(weights) => {
200                let w = *weights.get(&name.0).unwrap_or(&0);
201                Self::set_attribute(
202                    term,
203                    "weight".to_string(),
204                    AttributeValue::Constant(Constant::Numeral(w.into())),
205                )
206            }
207        };
208        let value = self.visitor().visit_forall(vars, term);
209        self.process_term(value)
210    }
211
212    fn visit_assert(&mut self, term: Term) -> Command {
213        let name = Self::get_clause_name(&term)
214            .cloned()
215            .unwrap_or_else(|| self.make_clause_name(&term));
216        if let Some(list) = &self.config.keep_only_clauses {
217            if !list.contains(&name.0) && !list.contains(&format!("|{}|", &name.0)) {
218                // Discard clause.
219                eprintln!("Discarding {}", name.0);
220                return Self::assert_true();
221            }
222        }
223        let term = if self.config.get_unsat_core {
224            Self::set_attribute(term, "named".to_string(), AttributeValue::Symbol(name))
225        } else {
226            term
227        };
228        let value = self.visitor().visit_assert(term);
229        self.process_command(value)
230    }
231
232    fn visit_set_option(&mut self, keyword: Keyword, value: AttributeValue) -> Command {
233        if self.discarded_options.contains(&keyword.0) {
234            return Self::assert_true();
235        }
236        let value = self.visitor().visit_set_option(keyword, value);
237        self.process_command(value)
238    }
239
240    fn visit_get_unsat_core(&mut self) -> Command {
241        if self.config.get_unsat_core {
242            // Will be re-added in Patcher.
243            return Self::assert_true();
244        }
245        let value = self.visitor().visit_get_unsat_core();
246        self.process_command(value)
247    }
248}
249
250#[derive(Debug, Clone, StructOpt)]
251pub struct PatcherConfig {
252    #[structopt(flatten)]
253    rewriter_config: RewriterConfig,
254}
255
256/// State of the SMT2 patcher.
257#[derive(Debug)]
258pub struct Patcher {
259    config: PatcherConfig,
260    script: Vec<Command>,
261}
262
263impl Patcher {
264    pub fn new(config: PatcherConfig) -> Self {
265        Self {
266            config,
267            script: Vec::new(),
268        }
269    }
270
271    pub fn read(&mut self, path: &Path) -> std::io::Result<()> {
272        let file = std::io::BufReader::new(std::fs::File::open(path)?);
273        let mut discarded_options = HashSet::new();
274        if self.config.rewriter_config.get_unsat_core {
275            discarded_options.insert(PRODUCE_UNSAT_CORES.to_string());
276        }
277        let rewriter = Rewriter::new(self.config.rewriter_config.clone(), discarded_options);
278        let mut stream = smt2parser::CommandStream::new(file, rewriter);
279        let assert_true = Rewriter::assert_true();
280        for result in &mut stream {
281            match result {
282                Ok(command) if command == assert_true => {}
283                Ok(command)
284                    if self.config.rewriter_config.get_unsat_core
285                        && command == Command::CheckSat =>
286                {
287                    self.script.push(command);
288                    self.script.push(Command::GetUnsatCore);
289                }
290                Ok(command) => {
291                    self.script.push(command);
292                }
293                Err(error) => {
294                    panic!("error:\n --> {}", error.location_in_file(path));
295                }
296            }
297        }
298        Ok(())
299    }
300
301    pub fn write(&self, path: &Path) -> std::io::Result<()> {
302        let mut file = std::fs::File::create(path)?;
303        if self.config.rewriter_config.get_unsat_core {
304            // TODO: repeat after resets
305            writeln!(file, "(set-option :{} true)", PRODUCE_UNSAT_CORES)?;
306        }
307        for command in &self.script {
308            writeln!(file, "{}", command)?;
309        }
310        Ok(())
311    }
312}