tasm_lib/traits/
mem_preserver.rs1use std::collections::HashMap;
2use std::collections::VecDeque;
3
4use rand::prelude::*;
5use triton_vm::prelude::*;
6
7use super::basic_snippet::BasicSnippet;
8use super::rust_shadow::RustShadow;
9use crate::InitVmState;
10use crate::linker::execute_bench;
11use crate::snippet_bencher::BenchmarkCase;
12use crate::snippet_bencher::NamedBenchmarkResult;
13use crate::snippet_bencher::write_benchmarks;
14use crate::test_helpers::test_rust_equivalence_given_complete_state;
15
16pub trait MemPreserver: BasicSnippet {
31 fn rust_shadow(
32 &self,
33 stack: &mut Vec<BFieldElement>,
34 memory: &HashMap<BFieldElement, BFieldElement>,
35 nd_tokens: VecDeque<BFieldElement>,
36 nd_digests: VecDeque<Digest>,
37 stdin: VecDeque<BFieldElement>,
38 sponge: &mut Option<Tip5>,
39 ) -> Vec<BFieldElement>;
40
41 fn pseudorandom_initial_state(
42 &self,
43 seed: [u8; 32],
44 bench_case: Option<BenchmarkCase>,
45 ) -> MemPreserverInitialState;
46
47 fn corner_case_initial_states(&self) -> Vec<MemPreserverInitialState> {
48 vec![]
49 }
50}
51
52#[derive(Debug, Clone, Default)]
53pub struct MemPreserverInitialState {
54 pub stack: Vec<BFieldElement>,
55 pub nondeterminism: NonDeterminism,
56 pub public_input: VecDeque<BFieldElement>,
57 pub sponge_state: Option<Tip5>,
58}
59
60impl From<MemPreserverInitialState> for InitVmState {
61 fn from(value: MemPreserverInitialState) -> Self {
62 Self {
63 stack: value.stack,
64 nondeterminism: value.nondeterminism,
65 public_input: value.public_input.into(),
66 sponge: value.sponge_state,
67 }
68 }
69}
70
71pub struct ShadowedMemPreserver<T: MemPreserver> {
72 mem_preserver: T,
73}
74
75impl<T: MemPreserver> ShadowedMemPreserver<T> {
76 pub fn new(mem_preserver: T) -> Self {
77 Self { mem_preserver }
78 }
79}
80
81impl<T> RustShadow for ShadowedMemPreserver<T>
82where
83 T: MemPreserver,
84{
85 fn inner(&self) -> &dyn BasicSnippet {
86 &self.mem_preserver
87 }
88
89 fn rust_shadow_wrapper(
90 &self,
91 stdin: &[BFieldElement],
92 nondeterminism: &NonDeterminism,
93 stack: &mut Vec<BFieldElement>,
94 memory: &mut HashMap<BFieldElement, BFieldElement>,
95 sponge: &mut Option<Tip5>,
96 ) -> Vec<BFieldElement> {
97 self.mem_preserver.rust_shadow(
98 stack,
99 memory,
100 nondeterminism.individual_tokens.to_owned().into(),
101 nondeterminism.digests.to_owned().into(),
102 stdin.to_vec().into(),
103 sponge,
104 )
105 }
106
107 fn test(&self) {
108 for corner_case in self.mem_preserver.corner_case_initial_states() {
109 let stdin: Vec<_> = corner_case.public_input.into();
110
111 test_rust_equivalence_given_complete_state(
112 self,
113 &corner_case.stack,
114 &stdin,
115 &corner_case.nondeterminism,
116 &corner_case.sponge_state,
117 None,
118 );
119 }
120
121 let num_states = 10;
122 let mut rng = StdRng::from_seed(rand::random());
123 for _ in 0..num_states {
124 let MemPreserverInitialState {
125 stack,
126 public_input,
127 sponge_state,
128 nondeterminism: non_determinism,
129 } = self
130 .mem_preserver
131 .pseudorandom_initial_state(rng.random(), None);
132
133 let stdin: Vec<_> = public_input.into();
134 test_rust_equivalence_given_complete_state(
135 self,
136 &stack,
137 &stdin,
138 &non_determinism,
139 &sponge_state,
140 None,
141 );
142 }
143 }
144
145 fn bench(&self) {
146 let mut rng = StdRng::from_seed(
147 hex::decode("73a24b6b8b32e4d7d563a4d9a85f476573a24b6b8b32e4d7d563a4d9a85f4765")
148 .unwrap()
149 .try_into()
150 .unwrap(),
151 );
152 let mut benchmarks = Vec::with_capacity(2);
153
154 for bench_case in [BenchmarkCase::CommonCase, BenchmarkCase::WorstCase] {
155 let MemPreserverInitialState {
156 stack,
157 public_input,
158 sponge_state,
159 nondeterminism: non_determinism,
160 } = self
161 .mem_preserver
162 .pseudorandom_initial_state(rng.random(), Some(bench_case));
163 let program = self.mem_preserver.link_for_isolated_run();
164 let benchmark = execute_bench(
165 &program,
166 &stack,
167 public_input.into(),
168 non_determinism,
169 sponge_state,
170 );
171 let benchmark = NamedBenchmarkResult {
172 name: self.mem_preserver.entrypoint(),
173 benchmark_result: benchmark,
174 case: bench_case,
175 };
176 benchmarks.push(benchmark);
177 }
178
179 write_benchmarks(benchmarks);
180 }
181}