tasm_lib/traits/
function.rs

1use std::collections::HashMap;
2
3use rand::prelude::*;
4use triton_vm::prelude::*;
5
6use super::basic_snippet::BasicSnippet;
7use super::rust_shadow::RustShadow;
8use crate::linker::execute_bench;
9use crate::prelude::Tip5;
10use crate::snippet_bencher::write_benchmarks;
11use crate::snippet_bencher::BenchmarkCase;
12use crate::snippet_bencher::NamedBenchmarkResult;
13use crate::test_helpers::test_rust_equivalence_given_complete_state;
14use crate::InitVmState;
15
16/// A function can modify stack and extend memory.
17///
18/// A Function is a piece of tasm code that can modify the top of the stack, and can read
19/// and even extend memory. Specifically: any memory writes have to happen to addresses
20/// larger than the dynamic memory allocator and the dynamic memory allocator value has to
21/// be updated accordingly.
22///
23/// See also: [closure], [algorithm], [read_only_algorithm], [procedure],
24///           [accessor], [mem_preserver],
25///
26///
27/// [closure]: crate::traits::closure::Closure
28/// [algorithm]: crate::traits::algorithm::Algorithm
29/// [read_only_algorithm]: crate::traits::read_only_algorithm::ReadOnlyAlgorithm
30/// [procedure]: crate::traits::procedure::Procedure
31/// [accessor]: crate::traits::accessor::Accessor
32/// [mem_preserver]: crate::traits::mem_preserver::MemPreserver
33pub trait Function: BasicSnippet {
34    fn rust_shadow(
35        &self,
36        stack: &mut Vec<BFieldElement>,
37        memory: &mut HashMap<BFieldElement, BFieldElement>,
38    );
39
40    /// Return (init_stack, init_memory)
41    fn pseudorandom_initial_state(
42        &self,
43        seed: [u8; 32],
44        bench_case: Option<BenchmarkCase>,
45    ) -> FunctionInitialState;
46
47    fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
48        vec![]
49    }
50}
51
52#[derive(Debug, Clone, Default)]
53pub struct FunctionInitialState {
54    pub stack: Vec<BFieldElement>,
55    pub memory: HashMap<BFieldElement, BFieldElement>,
56}
57
58impl From<FunctionInitialState> for InitVmState {
59    fn from(value: FunctionInitialState) -> Self {
60        let nd = NonDeterminism::default().with_ram(value.memory);
61        Self {
62            stack: value.stack,
63            nondeterminism: nd,
64            ..Default::default()
65        }
66    }
67}
68
69pub struct ShadowedFunction<F: Function> {
70    function: F,
71}
72
73impl<F: Function> ShadowedFunction<F> {
74    pub fn new(function: F) -> Self {
75        Self { function }
76    }
77}
78
79impl<P: Function> ShadowedFunction<P> {
80    fn test_initial_state(&self, state: FunctionInitialState) {
81        let FunctionInitialState { stack, memory } = state;
82
83        let stdin = vec![];
84        let non_determinism = NonDeterminism {
85            individual_tokens: vec![],
86            digests: vec![],
87            ram: memory,
88        };
89        test_rust_equivalence_given_complete_state(
90            self,
91            &stack,
92            &stdin,
93            &non_determinism,
94            &None,
95            None,
96        );
97    }
98}
99
100impl<F> RustShadow for ShadowedFunction<F>
101where
102    F: Function,
103{
104    fn inner(&self) -> &dyn BasicSnippet {
105        &self.function
106    }
107
108    fn rust_shadow_wrapper(
109        &self,
110        _stdin: &[BFieldElement],
111        _nondeterminism: &NonDeterminism,
112        stack: &mut Vec<BFieldElement>,
113        memory: &mut HashMap<BFieldElement, BFieldElement>,
114        _sponge: &mut Option<Tip5>,
115    ) -> Vec<BFieldElement> {
116        self.function.rust_shadow(stack, memory);
117        vec![]
118    }
119
120    /// Test rust-tasm equivalence.
121    fn test(&self) {
122        for cornercase_state in self.function.corner_case_initial_states() {
123            self.test_initial_state(cornercase_state);
124        }
125
126        let num_rng_states = 5;
127        let mut rng = rand::rng();
128
129        for _ in 0..num_rng_states {
130            let initial_state = self.function.pseudorandom_initial_state(rng.random(), None);
131            self.test_initial_state(initial_state)
132        }
133    }
134
135    /// Count number of cycles and other performance indicators and save them in directory
136    /// benchmarks/.
137    fn bench(&self) {
138        let seed = hex::decode("73a24b6b8b32e4d7d563a4d9a85f476573a24b6b8b32e4d7d563a4d9a85f4765")
139            .unwrap()
140            .try_into()
141            .unwrap();
142        let mut rng = StdRng::from_seed(seed);
143        let mut benchmarks = Vec::with_capacity(2);
144
145        for bench_case in [BenchmarkCase::CommonCase, BenchmarkCase::WorstCase] {
146            let FunctionInitialState { stack, memory } = self
147                .function
148                .pseudorandom_initial_state(rng.random(), Some(bench_case));
149            let program = self.function.link_for_isolated_run();
150            let non_determinism = NonDeterminism::default().with_ram(memory);
151            let benchmark = execute_bench(&program, &stack, vec![], non_determinism, None);
152            let benchmark = NamedBenchmarkResult {
153                name: self.function.entrypoint(),
154                benchmark_result: benchmark,
155                case: bench_case,
156            };
157            benchmarks.push(benchmark);
158        }
159
160        write_benchmarks(benchmarks);
161    }
162}