Skip to main content

trident/verify/sym/
mod.rs

1//! Symbolic execution engine for Trident programs.
2//!
3//! Transforms the AST into a symbolic constraint system suitable for
4//! algebraic verification, bounded model checking, and SMT solving.
5//!
6//! Since Trident programs have no heap, no recursion, bounded loops,
7//! and operate over a finite field (Goldilocks: p = 2^64 - 2^32 + 1),
8//! every program produces a finite, decidable constraint system.
9//!
10//! The symbolic engine:
11//! 1. Assigns a symbolic variable to each `let` binding
12//! 2. Tracks constraints from `assert`, `assert_eq`, `assert_digest`
13//! 3. Encodes `if/else` as path conditions
14//! 4. Unrolls bounded `for` loops up to their bound
15//! 5. Inlines function calls (no recursion → always terminates)
16//! 6. Produces a `ConstraintSystem` that can be checked by:
17//!    - The algebraic solver (polynomial identity testing)
18//!    - A bounded model checker (enumerate concrete values)
19//!    - An SMT solver (Z3/CVC5 via SMT-LIB encoding)
20
21use std::collections::BTreeMap;
22
23use crate::ast::*;
24use crate::span::Spanned;
25
26/// The prime modulus for the Goldilocks field.
27pub const GOLDILOCKS_P: u64 = crate::field::goldilocks::MODULUS;
28
29mod executor;
30mod expr;
31#[cfg(test)]
32mod tests;
33
34pub use executor::*;
35
36// ─── Symbolic Values ───────────────────────────────────────────────
37
38/// A symbolic value in the constraint system.
39#[derive(Clone, Debug, PartialEq, Eq, Hash)]
40pub enum SymValue {
41    /// A concrete constant.
42    Const(u64),
43    /// A named symbolic variable (from `let`, `divine`, `pub_read`, etc.).
44    Var(SymVar),
45    /// Addition: a + b (mod p).
46    Add(Box<SymValue>, Box<SymValue>),
47    /// Multiplication: a * b (mod p).
48    Mul(Box<SymValue>, Box<SymValue>),
49    /// Subtraction: a - b (mod p).
50    Sub(Box<SymValue>, Box<SymValue>),
51    /// Negation: -a (mod p).
52    Neg(Box<SymValue>),
53    /// Multiplicative inverse: 1/a (mod p). Undefined for a = 0.
54    Inv(Box<SymValue>),
55    /// Equality test: 1 if a == b, else 0.
56    Eq(Box<SymValue>, Box<SymValue>),
57    /// Less-than test: 1 if a < b, else 0 (on canonical representatives).
58    Lt(Box<SymValue>, Box<SymValue>),
59    /// Hash output: hash(inputs)[index]. Treated as opaque.
60    Hash(Vec<SymValue>, usize),
61    /// A divine (nondeterministic) input. Each occurrence is unique.
62    Divine(u32),
63    /// Struct field access: value.field_name.
64    FieldAccess(Box<SymValue>, String),
65    /// Public input. Sequential read index.
66    PubInput(u32),
67    /// If-then-else: if cond then a else b.
68    Ite(Box<SymValue>, Box<SymValue>, Box<SymValue>),
69}
70
71impl SymValue {
72    pub fn is_const(&self) -> bool {
73        matches!(self, SymValue::Const(_))
74    }
75
76    pub fn as_const(&self) -> Option<u64> {
77        match self {
78            SymValue::Const(v) => Some(*v),
79            _ => None,
80        }
81    }
82
83    /// Check if this value contains a Hash node or opaque intrinsic output.
84    ///
85    /// Opaque values (hashes, intrinsic calls, tuple projections) cannot be
86    /// meaningfully evaluated by random testing since the solver assigns
87    /// arbitrary values that don't reflect the actual computation.
88    pub fn contains_opaque(&self) -> bool {
89        match self {
90            SymValue::Hash(_, _) => true,
91            SymValue::Var(var) => {
92                var.name.starts_with("__proj_")
93                    || var.name.starts_with("__hash")
94                    || var.name.starts_with("__divine")
95            }
96            SymValue::Add(a, b)
97            | SymValue::Mul(a, b)
98            | SymValue::Sub(a, b)
99            | SymValue::Eq(a, b)
100            | SymValue::Lt(a, b) => a.contains_opaque() || b.contains_opaque(),
101            SymValue::Neg(a) | SymValue::Inv(a) => a.contains_opaque(),
102            SymValue::Ite(c, t, e) => {
103                c.contains_opaque() || t.contains_opaque() || e.contains_opaque()
104            }
105            SymValue::FieldAccess(inner, _) => inner.contains_opaque(),
106            SymValue::Const(_) | SymValue::Divine(_) | SymValue::PubInput(_) => false,
107        }
108    }
109
110    /// Check if this value is an external input (pub_read or divine).
111    /// Range checks on external inputs are input preconditions, not bugs.
112    pub fn is_external_input(&self) -> bool {
113        match self {
114            SymValue::Var(var) => {
115                var.name.starts_with("pub_in_") || var.name.starts_with("divine_")
116            }
117            SymValue::PubInput(_) | SymValue::Divine(_) => true,
118            _ => false,
119        }
120    }
121
122    /// Simplify obvious identities: x + 0 = x, x * 1 = x, etc.
123    pub fn simplify(&self) -> SymValue {
124        match self {
125            SymValue::Add(a, b) => {
126                let a = a.simplify();
127                let b = b.simplify();
128                match (&a, &b) {
129                    (SymValue::Const(0), _) => b,
130                    (_, SymValue::Const(0)) => a,
131                    (SymValue::Const(x), SymValue::Const(y)) => {
132                        SymValue::Const(((*x as u128 + *y as u128) % GOLDILOCKS_P as u128) as u64)
133                    }
134                    _ => SymValue::Add(Box::new(a), Box::new(b)),
135                }
136            }
137            SymValue::Mul(a, b) => {
138                let a = a.simplify();
139                let b = b.simplify();
140                match (&a, &b) {
141                    (SymValue::Const(0), _) | (_, SymValue::Const(0)) => SymValue::Const(0),
142                    (SymValue::Const(1), _) => b,
143                    (_, SymValue::Const(1)) => a,
144                    (SymValue::Const(x), SymValue::Const(y)) => {
145                        SymValue::Const(((*x as u128 * *y as u128) % GOLDILOCKS_P as u128) as u64)
146                    }
147                    _ => SymValue::Mul(Box::new(a), Box::new(b)),
148                }
149            }
150            SymValue::Sub(a, b) => {
151                let a = a.simplify();
152                let b = b.simplify();
153                match (&a, &b) {
154                    (_, SymValue::Const(0)) => a,
155                    (SymValue::Const(x), SymValue::Const(y)) => SymValue::Const(
156                        (((*x as u128 + GOLDILOCKS_P as u128) - *y as u128) % GOLDILOCKS_P as u128)
157                            as u64,
158                    ),
159                    _ if a == b => SymValue::Const(0),
160                    _ => SymValue::Sub(Box::new(a), Box::new(b)),
161                }
162            }
163            SymValue::Neg(a) => {
164                let a = a.simplify();
165                match &a {
166                    SymValue::Const(0) => SymValue::Const(0),
167                    SymValue::Const(v) => SymValue::Const(GOLDILOCKS_P - v),
168                    _ => SymValue::Neg(Box::new(a)),
169                }
170            }
171            SymValue::Eq(a, b) => {
172                let a = a.simplify();
173                let b = b.simplify();
174                if a == b {
175                    SymValue::Const(1)
176                } else {
177                    match (&a, &b) {
178                        (SymValue::Const(x), SymValue::Const(y)) => {
179                            SymValue::Const(if x == y { 1 } else { 0 })
180                        }
181                        _ => SymValue::Eq(Box::new(a), Box::new(b)),
182                    }
183                }
184            }
185            _ => self.clone(),
186        }
187    }
188}
189
190/// A named symbolic variable.
191#[derive(Clone, Debug, PartialEq, Eq, Hash)]
192pub struct SymVar {
193    pub name: String,
194    /// SSA version number (for mutable variables).
195    pub version: u32,
196}
197
198impl std::fmt::Display for SymVar {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        if self.version == 0 {
201            write!(f, "{}", self.name)
202        } else {
203            write!(f, "{}_{}", self.name, self.version)
204        }
205    }
206}
207
208// ─── Constraints ───────────────────────────────────────────────────
209
210/// A constraint in the system.
211#[derive(Clone, Debug)]
212pub enum Constraint {
213    /// a == b (from `assert_eq` or `assert(a == b)`)
214    Equal(SymValue, SymValue),
215    /// a == 0 (from `assert(cond)` where cond is truthy)
216    AssertTrue(SymValue),
217    /// Conditional: if path_condition then constraint holds
218    Conditional(SymValue, Box<Constraint>),
219    /// Range check: value fits in U32 (from `as_u32`)
220    RangeU32(SymValue),
221    /// Digest equality: 5-element vector comparison
222    DigestEqual(Vec<SymValue>, Vec<SymValue>),
223}
224
225impl Constraint {
226    /// Check if this constraint is trivially satisfied.
227    pub fn is_trivial(&self) -> bool {
228        match self {
229            Constraint::Equal(a, b) => a == b,
230            Constraint::AssertTrue(v) => matches!(v, SymValue::Const(1)),
231            Constraint::RangeU32(v) => {
232                if let SymValue::Const(c) = v {
233                    *c <= u32::MAX as u64
234                } else {
235                    false
236                }
237            }
238            Constraint::DigestEqual(a, b) => a == b,
239            Constraint::Conditional(cond, inner) => {
240                matches!(cond, SymValue::Const(0)) || inner.is_trivial()
241            }
242        }
243    }
244
245    /// Check if this constraint is trivially violated.
246    pub fn is_violated(&self) -> bool {
247        match self {
248            Constraint::Equal(SymValue::Const(a), SymValue::Const(b)) => a != b,
249            Constraint::AssertTrue(SymValue::Const(0)) => true,
250            Constraint::RangeU32(SymValue::Const(c)) => *c > u32::MAX as u64,
251            _ => false,
252        }
253    }
254
255    /// Check if this constraint depends on a hash output.
256    ///
257    /// Hash-dependent constraints (e.g. `hash(secret) == expected`) require
258    /// specific witness values to satisfy — random testing will almost always
259    /// report them as violated. The solver should classify these as
260    /// "witness-required" rather than "violated".
261    pub fn is_hash_dependent(&self) -> bool {
262        match self {
263            Constraint::Equal(a, b) => a.contains_opaque() || b.contains_opaque(),
264            Constraint::AssertTrue(v) => v.contains_opaque(),
265            Constraint::Conditional(_, inner) => inner.is_hash_dependent(),
266            Constraint::DigestEqual(a, b) => {
267                a.iter().any(|v| v.contains_opaque()) || b.iter().any(|v| v.contains_opaque())
268            }
269            // RangeU32 constraints are input preconditions: as_u32(x) asserts x fits
270            // in 32 bits. Random field values almost never satisfy this, producing
271            // false positives. Only concrete violations (static analysis) matter.
272            Constraint::RangeU32(_) => true,
273        }
274    }
275}
276
277// ─── Constraint System ─────────────────────────────────────────────
278
279/// The complete constraint system for a program or function.
280#[derive(Clone, Debug)]
281pub struct ConstraintSystem {
282    /// All constraints that must hold.
283    pub constraints: Vec<Constraint>,
284    /// Symbolic variables introduced (name → latest version).
285    pub variables: BTreeMap<String, u32>,
286    /// Public inputs read (in order).
287    pub pub_inputs: Vec<SymVar>,
288    /// Public outputs written (in order).
289    pub pub_outputs: Vec<SymValue>,
290    /// Divine inputs consumed (in order).
291    pub divine_inputs: Vec<SymVar>,
292    /// Number of unique symbolic variables.
293    pub num_variables: u32,
294}
295
296impl ConstraintSystem {
297    pub fn new() -> Self {
298        Self {
299            constraints: Vec::new(),
300            variables: BTreeMap::new(),
301            pub_inputs: Vec::new(),
302            pub_outputs: Vec::new(),
303            divine_inputs: Vec::new(),
304            num_variables: 0,
305        }
306    }
307
308    /// Count of non-trivial constraints.
309    pub fn active_constraints(&self) -> usize {
310        self.constraints.iter().filter(|c| !c.is_trivial()).count()
311    }
312
313    /// Check for trivially violated constraints (static analysis).
314    pub fn violated_constraints(&self) -> Vec<&Constraint> {
315        self.constraints
316            .iter()
317            .filter(|c| c.is_violated())
318            .collect()
319    }
320
321    /// Summary for display.
322    pub fn summary(&self) -> String {
323        format!(
324            "Variables: {}, Constraints: {} ({} active), Inputs: {} pub + {} divine, Outputs: {}",
325            self.num_variables,
326            self.constraints.len(),
327            self.active_constraints(),
328            self.pub_inputs.len(),
329            self.divine_inputs.len(),
330            self.pub_outputs.len(),
331        )
332    }
333}
334
335// ─── Analysis Functions ────────────────────────────────────────────
336
337/// Analyze a file and return its constraint system (main function only).
338pub fn analyze(file: &File) -> ConstraintSystem {
339    SymExecutor::new().execute_file(file)
340}
341
342/// Analyze a single function by name, treating parameters as symbolic inputs.
343pub fn analyze_function(file: &File, fn_name: &str) -> ConstraintSystem {
344    SymExecutor::new().execute_function(file, fn_name)
345}
346
347/// Analyze all functions in a file, returning per-function constraint systems.
348/// For programs, analyzes `main`. For modules, analyzes every non-test function with a body.
349pub fn analyze_all(file: &File) -> Vec<(String, ConstraintSystem)> {
350    let mut results = Vec::new();
351    for item in &file.items {
352        if let Item::Fn(func) = &item.node {
353            if func.body.is_some() && !func.is_test && func.intrinsic.is_none() {
354                let system = SymExecutor::new().execute_function(file, &func.name.node);
355                results.push((func.name.node.clone(), system));
356            }
357        }
358    }
359    results
360}
361
362/// Verification result for a function or program.
363#[derive(Clone, Debug)]
364pub struct VerificationResult {
365    /// The function or program name.
366    pub name: String,
367    /// Total constraints.
368    pub total_constraints: usize,
369    /// Active (non-trivial) constraints.
370    pub active_constraints: usize,
371    /// Trivially violated constraints (definite bugs).
372    pub violated: Vec<String>,
373    /// Redundant (trivially satisfied) constraints.
374    pub redundant_count: usize,
375    /// Summary of the constraint system.
376    pub system_summary: String,
377}
378
379impl VerificationResult {
380    pub fn is_safe(&self) -> bool {
381        self.violated.is_empty()
382    }
383
384    pub fn format_report(&self) -> String {
385        let mut report = String::new();
386        report.push_str(&format!("Verification: {}\n", self.name));
387        report.push_str(&format!("  {}\n", self.system_summary));
388        report.push_str(&format!(
389            "  Constraints: {} total, {} active, {} redundant\n",
390            self.total_constraints, self.active_constraints, self.redundant_count,
391        ));
392        if self.violated.is_empty() {
393            report.push_str("  Status: SAFE (no trivially violated assertions)\n");
394        } else {
395            report.push_str(&format!(
396                "  Status: VIOLATED ({} assertion(s) always fail)\n",
397                self.violated.len()
398            ));
399            for v in &self.violated {
400                report.push_str(&format!("    - {}\n", v));
401            }
402        }
403        report
404    }
405}
406
407/// Verify a file: analyze constraints and check for violations.
408/// For programs, checks `main`. For modules, checks all functions.
409pub fn verify_file(file: &File) -> VerificationResult {
410    let system = analyze(file);
411    let violated: Vec<String> = system
412        .violated_constraints()
413        .iter()
414        .map(|c| format!("{:?}", c))
415        .collect();
416    let redundant_count = system.constraints.iter().filter(|c| c.is_trivial()).count();
417
418    VerificationResult {
419        name: file.name.node.clone(),
420        total_constraints: system.constraints.len(),
421        active_constraints: system.active_constraints(),
422        violated,
423        redundant_count,
424        system_summary: system.summary(),
425    }
426}