1use std::collections::BTreeMap;
22
23use crate::ast::*;
24use crate::span::Spanned;
25
26pub const GOLDILOCKS_P: u64 = crate::field::goldilocks::MODULUS;
28
29mod executor;
30mod expr;
31#[cfg(test)]
32mod tests;
33
34pub use executor::*;
35
36#[derive(Clone, Debug, PartialEq, Eq, Hash)]
40pub enum SymValue {
41 Const(u64),
43 Var(SymVar),
45 Add(Box<SymValue>, Box<SymValue>),
47 Mul(Box<SymValue>, Box<SymValue>),
49 Sub(Box<SymValue>, Box<SymValue>),
51 Neg(Box<SymValue>),
53 Inv(Box<SymValue>),
55 Eq(Box<SymValue>, Box<SymValue>),
57 Lt(Box<SymValue>, Box<SymValue>),
59 Hash(Vec<SymValue>, usize),
61 Divine(u32),
63 FieldAccess(Box<SymValue>, String),
65 PubInput(u32),
67 Ite(Box<SymValue>, Box<SymValue>, Box<SymValue>),
69}
70
71impl SymValue {
72 pub fn is_const(&self) -> bool {
73 matches!(self, SymValue::Const(_))
74 }
75
76 pub fn as_const(&self) -> Option<u64> {
77 match self {
78 SymValue::Const(v) => Some(*v),
79 _ => None,
80 }
81 }
82
83 pub fn contains_opaque(&self) -> bool {
89 match self {
90 SymValue::Hash(_, _) => true,
91 SymValue::Var(var) => {
92 var.name.starts_with("__proj_")
93 || var.name.starts_with("__hash")
94 || var.name.starts_with("__divine")
95 }
96 SymValue::Add(a, b)
97 | SymValue::Mul(a, b)
98 | SymValue::Sub(a, b)
99 | SymValue::Eq(a, b)
100 | SymValue::Lt(a, b) => a.contains_opaque() || b.contains_opaque(),
101 SymValue::Neg(a) | SymValue::Inv(a) => a.contains_opaque(),
102 SymValue::Ite(c, t, e) => {
103 c.contains_opaque() || t.contains_opaque() || e.contains_opaque()
104 }
105 SymValue::FieldAccess(inner, _) => inner.contains_opaque(),
106 SymValue::Const(_) | SymValue::Divine(_) | SymValue::PubInput(_) => false,
107 }
108 }
109
110 pub fn is_external_input(&self) -> bool {
113 match self {
114 SymValue::Var(var) => {
115 var.name.starts_with("pub_in_") || var.name.starts_with("divine_")
116 }
117 SymValue::PubInput(_) | SymValue::Divine(_) => true,
118 _ => false,
119 }
120 }
121
122 pub fn simplify(&self) -> SymValue {
124 match self {
125 SymValue::Add(a, b) => {
126 let a = a.simplify();
127 let b = b.simplify();
128 match (&a, &b) {
129 (SymValue::Const(0), _) => b,
130 (_, SymValue::Const(0)) => a,
131 (SymValue::Const(x), SymValue::Const(y)) => {
132 SymValue::Const(((*x as u128 + *y as u128) % GOLDILOCKS_P as u128) as u64)
133 }
134 _ => SymValue::Add(Box::new(a), Box::new(b)),
135 }
136 }
137 SymValue::Mul(a, b) => {
138 let a = a.simplify();
139 let b = b.simplify();
140 match (&a, &b) {
141 (SymValue::Const(0), _) | (_, SymValue::Const(0)) => SymValue::Const(0),
142 (SymValue::Const(1), _) => b,
143 (_, SymValue::Const(1)) => a,
144 (SymValue::Const(x), SymValue::Const(y)) => {
145 SymValue::Const(((*x as u128 * *y as u128) % GOLDILOCKS_P as u128) as u64)
146 }
147 _ => SymValue::Mul(Box::new(a), Box::new(b)),
148 }
149 }
150 SymValue::Sub(a, b) => {
151 let a = a.simplify();
152 let b = b.simplify();
153 match (&a, &b) {
154 (_, SymValue::Const(0)) => a,
155 (SymValue::Const(x), SymValue::Const(y)) => SymValue::Const(
156 (((*x as u128 + GOLDILOCKS_P as u128) - *y as u128) % GOLDILOCKS_P as u128)
157 as u64,
158 ),
159 _ if a == b => SymValue::Const(0),
160 _ => SymValue::Sub(Box::new(a), Box::new(b)),
161 }
162 }
163 SymValue::Neg(a) => {
164 let a = a.simplify();
165 match &a {
166 SymValue::Const(0) => SymValue::Const(0),
167 SymValue::Const(v) => SymValue::Const(GOLDILOCKS_P - v),
168 _ => SymValue::Neg(Box::new(a)),
169 }
170 }
171 SymValue::Eq(a, b) => {
172 let a = a.simplify();
173 let b = b.simplify();
174 if a == b {
175 SymValue::Const(1)
176 } else {
177 match (&a, &b) {
178 (SymValue::Const(x), SymValue::Const(y)) => {
179 SymValue::Const(if x == y { 1 } else { 0 })
180 }
181 _ => SymValue::Eq(Box::new(a), Box::new(b)),
182 }
183 }
184 }
185 _ => self.clone(),
186 }
187 }
188}
189
190#[derive(Clone, Debug, PartialEq, Eq, Hash)]
192pub struct SymVar {
193 pub name: String,
194 pub version: u32,
196}
197
198impl std::fmt::Display for SymVar {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 if self.version == 0 {
201 write!(f, "{}", self.name)
202 } else {
203 write!(f, "{}_{}", self.name, self.version)
204 }
205 }
206}
207
208#[derive(Clone, Debug)]
212pub enum Constraint {
213 Equal(SymValue, SymValue),
215 AssertTrue(SymValue),
217 Conditional(SymValue, Box<Constraint>),
219 RangeU32(SymValue),
221 DigestEqual(Vec<SymValue>, Vec<SymValue>),
223}
224
225impl Constraint {
226 pub fn is_trivial(&self) -> bool {
228 match self {
229 Constraint::Equal(a, b) => a == b,
230 Constraint::AssertTrue(v) => matches!(v, SymValue::Const(1)),
231 Constraint::RangeU32(v) => {
232 if let SymValue::Const(c) = v {
233 *c <= u32::MAX as u64
234 } else {
235 false
236 }
237 }
238 Constraint::DigestEqual(a, b) => a == b,
239 Constraint::Conditional(cond, inner) => {
240 matches!(cond, SymValue::Const(0)) || inner.is_trivial()
241 }
242 }
243 }
244
245 pub fn is_violated(&self) -> bool {
247 match self {
248 Constraint::Equal(SymValue::Const(a), SymValue::Const(b)) => a != b,
249 Constraint::AssertTrue(SymValue::Const(0)) => true,
250 Constraint::RangeU32(SymValue::Const(c)) => *c > u32::MAX as u64,
251 _ => false,
252 }
253 }
254
255 pub fn is_hash_dependent(&self) -> bool {
262 match self {
263 Constraint::Equal(a, b) => a.contains_opaque() || b.contains_opaque(),
264 Constraint::AssertTrue(v) => v.contains_opaque(),
265 Constraint::Conditional(_, inner) => inner.is_hash_dependent(),
266 Constraint::DigestEqual(a, b) => {
267 a.iter().any(|v| v.contains_opaque()) || b.iter().any(|v| v.contains_opaque())
268 }
269 Constraint::RangeU32(_) => true,
273 }
274 }
275}
276
277#[derive(Clone, Debug)]
281pub struct ConstraintSystem {
282 pub constraints: Vec<Constraint>,
284 pub variables: BTreeMap<String, u32>,
286 pub pub_inputs: Vec<SymVar>,
288 pub pub_outputs: Vec<SymValue>,
290 pub divine_inputs: Vec<SymVar>,
292 pub num_variables: u32,
294}
295
296impl ConstraintSystem {
297 pub fn new() -> Self {
298 Self {
299 constraints: Vec::new(),
300 variables: BTreeMap::new(),
301 pub_inputs: Vec::new(),
302 pub_outputs: Vec::new(),
303 divine_inputs: Vec::new(),
304 num_variables: 0,
305 }
306 }
307
308 pub fn active_constraints(&self) -> usize {
310 self.constraints.iter().filter(|c| !c.is_trivial()).count()
311 }
312
313 pub fn violated_constraints(&self) -> Vec<&Constraint> {
315 self.constraints
316 .iter()
317 .filter(|c| c.is_violated())
318 .collect()
319 }
320
321 pub fn summary(&self) -> String {
323 format!(
324 "Variables: {}, Constraints: {} ({} active), Inputs: {} pub + {} divine, Outputs: {}",
325 self.num_variables,
326 self.constraints.len(),
327 self.active_constraints(),
328 self.pub_inputs.len(),
329 self.divine_inputs.len(),
330 self.pub_outputs.len(),
331 )
332 }
333}
334
335pub fn analyze(file: &File) -> ConstraintSystem {
339 SymExecutor::new().execute_file(file)
340}
341
342pub fn analyze_function(file: &File, fn_name: &str) -> ConstraintSystem {
344 SymExecutor::new().execute_function(file, fn_name)
345}
346
347pub fn analyze_all(file: &File) -> Vec<(String, ConstraintSystem)> {
350 let mut results = Vec::new();
351 for item in &file.items {
352 if let Item::Fn(func) = &item.node {
353 if func.body.is_some() && !func.is_test && func.intrinsic.is_none() {
354 let system = SymExecutor::new().execute_function(file, &func.name.node);
355 results.push((func.name.node.clone(), system));
356 }
357 }
358 }
359 results
360}
361
362#[derive(Clone, Debug)]
364pub struct VerificationResult {
365 pub name: String,
367 pub total_constraints: usize,
369 pub active_constraints: usize,
371 pub violated: Vec<String>,
373 pub redundant_count: usize,
375 pub system_summary: String,
377}
378
379impl VerificationResult {
380 pub fn is_safe(&self) -> bool {
381 self.violated.is_empty()
382 }
383
384 pub fn format_report(&self) -> String {
385 let mut report = String::new();
386 report.push_str(&format!("Verification: {}\n", self.name));
387 report.push_str(&format!(" {}\n", self.system_summary));
388 report.push_str(&format!(
389 " Constraints: {} total, {} active, {} redundant\n",
390 self.total_constraints, self.active_constraints, self.redundant_count,
391 ));
392 if self.violated.is_empty() {
393 report.push_str(" Status: SAFE (no trivially violated assertions)\n");
394 } else {
395 report.push_str(&format!(
396 " Status: VIOLATED ({} assertion(s) always fail)\n",
397 self.violated.len()
398 ));
399 for v in &self.violated {
400 report.push_str(&format!(" - {}\n", v));
401 }
402 }
403 report
404 }
405}
406
407pub fn verify_file(file: &File) -> VerificationResult {
410 let system = analyze(file);
411 let violated: Vec<String> = system
412 .violated_constraints()
413 .iter()
414 .map(|c| format!("{:?}", c))
415 .collect();
416 let redundant_count = system.constraints.iter().filter(|c| c.is_trivial()).count();
417
418 VerificationResult {
419 name: file.name.node.clone(),
420 total_constraints: system.constraints.len(),
421 active_constraints: system.active_constraints(),
422 violated,
423 redundant_count,
424 system_summary: system.summary(),
425 }
426}