Skip to main content

tasm_lib/
test_helpers.rs

1use std::collections::HashMap;
2use std::fmt::Display;
3
4use itertools::Itertools;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::RustShadowOutputState;
10use crate::dyn_malloc::DYN_MALLOC_ADDRESS;
11use crate::execute_test;
12use crate::execute_with_terminal_state;
13use crate::prelude::Tip5;
14use crate::traits::basic_snippet::SignedOffSnippet;
15use crate::traits::rust_shadow::RustShadow;
16use crate::traits::rust_shadow::RustShadowError;
17
18pub fn rust_final_state<T: RustShadow>(
19    shadowed_snippet: &T,
20    stack: &[BFieldElement],
21    stdin: &[BFieldElement],
22    nondeterminism: &NonDeterminism,
23    sponge: &Option<Tip5>,
24) -> RustShadowOutputState {
25    let mut rust_memory = nondeterminism.ram.clone();
26    let mut rust_stack = stack.to_vec();
27    let mut rust_sponge = sponge.clone();
28
29    // run rust shadow
30    let output = shadowed_snippet
31        .rust_shadow_wrapper(
32            stdin,
33            nondeterminism,
34            &mut rust_stack,
35            &mut rust_memory,
36            &mut rust_sponge,
37        )
38        .unwrap();
39
40    RustShadowOutputState {
41        public_output: output,
42        stack: rust_stack,
43        ram: rust_memory,
44        sponge: rust_sponge,
45    }
46}
47
48pub fn tasm_final_state<T: RustShadow>(
49    shadowed_snippet: &T,
50    stack: &[BFieldElement],
51    stdin: &[BFieldElement],
52    nondeterminism: NonDeterminism,
53    sponge: &Option<Tip5>,
54) -> Result<VMState, RustShadowError> {
55    // run tvm
56    link_and_run_tasm_for_test(
57        shadowed_snippet,
58        &mut stack.to_vec(),
59        stdin.to_vec(),
60        nondeterminism,
61        sponge.to_owned(),
62    )
63}
64
65/// assert stacks are equal, up to program hash
66pub fn verify_stack_equivalence(
67    stack_a_name: &str,
68    stack_a: &[BFieldElement],
69    stack_b_name: &str,
70    stack_b: &[BFieldElement],
71) {
72    let stack_a_name = format!("{stack_a_name}:");
73    let stack_b_name = format!("{stack_b_name}:");
74    let max_stack_name_len = stack_a_name.len().max(stack_b_name.len());
75
76    let stack_a = &stack_a[Digest::LEN..];
77    let stack_b = &stack_b[Digest::LEN..];
78    let display = |stack: &[BFieldElement]| stack.iter().join(",");
79    assert_eq!(
80        stack_a,
81        stack_b,
82        "{stack_a_name} stack must match {stack_b_name} stack\n\n\
83         {stack_a_name:<max_stack_name_len$} {}\n\n\
84         {stack_b_name:<max_stack_name_len$} {}",
85        display(stack_a),
86        display(stack_b),
87    );
88}
89
90/// Verify equivalence of memory up to the value of dynamic allocator.
91pub(crate) fn verify_memory_equivalence(
92    a_name: &str,
93    a_memory: &HashMap<BFieldElement, BFieldElement>,
94    b_name: &str,
95    b_memory: &HashMap<BFieldElement, BFieldElement>,
96) {
97    let memory_without_dyn_malloc = |mem: HashMap<_, _>| -> HashMap<_, _> {
98        mem.into_iter()
99            .filter(|&(k, _)| k != DYN_MALLOC_ADDRESS)
100            .collect()
101    };
102    let a_memory = memory_without_dyn_malloc(a_memory.clone());
103    let b_memory = memory_without_dyn_malloc(b_memory.clone());
104    if a_memory == b_memory {
105        return;
106    }
107
108    fn format_hash_map_iterator<K, V>(map: impl Iterator<Item = (K, V)>) -> String
109    where
110        u64: From<K>,
111        K: Copy + Display,
112        V: Display,
113    {
114        map.sorted_by_key(|(k, _)| u64::from(*k))
115            .map(|(k, v)| format!("({k} => {v})"))
116            .join(", ")
117    }
118
119    let in_a_and_different_in_b = a_memory
120        .iter()
121        .filter(|&(k, v)| b_memory.get(k).map(|b| b != v).unwrap_or(true));
122    let in_b_and_different_in_a = b_memory
123        .iter()
124        .filter(|&(k, v)| a_memory.get(k).map(|b| b != v).unwrap_or(true));
125
126    let in_a_and_different_in_b = format_hash_map_iterator(in_a_and_different_in_b);
127    let in_b_and_different_in_a = format_hash_map_iterator(in_b_and_different_in_a);
128
129    panic!(
130        "Memory for both implementations must match after execution.\n\n\
131        In {b_name}, different in {a_name}: {in_b_and_different_in_a}\n\n\
132        In {a_name}, different in {b_name}: {in_a_and_different_in_b}"
133    );
134}
135
136pub fn verify_stack_growth<T: RustShadow>(
137    shadowed_snippet: &T,
138    initial_stack: &[BFieldElement],
139    final_stack: &[BFieldElement],
140) {
141    let observed_stack_growth: isize = final_stack.len() as isize - initial_stack.len() as isize;
142    let expected_stack_growth: isize = shadowed_snippet.inner().stack_diff();
143    assert_eq!(
144        expected_stack_growth,
145        observed_stack_growth,
146        "Stack must pop and push expected number of elements. Got input: {}\nGot output: {}",
147        initial_stack.iter().map(|x| x.to_string()).join(","),
148        final_stack.iter().map(|x| x.to_string()).join(",")
149    );
150}
151
152pub fn verify_sponge_equivalence(a: &Option<Tip5>, b: &Option<Tip5>) {
153    match (a, b) {
154        (Some(state_a), Some(state_b)) => assert_eq!(state_a.state, state_b.state),
155        (None, None) => (),
156        _ => panic!("{a:?} != {b:?}"),
157    };
158}
159
160pub fn test_rust_equivalence_given_complete_state<T: RustShadow>(
161    shadowed_snippet: &T,
162    stack: &[BFieldElement],
163    stdin: &[BFieldElement],
164    nondeterminism: &NonDeterminism,
165    sponge: &Option<Tip5>,
166    expected_final_stack: Option<&[BFieldElement]>,
167) -> VMState {
168    shadowed_snippet
169        .inner()
170        .assert_all_sign_offs_are_up_to_date();
171
172    let init_stack = stack.to_vec();
173
174    let rust = rust_final_state(shadowed_snippet, stack, stdin, nondeterminism, sponge);
175
176    // run tvm
177    let tasm = tasm_final_state(
178        shadowed_snippet,
179        stack,
180        stdin,
181        nondeterminism.clone(),
182        sponge,
183    )
184    .unwrap();
185
186    assert_eq!(
187        rust.public_output, tasm.public_output,
188        "Rust shadowing and VM std out must agree"
189    );
190
191    verify_stack_equivalence(
192        "rust-shadow final stack",
193        &rust.stack,
194        "TASM final stack",
195        &tasm.op_stack.stack,
196    );
197    if let Some(expected) = expected_final_stack {
198        verify_stack_equivalence("expected", expected, "actual", &rust.stack);
199    }
200    verify_memory_equivalence("Rust-shadow", &rust.ram, "TVM", &tasm.ram);
201    verify_stack_growth(shadowed_snippet, &init_stack, &tasm.op_stack.stack);
202
203    tasm
204}
205
206pub fn link_and_run_tasm_for_test<T: RustShadow>(
207    snippet_struct: &T,
208    stack: &mut Vec<BFieldElement>,
209    std_in: Vec<BFieldElement>,
210    nondeterminism: NonDeterminism,
211    maybe_sponge: Option<Tip5>,
212) -> Result<VMState, RustShadowError> {
213    let code = snippet_struct.inner().link_for_isolated_run();
214
215    execute_test(
216        &code,
217        stack,
218        snippet_struct.inner().stack_diff(),
219        std_in,
220        nondeterminism,
221        maybe_sponge,
222    )
223}
224
225pub fn test_rust_equivalence_given_execution_state<S: RustShadow>(
226    snippet_struct: &S,
227    execution_state: InitVmState,
228) -> VMState {
229    test_rust_equivalence_given_complete_state::<S>(
230        snippet_struct,
231        &execution_state.stack,
232        &execution_state.public_input,
233        &execution_state.nondeterminism,
234        &execution_state.sponge,
235        None,
236    )
237}
238
239pub fn negative_test<T: RustShadow>(
240    snippet: &T,
241    initial_state: InitVmState,
242    allowed_errors: &[InstructionError],
243) {
244    let err = instruction_error_from_failing_code(snippet, initial_state);
245    assert!(
246        allowed_errors.contains(&err),
247        "Triton VM execution must fail with one of the expected errors:\n- {}\n\n Got:\n{err}",
248        allowed_errors.iter().join("\n- ")
249    );
250}
251
252pub fn test_assertion_failure<S: RustShadow>(
253    snippet_struct: &S,
254    initial_state: InitVmState,
255    expected_error_ids: &[i128],
256) {
257    let err = instruction_error_from_failing_code(snippet_struct, initial_state);
258    let maybe_error_id = match err {
259        InstructionError::AssertionFailed(err)
260        | InstructionError::VectorAssertionFailed(_, err) => err.id,
261        _ => panic!("Triton VM execution failed, but not due to an assertion. Instead, got: {err}"),
262    };
263    let error_id = maybe_error_id.expect(
264        "Triton VM execution failed due to unfulfilled assertion, but that assertion has no \
265        error ID. See `tasm-lib/src/assertion_error_ids.md` to grab a unique ID.",
266    );
267    let expected_error_ids_str = expected_error_ids.iter().join(", ");
268    assert!(
269        expected_error_ids.contains(&error_id),
270        "error ID {error_id} ∉ {{{expected_error_ids_str}}}\nTriton VM execution failed due to \
271         unfulfilled assertion with error ID {error_id}, but expected one of the following IDs: \
272         {{{expected_error_ids_str}}}",
273    );
274}
275
276fn instruction_error_from_failing_code<S: RustShadow>(
277    snippet: &S,
278    init_state: InitVmState,
279) -> InstructionError {
280    let mut rust_stack = init_state.stack.clone();
281    let mut rust_memory = init_state.nondeterminism.ram.clone();
282    let mut rust_sponge = init_state.sponge.clone();
283    let rust_result = snippet.rust_shadow_wrapper(
284        &init_state.public_input,
285        &init_state.nondeterminism,
286        &mut rust_stack,
287        &mut rust_memory,
288        &mut rust_sponge,
289    );
290    rust_result.expect_err("Failed to fail: Rust-shadowing must panic in negative test case");
291
292    let code = snippet.inner().link_for_isolated_run();
293    let tvm_result = execute_with_terminal_state(
294        Program::new(&code),
295        &init_state.public_input,
296        &init_state.stack,
297        &init_state.nondeterminism,
298        init_state.sponge,
299    );
300    tvm_result.expect_err("Failed to fail: Triton VM execution must crash in negative test case")
301}
302
303pub fn prepend_program_with_stack_setup(
304    init_stack: &[BFieldElement],
305    program: &Program,
306) -> Program {
307    let stack_initialization_code = init_stack
308        .iter()
309        .skip(NUM_OP_STACK_REGISTERS)
310        .map(|&word| triton_instr!(push word))
311        .collect_vec();
312
313    Program::new(&[stack_initialization_code, program.labelled_instructions()].concat())
314}
315
316pub fn prepend_program_with_sponge_init(program: &Program) -> Program {
317    Program::new(&[triton_asm!(sponge_init), program.labelled_instructions()].concat())
318}
319
320/// Store the output from Triton VM's `proof` function as files, such that a deterministic
321/// proof can be used for debugging purposes.
322pub fn maybe_write_tvm_output_to_disk(stark: &Stark, claim: &Claim, proof: &Proof) {
323    use std::io::Write;
324    let Ok(_) = std::env::var("TASMLIB_STORE") else {
325        return;
326    };
327
328    let mut stark_file = std::fs::File::create("stark.json").unwrap();
329    let state = serde_json::to_string(stark).unwrap();
330    write!(stark_file, "{state}").unwrap();
331    let mut claim_file = std::fs::File::create("claim.json").unwrap();
332    let claim = serde_json::to_string(claim).unwrap();
333    write!(claim_file, "{claim}").unwrap();
334    let mut proof_file = std::fs::File::create("proof.json").unwrap();
335    let proof = serde_json::to_string(proof).unwrap();
336    write!(proof_file, "{proof}").unwrap();
337}