tasm_lib/
test_helpers.rs

1use std::collections::HashMap;
2use std::fmt::Display;
3
4use itertools::Itertools;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::RustShadowOutputState;
10use crate::dyn_malloc::DYN_MALLOC_ADDRESS;
11use crate::execute_test;
12use crate::execute_with_terminal_state;
13use crate::prelude::Tip5;
14use crate::traits::basic_snippet::SignedOffSnippet;
15use crate::traits::rust_shadow::RustShadow;
16
17pub fn rust_final_state<T: RustShadow>(
18    shadowed_snippet: &T,
19    stack: &[BFieldElement],
20    stdin: &[BFieldElement],
21    nondeterminism: &NonDeterminism,
22    sponge: &Option<Tip5>,
23) -> RustShadowOutputState {
24    let mut rust_memory = nondeterminism.ram.clone();
25    let mut rust_stack = stack.to_vec();
26    let mut rust_sponge = sponge.clone();
27
28    // run rust shadow
29    let output = shadowed_snippet.rust_shadow_wrapper(
30        stdin,
31        nondeterminism,
32        &mut rust_stack,
33        &mut rust_memory,
34        &mut rust_sponge,
35    );
36
37    RustShadowOutputState {
38        public_output: output,
39        stack: rust_stack,
40        ram: rust_memory,
41        sponge: rust_sponge,
42    }
43}
44
45pub fn tasm_final_state<T: RustShadow>(
46    shadowed_snippet: &T,
47    stack: &[BFieldElement],
48    stdin: &[BFieldElement],
49    nondeterminism: NonDeterminism,
50    sponge: &Option<Tip5>,
51) -> VMState {
52    // run tvm
53    link_and_run_tasm_for_test(
54        shadowed_snippet,
55        &mut stack.to_vec(),
56        stdin.to_vec(),
57        nondeterminism,
58        sponge.to_owned(),
59    )
60}
61
62/// assert stacks are equal, up to program hash
63pub fn verify_stack_equivalence(
64    stack_a_name: &str,
65    stack_a: &[BFieldElement],
66    stack_b_name: &str,
67    stack_b: &[BFieldElement],
68) {
69    let stack_a_name = format!("{stack_a_name}:");
70    let stack_b_name = format!("{stack_b_name}:");
71    let max_stack_name_len = stack_a_name.len().max(stack_b_name.len());
72
73    let stack_a = &stack_a[Digest::LEN..];
74    let stack_b = &stack_b[Digest::LEN..];
75    let display = |stack: &[BFieldElement]| stack.iter().join(",");
76    assert_eq!(
77        stack_a,
78        stack_b,
79        "{stack_a_name} stack must match {stack_b_name} stack\n\n\
80         {stack_a_name:<max_stack_name_len$} {}\n\n\
81         {stack_b_name:<max_stack_name_len$} {}",
82        display(stack_a),
83        display(stack_b),
84    );
85}
86
87/// Verify equivalence of memory up to the value of dynamic allocator.
88pub(crate) fn verify_memory_equivalence(
89    a_name: &str,
90    a_memory: &HashMap<BFieldElement, BFieldElement>,
91    b_name: &str,
92    b_memory: &HashMap<BFieldElement, BFieldElement>,
93) {
94    let memory_without_dyn_malloc = |mem: HashMap<_, _>| -> HashMap<_, _> {
95        mem.into_iter()
96            .filter(|&(k, _)| k != DYN_MALLOC_ADDRESS)
97            .collect()
98    };
99    let a_memory = memory_without_dyn_malloc(a_memory.clone());
100    let b_memory = memory_without_dyn_malloc(b_memory.clone());
101    if a_memory == b_memory {
102        return;
103    }
104
105    fn format_hash_map_iterator<K, V>(map: impl Iterator<Item = (K, V)>) -> String
106    where
107        u64: From<K>,
108        K: Copy + Display,
109        V: Display,
110    {
111        map.sorted_by_key(|(k, _)| u64::from(*k))
112            .map(|(k, v)| format!("({k} => {v})"))
113            .join(", ")
114    }
115
116    let in_a_and_different_in_b = a_memory
117        .iter()
118        .filter(|&(k, v)| b_memory.get(k).map(|b| b != v).unwrap_or(true));
119    let in_b_and_different_in_a = b_memory
120        .iter()
121        .filter(|&(k, v)| a_memory.get(k).map(|b| b != v).unwrap_or(true));
122
123    let in_a_and_different_in_b = format_hash_map_iterator(in_a_and_different_in_b);
124    let in_b_and_different_in_a = format_hash_map_iterator(in_b_and_different_in_a);
125
126    panic!(
127        "Memory for both implementations must match after execution.\n\n\
128        In {b_name}, different in {a_name}: {in_b_and_different_in_a}\n\n\
129        In {a_name}, different in {b_name}: {in_a_and_different_in_b}"
130    );
131}
132
133pub fn verify_stack_growth<T: RustShadow>(
134    shadowed_snippet: &T,
135    initial_stack: &[BFieldElement],
136    final_stack: &[BFieldElement],
137) {
138    let observed_stack_growth: isize = final_stack.len() as isize - initial_stack.len() as isize;
139    let expected_stack_growth: isize = shadowed_snippet.inner().stack_diff();
140    assert_eq!(
141        expected_stack_growth,
142        observed_stack_growth,
143        "Stack must pop and push expected number of elements. Got input: {}\nGot output: {}",
144        initial_stack.iter().map(|x| x.to_string()).join(","),
145        final_stack.iter().map(|x| x.to_string()).join(",")
146    );
147}
148
149pub fn verify_sponge_equivalence(a: &Option<Tip5>, b: &Option<Tip5>) {
150    match (a, b) {
151        (Some(state_a), Some(state_b)) => assert_eq!(state_a.state, state_b.state),
152        (None, None) => (),
153        _ => panic!("{a:?} != {b:?}"),
154    };
155}
156
157pub fn test_rust_equivalence_given_complete_state<T: RustShadow>(
158    shadowed_snippet: &T,
159    stack: &[BFieldElement],
160    stdin: &[BFieldElement],
161    nondeterminism: &NonDeterminism,
162    sponge: &Option<Tip5>,
163    expected_final_stack: Option<&[BFieldElement]>,
164) -> VMState {
165    shadowed_snippet
166        .inner()
167        .assert_all_sign_offs_are_up_to_date();
168
169    let init_stack = stack.to_vec();
170
171    let rust = rust_final_state(shadowed_snippet, stack, stdin, nondeterminism, sponge);
172
173    // run tvm
174    let tasm = tasm_final_state(
175        shadowed_snippet,
176        stack,
177        stdin,
178        nondeterminism.clone(),
179        sponge,
180    );
181
182    assert_eq!(
183        rust.public_output, tasm.public_output,
184        "Rust shadowing and VM std out must agree"
185    );
186
187    verify_stack_equivalence(
188        "rust-shadow final stack",
189        &rust.stack,
190        "TASM final stack",
191        &tasm.op_stack.stack,
192    );
193    if let Some(expected) = expected_final_stack {
194        verify_stack_equivalence("expected", expected, "actual", &rust.stack);
195    }
196    verify_memory_equivalence("Rust-shadow", &rust.ram, "TVM", &tasm.ram);
197    verify_stack_growth(shadowed_snippet, &init_stack, &tasm.op_stack.stack);
198
199    tasm
200}
201
202pub fn link_and_run_tasm_for_test<T: RustShadow>(
203    snippet_struct: &T,
204    stack: &mut Vec<BFieldElement>,
205    std_in: Vec<BFieldElement>,
206    nondeterminism: NonDeterminism,
207    maybe_sponge: Option<Tip5>,
208) -> VMState {
209    let code = snippet_struct.inner().link_for_isolated_run();
210
211    execute_test(
212        &code,
213        stack,
214        snippet_struct.inner().stack_diff(),
215        std_in,
216        nondeterminism,
217        maybe_sponge,
218    )
219}
220
221pub fn test_rust_equivalence_given_execution_state<S: RustShadow>(
222    snippet_struct: &S,
223    execution_state: InitVmState,
224) -> VMState {
225    test_rust_equivalence_given_complete_state::<S>(
226        snippet_struct,
227        &execution_state.stack,
228        &execution_state.public_input,
229        &execution_state.nondeterminism,
230        &execution_state.sponge,
231        None,
232    )
233}
234
235pub fn negative_test<T: RustShadow>(
236    snippet: &T,
237    initial_state: InitVmState,
238    allowed_errors: &[InstructionError],
239) {
240    let err = instruction_error_from_failing_code(snippet, initial_state);
241    assert!(
242        allowed_errors.contains(&err),
243        "Triton VM execution must fail with one of the expected errors:\n- {}\n\n Got:\n{err}",
244        allowed_errors.iter().join("\n- ")
245    );
246}
247
248pub fn test_assertion_failure<S: RustShadow>(
249    snippet_struct: &S,
250    initial_state: InitVmState,
251    expected_error_ids: &[i128],
252) {
253    let err = instruction_error_from_failing_code(snippet_struct, initial_state);
254    let maybe_error_id = match err {
255        InstructionError::AssertionFailed(err)
256        | InstructionError::VectorAssertionFailed(_, err) => err.id,
257        _ => panic!("Triton VM execution failed, but not due to an assertion. Instead, got: {err}"),
258    };
259    let error_id = maybe_error_id.expect(
260        "Triton VM execution failed due to unfulfilled assertion, but that assertion has no \
261        error ID. See `tasm-lib/src/assertion_error_ids.md` to grab a unique ID.",
262    );
263    let expected_error_ids_str = expected_error_ids.iter().join(", ");
264    assert!(
265        expected_error_ids.contains(&error_id),
266        "error ID {error_id} ∉ {{{expected_error_ids_str}}}\nTriton VM execution failed due to \
267         unfulfilled assertion with error ID {error_id}, but expected one of the following IDs: \
268         {{{expected_error_ids_str}}}",
269    );
270}
271
272fn instruction_error_from_failing_code<S: RustShadow>(
273    snippet: &S,
274    init_state: InitVmState,
275) -> InstructionError {
276    // `AssertUnwindSafe` is fine because the caught panic is discarded immediately
277    let rust_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
278        let mut rust_stack = init_state.stack.clone();
279        let mut rust_memory = init_state.nondeterminism.ram.clone();
280        let mut rust_sponge = init_state.sponge.clone();
281        snippet.rust_shadow_wrapper(
282            &init_state.public_input,
283            &init_state.nondeterminism,
284            &mut rust_stack,
285            &mut rust_memory,
286            &mut rust_sponge,
287        )
288    }));
289    assert!(
290        rust_result.is_err(),
291        "Failed to fail: Rust-shadowing must panic in negative test case"
292    );
293
294    let code = snippet.inner().link_for_isolated_run();
295    let tvm_result = execute_with_terminal_state(
296        Program::new(&code),
297        &init_state.public_input,
298        &init_state.stack,
299        &init_state.nondeterminism,
300        init_state.sponge,
301    );
302
303    tvm_result.expect_err("Failed to fail: Triton VM execution must crash in negative test case")
304}
305
306pub fn prepend_program_with_stack_setup(
307    init_stack: &[BFieldElement],
308    program: &Program,
309) -> Program {
310    let stack_initialization_code = init_stack
311        .iter()
312        .skip(NUM_OP_STACK_REGISTERS)
313        .map(|&word| triton_instr!(push word))
314        .collect_vec();
315
316    Program::new(&[stack_initialization_code, program.labelled_instructions()].concat())
317}
318
319pub fn prepend_program_with_sponge_init(program: &Program) -> Program {
320    Program::new(&[triton_asm!(sponge_init), program.labelled_instructions()].concat())
321}
322
323/// Store the output from Triton VM's `proof` function as files, such that a deterministic
324/// proof can be used for debugging purposes.
325pub fn maybe_write_tvm_output_to_disk(
326    stark: &Stark,
327    claim: &triton_vm::proof::Claim,
328    proof: &Proof,
329) {
330    use std::io::Write;
331    let Ok(_) = std::env::var("TASMLIB_STORE") else {
332        return;
333    };
334
335    let mut stark_file = std::fs::File::create("stark.json").unwrap();
336    let state = serde_json::to_string(stark).unwrap();
337    write!(stark_file, "{state}").unwrap();
338    let mut claim_file = std::fs::File::create("claim.json").unwrap();
339    let claim = serde_json::to_string(claim).unwrap();
340    write!(claim_file, "{claim}").unwrap();
341    let mut proof_file = std::fs::File::create("proof.json").unwrap();
342    let proof = serde_json::to_string(proof).unwrap();
343    write!(proof_file, "{proof}").unwrap();
344}