Skip to main content

tasm_lib/traits/
mem_preserver.rs

1use std::collections::HashMap;
2use std::collections::VecDeque;
3
4use rand::prelude::*;
5use triton_vm::prelude::*;
6
7use super::basic_snippet::BasicSnippet;
8use super::rust_shadow::RustShadow;
9use super::rust_shadow::RustShadowError;
10use crate::InitVmState;
11use crate::linker::execute_bench;
12use crate::snippet_bencher::BenchmarkCase;
13use crate::snippet_bencher::NamedBenchmarkResult;
14use crate::snippet_bencher::write_benchmarks;
15use crate::test_helpers::test_rust_equivalence_given_complete_state;
16
17/// A MemPreserver cannot modify memory
18///
19/// An MemPreserver is a piece of tasm code that can do pretty much everything
20/// except modify memory, including static memory. It can read from any input
21/// and write to standard out. It can also modify the sponge state.
22/// See also: [closure], [function], [procedure], [algorithm],
23///           [read_only_algorithm], [accessor]
24///
25/// [closure]: crate::traits::closure::Closure
26/// [function]: crate::traits::function::Function
27/// [procedure]: crate::traits::procedure::Procedure
28/// [algorithm]: crate::traits::algorithm::Algorithm
29/// [read_only_algorithm]: crate::traits::read_only_algorithm::ReadOnlyAlgorithm
30/// [accessor]: crate::traits::accessor::Accessor
31pub trait MemPreserver: BasicSnippet {
32    fn rust_shadow(
33        &self,
34        stack: &mut Vec<BFieldElement>,
35        memory: &HashMap<BFieldElement, BFieldElement>,
36        nd_tokens: VecDeque<BFieldElement>,
37        nd_digests: VecDeque<Digest>,
38        stdin: VecDeque<BFieldElement>,
39        sponge: &mut Option<Tip5>,
40    ) -> Result<Vec<BFieldElement>, RustShadowError>;
41
42    fn pseudorandom_initial_state(
43        &self,
44        seed: [u8; 32],
45        bench_case: Option<BenchmarkCase>,
46    ) -> MemPreserverInitialState;
47
48    fn corner_case_initial_states(&self) -> Vec<MemPreserverInitialState> {
49        Vec::new()
50    }
51}
52
53#[derive(Debug, Clone, Default)]
54pub struct MemPreserverInitialState {
55    pub stack: Vec<BFieldElement>,
56    pub nondeterminism: NonDeterminism,
57    pub public_input: VecDeque<BFieldElement>,
58    pub sponge_state: Option<Tip5>,
59}
60
61impl From<MemPreserverInitialState> for InitVmState {
62    fn from(value: MemPreserverInitialState) -> Self {
63        Self {
64            stack: value.stack,
65            nondeterminism: value.nondeterminism,
66            public_input: value.public_input.into(),
67            sponge: value.sponge_state,
68        }
69    }
70}
71
72pub struct ShadowedMemPreserver<T: MemPreserver> {
73    mem_preserver: T,
74}
75
76impl<T: MemPreserver> ShadowedMemPreserver<T> {
77    pub fn new(mem_preserver: T) -> Self {
78        Self { mem_preserver }
79    }
80}
81
82impl<T> RustShadow for ShadowedMemPreserver<T>
83where
84    T: MemPreserver,
85{
86    fn inner(&self) -> &dyn BasicSnippet {
87        &self.mem_preserver
88    }
89
90    fn rust_shadow_wrapper(
91        &self,
92        stdin: &[BFieldElement],
93        nondeterminism: &NonDeterminism,
94        stack: &mut Vec<BFieldElement>,
95        memory: &mut HashMap<BFieldElement, BFieldElement>,
96        sponge: &mut Option<Tip5>,
97    ) -> Result<Vec<BFieldElement>, RustShadowError> {
98        self.mem_preserver.rust_shadow(
99            stack,
100            memory,
101            nondeterminism.individual_tokens.to_owned().into(),
102            nondeterminism.digests.to_owned().into(),
103            stdin.to_vec().into(),
104            sponge,
105        )
106    }
107
108    fn test(&self) {
109        for corner_case in self.mem_preserver.corner_case_initial_states() {
110            let stdin: Vec<_> = corner_case.public_input.into();
111
112            test_rust_equivalence_given_complete_state(
113                self,
114                &corner_case.stack,
115                &stdin,
116                &corner_case.nondeterminism,
117                &corner_case.sponge_state,
118                None,
119            );
120        }
121
122        let num_states = 10;
123        let mut rng = StdRng::from_seed(rand::random());
124        for _ in 0..num_states {
125            let MemPreserverInitialState {
126                stack,
127                public_input,
128                sponge_state,
129                nondeterminism: non_determinism,
130            } = self
131                .mem_preserver
132                .pseudorandom_initial_state(rng.random(), None);
133
134            let stdin: Vec<_> = public_input.into();
135            test_rust_equivalence_given_complete_state(
136                self,
137                &stack,
138                &stdin,
139                &non_determinism,
140                &sponge_state,
141                None,
142            );
143        }
144    }
145
146    fn bench(&self) {
147        let mut rng = StdRng::from_seed(
148            hex::decode("73a24b6b8b32e4d7d563a4d9a85f476573a24b6b8b32e4d7d563a4d9a85f4765")
149                .unwrap()
150                .try_into()
151                .unwrap(),
152        );
153        let mut benchmarks = Vec::with_capacity(2);
154
155        for bench_case in [BenchmarkCase::CommonCase, BenchmarkCase::WorstCase] {
156            let MemPreserverInitialState {
157                stack,
158                public_input,
159                sponge_state,
160                nondeterminism: non_determinism,
161            } = self
162                .mem_preserver
163                .pseudorandom_initial_state(rng.random(), Some(bench_case));
164            let program = self.mem_preserver.link_for_isolated_run();
165            let benchmark = execute_bench(
166                &program,
167                &stack,
168                public_input.into(),
169                non_determinism,
170                sponge_state,
171            );
172            let benchmark = NamedBenchmarkResult {
173                name: self.mem_preserver.entrypoint(),
174                benchmark_result: benchmark,
175                case: bench_case,
176            };
177            benchmarks.push(benchmark);
178        }
179
180        write_benchmarks(benchmarks);
181    }
182}