1use smt2parser::{
8 concrete::*,
9 visitors::{CommandVisitor, TermVisitor},
10};
11use std::{
12 collections::{HashMap, HashSet},
13 io::Write,
14 path::Path,
15 str::FromStr,
16};
17use structopt::StructOpt;
18
19#[derive(Debug, Clone, StructOpt)]
21pub struct RewriterConfig {
22 #[structopt(long, parse(from_str = parse_clauses))]
23 keep_only_clauses: Option<HashSet<String>>,
24
25 #[structopt(long)]
26 get_unsat_core: bool,
27
28 #[structopt(long)]
29 tag_quantifiers: bool,
30
31 #[structopt(long, parse(try_from_str = try_parse_weights))]
32 set_weights: Option<HashMap<String, usize>>,
33}
34
35fn parse_clauses(src: &str) -> HashSet<String> {
36 let src = src.trim();
37 let src = if src.starts_with('(') && src.ends_with(')') {
38 &src[1..src.len() - 1].trim()
39 } else {
40 src
41 };
42 src.split(' ').map(String::from).collect()
43}
44
45fn try_parse_weights(src: &str) -> Result<HashMap<String, usize>, std::num::ParseIntError> {
46 let src = src.trim();
47 src.split(' ')
48 .map(|s| {
49 let mut items = s.splitn(2, '=');
50 let key = items.next().unwrap();
51 let value = items.next().unwrap_or("0").parse()?;
52 Ok((key.to_string(), value))
53 })
54 .collect()
55}
56
57#[derive(Debug)]
59pub struct Rewriter {
60 config: RewriterConfig,
61 discarded_options: HashSet<String>,
62 builder: SyntaxBuilder,
63 clause_count: usize,
64 quantifier_count: usize,
65}
66
67const PRODUCE_UNSAT_CORES: &str = "produce-unsat-cores";
68const CLAUSE: &str = "clause!";
69const QUANT: &str = "quant!";
70
71impl Rewriter {
72 pub fn new(config: RewriterConfig, discarded_options: HashSet<String>) -> Self {
73 Self {
74 config,
75 discarded_options,
76 builder: SyntaxBuilder::default(),
77 clause_count: 0,
78 quantifier_count: 0,
79 }
80 }
81
82 fn make_clause_name(&mut self, term: &Term) -> Symbol {
83 let mut qid = String::new();
84 if let Term::Forall { term, .. } = term {
85 if let Some(s) = Self::get_quantifier_name(term) {
86 qid = format!("!{}", s.0);
87 }
88 }
89 let s = format!("{}{}{}", CLAUSE, self.clause_count, qid);
90 self.clause_count += 1;
91 Symbol(s)
92 }
93
94 fn make_quantifier_name(&mut self) -> Symbol {
95 let s = format!("{}{}", QUANT, self.quantifier_count);
96 self.quantifier_count += 1;
97 Symbol(s)
98 }
99
100 #[inline]
102 fn assert_true() -> Command {
103 Command::Assert {
104 term: Term::QualIdentifier(QualIdentifier::Simple {
105 identifier: Identifier::Simple {
106 symbol: Symbol("true".to_string()),
107 },
108 }),
109 }
110 }
111
112 fn get_clause_name(term: &Term) -> Option<&Symbol> {
113 if let Some(AttributeValue::Symbol(s)) = Self::get_attribute(term, "named") {
114 return Some(s);
115 }
116 None
117 }
118
119 fn get_quantifier_name(term: &Term) -> Option<&Symbol> {
120 match Self::get_attribute(term, "qid") {
121 Some(AttributeValue::Symbol(s)) => Some(s),
122 _ => None,
123 }
124 }
125
126 fn get_attribute<'a>(term: &'a Term, key: &str) -> Option<&'a AttributeValue> {
127 match term {
128 Term::Attributes { attributes, .. } => {
129 for (k, v) in attributes {
130 if k.0 == key {
131 return Some(v);
132 }
133 }
134 None
135 }
136 _ => None,
137 }
138 }
139
140 fn set_attribute(mut term: Term, key: String, value: AttributeValue) -> Term {
141 match &mut term {
142 Term::Attributes { attributes, .. } => {
143 for (k, v) in attributes.iter_mut() {
144 if k.0 == key {
145 *v = value;
146 return term;
147 }
148 }
149 attributes.push((Keyword(key), value));
150 term
151 }
152 _ => Term::Attributes {
153 term: Box::new(term),
154 attributes: vec![(Keyword(key), value)],
155 },
156 }
157 }
158}
159
160impl smt2parser::rewriter::Rewriter for Rewriter {
161 type V = SyntaxBuilder;
162
163 fn visitor(&mut self) -> &mut SyntaxBuilder {
164 &mut self.builder
165 }
166
167 fn process_symbol(&mut self, symbol: Symbol) -> Symbol {
168 if symbol.0.starts_with(CLAUSE) {
171 if let Ok(i) = usize::from_str(&symbol.0[CLAUSE.len()..]) {
172 self.clause_count = std::cmp::max(self.clause_count, i + 1);
173 }
174 } else if symbol.0.starts_with(QUANT) {
175 if let Ok(i) = usize::from_str(&symbol.0[QUANT.len()..]) {
176 self.clause_count = std::cmp::max(self.clause_count, i + 1);
177 }
178 }
179 symbol
180 }
181
182 fn visit_forall(&mut self, vars: Vec<(Symbol, Sort)>, term: Term) -> Term {
183 let name = Self::get_quantifier_name(&term)
184 .cloned()
185 .unwrap_or_else(|| self.make_quantifier_name());
186 let term = if !self.config.tag_quantifiers {
188 term
189 } else {
190 Self::set_attribute(
191 term,
192 "qid".to_string(),
193 AttributeValue::Symbol(name.clone()),
194 )
195 };
196 let term = match &self.config.set_weights {
198 None => term,
199 Some(weights) => {
200 let w = *weights.get(&name.0).unwrap_or(&0);
201 Self::set_attribute(
202 term,
203 "weight".to_string(),
204 AttributeValue::Constant(Constant::Numeral(w.into())),
205 )
206 }
207 };
208 let value = self.visitor().visit_forall(vars, term);
209 self.process_term(value)
210 }
211
212 fn visit_assert(&mut self, term: Term) -> Command {
213 let name = Self::get_clause_name(&term)
214 .cloned()
215 .unwrap_or_else(|| self.make_clause_name(&term));
216 if let Some(list) = &self.config.keep_only_clauses {
217 if !list.contains(&name.0) && !list.contains(&format!("|{}|", &name.0)) {
218 eprintln!("Discarding {}", name.0);
220 return Self::assert_true();
221 }
222 }
223 let term = if self.config.get_unsat_core {
224 Self::set_attribute(term, "named".to_string(), AttributeValue::Symbol(name))
225 } else {
226 term
227 };
228 let value = self.visitor().visit_assert(term);
229 self.process_command(value)
230 }
231
232 fn visit_set_option(&mut self, keyword: Keyword, value: AttributeValue) -> Command {
233 if self.discarded_options.contains(&keyword.0) {
234 return Self::assert_true();
235 }
236 let value = self.visitor().visit_set_option(keyword, value);
237 self.process_command(value)
238 }
239
240 fn visit_get_unsat_core(&mut self) -> Command {
241 if self.config.get_unsat_core {
242 return Self::assert_true();
244 }
245 let value = self.visitor().visit_get_unsat_core();
246 self.process_command(value)
247 }
248}
249
250#[derive(Debug, Clone, StructOpt)]
251pub struct PatcherConfig {
252 #[structopt(flatten)]
253 rewriter_config: RewriterConfig,
254}
255
256#[derive(Debug)]
258pub struct Patcher {
259 config: PatcherConfig,
260 script: Vec<Command>,
261}
262
263impl Patcher {
264 pub fn new(config: PatcherConfig) -> Self {
265 Self {
266 config,
267 script: Vec::new(),
268 }
269 }
270
271 pub fn read(&mut self, path: &Path) -> std::io::Result<()> {
272 let file = std::io::BufReader::new(std::fs::File::open(path)?);
273 let mut discarded_options = HashSet::new();
274 if self.config.rewriter_config.get_unsat_core {
275 discarded_options.insert(PRODUCE_UNSAT_CORES.to_string());
276 }
277 let rewriter = Rewriter::new(self.config.rewriter_config.clone(), discarded_options);
278 let mut stream = smt2parser::CommandStream::new(file, rewriter);
279 let assert_true = Rewriter::assert_true();
280 for result in &mut stream {
281 match result {
282 Ok(command) if command == assert_true => {}
283 Ok(command)
284 if self.config.rewriter_config.get_unsat_core
285 && command == Command::CheckSat =>
286 {
287 self.script.push(command);
288 self.script.push(Command::GetUnsatCore);
289 }
290 Ok(command) => {
291 self.script.push(command);
292 }
293 Err(error) => {
294 panic!("error:\n --> {}", error.location_in_file(path));
295 }
296 }
297 }
298 Ok(())
299 }
300
301 pub fn write(&self, path: &Path) -> std::io::Result<()> {
302 let mut file = std::fs::File::create(path)?;
303 if self.config.rewriter_config.get_unsat_core {
304 writeln!(file, "(set-option :{} true)", PRODUCE_UNSAT_CORES)?;
306 }
307 for command in &self.script {
308 writeln!(file, "{}", command)?;
309 }
310 Ok(())
311 }
312}