1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
use std::collections::HashMap;

use itertools::Itertools;
use twenty_first::shared_math::b_field_element::BFieldElement;

use crate::dyn_malloc::DYN_MALLOC_ADDRESS;
use crate::snippet::Snippet;
use crate::snippet_state::SnippetState;
use crate::{
    exported_snippets, rust_shadowing_helper_functions, ExecutionState, VmOutputState,
    DIGEST_LENGTH,
};

#[allow(dead_code)]
pub fn test_rust_equivalence_multiple<T: Snippet>(
    snippet_struct: &T,
    export_snippet: bool,
) -> Vec<VmOutputState> {
    // Verify that snippet can be found in `all_snippets`, so it's visible to the outside
    // This call will panic if snippet is not found in that function call
    // The data type value is a dummy value for all snippets except those that handle lists.
    if export_snippet {
        let looked_up_snippet = exported_snippets::name_to_snippet(&snippet_struct.entrypoint());
        assert_eq!(
            snippet_struct.entrypoint(),
            looked_up_snippet.entrypoint(),
            "Looked up snippet must match self"
        );
    }

    let mut execution_states = snippet_struct.gen_input_states();

    let mut vm_output_states = vec![];
    for execution_state in execution_states.iter_mut() {
        let vm_output_state = test_rust_equivalence_given_input_values::<T>(
            snippet_struct,
            &execution_state.stack,
            &execution_state.std_in,
            &execution_state.secret_in,
            &mut execution_state.memory,
            execution_state.words_allocated,
            None,
        );
        vm_output_states.push(vm_output_state);
    }

    vm_output_states
}

#[allow(dead_code)]
pub fn test_rust_equivalence_given_execution_state<T: Snippet>(
    snippet_struct: &T,
    mut execution_state: ExecutionState,
) -> VmOutputState {
    test_rust_equivalence_given_input_values::<T>(
        snippet_struct,
        &execution_state.stack,
        &execution_state.std_in,
        &execution_state.secret_in,
        &mut execution_state.memory,
        execution_state.words_allocated,
        None,
    )
}

#[allow(dead_code)]
pub fn test_rust_equivalence_given_input_values<T: Snippet>(
    snippet_struct: &T,
    stack: &[BFieldElement],
    stdin: &[BFieldElement],
    secret_in: &[BFieldElement],
    memory: &mut HashMap<BFieldElement, BFieldElement>,
    words_statically_allocated: usize,
    expected_final_stack: Option<&[BFieldElement]>,
) -> VmOutputState {
    let init_memory = memory.clone();
    let mut tasm_stack = stack.to_vec();
    let mut tasm_memory = init_memory.clone();

    let mut rust_memory = init_memory;
    let mut rust_stack = stack.to_vec();

    test_rust_equivalence_given_input_values_and_initial_stacks_and_memories(
        snippet_struct,
        stack,
        stdin,
        secret_in,
        memory,
        words_statically_allocated,
        expected_final_stack,
        &mut tasm_stack,
        &mut rust_stack,
        &mut tasm_memory,
        &mut rust_memory,
    )
}

#[allow(dead_code)]
#[allow(clippy::ptr_arg)]
#[allow(clippy::too_many_arguments)]
pub fn test_rust_equivalence_given_input_values_and_initial_stacks_and_memories<T: Snippet>(
    snippet_struct: &T,
    stack: &[BFieldElement],
    stdin: &[BFieldElement],
    secret_in: &[BFieldElement],
    memory: &mut HashMap<BFieldElement, BFieldElement>,
    words_statically_allocated: usize,
    expected_final_stack: Option<&[BFieldElement]>,
    tasm_stack: &mut Vec<BFieldElement>,
    rust_stack: &mut Vec<BFieldElement>,
    tasm_memory: &mut HashMap<BFieldElement, BFieldElement>,
    rust_memory: &mut HashMap<BFieldElement, BFieldElement>,
) -> VmOutputState {
    let init_stack = stack.to_vec();

    if words_statically_allocated > 0 {
        rust_shadowing_helper_functions::dyn_malloc::rust_dyn_malloc_initialize(
            rust_memory,
            words_statically_allocated,
        );
    }

    // run rust shadow
    snippet_struct.rust_shadowing(rust_stack, stdin.to_vec(), secret_in.to_vec(), rust_memory);

    // run tvm
    let vm_output_state = snippet_struct.link_and_run_tasm_for_test(
        tasm_stack,
        stdin.to_vec(),
        secret_in.to_vec(),
        tasm_memory,
        words_statically_allocated,
    );

    // assert stacks are equal, up to program hash
    let tasm_stack_skip_program_hash = tasm_stack.iter().cloned().skip(DIGEST_LENGTH).collect_vec();
    let rust_stack_skip_program_hash = rust_stack.iter().cloned().skip(DIGEST_LENGTH).collect_vec();
    assert_eq!(
        tasm_stack_skip_program_hash,
        rust_stack_skip_program_hash,
        "Rust code must match TVM for `{}`\n\nTVM: {}\n\nRust: {}. Code was: {}",
        snippet_struct.entrypoint(),
        tasm_stack_skip_program_hash
            .iter()
            .map(|x| x.to_string())
            .collect_vec()
            .join(","),
        rust_stack_skip_program_hash
            .iter()
            .map(|x| x.to_string())
            .collect_vec()
            .join(","),
        snippet_struct.function_code(&mut SnippetState::default())
    );

    // if expected final stack is given, test against it
    if let Some(expected) = expected_final_stack {
        let expected_final_stack_skip_program_hash =
            expected.iter().skip(DIGEST_LENGTH).cloned().collect_vec();
        assert_eq!(
            tasm_stack_skip_program_hash,
            expected_final_stack_skip_program_hash,
            "TVM must produce expected stack `{}`. \n\nTVM:\n{}\nExpected:\n{}",
            snippet_struct.entrypoint(),
            tasm_stack_skip_program_hash
                .iter()
                .map(|x| x.to_string())
                .collect_vec()
                .join(","),
            expected_final_stack_skip_program_hash
                .iter()
                .map(|x| x.to_string())
                .collect_vec()
                .join(","),
        );
    }

    // Verify that memory behaves as expected, except for the dyn malloc initialization address which
    // is too cumbersome to monitor this way. Its behavior should be tested elsewhere.
    // Alternatively the rust shadowing trait function must take a `Library` argument as input
    // and statically allocate memory from there.
    // TODO: Check if we could perform this check on dyn malloc too
    rust_memory.remove(&BFieldElement::new(DYN_MALLOC_ADDRESS as u64));
    tasm_memory.remove(&BFieldElement::new(DYN_MALLOC_ADDRESS as u64));
    let memory_difference = rust_memory
        .iter()
        .filter(|(k, v)| match tasm_memory.get(*k) {
            Some(b) => *b != **v,
            None => true,
        })
        .chain(
            tasm_memory
                .iter()
                .filter(|(k, v)| match rust_memory.get(*k) {
                    Some(b) => *b != **v,
                    None => true,
                }),
        )
        .collect_vec();
    if rust_memory != tasm_memory {
        let mut tasm_memory = tasm_memory.iter().collect_vec();
        tasm_memory.sort_unstable_by(|&a, &b| a.0.value().partial_cmp(&b.0.value()).unwrap());
        let tasm_mem_str = tasm_memory
            .iter()
            .map(|x| format!("({} => {})", x.0, x.1))
            .collect_vec()
            .join(",");

        let mut rust_memory = rust_memory.iter().collect_vec();
        rust_memory.sort_unstable_by(|&a, &b| a.0.value().partial_cmp(&b.0.value()).unwrap());
        let rust_mem_str = rust_memory
            .iter()
            .map(|x| format!("({} => {})", x.0, x.1))
            .collect_vec()
            .join(",");
        let diff_str = memory_difference
            .iter()
            .map(|x| format!("({} => {})", x.0, x.1))
            .collect_vec()
            .join(",");
        panic!(
            "Memory for both implementations must match after execution.\n\nTVM: {tasm_mem_str}\n\nRust: {rust_mem_str}\n\nDifference: {diff_str}\n\nCode was:\n\n {}",
            snippet_struct.function_code(&mut SnippetState::default())
        );
    }

    // Write back memory to be able to probe it in individual tests
    *memory = tasm_memory.clone();

    // Verify that stack grows with expected number of elements
    let stack_final = tasm_stack.clone();
    let observed_stack_growth: isize = stack_final.len() as isize - init_stack.len() as isize;
    let expected_stack_growth: isize =
        snippet_struct.outputs().len() as isize - snippet_struct.inputs().len() as isize;
    assert_eq!(
        expected_stack_growth,
        observed_stack_growth,
        "Stack must pop and push expected number of elements. Got input: {}\nGot output: {}",
        init_stack.iter().map(|x| x.to_string()).join(","),
        stack_final.iter().map(|x| x.to_string()).join(",")
    );

    vm_output_state
}

#[cfg(test)]
mod test {
    use std::collections::HashMap;

    use rand::random;
    use triton_vm::BFieldElement;
    use twenty_first::shared_math::tip5::DIGEST_LENGTH;

    use crate::{get_init_tvm_stack, hashing::sample_indices::SampleIndices, list::ListType};

    use super::test_rust_equivalence_given_input_values_and_initial_stacks_and_memories;

    /// TIP6 sets the bottom of the stack to the program hash. While testing Snippets,
    /// which are not standalone programs and therefore do not come with a well defined
    /// program hash, we want to verify that the tasm and rust stacks are identical up
    /// to these first five elements. This unit test tests this.
    #[test]
    fn test_program_hash_ignored() {
        let snippet_struct = SampleIndices {
            list_type: ListType::Safe,
        };
        let mut stack = get_init_tvm_stack();
        stack.push(BFieldElement::new(45u64));
        stack.push(BFieldElement::new(1u64 << 12));

        let mut init_memory = HashMap::new();
        let mut tasm_stack = stack.to_vec();
        for item in tasm_stack.iter_mut().take(DIGEST_LENGTH) {
            *item = random();
        }

        let mut tasm_memory = init_memory.clone();

        let mut rust_memory = init_memory.clone();
        let mut rust_stack = stack.to_vec();

        test_rust_equivalence_given_input_values_and_initial_stacks_and_memories(
            &snippet_struct,
            &stack,
            &[],
            &[],
            &mut init_memory,
            1,
            None,
            &mut tasm_stack,
            &mut rust_stack,
            &mut tasm_memory,
            &mut rust_memory,
        );
    }
}