1use std::collections::HashMap;
2use std::fmt::Display;
3
4use itertools::Itertools;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::RustShadowOutputState;
10use crate::dyn_malloc::DYN_MALLOC_ADDRESS;
11use crate::execute_test;
12use crate::execute_with_terminal_state;
13use crate::prelude::Tip5;
14use crate::traits::basic_snippet::SignedOffSnippet;
15use crate::traits::rust_shadow::RustShadow;
16
17pub fn rust_final_state<T: RustShadow>(
18 shadowed_snippet: &T,
19 stack: &[BFieldElement],
20 stdin: &[BFieldElement],
21 nondeterminism: &NonDeterminism,
22 sponge: &Option<Tip5>,
23) -> RustShadowOutputState {
24 let mut rust_memory = nondeterminism.ram.clone();
25 let mut rust_stack = stack.to_vec();
26 let mut rust_sponge = sponge.clone();
27
28 let output = shadowed_snippet.rust_shadow_wrapper(
30 stdin,
31 nondeterminism,
32 &mut rust_stack,
33 &mut rust_memory,
34 &mut rust_sponge,
35 );
36
37 RustShadowOutputState {
38 public_output: output,
39 stack: rust_stack,
40 ram: rust_memory,
41 sponge: rust_sponge,
42 }
43}
44
45pub fn tasm_final_state<T: RustShadow>(
46 shadowed_snippet: &T,
47 stack: &[BFieldElement],
48 stdin: &[BFieldElement],
49 nondeterminism: NonDeterminism,
50 sponge: &Option<Tip5>,
51) -> VMState {
52 link_and_run_tasm_for_test(
54 shadowed_snippet,
55 &mut stack.to_vec(),
56 stdin.to_vec(),
57 nondeterminism,
58 sponge.to_owned(),
59 )
60}
61
62pub fn verify_stack_equivalence(
64 stack_a_name: &str,
65 stack_a: &[BFieldElement],
66 stack_b_name: &str,
67 stack_b: &[BFieldElement],
68) {
69 let stack_a_name = format!("{stack_a_name}:");
70 let stack_b_name = format!("{stack_b_name}:");
71 let max_stack_name_len = stack_a_name.len().max(stack_b_name.len());
72
73 let stack_a = &stack_a[Digest::LEN..];
74 let stack_b = &stack_b[Digest::LEN..];
75 let display = |stack: &[BFieldElement]| stack.iter().join(",");
76 assert_eq!(
77 stack_a,
78 stack_b,
79 "{stack_a_name} stack must match {stack_b_name} stack\n\n\
80 {stack_a_name:<max_stack_name_len$} {}\n\n\
81 {stack_b_name:<max_stack_name_len$} {}",
82 display(stack_a),
83 display(stack_b),
84 );
85}
86
87pub(crate) fn verify_memory_equivalence(
89 a_name: &str,
90 a_memory: &HashMap<BFieldElement, BFieldElement>,
91 b_name: &str,
92 b_memory: &HashMap<BFieldElement, BFieldElement>,
93) {
94 let memory_without_dyn_malloc = |mem: HashMap<_, _>| -> HashMap<_, _> {
95 mem.into_iter()
96 .filter(|&(k, _)| k != DYN_MALLOC_ADDRESS)
97 .collect()
98 };
99 let a_memory = memory_without_dyn_malloc(a_memory.clone());
100 let b_memory = memory_without_dyn_malloc(b_memory.clone());
101 if a_memory == b_memory {
102 return;
103 }
104
105 fn format_hash_map_iterator<K, V>(map: impl Iterator<Item = (K, V)>) -> String
106 where
107 u64: From<K>,
108 K: Copy + Display,
109 V: Display,
110 {
111 map.sorted_by_key(|(k, _)| u64::from(*k))
112 .map(|(k, v)| format!("({k} => {v})"))
113 .join(", ")
114 }
115
116 let in_a_and_different_in_b = a_memory
117 .iter()
118 .filter(|&(k, v)| b_memory.get(k).map(|b| b != v).unwrap_or(true));
119 let in_b_and_different_in_a = b_memory
120 .iter()
121 .filter(|&(k, v)| a_memory.get(k).map(|b| b != v).unwrap_or(true));
122
123 let in_a_and_different_in_b = format_hash_map_iterator(in_a_and_different_in_b);
124 let in_b_and_different_in_a = format_hash_map_iterator(in_b_and_different_in_a);
125
126 panic!(
127 "Memory for both implementations must match after execution.\n\n\
128 In {b_name}, different in {a_name}: {in_b_and_different_in_a}\n\n\
129 In {a_name}, different in {b_name}: {in_a_and_different_in_b}"
130 );
131}
132
133pub fn verify_stack_growth<T: RustShadow>(
134 shadowed_snippet: &T,
135 initial_stack: &[BFieldElement],
136 final_stack: &[BFieldElement],
137) {
138 let observed_stack_growth: isize = final_stack.len() as isize - initial_stack.len() as isize;
139 let expected_stack_growth: isize = shadowed_snippet.inner().stack_diff();
140 assert_eq!(
141 expected_stack_growth,
142 observed_stack_growth,
143 "Stack must pop and push expected number of elements. Got input: {}\nGot output: {}",
144 initial_stack.iter().map(|x| x.to_string()).join(","),
145 final_stack.iter().map(|x| x.to_string()).join(",")
146 );
147}
148
149pub fn verify_sponge_equivalence(a: &Option<Tip5>, b: &Option<Tip5>) {
150 match (a, b) {
151 (Some(state_a), Some(state_b)) => assert_eq!(state_a.state, state_b.state),
152 (None, None) => (),
153 _ => panic!("{a:?} != {b:?}"),
154 };
155}
156
157pub fn test_rust_equivalence_given_complete_state<T: RustShadow>(
158 shadowed_snippet: &T,
159 stack: &[BFieldElement],
160 stdin: &[BFieldElement],
161 nondeterminism: &NonDeterminism,
162 sponge: &Option<Tip5>,
163 expected_final_stack: Option<&[BFieldElement]>,
164) -> VMState {
165 shadowed_snippet
166 .inner()
167 .assert_all_sign_offs_are_up_to_date();
168
169 let init_stack = stack.to_vec();
170
171 let rust = rust_final_state(shadowed_snippet, stack, stdin, nondeterminism, sponge);
172
173 let tasm = tasm_final_state(
175 shadowed_snippet,
176 stack,
177 stdin,
178 nondeterminism.clone(),
179 sponge,
180 );
181
182 assert_eq!(
183 rust.public_output, tasm.public_output,
184 "Rust shadowing and VM std out must agree"
185 );
186
187 verify_stack_equivalence(
188 "rust-shadow final stack",
189 &rust.stack,
190 "TASM final stack",
191 &tasm.op_stack.stack,
192 );
193 if let Some(expected) = expected_final_stack {
194 verify_stack_equivalence("expected", expected, "actual", &rust.stack);
195 }
196 verify_memory_equivalence("Rust-shadow", &rust.ram, "TVM", &tasm.ram);
197 verify_stack_growth(shadowed_snippet, &init_stack, &tasm.op_stack.stack);
198
199 tasm
200}
201
202pub fn link_and_run_tasm_for_test<T: RustShadow>(
203 snippet_struct: &T,
204 stack: &mut Vec<BFieldElement>,
205 std_in: Vec<BFieldElement>,
206 nondeterminism: NonDeterminism,
207 maybe_sponge: Option<Tip5>,
208) -> VMState {
209 let code = snippet_struct.inner().link_for_isolated_run();
210
211 execute_test(
212 &code,
213 stack,
214 snippet_struct.inner().stack_diff(),
215 std_in,
216 nondeterminism,
217 maybe_sponge,
218 )
219}
220
221pub fn test_rust_equivalence_given_execution_state<S: RustShadow>(
222 snippet_struct: &S,
223 execution_state: InitVmState,
224) -> VMState {
225 test_rust_equivalence_given_complete_state::<S>(
226 snippet_struct,
227 &execution_state.stack,
228 &execution_state.public_input,
229 &execution_state.nondeterminism,
230 &execution_state.sponge,
231 None,
232 )
233}
234
235pub fn negative_test<T: RustShadow>(
236 snippet: &T,
237 initial_state: InitVmState,
238 allowed_errors: &[InstructionError],
239) {
240 let err = instruction_error_from_failing_code(snippet, initial_state);
241 assert!(
242 allowed_errors.contains(&err),
243 "Triton VM execution must fail with one of the expected errors:\n- {}\n\n Got:\n{err}",
244 allowed_errors.iter().join("\n- ")
245 );
246}
247
248pub fn test_assertion_failure<S: RustShadow>(
249 snippet_struct: &S,
250 initial_state: InitVmState,
251 expected_error_ids: &[i128],
252) {
253 let err = instruction_error_from_failing_code(snippet_struct, initial_state);
254 let maybe_error_id = match err {
255 InstructionError::AssertionFailed(err)
256 | InstructionError::VectorAssertionFailed(_, err) => err.id,
257 _ => panic!("Triton VM execution failed, but not due to an assertion. Instead, got: {err}"),
258 };
259 let error_id = maybe_error_id.expect(
260 "Triton VM execution failed due to unfulfilled assertion, but that assertion has no \
261 error ID. See `tasm-lib/src/assertion_error_ids.md` to grab a unique ID.",
262 );
263 let expected_error_ids_str = expected_error_ids.iter().join(", ");
264 assert!(
265 expected_error_ids.contains(&error_id),
266 "error ID {error_id} ∉ {{{expected_error_ids_str}}}\nTriton VM execution failed due to \
267 unfulfilled assertion with error ID {error_id}, but expected one of the following IDs: \
268 {{{expected_error_ids_str}}}",
269 );
270}
271
272fn instruction_error_from_failing_code<S: RustShadow>(
273 snippet: &S,
274 init_state: InitVmState,
275) -> InstructionError {
276 let rust_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
278 let mut rust_stack = init_state.stack.clone();
279 let mut rust_memory = init_state.nondeterminism.ram.clone();
280 let mut rust_sponge = init_state.sponge.clone();
281 snippet.rust_shadow_wrapper(
282 &init_state.public_input,
283 &init_state.nondeterminism,
284 &mut rust_stack,
285 &mut rust_memory,
286 &mut rust_sponge,
287 )
288 }));
289 assert!(
290 rust_result.is_err(),
291 "Failed to fail: Rust-shadowing must panic in negative test case"
292 );
293
294 let code = snippet.inner().link_for_isolated_run();
295 let tvm_result = execute_with_terminal_state(
296 Program::new(&code),
297 &init_state.public_input,
298 &init_state.stack,
299 &init_state.nondeterminism,
300 init_state.sponge,
301 );
302
303 tvm_result.expect_err("Failed to fail: Triton VM execution must crash in negative test case")
304}
305
306pub fn prepend_program_with_stack_setup(
307 init_stack: &[BFieldElement],
308 program: &Program,
309) -> Program {
310 let stack_initialization_code = init_stack
311 .iter()
312 .skip(NUM_OP_STACK_REGISTERS)
313 .map(|&word| triton_instr!(push word))
314 .collect_vec();
315
316 Program::new(&[stack_initialization_code, program.labelled_instructions()].concat())
317}
318
319pub fn prepend_program_with_sponge_init(program: &Program) -> Program {
320 Program::new(&[triton_asm!(sponge_init), program.labelled_instructions()].concat())
321}
322
323pub fn maybe_write_tvm_output_to_disk(
326 stark: &Stark,
327 claim: &triton_vm::proof::Claim,
328 proof: &Proof,
329) {
330 use std::io::Write;
331 let Ok(_) = std::env::var("TASMLIB_STORE") else {
332 return;
333 };
334
335 let mut stark_file = std::fs::File::create("stark.json").unwrap();
336 let state = serde_json::to_string(stark).unwrap();
337 write!(stark_file, "{state}").unwrap();
338 let mut claim_file = std::fs::File::create("claim.json").unwrap();
339 let claim = serde_json::to_string(claim).unwrap();
340 write!(claim_file, "{claim}").unwrap();
341 let mut proof_file = std::fs::File::create("proof.json").unwrap();
342 let proof = serde_json::to_string(proof).unwrap();
343 write!(proof_file, "{proof}").unwrap();
344}