Skip to main content

trident/verify/smt/
mod.rs

1//! SMT-LIB2 encoder for Trident constraint systems.
2//!
3//! Encodes the `ConstraintSystem` from `sym.rs` as SMT-LIB2 queries
4//! compatible with Z3, CVC5, and other SMT solvers.
5//!
6//! Goldilocks field arithmetic is encoded using bitvector operations:
7//! - Field elements as 128-bit bitvectors (to handle multiplication overflow)
8//! - All operations mod p (p = 2^64 - 2^32 + 1)
9//! - Equality checks, range constraints, conditional assertions
10//!
11//! The encoder produces two kinds of queries:
12//! 1. **Safety check**: Is there an assignment that violates any constraint?
13//!    (check-sat on negation of all constraints)
14//! 2. **Witness existence**: For divine inputs, does a valid witness exist?
15//!    (check-sat on all constraints)
16
17use crate::sym::{Constraint, ConstraintSystem, SymValue, GOLDILOCKS_P};
18use std::collections::BTreeSet;
19
20/// Generate SMT-LIB2 encoding of a constraint system.
21///
22/// Returns the complete SMT-LIB2 script as a string.
23pub fn encode_system(system: &ConstraintSystem, mode: QueryMode) -> String {
24    let mut encoder = SmtEncoder::new(mode);
25    encoder.encode(system);
26    encoder.output
27}
28
29/// What kind of SMT query to generate.
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum QueryMode {
32    /// Check if constraints can be violated (negate and check-sat).
33    /// SAT → found a counterexample (bug). UNSAT → safe.
34    SafetyCheck,
35    /// Check if a valid witness exists for divine inputs.
36    /// SAT → witness found. UNSAT → no valid witness.
37    WitnessExistence,
38}
39
40/// Result of running an SMT solver.
41#[derive(Clone, Debug)]
42pub struct SmtResult {
43    /// Raw solver output.
44    pub output: String,
45    /// Parsed result.
46    pub status: SmtStatus,
47    /// Model (variable assignments) if SAT.
48    pub model: Option<String>,
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
52pub enum SmtStatus {
53    Sat,
54    Unsat,
55    Unknown,
56    Error(String),
57}
58
59struct SmtEncoder {
60    output: String,
61    mode: QueryMode,
62    declared_vars: BTreeSet<String>,
63}
64
65impl SmtEncoder {
66    fn new(mode: QueryMode) -> Self {
67        Self {
68            output: String::new(),
69            mode,
70            declared_vars: BTreeSet::new(),
71        }
72    }
73
74    fn emit(&mut self, s: &str) {
75        self.output.push_str(s);
76        self.output.push('\n');
77    }
78
79    fn encode(&mut self, system: &ConstraintSystem) {
80        // Header
81        self.emit("; Trident SMT-LIB2 encoding");
82        self.emit("; Generated by trident audit");
83        self.emit(&format!("; Mode: {:?}", self.mode));
84        self.emit(&format!("; Variables: {}", system.num_variables));
85        self.emit(&format!("; Constraints: {}", system.constraints.len()));
86        self.emit("");
87        self.emit("(set-logic QF_BV)");
88        self.emit("");
89
90        // Define the Goldilocks prime as a constant
91        self.emit("; Goldilocks prime: p = 2^64 - 2^32 + 1");
92        self.emit(&format!(
93            "(define-fun GOLDILOCKS_P () (_ BitVec 128) (_ bv{} 128))",
94            GOLDILOCKS_P
95        ));
96        self.emit("");
97
98        // Helper: field_mod(x) = x mod p (for 128-bit intermediate results)
99        self.emit("; Field modular reduction");
100        self.emit("(define-fun field_mod ((x (_ BitVec 128))) (_ BitVec 128)");
101        self.emit("  (bvurem x GOLDILOCKS_P))");
102        self.emit("");
103
104        // Declare all variables
105        self.emit("; Variable declarations");
106        self.declare_variables(system);
107        self.emit("");
108
109        // Field range constraints: all variables < p
110        self.emit("; Field range constraints (all values < p)");
111        for var_name in &self.declared_vars.clone() {
112            self.emit(&format!("(assert (bvult {} GOLDILOCKS_P))", var_name));
113        }
114        self.emit("");
115
116        // Encode constraints
117        match self.mode {
118            QueryMode::SafetyCheck => {
119                // For safety: assert negation of each constraint and check if
120                // any can be violated. We assert all constraints hold, then
121                // ask if this is satisfiable. If UNSAT, no valid input exists
122                // (vacuously safe). If SAT, all constraints hold for that input.
123                //
124                // Actually: we want to find a COUNTEREXAMPLE. So we assert the
125                // negation of at least one constraint.
126                self.emit("; Safety check: can any constraint be violated?");
127                self.emit("; (SAT = counterexample found, UNSAT = all constraints hold)");
128                self.emit("");
129
130                if system.constraints.is_empty() {
131                    self.emit("; No constraints to check");
132                    self.emit("(assert true)");
133                } else {
134                    // Assert: NOT (c1 AND c2 AND ... AND cn)
135                    // Equivalent to: c1_neg OR c2_neg OR ... OR cn_neg
136                    let mut disjuncts = Vec::new();
137                    for (i, constraint) in system.constraints.iter().enumerate() {
138                        let smt = self.encode_constraint(constraint);
139                        self.emit(&format!("; Constraint #{}", i));
140                        disjuncts.push(format!("(not {})", smt));
141                    }
142
143                    if disjuncts.len() == 1 {
144                        self.emit(&format!("(assert {})", disjuncts[0]));
145                    } else {
146                        self.emit(&format!("(assert (or {}))", disjuncts.join(" ")));
147                    }
148                }
149            }
150            QueryMode::WitnessExistence => {
151                // For witness: assert all constraints and check satisfiability.
152                // SAT → valid witness exists for divine inputs.
153                self.emit("; Witness existence: do valid divine values exist?");
154                self.emit("; (SAT = witness found, UNSAT = no valid witness)");
155                self.emit("");
156
157                for (i, constraint) in system.constraints.iter().enumerate() {
158                    let smt = self.encode_constraint(constraint);
159                    self.emit(&format!("; Constraint #{}", i));
160                    self.emit(&format!("(assert {})", smt));
161                }
162            }
163        }
164
165        self.emit("");
166        self.emit("(check-sat)");
167        self.emit("(get-model)");
168        self.emit("(exit)");
169    }
170
171    fn declare_variables(&mut self, system: &ConstraintSystem) {
172        // Collect all variable names from the constraint system
173        let mut var_names: Vec<String> = Vec::new();
174
175        for (name, max_version) in &system.variables {
176            for v in 0..=*max_version {
177                let var_name = if v == 0 {
178                    name.clone()
179                } else {
180                    format!("{}_{}", name, v)
181                };
182                let smt_name = sanitize_smt_name(&var_name);
183                if !self.declared_vars.contains(&smt_name) {
184                    var_names.push(smt_name.clone());
185                    self.declared_vars.insert(smt_name);
186                }
187            }
188        }
189
190        // Also declare pub_input and divine variables
191        for pi in &system.pub_inputs {
192            let smt_name = sanitize_smt_name(&pi.to_string());
193            if !self.declared_vars.contains(&smt_name) {
194                var_names.push(smt_name.clone());
195                self.declared_vars.insert(smt_name);
196            }
197        }
198        for di in &system.divine_inputs {
199            let smt_name = sanitize_smt_name(&di.to_string());
200            if !self.declared_vars.contains(&smt_name) {
201                var_names.push(smt_name.clone());
202                self.declared_vars.insert(smt_name);
203            }
204        }
205
206        for name in &var_names {
207            self.emit(&format!("(declare-fun {} () (_ BitVec 128))", name));
208        }
209    }
210
211    fn encode_constraint(&mut self, constraint: &Constraint) -> String {
212        match constraint {
213            Constraint::Equal(a, b) => {
214                let sa = self.encode_value(a);
215                let sb = self.encode_value(b);
216                format!("(= {} {})", sa, sb)
217            }
218            Constraint::AssertTrue(v) => {
219                let sv = self.encode_value(v);
220                // In Trident, true = 1, false = 0. Assert v != 0.
221                format!("(not (= {} (_ bv0 128)))", sv)
222            }
223            Constraint::Conditional(cond, inner) => {
224                let sc = self.encode_value(cond);
225                let si = self.encode_constraint(inner);
226                // If cond != 0 then inner must hold
227                format!("(=> (not (= {} (_ bv0 128))) {})", sc, si)
228            }
229            Constraint::RangeU32(v) => {
230                let sv = self.encode_value(v);
231                format!("(bvule {} (_ bv{} 128))", sv, u32::MAX)
232            }
233            Constraint::DigestEqual(a, b) => {
234                let mut conjuncts = Vec::new();
235                for (x, y) in a.iter().zip(b.iter()) {
236                    let sx = self.encode_value(x);
237                    let sy = self.encode_value(y);
238                    conjuncts.push(format!("(= {} {})", sx, sy));
239                }
240                if conjuncts.len() == 1 {
241                    conjuncts[0].clone()
242                } else {
243                    format!("(and {})", conjuncts.join(" "))
244                }
245            }
246        }
247    }
248
249    fn encode_value(&mut self, value: &SymValue) -> String {
250        match value {
251            SymValue::Const(c) => {
252                format!("(_ bv{} 128)", c % GOLDILOCKS_P)
253            }
254            SymValue::Var(var) => {
255                let name = sanitize_smt_name(&var.to_string());
256                // Ensure variable is declared
257                if !self.declared_vars.contains(&name) {
258                    self.declared_vars.insert(name.clone());
259                    // This will be emitted out of order, but that's OK for SMT-LIB2
260                    // in incremental mode. For safety, we handle this in declare_variables.
261                }
262                name
263            }
264            SymValue::Add(a, b) => {
265                let sa = self.encode_value(a);
266                let sb = self.encode_value(b);
267                format!("(field_mod (bvadd {} {}))", sa, sb)
268            }
269            SymValue::Mul(a, b) => {
270                let sa = self.encode_value(a);
271                let sb = self.encode_value(b);
272                format!("(field_mod (bvmul {} {}))", sa, sb)
273            }
274            SymValue::Sub(a, b) => {
275                let sa = self.encode_value(a);
276                let sb = self.encode_value(b);
277                // a - b mod p = (a + p - b) mod p
278                format!("(field_mod (bvadd {} (bvsub GOLDILOCKS_P {})))", sa, sb)
279            }
280            SymValue::Neg(a) => {
281                let sa = self.encode_value(a);
282                format!("(field_mod (bvsub GOLDILOCKS_P {}))", sa)
283            }
284            SymValue::Inv(a) => {
285                // Inverse is hard to encode directly in BV. We use an
286                // existential: declare a fresh variable inv_x, assert
287                // inv_x * x == 1 mod p.
288                let _sa = self.encode_value(a);
289                let inv_name = format!("__inv_{}", self.declared_vars.len());
290                self.declared_vars.insert(inv_name.clone());
291                // We can't add declarations mid-stream easily, so just
292                // return the variable name. The caller should handle this.
293                inv_name
294            }
295            SymValue::Eq(a, b) => {
296                let sa = self.encode_value(a);
297                let sb = self.encode_value(b);
298                // Returns 1 if equal, 0 otherwise
299                format!("(ite (= {} {}) (_ bv1 128) (_ bv0 128))", sa, sb)
300            }
301            SymValue::Lt(a, b) => {
302                let sa = self.encode_value(a);
303                let sb = self.encode_value(b);
304                format!("(ite (bvult {} {}) (_ bv1 128) (_ bv0 128))", sa, sb)
305            }
306            SymValue::Hash(inputs, index) => {
307                // Hash is opaque — create uninterpreted function
308                let hash_name = format!("__hash_{}_{}", inputs.len(), index);
309                let name = sanitize_smt_name(&hash_name);
310                if !self.declared_vars.contains(&name) {
311                    self.declared_vars.insert(name.clone());
312                }
313                name
314            }
315            SymValue::Divine(idx) => {
316                let name = format!("divine_{}", idx);
317                let smt_name = sanitize_smt_name(&name);
318                if !self.declared_vars.contains(&smt_name) {
319                    self.declared_vars.insert(smt_name.clone());
320                }
321                smt_name
322            }
323            SymValue::PubInput(idx) => {
324                let name = format!("pub_in_{}", idx);
325                let smt_name = sanitize_smt_name(&name);
326                if !self.declared_vars.contains(&smt_name) {
327                    self.declared_vars.insert(smt_name.clone());
328                }
329                smt_name
330            }
331            SymValue::Ite(cond, then_val, else_val) => {
332                let sc = self.encode_value(cond);
333                let st = self.encode_value(then_val);
334                let se = self.encode_value(else_val);
335                format!("(ite (not (= {} (_ bv0 128))) {} {})", sc, st, se)
336            }
337            SymValue::FieldAccess(inner, field) => {
338                // Field access is opaque — create uninterpreted function
339                let inner_enc = self.encode_value(inner);
340                let name = format!("__field_{}_{}", inner_enc, sanitize_smt_name(field));
341                if !self.declared_vars.contains(&name) {
342                    self.declared_vars.insert(name.clone());
343                }
344                name
345            }
346        }
347    }
348}
349
350/// Sanitize a variable name for SMT-LIB2 (replace dots with underscores, etc.).
351fn sanitize_smt_name(name: &str) -> String {
352    let sanitized: String = name
353        .chars()
354        .map(|c| {
355            if c.is_alphanumeric() || c == '_' {
356                c
357            } else {
358                '_'
359            }
360        })
361        .collect();
362    // SMT-LIB2 names can't start with a digit
363    if sanitized.starts_with(|c: char| c.is_ascii_digit()) {
364        format!("v_{}", sanitized)
365    } else {
366        sanitized
367    }
368}
369
370// --- Z3 Process Runner ---
371
372mod runner;
373pub use runner::run_z3;
374
375#[cfg(test)]
376mod tests;