tasm_lib/traits/
mem_preserver.rs1use std::collections::HashMap;
2use std::collections::VecDeque;
3
4use rand::prelude::*;
5use triton_vm::prelude::*;
6
7use super::basic_snippet::BasicSnippet;
8use super::rust_shadow::RustShadow;
9use super::rust_shadow::RustShadowError;
10use crate::InitVmState;
11use crate::linker::execute_bench;
12use crate::snippet_bencher::BenchmarkCase;
13use crate::snippet_bencher::NamedBenchmarkResult;
14use crate::snippet_bencher::write_benchmarks;
15use crate::test_helpers::test_rust_equivalence_given_complete_state;
16
17pub trait MemPreserver: BasicSnippet {
32 fn rust_shadow(
33 &self,
34 stack: &mut Vec<BFieldElement>,
35 memory: &HashMap<BFieldElement, BFieldElement>,
36 nd_tokens: VecDeque<BFieldElement>,
37 nd_digests: VecDeque<Digest>,
38 stdin: VecDeque<BFieldElement>,
39 sponge: &mut Option<Tip5>,
40 ) -> Result<Vec<BFieldElement>, RustShadowError>;
41
42 fn pseudorandom_initial_state(
43 &self,
44 seed: [u8; 32],
45 bench_case: Option<BenchmarkCase>,
46 ) -> MemPreserverInitialState;
47
48 fn corner_case_initial_states(&self) -> Vec<MemPreserverInitialState> {
49 Vec::new()
50 }
51}
52
53#[derive(Debug, Clone, Default)]
54pub struct MemPreserverInitialState {
55 pub stack: Vec<BFieldElement>,
56 pub nondeterminism: NonDeterminism,
57 pub public_input: VecDeque<BFieldElement>,
58 pub sponge_state: Option<Tip5>,
59}
60
61impl From<MemPreserverInitialState> for InitVmState {
62 fn from(value: MemPreserverInitialState) -> Self {
63 Self {
64 stack: value.stack,
65 nondeterminism: value.nondeterminism,
66 public_input: value.public_input.into(),
67 sponge: value.sponge_state,
68 }
69 }
70}
71
72pub struct ShadowedMemPreserver<T: MemPreserver> {
73 mem_preserver: T,
74}
75
76impl<T: MemPreserver> ShadowedMemPreserver<T> {
77 pub fn new(mem_preserver: T) -> Self {
78 Self { mem_preserver }
79 }
80}
81
82impl<T> RustShadow for ShadowedMemPreserver<T>
83where
84 T: MemPreserver,
85{
86 fn inner(&self) -> &dyn BasicSnippet {
87 &self.mem_preserver
88 }
89
90 fn rust_shadow_wrapper(
91 &self,
92 stdin: &[BFieldElement],
93 nondeterminism: &NonDeterminism,
94 stack: &mut Vec<BFieldElement>,
95 memory: &mut HashMap<BFieldElement, BFieldElement>,
96 sponge: &mut Option<Tip5>,
97 ) -> Result<Vec<BFieldElement>, RustShadowError> {
98 self.mem_preserver.rust_shadow(
99 stack,
100 memory,
101 nondeterminism.individual_tokens.to_owned().into(),
102 nondeterminism.digests.to_owned().into(),
103 stdin.to_vec().into(),
104 sponge,
105 )
106 }
107
108 fn test(&self) {
109 for corner_case in self.mem_preserver.corner_case_initial_states() {
110 let stdin: Vec<_> = corner_case.public_input.into();
111
112 test_rust_equivalence_given_complete_state(
113 self,
114 &corner_case.stack,
115 &stdin,
116 &corner_case.nondeterminism,
117 &corner_case.sponge_state,
118 None,
119 );
120 }
121
122 let num_states = 10;
123 let mut rng = StdRng::from_seed(rand::random());
124 for _ in 0..num_states {
125 let MemPreserverInitialState {
126 stack,
127 public_input,
128 sponge_state,
129 nondeterminism: non_determinism,
130 } = self
131 .mem_preserver
132 .pseudorandom_initial_state(rng.random(), None);
133
134 let stdin: Vec<_> = public_input.into();
135 test_rust_equivalence_given_complete_state(
136 self,
137 &stack,
138 &stdin,
139 &non_determinism,
140 &sponge_state,
141 None,
142 );
143 }
144 }
145
146 fn bench(&self) {
147 let mut rng = StdRng::from_seed(
148 hex::decode("73a24b6b8b32e4d7d563a4d9a85f476573a24b6b8b32e4d7d563a4d9a85f4765")
149 .unwrap()
150 .try_into()
151 .unwrap(),
152 );
153 let mut benchmarks = Vec::with_capacity(2);
154
155 for bench_case in [BenchmarkCase::CommonCase, BenchmarkCase::WorstCase] {
156 let MemPreserverInitialState {
157 stack,
158 public_input,
159 sponge_state,
160 nondeterminism: non_determinism,
161 } = self
162 .mem_preserver
163 .pseudorandom_initial_state(rng.random(), Some(bench_case));
164 let program = self.mem_preserver.link_for_isolated_run();
165 let benchmark = execute_bench(
166 &program,
167 &stack,
168 public_input.into(),
169 non_determinism,
170 sponge_state,
171 );
172 let benchmark = NamedBenchmarkResult {
173 name: self.mem_preserver.entrypoint(),
174 benchmark_result: benchmark,
175 case: bench_case,
176 };
177 benchmarks.push(benchmark);
178 }
179
180 write_benchmarks(benchmarks);
181 }
182}