Skip to main content

tasm_lib/traits/
function.rs

1use std::collections::HashMap;
2
3use rand::prelude::*;
4use triton_vm::prelude::*;
5
6use super::basic_snippet::BasicSnippet;
7use super::rust_shadow::RustShadow;
8use super::rust_shadow::RustShadowError;
9use crate::InitVmState;
10use crate::linker::execute_bench;
11use crate::prelude::Tip5;
12use crate::snippet_bencher::BenchmarkCase;
13use crate::snippet_bencher::NamedBenchmarkResult;
14use crate::snippet_bencher::write_benchmarks;
15use crate::test_helpers::test_rust_equivalence_given_complete_state;
16
17/// A function can modify stack and extend memory.
18///
19/// A Function is a piece of tasm code that can modify the top of the stack, and can read
20/// and even extend memory. Specifically: any memory writes have to happen to addresses
21/// larger than the dynamic memory allocator and the dynamic memory allocator value has to
22/// be updated accordingly.
23///
24/// See also: [closure], [algorithm], [read_only_algorithm], [procedure],
25///           [accessor], [mem_preserver],
26///
27///
28/// [closure]: crate::traits::closure::Closure
29/// [algorithm]: crate::traits::algorithm::Algorithm
30/// [read_only_algorithm]: crate::traits::read_only_algorithm::ReadOnlyAlgorithm
31/// [procedure]: crate::traits::procedure::Procedure
32/// [accessor]: crate::traits::accessor::Accessor
33/// [mem_preserver]: crate::traits::mem_preserver::MemPreserver
34pub trait Function: BasicSnippet {
35    fn rust_shadow(
36        &self,
37        stack: &mut Vec<BFieldElement>,
38        memory: &mut HashMap<BFieldElement, BFieldElement>,
39    ) -> Result<(), RustShadowError>;
40
41    /// Return (init_stack, init_memory)
42    fn pseudorandom_initial_state(
43        &self,
44        seed: [u8; 32],
45        bench_case: Option<BenchmarkCase>,
46    ) -> FunctionInitialState;
47
48    fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
49        Vec::new()
50    }
51}
52
53#[derive(Debug, Clone, Default)]
54pub struct FunctionInitialState {
55    pub stack: Vec<BFieldElement>,
56    pub memory: HashMap<BFieldElement, BFieldElement>,
57}
58
59impl From<FunctionInitialState> for InitVmState {
60    fn from(value: FunctionInitialState) -> Self {
61        let nd = NonDeterminism::default().with_ram(value.memory);
62        Self {
63            stack: value.stack,
64            nondeterminism: nd,
65            ..Default::default()
66        }
67    }
68}
69
70pub struct ShadowedFunction<F: Function> {
71    function: F,
72}
73
74impl<F: Function> ShadowedFunction<F> {
75    pub fn new(function: F) -> Self {
76        Self { function }
77    }
78}
79
80impl<P: Function> ShadowedFunction<P> {
81    fn test_initial_state(&self, state: FunctionInitialState) {
82        let FunctionInitialState { stack, memory } = state;
83
84        let stdin = vec![];
85        let non_determinism = NonDeterminism {
86            individual_tokens: vec![],
87            digests: vec![],
88            ram: memory,
89        };
90        test_rust_equivalence_given_complete_state(
91            self,
92            &stack,
93            &stdin,
94            &non_determinism,
95            &None,
96            None,
97        );
98    }
99}
100
101impl<F> RustShadow for ShadowedFunction<F>
102where
103    F: Function,
104{
105    fn inner(&self) -> &dyn BasicSnippet {
106        &self.function
107    }
108
109    fn rust_shadow_wrapper(
110        &self,
111        _stdin: &[BFieldElement],
112        _nondeterminism: &NonDeterminism,
113        stack: &mut Vec<BFieldElement>,
114        memory: &mut HashMap<BFieldElement, BFieldElement>,
115        _sponge: &mut Option<Tip5>,
116    ) -> Result<Vec<BFieldElement>, RustShadowError> {
117        self.function.rust_shadow(stack, memory)?;
118
119        Ok(Vec::new())
120    }
121
122    /// Test rust-tasm equivalence.
123    fn test(&self) {
124        for cornercase_state in self.function.corner_case_initial_states() {
125            self.test_initial_state(cornercase_state);
126        }
127
128        let num_rng_states = 5;
129        let mut rng = rand::rng();
130
131        for _ in 0..num_rng_states {
132            let initial_state = self.function.pseudorandom_initial_state(rng.random(), None);
133            self.test_initial_state(initial_state)
134        }
135    }
136
137    /// Count number of cycles and other performance indicators and save them in directory
138    /// benchmarks/.
139    fn bench(&self) {
140        let seed = hex::decode("73a24b6b8b32e4d7d563a4d9a85f476573a24b6b8b32e4d7d563a4d9a85f4765")
141            .unwrap()
142            .try_into()
143            .unwrap();
144        let mut rng = StdRng::from_seed(seed);
145        let mut benchmarks = Vec::with_capacity(2);
146
147        for bench_case in [BenchmarkCase::CommonCase, BenchmarkCase::WorstCase] {
148            let FunctionInitialState { stack, memory } = self
149                .function
150                .pseudorandom_initial_state(rng.random(), Some(bench_case));
151            let program = self.function.link_for_isolated_run();
152            let non_determinism = NonDeterminism::default().with_ram(memory);
153            let benchmark = execute_bench(&program, &stack, vec![], non_determinism, None);
154            let benchmark = NamedBenchmarkResult {
155                name: self.function.entrypoint(),
156                benchmark_result: benchmark,
157                case: bench_case,
158            };
159            benchmarks.push(benchmark);
160        }
161
162        write_benchmarks(benchmarks);
163    }
164}